Skip to content

Commit b94d60d

Browse files
authored
Simplify using new IRStructurizer.jl tools. (#148)
1 parent dd600c8 commit b94d60d

9 files changed

Lines changed: 257 additions & 484 deletions

File tree

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ DLFP8Types = "f4c16678-4a16-415b-82ef-ed337c5d6c7c"
2323
CUDAExt = "CUDA"
2424
DLFP8TypesExt = "DLFP8Types"
2525

26+
[sources]
27+
IRStructurizer = {url = "https://github.com/maleadt/IRStructurizer.jl", rev = "main"}
28+
2629
[compat]
2730
BFloat16s = "0.6"
2831
CUDA_Compiler_jll = "0.4"

src/compiler/codegen.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Codegen: Julia IR -> Tile IR bytecode
22

33
include("codegen/utils.jl")
4-
include("codegen/irutils.jl") # SSAMap/Block mutation helpers
54
include("codegen/passes/token_keys.jl") # TokenKey, TokenRole, ACQUIRE_TOKEN_KEY
65
include("codegen/passes/alias_analysis.jl") # alias_analysis_pass!
76
include("codegen/passes/token_order.jl") # token_order_pass!

src/compiler/codegen/control_flow.jl

Lines changed: 44 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,5 @@
11
# Structured IR Emission
22

3-
"""
4-
result_count(T) -> Int
5-
6-
Compute the number of results from a Block.types entry.
7-
Block.types contains Julia types:
8-
- For Statement: Julia type → 1 result
9-
- For ControlFlowOp with 0 results: Nothing → 0 results
10-
- For ControlFlowOp with 1 result: Julia type → 1 result
11-
- For ControlFlowOp with N results: Tuple{T1, T2, ...} → N results
12-
"""
13-
function result_count(@nospecialize(T))
14-
T === Nothing && return 0
15-
T <: Tuple && return length(T.parameters)
16-
return 1
17-
end
18-
193
"""
204
emit_block!(ctx, block::Block)
215
@@ -24,38 +8,37 @@ All SSA values use original Julia SSA indices (no local renumbering).
248
Values are stored in ctx.values by their original index.
259
"""
2610
function emit_block!(ctx::CGCtx, block::Block; skip_terminator::Bool=false)
27-
# SSAVector iteration yields (ssa_idx, entry) where entry has .stmt and .typ
28-
for (ssa_idx, entry) in block.body
29-
if entry.stmt isa ControlFlowOp
30-
n_results = result_count(entry.typ)
31-
emit_control_flow_op!(ctx, entry.stmt, entry.typ, n_results, ssa_idx)
11+
for inst in instructions(block)
12+
s = stmt(inst)
13+
if s isa ControlFlowOp
14+
emit_control_flow_op!(ctx, s, value_type(inst), inst.ssa_idx)
3215
else
33-
emit_statement!(ctx, entry.stmt, ssa_idx, entry.typ)
16+
emit_statement!(ctx, s, inst.ssa_idx, value_type(inst))
3417
end
3518
end
36-
if !skip_terminator && block.terminator !== nothing
37-
emit_terminator!(ctx, block.terminator)
19+
if !skip_terminator && terminator(block) !== nothing
20+
emit_terminator!(ctx, terminator(block))
3821
end
3922
end
4023

4124
"""
42-
emit_control_flow_op!(ctx, op::ControlFlowOp, result_type, n_results, original_idx)
25+
emit_control_flow_op!(ctx, op::ControlFlowOp, result_type, original_idx)
4326
4427
Emit bytecode for a structured control flow operation.
4528
Uses multiple dispatch on the concrete ControlFlowOp type.
4629
Results are stored at indices assigned AFTER nested regions (DFS order).
4730
original_idx is the original Julia SSA index for cross-block references.
4831
"""
49-
emit_control_flow_op!(ctx::CGCtx, op::IfOp, @nospecialize(rt), n::Int, idx::Int) = emit_if_op!(ctx, op, rt, n, idx)
50-
emit_control_flow_op!(ctx::CGCtx, op::ForOp, @nospecialize(rt), n::Int, idx::Int) = emit_for_op!(ctx, op, rt, n, idx)
51-
emit_control_flow_op!(ctx::CGCtx, op::WhileOp, @nospecialize(rt), n::Int, idx::Int) = emit_while_op!(ctx, op, rt, n, idx)
52-
emit_control_flow_op!(ctx::CGCtx, op::LoopOp, @nospecialize(rt), n::Int, idx::Int) = emit_loop_op!(ctx, op, rt, n, idx)
32+
emit_control_flow_op!(ctx::CGCtx, op::IfOp, @nospecialize(rt), idx::Int) = emit_if_op!(ctx, op, rt, idx)
33+
emit_control_flow_op!(ctx::CGCtx, op::ForOp, @nospecialize(rt), idx::Int) = emit_for_op!(ctx, op, rt, idx)
34+
emit_control_flow_op!(ctx::CGCtx, op::WhileOp, @nospecialize(rt), idx::Int) = emit_while_op!(ctx, op, rt, idx)
35+
emit_control_flow_op!(ctx::CGCtx, op::LoopOp, @nospecialize(rt), idx::Int) = emit_loop_op!(ctx, op, rt, idx)
5336

5437
#=============================================================================
5538
IfOp
5639
=============================================================================#
5740

58-
function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_results::Int, ssa_idx::Int)
41+
function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), ssa_idx::Int)
5942
cb = ctx.cb
6043

