Skip to content

Commit d522a1d

Browse files
committed
Add alias-aware token threading for memory operations.
Introduce alias analysis based token threading: - Group pointers into alias sets. - Maintain per-alias-set token chains. - Thread tokens only between potentially aliasing operations. - Conservatively fall back to the global set for unknown pointers. - Preserve existing control-flow token merging semantics. Enables independent memory operations to execute without unnecessary serialization.
1 parent 538508a commit d522a1d

9 files changed

Lines changed: 641 additions & 46 deletions

File tree

src/compiler/codegen.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Codegen: Julia IR -> Tile IR bytecode
22

33
include("codegen/utils.jl")
4+
include("codegen/token_keys.jl") # Defines TokenKey, TokenRole, ACQUIRE_TOKEN_KEY
5+
include("codegen/alias_analysis.jl") # Defines alias_analysis_pass!
6+
include("codegen/token_order.jl") # Defines get_alias_set, get_input_token!
47
include("codegen/kernel.jl")
58
include("codegen/control_flow.jl")
69
include("codegen/statements.jl")
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
"""
2+
AliasTracker
3+
4+
Tracks alias sets for each SSA value during fixed-point analysis.
5+
"""
6+
mutable struct AliasTracker
7+
dirty::Bool
8+
aliases::Dict{Any, AliasSet} # SSAValue/Argument/SlotNumber -> AliasSet
9+
end
10+
11+
AliasTracker() = AliasTracker(false, Dict{Any, AliasSet}())
12+
13+
function Base.getindex(tracker::AliasTracker, key)
14+
return get(tracker.aliases, key, ALIAS_UNIVERSE)
15+
end
16+
17+
function Base.setindex!(tracker::AliasTracker, value::AliasSet, key)
18+
current = get(tracker.aliases, key, nothing)
19+
if current !== value
20+
tracker.dirty = true
21+
tracker.aliases[key] = value
22+
end
23+
return
24+
end
25+
26+
"""
27+
alias_analysis_pass!(sci::StructuredIRCode) -> Dict{Any, AliasSet}
28+
29+
Perform fixed-point alias analysis on structured IR.
30+
Returns mapping from SSA values to alias sets.
31+
"""
32+
function alias_analysis_pass!(sci::StructuredIRCode)
33+
tracker = AliasTracker()
34+
35+
# Initialize: each argument gets its own alias set
36+
for (idx, argtype) in enumerate(sci.argtypes)
37+
argtype_unwrapped = CC.widenconst(argtype)
38+
if contains_pointers(argtype_unwrapped)
39+
arg_ref = Argument(idx)
40+
tracker[arg_ref] = Set{Any}([arg_ref])
41+
end
42+
end
43+
44+
# Fixed-point iteration
45+
iteration = 0
46+
max_iterations = 100
47+
48+
tracker.dirty = true
49+
while tracker.dirty && iteration < max_iterations
50+
tracker.dirty = false
51+
iteration += 1
52+
53+
analyze_block!(tracker, sci.entry)
54+
end
55+
56+
@debug "Alias analysis converged in $iteration iterations"
57+
58+
return tracker.aliases
59+
end
60+
61+
"""
62+
propagate!(tracker::AliasTracker, from, to)
63+
64+
Propagate alias set from `from` to `to`.
65+
Uses direct assignment when `to` is uninitialized, union otherwise.
66+
"""
67+
function propagate!(tracker::AliasTracker, from, to)
68+
from_aliases = tracker[from]
69+
70+
if from_aliases === ALIAS_UNIVERSE
71+
# Propagating UNIVERSE is always conservative
72+
tracker[to] = ALIAS_UNIVERSE
73+
return
74+
end
75+
76+
if haskey(tracker.aliases, to)
77+
# Target already has an alias set union with it
78+
to_aliases = tracker.aliases[to]
79+
new_aliases = union(from_aliases, to_aliases)
80+
if new_aliases != to_aliases
81+
tracker[to] = new_aliases
82+
end
83+
else
84+
# Target not yet analyzed assign directly
85+
tracker[to] = from_aliases
86+
end
87+
return
88+
end
89+
90+
"""
91+
analyze_block!(tracker::AliasTracker, block)
92+
93+
Analyze all statements in a block, recursing into nested control flow.
94+
"""
95+
function analyze_block!(tracker::AliasTracker, block)
96+
for (ssa_idx, entry) in block.body
97+
if entry.stmt isa ControlFlowOp
98+
analyze_control_flow!(tracker, entry.stmt)
99+
else
100+
analyze_statement!(tracker, SSAValue(ssa_idx), entry.stmt)
101+
end
102+
end
103+
return
104+
end
105+
106+
# Recurse into nested control flow regions
107+
function analyze_control_flow!(tracker::AliasTracker, op::IfOp)
108+
analyze_block!(tracker, op.then_region)
109+
return analyze_block!(tracker, op.else_region)
110+
end
111+
112+
function analyze_control_flow!(tracker::AliasTracker, op::ForOp)
113+
return analyze_block!(tracker, op.body)
114+
end
115+
116+
function analyze_control_flow!(tracker::AliasTracker, op::WhileOp)
117+
analyze_block!(tracker, op.before)
118+
return analyze_block!(tracker, op.after)
119+
end
120+
121+
function analyze_control_flow!(tracker::AliasTracker, op::LoopOp)
122+
return analyze_block!(tracker, op.body)
123+
end
124+
125+
# Fallback for unknown control flow ops
126+
function analyze_control_flow!(::AliasTracker, ::ControlFlowOp)
127+
return
128+
end
129+
130+
"""
131+
analyze_statement!(tracker::AliasTracker, ssa::SSAValue, stmt)
132+
133+
Analyze a single statement and propagate aliases.
134+
Handles both `:call` and `:invoke` expression forms.
135+
"""
136+
function analyze_statement!(tracker::AliasTracker, ssa::SSAValue, stmt)
137+
if stmt isa Expr && (stmt.head === :call || stmt.head === :invoke)
138+
# Normalize :call and :invoke into (func, operands)
139+
# :call -> args = [func, operands...]
140+
# :invoke -> args = [MethodInstance, func, operands...]
141+
if stmt.head === :call
142+
func = stmt.args[1]
143+
operands = @view stmt.args[2:end]
144+
else # :invoke
145+
func = stmt.args[2]
146+
operands = @view stmt.args[3:end]
147+
end
148+
149+
# Resolve func to its runtime value for intrinsic matching.
150+
# In :invoke, func may already be the function object (not a GlobalRef).
151+
resolved_func = if func isa GlobalRef
152+
try
153+
getfield(func.mod, func.name)
154+
catch
155+
nothing
156+
end
157+
else
158+
func # Direct function value (common in :invoke)
159+
end
160+
161+
# getfield: propagate from parent
162+
if func === GlobalRef(Core, :getfield) && length(operands) >= 1
163+
field = length(operands) >= 2 ? operands[2] : nothing
164+
165+
# For TileArray.ptr field access, propagate pointer alias
166+
if field isa QuoteNode && field.value === :ptr
167+
propagate!(tracker, operands[1], ssa)
168+
else
169+
# Conservatively mark as UNIVERSE for non-pointer fields
170+
tracker[ssa] = ALIAS_UNIVERSE
171+
end
172+
173+
# Pointer arithmetic: propagate from pointer operand
174+
elseif func === GlobalRef(Base, :+) || func === GlobalRef(Base, :-)
175+
for arg in operands
176+
# Find the pointer argument and propagate
177+
arg_aliases = tracker[arg]
178+
if arg_aliases !== ALIAS_UNIVERSE && arg_aliases isa Set
179+
propagate!(tracker, arg, ssa)
180+
break
181+
end
182+
end
183+
184+
# View construction: propagate alias from first operand
185+
elseif is_view_constructor(resolved_func) || is_pointer_passthrough(resolved_func)
186+
if length(operands) >= 1
187+
propagate!(tracker, operands[1], ssa)
188+
end
189+
190+
# Default: unknown operation -> UNIVERSE
191+
else
192+
tracker[ssa] = ALIAS_UNIVERSE
193+
end
194+
195+
elseif stmt isa ReturnNode
196+
# No alias propagation needed
197+
198+
else
199+
# Unknown statement type -> conservative
200+
tracker[ssa] = ALIAS_UNIVERSE
201+
end
202+
return
203+
end
204+
205+
# Helper functions
206+
contains_pointers(T) = T <: Ptr || T <: TileArray || (T <: Tile && eltype(T) <: Ptr)
207+
208+
"""
209+
is_view_constructor(func) -> Bool
210+
211+
Check if a resolved function is a tensor/partition view constructor.
212+
These propagate alias identity from their first operand.
213+
"""
214+
function is_view_constructor(func)
215+
return func === Intrinsics.make_tensor_view ||
216+
func === Intrinsics.make_partition_view
217+
end
218+
219+
function is_pointer_passthrough(func)
220+
func === GlobalRef(Core.Intrinsics, :bitcast) && return true
221+
222+
# Safely check by name to avoid UndefVarError if intrinsics aren't exposed
223+
if func isa Core.IntrinsicFunction || func isa Function
224+
n = nameof(func)
225+
return n === :bitcast || n === :assume_div_by || n === :assume_bounded || n === :assume_aligned
226+
end
227+
return false
228+
end

