@@ -88,10 +88,14 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_
8888 # Save token before branches
8989 token_before = ctx. token
9090
91+ # Save token_map before branches
92+ token_map_before = copy (ctx. token_map)
93+
9194 # Emit IfOp with callback-based region building
9295 then_body = function (_)
9396 saved_block_args = copy (ctx. block_args)
9497 ctx. token = token_before # Reset to pre-branch token
98+ ctx. token_map = copy (token_map_before) # Reset token_map too
9599 emit_block! (ctx, then_blk)
96100 if then_blk. terminator === nothing
97101 encode_YieldOp! (ctx. cb, [ctx. token])
@@ -102,6 +106,7 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_
102106 else_body = function (_)
103107 saved_block_args = copy (ctx. block_args)
104108 ctx. token = token_before # Reset to pre-branch token
109+ ctx. token_map = copy (token_map_before) # Reset token_map too
105110 emit_block! (ctx, else_blk)
106111 if else_blk. terminator === nothing
107112 encode_YieldOp! (ctx. cb, [ctx. token])
@@ -114,6 +119,12 @@ function emit_if_op!(ctx::CGCtx, op::IfOp, @nospecialize(parent_result_type), n_
114119 # Last result is the merged token from both branches
115120 ctx. token = results[end ]
116121
122+ # Merge token_map from both branches
123+ # Conservatively reset to token_before for all keys
124+ for key in keys (ctx. token_map)
125+ ctx. token_map[key] = results[end ]
126+ end
127+
117128 # Store results at IfOp's SSA index (may be empty for void-returning ifs)
118129 ctx. values[ssa_idx] = CGVal (results[1 : n_user_results], parent_result_type)
119130end
@@ -164,6 +175,9 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type),
164175 # Number of user result types (excluding token)
165176 n_user_results = n_carries
166177
178+ # Save token_map before loop
179+ token_map_before = copy (ctx. token_map)
180+
167181 # Emit ForOp with callback-based region building
168182 body_builder = function (block_args)
169183 saved_block_args = copy (ctx. block_args)
@@ -196,6 +210,12 @@ function emit_for_op!(ctx::CGCtx, op::ForOp, @nospecialize(parent_result_type),
196210 # Last result is the token
197211 ctx. token = results[end ]
198212
213+ # Update token_map after loop
214+ # Conservatively update all keys to the merged token
215+ for key in keys (token_map_before)
216+ ctx. token_map[key] = results[end ]
217+ end
218+
199219 # Store results at the loop's SSA index (may be empty for void-returning loops)
200220 ctx. values[ssa_idx] = CGVal (results[1 : n_user_results], parent_result_type)
201221end
@@ -230,6 +250,9 @@ function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type)
230250 # Number of user result types (excluding token)
231251 n_user_results = n_carries
232252
253+ # Save token_map before loop
254+ token_map_before = copy (ctx. token_map)
255+
233256 # Emit LoopOp with callback-based region building
234257 body_builder = function (block_args)
235258 saved_block_args = copy (ctx. block_args)
@@ -266,6 +289,12 @@ function emit_loop_op!(ctx::CGCtx, op::LoopOp, @nospecialize(parent_result_type)
266289 # Last result is the token
267290 ctx. token = results[end ]
268291
292+ # Update token_map after loop
293+ # Conservatively update all keys to the merged token
294+ for key in keys (token_map_before)
295+ ctx. token_map[key] = results[end ]
296+ end
297+
269298 # Store results at the loop's SSA index (may be empty for void-returning loops)
270299 ctx. values[ssa_idx] = CGVal (results[1 : n_user_results], parent_result_type)
271300end
@@ -301,6 +330,9 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ
301330 # Number of user result types (excluding token)
302331 n_user_results = n_carries
303332
333+ # Save token_map before loop
334+ token_map_before = copy (ctx. token_map)
335+
304336 # Emit WhileOp as cuda_tile.loop with conditional break pattern
305337 # MLIR structure: before { stmts; condition(cond) args } do { stmts; yield vals }
306338 # Emitted as: loop { before_stmts; if(!cond) { break } else { yield }; after_stmts; continue }
@@ -396,6 +428,12 @@ function emit_while_op!(ctx::CGCtx, op::WhileOp, @nospecialize(parent_result_typ
396428 # Last result is the token
397429 ctx. token = results[end ]
398430
431+ # Update token_map after loop
432+ # Conservatively update all keys to the merged token
433+ for key in keys (token_map_before)
434+ ctx. token_map[key] = results[end ]
435+ end
436+
399437 # Store results at the loop's SSA index (may be empty for void-returning loops)
400438 ctx. values[ssa_idx] = CGVal (results[1 : n_user_results], parent_result_type)
401439end
0 commit comments