Skip to content

Add printing functionality (+ rewriter splat functionality)#173

Merged
maleadt merged 4 commits intomainfrom
tb/print
Apr 3, 2026
Merged

Add printing functionality (+ rewriter splat functionality)#173
maleadt merged 4 commits intomainfrom
tb/print

Conversation

@maleadt
Copy link
Copy Markdown
Member

@maleadt maleadt commented Apr 3, 2026

Demo:

julia> import cuTile as ct
julia> using CUDA

julia> function vadd(a, b, c, tile_size::Int)
              pid = ct.bid(1)

              tile_a = ct.load(a; index=pid, shape=(tile_size,))
              tile_b = ct.load(b; index=pid, shape=(tile_size,))
              result = tile_a + tile_b

              # print/println work just like in regular Julia
              println("Block $pid: a=", tile_a, " b=", tile_b, " sum=", result)

              ct.store(c; index=pid, tile=result)
              return
       end

julia> n = 16;
julia> a = CUDA.rand(Float32, n);
julia> b = CUDA.rand(Float32, n);
julia> c = CUDA.zeros(Float32, n);

julia> ct.launch(vadd, 1, a, b, c, ct.Constant(n))
Block 1: a=[0.178055, 0.674633, 0.646176, 0.129320, 0.444650, 0.574574, 0.932177, 0.942232, 0.093199, 0.513216, 0.477760, 0.720973, 0.546608, 0.632605, 0.837211, 0.463458] b=[0.976412, 0.223736, 0.531745, 0.285275, 0.526129, 0.461239, 0.318035, 0.677270, 0.134871, 0.357377, 0.712498, 0.847383, 0.027163, 0.844440, 0.328595, 0.341619] sum=[1.154467, 0.898369, 1.177921, 0.414594, 0.970779, 1.035813, 1.250212, 1.619502, 0.228070, 0.870593, 1.190258, 1.568356, 0.573771, 1.477045, 1.165806, 0.805076]

As opposed to CUDA.jl's @cuprint, this actually works with Strings in the IR, as well as the standard print functions (though without the IO redirection for now).

Julia IR:

2  1 ─ %1  =   dynamic (cuTile.Intrinsics.get_tile_block_id)(0)::Int32
   │   %2  = intrinsic Base.add_int(%1, 1)::Int32
4  │   %3  =   dynamic (cuTile.Intrinsics.make_tensor_view)(_2)::cuTile.TensorView{Float32, 1}
   │   %4  =   dynamic (cuTile.Intrinsics.make_partition_view)(%3, (16,), $(QuoteNode(cuTile.PaddingMode.Undetermined)), cuTile.nothing)::cuTile.PartitionView{Float32, 1, Tuple{16}}
   │   %5  = intrinsic Base.sub_int(%2, 1)::Int32
   │   %6  =   builtin Core.tuple(%5)::Tuple{Int32}
   │   %7  =   dynamic (cuTile.Intrinsics.load_partition_view)(%4, cuTile.nothing, cuTile.nothing, %6)::cuTile.FloatTile{Tuple{16}, Float32}
5  │   %8  =   dynamic (cuTile.Intrinsics.make_tensor_view)(_3)::cuTile.TensorView{Float32, 1}
   │   %9  =   dynamic (cuTile.Intrinsics.make_partition_view)(%8, (16,), $(QuoteNode(cuTile.PaddingMode.Undetermined)), cuTile.nothing)::cuTile.PartitionView{Float32, 1, Tuple{16}}
   │   %10 = intrinsic Base.sub_int(%2, 1)::Int32
   │   %11 =   builtin Core.tuple(%10)::Tuple{Int32}
   │   %12 =   dynamic (cuTile.Intrinsics.load_partition_view)(%9, cuTile.nothing, cuTile.nothing, %11)::cuTile.FloatTile{Tuple{16}, Float32}
6  │   %13 =   dynamic (cuTile.Intrinsics.addf)(%7, %12)::cuTile.FloatTile{Tuple{16}, Float32}
9  │   %14 =   dynamic (cuTile.Intrinsics.format_string)("Block ", %2, ": a=")::String
   │           dynamic (cuTile.Intrinsics.print_tko)(%14, %7, " b=", %12, " sum=", %13, "\n")::Nothing
