Skip to content

convert_exported_program_to_serialized_trt_engine fails with AssertionError on TorchExportableModuleWithStaticCache exports #4162

@Mgluhovskoi

Description

@Mgluhovskoi

Bug Description

convert_exported_program_to_serialized_trt_engine fails with AssertionError when given an ExportedProgram from TorchExportableModuleWithStaticCache (HuggingFace transformers). The error originates in run_decompositions() which is called internally.

The root cause is in torch._functorch._aot_autograd.graph_compile.aot_stage2_export:

assert isinstance(compiled_fn, torch.fx.GraphModule)
AssertionError

TorchExportableModuleWithStaticCache wraps a causal LM with StaticCache registered as module state. The exported program traces cleanly via torch.export.export(strict=False), but run_decompositions() fails during the AOT re-export step.

This is not the same as #3226 (torchao import order), which was fixed. This bug occurs without torchao.

To Reproduce

import torch
import torch_tensorrt
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.integrations.executorch import TorchExportableModuleWithStaticCache

config = AutoConfig.from_pretrained("gpt2")
config.n_layer = 1
model = AutoModelForCausalLM.from_config(config).eval().half()
model.generation_config.cache_implementation = "static"
model.generation_config.use_cache = True

wrapper = TorchExportableModuleWithStaticCache(model, batch_size=1, max_cache_len=16)

input_ids = torch.tensor([[42]], dtype=torch.long)
cache_position = torch.tensor([0], dtype=torch.long)

exported = torch.export.export(
    wrapper, (), kwargs={"input_ids": input_ids, "cache_position": cache_position}, strict=False
)

# This fails:
engine = torch_tensorrt.dynamo.convert_exported_program_to_serialized_trt_engine(
    exported,
    inputs=[
        torch_tensorrt.Input(shape=input_ids.shape, dtype=input_ids.dtype),
        torch_tensorrt.Input(shape=cache_position.shape, dtype=cache_position.dtype),
    ],
    use_explicit_typing=True,
    min_block_size=1,
)

Error

File ".../torch/export/exported_program.py", line 1484, in run_decompositions
    return _decompose_exported_program(
File ".../torch/export/exported_program.py", line 967, in _decompose_exported_program
    ) = _decompose_and_get_gm_with_new_signature_constants(
File ".../torch/export/exported_program.py", line 476, in _decompose_and_get_gm_with_new_signature_constants
    aten_export_artifact = _export_to_aten_ir(
File ".../torch/export/_trace.py", line 985, in _export_to_aten_ir
    gm, graph_signature = transform(_aot_export_joint_with_descriptors)(
File ".../torch/export/_trace.py", line 924, in _aot_export_joint_with_descriptors
    gm, fw_metadata = aot_stage2_export(
File ".../torch/_functorch/_aot_autograd/graph_compile.py", line 288, in aot_stage2_export
    assert isinstance(compiled_fn, torch.fx.GraphModule)
AssertionError

Workaround

Instead of using TorchExportableModuleWithStaticCache, create a fully stateless wrapper that accepts KV cache tensors as explicit inputs/outputs (no internal buffer mutations). This avoids the run_decompositions code path issue entirely. The stateless wrapper exports and converts to TRT engine successfully.

Environment

  • torch: 2.10.0+cu128
  • torch_tensorrt: 2.10.0+cu130
  • transformers: 5.2.0
  • CUDA: 12.8
  • GPU: NVIDIA GeForce RTX 4090
  • OS: Ubuntu 22.04 (Docker container)

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions