Skip to content

Commit c373daa

Browse files
authored
Merge pull request #152 from JuliaGPU/tb/rewrite_improvements
Improve the rewriter
2 parents aa5f34e + 1419e59 commit c373daa

File tree

10 files changed

+254
-204
lines changed

10 files changed

+254
-204
lines changed

src/compiler/codegen.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
# Codegen: Julia IR -> Tile IR bytecode
22

3-
include("codegen/utils.jl")
4-
include("codegen/passes/normalize.jl") # normalize_ir!
53
include("codegen/passes/token_keys.jl") # TokenKey, TokenRole, ACQUIRE_TOKEN_KEY
64
include("codegen/passes/rewrite.jl") # @rewrite, rewrite_patterns! framework
7-
include("codegen/passes/rewrite_patterns.jl") # scalar_view_elim_pass!, fma_fusion_pass!
85
include("codegen/passes/alias_analysis.jl") # alias_analysis_pass!
96
include("codegen/passes/token_order.jl") # token_order_pass!
107
include("codegen/passes/dce.jl") # dce_pass!
8+
include("codegen/passes/pipeline.jl") # run_passes!
119
include("codegen/kernel.jl")
1210
include("codegen/control_flow.jl")
1311
include("codegen/statements.jl")

src/compiler/codegen/kernel.jl

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -145,20 +145,8 @@ function emit_kernel!(writer::BytecodeWriter, func_buf::Vector{UInt8},
145145
# ReturnNode terminators to YieldOp, which the token pass then extends.
146146
hoist_returns!(ctx.sci.entry)
147147

148-
# Normalize Julia Core intrinsics to cuTile Intrinsics equivalents.
149-
normalize_ir!(sci)
150-
151-
# Eliminate redundant to_scalar/from_scalar chains from broadcast wrapping.
152-
scalar_view_elim_pass!(sci)
153-
154-
# Fuse mul+add/sub into fma to reduce register pressure.
155-
fma_fusion_pass!(sci)
156-
157-
158-
# Run alias analysis and token ordering pass on the structured IR.
159-
alias_result = alias_analysis_pass!(sci)
160-
token_order_pass!(sci, alias_result)
161-
dce_pass!(sci)
148+
# Run the pass pipeline (normalize, optimize, token ordering, DCE).
149+
run_passes!(sci)
162150

163151
# Cache the token bytecode type for codegen
164152
ctx.token_type = Token(tt)
@@ -326,8 +314,8 @@ function emit_subprogram!(ctx::CGCtx, func, arg_types::Vector,
326314
compile_hook[] = old_hook
327315
end
328316

329-
# 2b. Normalize Julia intrinsics in subprogram IR
330-
normalize_ir!(sci)
317+
# 2b. Run the pass pipeline on subprogram IR
318+
run_passes!(sci)
331319

332320
# 3. Create sub-context
333321
sub_ctx = CGCtx(; ctx.cb, ctx.tt, sci,

src/compiler/codegen/passes/normalize.jl

Lines changed: 0 additions & 113 deletions
This file was deleted.
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Pass Pipeline
2+
#
3+
# Defines all IR passes and their execution order. Rewrite-based passes are
4+
# defined inline here; complex imperative passes live in their own files
5+
# (alias_analysis.jl, token_order.jl, dce.jl) and are called from run_passes!.
6+
7+
#=============================================================================
8+
IR Normalization (rewrite)
9+
=============================================================================#
10+
11+
# Lowers Julia Core intrinsics and builtins to cuTile Intrinsics.
12+
# Core intrinsics appear in the SCI either because:
13+
# - IRStructurizer introduces them for control flow (loop bounds, increments)
14+
# - Julia's type inference inlined Base functions down to Core intrinsics
15+
# (e.g., Base.:-(x::Int32, y::Int32) → Core.Intrinsics.sub_int(x, y))
16+
17+
const NORMALIZE_RULES = RewriteRule[
18+
# Integer arithmetic
19+
@rewrite Core.Intrinsics.add_int(~x, ~y) => Intrinsics.addi(~x, ~y)
20+
@rewrite Core.Intrinsics.sub_int(~x, ~y) => Intrinsics.subi(~x, ~y)
21+
@rewrite Core.Intrinsics.mul_int(~x, ~y) => Intrinsics.muli(~x, ~y)
22+
@rewrite Core.Intrinsics.neg_int(~x) => Intrinsics.negi(~x)
23+
24+
# Integer comparison
25+
@rewrite Core.Intrinsics.slt_int(~x, ~y) =>
26+
Intrinsics.cmpi(~x, ~y, $(ComparisonPredicate.LessThan), $(Signedness.Signed))
27+
@rewrite Core.Intrinsics.sle_int(~x, ~y) =>
28+
Intrinsics.cmpi(~x, ~y, $(ComparisonPredicate.LessThanOrEqual), $(Signedness.Signed))
29+
@rewrite Core.Intrinsics.ult_int(~x, ~y) =>
30+
Intrinsics.cmpi(~x, ~y, $(ComparisonPredicate.LessThan), $(Signedness.Unsigned))
31+
32+
# Bitwise
33+
@rewrite Core.Intrinsics.and_int(~x, ~y) => Intrinsics.andi(~x, ~y)
34+
@rewrite Core.Intrinsics.or_int(~x, ~y) => Intrinsics.ori(~x, ~y)
35+
@rewrite Core.Intrinsics.xor_int(~x, ~y) => Intrinsics.xori(~x, ~y)
36+
37+
# not_int: xori with all-ones constant (type-dependent)
38+
@rewrite Core.Intrinsics.not_int(~x::Bool) => Intrinsics.xori(~x, $(true))
39+
@rewrite Core.Intrinsics.not_int(~x::Int32) => Intrinsics.xori(~x, $(Int32(-1)))
40+
@rewrite Core.Intrinsics.not_int(~x::Int64) => Intrinsics.xori(~x, $(Int64(-1)))
41+
@rewrite Core.Intrinsics.not_int(~x::UInt32) => Intrinsics.xori(~x, $(~UInt32(0)))
42+
@rewrite Core.Intrinsics.not_int(~x::UInt64) => Intrinsics.xori(~x, $(~UInt64(0)))
43+
44+
# Float arithmetic
45+
@rewrite Core.Intrinsics.add_float(~x, ~y) => Intrinsics.addf(~x, ~y)
46+
@rewrite Core.Intrinsics.sub_float(~x, ~y) => Intrinsics.subf(~x, ~y)
47+
@rewrite Core.Intrinsics.mul_float(~x, ~y) => Intrinsics.mulf(~x, ~y)
48+
@rewrite Core.Intrinsics.div_float(~x, ~y) => Intrinsics.divf(~x, ~y)
49+
@rewrite Core.Intrinsics.neg_float(~x) => Intrinsics.negf(~x)
50+
51+
# Float comparison
52+
@rewrite Core.Intrinsics.lt_float(~x, ~y) =>
53+
Intrinsics.cmpf(~x, ~y, $(ComparisonPredicate.LessThan))
54+
@rewrite Core.Intrinsics.le_float(~x, ~y) =>
55+
Intrinsics.cmpf(~x, ~y, $(ComparisonPredicate.LessThanOrEqual))
56+
@rewrite Core.Intrinsics.eq_float(~x, ~y) =>
57+
Intrinsics.cmpf(~x, ~y, $(ComparisonPredicate.Equal))
58+
@rewrite Core.Intrinsics.ne_float(~x, ~y) =>
59+
Intrinsics.cmpf(~x, ~y, $(ComparisonPredicate.NotEqual))
60+
61+
# Builtins
62+
@rewrite (===)(~x, ~y) =>
63+
Intrinsics.cmpi(~x, ~y, $(ComparisonPredicate.Equal), $(Signedness.Signed))
64+
@rewrite Core.ifelse(~c, ~x, ~y) => Intrinsics.select(~c, ~x, ~y)
65+
]
66+
67+
normalize_pass!(sci::StructuredIRCode) = rewrite_patterns!(sci, NORMALIZE_RULES)
68+
69+
#=============================================================================
70+
Scalar View Elimination (rewrite)
71+
=============================================================================#
72+
73+
# Removes redundant to_scalar(from_scalar(x, S)) chains from Julia's broadcast
74+
# system. Transparent op tracing handles intermediate broadcasts.
75+
76+
const SVE_RULES = RewriteRule[
77+
@rewrite Intrinsics.to_scalar(Intrinsics.from_scalar(~x, ~_)) => ~x
78+
]
79+
80+
scalar_view_elim_pass!(sci::StructuredIRCode) = rewrite_patterns!(sci, SVE_RULES)
81+
82+
#=============================================================================
83+
FMA Fusion (rewrite)
84+
=============================================================================#
85+
86+
# mul+add/sub → fma to reduce register pressure.
87+
# Mirrors cuTile Python's fuse_mul_addsub in rewrite_patterns.py.
88+
#
89+
# Two rule variants per pattern: 2-arg (default RM/FTZ from normalization) and
90+
# 4-arg (explicit RM/FTZ). Repeated binds ~rm/~ftz enforce consistency between
91+
# mul and add/sub — mismatched flags cause the pattern match to fail, preventing
92+
# incorrect fusion.
93+
94+
const FMA_RULES = RewriteRule[
95+
# Default RM/FTZ (2-arg forms from normalization)
96+
@rewrite Intrinsics.addf(one_use(Intrinsics.mulf(~x, ~y)), ~z) =>
97+
Intrinsics.fma(~x, ~y, ~z)
98+
@rewrite Intrinsics.addf(~z, one_use(Intrinsics.mulf(~x, ~y))) =>
99+
Intrinsics.fma(~x, ~y, ~z)
100+
@rewrite Intrinsics.subf(one_use(Intrinsics.mulf(~x, ~y)), ~z) =>
101+
Intrinsics.fma(~x, ~y, Intrinsics.negf(~z))
102+
103+
# Explicit RM/FTZ: repeated ~rm/~ftz binds require mul and add/sub to agree
104+
@rewrite Intrinsics.addf(one_use(Intrinsics.mulf(~x, ~y, ~rm, ~ftz)), ~z, ~rm, ~ftz) =>
105+
Intrinsics.fma(~x, ~y, ~z, ~rm, ~ftz)
106+
@rewrite Intrinsics.addf(~z, one_use(Intrinsics.mulf(~x, ~y, ~rm, ~ftz)), ~rm, ~ftz) =>
107+
Intrinsics.fma(~x, ~y, ~z, ~rm, ~ftz)
108+
@rewrite Intrinsics.subf(one_use(Intrinsics.mulf(~x, ~y, ~rm, ~ftz)), ~z, ~rm, ~ftz) =>
109+
Intrinsics.fma(~x, ~y, Intrinsics.negf(~z), ~rm, ~ftz)
110+
]
111+
112+
fma_fusion_pass!(sci::StructuredIRCode) = rewrite_patterns!(sci, FMA_RULES)
113+
114+
#=============================================================================
115+
Pass Pipeline
116+
=============================================================================#
117+
118+
"""
119+
run_passes!(sci::StructuredIRCode)
120+
121+
Run the full pass pipeline on a StructuredIRCode. Called for both kernel
122+
and subprogram compilation.
123+
"""
124+
function run_passes!(sci::StructuredIRCode)
125+
# Rewrite passes (order matters: normalize before optimize, SVE before FMA)
126+
normalize_pass!(sci)
127+
scalar_view_elim_pass!(sci)
128+
fma_fusion_pass!(sci)
129+
130+
# Memory ordering
131+
alias_result = alias_analysis_pass!(sci)
132+
token_order_pass!(sci, alias_result)
133+
134+
# Cleanup
135+
dce_pass!(sci)
136+
end

0 commit comments

Comments
 (0)