From a39da7a71a619e10eca384f3c4cbeba7e28b3f39 Mon Sep 17 00:00:00 2001 From: Hugo Devillers Date: Fri, 24 Oct 2025 18:00:20 +0200 Subject: [PATCH 1/3] lower offloading calls at the IR level --- src/thorin/CMakeLists.txt | 2 + src/thorin/be/codegen.cpp | 114 ++---- src/thorin/be/codegen.h | 6 +- src/thorin/be/lower_offload_intrinsics.cpp | 413 +++++++++++++++++++++ src/thorin/be/lower_offload_intrinsics.h | 15 + src/thorin/be/runtime.h | 33 ++ 6 files changed, 507 insertions(+), 76 deletions(-) create mode 100644 src/thorin/be/lower_offload_intrinsics.cpp create mode 100644 src/thorin/be/lower_offload_intrinsics.h diff --git a/src/thorin/CMakeLists.txt b/src/thorin/CMakeLists.txt index 7910f0dfa..a512bd992 100644 --- a/src/thorin/CMakeLists.txt +++ b/src/thorin/CMakeLists.txt @@ -33,6 +33,8 @@ set(THORIN_SOURCES be/emitter.h be/c/c.cpp be/c/c.h + be/lower_offload_intrinsics.cpp + be/lower_offload_intrinsics.h be/runtime.h be/kernel_config.h tables/allnodes.h diff --git a/src/thorin/be/codegen.cpp b/src/thorin/be/codegen.cpp index 36aa34849..a9d2c21fb 100644 --- a/src/thorin/be/codegen.cpp +++ b/src/thorin/be/codegen.cpp @@ -21,14 +21,17 @@ #include "thorin/transform/hls_channels.h" #include "thorin/transform/hls_kernel_launch.h" +#include "lower_offload_intrinsics.h" namespace thorin { void Backend::prepare_kernel_configs() { device_code_.opt(); + Cont2Config adjusted_configs_map; + auto conts = device_code_.world().copy_continuations(); - for (auto continuation : kernels_) { + for (auto& [continuation, config] : kernel_configs_) { // recover the imported continuation (lost after the call to opt) Continuation* imported = nullptr; for (auto original_cont : conts) { @@ -37,27 +40,13 @@ void Backend::prepare_kernel_configs() { if (original_cont->name() == continuation->name()) imported = original_cont; } + assert(imported && "we lost a kernel ?"); if (!imported) continue; - visit_uses(continuation, [&] (Continuation* use) { - assert(use->has_body()); - - auto handler = backends_.intrinsics_.find(use->body()->callee()->as()->intrinsic()); - assert(handler != backends_.intrinsics_.end()); - auto [backend2, get_config] = handler->second; - assert(backend2 == this); - - auto config = get_config(use->body(), imported); - if (config) { - auto p = kernel_configs_.emplace(imported, std::move(config)); - assert_unused(p.second && "single kernel config entry expected"); - } - return false; - }, true); - - continuation->world().make_external(continuation); - continuation->destroy("codegen"); + adjusted_configs_map[imported] = std::move(config); } + + std::swap(kernel_configs_, adjusted_configs_map); } static const App* get_alloc_call(const Def* def) { @@ -211,7 +200,7 @@ struct ShadyBackend : public Backend { struct HLSBackend : public Backend { explicit HLSBackend(DeviceBackends& b, World& src, std::string& hls_flags) : Backend(b, src), hls_flags_(hls_flags) { - b.register_intrinsic(Intrinsic::HLS, *this, [&](const App* app, Continuation* imported) { + b.register_intrinsic(Intrinsic::HLS, *this, [&](const App* app, Continuation* kernel) { HLSKernelConfig::Param2Size param_sizes; for (size_t i = hls_free_vars_offset, e = app->num_args(); i != e; ++i) { auto arg = app->arg(i); @@ -237,7 +226,7 @@ struct HLSBackend : public Backend { b.world().edef(arg, "only pointers to arrays of primitive types are supported"); auto num_elems = size / (multiplier * num_bits(prim_type->primtype_tag()) / 8); // imported has type: fn (mem, fn (mem), ...) - param_sizes.emplace(imported->param(i - hls_free_vars_offset + 2), num_elems); + param_sizes.emplace(kernel->param(i - hls_free_vars_offset + 2), num_elems); } return std::make_unique(param_sizes); }); @@ -257,13 +246,13 @@ struct HLSBackend : public Backend { std::string& hls_flags_; }; -DeviceBackends::DeviceBackends(thorin::World& world, int opt, bool debug, std::string& hls_flags) : world_(world), opt_(opt), debug_(debug) { - register_backend(std::make_unique(*this, world)); - register_backend(std::make_unique(*this, world)); +DeviceBackends::DeviceBackends(World& world, int opt, bool debug, std::string& hls_flags) : world_(world), opt_(opt), debug_(debug) { + register_backend(std::make_unique(*this, world_)); + register_backend(std::make_unique(*this, world_)); #if THORIN_ENABLE_LLVM - register_backend(std::make_unique(*this, world)); - register_backend(std::make_unique(*this, world)); - register_backend(std::make_unique(*this, world)); + register_backend(std::make_unique(*this, world_)); + register_backend(std::make_unique(*this, world_)); + register_backend(std::make_unique(*this, world_)); #endif #if THORIN_ENABLE_SHADY register_backend(std::make_unique(*this, world)) @@ -272,9 +261,17 @@ DeviceBackends::DeviceBackends(thorin::World& world, int opt, bool debug, std::s register_backend(std::make_unique(*this, world)); register_backend(std::make_unique(*this, world)); #endif - register_backend(std::make_unique(*this, world, hls_flags)); + register_backend(std::make_unique(*this, world_, hls_flags)); + + lower_offload_intrinsics(world, *this); + + for (auto& backend : backends_) { + if (backend->thorin().world().empty()) + continue; - search_for_device_code(); + backend->prepare_kernel_configs(); + cgs.emplace_back(backend->create_cg()); + } } void DeviceBackends::register_backend(std::unique_ptr backend) { @@ -289,50 +286,21 @@ void DeviceBackends::register_intrinsic(thorin::Intrinsic intrinsic, Backend& ba intrinsics_[intrinsic] = std::make_pair(&backend, f); } -void DeviceBackends::search_for_device_code() { - // determine different parts of the world which need to be compiled differently - ScopesForest(world_).for_each([&] (const Scope& scope) { - auto continuation = scope.entry(); - Continuation* imported = nullptr; - - Intrinsic intrinsic = Intrinsic::None; - visit_capturing_intrinsics(continuation, [&] (Continuation* continuation) { - if (continuation->is_offload_intrinsic()) { - intrinsic = continuation->intrinsic(); - return true; - } - return false; - }); - - if (intrinsic == Intrinsic::None) - return; - - auto handler = intrinsics_.find(intrinsic); - assert(handler != intrinsics_.end()); - auto [backend, get_config] = handler->second; - - imported = backend->importer_->import(continuation)->as_nom(); - if (imported == nullptr) - return; - - // Necessary so that the names match in the original and imported worlds - imported->set_name(continuation->unique_name()); - continuation->set_name(continuation->unique_name()); - for (size_t i = 0, e = continuation->num_params(); i != e; ++i) - imported->param(i)->set_name(continuation->param(i)->name()); - imported->world().make_external(imported); - imported->attributes().cc = CC::C; - - backend->kernels_.emplace_back(continuation); - }); - - for (auto& backend : backends_) { - if (backend->thorin().world().empty()) - continue; - - backend->prepare_kernel_configs(); - cgs.emplace_back(backend->create_cg()); - } +void DeviceBackends::register_kernel_for_offloading(const App* launch, Continuation* kernel) { + Continuation* intrinsic_cont = launch->callee()->as_nom(); + auto handler = intrinsics_.find(intrinsic_cont->intrinsic()); + assert(handler != intrinsics_.end()); + auto [backend, get_config] = handler->second; + + // Import the continuation in the destination world + Continuation* imported = backend->importer_->import(kernel)->as_nom(); + assert(imported); + imported->world().make_external(imported); + imported->attributes().cc = CC::C; + + // Obtain the kernel config now + auto config = get_config(launch, kernel); + backend->kernel_configs_[kernel] = std::move(config); } CodeGen::CodeGen(Thorin& thorin, bool debug) diff --git a/src/thorin/be/codegen.h b/src/thorin/be/codegen.h index d4d804e94..b5fa0c6bd 100644 --- a/src/thorin/be/codegen.h +++ b/src/thorin/be/codegen.h @@ -43,7 +43,6 @@ struct Backend { Thorin device_code_; std::unique_ptr importer_; - std::vector kernels_; Cont2Config kernel_configs_; void prepare_kernel_configs(); @@ -53,6 +52,8 @@ struct Backend { struct DeviceBackends { DeviceBackends(World& world, int opt, bool debug, std::string& hls_flags); + DeviceBackends(DeviceBackends&) = delete; + World& world(); std::vector> cgs; @@ -63,6 +64,7 @@ struct DeviceBackends { using GetKernelConfigFn = std::function(const App*, Continuation*)>; void register_intrinsic(Intrinsic, Backend&, GetKernelConfigFn); + void register_kernel_for_offloading(const App* launch, Continuation*); private: World& world_; std::vector> backends_; @@ -70,8 +72,6 @@ struct DeviceBackends { int opt_; bool debug_; - - void search_for_device_code(); friend Backend; }; diff --git a/src/thorin/be/lower_offload_intrinsics.cpp b/src/thorin/be/lower_offload_intrinsics.cpp new file mode 100644 index 000000000..b890fd770 --- /dev/null +++ b/src/thorin/be/lower_offload_intrinsics.cpp @@ -0,0 +1,413 @@ +#include "lower_offload_intrinsics.h" + +#include "runtime.h" + +namespace thorin { + +struct RuntimeAPI { + World& world_; + DeviceBackends& backends_; + ContinuationMap unique_kernel_names_; + + const Def* anydsl_alloc; + const Def* anydsl_alloc_unified; + const Def* anydsl_release; + const Def* anydsl_launch_kernel; + const Def* anydsl_parallel_for; + const Def* anydsl_fibers_spawn; + const Def* anydsl_spawn_thread; + const Def* anydsl_sync_thread; + const Def* anydsl_create_graph; + const Def* anydsl_create_task; + const Def* anydsl_create_edge; + const Def* anydsl_execute_graph; + + std::string register_kernel_for_offloading(const App* launch, Continuation* kernel) { + auto found = unique_kernel_names_.find(kernel); + if (found != unique_kernel_names_.end()) + return found->second; + kernel->set_name(kernel->unique_name()); + unique_kernel_names_[kernel] = kernel->name(); + backends_.register_kernel_for_offloading(launch, kernel); + + kernel->world().make_external(kernel); + kernel->destroy("codegen"); + return kernel->name(); + } + + RuntimeAPI(World& world, DeviceBackends& backends) : world_(world), backends_(backends) { + auto get_api_fn = [&](Types dom, const Type* codom, std::string name) { + auto mem_ty = world.mem_type(); + auto found = world.find_cont(name.c_str()); + if (found) + return found; + auto r = codom ? world.return_type({mem_ty, codom}) : world.return_type({mem_ty}); + Array p = concat(mem_ty, concat(dom, r)); + auto c = world.continuation(world.fn_type(p), name); + c->attributes_.cc = CC::C; + world.make_external(c); + return c; + }; + + auto i32 = world.type_qs32(); + auto i64 = world.type_qs64(); + auto ptr_ty = world.ptr_type(world.indefinite_array_type(world.type_qu8())); + + anydsl_alloc = get_api_fn({ i32, i64 }, ptr_ty, "anydsl_alloc"); + anydsl_alloc_unified = get_api_fn({ i32, i64 }, ptr_ty, "anydsl_alloc_unified"); + anydsl_release = get_api_fn({ i32 }, nullptr, "anydsl_release"); + anydsl_launch_kernel = get_api_fn({ i32, ptr_ty, ptr_ty, ptr_ty, ptr_ty, ptr_ty, ptr_ty, ptr_ty, ptr_ty, ptr_ty, i32 }, nullptr, "anydsl_launch_kernel"); + anydsl_parallel_for = get_api_fn({ i32, i32, i32, ptr_ty, ptr_ty }, nullptr, "anydsl_parallel_for"); + anydsl_fibers_spawn = get_api_fn({ i32, i32, i32, ptr_ty, ptr_ty }, nullptr, "anydsl_fibers_spawn"); + anydsl_spawn_thread = get_api_fn({ ptr_ty, ptr_ty }, i32, "anydsl_spawn_thread"); + anydsl_sync_thread = get_api_fn({ i32 }, nullptr, "anydsl_sync_thread"); + anydsl_create_graph = get_api_fn({ i32 }, i32, "anydsl_create_graph"); + anydsl_create_task = get_api_fn({ i32, world.tuple_type({ ptr_ty, i64 }) }, i32, "anydsl_create_task"); + anydsl_create_edge = get_api_fn({ i32, i32}, nullptr, "anydsl_create_edge"); + anydsl_execute_graph = get_api_fn({ i32, i32}, nullptr, "anydsl_execute_graph"); + } +}; + +static bool contains_ptrtype(const Type* type) { + switch (type->tag()) { + case Node_PtrType: return false; + case Node_IndefiniteArrayType: return contains_ptrtype(type->as()->elem_type()); + case Node_DefiniteArrayType: return contains_ptrtype(type->as()->elem_type()); + case Node_FnType: return false; + case Node_StructType: { + bool good = true; + auto struct_type = type->as(); + for (auto& t : struct_type->types()) + good &= contains_ptrtype(t); + return good; + } + case Node_TupleType: { + bool good = true; + auto tuple = type->as(); + for (auto& t : tuple->types()) + good &= contains_ptrtype(t); + return good; + } + default: return true; + } +} + +void emit_host_code(RuntimeAPI& api, const App* launch, Platform platform, const std::string& ext, Continuation* continuation) { + World& world = continuation->world(); + + assert(continuation->has_body()); + auto body = continuation->body(); + // to-target is the desired kernel call + // target(mem, device, (dim.x, dim.y, dim.z), (block.x, block.y, block.z), body, return, free_vars) + auto target = body->callee()->as_nom(); + assert_unused(target->is_intrinsic()); + assert(body->num_args() >= KernelLaunchArgs::Num && "required arguments are missing"); + + // arguments + const Def* mem = body->arg(KernelLaunchArgs::Mem); + auto ret = body->arg(KernelLaunchArgs::Return); + + auto target_device_id = body->arg(KernelLaunchArgs::Device); + auto target_platform = world.literal_qs32(static_cast(platform), {}); + auto target_device = world.arithop_or(target_platform, world.arithop_shl(target_device_id, world.literal_qs32(4, {}))); + + auto it_space = body->arg(KernelLaunchArgs::Space); + auto it_config = body->arg(KernelLaunchArgs::Config); + auto kernel = body->arg(KernelLaunchArgs::Body)->as_nom(); + + //auto kernel_name = builder.CreateGlobalString(kernel->name() == "hls_top" ? kernel->name() : kernel->name()); + auto kernel_name = world.global_immutable_string(api.register_kernel_for_offloading(launch, kernel)); + auto file_name = world.global_immutable_string(world.name() + ext); + const size_t num_kernel_args = body->num_args() - KernelLaunchArgs::Num; + + auto ptr_ty = world.ptr_type(world.indefinite_array_type(world.type_qu8())); + + auto alloc = [&](const Type* t, std::string name) { + auto a = world.alloc(t, mem, { name }); + mem = world.extract(a, static_cast(0)); + return world.extract(a, 1); + }; + + auto store = [&](const Def* val, const Def* ptr) { + mem = world.store(mem, ptr, val); + }; + + // allocate argument pointers, sizes, and types + const Def* args = alloc(world.definite_array_type(ptr_ty, num_kernel_args), "args"); + const Def* sizes = alloc(world.definite_array_type(world.type_pu32(), num_kernel_args), "sizes"); + const Def* aligns = alloc(world.definite_array_type(world.type_pu32(), num_kernel_args), "aligns"); + const Def* allocs = alloc(world.definite_array_type(world.type_pu32(), num_kernel_args), "allocs"); + const Def* types = alloc(world.definite_array_type(world.type_qu8(), num_kernel_args), "types"); + + // fill array of arguments + for (size_t i = 0; i < num_kernel_args; ++i) { + auto target_arg = body->arg(i + KernelLaunchArgs::Num); + //const auto target_val = code_gen.emit(target_arg); + auto target_val = target_arg; + + KernelArgType arg_type; + const Def* void_ptr; + if (target_arg->type()->isa() || + target_arg->type()->isa() || + target_arg->type()->isa()) { + // definite array | struct | tuple + auto alloca = alloc(target_arg->type(), target_arg->name()); + store(target_val, alloca); + + // check if argument type contains pointers + if (!contains_ptrtype(target_arg->type())) + world.wdef(target_arg, "argument '{}' of aggregate type '{}' contains pointer (not supported in OpenCL 1.2)", target_arg, target_arg->type()); + + void_ptr = world.bitcast(ptr_ty, alloca); + arg_type = KernelArgType::Struct; + } else if (target_arg->type()->isa()) { + auto ptr = target_arg->type()->as(); + auto rtype = ptr->pointee(); + + //if (!rtype->isa()) + // world.edef(target_arg, "currently only pointers to arrays supported as kernel argument; argument has different type: {}", ptr); + + auto alloca = alloc(ptr_ty, target_arg->name()); + auto target_ptr = world.bitcast(ptr_ty, target_val); + store(target_ptr, alloca); + void_ptr = world.bitcast(ptr_ty, alloca); + arg_type = KernelArgType::Ptr; + } else { + // normal variable + auto alloca = alloc(target_arg->type(), target_arg->name()); + store(target_val, alloca); + + void_ptr = world.bitcast(ptr_ty, alloca); + arg_type = KernelArgType::Val; + } + + auto arg_ptr = world.lea(args, world.literal_pu32(i, {}), {}); + auto size_ptr = world.lea(sizes, world.literal_pu32(i, {}), {}); + auto align_ptr = world.lea(aligns, world.literal_pu32(i, {}), {}); + auto alloc_ptr = world.lea(allocs, world.literal_pu32(i, {}), {}); + auto type_ptr = world.lea(types, world.literal_pu32(i, {}), {}); + + auto size = world.size_of(target_arg->type()); + + //if (auto struct_type = llvm::dyn_cast(target_val->getType())) { + // // In the case of a structure, do not include the padding at the end in the size + // auto last_elem = struct_type->getStructNumElements() - 1; + // auto last_offset = layout_.getStructLayout(struct_type)->getElementOffset(last_elem); + // size = last_offset + layout_.getTypeStoreSize(struct_type->getStructElementType(last_elem)).getFixedValue(); + //} + + store(void_ptr, arg_ptr); + store(size, size_ptr); + store(world.align_of(target_arg->type()), align_ptr); + store(world.size_of(target_arg->type()), alloc_ptr); + store(world.literal_qu8((uint8_t)arg_type, {}), type_ptr); + } + + // allocate arrays for the grid and block size + const Def* grid_array = world.definite_array(world.type_qs32(), { + world.extract(it_space, 0_u32), + world.extract(it_space, 1_u32), + world.extract(it_space, 2_u32), + }); + const Def* grid_size = alloc(world.definite_array_type(world.type_qs32(), 3), "grid_size_alloc"); + store(grid_array, grid_size); + + const Def* block_array = world.definite_array(world.type_qs32(), { + world.extract(it_config, 0_u32), + world.extract(it_config, 1_u32), + world.extract(it_config, 2_u32), + }); + const Def* block_size = alloc(world.definite_array_type(world.type_qs32(), 3), "block_size_alloc"); + store(block_array, block_size); + + grid_size = world.bitcast(ptr_ty, grid_size); + block_size = world.bitcast(ptr_ty, block_size); + args = world.bitcast(ptr_ty, args); + sizes = world.bitcast(ptr_ty, sizes); + aligns = world.bitcast(ptr_ty, aligns); + allocs = world.bitcast(ptr_ty, allocs); + types = world.bitcast(ptr_ty, types); + + file_name = world.bitcast(ptr_ty, file_name); + kernel_name = world.bitcast(ptr_ty, kernel_name); + + continuation->set_body(world.app(api.anydsl_launch_kernel, {mem, target_device, file_name, kernel_name, grid_size, block_size, args, sizes, aligns, allocs, types, world.literal_qs32(num_kernel_args, {}), ret})); +} + +std::tuple> spill(const Def*& mem, const Defs& args, const Def*& wrapper_mem, const Def* wrapper_ptr) { + World& world = mem->world(); + StructType* st = world.struct_type("spillbox", args.size()); + for (size_t i = 0; i < args.size(); i++) + st->set_op(i, args[i]->type()); + + auto alloc = [&](const Type* t, std::string name) { + auto a = world.alloc(t, mem, { name }); + mem = world.extract(a, static_cast(0)); + return world.extract(a, 1); + }; + + auto store = [&](const Def* val, const Def* ptr) { + mem = world.store(mem, ptr, val); + }; + + const Def* spill_alloca = alloc(st, "spill"); + std::vector restored; + wrapper_ptr = world.bitcast(spill_alloca->type(), wrapper_ptr); + for (size_t i = 0; i < args.size(); i++) { + store(args[i], world.lea(spill_alloca, world.literal_pu32(i, {}), {})); + auto l = world.load(wrapper_mem, world.lea(wrapper_ptr, world.literal_pu32(i, {}), {})); + wrapper_mem = world.extract(l, static_cast(0)); + restored.push_back(world.extract(l, 1)); + } + + auto ptr_ty = world.ptr_type(world.indefinite_array_type(world.type_qu8())); + return std::make_tuple(world.bitcast(ptr_ty, spill_alloca), restored); +} + +enum class RuntimeParallelForArgs { + Mem = 0, + NumThreads, + Lower, + Upper, + Args, + Fun, + Return, +}; + +void emit_parallel(RuntimeAPI& api, Continuation* continuation) { + World& world = continuation->world(); + auto ptr_ty = world.ptr_type(world.indefinite_array_type(world.type_qu8())); + + assert(continuation->has_body()); + auto body = continuation->body(); + const Def* mem = body->arg(static_cast(ParallelForArgs::Mem)); + auto numthreads = body->arg(static_cast(ParallelForArgs::NumThreads)); + auto lower = body->arg(static_cast(ParallelForArgs::Lower)); + auto upper = body->arg(static_cast(ParallelForArgs::Upper)); + auto fun = body->arg(static_cast(ParallelForArgs::Fun))->as(); + auto ret = body->arg(static_cast(ParallelForArgs::Return)); + + // create loop iterating over range: + // for (int i=lower; i); + + auto wrapper = world.continuation(world.fn_type({world.mem_type(), ptr_ty, world.type_qs32(), world.type_qs32(), world.return_type({world.mem_type()})})); + world.make_external(wrapper); + const Def* wrapper_mem = wrapper->mem_param(); + auto inner_lower = wrapper->param(2); + auto inner_upper = wrapper->param(3); + auto [args, recovered] = spill(mem, body->args().skip_front(static_cast(ParallelForArgs::Num)), wrapper_mem, wrapper->param(1)); + + auto loop_head = world.continuation(world.fn_type({world.mem_type(), world.type_qs32()}), "loop_head"); + auto loop_body = world.continuation(world.fn_type({world.mem_type()}), "loop_body"); + auto loop_continue = world.continuation(world.fn_type({world.mem_type()}), "loop_continue"); + auto loop_exit = world.continuation(world.fn_type({world.mem_type()}), "loop_exit"); + loop_head->branch(loop_head->mem_param(), world.cmp_lt(loop_head->param(1), inner_upper), loop_body, loop_exit); + + Array prefix {loop_body->mem_param(), world.bitcast(fun->param(1)->type(), loop_head->param(1)), world.return_point(loop_continue) }; + loop_body->jump(fun, concat(prefix, recovered)); + loop_continue->jump(loop_head, { loop_continue->mem_param(), world.arithop_add(loop_head->param(1), world.literal_qs32(1, {})) }); + loop_exit->jump(wrapper->ret_param(), loop_exit->params_as_defs()); + wrapper->jump(loop_head, { wrapper->mem_param(), inner_lower }); + + continuation->set_body(world.app(api.anydsl_parallel_for, {mem, numthreads, lower, upper, world.bitcast(ptr_ty, args), world.bitcast(ptr_ty, wrapper), ret})); +} + +enum class RuntimeSpawnFibersArgs { + Mem = 0, + NumThreads, + NumBlocks, + NumWarps, + Args, + Fun, + Return, +}; + +void emit_fibers(RuntimeAPI& api, Continuation* continuation) { + World& world = continuation->world(); + auto ptr_ty = world.ptr_type(world.type_qu8()); + auto i32 = world.type_qs32(); + + assert(continuation->has_body()); + auto body = continuation->body(); + const Def* mem = body->arg(static_cast(SpawnFibersArgs::Mem)); + auto threads = body->arg(static_cast(SpawnFibersArgs::NumThreads)); + auto blocks = body->arg(static_cast(SpawnFibersArgs::NumBlocks)); + auto warps = body->arg(static_cast(SpawnFibersArgs::NumWarps)); + auto fun = body->arg(static_cast(SpawnFibersArgs::Fun))->as(); + auto ret = body->arg(static_cast(SpawnFibersArgs::Return)); + + auto wrapper = world.continuation(world.fn_type({world.mem_type(), ptr_ty, i32, i32, world.return_type({world.mem_type()})})); + const Def* wrapper_mem = wrapper->mem_param(); + auto [args, recovered] = spill(mem, body->args().skip_front(static_cast(SpawnThreadArgs::Num)), wrapper_mem, wrapper->param(1)); + Array prefix = {mem, wrapper->param(2), wrapper->param(3)}; + wrapper->jump(fun, concat(concat(prefix, recovered), {wrapper->ret_param()})); + + continuation->set_body(world.app(api.anydsl_fibers_spawn, {mem, threads, blocks, warps, args, fun, ret})); +} + +enum class RuntimeSpawnThreadArgs { + Mem = 0, + Args, + Fun, + Return, +}; + +void emit_spawn(RuntimeAPI& api, Continuation* continuation) { + World& world = continuation->world(); + auto ptr_ty = world.ptr_type(world.indefinite_array_type(world.type_qu8())); + + assert(continuation->has_body()); + auto body = continuation->body(); + const Def* mem = body->arg(static_cast(SpawnThreadArgs::Mem)); + auto fun = body->arg(static_cast(SpawnThreadArgs::Fun)); + auto ret = body->arg(static_cast(SpawnThreadArgs::Return)); + + auto wrapper = world.continuation(world.fn_type({world.mem_type(), ptr_ty, world.return_type({world.mem_type()})})); + const Def* wrapper_mem = wrapper->mem_param(); + auto [args, recovered] = spill(mem, body->args().skip_front(static_cast(SpawnThreadArgs::Num)), wrapper_mem, wrapper->param(1)); + wrapper->jump(fun, concat(concat(mem, recovered.ref()), {wrapper->ret_param()})); + + continuation->set_body(world.app(api.anydsl_sync_thread, {mem, world.bitcast(ptr_ty, args), world.bitcast(ptr_ty, wrapper), ret})); +} + +void emit_sync(RuntimeAPI& api, Continuation* continuation) { + World& world = continuation->world(); + + assert(continuation->has_body()); + auto body = continuation->body(); + const Def* mem = body->arg(static_cast(SyncArgs::Mem)); + auto id = body->arg(static_cast(SyncArgs::Id)); + auto ret = body->arg(static_cast(SyncArgs::Return)); + + continuation->set_body(world.app(api.anydsl_sync_thread, {mem, id, ret})); +} + +void lower_offload_intrinsics(World& world, DeviceBackends& backends) { + RuntimeAPI api(world, backends); + + for (auto continuation : world.copy_continuations()) { + if (!continuation->has_body()) continue; + auto call = continuation->body(); + if (auto callee = call->callee()->isa()) { + switch (callee->intrinsic()) { + case Intrinsic::CUDA: emit_host_code(api, call, Platform::CUDA_PLATFORM, ".cu", continuation); break; + case Intrinsic::NVVM: emit_host_code(api, call, Platform::CUDA_PLATFORM, ".nvvm", continuation); break; + case Intrinsic::OpenCL: emit_host_code(api, call, Platform::OPENCL_PLATFORM, ".cl", continuation); break; + case Intrinsic::OpenCL_SPIRV: emit_host_code(api, call, Platform::OPENCL_PLATFORM, ".spv", continuation); break; + case Intrinsic::LevelZero_SPIRV: emit_host_code(api, call, Platform::LEVEL_ZERO_PLATFORM, ".spv", continuation); break; + case Intrinsic::AMDGPUHSA: emit_host_code(api, call, Platform::HSA_PLATFORM, ".amdgpu", continuation); break; + case Intrinsic::AMDGPUPAL: emit_host_code(api, call, Platform::PAL_PLATFORM, ".amdgpu", continuation); break; + case Intrinsic::VulkanCS_SPIRV: emit_host_code(api, call, Platform::VULKAN_PLATFORM, ".spv", continuation); break; + + case Intrinsic::Parallel: emit_parallel(api, continuation); break; + case Intrinsic::Fibers: emit_fibers(api, continuation); break; + case Intrinsic::Spawn: emit_spawn(api, continuation); break; + case Intrinsic::Sync: emit_sync(api, continuation); break; + default: continue; + } + } + } +} + +} diff --git a/src/thorin/be/lower_offload_intrinsics.h b/src/thorin/be/lower_offload_intrinsics.h new file mode 100644 index 000000000..d0b5f76ef --- /dev/null +++ b/src/thorin/be/lower_offload_intrinsics.h @@ -0,0 +1,15 @@ +#ifndef THORIN_OFFLOAD_H +#define THORIN_OFFLOAD_H + +#include "codegen.h" +#include "thorin/world.h" + +namespace thorin { + +enum class KernelArgType : uint8_t { Val = 0, Ptr, Struct }; + +void lower_offload_intrinsics(World&, DeviceBackends&); + +} + +#endif //THORIN_OFFLOAD_H diff --git a/src/thorin/be/runtime.h b/src/thorin/be/runtime.h index 68cdd9b97..e2ddab211 100644 --- a/src/thorin/be/runtime.h +++ b/src/thorin/be/runtime.h @@ -24,6 +24,39 @@ enum KernelLaunchArgs { Num }; +enum class ParallelForArgs { + Mem = 0, + NumThreads, + Lower, + Upper, + Fun, + Return, + Num +}; + +enum class SpawnFibersArgs { + Mem = 0, + NumThreads, + NumBlocks, + NumWarps, + Fun, + Return, + Num +}; + +enum class SpawnThreadArgs { + Mem = 0, + Fun, + Return, + Num +}; + +enum class SyncArgs { + Mem = 0, + Id, + Return, +}; + } #endif From 9ef6f18b98be5f5ce38a3934f0a55ed994455710 Mon Sep 17 00:00:00 2001 From: Hugo Devillers Date: Fri, 24 Oct 2025 18:10:37 +0200 Subject: [PATCH 2/3] backport to master --- src/thorin/be/lower_offload_intrinsics.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/thorin/be/lower_offload_intrinsics.cpp b/src/thorin/be/lower_offload_intrinsics.cpp index b890fd770..67504f462 100644 --- a/src/thorin/be/lower_offload_intrinsics.cpp +++ b/src/thorin/be/lower_offload_intrinsics.cpp @@ -41,7 +41,7 @@ struct RuntimeAPI { auto found = world.find_cont(name.c_str()); if (found) return found; - auto r = codom ? world.return_type({mem_ty, codom}) : world.return_type({mem_ty}); + auto r = codom ? world.fn_type({mem_ty, codom}) : world.fn_type({mem_ty}); Array p = concat(mem_ty, concat(dom, r)); auto c = world.continuation(world.fn_type(p), name); c->attributes_.cc = CC::C; @@ -291,7 +291,7 @@ void emit_parallel(RuntimeAPI& api, Continuation* continuation) { // for (int i=lower; i); - auto wrapper = world.continuation(world.fn_type({world.mem_type(), ptr_ty, world.type_qs32(), world.type_qs32(), world.return_type({world.mem_type()})})); + auto wrapper = world.continuation(world.fn_type({world.mem_type(), ptr_ty, world.type_qs32(), world.type_qs32(), world.fn_type({world.mem_type()})})); world.make_external(wrapper); const Def* wrapper_mem = wrapper->mem_param(); auto inner_lower = wrapper->param(2); @@ -304,7 +304,7 @@ void emit_parallel(RuntimeAPI& api, Continuation* continuation) { auto loop_exit = world.continuation(world.fn_type({world.mem_type()}), "loop_exit"); loop_head->branch(loop_head->mem_param(), world.cmp_lt(loop_head->param(1), inner_upper), loop_body, loop_exit); - Array prefix {loop_body->mem_param(), world.bitcast(fun->param(1)->type(), loop_head->param(1)), world.return_point(loop_continue) }; + Array prefix {loop_body->mem_param(), world.bitcast(fun->param(1)->type(), loop_head->param(1)), loop_continue }; loop_body->jump(fun, concat(prefix, recovered)); loop_continue->jump(loop_head, { loop_continue->mem_param(), world.arithop_add(loop_head->param(1), world.literal_qs32(1, {})) }); loop_exit->jump(wrapper->ret_param(), loop_exit->params_as_defs()); @@ -337,7 +337,7 @@ void emit_fibers(RuntimeAPI& api, Continuation* continuation) { auto fun = body->arg(static_cast(SpawnFibersArgs::Fun))->as(); auto ret = body->arg(static_cast(SpawnFibersArgs::Return)); - auto wrapper = world.continuation(world.fn_type({world.mem_type(), ptr_ty, i32, i32, world.return_type({world.mem_type()})})); + auto wrapper = world.continuation(world.fn_type({world.mem_type(), ptr_ty, i32, i32, world.fn_type({world.mem_type()})})); const Def* wrapper_mem = wrapper->mem_param(); auto [args, recovered] = spill(mem, body->args().skip_front(static_cast(SpawnThreadArgs::Num)), wrapper_mem, wrapper->param(1)); Array prefix = {mem, wrapper->param(2), wrapper->param(3)}; @@ -363,7 +363,7 @@ void emit_spawn(RuntimeAPI& api, Continuation* continuation) { auto fun = body->arg(static_cast(SpawnThreadArgs::Fun)); auto ret = body->arg(static_cast(SpawnThreadArgs::Return)); - auto wrapper = world.continuation(world.fn_type({world.mem_type(), ptr_ty, world.return_type({world.mem_type()})})); + auto wrapper = world.continuation(world.fn_type({world.mem_type(), ptr_ty, world.fn_type({world.mem_type()})})); const Def* wrapper_mem = wrapper->mem_param(); auto [args, recovered] = spill(mem, body->args().skip_front(static_cast(SpawnThreadArgs::Num)), wrapper_mem, wrapper->param(1)); wrapper->jump(fun, concat(concat(mem, recovered.ref()), {wrapper->ret_param()})); @@ -398,7 +398,6 @@ void lower_offload_intrinsics(World& world, DeviceBackends& backends) { case Intrinsic::LevelZero_SPIRV: emit_host_code(api, call, Platform::LEVEL_ZERO_PLATFORM, ".spv", continuation); break; case Intrinsic::AMDGPUHSA: emit_host_code(api, call, Platform::HSA_PLATFORM, ".amdgpu", continuation); break; case Intrinsic::AMDGPUPAL: emit_host_code(api, call, Platform::PAL_PLATFORM, ".amdgpu", continuation); break; - case Intrinsic::VulkanCS_SPIRV: emit_host_code(api, call, Platform::VULKAN_PLATFORM, ".spv", continuation); break; case Intrinsic::Parallel: emit_parallel(api, continuation); break; case Intrinsic::Fibers: emit_fibers(api, continuation); break; From dbef006ec2afc43a26fca74efe8684f816aeca85 Mon Sep 17 00:00:00 2001 From: Hugo Devillers Date: Fri, 24 Oct 2025 18:11:30 +0200 Subject: [PATCH 3/3] remove LLVM-based runtime offloading support --- src/thorin/CMakeLists.txt | 4 - src/thorin/be/llvm/llvm.cpp | 27 +-- src/thorin/be/llvm/llvm.h | 7 +- src/thorin/be/llvm/parallel.cpp | 283 -------------------------------- src/thorin/be/llvm/runtime.cpp | 236 -------------------------- src/thorin/be/llvm/runtime.h | 62 ------- src/thorin/be/llvm/runtime.inc | 19 --- 7 files changed, 10 insertions(+), 628 deletions(-) delete mode 100644 src/thorin/be/llvm/parallel.cpp delete mode 100644 src/thorin/be/llvm/runtime.cpp delete mode 100644 src/thorin/be/llvm/runtime.h delete mode 100644 src/thorin/be/llvm/runtime.inc diff --git a/src/thorin/CMakeLists.txt b/src/thorin/CMakeLists.txt index a512bd992..ae534af93 100644 --- a/src/thorin/CMakeLists.txt +++ b/src/thorin/CMakeLists.txt @@ -106,10 +106,6 @@ if(LLVM_FOUND) be/llvm/amdgpu_pal.h be/llvm/nvvm.cpp be/llvm/nvvm.h - be/llvm/parallel.cpp - be/llvm/runtime.inc - be/llvm/runtime.cpp - be/llvm/runtime.h be/llvm/vectorize.cpp ) endif() diff --git a/src/thorin/be/llvm/llvm.cpp b/src/thorin/be/llvm/llvm.cpp index 4f7d80451..076d5d524 100644 --- a/src/thorin/be/llvm/llvm.cpp +++ b/src/thorin/be/llvm/llvm.cpp @@ -55,9 +55,16 @@ CodeGen::CodeGen( , function_calling_convention_(function_calling_convention) , device_calling_convention_(device_calling_convention) , kernel_calling_convention_(kernel_calling_convention) - , runtime_(std::make_unique(context(), module())) {} +llvm::Function* CodeGen::get(CodeGen& code_gen, const char* name) { + auto result = llvm::cast(module_->getOrInsertFunction(name, module_->getFunction(name)->getFunctionType()).getCallee()->stripPointerCasts()); + result->addFnAttr("target-cpu", code_gen.machine().getTargetCPU()); + result->addFnAttr("target-features", code_gen.machine().getTargetFeatureString()); + assert(result != nullptr && "Required runtime function could not be resolved"); + return result; +} + void CodeGen::optimize() { llvm::PassBuilder PB; llvm::OptimizationLevel opt_level; @@ -362,10 +369,6 @@ CodeGen::emit_module() { verify(); optimize(); - - // We need to delete the runtime at this point, since the ownership of - // the context and module is handed away. - runtime_.reset(); return std::pair { std::move(context_), std::move(module_) }; } @@ -1041,7 +1044,7 @@ void CodeGen::emit_phi_arg(llvm::IRBuilder<>& irbuilder, const Param* param, llv */ llvm::Value* CodeGen::emit_alloc(llvm::IRBuilder<>& irbuilder, const Type* type, const Def* extra) { - auto llvm_malloc = runtime_->get(*this, get_alloc_name().c_str()); + auto llvm_malloc = get(*this, get_alloc_name().c_str()); auto alloced_type = convert(type); llvm::CallInst* void_ptr; auto layout = module().getDataLayout(); @@ -1303,19 +1306,7 @@ std::vector CodeGen::emit_intrinsic(llvm::IRBuilder<>& irbuilder, case Intrinsic::CmpXchgWeak: return emit_cmpxchg(irbuilder, continuation, true); case Intrinsic::Fence: emit_fence(irbuilder, continuation); break; case Intrinsic::Reserve: return { emit_reserve(irbuilder, continuation) }; - case Intrinsic::CUDA: runtime_->emit_host_code(*this, irbuilder, Platform::CUDA_PLATFORM, ".cu", continuation); break; - case Intrinsic::NVVM: runtime_->emit_host_code(*this, irbuilder, Platform::CUDA_PLATFORM, ".nvvm", continuation); break; - case Intrinsic::OpenCL: runtime_->emit_host_code(*this, irbuilder, Platform::OPENCL_PLATFORM, ".cl", continuation); break; - case Intrinsic::OpenCL_SPIRV: runtime_->emit_host_code(*this, irbuilder, Platform::OPENCL_PLATFORM, ".spv", continuation); break; - case Intrinsic::LevelZero_SPIRV: runtime_->emit_host_code(*this, irbuilder, Platform::LEVEL_ZERO_PLATFORM, ".spv", continuation); break; - case Intrinsic::AMDGPUHSA: runtime_->emit_host_code(*this, irbuilder, Platform::HSA_PLATFORM, ".amdgpu", continuation); break; - case Intrinsic::AMDGPUPAL: runtime_->emit_host_code(*this, irbuilder, Platform::PAL_PLATFORM, ".amdgpu", continuation); break; - case Intrinsic::ShadyCompute: runtime_->emit_host_code(*this, irbuilder, Platform::SHADY_PLATFORM, ".shady", continuation); break; case Intrinsic::HLS: emit_hls(irbuilder, continuation); break; - case Intrinsic::Parallel: emit_parallel(irbuilder, continuation); break; - case Intrinsic::Fibers: emit_fibers(irbuilder, continuation); break; - case Intrinsic::Spawn: return { emit_spawn(irbuilder, continuation) }; - case Intrinsic::Sync: emit_sync(irbuilder, continuation); break; #if THORIN_ENABLE_RV case Intrinsic::Vectorize: emit_vectorize_continuation(irbuilder, continuation); break; #else diff --git a/src/thorin/be/llvm/llvm.h b/src/thorin/be/llvm/llvm.h index 7d5414331..8bc27d0d5 100644 --- a/src/thorin/be/llvm/llvm.h +++ b/src/thorin/be/llvm/llvm.h @@ -11,7 +11,6 @@ #include "thorin/analyses/schedule.h" #include "thorin/be/codegen.h" #include "thorin/be/emitter.h" -#include "thorin/be/llvm/runtime.h" #include "thorin/be/kernel_config.h" #include "thorin/transform/importer.h" @@ -43,6 +42,7 @@ class CodeGen : public thorin::CodeGen, public thorin::Emitter&, Continuation*); std::vector emit_intrinsic(llvm::IRBuilder<>&, Continuation*); void emit_hls(llvm::IRBuilder<>&, Continuation*); - void emit_parallel(llvm::IRBuilder<>&, Continuation*); - void emit_fibers(llvm::IRBuilder<>&, Continuation*); - llvm::Value* emit_spawn(llvm::IRBuilder<>&, Continuation*); - void emit_sync(llvm::IRBuilder<>&, Continuation*); void emit_vectorize_continuation(llvm::IRBuilder<>&, Continuation*); llvm::Value* emit_atomic(llvm::IRBuilder<>&, Continuation*); std::vector emit_cmpxchg(llvm::IRBuilder<>&, Continuation*, bool); @@ -136,7 +132,6 @@ class CodeGen : public thorin::CodeGen, public thorin::Emitter runtime_; #if THORIN_ENABLE_RV std::vector> vec_todo_; #endif diff --git a/src/thorin/be/llvm/parallel.cpp b/src/thorin/be/llvm/parallel.cpp deleted file mode 100644 index 96ab7142d..000000000 --- a/src/thorin/be/llvm/parallel.cpp +++ /dev/null @@ -1,283 +0,0 @@ -#include "thorin/be/llvm/llvm.h" - -namespace thorin::llvm { - -enum { - PAR_ARG_MEM, - PAR_ARG_NUMTHREADS, - PAR_ARG_LOWER, - PAR_ARG_UPPER, - PAR_ARG_BODY, - PAR_ARG_RETURN, - PAR_NUM_ARGS -}; - -void CodeGen::emit_parallel(llvm::IRBuilder<>& irbuilder, Continuation* continuation) { - assert(continuation->has_body()); - auto body = continuation->body(); - // Emit memory dependencies up to this point - emit_unsafe(body->arg(PAR_ARG_MEM)); - - // arguments - assert(body->num_args() >= PAR_NUM_ARGS && "required arguments are missing"); - auto num_threads = emit(body->arg(PAR_ARG_NUMTHREADS)); - auto lower = emit(body->arg(PAR_ARG_LOWER)); - auto upper = emit(body->arg(PAR_ARG_UPPER)); - auto kernel = body->arg(PAR_ARG_BODY)->as()->init()->as_nom(); - - const size_t num_kernel_args = body->num_args() - PAR_NUM_ARGS; - - // build parallel-function signature - Array par_args(num_kernel_args + 1); - par_args[0] = irbuilder.getInt32Ty(); // loop index - for (size_t i = 0; i < num_kernel_args; ++i) { - auto type = body->arg(i + PAR_NUM_ARGS)->type(); - par_args[i + 1] = convert(type); - } - - // fetch values and create a unified struct which contains all values (closure) - auto closure_type = convert(world().tuple_type(continuation->body()->callee()->type()->as()->types().skip_front(PAR_NUM_ARGS))); - llvm::Value* closure = llvm::UndefValue::get(closure_type); - if (num_kernel_args != 1) { - for (size_t i = 0; i < num_kernel_args; ++i) - closure = irbuilder.CreateInsertValue(closure, emit(body->arg(i + PAR_NUM_ARGS)), unsigned(i)); - } else { - closure = emit(body->arg(PAR_NUM_ARGS)); - } - - // allocate closure object and write values into it - auto ptr = emit_alloca(irbuilder, closure_type, "parallel_closure"); - irbuilder.CreateStore(closure, ptr, false); - - // create wrapper function and call the runtime - // wrapper(void* closure, int lower, int upper) - llvm::Type* wrapper_arg_types[] = { irbuilder.getPtrTy(), irbuilder.getInt32Ty(), irbuilder.getInt32Ty() }; - auto wrapper_ft = llvm::FunctionType::get(irbuilder.getVoidTy(), wrapper_arg_types, false); - auto wrapper_name = kernel->unique_name() + "_parallel_for"; - auto wrapper = (llvm::Function*)module_->getOrInsertFunction(wrapper_name, wrapper_ft).getCallee()->stripPointerCasts(); - wrapper->addFnAttr("target-cpu", machine_->getTargetCPU()); - wrapper->addFnAttr("target-features", machine_->getTargetFeatureString()); - runtime_->parallel_for(*this, irbuilder, num_threads, lower, upper, ptr, wrapper); - - // set insert point to the wrapper function - auto old_bb = irbuilder.GetInsertBlock(); - auto bb = llvm::BasicBlock::Create(*context_, wrapper_name, wrapper); - irbuilder.SetInsertPoint(bb); - - // extract all arguments from the closure - auto wrapper_args = wrapper->arg_begin(); - auto val = irbuilder.CreateLoad(closure_type, &*wrapper_args); - std::vector target_args(num_kernel_args + 1); - if (num_kernel_args != 1) { - for (size_t i = 0; i < num_kernel_args; ++i) - target_args[i + 1] = irbuilder.CreateExtractValue(val, { unsigned(i) }); - } else { - target_args[1] = val; - } - - // create loop iterating over range: - // for (int i=lower; i); - auto wrapper_lower = &*(++wrapper_args); - auto wrapper_upper = &*(++wrapper_args); - create_loop(irbuilder, wrapper_lower, wrapper_upper, irbuilder.getInt32(1), wrapper, [&](llvm::Value* counter) { - // call kernel body - target_args[0] = counter; // loop index - auto par_type = llvm::FunctionType::get(irbuilder.getVoidTy(), llvm_ref(par_args), false); - auto kernel_par_func = (llvm::Function*)module_->getOrInsertFunction(kernel->unique_name(), par_type).getCallee()->stripPointerCasts(); - irbuilder.CreateCall(kernel_par_func, target_args); - }); - irbuilder.CreateRetVoid(); - - // restore old insert point - irbuilder.SetInsertPoint(old_bb); -} - -enum { - FIB_ARG_MEM, - FIB_ARG_NUMTHREADS, - FIB_ARG_NUMBLOCKS, - FIB_ARG_NUMWARPS, - FIB_ARG_BODY, - FIB_ARG_RETURN, - FIB_NUM_ARGS -}; - -void CodeGen::emit_fibers(llvm::IRBuilder<>& irbuilder, Continuation* continuation) { - assert(continuation->has_body()); - auto body = continuation->body(); - // Emit memory dependencies up to this point - emit_unsafe(body->arg(FIB_ARG_MEM)); - - // arguments - assert(body->num_args() >= FIB_NUM_ARGS && "required arguments are missing"); - auto num_threads = emit(body->arg(FIB_ARG_NUMTHREADS)); - auto num_blocks = emit(body->arg(FIB_ARG_NUMBLOCKS)); - auto num_warps = emit(body->arg(FIB_ARG_NUMWARPS)); - auto kernel = body->arg(FIB_ARG_BODY)->as()->init()->as_nom(); - - const size_t num_kernel_args = body->num_args() - FIB_NUM_ARGS; - - // build fibers-function signature - Array fib_args(num_kernel_args + 2); - fib_args[0] = irbuilder.getInt32Ty(); // block index - fib_args[1] = irbuilder.getInt32Ty(); // warp index - for (size_t i = 0; i < num_kernel_args; ++i) { - auto type = body->arg(i + FIB_NUM_ARGS)->type(); - fib_args[i + 2] = convert(type); - } - - // fetch values and create a unified struct which contains all values (closure) - auto closure_type = convert(world().tuple_type(continuation->body()->callee()->type()->as()->types().skip_front(FIB_NUM_ARGS))); - llvm::Value* closure = llvm::UndefValue::get(closure_type); - if (num_kernel_args != 1) { - for (size_t i = 0; i < num_kernel_args; ++i) - closure = irbuilder.CreateInsertValue(closure, emit(body->arg(i + FIB_NUM_ARGS)), unsigned(i)); - } else { - closure = emit(body->arg(FIB_NUM_ARGS)); - } - - // allocate closure object and write values into it - auto ptr = emit_alloca(irbuilder, closure_type, "fibers_closure"); - irbuilder.CreateStore(closure, ptr, false); - - // create wrapper function and call the runtime - // wrapper(void* closure, int lower, int upper) - llvm::Type* wrapper_arg_types[] = { irbuilder.getPtrTy(), irbuilder.getInt32Ty(), irbuilder.getInt32Ty() }; - auto wrapper_ft = llvm::FunctionType::get(irbuilder.getVoidTy(), wrapper_arg_types, false); - auto wrapper_name = kernel->unique_name() + "_fibers"; - auto wrapper = (llvm::Function*)module_->getOrInsertFunction(wrapper_name, wrapper_ft).getCallee()->stripPointerCasts(); - wrapper->addFnAttr("target-cpu", machine_->getTargetCPU()); - wrapper->addFnAttr("target-features", machine_->getTargetFeatureString()); - runtime_->spawn_fibers(*this, irbuilder, num_threads, num_blocks, num_warps, ptr, wrapper); - - // set insert point to the wrapper function - auto old_bb = irbuilder.GetInsertBlock(); - auto bb = llvm::BasicBlock::Create(*context_, wrapper_name, wrapper); - irbuilder.SetInsertPoint(bb); - - // extract all arguments from the closure - auto wrapper_args = wrapper->arg_begin(); - auto val = irbuilder.CreateLoad(closure_type, &*wrapper_args); - std::vector target_args(num_kernel_args + 2); - if (num_kernel_args != 1) { - for (size_t i = 0; i < num_kernel_args; ++i) - target_args[i + 2] = irbuilder.CreateExtractValue(val, { unsigned(i) }); - } else { - target_args[2] = val; - } - - auto wrapper_block = &*(++wrapper_args); - auto wrapper_warp = &*(++wrapper_args); - - target_args[0] = wrapper_block; - target_args[1] = wrapper_warp; - - // call kernel body - auto fib_type = llvm::FunctionType::get(irbuilder.getVoidTy(), llvm_ref(fib_args), false); - auto kernel_fib_func = (llvm::Function*)module_->getOrInsertFunction(kernel->unique_name(), fib_type).getCallee()->stripPointerCasts(); - irbuilder.CreateCall(kernel_fib_func, target_args); - irbuilder.CreateRetVoid(); - - // restore old insert point - irbuilder.SetInsertPoint(old_bb); -} - -enum { - SPAWN_ARG_MEM, - SPAWN_ARG_BODY, - SPAWN_ARG_RETURN, - SPAWN_NUM_ARGS -}; - -llvm::Value* CodeGen::emit_spawn(llvm::IRBuilder<>& irbuilder, Continuation* continuation) { - assert(continuation->has_body()); - auto body = continuation->body(); - assert(body->num_args() >= SPAWN_NUM_ARGS && "required arguments are missing"); - - // Emit memory dependencies up to this point - emit_unsafe(body->arg(FIB_ARG_MEM)); - - auto kernel = body->arg(SPAWN_ARG_BODY)->as()->init()->as_nom(); - const size_t num_kernel_args = body->num_args() - SPAWN_NUM_ARGS; - - // build parallel-function signature - Array par_args(num_kernel_args); - for (size_t i = 0; i < num_kernel_args; ++i) { - auto type = body->arg(i + SPAWN_NUM_ARGS)->type(); - par_args[i] = convert(type); - } - - // fetch values and create a unified struct which contains all values (closure) - auto closure_type = convert(world().tuple_type(continuation->body()->callee()->type()->as()->types().skip_front(SPAWN_NUM_ARGS))); - llvm::Value* closure = nullptr; - if (closure_type->isStructTy()) { - closure = llvm::UndefValue::get(closure_type); - for (size_t i = 0; i < num_kernel_args; ++i) - closure = irbuilder.CreateInsertValue(closure, emit(body->arg(i + SPAWN_NUM_ARGS)), unsigned(i)); - } else { - closure = emit(body->arg(0 + SPAWN_NUM_ARGS)); - } - - // allocate closure object and write values into it - auto ptr = irbuilder.CreateAlloca(closure_type, nullptr); - irbuilder.CreateStore(closure, ptr, false); - - // create wrapper function and call the runtime - // wrapper(void* closure) - llvm::Type* wrapper_arg_types[] = { irbuilder.getPtrTy() }; - auto wrapper_ft = llvm::FunctionType::get(irbuilder.getVoidTy(), wrapper_arg_types, false); - auto wrapper_name = kernel->unique_name() + "_spawn_thread"; - auto wrapper = (llvm::Function*)module_->getOrInsertFunction(wrapper_name, wrapper_ft).getCallee()->stripPointerCasts(); - wrapper->addFnAttr("target-cpu", machine_->getTargetCPU()); - wrapper->addFnAttr("target-features", machine_->getTargetFeatureString()); - auto call = runtime_->spawn_thread(*this, irbuilder, ptr, wrapper); - - // set insert point to the wrapper function - auto old_bb = irbuilder.GetInsertBlock(); - auto bb = llvm::BasicBlock::Create(*context_, wrapper_name, wrapper); - irbuilder.SetInsertPoint(bb); - - // extract all arguments from the closure - auto wrapper_args = wrapper->arg_begin(); - auto val = irbuilder.CreateLoad(closure_type, &*wrapper_args); - std::vector target_args(num_kernel_args); - if (val->getType()->isStructTy()) { - for (size_t i = 0; i < num_kernel_args; ++i) - target_args[i] = irbuilder.CreateExtractValue(val, { unsigned(i) }); - } else { - target_args[0] = val; - } - - // call kernel body - auto par_type = llvm::FunctionType::get(irbuilder.getVoidTy(), llvm_ref(par_args), false); - auto kernel_par_func = (llvm::Function*)module_->getOrInsertFunction(kernel->unique_name(), par_type).getCallee()->stripPointerCasts(); - irbuilder.CreateCall(kernel_par_func, target_args); - irbuilder.CreateRetVoid(); - - // restore old insert point - irbuilder.SetInsertPoint(old_bb); - - return call; -} - -enum { - SYNC_ARG_MEM, - SYNC_ARG_ID, - SYNC_ARG_RETURN, - SYNC_NUM_ARGS -}; - -void CodeGen::emit_sync(llvm::IRBuilder<>& irbuilder, Continuation* continuation) { - assert(continuation->has_body()); - auto body = continuation->body(); - assert(body->num_args() == SYNC_NUM_ARGS && "wrong number of arguments"); - - // Emit memory dependencies up to this point - emit_unsafe(body->arg(FIB_ARG_MEM)); - - auto id = emit(body->arg(SYNC_ARG_ID)); - runtime_->sync_thread(*this, irbuilder, id); -} - -} diff --git a/src/thorin/be/llvm/runtime.cpp b/src/thorin/be/llvm/runtime.cpp deleted file mode 100644 index 310957fc6..000000000 --- a/src/thorin/be/llvm/runtime.cpp +++ /dev/null @@ -1,236 +0,0 @@ -#include "thorin/be/llvm/runtime.h" - -#include -#include - -#include -#include -#include -#include -#include - -#include "thorin/primop.h" -#include "thorin/be/llvm/llvm.h" -#include "thorin/be/llvm/runtime.inc" - -namespace thorin::llvm { - -Runtime::Runtime( - llvm::LLVMContext& context, - llvm::Module& target) - : target_(target) - , layout_(target.getDataLayout()) -{ - llvm::SMDiagnostic diag; - auto mem_buf = llvm::MemoryBuffer::getMemBuffer(runtime_definitions); - runtime_ = llvm::parseIR(*mem_buf.get(), diag, context); - if (runtime_ == nullptr) - throw std::logic_error("runtime could not be loaded"); -} - -llvm::Function* Runtime::get(CodeGen& code_gen, const char* name) { - auto result = llvm::cast(target_.getOrInsertFunction(name, runtime_->getFunction(name)->getFunctionType()).getCallee()->stripPointerCasts()); - result->addFnAttr("target-cpu", code_gen.machine().getTargetCPU()); - result->addFnAttr("target-features", code_gen.machine().getTargetFeatureString()); - assert(result != nullptr && "Required runtime function could not be resolved"); - return result; -} - -static bool contains_ptrtype(const Type* type) { - switch (type->tag()) { - case Node_PtrType: return false; - case Node_IndefiniteArrayType: return contains_ptrtype(type->as()->elem_type()); - case Node_DefiniteArrayType: return contains_ptrtype(type->as()->elem_type()); - case Node_FnType: return false; - case Node_StructType: { - bool good = true; - auto struct_type = type->as(); - for (auto& t : struct_type->types()) - good &= contains_ptrtype(t); - return good; - } - case Node_TupleType: { - bool good = true; - auto tuple = type->as(); - for (auto& t : tuple->types()) - good &= contains_ptrtype(t); - return good; - } - default: return true; - } -} - -void Runtime::emit_host_code(CodeGen& code_gen, llvm::IRBuilder<>& builder, Platform platform, const std::string& ext, Continuation* continuation) { - assert(continuation->has_body()); - auto body = continuation->body(); - // to-target is the desired kernel call - // target(mem, device, (dim.x, dim.y, dim.z), (block.x, block.y, block.z), body, return, free_vars) - auto target = body->callee()->as_nom(); - assert_unused(target->is_intrinsic()); - assert(body->num_args() >= KernelLaunchArgs::Num && "required arguments are missing"); - - // arguments - auto target_device_id = code_gen.emit(body->arg(KernelLaunchArgs::Device)); - auto target_platform = builder.getInt32(platform); - auto target_device = builder.CreateOr(target_platform, builder.CreateShl(target_device_id, builder.getInt32(4))); - - auto it_space = body->arg(KernelLaunchArgs::Space); - auto it_config = body->arg(KernelLaunchArgs::Config); - auto kernel = body->arg(KernelLaunchArgs::Body)->as()->init()->as(); - - auto& world = continuation->world(); - //auto kernel_name = builder.CreateGlobalString(kernel->name() == "hls_top" ? kernel->name() : kernel->name()); - auto kernel_name = builder.CreateGlobalString(kernel->name()); - auto file_name = builder.CreateGlobalString(world.name() + ext); - const size_t num_kernel_args = body->num_args() - KernelLaunchArgs::Num; - - // allocate argument pointers, sizes, and types - llvm::Value* args = code_gen.emit_alloca(builder, llvm::ArrayType::get(builder.getPtrTy(), num_kernel_args), "args"); - llvm::Value* sizes = code_gen.emit_alloca(builder, llvm::ArrayType::get(builder.getInt32Ty(), num_kernel_args), "sizes"); - llvm::Value* aligns = code_gen.emit_alloca(builder, llvm::ArrayType::get(builder.getInt32Ty(), num_kernel_args), "aligns"); - llvm::Value* allocs = code_gen.emit_alloca(builder, llvm::ArrayType::get(builder.getInt32Ty(), num_kernel_args), "allocs"); - llvm::Value* types = code_gen.emit_alloca(builder, llvm::ArrayType::get(builder.getInt8Ty(), num_kernel_args), "types"); - - // fill array of arguments - for (size_t i = 0; i < num_kernel_args; ++i) { - auto target_arg = body->arg(i + KernelLaunchArgs::Num); - const auto target_val = code_gen.emit(target_arg); - - KernelArgType arg_type; - llvm::Value* void_ptr; - if (target_arg->type()->isa() || - target_arg->type()->isa() || - target_arg->type()->isa()) { - // definite array | struct | tuple - auto alloca = code_gen.emit_alloca(builder, target_val->getType(), target_arg->name()); - builder.CreateStore(target_val, alloca); - - // check if argument type contains pointers - if (!contains_ptrtype(target_arg->type())) - world.wdef(target_arg, "argument '{}' of aggregate type '{}' contains pointer (not supported in OpenCL 1.2)", target_arg, target_arg->type()); - - void_ptr = builder.CreatePointerCast(alloca, builder.getPtrTy()); - arg_type = KernelArgType::Struct; - } else if (target_arg->type()->isa()) { - auto ptr = target_arg->type()->as(); - auto rtype = ptr->pointee(); - - if (!rtype->isa()) - world.edef(target_arg, "currently only pointers to arrays supported as kernel argument; argument has different type: {}", ptr); - - auto alloca = code_gen.emit_alloca(builder, builder.getPtrTy(), target_arg->name()); - auto target_ptr = builder.CreatePointerCast(target_val, builder.getPtrTy()); - builder.CreateStore(target_ptr, alloca); - void_ptr = builder.CreatePointerCast(alloca, builder.getPtrTy()); - arg_type = KernelArgType::Ptr; - } else { - // normal variable - auto alloca = code_gen.emit_alloca(builder, target_val->getType(), target_arg->name()); - builder.CreateStore(target_val, alloca); - - void_ptr = builder.CreatePointerCast(alloca, builder.getPtrTy()); - arg_type = KernelArgType::Val; - } - - auto arg_ptr = builder.CreateInBoundsGEP(llvm::cast(args)->getAllocatedType(), args, llvm::ArrayRef{builder.getInt32(0), builder.getInt32(i)}); - auto size_ptr = builder.CreateInBoundsGEP(llvm::cast(sizes)->getAllocatedType(), sizes, llvm::ArrayRef{builder.getInt32(0), builder.getInt32(i)}); - auto align_ptr = builder.CreateInBoundsGEP(llvm::cast(aligns)->getAllocatedType(), aligns, llvm::ArrayRef{builder.getInt32(0), builder.getInt32(i)}); - auto alloc_ptr = builder.CreateInBoundsGEP(llvm::cast(allocs)->getAllocatedType(), allocs, llvm::ArrayRef{builder.getInt32(0), builder.getInt32(i)}); - auto type_ptr = builder.CreateInBoundsGEP(llvm::cast(types)->getAllocatedType(), types, llvm::ArrayRef{builder.getInt32(0), builder.getInt32(i)}); - - auto size = layout_.getTypeStoreSize(target_val->getType()).getFixedValue(); - if (auto struct_type = llvm::dyn_cast(target_val->getType())) { - // In the case of a structure, do not include the padding at the end in the size - auto last_elem = struct_type->getStructNumElements() - 1; - auto last_offset = layout_.getStructLayout(struct_type)->getElementOffset(last_elem); - size = last_offset + layout_.getTypeStoreSize(struct_type->getStructElementType(last_elem)).getFixedValue(); - } - - builder.CreateStore(void_ptr, arg_ptr); - builder.CreateStore(builder.getInt32(size), size_ptr); - builder.CreateStore(builder.getInt32(layout_.getABITypeAlign(target_val->getType()).value()), align_ptr); - builder.CreateStore(builder.getInt32(layout_.getTypeAllocSize(target_val->getType())), alloc_ptr); - builder.CreateStore(builder.getInt8((uint8_t)arg_type), type_ptr); - } - - // allocate arrays for the grid and block size - const auto get_u32 = [&](const Def* def) { return builder.CreateSExt(code_gen.emit(def), builder.getInt32Ty()); }; - - llvm::Value* grid_array = llvm::UndefValue::get(llvm::ArrayType::get(builder.getInt32Ty(), 3)); - grid_array = builder.CreateInsertValue(grid_array, get_u32(world.extract(it_space, 0_u32)), 0); - grid_array = builder.CreateInsertValue(grid_array, get_u32(world.extract(it_space, 1_u32)), 1); - grid_array = builder.CreateInsertValue(grid_array, get_u32(world.extract(it_space, 2_u32)), 2); - llvm::Value* grid_size = code_gen.emit_alloca(builder, grid_array->getType(), ""); - builder.CreateStore(grid_array, grid_size); - - llvm::Value* block_array = llvm::UndefValue::get(llvm::ArrayType::get(builder.getInt32Ty(), 3)); - block_array = builder.CreateInsertValue(block_array, get_u32(world.extract(it_config, 0_u32)), 0); - block_array = builder.CreateInsertValue(block_array, get_u32(world.extract(it_config, 1_u32)), 1); - block_array = builder.CreateInsertValue(block_array, get_u32(world.extract(it_config, 2_u32)), 2); - llvm::Value* block_size = code_gen.emit_alloca(builder, block_array->getType(), ""); - builder.CreateStore(block_array, block_size); - - std::vector gep_first_elem{builder.getInt32(0), builder.getInt32(0)}; - grid_size = builder.CreateInBoundsGEP(llvm::cast(grid_size)->getAllocatedType(), grid_size, gep_first_elem); - block_size = builder.CreateInBoundsGEP(llvm::cast(block_size)->getAllocatedType(), block_size, gep_first_elem); - args = builder.CreateInBoundsGEP(llvm::cast(args)->getAllocatedType(), args, gep_first_elem); - sizes = builder.CreateInBoundsGEP(llvm::cast(sizes)->getAllocatedType(), sizes, gep_first_elem); - aligns = builder.CreateInBoundsGEP(llvm::cast(aligns)->getAllocatedType(), aligns, gep_first_elem); - allocs = builder.CreateInBoundsGEP(llvm::cast(allocs)->getAllocatedType(), allocs, gep_first_elem); - types = builder.CreateInBoundsGEP(llvm::cast(types)->getAllocatedType(), types, gep_first_elem); - - launch_kernel(code_gen, builder, target_device, - file_name, kernel_name, - grid_size, block_size, - args, sizes, aligns, allocs, types, - builder.getInt32(num_kernel_args)); -} - -llvm::Value* Runtime::launch_kernel( - CodeGen& code_gen, llvm::IRBuilder<>& builder, llvm::Value* device, - llvm::Value* file, llvm::Value* kernel, - llvm::Value* grid, llvm::Value* block, - llvm::Value* args, llvm::Value* sizes, llvm::Value* aligns, llvm::Value* allocs, llvm::Value* types, - llvm::Value* num_args) -{ - llvm::Value* launch_args[] = { device, file, kernel, grid, block, args, sizes, aligns, allocs, types, num_args }; - return builder.CreateCall(get(code_gen, "anydsl_launch_kernel"), launch_args); -} - -llvm::Value* Runtime::parallel_for( - CodeGen& code_gen, llvm::IRBuilder<>& builder, llvm::Value* num_threads, llvm::Value* lower, llvm::Value* upper, - llvm::Value* closure_ptr, llvm::Value* fun_ptr) -{ - llvm::Value* parallel_args[] = { - num_threads, lower, upper, - builder.CreatePointerCast(closure_ptr, builder.getPtrTy()), - builder.CreatePointerCast(fun_ptr, builder.getPtrTy()) - }; - return builder.CreateCall(get(code_gen, "anydsl_parallel_for"), parallel_args); -} - -llvm::Value* Runtime::spawn_fibers( - CodeGen& code_gen, llvm::IRBuilder<>& builder, llvm::Value* num_threads, llvm::Value* num_blocks, llvm::Value* num_warps, - llvm::Value* closure_ptr, llvm::Value* fun_ptr) -{ - llvm::Value* fibers_args[] = { - num_threads, num_blocks, num_warps, - builder.CreatePointerCast(closure_ptr, builder.getPtrTy()), - builder.CreatePointerCast(fun_ptr, builder.getPtrTy()) - }; - return builder.CreateCall(get(code_gen, "anydsl_fibers_spawn"), fibers_args); -} - -llvm::Value* Runtime::spawn_thread(CodeGen& code_gen, llvm::IRBuilder<>& builder, llvm::Value* closure_ptr, llvm::Value* fun_ptr) { - llvm::Value* spawn_args[] = { - builder.CreatePointerCast(closure_ptr, builder.getPtrTy()), - builder.CreatePointerCast(fun_ptr, builder.getPtrTy()) - }; - return builder.CreateCall(get(code_gen, "anydsl_spawn_thread"), spawn_args); -} - -llvm::Value* Runtime::sync_thread(CodeGen& code_gen, llvm::IRBuilder<>& builder, llvm::Value* id) { - return builder.CreateCall(get(code_gen, "anydsl_sync_thread"), id); -} - -} diff --git a/src/thorin/be/llvm/runtime.h b/src/thorin/be/llvm/runtime.h deleted file mode 100644 index fca64aa6f..000000000 --- a/src/thorin/be/llvm/runtime.h +++ /dev/null @@ -1,62 +0,0 @@ -#ifndef THORIN_BE_LLVM_RUNTIME_H -#define THORIN_BE_LLVM_RUNTIME_H - -#include - -#include -#include - -#include "thorin/world.h" -#include "thorin/be/runtime.h" - -namespace thorin::llvm { - -namespace llvm = ::llvm; - -class CodeGen; - -class Runtime { -public: - Runtime(llvm::LLVMContext&, llvm::Module&); - - /// Emits a call to anydsl_launch_kernel. - llvm::Value* launch_kernel( - CodeGen&, llvm::IRBuilder<>&, llvm::Value* device, - llvm::Value* file, llvm::Value* kernel, - llvm::Value* grid, llvm::Value* block, - llvm::Value* args, llvm::Value* sizes, llvm::Value* aligns, llvm::Value* allocs, llvm::Value* types, - llvm::Value* num_args); - - /// Emits a call to anydsl_parallel_for. - llvm::Value* parallel_for( - CodeGen&, llvm::IRBuilder<>&, - llvm::Value* num_threads, llvm::Value* lower, llvm::Value* upper, - llvm::Value* closure_ptr, llvm::Value* fun_ptr); - - /// Emits a call to anydsl_fibers_spawn. - llvm::Value* spawn_fibers( - CodeGen&, llvm::IRBuilder<>&, - llvm::Value* num_threads, llvm::Value* num_blocks, llvm::Value* num_warps, - llvm::Value* closure_ptr, llvm::Value* fun_ptr); - - /// Emits a call to anydsl_spawn_thread. - llvm::Value* spawn_thread(CodeGen&, llvm::IRBuilder<>&, llvm::Value* closure_ptr, llvm::Value* fun_ptr); - /// Emits a call to anydsl_sync_thread. - llvm::Value* sync_thread(CodeGen&, llvm::IRBuilder<>&, llvm::Value* id); - - void emit_host_code( - CodeGen& code_gen, llvm::IRBuilder<>& builder, - Platform platform, const std::string& ext, Continuation* continuation); - - llvm::Function* get(CodeGen& code_gen, const char* name); - -protected: - llvm::Module& target_; - const llvm::DataLayout& layout_; - - std::unique_ptr runtime_; -}; - -} - -#endif diff --git a/src/thorin/be/llvm/runtime.inc b/src/thorin/be/llvm/runtime.inc deleted file mode 100644 index 44ee7f9f8..000000000 --- a/src/thorin/be/llvm/runtime.inc +++ /dev/null @@ -1,19 +0,0 @@ -namespace thorin { - enum class KernelArgType : uint8_t { Val = 0, Ptr, Struct }; - - static const char* runtime_definitions = R"( - ; Module anydsl runtime decls - declare noalias ptr @anydsl_alloc(i32, i64); - declare noalias ptr @anydsl_alloc_unified(i32, i64); - declare void @anydsl_release(i32, ptr); - declare void @anydsl_launch_kernel(i32, ptr, ptr, ptr, ptr, ptr, ptr, ptr, ptr, ptr, i32); - declare void @anydsl_parallel_for(i32, i32, i32, ptr, ptr); - declare void @anydsl_fibers_spawn(i32, i32, i32, ptr, ptr); - declare i32 @anydsl_spawn_thread(ptr, ptr); - declare void @anydsl_sync_thread(i32); - declare i32 @anydsl_create_graph(); - declare i32 @anydsl_create_task(i32, { ptr, i64 }); - declare void @anydsl_create_edge(i32, i32); - declare void @anydsl_execute_graph(i32, i32); - )"; -}