Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,23 @@ ct.scatter(arr, indices, tile; mask=active_mask)
All atomics accept `memory_order` (default: `ct.MemoryOrder.AcqRel`) and
`memory_scope` (default: `ct.MemScope.Device`) keyword arguments.

### Debugging
| Operation | Description |
|-----------|-------------|
| `print(args...)` | Print values (Base overlay) |
| `println(args...)` | Print values with newline (Base overlay) |

Standard Julia `print`/`println` work inside kernels. String constants and tiles
can be mixed freely; format specifiers are inferred from element types at compile
time. String interpolation is supported.

```julia
println("Block ", ct.bid(1), ": tile=", tile)
println("result=$result") # string interpolation
```

This is a debugging aid and is not optimized for performance.


## Differences from cuTile Python

Expand Down
36 changes: 36 additions & 0 deletions src/bytecode/encodings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,42 @@ function encode_AssertOp!(cb::CodeBuilder, condition::Value, message::String)
return new_op!(cb, 0)
end

"""
encode_PrintTkoOp!(cb, token_type, args; token, format_string) -> Union{Value, Nothing}

Print a formatted string with optional token ordering.
Opcode: 85

Returns the result token for synchronization (v13.2+), or nothing (v13.1).
"""
function encode_PrintTkoOp!(cb::CodeBuilder,
token_type::Union{TypeId, Nothing},
args::Vector{Value};
token::Union{Value, Nothing}=nothing,
format_string::String)
encode_varint!(cb.buf, Opcode.PrintOp)
# Variadic result types: [token] in v13.2+, empty in v13.1
if cb.version >= v"13.2"
encode_typeid_seq!(cb.buf, [token_type])
else
encode_typeid_seq!(cb.buf, TypeId[])
end
# Flags: bit 0 = has input token (v13.2+ only)
if cb.version >= v"13.2"
encode_varint!(cb.buf, token !== nothing ? 1 : 0)
end
# Attributes: format string
encode_opattr_str!(cb, format_string)
# Operands: sized variadic args + optional token
encode_sized_operands!(cb.buf, args)
if cb.version >= v"13.2"
encode_optional_operand!(cb.buf, token)
end
num_results = cb.version >= v"13.2" ? 1 : 0
result = new_op!(cb, num_results)
return result # Value for v13.2+, nothing for v13.1
end

