Skip to content

Canonicalize the IR before optimization#158

Merged
maleadt merged 7 commits intomainfrom
tb/canonicalize
Mar 31, 2026
Merged

Canonicalize the IR before optimization#158
maleadt merged 7 commits intomainfrom
tb/canonicalize

Conversation

@maleadt
Copy link
Copy Markdown
Member

@maleadt maleadt commented Mar 31, 2026

Specifically, convert every scalar SSA value into a Tile. This simplifies compilation, and matches cuTile Python.

Before:

julia> ct.code_structured(vadd, Tuple{ct.TileArray{Float32, 1, ct.ArraySpec{1}(128, true, (0,), (32,))},
                                 ct.TileArray{Float32, 1, ct.ArraySpec{1}(128, true, (0,), (32,))},
                                 ct.TileArray{Float32, 1, ct.ArraySpec{1}(128, true, (0,), (32,))},
                                 ct.Constant{Int64, 16}})
1-element Vector{Pair{IRStructurizer.StructuredIRCode, DataType}}:
 StructuredIRCode(
│ %20 = cuTile.MakeTokenNode()::cuTile.TokenType()
│ %1  = cuTile.Intrinsics.get_tile_block_id(0)::Int32
│ %3  = cuTile.Intrinsics.make_tensor_view(_2)::cuTile.TensorView{Float32, 1}
│ %4  = cuTile.Intrinsics.make_partition_view(%3, (16,), cuTile.PaddingMode.Undetermined, cuTile.nothing)::cuTile.PartitionView{Float32, 1, Tuple{16}}
│ %6  = Core.tuple(%1)::Tuple{Int32}
│ %7  = cuTile.Intrinsics.load_partition_view(%4, cuTile.nothing, cuTile.nothing, %6, %20)::cuTile.FloatTile{Tuple{16}, Float32}
│ %8  = cuTile.Intrinsics.make_tensor_view(_3)::cuTile.TensorView{Float32, 1}
│ %9  = cuTile.Intrinsics.make_partition_view(%8, (16,), cuTile.PaddingMode.Undetermined, cuTile.nothing)::cuTile.PartitionView{Float32, 1, Tuple{16}}
│ %11 = Core.tuple(%1)::Tuple{Int32}
│ %12 = cuTile.Intrinsics.load_partition_view(%9, cuTile.nothing, cuTile.nothing, %11, %20)::cuTile.FloatTile{Tuple{16}, Float32}
│ %13 = cuTile.Intrinsics.addf(%7, %12)::cuTile.FloatTile{Tuple{16}, Float32}
│ %15 = Core.tuple(%1)::Tuple{Int32}
│ %16 = cuTile.Intrinsics.make_tensor_view(_4)::cuTile.TensorView{Float32, 1}
│ %17 = cuTile.Intrinsics.make_partition_view(%16, (16,), cuTile.PaddingMode.Undetermined, nothing)::cuTile.PartitionView{Float32, 1, Tuple{16}}
│ %18 = cuTile.Intrinsics.store_partition_view(%17, %13, nothing, nothing, %15, %20)::Nothing
└ return nothing
) => Nothing

After:

julia> ct.code_structured(vadd, Tuple{ct.TileArray{Float32, 1, ct.ArraySpec{1}(128, true, (0,), (32,))},
                                 ct.TileArray{Float32, 1, ct.ArraySpec{1}(128, true, (0,), (32,))},
                                 ct.TileArray{Float32, 1, ct.ArraySpec{1}(128, true, (0,), (32,))},
                                 ct.Constant{Int64, 16}})
1-element Vector{Pair{IRStructurizer.StructuredIRCode, DataType}}:
 StructuredIRCode(
│ %20 = cuTile.MakeTokenNode()::cuTile.TokenType()
│ %1  = cuTile.Intrinsics.get_tile_block_id(0)::cuTile.IntTile{Tuple{}, Int32}
│ %3  = cuTile.Intrinsics.make_tensor_view(_2)::cuTile.TensorView{Float32, 1}
│ %4  = cuTile.Intrinsics.make_partition_view(%3, (16,), cuTile.PaddingMode.Undetermined, cuTile.nothing)::cuTile.PartitionView{Float32, 1, Tuple{16}}
│ %6  = Core.tuple(%1)::Tuple{cuTile.IntTile{Tuple{}, Int32}}
│ %7  = cuTile.Intrinsics.load_partition_view(%4, cuTile.nothing, cuTile.nothing, %6, %20)::cuTile.FloatTile{Tuple{16}, Float32}
│ %8  = cuTile.Intrinsics.make_tensor_view(_3)::cuTile.TensorView{Float32, 1}
│ %9  = cuTile.Intrinsics.make_partition_view(%8, (16,), cuTile.PaddingMode.Undetermined, cuTile.nothing)::cuTile.PartitionView{Float32, 1, Tuple{16}}
│ %11 = Core.tuple(%1)::Tuple{cuTile.IntTile{Tuple{}, Int32}}
│ %12 = cuTile.Intrinsics.load_partition_view(%9, cuTile.nothing, cuTile.nothing, %11, %20)::cuTile.FloatTile{Tuple{16}, Float32}
│ %13 = cuTile.Intrinsics.addf(%7, %12)::cuTile.FloatTile{Tuple{16}, Float32}
│ %15 = Core.tuple(%1)::Tuple{cuTile.IntTile{Tuple{}, Int32}}
│ %16 = cuTile.Intrinsics.make_tensor_view(_4)::cuTile.TensorView{Float32, 1}
│ %17 = cuTile.Intrinsics.make_partition_view(%16, (16,), cuTile.PaddingMode.Undetermined, nothing)::cuTile.PartitionView{Float32, 1, Tuple{16}}
│ %18 = cuTile.Intrinsics.store_partition_view(%17, %13, nothing, nothing, %15, %20)::Nothing
└ return nothing
) => Nothing

Motivation: this does away with the existing scalar elimination pass and should make it easier to add rewrite rules (that now don't have to think about scalar vs tile).

@maleadt maleadt merged commit 0167be0 into main Mar 31, 2026
9 checks passed
@maleadt maleadt deleted the tb/canonicalize branch March 31, 2026 15:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant