Skip to content

Commit 93bc239

Browse files
authored
Merge pull request #173 from JuliaGPU/tb/print
Add printing functionality (+ rewriter splat functionality)
2 parents a6daec6 + f2cf374 commit 93bc239

File tree

10 files changed

+401
-9
lines changed

10 files changed

+401
-9
lines changed

README.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,23 @@ ct.scatter(arr, indices, tile; mask=active_mask)
250250
All atomics accept `memory_order` (default: `ct.MemoryOrder.AcqRel`) and
251251
`memory_scope` (default: `ct.MemScope.Device`) keyword arguments.
252252

253+
### Debugging
254+
| Operation | Description |
255+
|-----------|-------------|
256+
| `print(args...)` | Print values (Base overlay) |
257+
| `println(args...)` | Print values with newline (Base overlay) |
258+
259+
Standard Julia `print`/`println` work inside kernels. String constants and tiles
260+
can be mixed freely; format specifiers are inferred from element types at compile
261+
time. String interpolation is supported.
262+
263+
```julia
264+
println("Block ", ct.bid(1), ": tile=", tile)
265+
println("result=$result") # string interpolation
266+
```
267+
268+
This is a debugging aid and is not optimized for performance.
269+
253270

254271
## Differences from cuTile Python
255272

src/bytecode/encodings.jl

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,42 @@ function encode_AssertOp!(cb::CodeBuilder, condition::Value, message::String)
305305
return new_op!(cb, 0)
306306
end
307307

