Skip to content

Commit 432cec3

Browse files
committed
Add alias-aware token threading for memory operations.
Introduce alias analysis–based token threading: - Group pointers into alias sets. - Maintain per-alias-set token chains. - Thread tokens only between potentially aliasing operations. - Conservatively fall back to the global set for unknown pointers. - Preserve existing control-flow token merging semantics. Enables independent memory operations to execute without unnecessary serialization.
1 parent d595fcc commit 432cec3

9 files changed

Lines changed: 497 additions & 17 deletions

File tree

src/compiler/codegen.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Codegen: Julia IR -> Tile IR bytecode
22

33
include("codegen/utils.jl")
4+
include("codegen/token_keys.jl") # Defines TokenKey, TokenRole, ACQUIRE_TOKEN_KEY
5+
include("codegen/alias_analysis.jl") # Defines alias_analysis_pass!
6+
include("codegen/token_order.jl") # Defines get_alias_set, get_input_token!
47
include("codegen/kernel.jl")
58
include("codegen/control_flow.jl")
69
include("codegen/statements.jl")
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
"""
2+
AliasTracker
3+
4+
Tracks alias sets for each SSA value during fixed-point analysis.
5+
"""
6+
mutable struct AliasTracker
7+
dirty::Bool
8+
aliases::Dict{Any, AliasSet} # SSAValue/Argument/SlotNumber -> AliasSet
9+
end
10+
11+
AliasTracker() = AliasTracker(false, Dict{Any, AliasSet}())
12+
13+
function Base.getindex(tracker::AliasTracker, key)
14+
return get(tracker.aliases, key, ALIAS_UNIVERSE)
15+
end
16+
17+
function Base.setindex!(tracker::AliasTracker, value::AliasSet, key)
18+
current = get(tracker.aliases, key, nothing)
19+
return if current !== value
20+
tracker.dirty = true
21+
tracker.aliases[key] = value
22+
end
23+
end
24+
25+
"""
26+
alias_analysis_pass!(sci::StructuredIRCode) -> Dict{Any, AliasSet}
27+
28+
Perform fixed-point alias analysis on structured IR.
29+
Returns mapping from SSA values to alias sets.
30+
"""
31+
function alias_analysis_pass!(sci::StructuredIRCode)
32+
tracker = AliasTracker()
33+
34+
# Initialize: each argument gets its own alias set
35+
for (idx, argtype) in enumerate(sci.argtypes)
36+
argtype_unwrapped = CC.widenconst(argtype)
37+
if contains_pointers(argtype_unwrapped)
38+
arg_ref = Argument(idx)
39+
tracker[arg_ref] = Set{Any}([arg_ref])
40+
end
41+
end
42+
43+
# Fixed-point iteration
44+
iteration = 0
45+
max_iterations = 100
46+
47+
tracker.dirty = true
48+
while tracker.dirty && iteration < max_iterations
49+
tracker.dirty = false
50+
iteration += 1
51+
52+
analyze_block!(tracker, sci.entry)
53+
end
54+
55+
@debug "Alias analysis converged in $iteration iterations"
56+
57+
return tracker.aliases
58+
end
59+
60+
"""
61+
propagate!(tracker::AliasTracker, from, to)
62+
63+
Propagate alias set from `from` to `to` (union operation).
64+
"""
65+
function propagate!(tracker::AliasTracker, from, to)
66+
from_aliases = tracker[from]
67+
to_aliases = tracker[to]
68+
69+
# Union the alias sets
70+
new_aliases = union(from_aliases, to_aliases)
71+
72+
return if new_aliases != to_aliases
73+
tracker[to] = new_aliases
74+
end
75+
end
76+
77+
"""
78+
analyze_block!(tracker::AliasTracker, block)
79+
80+
Analyze all statements in a block.
81+
"""
82+
function analyze_block!(tracker::AliasTracker, block)
83+
# Block has args, body, terminator
84+
# body is an iterator that yields (ssa_idx, entry) where entry has .stmt and .typ
85+
for (ssa_idx, entry) in block.body
86+
analyze_statement!(tracker, SSAValue(ssa_idx), entry.stmt)
87+
end
88+
return
89+
end
90+
91+
"""
92+
analyze_statement!(tracker::AliasTracker, ssa::SSAValue, stmt)
93+
94+
Analyze a single statement and propagate aliases.
95+
"""
96+
function analyze_statement!(tracker::AliasTracker, ssa::SSAValue, stmt)
97+
return if stmt isa Expr && stmt.head === :call
98+
func = stmt.args[1]
99+
100+
# getfield: propagate from parent
101+
if func === GlobalRef(Core, :getfield) && length(stmt.args) >= 2
102+
parent = stmt.args[2]
103+
field = length(stmt.args) >= 3 ? stmt.args[3] : nothing
104+
105+
# For TileArray.ptr field access, propagate pointer alias
106+
if field isa QuoteNode && field.value === :ptr
107+
propagate!(tracker, parent, ssa)
108+
else
109+
# Conservatively mark as UNIVERSE for non-pointer fields
110+
tracker[ssa] = ALIAS_UNIVERSE
111+
end
112+
113+
# Pointer arithmetic: propagate from pointer operand
114+
elseif func === GlobalRef(Base, :+) || func === GlobalRef(Base, :-)
115+
for arg in stmt.args[2:end]
116+
# Find the pointer argument and propagate
117+
arg_aliases = tracker[arg]
118+
if arg_aliases !== ALIAS_UNIVERSE || arg_aliases isa Set
119+
propagate!(tracker, arg, ssa)
120+
break
121+
end
122+
end
123+
124+
# TileArray construction: propagate from pointer argument
125+
elseif is_tile_array_constructor(func)
126+
# First argument is typically the pointer
127+
if length(stmt.args) >= 2
128+
propagate!(tracker, stmt.args[2], ssa)
129+
end
130+
131+
# Default: unknown operation -> UNIVERSE
132+
else
133+
tracker[ssa] = ALIAS_UNIVERSE
134+
end
135+
136+
# Control flow operations need special handling
137+
elseif stmt isa ReturnNode
138+
# No alias propagation needed
139+
140+
else
141+
# Unknown statement type -> conservative
142+
tracker[ssa] = ALIAS_UNIVERSE
143+
end
144+
end
145+
146+
# Helper functions
147+
contains_pointers(T) = T <: Ptr || T <: TileArray || (T <: Tile && eltype(T) <: Ptr)
148+
149+
function is_tile_array_constructor(func)
150+
# Check if this is a TileArray constructor call
151+
# You'll need to detect the specific GlobalRef for TileArray
152+
return false # TODO: implement
153+
end

