@@ -27,11 +27,13 @@ struct PBind <: PatternNode; name::Symbol; end
2727struct PTypedBind <: PatternNode ; name:: Symbol ; type:: Type ; end
2828struct POneUse <: PatternNode ; inner:: PatternNode ; end
2929struct PLiteral <: PatternNode ; val:: Any ; end
30+ struct PSplat <: PatternNode ; name:: Symbol ; end # ~x... — captures remaining operands
3031
3132abstract type RewriteNode end
3233struct RCall <: RewriteNode ; func:: Any ; operands:: Vector{RewriteNode} ; end
3334struct RBind <: RewriteNode ; name:: Symbol ; end
3435struct RConst <: RewriteNode ; val:: Any ; end
36+ struct RSplat <: RewriteNode ; name:: Symbol ; end # ~x... — expands splat binding
3537
3638"""
3739 RFunc(func)
@@ -117,6 +119,15 @@ function compile_lhs(ex)
117119 if ex isa Expr && ex. head === :$
118120 return :(PLiteral ($ (ex. args[1 ])))
119121 end
122+ # ~x... on the LHS: splat capture of remaining operands
123+ # Julia parses `~x...` as Expr(:..., Expr(:call, :~, :x))
124+ if ex isa Expr && ex. head === :... && length (ex. args) == 1
125+ inner = ex. args[1 ]
126+ if inner isa Expr && inner. head === :call && inner. args[1 ] === :~ && length (inner. args) == 2
127+ name = inner. args[2 ]
128+ return :(PSplat ($ (QuoteNode (name))))
129+ end
130+ end
120131 ex isa Expr && ex. head === :call || error (" @rewrite LHS: expected call, got $ex " )
121132 f = ex. args[1 ]
122133 if f === :~
@@ -134,6 +145,14 @@ function compile_rhs(ex)
134145 if ex isa Expr && ex. head === :$
135146 return :(RConst ($ (ex. args[1 ])))
136147 end
148+ # ~x... on the RHS: expand splat binding
149+ if ex isa Expr && ex. head === :... && length (ex. args) == 1
150+ inner = ex. args[1 ]
151+ if inner isa Expr && inner. head === :call && inner. args[1 ] === :~ && length (inner. args) == 2
152+ name = inner. args[2 ]
153+ return :(RSplat ($ (QuoteNode (name))))
154+ end
155+ end
137156 ex isa Expr && ex. head === :call || error (" @rewrite RHS: expected call or \$ const, got $ex " )
138157 f = ex. args[1 ]
139158 f === :~ && return :(RBind ($ (QuoteNode (ex. args[2 ]))))
@@ -283,14 +302,23 @@ function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PCall,
283302
284303 if entry. func === pat. func
285304 ops = def_operands (entry)
286- if length (ops) == length (pat. operands)
305+ has_splat = ! isempty (pat. operands) && last (pat. operands) isa PSplat
306+ n_fixed = has_splat ? length (pat. operands) - 1 : length (pat. operands)
307+
308+ if has_splat ? length (ops) >= n_fixed : length (ops) == n_fixed
287309 result = MatchResult (Dict {Symbol,Any} (), SSAValue[val])
288- for (op, sub) in zip (ops, pat. operands)
289- m = pattern_match (driver, op, sub, entry. block)
310+ # Match fixed operands
311+ for i in 1 : n_fixed
312+ m = pattern_match (driver, ops[i], pat. operands[i], entry. block)
290313 m === nothing && return nothing
291314 merge_bindings! (result. bindings, m. bindings) || return nothing
292315 append! (result. matched_ssas, m. matched_ssas)
293316 end
317+ # Capture remaining operands into the splat binding
318+ if has_splat
319+ splat_name = pat. operands[end ]:: PSplat
320+ result. bindings[splat_name. name] = ops[n_fixed+ 1 : end ]
321+ end
294322 return result
295323 end
296324 end
@@ -343,7 +371,15 @@ their type from the first SSA operand, since element-wise ops preserve type."""
343371resolve_rhs (driver, block, ref, op:: RBind , bindings, root_typ) = bindings[op. name]
344372resolve_rhs (driver, block, ref, op:: RConst , bindings, root_typ) = op. val
345373function resolve_rhs (driver:: RewriteDriver , block, ref, op:: RCall , bindings, root_typ)
346- operands = Any[resolve_rhs (driver, block, ref, sub, bindings, root_typ) for sub in op. operands]
374+ # Flatten RSplat nodes: each RSplat expands to multiple operands
375+ operands = Any[]
376+ for sub in op. operands
377+ if sub isa RSplat
378+ append! (operands, bindings[sub. name])
379+ else
380+ push! (operands, resolve_rhs (driver, block, ref, sub, bindings, root_typ))
381+ end
382+ end
347383 # Infer type from first SSA operand — correct for element-wise ops (addi,
348384 # subi, negf, etc.) whose result type matches their operands. Falls back to
349385 # root_typ when no SSA operand is available.
@@ -477,8 +513,15 @@ function apply_rewrite!(driver::RewriteDriver, block, val::SSAValue, rule, match
477513 end
478514 pos = findfirst (== (val. id), block. body. ssa_idxes)
479515 typ = block. body. types[pos]
480- operands = Any[resolve_rhs (driver, block, val, op, match. bindings, typ)
481- for op in rule. rhs. operands]
516+ # Build operands, flattening RSplat nodes into multiple operands
517+ operands = Any[]
518+ for op in rule. rhs. operands
519+ if op isa RSplat
520+ append! (operands, match. bindings[op. name])
521+ else
522+ push! (operands, resolve_rhs (driver, block, val, op, match. bindings, typ))
523+ end
524+ end
482525 # Recompute pos: resolve_rhs may insert instructions before val
483526 # (e.g. negf in subf→fma), shifting positions.
484527 pos = findfirst (== (val. id), block. body. ssa_idxes)
0 commit comments