11 │   %16 = intrinsic Base.sub_int(%2, 1)::Int32
   │   %17 =   builtin Core.tuple(%16)::Tuple{Int32}
   │   %18 =   dynamic (cuTile.Intrinsics.make_tensor_view)(_4)::cuTile.TensorView{Float32, 1}
   │   %19 =   dynamic (cuTile.Intrinsics.make_partition_view)(%18, (16,), $(QuoteNode(cuTile.PaddingMode.Undetermined)), nothing)::cuTile.PartitionView{Float32, 1, Tuple{16}}
   │           dynamic (cuTile.Intrinsics.store_partition_view)(%19, %13, nothing, nothing, %17)::Nothing
12 └──       return nothing

Structured IR:

StructuredIRCode(
│ %22 = cuTile.MakeTokenNode()::cuTile.TokenType()
│ %1  = cuTile.Intrinsics.get_tile_block_id(0)::cuTile.IntTile{Tuple{}, Int32}
│ %2  = cuTile.Intrinsics.addi(%1, 1)::cuTile.IntTile{Tuple{}, Int32}
│ %3  = cuTile.Intrinsics.make_tensor_view(_2)::cuTile.TensorView{Float32, 1}
│ %4  = cuTile.Intrinsics.make_partition_view(%3, (16,), cuTile.PaddingMode.Undetermined, cuTile.nothing)::cuTile.PartitionView{Float32, 1, Tuple{16}}
│ %6  = Core.tuple(%1)::Tuple{cuTile.IntTile{Tuple{}, Int32}}
│ %7  = cuTile.Intrinsics.load_partition_view(%4, cuTile.nothing, cuTile.nothing, %6, %22)::cuTile.FloatTile{Tuple{16}, Float32}
│ %23 = cuTile.TokenResultNode(7)::cuTile.TokenType()
│ %24 = cuTile.JoinTokensNode(Any[:(%22), :(%23)])::cuTile.TokenType()
│ %8  = cuTile.Intrinsics.make_tensor_view(_3)::cuTile.TensorView{Float32, 1}
│ %9  = cuTile.Intrinsics.make_partition_view(%8, (16,), cuTile.PaddingMode.Undetermined, cuTile.nothing)::cuTile.PartitionView{Float32, 1, Tuple{16}}
│ %11 = Core.tuple(%1)::Tuple{cuTile.IntTile{Tuple{}, Int32}}
│ %12 = cuTile.Intrinsics.load_partition_view(%9, cuTile.nothing, cuTile.nothing, %11, %22)::cuTile.FloatTile{Tuple{16}, Float32}
│ %25 = cuTile.TokenResultNode(12)::cuTile.TokenType()
│ %26 = cuTile.JoinTokensNode(Any[:(%22), :(%25)])::cuTile.TokenType()
│ %13 = cuTile.Intrinsics.addf(%7, %12)::cuTile.FloatTile{Tuple{16}, Float32}
│ %27 = cuTile.JoinTokensNode(Any[:(%22), :(%26), :(%24)])::cuTile.TokenType()
│ %15 = cuTile.Intrinsics.print_tko("Block ", %2, ": a=", %7, " b=", %12, " sum=", %13, "\n", %27)::Nothing
│ %28 = cuTile.TokenResultNode(15)::cuTile.TokenType()
│ %17 = Core.tuple(%1)::Tuple{cuTile.IntTile{Tuple{}, Int32}}
│ %18 = cuTile.Intrinsics.make_tensor_view(_4)::cuTile.TensorView{Float32, 1}
│ %19 = cuTile.Intrinsics.make_partition_view(%18, (16,), cuTile.PaddingMode.Undetermined, nothing)::cuTile.PartitionView{Float32, 1, Tuple{16}}
│ %29 = cuTile.JoinTokensNode(Any[:(%22), :(%28)])::cuTile.TokenType()
│ %20 = cuTile.Intrinsics.store_partition_view(%19, %13, nothing, nothing, %17, %29)::Nothing
└ return nothing
)

Generated Tile IR:

cuda_tile.module @kernels {
  entry @vadd(%arg0: tile<ptr<f32>>, %arg1: tile<i32>, %arg2: tile<i32>, %arg3: tile<ptr<f32>>, %arg4: tile<i32>, %arg5: tile<i32>, %arg6: tile<ptr<f32>>, %arg7: tile<i32>, %arg8: tile<i32>) {
    %cst_16_i64 = constant <i64: 16> : tile<i64>
    %assume = assume div_by<128>, %arg6 : tile<ptr<f32>>
    %assume_0 = assume bounded<0, ?>, %arg7 : tile<i32>
    %assume_assume = assume div_by<16>, %assume_0 : tile<i32>
    %tview = make_tensor_view %assume, shape = [%assume_assume], strides = [1] : tile<i32> -> tensor_view<?xf32, strides=[1]>
    %assume_1 = assume div_by<128>, %arg0 : tile<ptr<f32>>
    %assume_2 = assume bounded<0, ?>, %arg1 : tile<i32>
    %assume_assume_3 = assume div_by<16>, %assume_2 : tile<i32>
    %tview_4 = make_tensor_view %assume_1, shape = [%assume_assume_3], strides = [1] : tile<i32> -> tensor_view<?xf32, strides=[1]>
    %assume_5 = assume div_by<128>, %arg3 : tile<ptr<f32>>
    %assume_6 = assume bounded<0, ?>, %arg4 : tile<i32>
    %assume_assume_7 = assume div_by<16>, %assume_6 : tile<i32>
    %tview_8 = make_tensor_view %assume_5, shape = [%assume_assume_7], strides = [1] : tile<i32> -> tensor_view<?xf32, strides=[1]>
    %0 = make_token : token
    %blockId_x, %blockId_y, %blockId_z = get_tile_block_id : tile<i32>
    %cst_1_i32 = constant <i32: 1> : tile<i32>
    %1 = addi %blockId_x, %cst_1_i32 : tile<i32>
    %pview = make_partition_view %tview_4 : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>
    %tile, %result_token = load_view_tko weak %pview[%blockId_x] token = %0 : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>, tile<i32> -> tile<16xf32>, token
    %2 = join_tokens %0, %result_token : token
    %pview_9 = make_partition_view %tview_8 : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>
    %tile_10, %result_token_11 = load_view_tko weak %pview_9[%blockId_x] token = %0 : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>, tile<i32> -> tile<16xf32>, token
    %3 = join_tokens %0, %result_token_11 : token
    %4 = addf %tile, %tile_10  : tile<16xf32>
    %5 = join_tokens %0, %3, %2 : token
    %6 = print_tko "Block %d: a=%f b=%f sum=%f\0A", %1, %tile, %tile_10, %4 : tile<i32>, tile<16xf32>, tile<16xf32>, tile<16xf32> -> token
    %7 = make_token : token
    %pview_12 = make_partition_view %tview : partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>
    %8 = join_tokens %0, %7 : token
    %9 = store_view_tko weak %4, %pview_12[%blockId_x] token = %8 : tile<16xf32>, partition_view<tile=(16), tensor_view<?xf32, strides=[1]>>, tile<i32> -> token
    return
  }
}

Fusing the string calls into the print is done with a rewrite rule.

maleadt and others added 4 commits April 3, 2026 12:57
Implement the Tile IR print_tko instruction (opcode 85) with a Julia-native
API: standard print() and println() calls inside kernels compile to print_tko
instructions, with format strings built at compile time from mixed constant
string and tile arguments. String interpolation (e.g., println("bid=$bid"))
is supported via a format_string fusion pass that inlines string() args into
print_tko calls.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Add PSplat/RSplat support to the rewrite pattern framework: `~x...`
captures remaining operands on the LHS and expands them on the RHS.
This enables expressing the format_string→print_tko fusion as a single
declarative rule instead of an imperative pass:

  print_tko(format_string(~parts...), ~rest...) =>
  print_tko(~parts..., ~rest...)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@maleadt maleadt merged commit 93bc239 into main Apr 3, 2026
8 of 9 checks passed
@maleadt maleadt deleted the tb/print branch April 3, 2026 11:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant