Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ members = [
"zluda_inject",
"zluda_ld",
"zluda_ml",
"zluda_replay",
"zluda_redirect",
"zluda_sparse",
"compiler",
Expand Down
2 changes: 2 additions & 0 deletions comgr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ pub fn compile_bitcode(
compile_to_exec.set_isa_name(gcn_arch)?;
compile_to_exec.set_language(Language::LlvmIr)?;
let common_options = [
c"-Xlinker",
c"--no-undefined",
c"-mllvm",
c"-ignore-tti-inline-compatible",
// c"-mllvm",
Expand Down
2 changes: 1 addition & 1 deletion cuda_macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ edition = "2021"
quote = "1.0"
syn = { version = "2.0", features = ["full", "visit-mut", "extra-traits"] }
proc-macro2 = "1.0"
rustc-hash = "1.1.0"
rustc-hash = "2.0.0"

[lib]
proc-macro = true
46 changes: 22 additions & 24 deletions ptx/src/pass/llvm/emit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1656,25 +1656,23 @@ impl<'a> MethodEmitContext<'a> {
.ok_or_else(|| error_mismatched_type())?,
);
let src2 = self.resolver.value(src2)?;
self.resolver.with_result(arguments.dst, |dst| {
let vec = unsafe {
LLVMBuildInsertElement(
self.builder,
LLVMGetPoison(dst_type),
llvm_fn(self.builder, src, packed_type, LLVM_UNNAMED.as_ptr()),
LLVMConstInt(LLVMInt32TypeInContext(self.context), 1, false as i32),
LLVM_UNNAMED.as_ptr(),
)
};
unsafe {
LLVMBuildInsertElement(
self.builder,
vec,
llvm_fn(self.builder, src2, packed_type, LLVM_UNNAMED.as_ptr()),
LLVMConstInt(LLVMInt32TypeInContext(self.context), 0, false as i32),
dst,
)
}
let vec = unsafe {
LLVMBuildInsertElement(
self.builder,
LLVMGetPoison(dst_type),
llvm_fn(self.builder, src, packed_type, LLVM_UNNAMED.as_ptr()),
LLVMConstInt(LLVMInt32TypeInContext(self.context), 1, false as i32),
LLVM_UNNAMED.as_ptr(),
)
};
self.resolver.with_result(arguments.dst, |dst| unsafe {
LLVMBuildInsertElement(
self.builder,
vec,
llvm_fn(self.builder, src2, packed_type, LLVM_UNNAMED.as_ptr()),
LLVMConstInt(LLVMInt32TypeInContext(self.context), 0, false as i32),
dst,
)
})
} else {
self.resolver.with_result(arguments.dst, |dst| unsafe {
Expand Down Expand Up @@ -2200,7 +2198,7 @@ impl<'a> MethodEmitContext<'a> {
Some(&ast::ScalarType::F32.into()),
vec![(
self.resolver.value(arguments.src)?,
get_scalar_type(self.context, ast::ScalarType::F32.into()),
get_scalar_type(self.context, ast::ScalarType::F32),
)],
)?;
Ok(())
Expand Down Expand Up @@ -2703,14 +2701,14 @@ impl<'a> MethodEmitContext<'a> {

let load = unsafe { LLVMBuildLoad2(self.builder, from_type, from, LLVM_UNNAMED.as_ptr()) };
unsafe {
LLVMSetAlignment(load, (cp_size.as_u64() as u32) * 8);
LLVMSetAlignment(load, cp_size.as_u64() as u32);
}

let extended = unsafe { LLVMBuildZExt(self.builder, load, to_type, LLVM_UNNAMED.as_ptr()) };

unsafe { LLVMBuildStore(self.builder, extended, to) };
let store = unsafe { LLVMBuildStore(self.builder, extended, to) };
unsafe {
LLVMSetAlignment(load, (cp_size.as_u64() as u32) * 8);
LLVMSetAlignment(store, cp_size.as_u64() as u32);
}
Ok(())
}
Expand Down Expand Up @@ -2990,7 +2988,7 @@ fn get_scope_membar(scope: ast::MemScope) -> Result<*const i8, TranslateError> {
Ok(match scope {
ast::MemScope::Cta => c"workgroup",
ast::MemScope::Gpu => c"agent",
ast::MemScope::Sys => c"",
ast::MemScope::Sys => c"system",
ast::MemScope::Cluster => todo!(),
}
.as_ptr())
Expand Down
35 changes: 28 additions & 7 deletions ptx_parser/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use derive_more::Display;
use logos::Logos;
use ptx_parser_macros::derive_parser;
use rustc_hash::FxHashMap;
use std::alloc::Layout;
use std::fmt::Debug;
use std::num::{NonZeroU8, ParseFloatError, ParseIntError};
use std::{iter, usize};
Expand Down Expand Up @@ -226,8 +227,9 @@ fn int_immediate<'a, 'input>(input: &mut PtxParser<'a, 'input>) -> PResult<ast::
take_error((opt(Token::Minus), num).map(|(neg, x)| {
let (num, radix, is_unsigned) = x;
if neg.is_some() {
match i64::from_str_radix(num, radix) {
Ok(x) => Ok(ast::ImmediateValue::S64(-x)),
let full_number = format!("-{num}");
match i64::from_str_radix(&full_number, radix) {
Ok(x) => Ok(ast::ImmediateValue::S64(x)),
Err(err) => Err((ast::ImmediateValue::S64(0), PtxError::from(err))),
}
} else if is_unsigned {
Expand Down Expand Up @@ -345,7 +347,9 @@ fn reg_or_immediate<'a, 'input>(
.parse_next(stream)
}

pub fn parse_for_errors<'input>(text: &'input str) -> Vec<PtxError<'input>> {
pub fn parse_for_errors_and_params<'input>(
text: &'input str,
) -> (Vec<PtxError<'input>>, FxHashMap<String, Vec<Layout>>) {
let (tokens, mut errors) = lex_with_span_unchecked(text);
let parse_result = {
let state = PtxParserState::new(text, &mut errors);
Expand All @@ -357,13 +361,30 @@ pub fn parse_for_errors<'input>(text: &'input str) -> Vec<PtxError<'input>> {
.parse(parser)
.map_err(|err| PtxError::Parser(err.into_inner()))
};
match parse_result {
Ok(_) => {}
let params = match parse_result {
Ok(module) => module
.directives
.into_iter()
.filter_map(|directive| {
if let ast::Directive::Method(_, func) = directive {
let layouts = func
.func_directive
.input_arguments
.iter()
.map(|arg| arg.info.v_type.layout())
.collect();
Some((func.func_directive.name().to_string(), layouts))
} else {
None
}
})
.collect(),
Err(err) => {
errors.push(err);
FxHashMap::default()
}
}
errors
};
(errors, params)
}

fn lex_with_span_unchecked<'input>(
Expand Down
2 changes: 1 addition & 1 deletion zluda/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ num_enum = "0.4"
lz4-sys = "1.9"
tempfile = "3"
paste = "1.0"
rustc-hash = "1.1"
rustc-hash = "2.0.0"
zluda_common = { path = "../zluda_common" }
blake3 = "1.8.2"
serde = "1.0.219"
Expand Down
49 changes: 43 additions & 6 deletions zluda/src/impl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,25 @@ pub(crate) unsafe fn push_current_v2(ctx: CUcontext) -> CUresult {
push_current(ctx)
}

pub(crate) unsafe fn pop_current(ctx: &mut CUcontext) -> CUresult {
STACK.with(|stack| {
if let Some((_ctx, _)) = stack.borrow_mut().pop() {
*ctx = _ctx;
}
pub(crate) unsafe fn pop_current(result: Option<&mut CUcontext>) -> CUresult {
let old_ctx_and_new_device = STACK.with(|stack| {
let mut stack = stack.borrow_mut();
stack
.pop()
.map(|(ctx, _)| (ctx, stack.last().map(|(_, dev)| *dev)))
});
let ctx = match old_ctx_and_new_device {
Some((old_ctx, new_device)) => {
if let Some(new_device) = new_device {
hipSetDevice(new_device)?;
}
old_ctx
}
None => return CUresult::ERROR_INVALID_CONTEXT,
};
if let Some(out) = result {
*out = ctx;
}
Ok(())
}

Expand All @@ -213,7 +226,7 @@ pub(crate) unsafe fn destroy_v2(ctx: CUcontext) -> CUresult {
zluda_common::drop_checked::<Context>(ctx)
}

pub(crate) unsafe fn pop_current_v2(ctx: &mut CUcontext) -> CUresult {
pub(crate) unsafe fn pop_current_v2(ctx: Option<&mut CUcontext>) -> CUresult {
pop_current(ctx)
}

Expand Down Expand Up @@ -241,3 +254,27 @@ pub(crate) unsafe fn get_api_version(
*version = 3020;
Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
use crate::tests::CudaApi;
use cuda_macros::test_cuda;
use std::mem;

#[test_cuda]
fn empty_pop_fails(api: impl CudaApi) {
api.cuInit(0);
assert_eq!(
api.cuCtxPopCurrent_v2_unchecked(&mut unsafe { mem::zeroed() }),
CUresult::ERROR_INVALID_CONTEXT
);
}

#[test_cuda]
fn pop_into_null_succeeds(api: impl CudaApi) {
api.cuInit(0);
api.cuCtxCreate_v2(&mut unsafe { mem::zeroed() }, 0, 0);
api.cuCtxPopCurrent_v2(ptr::null_mut());
}
}
2 changes: 1 addition & 1 deletion zluda_bindgen/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ syn = { version = "2.0", features = ["full", "visit-mut"] }
proc-macro2 = "1.0.89"
quote = "1.0"
prettyplease = "0.2.25"
rustc-hash = "1.1.0"
rustc-hash = "2.0.0"
libloading = "0.8"
cuda_types = { path = "../cuda_types" }
17 changes: 17 additions & 0 deletions zluda_replay/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[package]
name = "zluda_replay"
version = "0.0.0"
authors = ["Andrzej Janik <vosen@vosen.pl>"]
edition = "2021"

[[bin]]
name = "zluda_replay"

[dependencies]
zluda_trace_common = { path = "../zluda_trace_common" }
cuda_macros = { path = "../cuda_macros" }
cuda_types = { path = "../cuda_types" }
libloading = "0.8"

[package.metadata.zluda]
debug_only = true
Loading
Loading