6144
# Get condition value
@@ -78,13 +61,13 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_
7861
then_body = function(_)
7962
saved = copy(ctx.block_args)
8063
emit_block!(ctx, op.then_region)
81-
op.then_region.terminator === nothing && encode_YieldOp!(ctx.cb, Value[])
64+
terminator(op.then_region) === nothing && encode_YieldOp!(ctx.cb, Value[])
8265
empty!(ctx.block_args); merge!(ctx.block_args, saved)
8366
end
8467
else_body = function(_)
8568
saved = copy(ctx.block_args)
8669
emit_block!(ctx, op.else_region)
87-
op.else_region.terminator === nothing && encode_YieldOp!(ctx.cb, Value[])
70+
terminator(op.else_region) === nothing && encode_YieldOp!(ctx.cb, Value[])
8871
empty!(ctx.block_args); merge!(ctx.block_args, saved)
8972
end
9073
results = encode_IfOp!(then_body, else_body, cb, result_types, cond_tv.v)
@@ -96,7 +79,7 @@ end
9679
ForOp
9780
=============================================================================#
9881

99-
function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type), n_results::Int, ssa_idx::Int)
82+
function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type), ssa_idx::Int)
10083
cb = ctx.cb
10184
body_blk = op.body
10285

@@ -138,7 +121,7 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type),
138121
end
139122
emit_block!(ctx, body_blk)
140123
# If body has no terminator, emit a ContinueOp with all carried values
141-
if body_blk.terminator === nothing
124+
if terminator(body_blk) === nothing
142125
encode_ContinueOp!(ctx.cb, block_args[2:end])
143126
end
144127
empty!(ctx.block_args); merge!(ctx.block_args, saved)
@@ -153,7 +136,7 @@ end
153136
LoopOp
154137
=============================================================================#
155138

156-
function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type), n_results::Int, ssa_idx::Int)
139+
function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type), ssa_idx::Int)
157140
cb = ctx.cb
158141
body_blk = op.body
159142

@@ -181,7 +164,7 @@ function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type)
181164
# In Tile IR, if the loop body ends with an IfOp (even one with continue/break
182165
# in all branches), the if is NOT a terminator. We need an explicit terminator
183166
# after the if. Add an unreachable ContinueOp as fallback terminator.
184-
if body_blk.terminator === nothing
167+
if terminator(body_blk) === nothing
185168
encode_ContinueOp!(ctx.cb, copy(block_args))
186169
end
187170
empty!(ctx.block_args); merge!(ctx.block_args, saved)
@@ -200,7 +183,7 @@ end
200183
nested region issues when "after" contains loops.
201184
=============================================================================#
202185

203-
function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_type), n_results::Int, ssa_idx::Int)
186+
function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_type), ssa_idx::Int)
204187
cb = ctx.cb
205188
before_blk = op.before
206189
after_blk = op.after
@@ -230,7 +213,7 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ
230213
emit_block!(ctx, before_blk)
231214

232215
# Get condition from ConditionOp terminator
233-
cond_op = before_blk.terminator
216+
cond_op = terminator(before_blk)
234217
cond_op isa ConditionOp || throw(IRError("WhileOp before region must end with ConditionOp"))
235218

236219
cond_tv = emit_value!(ctx, cond_op.condition)
@@ -242,7 +225,7 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ
242225
else_body = function(_)
243226
# Break with ConditionOp args (become loop results)
244227
break_operands = Value[]
245-
for arg in cond_op.args
228+
for arg in operands(cond_op)
246229
tv = emit_value!(ctx, arg)
247230
tv !== nothing && tv.v !== nothing && push!(break_operands, tv.v)
248231
end
@@ -261,10 +244,11 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ
261244

262245
# Map "after" region block args from ConditionOp.args (user carries)
263246
# and block_args (token carries beyond ConditionOp.args)
247+
cond_operands = operands(cond_op)
264248
for i in 1:length(after_blk.args)
265249
arg = after_blk.args[i]
266-
if i <= length(cond_op.args)
267-
tv = emit_value!(ctx, cond_op.args[i])
250+
if i <= length(cond_operands)
251+
tv = emit_value!(ctx, cond_operands[i])
268252
if tv !== nothing
269253
ctx[arg] = tv
270254
else
@@ -283,8 +267,9 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ
283267

