-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathmisc.jl
More file actions
145 lines (125 loc) · 5.69 KB
/
misc.jl
File metadata and controls
145 lines (125 loc) · 5.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# miscellaneous intrinsics
# cuda_tile.assert
@intrinsic assert(cond::Bool, message::String)
tfunc(𝕃, ::typeof(Intrinsics.assert), @nospecialize(cond), @nospecialize(message)) = Nothing
efunc(::typeof(Intrinsics.assert), effects::CC.Effects) =
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.assert), args)
cond = @something emit_value!(ctx, args[1]) throw(IRError("assert: cannot resolve condition"))
message = @something get_constant(ctx, args[2]) throw(IRError("assert: requires constant message"))
encode_AssertOp!(ctx.cb, cond.v, message)
nothing # no result value
end
# XXX: cuda_tile.assume
# make this a pass?
function emit_assume_ops!(ctx::CGCtx, array_val::Value, size_vals::Vector{Value},
stride_vals::Vector{Value}, array_spec::ArraySpec, dtype::TypeId, scalar_type::TypeId;
tv_strides::Union{Vector{Int64}, Nothing}=nothing)
cb = ctx.cb
tt = ctx.tt
# Pointer alignment
if array_spec.alignment > 0
ptr_dtype = pointer_type!(tt, dtype)
ptr_tile_type = tile_type!(tt, ptr_dtype, RowMajorShape(()))
array_val = encode_AssumeOp!(cb, ptr_tile_type, array_val, DivBy(array_spec.alignment))
end
# Bounds assumes for sizes
size_vals = Value[encode_AssumeOp!(cb, scalar_type, v, Bounded(0, nothing)) for v in size_vals]
# Bounds assumes for strides - only for dynamic strides
if tv_strides !== nothing
stride_vals = Value[tv_strides[i] == DYNAMIC_SHAPE ?
encode_AssumeOp!(cb, scalar_type, v, Bounded(0, nothing)) : v
for (i, v) in enumerate(stride_vals)]
else
stride_vals = Value[encode_AssumeOp!(cb, scalar_type, v, Bounded(0, nothing)) for v in stride_vals]
end
# Divisibility assumes for sizes
# ArraySpec fields are in Julia order; size_vals are in Tile IR order (reversed)
ndim = length(size_vals)
for (julia_i, div_by) in enumerate(array_spec.shape_div_by)
tileir_i = ndim + 1 - julia_i # Reverse index mapping
if div_by > 0 && tileir_i <= length(size_vals)
size_vals[tileir_i] = encode_AssumeOp!(cb, scalar_type, size_vals[tileir_i], DivBy(div_by))
end
end
# Divisibility assumes for strides - only for dynamic strides
for (julia_i, div_by) in enumerate(array_spec.stride_div_by)
tileir_i = ndim + 1 - julia_i # Reverse index mapping
if div_by > 0 && tileir_i <= length(stride_vals)
# Skip if this stride is static (not DYNAMIC_SHAPE)
if tv_strides === nothing || tv_strides[tileir_i] == DYNAMIC_SHAPE
stride_vals[tileir_i] = encode_AssumeOp!(cb, scalar_type, stride_vals[tileir_i], DivBy(div_by))
end
end
end
return array_val, size_vals, stride_vals
end
# cuda_tile.print_tko
# Format specifier inference for print_tko
function infer_format_specifier(::Type{T}) where T
if T <: Union{Bool, Int8, Int16, Int32, UInt8, UInt16, UInt32}
return "%d"
elseif T <: Union{Int64, UInt64}
return "%ld"
elseif T <: AbstractFloat # Float16, BFloat16, Float32, TFloat32, Float64
return "%f"
else
throw(IRError("print: unsupported element type $T"))
end
end
# Escape literal `%` as `%%` for C printf format strings
escape_printf(s::String) = replace(s, "%" => "%%")
@intrinsic print_tko(xs...)
tfunc(𝕃, ::typeof(Intrinsics.print_tko), @nospecialize(args...)) = Nothing
efunc(::typeof(Intrinsics.print_tko), effects::CC.Effects) =
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.print_tko), args)
cb = ctx.cb
tt = ctx.tt
# Extract input token from last arg (added by token_order_pass!)
input_token = extract_token_arg!(ctx, args)
# Build format string and collect tile operands
format_parts = String[]
tile_args = Value[]
for arg in args
c = get_constant(ctx, arg)
if c !== nothing
val = something(c)
if val isa String
push!(format_parts, escape_printf(val))
elseif val isa Number
push!(format_parts, escape_printf(string(val)))
else
throw(IRError("print: unsupported constant type $(typeof(val))"))
end
else
tv = emit_value!(ctx, arg)
tv === nothing && throw(IRError("print: cannot resolve argument"))
jltype = CC.widenconst(tv.jltype)
elem_type = jltype <: Tile ? eltype(jltype) : jltype
push!(format_parts, infer_format_specifier(elem_type))
push!(tile_args, tv.v)
end
end
format_string = join(format_parts)
token_type = Token(tt)
result = encode_PrintTkoOp!(cb, token_type, tile_args;
token=input_token, format_string)
# Store result token for TokenResultNode
# v13.2+ returns a token Value; v13.1 returns nothing (no token support)
new_token = if result isa Value
result
else
# Pre-13.2: create a fresh token to satisfy the token chain
encode_MakeTokenOp!(cb, token_type)
end
ctx.result_tokens[ctx.current_ssa_idx] = new_token
nothing # print returns Nothing
end
# cuda_tile.format_string (used by string interpolation fusion)
@intrinsic format_string(xs...)
tfunc(𝕃, ::typeof(Intrinsics.format_string), @nospecialize(args...)) = String
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.format_string), args)
throw(IRError("format_string intrinsic should have been fused into print_tko by the print fusion pass. " *
"Standalone string() with Tile arguments is not supported in cuTile kernels."))
end