src/compiler/codegen/control_flow.jl

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,14 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_
8888
# Save token before branches
8989
token_before = ctx.token
9090

91+
# Save token_map before branches
92+
token_map_before = copy(ctx.token_map)
93+
9194
# Emit IfOp with callback-based region building
9295
then_body = function(_)
9396
saved_block_args = copy(ctx.block_args)
9497
ctx.token = token_before # Reset to pre-branch token
98+
ctx.token_map = copy(token_map_before) # Reset token_map too
9599
emit_block!(ctx, then_blk)
96100
if then_blk.terminator === nothing
97101
encode_YieldOp!(ctx.cb, [ctx.token])
@@ -102,6 +106,7 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_
102106
else_body = function(_)
103107
saved_block_args = copy(ctx.block_args)
104108
ctx.token = token_before # Reset to pre-branch token
109+
ctx.token_map = copy(token_map_before) # Reset token_map too
105110
emit_block!(ctx, else_blk)
106111
if else_blk.terminator === nothing
107112
encode_YieldOp!(ctx.cb, [ctx.token])
@@ -114,6 +119,12 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_
114119
# Last result is the merged token from both branches
115120
ctx.token = results[end]
116121

122+
# Merge token_map from both branches
123+
# Conservatively reset to token_before for all keys
124+
for key in keys(ctx.token_map)
125+
ctx.token_map[key] = results[end]
126+
end
127+
117128
# Store results at IfOp's SSA index (may be empty for void-returning ifs)
118129
ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type)
119130
end
@@ -164,6 +175,9 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type),
164175
# Number of user result types (excluding token)
165176
n_user_results = n_carries
166177