"""
encode_AssumeOp!(cb, result_type, value, predicate) -> Value

Expand Down
81 changes: 80 additions & 1 deletion src/compiler/intrinsics/misc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,83 @@ function emit_assume_ops!(ctx::CGCtx, array_val::Value, size_vals::Vector{Value}
return array_val, size_vals, stride_vals
end

# TODO: cuda_tile.print_tko
# cuda_tile.print_tko

# Format specifier inference for print_tko
function infer_format_specifier(::Type{T}) where T
if T <: Union{Bool, Int8, Int16, Int32, UInt8, UInt16, UInt32}
return "%d"
elseif T <: Union{Int64, UInt64}
return "%ld"
elseif T <: AbstractFloat # Float16, BFloat16, Float32, TFloat32, Float64
return "%f"
else
throw(IRError("print: unsupported element type $T"))
end
end

# Escape literal `%` as `%%` for C printf format strings
escape_printf(s::String) = replace(s, "%" => "%%")

@intrinsic print_tko(xs...)
tfunc(𝕃, ::typeof(Intrinsics.print_tko), @nospecialize(args...)) = Nothing
efunc(::typeof(Intrinsics.print_tko), effects::CC.Effects) =
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.print_tko), args)
cb = ctx.cb
tt = ctx.tt

# Extract input token from last arg (added by token_order_pass!)
input_token = extract_token_arg!(ctx, args)

# Build format string and collect tile operands
format_parts = String[]
tile_args = Value[]

for arg in args
c = get_constant(ctx, arg)
if c !== nothing
val = something(c)
if val isa String
push!(format_parts, escape_printf(val))
elseif val isa Number
push!(format_parts, escape_printf(string(val)))
else
throw(IRError("print: unsupported constant type $(typeof(val))"))
end
else
tv = emit_value!(ctx, arg)
tv === nothing && throw(IRError("print: cannot resolve argument"))
jltype = CC.widenconst(tv.jltype)
elem_type = jltype <: Tile ? eltype(jltype) : jltype
push!(format_parts, infer_format_specifier(elem_type))
push!(tile_args, tv.v)
end
end

format_string = join(format_parts)
token_type = Token(tt)

result = encode_PrintTkoOp!(cb, token_type, tile_args;
token=input_token, format_string)

# Store result token for TokenResultNode
# v13.2+ returns a token Value; v13.1 returns nothing (no token support)
new_token = if result isa Value
result
else
# Pre-13.2: create a fresh token to satisfy the token chain
encode_MakeTokenOp!(cb, token_type)
end
ctx.result_tokens[ctx.current_ssa_idx] = new_token

nothing # print returns Nothing
end

# cuda_tile.format_string (used by string interpolation fusion)
@intrinsic format_string(xs...)
tfunc(𝕃, ::typeof(Intrinsics.format_string), @nospecialize(args...)) = String
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.format_string), args)
throw(IRError("format_string intrinsic should have been fused into print_tko by the print fusion pass. " *
"Standalone string() with Tile arguments is not supported in cuTile kernels."))
end
16 changes: 16 additions & 0 deletions src/compiler/passes/pipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@
# defined inline here; complex imperative passes live in their own files
# (alias_analysis.jl, token_order.jl, dce.jl) and are called from run_passes!.

#=============================================================================
Print Fusion (rewrite)
=============================================================================#

# Fuse format_string (from string interpolation overlay) into print_tko.
# Julia lowers `print("hello $x")` → `print(string("hello ", x))`, which our
# overlays compile to `print_tko(format_string("hello ", x), "\n")`.
# This rule inlines format_string's args into the print_tko call.

const PRINT_FUSION_RULES = RewriteRule[
@rewrite Intrinsics.print_tko(Intrinsics.format_string(~parts...), ~rest...) =>
Intrinsics.print_tko(~parts..., ~rest...)
]

#=============================================================================
FMA Fusion (rewrite)
=============================================================================#
Expand Down Expand Up @@ -287,6 +301,8 @@ and subprogram compilation.
function run_passes!(sci::StructuredIRCode)
canonicalize!(sci)

rewrite_patterns!(sci, PRINT_FUSION_RULES)

constants = propagate_constants(sci)
rewrite_patterns!(sci, OPTIMIZATION_RULES; constants)

Expand Down
55 changes: 49 additions & 6 deletions src/compiler/passes/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ struct PBind <: PatternNode; name::Symbol; end
struct PTypedBind <: PatternNode; name::Symbol; type::Type; end
struct POneUse <: PatternNode; inner::PatternNode; end
struct PLiteral <: PatternNode; val::Any; end
struct PSplat <: PatternNode; name::Symbol; end # ~x... — captures remaining operands

abstract type RewriteNode end
struct RCall <: RewriteNode; func::Any; operands::Vector{RewriteNode}; end
struct RBind <: RewriteNode; name::Symbol; end
struct RConst <: RewriteNode; val::Any; end
struct RSplat <: RewriteNode; name::Symbol; end # ~x... — expands splat binding

"""
RFunc(func)
Expand Down Expand Up @@ -117,6 +119,15 @@ function compile_lhs(ex)
if ex isa Expr && ex.head === :$
return :(PLiteral($(ex.args[1])))
end
# ~x... on the LHS: splat capture of remaining operands
# Julia parses `~x...` as Expr(:..., Expr(:call, :~, :x))
if ex isa Expr && ex.head === :... && length(ex.args) == 1
inner = ex.args[1]
if inner isa Expr && inner.head === :call && inner.args[1] === :~ && length(inner.args) == 2
name = inner.args[2]
return :(PSplat($(QuoteNode(name))))
end
end
ex isa Expr && ex.head === :call || error("@rewrite LHS: expected call, got $ex")
f = ex.args[1]
if f === :~
Expand All @@ -134,6 +145,14 @@ function compile_rhs(ex)
if ex isa Expr && ex.head === :$
return :(RConst($(ex.args[1])))
end
# ~x... on the RHS: expand splat binding
if ex isa Expr && ex.head === :... && length(ex.args) == 1
inner = ex.args[1]
if inner isa Expr && inner.head === :call && inner.args[1] === :~ && length(inner.args) == 2
name = inner.args[2]
return :(RSplat($(QuoteNode(name))))
end
end
ex isa Expr && ex.head === :call || error("@rewrite RHS: expected call or \$const, got $ex")
f = ex.args[1]
f === :~ && return :(RBind($(QuoteNode(ex.args[2]))))
Expand Down Expand Up @@ -283,14 +302,23 @@ function pattern_match(driver::RewriteDriver, @nospecialize(val), pat::PCall,

if entry.func === pat.func
ops = def_operands(entry)
if length(ops) == length(pat.operands)
has_splat = !isempty(pat.operands) && last(pat.operands) isa PSplat
n_fixed = has_splat ? length(pat.operands) - 1 : length(pat.operands)

if has_splat ? length(ops) >= n_fixed : length(ops) == n_fixed
result = MatchResult(Dict{Symbol,Any}(), SSAValue[val])
for (op, sub) in zip(ops, pat.operands)
m = pattern_match(driver, op, sub, entry.block)
# Match fixed operands
for i in 1:n_fixed
m = pattern_match(driver, ops[i], pat.operands[i], entry.block)
m === nothing && return nothing
merge_bindings!(result.bindings, m.bindings) || return nothing
append!(result.matched_ssas, m.matched_ssas)
end
# Capture remaining operands into the splat binding
if has_splat
splat_name = pat.operands[end]::PSplat
result.bindings[splat_name.name] = ops[n_fixed+1:end]
end
return result
end
end
Expand Down Expand Up @@ -343,7 +371,15 @@ their type from the first SSA operand, since element-wise ops preserve type."""
resolve_rhs(driver, block, ref, op::RBind, bindings, root_typ) = bindings[op.name]
resolve_rhs(driver, block, ref, op::RConst, bindings, root_typ) = op.val
function resolve_rhs(driver::RewriteDriver, block, ref, op::RCall, bindings, root_typ)
operands = Any[resolve_rhs(driver, block, ref, sub, bindings, root_typ) for sub in op.operands]
# Flatten RSplat nodes: each RSplat expands to multiple operands
operands = Any[]
for sub in op.operands
if sub isa RSplat
append!(operands, bindings[sub.name])
else
push!(operands, resolve_rhs(driver, block, ref, sub, bindings, root_typ))
end
end
# Infer type from first SSA operand — correct for element-wise ops (addi,
# subi, negf, etc.) whose result type matches their operands. Falls back to
# root_typ when no SSA operand is available.
Expand Down Expand Up @@ -477,8 +513,15 @@ function apply_rewrite!(driver::RewriteDriver, block, val::SSAValue, rule, match
end
pos = findfirst(==(val.id), block.body.ssa_idxes)
typ = block.body.types[pos]
operands = Any[resolve_rhs(driver, block, val, op, match.bindings, typ)
for op in rule.rhs.operands]
# Build operands, flattening RSplat nodes into multiple operands
operands = Any[]
for op in rule.rhs.operands
if op isa RSplat
append!(operands, match.bindings[op.name])
else
push!(operands, resolve_rhs(driver, block, val, op, match.bindings, typ))
end
end
# Recompute pos: resolve_rhs may insert instructions before val
# (e.g. negf in subf→fma), shifting positions.
pos = findfirst(==(val.id), block.body.ssa_idxes)
Expand Down
2 changes: 2 additions & 0 deletions src/compiler/passes/token_order.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ function classify_memory_op(resolved_func)
elseif resolved_func === Intrinsics.store_partition_view ||
resolved_func === Intrinsics.store_ptr_tko
return MEM_STORE
elseif resolved_func === Intrinsics.print_tko
return MEM_STORE
elseif is_atomic_intrinsic(resolved_func)
return MEM_STORE
else
Expand Down
29 changes: 29 additions & 0 deletions src/language/overlays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,35 @@ for F in Floats
end


#=============================================================================
Printing
=============================================================================#

# Override all print/println entry points from coreio.jl to bypass stdout
# and route directly to the print_tko Tile IR instruction.
# Uses @consistent_overlay (not @overlay) because print has side effects.
Base.Experimental.@consistent_overlay cuTileMethodTable Base.print(x) =
Intrinsics.print_tko(x)
Base.Experimental.@consistent_overlay cuTileMethodTable Base.print(x1, x2) =
Intrinsics.print_tko(x1, x2)
Base.Experimental.@consistent_overlay cuTileMethodTable Base.print(xs...) =
Intrinsics.print_tko(xs...)
Base.Experimental.@consistent_overlay cuTileMethodTable Base.println() =
Intrinsics.print_tko("\n")
Base.Experimental.@consistent_overlay cuTileMethodTable Base.println(x) =
Intrinsics.print_tko(x, "\n")
Base.Experimental.@consistent_overlay cuTileMethodTable Base.println(x1, x2) =
Intrinsics.print_tko(x1, x2, "\n")
Base.Experimental.@consistent_overlay cuTileMethodTable Base.println(xs...) =
Intrinsics.print_tko(xs..., "\n")

# String interpolation support: route string() to format_string intrinsic.
# For all-constant args, the interpreter constant-folds via :foldable effects.
# For args containing Tiles, the format_string intrinsic is emitted in the IR
# and later fused into print_tko by the print fusion pass.
@overlay Base.string(xs...) = Intrinsics.format_string(xs...)


#=============================================================================
Tile Constructors
=============================================================================#
Expand Down
3 changes: 2 additions & 1 deletion test/codegen/integration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,8 @@ end
@test_throws "Unsupported function call during Tile IR compilation" begin
code_tiled(Tuple{ct.TileArray{Float32,1,spec}}) do a
tile = ct.load(a, ct.bid(1), (16,))
print(tile)
# write() has no overlay — should fail as unsupported
write(stdout, tile)
return
end
end
Expand Down
Loading
Loading