Skip to content

[BUG][DeepCompile] DeepCompile fails on Qwen1.5-MoE-A2.7B-Chat with RuntimeError: 'weight' must be 2-D (LLaMA works in the same environment) #7942

@zuoyanzhang

Description

@zuoyanzhang

Describe the bug
When using DeepCompile to train Qwen1.5-MoE-A2.7B-Chat (a Qwen-series MoE model), the run fails during the first forward pass with:

The error happens at the embedding lookup:
inputs_embeds = self.embed_tokens(input_ids)

and eventually fails inside:

torch.embedding(weight, input, ...)

In the same environment, DeepCompile can run LLaMA models successfully, but Qwen MoE models fail with the error above.

At least from the traceback, the failure happens before entering the MoE block itself, in the embedding layer. This suggests that under the DeepCompile path, embed_tokens.weight may no longer be a 2-D tensor when F.embedding(...) is called.

Model
• Qwen/Qwen1.5-MoE-A2.7B-Chat

What works
• LLaMA models run successfully with DeepCompile in the same environment.

What fails
• Qwen-series MoE models
• Reproduced with Qwen1.5-MoE-A2.7B-Chat

Full traceback
[rank1]: Traceback (most recent call last):
[rank1]: File "<deepspeed_root>/examples/deepcompile/run_bench_lm.py", line 365, in
[rank1]: main()
[rank1]: File "<deepspeed_root>/examples/deepcompile/run_bench_lm.py", line 288, in main
[rank1]: outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids, use_cache=False)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1755, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1766, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "<deepspeed_root>/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank1]: ret_val = func(*args, **kwargs)
[rank1]: File "<deepspeed_root>/deepspeed/runtime/engine.py", line 2237, in forward
[rank1]: loss = self.module(*inputs, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1753, in _wrapped_call_impl
[rank1]: return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 662, in _fn
[rank1]: return fn(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1766, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
[rank1]: return func(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/transformers/models/qwen2_moe/modeling_qwen2_moe.py", line 1317, in forward
[rank1]: outputs = self.model(
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1755, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1766, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/transformers/models/qwen2_moe/modeling_qwen2_moe.py", line 974, in forward
[rank1]: inputs_embeds = self.embed_tokens(input_ids)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1755, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1766, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 190, in forward
[rank1]: return F.embedding(
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/functional.py", line 2551, in embedding
[rank1]: return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
[rank1]: RuntimeError: 'weight' must be 2-D

Environment
• Python: 3.10.12
• torch: 2.7.0+cu124
• transformers: 4.49.0
• deepspeed: 0.18.4+28009dff
• accelerate: 1.13.0
• triton: 3.0.0
• tokenizers: 0.21.4
• safetensors: 0.7.0
• CUDA: 12.4
accelerate 1.13.0
datasets 3.1.0
deepspeed 0.18.4+28009dff
einops 0.8.1
huggingface_hub 0.36.2
numpy 1.26.4
safetensors 0.7.0
tokenizers 0.21.4
torch 2.7.0+cu124
torchaudio 2.7.0+cu124
torchvision 0.22.0+cu124
transformers 4.49.0
triton 3.0.0

Additional context

A few observations:
1. The same setup works for LLaMA models, so this does not look like a generic DeepCompile failure.
2. The failure happens in the embedding layer, not inside the expert computation itself.
3. The failing model is a Qwen MoE model, so this may be related to how DeepCompile handles this architecture under compilation / sharding.
4. The traceback goes through self._compiled_call_impl(...), so the issue appears on the compiled path.

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingtraining

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions