Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 83 additions & 36 deletions genie.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jax
import jax.numpy as jnp
import flax.linen as nn
import einops

from models.dynamics import DynamicsMaskGIT
from models.lam import LatentActionModel
Expand Down Expand Up @@ -87,25 +88,42 @@ def __call__(self, batch: Dict[str, Any], training: bool = True) -> Dict[str, An
def sample(
self,
batch: Dict[str, Any],
seq_len: int,
steps: int = 25,
temperature: int = 1,
temperature: float = 1,
sample_argmax: bool = False,
) -> Any:
"""
Autoregressively samples up to `seq_len` future frames, following Figure 8 of the paper.

- Input frames are tokenized once.
- Future frames are generated autoregressively in token space.
- All frames are detokenized in a single pass.

Note:
- For interactive or step-wise sampling, detokenization should occur after each action.
- To maintain consistent tensor shapes across timesteps, all current and future frames are decoded at every step.
- Temporal causal structure is preserved by
a) reapplying the mask before each decoding step.
b) a temporal causal mask is applied within each ST-transformer block.

Dimension keys:
B: batch size
T: number of input (conditioning) frames
N: patches per frame
S: sequence length
A: action space
D: model latent dimension
"""
# --- Encode videos and actions ---
tokenizer_out = self.tokenizer.vq_encode(batch["videos"], training=False)
token_idxs = tokenizer_out["indices"]
new_frame_idxs = jnp.zeros_like(token_idxs)[:, 0]
token_idxs = tokenizer_out["indices"] # (B, T, N)
B, T, N = token_idxs.shape
pad_shape = (B, seq_len - T, N)
pad = jnp.zeros(pad_shape, dtype=token_idxs.dtype)
token_idxs = jnp.concatenate([token_idxs, pad], axis=1) # (B, S, N)
action_tokens = self.lam.vq.get_codes(batch["latent_actions"])

# --- Initialize MaskGIT ---
init_mask = jnp.ones_like(token_idxs, dtype=bool)[:, 0]
init_carry = (
batch["rng"],
new_frame_idxs,
init_mask,
token_idxs,
action_tokens,
)
MaskGITLoop = nn.scan(
MaskGITStep,
variable_broadcast="params",
Expand All @@ -123,13 +141,45 @@ def sample(
sample_argmax=sample_argmax,
steps=steps,
)
final_carry, _ = loop_fn(init_carry, jnp.arange(steps))
new_frame_idxs = final_carry[1]
new_frame_pixels = self.tokenizer.decode(
jnp.expand_dims(new_frame_idxs, 1),

def generation_step_fn(carry, step_t):
rng, current_token_idxs = carry
rng, step_rng = jax.random.split(rng)

# Mask current frame (future frames are masked by default using causal mask in ST-transformer)
mask = jnp.arange(seq_len) == step_t # (S,)
mask = jnp.broadcast_to(mask[None, :, None], (B, seq_len, N)) # (B, S, N)
mask = mask.astype(bool)
masked_token_idxs = current_token_idxs * ~mask

# --- Initialize and run MaskGIT loop ---
init_carry_maskgit = (
step_rng,
masked_token_idxs,
mask,
action_tokens,
)
final_carry_maskgit, _ = loop_fn(init_carry_maskgit, jnp.arange(steps))
updated_token_idxs = final_carry_maskgit[1]
new_carry = (rng, updated_token_idxs)
return new_carry, None

# --- Run the autoregressive generation using scan ---
initial_carry = (batch["rng"], token_idxs)
timesteps_to_scan = jnp.arange(T, seq_len)
final_carry, _ = jax.lax.scan(
generation_step_fn,
initial_carry,
timesteps_to_scan
)
final_token_idxs = final_carry[1]

# --- Decode all tokens at once at the end ---
final_frames = self.tokenizer.decode(
final_token_idxs,
video_hw=batch["videos"].shape[2:4],
)
return new_frame_pixels
return final_frames

def vq_encode(self, batch, training) -> Dict[str, Any]:
# --- Preprocess videos ---
Expand All @@ -146,28 +196,22 @@ class MaskGITStep(nn.Module):

@nn.compact
def __call__(self, carry, x):
rng, final_token_idxs, mask, token_idxs, action_tokens = carry
rng, token_idxs, mask, action_tokens = carry
step = x
B, T, N = token_idxs.shape[:3]
N = token_idxs.shape[2]

# --- Construct + encode video ---
vid_token_idxs = jnp.concatenate(
(token_idxs, jnp.expand_dims(final_token_idxs, 1)), axis=1
)
vid_embed = self.dynamics.patch_embed(vid_token_idxs)
curr_masked_frame = jnp.where(
jnp.expand_dims(mask, -1),
self.dynamics.mask_token[0],
vid_embed[:, -1],
)
vid_embed = vid_embed.at[:, -1].set(curr_masked_frame)
vid_embed = self.dynamics.patch_embed(token_idxs) # (B, S, N, D)
mask_token = self.dynamics.mask_token # (1, 1, 1, D,)
mask_expanded = mask[..., None] # (B, S, N, 1)
vid_embed = jnp.where(mask_expanded, mask_token, vid_embed)

# --- Predict transition ---
act_embed = self.dynamics.action_up(action_tokens)
vid_embed += jnp.pad(act_embed, ((0, 0), (1, 0), (0, 0), (0, 0)))
unmasked_ratio = jnp.cos(jnp.pi * (step + 1) / (self.steps * 2))
step_temp = self.temperature * (1.0 - unmasked_ratio)
final_logits = self.dynamics.dynamics(vid_embed)[:, -1] / step_temp
final_logits = self.dynamics.dynamics(vid_embed) / step_temp

# --- Sample new tokens for final frame ---
if self.sample_argmax:
Expand All @@ -179,20 +223,23 @@ def __call__(self, carry, x):
jnp.argmax(final_logits, axis=-1),
jax.random.categorical(_rng, final_logits),
)
gather_fn = jax.vmap(jax.vmap(lambda x, y: x[y]))
gather_fn = jax.vmap(jax.vmap(jax.vmap(lambda x, y: x[y])))
final_token_probs = gather_fn(jax.nn.softmax(final_logits), sampled_token_idxs)
final_token_probs += ~mask
# Update masked tokens only
new_token_idxs = jnp.where(mask, sampled_token_idxs, final_token_idxs)
token_idxs = jnp.where(mask, sampled_token_idxs, token_idxs)

# --- Update mask ---
num_unmasked_tokens = jnp.round(N * (1.0 - unmasked_ratio)).astype(int)
idx_mask = jnp.arange(final_token_probs.shape[-1]) > num_unmasked_tokens
sorted_idxs = jnp.argsort(final_token_probs, axis=-1, descending=True)
idx_mask = jnp.arange(final_token_probs.shape[-1]) <= N - num_unmasked_tokens
final_token_probs_flat = einops.rearrange(final_token_probs, "b s n -> b (s n)")
sorted_idxs = jnp.argsort(final_token_probs_flat, axis=-1)
mask_update_fn = jax.vmap(lambda msk, ids: msk.at[ids].set(idx_mask))
new_mask = mask_update_fn(mask, sorted_idxs)
mask_flat = einops.rearrange(mask, "b s n -> b (s n)")
new_mask_flat = mask_update_fn(mask_flat, sorted_idxs)
new_mask = einops.rearrange(new_mask_flat, "b (s n) -> b s n", n=N)

new_carry = (rng, new_token_idxs, new_mask, token_idxs, action_tokens)
new_carry = (rng, token_idxs, new_mask, action_tokens)
return new_carry, None


Expand Down
28 changes: 13 additions & 15 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import einops
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
from orbax.checkpoint import PyTreeCheckpointer
from PIL import Image, ImageDraw
Expand Down Expand Up @@ -85,24 +86,21 @@ class Args:
ckpt = PyTreeCheckpointer().restore(args.checkpoint)["model"]["params"]["params"]
params["params"].update(ckpt)


def _sampling_wrapper(module, batch):
return module.sample(batch, args.seq_len, args.maskgit_steps, args.temperature, args.sample_argmax)

# --- Define autoregressive sampling loop ---
def _autoreg_sample(rng, video_batch, action_batch):
vid = video_batch[:, : args.start_frame + 1]
for frame_idx in range(args.start_frame + 1, args.seq_len):
# --- Sample next frame ---
print("Frame", frame_idx)
rng, _rng = jax.random.split(rng)
batch = dict(videos=vid, latent_actions=action_batch[:, :frame_idx], rng=_rng)
new_frame = genie.apply(
params,
batch,
args.maskgit_steps,
args.temperature,
args.sample_argmax,
method=Genie.sample,
)
vid = jnp.concatenate([vid, new_frame], axis=1)
return vid
sampling_fn = jax.jit(nn.apply(_sampling_wrapper, genie))
rng, _rng = jax.random.split(rng)
batch = dict(videos=vid, latent_actions=action_batch, rng=_rng)
generated_vid = sampling_fn(
params,
batch
)
return generated_vid


# --- Get video + latent actions ---
Expand Down