src/compiler/codegen/control_flow.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,14 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_
8888
# Save token before branches
8989
token_before = ctx.token
9090

91+
# Save token_map before branches
92+
token_map_before = copy(ctx.token_map)
93+
9194
# Emit IfOp with callback-based region building
9295
then_body = function(_)
9396
saved_block_args = copy(ctx.block_args)
9497
ctx.token = token_before # Reset to pre-branch token
98+
ctx.token_map = copy(token_map_before) # Reset token_map too
9599
emit_block!(ctx, then_blk)
96100
if then_blk.terminator === nothing
97101
encode_YieldOp!(ctx.cb, [ctx.token])
@@ -102,6 +106,7 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_
102106
else_body = function(_)
103107
saved_block_args = copy(ctx.block_args)
104108
ctx.token = token_before # Reset to pre-branch token
109+
ctx.token_map = copy(token_map_before) # Reset token_map too
105110
emit_block!(ctx, else_blk)
106111
if else_blk.terminator === nothing
107112
encode_YieldOp!(ctx.cb, [ctx.token])
@@ -114,6 +119,12 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_
114119
# Last result is the merged token from both branches
115120
ctx.token = results[end]
116121

122+
# Merge token_map from both branches
123+
# Conservatively reset to token_before for all keys
124+
for key in keys(ctx.token_map)
125+
ctx.token_map[key] = results[end]
126+
end
127+
117128
# Store results at IfOp's SSA index (may be empty for void-returning ifs)
118129
ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type)
119130
end
@@ -164,6 +175,9 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type),
164175
# Number of user result types (excluding token)
165176
n_user_results = n_carries
166177

178+
# Save token_map before loop
179+
token_map_before = copy(ctx.token_map)
180+
167181
# Emit ForOp with callback-based region building
168182
body_builder = function(block_args)
169183
saved_block_args = copy(ctx.block_args)
@@ -196,6 +210,12 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type),
196210
# Last result is the token
197211
ctx.token = results[end]
198212

213+
# Update token_map after loop
214+
# Conservatively update all keys to the merged token
215+
for key in keys(token_map_before)
216+
ctx.token_map[key] = results[end]
217+
end
218+
199219
# Store results at the loop's SSA index (may be empty for void-returning loops)
200220
ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type)
201221
end
@@ -230,6 +250,9 @@ function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type)
230250
# Number of user result types (excluding token)
231251
n_user_results = n_carries
232252