308+
"""
309+
encode_PrintTkoOp!(cb, token_type, args; token, format_string) -> Union{Value, Nothing}
310+
311+
Print a formatted string with optional token ordering.
312+
Opcode: 85
313+
314+
Returns the result token for synchronization (v13.2+), or nothing (v13.1).
315+
"""
316+
function encode_PrintTkoOp!(cb::CodeBuilder,
317+
token_type::Union{TypeId, Nothing},
318+
args::Vector{Value};
319+
token::Union{Value, Nothing}=nothing,
320+
format_string::String)
321+
encode_varint!(cb.buf, Opcode.PrintOp)
322+
# Variadic result types: [token] in v13.2+, empty in v13.1
323+
if cb.version >= v"13.2"
324+
encode_typeid_seq!(cb.buf, [token_type])
325+
else
326+
encode_typeid_seq!(cb.buf, TypeId[])
327+
end
328+
# Flags: bit 0 = has input token (v13.2+ only)
329+
if cb.version >= v"13.2"
330+
encode_varint!(cb.buf, token !== nothing ? 1 : 0)
331+
end
332+
# Attributes: format string
333+
encode_opattr_str!(cb, format_string)
334+
# Operands: sized variadic args + optional token
335+
encode_sized_operands!(cb.buf, args)
336+
if cb.version >= v"13.2"
337+
encode_optional_operand!(cb.buf, token)
338+
end
339+
num_results = cb.version >= v"13.2" ? 1 : 0
340+
result = new_op!(cb, num_results)
341+
return result # Value for v13.2+, nothing for v13.1
342+
end
343+
308344
"""
309345
encode_AssumeOp!(cb, result_type, value, predicate) -> Value
310346

src/compiler/intrinsics/misc.jl

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,83 @@ function emit_assume_ops!(ctx::CGCtx, array_val::Value, size_vals::Vector{Value}
6363
return array_val, size_vals, stride_vals
6464
end
6565

66-
# TODO: cuda_tile.print_tko
66+
# cuda_tile.print_tko
67+
68+
# Format specifier inference for print_tko
69+
function infer_format_specifier(::Type{T}) where T
70+
if T <: Union{Bool, Int8, Int16, Int32, UInt8, UInt16, UInt32}
71+
return "%d"
72+
elseif T <: Union{Int64, UInt64}
73+
return "%ld"
74+
elseif T <: AbstractFloat # Float16, BFloat16, Float32, TFloat32, Float64
75+
return "%f"
76+
else
77+
throw(IRError("print: unsupported element type $T"))
78+
end
79+
end
80+
81+
# Escape literal `%` as `%%` for C printf format strings
82+
escape_printf(s::String) = replace(s, "%" => "%%")
83+
84+
@intrinsic print_tko(xs...)
85+
tfunc(𝕃, ::typeof(Intrinsics.print_tko), @nospecialize(args...)) = Nothing
86+
efunc(::typeof(Intrinsics.print_tko), effects::CC.Effects) =
87+
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
88+
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.print_tko), args)
89+
cb = ctx.cb
90+
tt = ctx.tt
91+
92+
# Extract input token from last arg (added by token_order_pass!)
93+
input_token = extract_token_arg!(ctx, args)
94+
95+
# Build format string and collect tile operands
96+
format_parts = String[]
97+
tile_args = Value[]
98+
99+
for arg in args
100+
c = get_constant(ctx, arg)
101+
if c !== nothing
102+
val = something(c)
103+
if val isa String
104+
push!(format_parts, escape_printf(val))
105+
elseif val isa Number
106+
push!(format_parts, escape_printf(string(val)))
107+
else
108+
throw(IRError("print: unsupported constant type $(typeof(val))"))
109+
end
110+
else
111+
tv = emit_value!(ctx, arg)
112+
tv === nothing && throw(IRError("print: cannot resolve argument"))
113+
jltype = CC.widenconst(tv.jltype)
114+
elem_type = jltype <: Tile ? eltype(jltype) : jltype
115+
push!(format_parts, infer_format_specifier(elem_type))
116+
push!(tile_args, tv.v)
117+
end
118+
end
119+
120+
format_string = join(format_parts)
121+
token_type = Token(tt)
122+
123+
result = encode_PrintTkoOp!(cb, token_type, tile_args;
124+
token=input_token, format_string)
125+
126+
# Store result token for TokenResultNode
127+
# v13.2+ returns a token Value; v13.1 returns nothing (no token support)
128+
new_token = if result isa Value
129+
result
130+
else
131+
# Pre-13.2: create a fresh token to satisfy the token chain
132+
encode_MakeTokenOp!(cb, token_type)
133+
end
134+
ctx.result_tokens[ctx.current_ssa_idx] = new_token
135+
136+
nothing # print returns Nothing
137+
end
138+
139+
# cuda_tile.format_string (used by string interpolation fusion)
140+
@intrinsic format_string(xs...)
141+
tfunc(𝕃, ::typeof(Intrinsics.format_string), @nospecialize(args...)) = String
142+
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.format_string), args)
143+
throw(IRError("format_string intrinsic should have been fused into print_tko by the print fusion pass. " *
144+
"Standalone string() with Tile arguments is not supported in cuTile kernels."))
145+
end

src/compiler/passes/pipeline.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,20 @@
44
# defined inline here; complex imperative passes live in their own files
55
# (alias_analysis.jl, token_order.jl, dce.jl) and are called from run_passes!.
66

7+
#=============================================================================
8+
Print Fusion (rewrite)
9+
=============================================================================#
10+
11+
# Fuse format_string (from string interpolation overlay) into print_tko.
12+
# Julia lowers `print("hello $x")` → `print(string("hello ", x))`, which our
13+
# overlays compile to `print_tko(format_string("hello ", x), "\n")`.
14+
# This rule inlines format_string's args into the print_tko call.
15+
16+
const PRINT_FUSION_RULES = RewriteRule[
17+
@rewrite Intrinsics.print_tko(Intrinsics.format_string(~parts...), ~rest...) =>
18+
Intrinsics.print_tko(~parts..., ~rest...)
19+
]
20+
721
#=============================================================================
822
FMA Fusion (rewrite)
923
=============================================================================#
@@ -287,6 +301,8 @@ and subprogram compilation.
287301
function run_passes!(sci::StructuredIRCode)
288302
canonicalize!(sci)
289303

304+
rewrite_patterns!(sci, PRINT_FUSION_RULES)
305+
290306
constants = propagate_constants(sci)
291307
rewrite_patterns!(sci, OPTIMIZATION_RULES; constants)
292308

src/compiler/passes/rewrite.jl

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@ struct PBind <: PatternNode; name::Symbol; end
2727
struct PTypedBind <: PatternNode; name::Symbol; type::Type; end
2828
struct POneUse <: PatternNode; inner::PatternNode; end
2929
struct PLiteral <: PatternNode; val::Any; end
30+
struct PSplat <: PatternNode; name::Symbol; end # ~x... — captures remaining operands
3031

3132
abstract type RewriteNode end
3233
struct RCall <: RewriteNode; func::Any; operands::Vector{RewriteNode}; end
3334
struct RBind <: RewriteNode; name::Symbol; end
3435
struct RConst <: RewriteNode; val::Any; end
36+
struct RSplat <: RewriteNode; name::Symbol; end # ~x... — expands splat binding
3537

3638
"""
3739
RFunc(func)
@@ -117,6 +119,15 @@ function compile_lhs(ex)
117119
if ex isa Expr && ex.head === :$
118120
return :(PLiteral($(ex.args[1])))
119121
end
122+
# ~x... on the LHS: splat capture of remaining operands
123+
# Julia parses `~x...` as Expr(:..., Expr(:call, :~, :x))
124+
if ex isa Expr && ex.head === :... && length(ex.args) == 1
125+
inner = ex.args[1]
126+
if inner isa Expr && inner.head === :call && inner.args[1] === :~ && length(inner.args) == 2
127+
name = inner.args[2]
128+
return :(PSplat($(QuoteNode(name))))
129+
end
130+
end
120131
ex isa Expr && ex.head === :call || error("@rewrite LHS: expected call, got $ex")
121132
f = ex.args[1]
122133
if f === :~
@@ -134,6 +145,14 @@ function compile_rhs(ex)
134145
if ex isa Expr && ex.head === :$
135146
return :(RConst($(ex.args[1])))
136147
end
148+
# ~x... on the RHS: expand splat binding
149+
if ex isa Expr && ex.head === :... && length(ex.args) == 1
150+
inner = ex.args[1]
151+
if inner isa Expr && inner.head === :call && inner.args[1] === :~ && length(inner.args) == 2
152+
name = inner.args[2]
153+
return :(RSplat($(QuoteNode(name))))
154+
end
155+
end
137156
ex isa Expr && ex.head === :call || error("@rewrite RHS: expected call or \$const, got $ex")
138157
f = ex.args[1]
139158
f === :~ && return :(RBind($(QuoteNode(ex.args[2]))))
@@ -283,14 +302,23 @@ function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PCall,
283302

