Skip to content

Commit c511ec5

Browse files
maleadtclaude
andcommitted
Algebra rules: replace _const_scalar with const_value, generalize same_const
Remove the redundant _const_scalar helper (near-duplicate of const_value) and make same_const a factory taking arbitrary binding names, so it can be reused with rules matching more than two constants. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent a6eca37 commit c511ec5

File tree

1 file changed

+11
-24
lines changed

1 file changed

+11
-24
lines changed

src/compiler/passes/pipeline.jl

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -43,29 +43,16 @@ fma_fusion_pass!(sci::StructuredIRCode) = rewrite_patterns!(sci, FMA_RULES)
4343
# Cancel inverse addi/subi pairs: x+c-c → x, x-c+c → x.
4444
# Repeated ~c binds enforce that both operands are the same value.
4545

46-
# Guard: check that ~c0 and ~c1 resolve to the same constant value.
47-
# Like MLIR's ConstantLikeMatcher: matches on attribute value, not SSA identity.
48-
function same_const(match, driver)
49-
c0 = match.bindings[:c0]
50-
c1 = match.bindings[:c1]
51-
v0 = _const_scalar(driver.constants, c0)
52-
v1 = _const_scalar(driver.constants, c1)
53-
v0 !== nothing && v1 !== nothing && v0 == v1
54-
end
55-
56-
# Resolve an SSA value or literal to its scalar constant, or nothing.
57-
function _const_scalar(constants, @nospecialize(op))
58-
if op isa Number
59-
return op
60-
elseif op isa SSAValue
61-
c = get(constants, op, nothing)
62-
c isa AbstractArray || return nothing
63-
isempty(c) && return nothing
64-
v = first(c)
65-
all(==(v), c) && return v
66-
return nothing
46+
# Guard factory: check that the given bindings all resolve to the same constant
47+
# value. Like MLIR's ConstantLikeMatcher: matches on attribute value, not SSA
48+
# identity. Returns a guard function for use with @rewrite.
49+
function same_const(keys::Symbol...)
50+
(match, driver) -> begin
51+
vals = map(keys) do k
52+
const_value(driver.constants, match.bindings[k])
53+
end
54+
all(!isnothing, vals) && allequal(vals)
6755
end
68-
return nothing
6956
end
7057

7158
const ALGEBRA_RULES = RewriteRule[
@@ -77,8 +64,8 @@ const ALGEBRA_RULES = RewriteRule[
7764
# (different SSA defs, same constant). Catches 1-based indexing patterns where
7865
# arange(N)+1 produces one broadcast(1) and gather's -1 produces another.
7966
# Generalizes MLIR's arith.addi/subi canonicalization for matching constants.
80-
@rewrite(Intrinsics.subi(Intrinsics.addi(~x, ~c0), ~c1) => ~x, same_const)
81-
@rewrite(Intrinsics.addi(Intrinsics.subi(~x, ~c0), ~c1) => ~x, same_const)
67+
@rewrite(Intrinsics.subi(Intrinsics.addi(~x, ~c0), ~c1) => ~x, same_const(:c0, :c1))
68+
@rewrite(Intrinsics.addi(Intrinsics.subi(~x, ~c0), ~c1) => ~x, same_const(:c0, :c1))
8269
]
8370

8471
algebra_pass!(sci::StructuredIRCode) = rewrite_patterns!(sci, ALGEBRA_RULES)

0 commit comments

Comments
 (0)