Skip to content

Commit 5879321

Browse files
maleadtclaude
andauthored
Constant propagation and identity fold for the rewrite framework (#160)
Replace the ad-hoc _is_transparent tracing in PCall pattern matching with two proper mechanisms: 1. Identity fold rewrite rules: eliminate identity broadcasts and reshapes (same shape in/out) that are no-ops left behind by the broadcast system. FMA patterns now match directly without transparent-op tracing. 2. Constant propagation analysis: propagate_constants() builds a Dict{SSAValue, Array} mapping SSA values to their known constant tile contents (e.g., broadcast(1, (16,)) → fill(Int32(1), 16)). This feeds into the rewriter via a constants kwarg. 3. PLiteral pattern node: $() on the LHS matches literal values via O(1) constants dict lookup instead of recursive _traces_to_const tracing. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 6a8b83f commit 5879321

File tree

2 files changed

+130
-30
lines changed

2 files changed

+130
-30
lines changed

src/compiler/passes/pipeline.jl

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,117 @@ const ALGEBRA_RULES = RewriteRule[
5050

5151
algebra_pass!(sci::StructuredIRCode) = rewrite_patterns!(sci, ALGEBRA_RULES)
5252

53+
#=============================================================================
54+
Identity Fold (rewrite)
55+
=============================================================================#
56+
57+
# Eliminate identity broadcasts and reshapes (same shape in/out). These are
58+
# no-ops left behind by the broadcast system after scalar elimination.
59+
60+
function is_identity_op(match, driver)
61+
x = match.bindings[:x]
62+
val = first(match.matched_ssas)
63+
entry = driver.defs[val]
64+
in_t = value_type(entry.block, x)
65+
out_t = value_type(entry.block, val)
66+
in_t === nothing && return false
67+
out_t === nothing && return false
68+
in_T = CC.widenconst(in_t)
69+
out_T = CC.widenconst(out_t)
70+
in_T <: Tile && out_T <: Tile || return false
71+
return size(in_T) == size(out_T)
72+
end
73+
74+
const IDENTITY_RULES = RewriteRule[
75+
@rewrite(Intrinsics.broadcast(~x, ~shape) => ~x, is_identity_op)
76+
@rewrite(Intrinsics.reshape(~x, ~shape) => ~x, is_identity_op)
77+
]
78+
5379
#=============================================================================
5480
Combined Rule Set
5581
=============================================================================#
5682

5783
const OPTIMIZATION_RULES = RewriteRule[
84+
IDENTITY_RULES...,
5885
ALGEBRA_RULES...,
5986
FMA_RULES...,
6087
]
6188

89+
#=============================================================================
90+
Constant Propagation (analysis)
91+
=============================================================================#
92+
93+
# Tracks which SSA values have known constant values. Constants are represented
94+
# as Julia Arrays matching the tile's element type and shape. This enables O(1)
95+
# constant lookups in the rewrite pattern matcher (PLiteral).
96+
97+
"""
98+
propagate_constants(sci) -> Dict{SSAValue, Any}
99+
100+
Build a map from SSA values to their known constant values. Walks all blocks
101+
in program order so transitive constants (e.g. reshape of a broadcast of a
102+
literal) resolve correctly.
103+
"""
104+
function propagate_constants(sci::StructuredIRCode)
105+
constants = Dict{SSAValue, Any}()
106+
propagate_constants!(constants, sci.entry)
107+
return constants
108+
end
109+
110+
function propagate_constants!(constants::Dict{SSAValue, Any}, block::Block)
111+
# Recurse into nested control flow first
112+
for inst in instructions(block)
113+
s = stmt(inst)
114+
if s isa ForOp
115+
propagate_constants!(constants, s.body)
116+
elseif s isa IfOp
117+
propagate_constants!(constants, s.then_region)
118+
propagate_constants!(constants, s.else_region)
119+
elseif s isa WhileOp
120+
propagate_constants!(constants, s.before)
121+
propagate_constants!(constants, s.after)
122+
elseif s isa LoopOp
123+
propagate_constants!(constants, s.body)
124+
end
125+
end
126+
127+
for inst in instructions(block)
128+
call = resolve_call(block, inst)
129+
call === nothing && continue
130+
func, ops = call
131+
132+
# Transparent ops (broadcast, reshape) propagate constants from operand
133+
if (func === Intrinsics.broadcast || func === Intrinsics.reshape) &&
134+
length(ops) >= 1
135+
scalar = const_value(constants, ops[1])
136+
scalar === nothing && continue
137+
vt = value_type(block, SSAValue(inst))
138+
vt === nothing && continue
139+
T = CC.widenconst(vt)
140+
T <: Tile || continue
141+
S = size(T)
142+
constants[SSAValue(inst)] = fill(convert(eltype(T), scalar), S)
143+
end
144+
end
145+
end
146+
147+
"""Resolve an operand to its scalar constant value, or `nothing`."""
148+
function const_value(constants::Dict{SSAValue, Any}, @nospecialize(op))
149+
if op isa Number
150+
return op
151+
elseif op isa QuoteNode && op.value isa Number
152+
return op.value
153+
elseif op isa SSAValue
154+
c = get(constants, op, nothing)
155+
c isa AbstractArray || return nothing
156+
isempty(c) && return nothing
157+
v = first(c)
158+
all(==(v), c) && return v
159+
return nothing
160+
end
161+
return nothing
162+
end
163+
62164
#=============================================================================
63165
Pass Pipeline
64166
=============================================================================#
@@ -72,7 +174,8 @@ and subprogram compilation.
72174
function run_passes!(sci::StructuredIRCode)
73175
canonicalize!(sci)
74176

75-
rewrite_patterns!(sci, OPTIMIZATION_RULES)
177+
constants = propagate_constants(sci)
178+
rewrite_patterns!(sci, OPTIMIZATION_RULES; constants)
76179

77180
alias_result = alias_analysis_pass!(sci)
78181
token_order_pass!(sci, alias_result)

src/compiler/passes/rewrite.jl

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ struct PCall <: PatternNode; func::Any; operands::Vector{PatternNode}; end
2626
struct PBind <: PatternNode; name::Symbol; end
2727
struct PTypedBind <: PatternNode; name::Symbol; type::Type; end
2828
struct POneUse <: PatternNode; inner::PatternNode; end
29+
struct PLiteral <: PatternNode; val::Any; end
2930

3031
abstract type RewriteNode end
3132
struct RCall <: RewriteNode; func::Any; operands::Vector{RewriteNode}; end
@@ -86,6 +87,10 @@ macro rewriter(ex)
8687
end
8788

8889
function _compile_lhs(ex)
90+
# $(expr) on the LHS: match a literal value
91+
if ex isa Expr && ex.head === :$
92+
return :(PLiteral($(ex.args[1])))
93+
end
8994
ex isa Expr && ex.head === :call || error("@rewrite LHS: expected call, got $ex")
9095
f = ex.args[1]
9196
if f === :~
@@ -172,17 +177,14 @@ mutable struct RewriteDriver
172177
defs::Dict{SSAValue, DefEntry}
173178
dispatch::Dict{Any, Vector{RewriteRule}}
174179
worklist::Worklist
180+
constants::Dict{SSAValue, Any} # SSA → constant value (from propagate_constants)
175181
max_rewrites::Int
176182
end
177183

178184
"""Compute fresh use count for an SSA value."""
179185
_use_count(driver::RewriteDriver, val::SSAValue) =
180186
length(uses(driver.sci.entry, val))
181187

182-
# Codegen no-ops that pattern matching traces through transparently.
183-
_is_transparent(func) = func === Intrinsics.broadcast ||
184-
func === Intrinsics.reshape
185-
186188
#=============================================================================
187189
Notifications
188190
=============================================================================#
@@ -266,29 +268,6 @@ function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PCall,
266268
end
267269
end
268270

269-
# Trace through single-use transparent ops to find the underlying operation
270-
if _is_transparent(entry.func)
271-
_use_count(driver, val) == 1 || return nothing
272-
ops = _def_operands(entry)
273-
isempty(ops) && return nothing
274-
if entry.func === Intrinsics.broadcast
275-
inner = ops[1]
276-
if inner isa SSAValue
277-
inner_entry = get(driver.defs, inner, nothing)
278-
if inner_entry !== nothing
279-
it = value_type(entry.block, inner)
280-
ot = value_type(entry.block, val)
281-
it !== nothing && ot !== nothing || return nothing
282-
CC.widenconst(it) <: Tile && CC.widenconst(ot) <: Tile || return nothing
283-
size(CC.widenconst(it)) == size(CC.widenconst(ot)) || return nothing
284-
end
285-
end
286-
end
287-
result = pattern_match(driver, ops[1], pat, entry.block)
288-
result === nothing && return nothing
289-
push!(result.matched_ssas, val)
290-
return result
291-
end
292271
return nothing
293272
end
294273

@@ -309,6 +288,23 @@ function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::POneUse,
309288
pattern_match(driver, val, pat.inner, block)
310289
end
311290

291+
# PLiteral: match if the operand equals the given value.
292+
# For non-SSA operands (enum constants, predicates): checks ===.
293+
# For SSA operands: O(1) lookup in the constants map built by propagate_constants.
294+
function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PLiteral,
295+
block::Block=driver.sci.entry)
296+
val === pat.val && return MatchResult(Dict{Symbol,Any}(), SSAValue[])
297+
if val isa SSAValue
298+
c = get(driver.constants, val, nothing)
299+
if c isa AbstractArray
300+
all(==(pat.val), c) && return MatchResult(Dict{Symbol,Any}(), SSAValue[])
301+
elseif c !== nothing
302+
c == pat.val && return MatchResult(Dict{Symbol,Any}(), SSAValue[])
303+
end
304+
end
305+
return nothing
306+
end
307+
312308
#=============================================================================
313309
Rewrite Application
314310
=============================================================================#
@@ -373,7 +369,8 @@ Rules are tried until no more matches fire or `max_rewrites` is reached.
373369
Dead code left behind is cleaned up by the pipeline's `dce_pass!`.
374370
"""
375371
function rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule};
376-
max_rewrites::Int=10_000)
372+
max_rewrites::Int=10_000,
373+
constants::Dict{SSAValue, Any}=Dict{SSAValue, Any}())
377374
# Build dispatch table
378375
dispatch = Dict{Any, Vector{RewriteRule}}()
379376
for rule in rules
@@ -401,7 +398,7 @@ function rewrite_patterns!(sci::StructuredIRCode, rules::Vector{RewriteRule};
401398
end
402399
end
403400

404-
driver = RewriteDriver(sci, defs, dispatch, wl, max_rewrites)
401+
driver = RewriteDriver(sci, defs, dispatch, wl, constants, max_rewrites)
405402

406403
num_rewrites = 0
407404
while !isempty(driver.worklist) && num_rewrites < driver.max_rewrites

0 commit comments

Comments
 (0)