Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions alan_compiler/src/std/root.ln
Original file line number Diff line number Diff line change
Expand Up @@ -1850,8 +1850,8 @@ fn{Rs} id(b: GBufferRaw) = {"alan_std::buffer_id" <- RootBacking :: GBufferRaw -
fn{Js} id(b: GBufferRaw) = {"alan_std.bufferid" <- RootBacking :: GBufferRaw -> string}(b);
fn id{T}(b: GBuffer{T}) = b.rawBuffer.id;
fn id(b: GBufferTagged) = b.rawBuffer.id;
fn{Rs} optimalLocalGroup() = {"alan_std::optimal_local_group" <- RootBacking :: () -> Buffer{i64, 3}}();
fn{Js} optimalLocalGroup() = {"alan_std.optimalLocalGroup" <- RootBacking :: () -> Buffer{i64, 3}}();
fn{Rs} optimalLocalGroup(global: Buffer{i64, 3}) = {"alan_std::optimal_local_group" <- RootBacking :: Deref{Buffer{i64, 3}} -> Buffer{i64, 3}}(global);
fn{Js} optimalLocalGroup(global: Buffer{i64, 3}) = {"alan_std.optimalLocalGroup" <- RootBacking :: Buffer{i64, 3} -> Buffer{i64, 3}}(global);
fn{Rs} GPGPU "alan_std::GPGPU::new" <- RootBacking :: (Own{string}, Own{Array{Array{GBufferRaw}}}, Deref{Buffer{i64, 3}}, Deref{Buffer{i64, 3}}) -> GPGPU;
fn{Js} GPGPU (src: string, gbuffers: Array{Array{GBufferRaw}}, idx: Buffer{i64, 3}, local: Buffer{i64, 3}) = {"new alan_std.GPGPU" <- RootBacking :: (string, Array{Array{GBufferRaw}}, Buffer{i32, 3}, Buffer{i32, 3}) -> GPGPU}(src, gbuffers, idx.map(i32), local.map(i32));
fn GPGPU{T}(src: string, buf: GBuffer{T}, local: Buffer{i64, 3}) -> GPGPU {
Expand Down Expand Up @@ -1906,10 +1906,9 @@ fn build{N}(ret: N) {
maxGlobalIdArray[1] ?? 0,
maxGlobalIdArray[2] ?? 0
);
let localGroup = optimalLocalGroup();
// Compute optimal local group size based on total work to perform
let localGroup = optimalLocalGroup(maxGlobalId);
// Clamp each local dimension so it doesn't exceed the corresponding global dimension.
// If local > global, one workgroup would launch more invocations than there are valid
// indices, causing out-of-bounds writes since shaders have no bounds checking.
let localX = max(min(localGroup.0, maxGlobalId.0), 1);
let localY = max(min(localGroup.1, maxGlobalId.1), 1);
let localZ = max(min(localGroup.2, maxGlobalId.2), 1);
Expand Down Expand Up @@ -2695,6 +2694,7 @@ fn gFor{T}(x: T, y: T, z: T) = gFor(x.u32, y.u32, z.u32);
fn gFor{T}(x: T, y: T) = gFor(x.u32, y.u32, 1.u32);
fn gFor{T}(x: T) = gFor(x.u32, 1.u32, 1.u32).x;
// TODO: More hackery to eliminate
fn gFor(x: u32, y: u32) = gFor(x, y, 1.u32);
fn gFor(x: i64, y: i64) {
let initialStatement = "@builtin(global_invocation_id) id: vec3u";
let statements = Dict(initialStatement, x.string.concat(',').concat(y.string).concat(',1'));
Expand Down Expand Up @@ -5088,7 +5088,11 @@ fn{Rs} context(f: Frame) = GBuffer{u32}({Property{"context.clone()"} :: Frame ->
fn{Js} context(f: Frame) = GBuffer{u32}({"alan_std.frameContext" <- RootBacking :: Frame -> GBufferRaw}(f));
fn{Rs} framebuffer(f: Frame) = GBuffer{u32}({Property{"framebuffer.clone()"} :: Frame -> GBufferRaw}(f));
fn{Js} framebuffer(f: Frame) = GBuffer{u32}({"alan_std.frameFramebuffer" <- RootBacking :: Frame -> GBufferRaw}(f));
fn pixel Frame = gFor(-1, -2); // Magic numbers for the binding
fn{Rs} width(f: Frame) = {Property{"width"} :: Frame -> u32}(f);
fn{Js} width(f: Frame) = {"alan_std.frameWidth" <- RootBacking :: Frame -> u32}(f);
fn{Rs} height(f: Frame) = {Property{"height"} :: Frame -> u32}(f);
fn{Js} height(f: Frame) = {"alan_std.frameHeight" <- RootBacking :: Frame -> u32}(f);
fn pixel(f: Frame) = gFor(f.width, f.height);

/// Process exit-related functions
fn{Rs} ExitCode "std::process::ExitCode::from" :: Own{u8} -> ExitCode;
Expand Down
130 changes: 68 additions & 62 deletions alan_std.js
Original file line number Diff line number Diff line change
Expand Up @@ -792,26 +792,57 @@ export class GPU {
}

let GPUS = null;
let OPTIMAL_LOCAL_GROUP = null;

export function optimalLocalGroup() {
if (OPTIMAL_LOCAL_GROUP === null) {
let SUBGROUP_MAX_SIZE = null;

// Typical subgroup_max_size values from wgpu docs (https://docs.rs/wgpu/28.0.0/wgpu/struct.AdapterInfo.html#structfield.subgroup_max_size)
const SUBGROUP_MAX_SIZE_BY_VENDOR = {
"Intel": 32, // Intel: 16 or 32, using upper bound
"AMD": 64, // AMD GCN/Vega: 64, RDNA+: 64
"Apple": 32, // Apple M-series (Metal backend)
"Google": 32, // ChromeOS (typically AMD/Intel)
"Qualcomm": 128, // Qualcomm: 128
"NVIDIA": 32, // NVIDIA: 32
"Microsoft": 128, // WARP software rasterizer: 4 or 128, using upper bound
};

function getSubgroupMaxSize() {
if (SUBGROUP_MAX_SIZE === null) {
if (GPUS !== null && GPUS.length > 0) {
let maxInvocations = GPUS[0].device.limits.maxComputeInvocationsPerWorkgroup;
let n = maxInvocations;
let sqrt = Math.floor(Math.sqrt(n));
if (sqrt * sqrt === n) {
OPTIMAL_LOCAL_GROUP = [sqrt, sqrt, 1];
return OPTIMAL_LOCAL_GROUP;
}
if (n % 8 === 0) {
OPTIMAL_LOCAL_GROUP = [n / 8, 8, 1];
return OPTIMAL_LOCAL_GROUP;
}
let vendor = GPUS[0].adapter.info.vendor || "";
// Try exact match first, then partial match
SUBGROUP_MAX_SIZE = SUBGROUP_MAX_SIZE_BY_VENDOR[vendor] || 64; // 64 is safe middle ground
} else {
SUBGROUP_MAX_SIZE = 64;
}
OPTIMAL_LOCAL_GROUP = [8, 8, 1];
}
return OPTIMAL_LOCAL_GROUP;
return SUBGROUP_MAX_SIZE;
}

export function optimalLocalGroup(global) {
let totalGlobal = Number(global[0]) * Number(global[1]) * Number(global[2]);
if (totalGlobal === 0) {
return [new I64(1), new I64(1), new I64(1)];
}
let subMax = getSubgroupMaxSize();
let maxInvocations = 256;
if (GPUS !== null && GPUS.length > 0) {
maxInvocations = GPUS[0].device.limits.maxComputeInvocationsPerWorkgroup;
}
// Target totalInvocationsPerWorkgroup so that totalWorkgroups ~ subgroup_max_size
let target = Math.ceil(totalGlobal / subMax);
// Clamp to [8, maxInvocations]
target = Math.max(8, Math.min(target, maxInvocations));
// Snap to nearest multiple of 8 (hardware alignment)
target = Math.ceil(target / 8) * 8;
// Shape: prefer S*S*1, then D*8*1
let sqrt = Math.floor(Math.sqrt(target));
if (sqrt >= 8 && sqrt * sqrt === target) {
return [new I64(sqrt), new I64(sqrt), new I64(1)];
}
if (target % 8 === 0) {
return [new I64(target / 8), new I64(8), new I64(1)];
}
return [new I64(target), new I64(1), new I64(1)];
}

export async function gpu() {
Expand Down Expand Up @@ -1105,6 +1136,14 @@ export function frameFramebuffer(frame) {
return frame.framebuffer;
}

export function frameWidth(frame) {
return frame.width;
}

export function frameHeight(frame) {
return frame.height;
}

export async function runWindow(initialContextFn, contextFn, gpgpuShaderFn) {
// None of this can run before `document.body` exists, so let's wait for that
if (document.readyState !== "complete" && document.readyState !== "loaded") {
Expand Down Expand Up @@ -1177,39 +1216,25 @@ export async function runWindow(initialContextFn, contextFn, gpgpuShaderFn) {
label: `buffer_${uuidv4().replaceAll('-', '_')}`,
});
buffer.ValKind = U32;
let gpgpuShaders = await gpgpuShaderFn({ context: contextBuffer, framebuffer: buffer });
let gpgpuShaders = await gpgpuShaderFn({ context: contextBuffer, framebuffer: buffer, width: width, height: height });
let redraw = async function() {
// First resize things if necessary
if (width !== context.canvas.width || height !== context.canvas.height) {
width = context.canvas.width;
height = context.canvas.height;
width = Math.max(1, context.canvas.width);
height = Math.max(1, context.canvas.height);
context.bufferWidth = (4 * width) % 256 === 0 ? 4 * width : 4 * width + (256 - ((4 * width) % 256));
bufferHeight = height;
bufferSize = context.bufferWidth * bufferHeight;
let oldBufferId = buffer.label;
let newBuffer = await device.createBuffer({
size: bufferSize,
usage: storageBufferType(),
label: `buffer_${uuidv4().replaceAll('-', '_')}`,
});
newBuffer.ValKind = U32;
for (let shader of gpgpuShaders) {
for (let group of shader.buffers) {
let idx = undefined;
for (let i = 0; i < group.length; i++) {
let buffer = group[i];
if (buffer.label == oldBufferId) {
idx = i;
break;
}
}
if (typeof(idx) !== 'undefined') {
group[idx] = newBuffer;
}
}
}
buffer.destroy();
buffer = newBuffer;
// Re-invoke the shader callback with new dimensions so shaders are regenerated
gpgpuShaders = await gpgpuShaderFn({ context: contextBuffer, framebuffer: buffer, width: width, height: height });
}
// Now, actually start drawing
let frame = surface.getCurrentTexture();
Expand Down Expand Up @@ -1260,32 +1285,13 @@ export async function runWindow(initialContextFn, contextFn, gpgpuShaderFn) {
for (let i = 0; i < gg.buffers.length; i++) {
cpass.setBindGroup(i, bindGroups[i]);
}
let x = 0;
let y = 0;
let lx = gg.localWorkgroupSize[0] ?? 8;
let ly = gg.localWorkgroupSize[1] ?? 8;
switch (gg.workgroupSizes[0].val) {
case -1:
x = Math.ceil(width / lx);
break;
case -2:
x = Math.ceil(height / lx);
break;
default:
x = Math.ceil(gg.workgroupSizes[0].val / lx);
}
switch (gg.workgroupSizes[1].val) {
case -1:
y = Math.ceil(width / ly);
break;
case -2:
y = Math.ceil(height / ly);
break;
default:
y = Math.ceil(gg.workgroupSizes[1].val / ly);
}
let z = gg.workgroupSizes[2].val;
cpass.dispatchWorkgroups(x, y, z);
let lx = Number(gg.localWorkgroupSize[0]) ?? 8;
let ly = Number(gg.localWorkgroupSize[1]) ?? 8;
cpass.dispatchWorkgroups(
Math.ceil(Number(gg.workgroupSizes[0]) / lx),
Math.ceil(Number(gg.workgroupSizes[1]) / ly),
Number(gg.workgroupSizes[2])
);
cpass.end();
}
encoder.copyBufferToTexture({
Expand Down
85 changes: 42 additions & 43 deletions alan_std/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@
}

static GPUS: OnceLock<Vec<GPU>> = OnceLock::new();
static OPTIMAL_LOCAL_GROUP: OnceLock<[i64; 3]> = OnceLock::new();
static SUBGROUP_MAX_SIZE: OnceLock<u32> = OnceLock::new();

fn gpu() -> &'static GPU {
match GPUS.get_or_init(|| GPU::init(GPU::list())).first() {
Expand All @@ -834,23 +834,38 @@
}
}

pub fn optimal_local_group() -> [i64; 3] {
*OPTIMAL_LOCAL_GROUP.get_or_init(|| {
fn subgroup_max_size() -> u32 {
*SUBGROUP_MAX_SIZE.get_or_init(|| {
let g = gpu();
let max_invocations = g.adapter.limits().max_compute_invocations_per_workgroup;
let n = max_invocations as u64;
let sqrt = (n as f64).sqrt() as u64;
if sqrt * sqrt == n {
return [sqrt as i64, sqrt as i64, 1];
}
if n % 8 == 0 {
let d = n / 8;
return [d as i64, 8, 1];
}
[8, 8, 1]
g.adapter.get_info().subgroup_max_size
})
}

pub fn optimal_local_group(global: [i64; 3]) -> [i64; 3] {
let total_global = (global[0] as u64) * (global[1] as u64) * (global[2] as u64);
if total_global == 0 {
return [1, 1, 1];
}
let g = gpu();
let sub_max = subgroup_max_size();
let max_invocations = g.adapter.limits().max_compute_invocations_per_workgroup as u64;
// Target totalInvocationsPerWorkgroup so that totalWorkgroups ~ subgroup_max_size
let mut target = (total_global as f64 / sub_max as f64).ceil() as u64;
// Clamp to [8, maxInvocations]
target = target.max(8).min(max_invocations);
// Snap to nearest multiple of 8 (hardware alignment)
target = ((target + 7) / 8) * 8;

Check warning

Code scanning / clippy

manually reimplementing div_ceil Warning

manually reimplementing div_ceil
// Shape: prefer S*S*1, then D*8*1
let sqrt = (target as f64).sqrt() as u64;
if sqrt >= 8 && sqrt * sqrt == target {
return [sqrt as i64, sqrt as i64, 1];
}
if target % 8 == 0 {

Check warning

Code scanning / clippy

manual implementation of .is_multiple_of() Warning

manual implementation of .is_multiple_of()
return [(target / 8) as i64, 8, 1];
}
[target as i64, 1, 1]
}

#[derive(Clone)]
pub struct GBuffer {
buffer: Rc<wgpu::Buffer>,
Expand Down Expand Up @@ -1312,6 +1327,8 @@
pub struct AlanWindowFrame {
pub context: GBuffer,
pub framebuffer: GBuffer,
pub width: u32,
pub height: u32,
}

pub struct AlanWindow<C, R>
Expand Down Expand Up @@ -1427,6 +1444,8 @@
self.gpgpu_shaders = Some((self.gpgpu_shader_fn)(&AlanWindowFrame {
context: self.context_buffer.as_ref().unwrap().clone(),
framebuffer: self.buffer.as_ref().unwrap().clone(),
width: size.width,
height: size.height,
}));
}
self.inited = true;
Expand Down Expand Up @@ -1475,8 +1494,6 @@
if !self.inited {
self.window_gpu_init();
}
// We need to create a new buffer with the right size *and* replace all instances
// of the old buffer in the GPGPU array with the new one.
let device = self.device.as_ref().unwrap();
new_size.width = new_size.width.max(1);
new_size.height = new_size.height.max(1);
Expand All @@ -1488,7 +1505,6 @@
let buffer_height = new_size.height;
let buffer_size =
(self.context.buffer_width.unwrap() as u64) * (buffer_height as u64);
let old_buffer_id = self.buffer.as_ref().unwrap().id.clone();
let new_buffer = GBuffer {
buffer: Rc::new(device.create_buffer(&wgpu::BufferDescriptor {
label: None,
Expand All @@ -1499,24 +1515,17 @@
id: format!("buffer_{}", format!("{}", Uuid::new_v4()).replace("-", "_")),
element_size: 4,
};
for shader in self.gpgpu_shaders.as_mut().unwrap() {
for group in &mut shader.buffers {
let mut idx = None;
for (i, buffer) in group.iter().enumerate() {
if buffer.id == old_buffer_id {
idx = Some(i);
break;
}
}
if let Some(id) = idx {
group[id] = new_buffer.clone();
}
}
}
if let Some(b) = &self.buffer {
b.destroy();
}
self.buffer = Some(new_buffer);
// Re-invoke the shader callback with new dimensions so shaders are regenerated
self.gpgpu_shaders = Some((self.gpgpu_shader_fn)(&AlanWindowFrame {
context: self.context_buffer.as_ref().unwrap().clone(),
framebuffer: self.buffer.as_ref().unwrap().clone(),
width: new_size.width,
height: new_size.height,
}));
self.context.window.as_ref().unwrap().request_redraw();
}
WindowEvent::RedrawRequested => {
Expand Down Expand Up @@ -1624,19 +1633,9 @@
}
let lx = gg.local_workgroup_size[0];
let ly = gg.local_workgroup_size[1];
let wx = match gg.workgroup_sizes[0] {
-1 => size.width as i64,
-2 => size.height as i64,
_ => gg.workgroup_sizes[0],
};
let wy = match gg.workgroup_sizes[1] {
-1 => size.width as i64,
-2 => size.height as i64,
_ => gg.workgroup_sizes[1],
};
cpass.dispatch_workgroups(
((wx + lx - 1) / lx) as u32,
((wy + ly - 1) / ly) as u32,
((gg.workgroup_sizes[0] + lx - 1) / lx) as u32,
((gg.workgroup_sizes[1] + ly - 1) / ly) as u32,
gg.workgroup_sizes[2] as u32,
);
}
Expand Down
Loading