253+
# Save token_map before loop
254+
token_map_before = copy(ctx.token_map)
255+
233256
# Emit LoopOp with callback-based region building
234257
body_builder = function(block_args)
235258
saved_block_args = copy(ctx.block_args)
@@ -266,6 +289,12 @@ function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type)
266289
# Last result is the token
267290
ctx.token = results[end]
268291

292+
# Update token_map after loop
293+
# Conservatively update all keys to the merged token
294+
for key in keys(token_map_before)
295+
ctx.token_map[key] = results[end]
296+
end
297+
269298
# Store results at the loop's SSA index (may be empty for void-returning loops)
270299
ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type)
271300
end
@@ -301,6 +330,9 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ
301330
# Number of user result types (excluding token)
302331
n_user_results = n_carries
303332

333+
# Save token_map before loop
334+
token_map_before = copy(ctx.token_map)
335+
304336
# Emit WhileOp as cuda_tile.loop with conditional break pattern
305337
# MLIR structure: before { stmts; condition(cond) args } do { stmts; yield vals }
306338
# Emitted as: loop { before_stmts; if(!cond) { break } else { yield }; after_stmts; continue }
@@ -396,6 +428,12 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ
396428
# Last result is the token
397429
ctx.token = results[end]
398430

431+
# Update token_map after loop
432+
# Conservatively update all keys to the merged token
433+
for key in keys(token_map_before)
434+
ctx.token_map[key] = results[end]
435+
end
436+
399437
# Store results at the loop's SSA index (may be empty for void-returning loops)
400438
ctx.values[ssa_idx] = CGVal(results[1:n_user_results], parent_result_type)
401439
end

src/compiler/codegen/kernel.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,30 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
152152
cache_tensor_view!(ctx, arg_idx)
153153
end
154154

155+
# Run alias analysis FIRST
156+
alias_result = alias_analysis_pass!(sci)
157+
ctx.alias_result = alias_result
158+
155159
# Create memory ordering token
156160
token_type = Token(tt)
157161
ctx.token_type = token_type
158-
ctx.token = encode_MakeTokenOp!(cb, token_type)
162+
root_token = encode_MakeTokenOp!(cb, token_type)
163+
164+
ctx.global_token = root_token
165+
ctx.token = root_token
166+
167+
# Initialize token map with root token for all alias sets
168+
# Default: all tokens start at root
169+
ctx.token_map = Dict{TokenKey, Value}()
170+
171+
unique_alias_sets = Set(values(alias_result))
172+
for alias_set in unique_alias_sets
173+
ctx.token_map[last_op_key(alias_set)] = root_token
174+
ctx.token_map[last_store_key(alias_set)] = root_token
175+
end
176+
177+
# ACQUIRE token also starts at root
178+
ctx.token_map[ACQUIRE_TOKEN_KEY] = root_token
159179

160180
# Emit the structured IR (uses original Julia SSA indices everywhere)
161181
emit_block!(ctx, ctx.sci.entry)

src/compiler/codegen/token_keys.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Token role enum
2+
@enum TokenRole LAST_OP LAST_STORE
3+
4+
# Acquire token key (singleton)
5+
struct AcquireTokenKey end
6+
const ACQUIRE_TOKEN_KEY = AcquireTokenKey()
7+
8+
# Alias token key (per alias set and role)
9+
struct AliasTokenKey
10+
alias_set::AliasSet
11+
role::TokenRole
12+
end
13+
14+
# Union type for all token keys
15+
const TokenKey = Union{AliasTokenKey, AcquireTokenKey}
16+
17+
# Helper constructors
18+
"""
19+
last_op_key(alias_set::AliasSet) -> AliasTokenKey
20+
21+
Create a TokenKey for the last operation (load or store) on an alias set.
22+
"""
23+
last_op_key(alias_set::AliasSet) = AliasTokenKey(alias_set, LAST_OP)
24+
25+
"""
26+
last_store_key(alias_set::AliasSet) -> AliasTokenKey
27+
28+
Create a TokenKey for the last store operation on an alias set.
29+
"""
30+
last_store_key(alias_set::AliasSet) = AliasTokenKey(alias_set, LAST_STORE)
31+
32+
# Make TokenKey hashable for use in Dict
33+
Base.hash(key::AliasTokenKey, h::UInt) = hash((key.alias_set, key.role), h)
34+
Base.:(==)(a::AliasTokenKey, b::AliasTokenKey) =
35+
a.alias_set == b.alias_set && a.role == b.role
36+
37+
Base.hash(::AcquireTokenKey, h::UInt) = hash(:ACQUIRE_TOKEN_KEY, h)
38+
Base.:(==)(::AcquireTokenKey, ::AcquireTokenKey) = true

0 commit comments

Comments
 (0)