178+
# Save token_map before loop
179+
token_map_before = copy(ctx.token_map)
180+
167181
# Emit ForOp with callback-based region building
168182
body_builder = function(block_args)
169183
saved_block_args = copy(ctx.block_args)
@@ -193,8 +207,11 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type),
193207
end
194208
results = encode_ForOp!(body_builder, cb, result_types, iv_type, lower_tv.v, upper_tv.v, step_tv.v, init_values)
195209

196-
# Last result is the token
197-
ctx.token = results[end]
210+
ctx.token = ctx.global_token
211+
212+
for key in keys(token_map_before)
213+
ctx.token_map[key] = ctx.global_token
214+
end
198215

199216
# Store results at the loop's SSA index (may be empty for void-returning loops)
200217
ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type)
@@ -230,6 +247,9 @@ function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type)
230247
# Number of user result types (excluding token)
231248
n_user_results = n_carries
232249

250+
# Save token_map before loop
251+
token_map_before = copy(ctx.token_map)
252+
233253
# Emit LoopOp with callback-based region building
234254
body_builder = function(block_args)
235255
saved_block_args = copy(ctx.block_args)
@@ -263,8 +283,11 @@ function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type)
263283
end
264284
results = encode_LoopOp!(body_builder, cb, result_types, init_values)
265285

266-
# Last result is the token
267-
ctx.token = results[end]
286+
ctx.token = ctx.global_token
287+
288+
for key in keys(token_map_before)
289+
ctx.token_map[key] = ctx.global_token
290+
end
268291

269292
# Store results at the loop's SSA index (may be empty for void-returning loops)
270293
ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type)
@@ -301,6 +324,9 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ
301324
# Number of user result types (excluding token)
302325
n_user_results = n_carries
303326

327+
# Save token_map before loop
328+
token_map_before = copy(ctx.token_map)
329+
304330
# Emit WhileOp as cuda_tile.loop with conditional break pattern
305331
# MLIR structure: before { stmts; condition(cond) args } do { stmts; yield vals }
306332
# Emitted as: loop { before_stmts; if(!cond) { break } else { yield }; after_stmts; continue }
@@ -393,8 +419,11 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ
393419
end
394420
results = encode_LoopOp!(body_builder, cb, result_types, init_values)
395421

396-
# Last result is the token
397-
ctx.token = results[end]
422+
ctx.token = ctx.global_token
423+
424+
for key in keys(token_map_before)
425+
ctx.token_map[key] = ctx.global_token
426+
end
398427

399428
# Store results at the loop's SSA index (may be empty for void-returning loops)
400429
ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type)

src/compiler/codegen/kernel.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,30 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
134134
create_tensor_views!(ctx, arg_idx, argtype, Int[])
135135
end
136136

137+
# Run alias analysis FIRST
138+
alias_result = alias_analysis_pass!(sci)
139+
ctx.alias_result = alias_result
140+
137141
# Create memory ordering token
138142
token_type = Token(tt)
139143
ctx.token_type = token_type
140-
ctx.token = encode_MakeTokenOp!(cb, token_type)
144+
root_token = encode_MakeTokenOp!(cb, token_type)
145+
146+
ctx.global_token = root_token
147+
ctx.token = root_token
148+
149+
# Initialize token map with root token for all alias sets
150+
# Default: all tokens start at root
151+
ctx.token_map = Dict{TokenKey, Value}()
152+
153+
unique_alias_sets = Set(values(alias_result))
154+
for alias_set in unique_alias_sets
155+
ctx.token_map[last_op_key(alias_set)] = root_token
156+
ctx.token_map[last_store_key(alias_set)] = root_token
157+
end
158+
159+
# ACQUIRE token also starts at root
160+
ctx.token_map[ACQUIRE_TOKEN_KEY] = root_token
141161

142162
# Hoist early returns out of IfOp regions (tileiras rejects ReturnOp inside IfOp)
143163
hoist_returns!(ctx.sci.entry)

0 commit comments

Comments
 (0)