-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathpipeline.jl
More file actions
313 lines (262 loc) · 12.8 KB
/
pipeline.jl
File metadata and controls
313 lines (262 loc) · 12.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
# Pass Pipeline
#
# Defines all IR passes and their execution order. Rewrite-based passes are
# defined inline here; complex imperative passes live in their own files
# (alias_analysis.jl, token_order.jl, dce.jl) and are called from run_passes!.
#=============================================================================
Print Fusion (rewrite)
=============================================================================#
# Fuse format_string (from string interpolation overlay) into print_tko.
# Julia lowers `print("hello $x")` → `print(string("hello ", x))`, which our
# overlays compile to `print_tko(format_string("hello ", x), "\n")`.
# This rule inlines format_string's args into the print_tko call.
const PRINT_FUSION_RULES = RewriteRule[
@rewrite Intrinsics.print_tko(Intrinsics.format_string(~parts...), ~rest...) =>
Intrinsics.print_tko(~parts..., ~rest...)
]
#=============================================================================
FMA Fusion (rewrite)
=============================================================================#
# mul+add/sub → fma to reduce register pressure.
# Mirrors cuTile Python's fuse_mul_addsub in rewrite_patterns.py.
# RM/FTZ are applied by the ambient @fpmode scope during codegen, not per-op.
const FMA_RULES = RewriteRule[
@rewrite Intrinsics.addf(one_use(Intrinsics.mulf(~x, ~y)), ~z) =>
Intrinsics.fma(~x, ~y, ~z)
@rewrite Intrinsics.addf(~z, one_use(Intrinsics.mulf(~x, ~y))) =>
Intrinsics.fma(~x, ~y, ~z)
@rewrite Intrinsics.subf(one_use(Intrinsics.mulf(~x, ~y)), ~z) =>
Intrinsics.fma(~x, ~y, Intrinsics.negf(~z))
]
fma_fusion_pass!(sci::StructuredIRCode) = rewrite_patterns!(sci, FMA_RULES)
#=============================================================================
Algebraic Simplification (rewrite)
=============================================================================#
# Cancel inverse addi/subi pairs: x+c-c → x, x-c+c → x.
# Repeated ~c binds enforce that both operands are the same value.
# Guard factory: check that the given bindings all resolve to the same constant
# value. Like MLIR's ConstantLikeMatcher: matches on attribute value, not SSA
# identity. Returns a guard function for use with @rewrite.
function same_const(keys::Symbol...)
(match, driver) -> begin
vals = map(keys) do k
const_value(driver.constants, match.bindings[k])
end
all(!isnothing, vals) && allequal(vals)
end
end
"""Commute addi/subi past a transparent op (reshape or broadcast) by recreating
the constant at the pre-transparent shape. The transparent op is determined from
the matched intermediate instruction, so one function handles both reshape and
broadcast patterns."""
function commute_arith_transparent(sci, block, inst, match, driver)
x = match.bindings[:x]
c = match.bindings[:c]
scalar = const_value(driver.constants, c)
scalar === nothing && return false
# Don't commute when x is also a constant (would loop).
const_value(driver.constants, x) !== nothing && return false
x_type = value_type(block, x)
x_type === nothing && return false
xT = CC.widenconst(x_type)
xT <: Tile || return false
val = SSAValue(inst)
root_func = driver.defs[val].func # Intrinsics.subi or Intrinsics.addi
# Determine transparent op (reshape or broadcast) from first operand
inner_val = def_operands(driver.defs[val])[1]
transparent_func = driver.defs[inner_val].func
# Don't commute through identity ops (same shape in/out) — that's a
# pointless restructuring that breaks pattern matching for other rules.
# Identity transparent ops are handled by IDENTITY_RULES instead.
inst_type = value_type(block, val)
inst_type === nothing && return false
size(CC.widenconst(inst_type)) == size(xT) && return false
# Insert broadcast of the scalar to x's shape and register as constant
x_shape = size(xT)
bc_type = Tile{eltype(xT), Tuple{x_shape...}}
bc = insert_before!(block, val, Expr(:call, Intrinsics.broadcast, scalar, x_shape), bc_type)
notify_insert!(driver, block, bc)
driver.constants[SSAValue(bc)] = fill(convert(eltype(xT), scalar), x_shape)
# Insert op(x, broadcast) with x's type
op = insert_before!(block, val, Expr(:call, root_func, x, SSAValue(bc)), xT)
notify_insert!(driver, block, op)
# Replace root with transparent_op(op_result, s)
pos = findfirst(==(val.id), block.body.ssa_idxes)
block.body.stmts[pos] = Expr(:call, transparent_func, SSAValue(op), match.bindings[:s])
driver.defs[val] = DefEntry(block, val, transparent_func)
push!(driver.worklist, val)
add_users_to_worklist!(driver, val)
return true
end
const ALGEBRA_RULES = RewriteRule[
# SSA-identity cancellation: subi(addi(x, c), c) where c is the same SSA value
@rewrite Intrinsics.subi(Intrinsics.addi(~x, ~c), ~c) => ~x
@rewrite Intrinsics.addi(Intrinsics.subi(~x, ~c), ~c) => ~x
# Constant-value cancellation: subi(addi(x, c0), c1) where c0 == c1 as values
# (different SSA defs, same constant). Catches 1-based indexing patterns where
# arange(N)+1 produces one broadcast(1) and gather's -1 produces another.
# Generalizes MLIR's arith.addi/subi canonicalization for matching constants.
@rewrite(Intrinsics.subi(Intrinsics.addi(~x, ~c0), ~c1) => ~x, same_const(:c0, :c1))
@rewrite(Intrinsics.addi(Intrinsics.subi(~x, ~c0), ~c1) => ~x, same_const(:c0, :c1))
# Nested cancellation: (a + (b + c)) - c → a + b
# Catches arange pattern where iota+1 is added to an offset, then gather/scatter
# subtracts 1: subi(addi(offset, addi(iota, 1)), 1) → addi(offset, iota).
@rewrite(Intrinsics.subi(Intrinsics.addi(~a, Intrinsics.addi(~b, ~c0)), ~c1) =>
Intrinsics.addi(~a, ~b), same_const(:c0, :c1))
@rewrite(Intrinsics.addi(Intrinsics.subi(~a, Intrinsics.subi(~b, ~c0)), ~c1) =>
Intrinsics.subi(~a, ~b), same_const(:c0, :c1))
# Commute addi/subi through transparent ops (reshape, broadcast).
# Moves arithmetic past the transparent op so cancellation rules above can fire.
@rewriter Intrinsics.subi(Intrinsics.reshape(~x, ~s), ~c) => commute_arith_transparent
@rewriter Intrinsics.addi(Intrinsics.reshape(~x, ~s), ~c) => commute_arith_transparent
@rewriter Intrinsics.subi(Intrinsics.broadcast(~x, ~s), ~c) => commute_arith_transparent
@rewriter Intrinsics.addi(Intrinsics.broadcast(~x, ~s), ~c) => commute_arith_transparent
]
algebra_pass!(sci::StructuredIRCode) = rewrite_patterns!(sci, ALGEBRA_RULES)
#=============================================================================
Identity Fold (rewrite)
=============================================================================#
# Eliminate identity broadcasts and reshapes (same shape in/out). These are
# no-ops left behind by the broadcast system after scalar elimination.
function is_identity_op(match, driver)
x = match.bindings[:x]
val = first(match.matched_ssas)
entry = driver.defs[val]
in_t = value_type(entry.block, x)
out_t = value_type(entry.block, val)
in_t === nothing && return false
out_t === nothing && return false
in_T = CC.widenconst(in_t)
out_T = CC.widenconst(out_t)
in_T <: Tile && out_T <: Tile || return false
return size(in_T) == size(out_T)
end
const IDENTITY_RULES = RewriteRule[
@rewrite(Intrinsics.broadcast(~x, ~shape) => ~x, is_identity_op)
@rewrite(Intrinsics.reshape(~x, ~shape) => ~x, is_identity_op)
]
#=============================================================================
Comparison Strength Reduction (rewrite)
=============================================================================#
# (x + 1) <= y → x < y for signed integers.
# Canonicalizes Julia's 1-based `arange(N) .+ 1 .<= limit` mask pattern
# into 0-based `arange(N) .< limit`, eliminating the tile-wide addi(iota, 1).
const COMPARISON_RULES = RewriteRule[
# Direct: cmpi(addi(x, 1), y, <=, signed) → cmpi(x, y, <, signed)
@rewrite Intrinsics.cmpi(Intrinsics.addi(~x, $(1)), ~y,
$(ComparisonPredicate.LessThanOrEqual), $(Signedness.Signed)) =>
Intrinsics.cmpi(~x, ~y, $(ComparisonPredicate.LessThan), $(Signedness.Signed))
# Nested: cmpi(addi(a, addi(b, 1)), y, <=, signed) → cmpi(addi(a, b), y, <, signed)
# Uses inplace=true to modify the existing addi and cmpi ops' operands rather
# than creating new ones (which would cascade the worklist).
@rewrite(inplace=true,
Intrinsics.cmpi(Intrinsics.addi(~a, Intrinsics.addi(~b, $(1))), ~y,
$(ComparisonPredicate.LessThanOrEqual), $(Signedness.Signed)) =>
Intrinsics.cmpi(Intrinsics.addi(~a, ~b), ~y,
$(ComparisonPredicate.LessThan), $(Signedness.Signed)))
]
#=============================================================================
Combined Rule Set
=============================================================================#
const OPTIMIZATION_RULES = RewriteRule[
IDENTITY_RULES...,
ALGEBRA_RULES...,
FMA_RULES...,
COMPARISON_RULES...,
]
#=============================================================================
Constant Propagation (analysis)
=============================================================================#
# Tracks which SSA values have known constant values. Constants are represented
# as Julia Arrays matching the tile's element type and shape. This enables O(1)
# constant lookups in the rewrite pattern matcher (PLiteral).
"""
propagate_constants(sci) -> Dict{SSAValue, Any}
Build a map from SSA values to their known constant values. Walks all blocks
in program order so transitive constants (e.g. reshape of a broadcast of a
literal) resolve correctly.
"""
function propagate_constants(sci::StructuredIRCode)
constants = Dict{SSAValue, Any}()
propagate_constants!(constants, sci.entry)
return constants
end
function propagate_constants!(constants::Dict{SSAValue, Any}, block::Block)
# Recurse into nested control flow first
for inst in instructions(block)
s = stmt(inst)
if s isa ForOp
propagate_constants!(constants, s.body)
elseif s isa IfOp
propagate_constants!(constants, s.then_region)
propagate_constants!(constants, s.else_region)
elseif s isa WhileOp
propagate_constants!(constants, s.before)
propagate_constants!(constants, s.after)
elseif s isa LoopOp
propagate_constants!(constants, s.body)
end
end
for inst in instructions(block)
call = resolve_call(block, inst)
call === nothing && continue
func, ops = call
# Seed constants from constant ops
if func === Intrinsics.constant && length(ops) >= 2
scalar = const_value(constants, ops[2])
scalar === nothing && continue
vt = value_type(block, SSAValue(inst))
vt === nothing && continue
T = CC.widenconst(vt)
T <: Tile || continue
S = size(T)
all(>=(0), S) || continue # skip invalid shapes (caught later by validation)
constants[SSAValue(inst)] = fill(convert(eltype(T), scalar), S)
end
# Transparent ops (broadcast, reshape) propagate constants from operand
if (func === Intrinsics.broadcast || func === Intrinsics.reshape) &&
length(ops) >= 1
scalar = const_value(constants, ops[1])
scalar === nothing && continue
vt = value_type(block, SSAValue(inst))
vt === nothing && continue
T = CC.widenconst(vt)
T <: Tile || continue
S = size(T)
constants[SSAValue(inst)] = fill(convert(eltype(T), scalar), S)
end
end
end
"""Resolve an operand to its scalar constant value, or `nothing`."""
function const_value(constants::Dict{SSAValue, Any}, @nospecialize(op))
if op isa Number
return op
elseif op isa QuoteNode && op.value isa Number
return op.value
elseif op isa SSAValue
c = get(constants, op, nothing)
c isa AbstractArray || return nothing
isempty(c) && return nothing
v = first(c)
all(==(v), c) && return v
return nothing
end
return nothing
end
#=============================================================================
Pass Pipeline
=============================================================================#
"""
run_passes!(sci::StructuredIRCode)
Run the full pass pipeline on a StructuredIRCode. Called for both kernel
and subprogram compilation.
"""
function run_passes!(sci::StructuredIRCode)
canonicalize!(sci)
rewrite_patterns!(sci, PRINT_FUSION_RULES)
constants = propagate_constants(sci)
rewrite_patterns!(sci, OPTIMIZATION_RULES; constants)
alias_result = alias_analysis_pass!(sci)
token_order_pass!(sci, alias_result)
dce_pass!(sci)
end