284268
# Emit ContinueOp with yield values from after region's YieldOp
285269
continue_operands = Value[]
286-
if after_blk.terminator isa YieldOp
287-
for val in after_blk.terminator.values
270+
after_term = terminator(after_blk)
271+
if after_term isa YieldOp
272+
for val in operands(after_term)
288273
tv = emit_value!(ctx, val)
289274
tv !== nothing && tv.v !== nothing && push!(continue_operands, tv.v)
290275
end
@@ -304,7 +289,6 @@ end
304289

305290
#=============================================================================
306291
Terminators
307-
Token values are already in op.values (appended by token_order_pass!).
308292
=============================================================================#
309293

310294
"""
@@ -314,31 +298,17 @@ Emit bytecode for a block terminator.
314298
"""
315299
emit_terminator!(ctx::CGCtx, node::ReturnNode) = emit_return!(ctx, node)
316300

317-
function emit_terminator!(ctx::CGCtx, op::YieldOp)
318-
operands = Value[]
319-
for val in op.values
320-
tv = emit_value!(ctx, val)
321-
tv !== nothing && tv.v !== nothing && push!(operands, tv.v)
322-
end
323-
encode_YieldOp!(ctx.cb, operands)
324-
end
301+
_encode_term!(cb, ::YieldOp, v) = encode_YieldOp!(cb, v)
302+
_encode_term!(cb, ::ContinueOp, v) = encode_ContinueOp!(cb, v)
303+
_encode_term!(cb, ::BreakOp, v) = encode_BreakOp!(cb, v)
325304

326-
function emit_terminator!(ctx::CGCtx, op::ContinueOp)
327-
operands = Value[]
328-
for val in op.values
305+
function emit_terminator!(ctx::CGCtx, op::Union{YieldOp, ContinueOp, BreakOp})
306+
vals = Value[]
307+
for val in operands(op)
329308
tv = emit_value!(ctx, val)
330-
tv !== nothing && tv.v !== nothing && push!(operands, tv.v)
309+
tv !== nothing && tv.v !== nothing && push!(vals, tv.v)
331310
end
332-
encode_ContinueOp!(ctx.cb, operands)
333-
end
334-
335-
function emit_terminator!(ctx::CGCtx, op::BreakOp)
336-
operands = Value[]
337-
for val in op.values
338-
tv = emit_value!(ctx, val)
339-
tv !== nothing && tv.v !== nothing && push!(operands, tv.v)
340-
end
341-
encode_BreakOp!(ctx.cb, operands)
311+
_encode_term!(ctx.cb, op, vals)
342312
end
343313

344314
emit_terminator!(ctx::CGCtx, ::Nothing) = nothing
@@ -366,28 +336,14 @@ ReturnNode (REGION_TERMINATION with 3 children). The 2-child case
366336
(early return inside a loop) is not handled.
367337
"""
368338
function hoist_returns!(block::Block)
369-
for (_, entry) in block.body
370-
stmt = entry.stmt
371-
if stmt isa IfOp
372-
hoist_returns!(stmt.then_region)
373-
hoist_returns!(stmt.else_region)
374-
elseif stmt isa ForOp
375-
hoist_returns!(stmt.body)
376-
elseif stmt isa WhileOp
377-
hoist_returns!(stmt.before)
378-
hoist_returns!(stmt.after)
379-
elseif stmt isa LoopOp
380-
hoist_returns!(stmt.body)
381-
end
382-
end
383-
for (_, entry) in block.body
384-
entry.stmt isa IfOp || continue
385-
op = entry.stmt::IfOp
386-
op.then_region.terminator isa ReturnNode || continue
387-
op.else_region.terminator isa ReturnNode || continue
388-
op.then_region.terminator = YieldOp()
389-
op.else_region.terminator = YieldOp()
390-
block.terminator = ReturnNode(nothing)
339+
walk(block; order=:postorder) do inst, blk
340+
s = stmt(inst)
341+
s isa IfOp || return
342+
terminator(s.then_region) isa ReturnNode || return
343+
terminator(s.else_region) isa ReturnNode || return
344+
terminator!(s.then_region, YieldOp())
345+
terminator!(s.else_region, YieldOp())
346+
terminator!(blk, ReturnNode(nothing))
391347
end
392348
end
393349

src/compiler/codegen/irutils.jl

Lines changed: 0 additions & 104 deletions
This file was deleted.

src/compiler/codegen/kernel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ function emit_subprogram!(ctx::CGCtx, func, arg_types::Vector,
364364
emit_block!(sub_ctx, sci.entry; skip_terminator=true)
365365

366366
# 6. Extract return value and yield
367-
ret = sci.entry.terminator::ReturnNode
367+
ret = terminator(sci.entry)::ReturnNode
368368
tv = emit_value!(sub_ctx, ret.val)
369369
if tv.tuple !== nothing
370370
# Tuple return: resolve each component to a concrete Value

0 commit comments

Comments
 (0)