284303
if entry.func === pat.func
285304
ops = def_operands(entry)
286-
if length(ops) == length(pat.operands)
305+
has_splat = !isempty(pat.operands) && last(pat.operands) isa PSplat
306+
n_fixed = has_splat ? length(pat.operands) - 1 : length(pat.operands)
307+
308+
if has_splat ? length(ops) >= n_fixed : length(ops) == n_fixed
287309
result = MatchResult(Dict{Symbol,Any}(), SSAValue[val])
288-
for (op, sub) in zip(ops, pat.operands)
289-
m = pattern_match(driver, op, sub, entry.block)
310+
# Match fixed operands
311+
for i in 1:n_fixed
312+
m = pattern_match(driver, ops[i], pat.operands[i], entry.block)
290313
m === nothing && return nothing
291314
merge_bindings!(result.bindings, m.bindings) || return nothing
292315
append!(result.matched_ssas, m.matched_ssas)
293316
end
317+
# Capture remaining operands into the splat binding
318+
if has_splat
319+
splat_name = pat.operands[end]::PSplat
320+
result.bindings[splat_name.name] = ops[n_fixed+1:end]
321+
end
294322
return result
295323
end
296324
end
@@ -343,7 +371,15 @@ their type from the first SSA operand, since element-wise ops preserve type."""
343371
resolve_rhs(driver, block, ref, op::RBind, bindings, root_typ) = bindings[op.name]
344372
resolve_rhs(driver, block, ref, op::RConst, bindings, root_typ) = op.val
345373
function resolve_rhs(driver::RewriteDriver, block, ref, op::RCall, bindings, root_typ)
346-
operands = Any[resolve_rhs(driver, block, ref, sub, bindings, root_typ) for sub in op.operands]
374+
# Flatten RSplat nodes: each RSplat expands to multiple operands
375+
operands = Any[]
376+
for sub in op.operands
377+
if sub isa RSplat
378+
append!(operands, bindings[sub.name])
379+
else
380+
push!(operands, resolve_rhs(driver, block, ref, sub, bindings, root_typ))
381+
end
382+
end
347383
# Infer type from first SSA operand — correct for element-wise ops (addi,
348384
# subi, negf, etc.) whose result type matches their operands. Falls back to
349385
# root_typ when no SSA operand is available.
@@ -477,8 +513,15 @@ function apply_rewrite!(driver::RewriteDriver, block, val::SSAValue, rule, match
477513
end
478514
pos = findfirst(==(val.id), block.body.ssa_idxes)
479515
typ = block.body.types[pos]
480-
operands = Any[resolve_rhs(driver, block, val, op, match.bindings, typ)
481-
for op in rule.rhs.operands]
516+
# Build operands, flattening RSplat nodes into multiple operands
517+
operands = Any[]
518+
for op in rule.rhs.operands
519+
if op isa RSplat
520+
append!(operands, match.bindings[op.name])
521+
else
522+
push!(operands, resolve_rhs(driver, block, val, op, match.bindings, typ))
523+
end
524+
end
482525
# Recompute pos: resolve_rhs may insert instructions before val
483526
# (e.g. negf in subf→fma), shifting positions.
484527
pos = findfirst(==(val.id), block.body.ssa_idxes)

src/compiler/passes/token_order.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ function classify_memory_op(resolved_func)
7272
elseif resolved_func === Intrinsics.store_partition_view ||
7373
resolved_func === Intrinsics.store_ptr_tko
7474
return MEM_STORE
75+
elseif resolved_func === Intrinsics.print_tko
76+
return MEM_STORE
7577
elseif is_atomic_intrinsic(resolved_func)
7678
return MEM_STORE
7779
else

src/language/overlays.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,35 @@ for F in Floats
111111
end
112112

113113

114+
#=============================================================================
115+
Printing
116+
=============================================================================#
117+
118+
# Override all print/println entry points from coreio.jl to bypass stdout
119+
# and route directly to the print_tko Tile IR instruction.
120+
# Uses @consistent_overlay (not @overlay) because print has side effects.
121+
Base.Experimental.@consistent_overlay cuTileMethodTable Base.print(x) =
122+
Intrinsics.print_tko(x)
123+
Base.Experimental.@consistent_overlay cuTileMethodTable Base.print(x1, x2) =
124+
Intrinsics.print_tko(x1, x2)
125+
Base.Experimental.@consistent_overlay cuTileMethodTable Base.print(xs...) =
126+
Intrinsics.print_tko(xs...)
127+
Base.Experimental.@consistent_overlay cuTileMethodTable Base.println() =
128+
Intrinsics.print_tko("\n")
129+
Base.Experimental.@consistent_overlay cuTileMethodTable Base.println(x) =
130+
Intrinsics.print_tko(x, "\n")
131+
Base.Experimental.@consistent_overlay cuTileMethodTable Base.println(x1, x2) =
132+
Intrinsics.print_tko(x1, x2, "\n")
133+
Base.Experimental.@consistent_overlay cuTileMethodTable Base.println(xs...) =
134+
Intrinsics.print_tko(xs..., "\n")
135+
136+
# String interpolation support: route string() to format_string intrinsic.
137+
# For all-constant args, the interpreter constant-folds via :foldable effects.
138+
# For args containing Tiles, the format_string intrinsic is emitted in the IR
139+
# and later fused into print_tko by the print fusion pass.
140+
@overlay Base.string(xs...) = Intrinsics.format_string(xs...)
141+
142+
114143
#=============================================================================
115144
Tile Constructors
116145
=============================================================================#

test/codegen/integration.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,8 @@ end
588588
@test_throws "Unsupported function call during Tile IR compilation" begin
589589
code_tiled(Tuple{ct.TileArray{Float32,1,spec}}) do a
590590
tile = ct.load(a, ct.bid(1), (16,))
591-
print(tile)
591+
# write() has no overlay — should fail as unsupported
592+
write(stdout, tile)
592593
return
593594
end
594595
end

0 commit comments

Comments
 (0)