|
| 1 | +# Pass Pipeline |
| 2 | +# |
| 3 | +# Defines all IR passes and their execution order. Rewrite-based passes are |
| 4 | +# defined inline here; complex imperative passes live in their own files |
| 5 | +# (alias_analysis.jl, token_order.jl, dce.jl) and are called from run_passes!. |
| 6 | + |
| 7 | +#============================================================================= |
| 8 | + IR Normalization (rewrite) |
| 9 | +=============================================================================# |
| 10 | + |
| 11 | +# Lowers Julia Core intrinsics and builtins to cuTile Intrinsics. |
| 12 | +# Core intrinsics appear in the SCI either because: |
| 13 | +# - IRStructurizer introduces them for control flow (loop bounds, increments) |
| 14 | +# - Julia's type inference inlined Base functions down to Core intrinsics |
| 15 | +# (e.g., Base.:-(x::Int32, y::Int32) → Core.Intrinsics.sub_int(x, y)) |
| 16 | + |
| 17 | +const NORMALIZE_RULES = RewriteRule[ |
| 18 | + # Integer arithmetic |
| 19 | + @rewrite Core.Intrinsics.add_int(~x, ~y) => Intrinsics.addi(~x, ~y) |
| 20 | + @rewrite Core.Intrinsics.sub_int(~x, ~y) => Intrinsics.subi(~x, ~y) |
| 21 | + @rewrite Core.Intrinsics.mul_int(~x, ~y) => Intrinsics.muli(~x, ~y) |
| 22 | + @rewrite Core.Intrinsics.neg_int(~x) => Intrinsics.negi(~x) |
| 23 | + |
| 24 | + # Integer comparison |
| 25 | + @rewrite Core.Intrinsics.slt_int(~x, ~y) => |
| 26 | + Intrinsics.cmpi(~x, ~y, $(ComparisonPredicate.LessThan), $(Signedness.Signed)) |
| 27 | + @rewrite Core.Intrinsics.sle_int(~x, ~y) => |
| 28 | + Intrinsics.cmpi(~x, ~y, $(ComparisonPredicate.LessThanOrEqual), $(Signedness.Signed)) |
| 29 | + @rewrite Core.Intrinsics.ult_int(~x, ~y) => |
| 30 | + Intrinsics.cmpi(~x, ~y, $(ComparisonPredicate.LessThan), $(Signedness.Unsigned)) |
| 31 | + |
| 32 | + # Bitwise |
| 33 | + @rewrite Core.Intrinsics.and_int(~x, ~y) => Intrinsics.andi(~x, ~y) |
| 34 | + @rewrite Core.Intrinsics.or_int(~x, ~y) => Intrinsics.ori(~x, ~y) |
| 35 | + @rewrite Core.Intrinsics.xor_int(~x, ~y) => Intrinsics.xori(~x, ~y) |
| 36 | + |
| 37 | + # not_int: xori with all-ones constant (type-dependent) |
| 38 | + @rewrite Core.Intrinsics.not_int(~x::Bool) => Intrinsics.xori(~x, $(true)) |
| 39 | + @rewrite Core.Intrinsics.not_int(~x::Int32) => Intrinsics.xori(~x, $(Int32(-1))) |
| 40 | + @rewrite Core.Intrinsics.not_int(~x::Int64) => Intrinsics.xori(~x, $(Int64(-1))) |
| 41 | + @rewrite Core.Intrinsics.not_int(~x::UInt32) => Intrinsics.xori(~x, $(~UInt32(0))) |
| 42 | + @rewrite Core.Intrinsics.not_int(~x::UInt64) => Intrinsics.xori(~x, $(~UInt64(0))) |
| 43 | + |
| 44 | + # Float arithmetic |
| 45 | + @rewrite Core.Intrinsics.add_float(~x, ~y) => Intrinsics.addf(~x, ~y) |
| 46 | + @rewrite Core.Intrinsics.sub_float(~x, ~y) => Intrinsics.subf(~x, ~y) |
| 47 | + @rewrite Core.Intrinsics.mul_float(~x, ~y) => Intrinsics.mulf(~x, ~y) |
| 48 | + @rewrite Core.Intrinsics.div_float(~x, ~y) => Intrinsics.divf(~x, ~y) |
| 49 | + @rewrite Core.Intrinsics.neg_float(~x) => Intrinsics.negf(~x) |
| 50 | + |
| 51 | + # Float comparison |
| 52 | + @rewrite Core.Intrinsics.lt_float(~x, ~y) => |
| 53 | + Intrinsics.cmpf(~x, ~y, $(ComparisonPredicate.LessThan)) |
| 54 | + @rewrite Core.Intrinsics.le_float(~x, ~y) => |
| 55 | + Intrinsics.cmpf(~x, ~y, $(ComparisonPredicate.LessThanOrEqual)) |
| 56 | + @rewrite Core.Intrinsics.eq_float(~x, ~y) => |
| 57 | + Intrinsics.cmpf(~x, ~y, $(ComparisonPredicate.Equal)) |
| 58 | + @rewrite Core.Intrinsics.ne_float(~x, ~y) => |
| 59 | + Intrinsics.cmpf(~x, ~y, $(ComparisonPredicate.NotEqual)) |
| 60 | + |
| 61 | + # Builtins |
| 62 | + @rewrite (===)(~x, ~y) => |
| 63 | + Intrinsics.cmpi(~x, ~y, $(ComparisonPredicate.Equal), $(Signedness.Signed)) |
| 64 | + @rewrite Core.ifelse(~c, ~x, ~y) => Intrinsics.select(~c, ~x, ~y) |
| 65 | +] |
| 66 | + |
| 67 | +normalize_pass!(sci::StructuredIRCode) = rewrite_patterns!(sci, NORMALIZE_RULES) |
| 68 | + |
| 69 | +#============================================================================= |
| 70 | + Scalar View Elimination (rewrite) |
| 71 | +=============================================================================# |
| 72 | + |
| 73 | +# Removes redundant to_scalar(from_scalar(x, S)) chains from Julia's broadcast |
| 74 | +# system. Transparent op tracing handles intermediate broadcasts. |
| 75 | + |
| 76 | +const SVE_RULES = RewriteRule[ |
| 77 | + @rewrite Intrinsics.to_scalar(Intrinsics.from_scalar(~x, ~_)) => ~x |
| 78 | +] |
| 79 | + |
| 80 | +scalar_view_elim_pass!(sci::StructuredIRCode) = rewrite_patterns!(sci, SVE_RULES) |
| 81 | + |
| 82 | +#============================================================================= |
| 83 | + FMA Fusion (rewrite) |
| 84 | +=============================================================================# |
| 85 | + |
| 86 | +# mul+add/sub → fma to reduce register pressure. |
| 87 | +# Mirrors cuTile Python's fuse_mul_addsub in rewrite_patterns.py. |
| 88 | +# |
| 89 | +# Two rule variants per pattern: 2-arg (default RM/FTZ from normalization) and |
| 90 | +# 4-arg (explicit RM/FTZ). Repeated binds ~rm/~ftz enforce consistency between |
| 91 | +# mul and add/sub — mismatched flags cause the pattern match to fail, preventing |
| 92 | +# incorrect fusion. |
| 93 | + |
| 94 | +const FMA_RULES = RewriteRule[ |
| 95 | + # Default RM/FTZ (2-arg forms from normalization) |
| 96 | + @rewrite Intrinsics.addf(one_use(Intrinsics.mulf(~x, ~y)), ~z) => |
| 97 | + Intrinsics.fma(~x, ~y, ~z) |
| 98 | + @rewrite Intrinsics.addf(~z, one_use(Intrinsics.mulf(~x, ~y))) => |
| 99 | + Intrinsics.fma(~x, ~y, ~z) |
| 100 | + @rewrite Intrinsics.subf(one_use(Intrinsics.mulf(~x, ~y)), ~z) => |
| 101 | + Intrinsics.fma(~x, ~y, Intrinsics.negf(~z)) |
| 102 | + |
| 103 | + # Explicit RM/FTZ: repeated ~rm/~ftz binds require mul and add/sub to agree |
| 104 | + @rewrite Intrinsics.addf(one_use(Intrinsics.mulf(~x, ~y, ~rm, ~ftz)), ~z, ~rm, ~ftz) => |
| 105 | + Intrinsics.fma(~x, ~y, ~z, ~rm, ~ftz) |
| 106 | + @rewrite Intrinsics.addf(~z, one_use(Intrinsics.mulf(~x, ~y, ~rm, ~ftz)), ~rm, ~ftz) => |
| 107 | + Intrinsics.fma(~x, ~y, ~z, ~rm, ~ftz) |
| 108 | + @rewrite Intrinsics.subf(one_use(Intrinsics.mulf(~x, ~y, ~rm, ~ftz)), ~z, ~rm, ~ftz) => |
| 109 | + Intrinsics.fma(~x, ~y, Intrinsics.negf(~z), ~rm, ~ftz) |
| 110 | +] |
| 111 | + |
| 112 | +fma_fusion_pass!(sci::StructuredIRCode) = rewrite_patterns!(sci, FMA_RULES) |
| 113 | + |
| 114 | +#============================================================================= |
| 115 | + Pass Pipeline |
| 116 | +=============================================================================# |
| 117 | + |
| 118 | +""" |
| 119 | + run_passes!(sci::StructuredIRCode) |
| 120 | +
|
| 121 | +Run the full pass pipeline on a StructuredIRCode. Called for both kernel |
| 122 | +and subprogram compilation. |
| 123 | +""" |
| 124 | +function run_passes!(sci::StructuredIRCode) |
| 125 | + # Rewrite passes (order matters: normalize before optimize, SVE before FMA) |
| 126 | + normalize_pass!(sci) |
| 127 | + scalar_view_elim_pass!(sci) |
| 128 | + fma_fusion_pass!(sci) |
| 129 | + |
| 130 | + # Memory ordering |
| 131 | + alias_result = alias_analysis_pass!(sci) |
| 132 | + token_order_pass!(sci, alias_result) |
| 133 | + |
| 134 | + # Cleanup |
| 135 | + dce_pass!(sci) |
| 136 | +end |
0 commit comments