Describe the bug
When using DeepCompile to train Qwen1.5-MoE-A2.7B-Chat (a Qwen-series MoE model), the run fails during the first forward pass with:
The error happens at the embedding lookup:
inputs_embeds = self.embed_tokens(input_ids)
and eventually fails inside:
torch.embedding(weight, input, ...)
In the same environment, DeepCompile can run LLaMA models successfully, but Qwen MoE models fail with the error above.
At least from the traceback, the failure happens before entering the MoE block itself, in the embedding layer. This suggests that under the DeepCompile path, embed_tokens.weight may no longer be a 2-D tensor when F.embedding(...) is called.
Model
• Qwen/Qwen1.5-MoE-A2.7B-Chat
What works
• LLaMA models run successfully with DeepCompile in the same environment.
What fails
• Qwen-series MoE models
• Reproduced with Qwen1.5-MoE-A2.7B-Chat
Full traceback
[rank1]: Traceback (most recent call last):
[rank1]: File "<deepspeed_root>/examples/deepcompile/run_bench_lm.py", line 365, in
[rank1]: main()
[rank1]: File "<deepspeed_root>/examples/deepcompile/run_bench_lm.py", line 288, in main
[rank1]: outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids, use_cache=False)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1755, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1766, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "<deepspeed_root>/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank1]: ret_val = func(*args, **kwargs)
[rank1]: File "<deepspeed_root>/deepspeed/runtime/engine.py", line 2237, in forward
[rank1]: loss = self.module(*inputs, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1753, in _wrapped_call_impl
[rank1]: return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 662, in _fn
[rank1]: return fn(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1766, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
[rank1]: return func(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/transformers/models/qwen2_moe/modeling_qwen2_moe.py", line 1317, in forward
[rank1]: outputs = self.model(
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1755, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1766, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/transformers/models/qwen2_moe/modeling_qwen2_moe.py", line 974, in forward
[rank1]: inputs_embeds = self.embed_tokens(input_ids)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1755, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1766, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 190, in forward
[rank1]: return F.embedding(
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/functional.py", line 2551, in embedding
[rank1]: return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
[rank1]: RuntimeError: 'weight' must be 2-D
Environment
• Python: 3.10.12
• torch: 2.7.0+cu124
• transformers: 4.49.0
• deepspeed: 0.18.4+28009dff
• accelerate: 1.13.0
• triton: 3.0.0
• tokenizers: 0.21.4
• safetensors: 0.7.0
• CUDA: 12.4
accelerate 1.13.0
datasets 3.1.0
deepspeed 0.18.4+28009dff
einops 0.8.1
huggingface_hub 0.36.2
numpy 1.26.4
safetensors 0.7.0
tokenizers 0.21.4
torch 2.7.0+cu124
torchaudio 2.7.0+cu124
torchvision 0.22.0+cu124
transformers 4.49.0
triton 3.0.0
Additional context
A few observations:
1. The same setup works for LLaMA models, so this does not look like a generic DeepCompile failure.
2. The failure happens in the embedding layer, not inside the expert computation itself.
3. The failing model is a Qwen MoE model, so this may be related to how DeepCompile handles this architecture under compilation / sharding.
4. The traceback goes through self._compiled_call_impl(...), so the issue appears on the compiled path.
Describe the bug
When using DeepCompile to train Qwen1.5-MoE-A2.7B-Chat (a Qwen-series MoE model), the run fails during the first forward pass with:
The error happens at the embedding lookup:
inputs_embeds = self.embed_tokens(input_ids)
and eventually fails inside:
torch.embedding(weight, input, ...)
In the same environment, DeepCompile can run LLaMA models successfully, but Qwen MoE models fail with the error above.
At least from the traceback, the failure happens before entering the MoE block itself, in the embedding layer. This suggests that under the DeepCompile path, embed_tokens.weight may no longer be a 2-D tensor when F.embedding(...) is called.
Model
• Qwen/Qwen1.5-MoE-A2.7B-Chat
What works
• LLaMA models run successfully with DeepCompile in the same environment.
What fails
• Qwen-series MoE models
• Reproduced with Qwen1.5-MoE-A2.7B-Chat
Full traceback
[rank1]: Traceback (most recent call last):
[rank1]: File "<deepspeed_root>/examples/deepcompile/run_bench_lm.py", line 365, in
[rank1]: main()
[rank1]: File "<deepspeed_root>/examples/deepcompile/run_bench_lm.py", line 288, in main
[rank1]: outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids, use_cache=False)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1755, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1766, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "<deepspeed_root>/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank1]: ret_val = func(*args, **kwargs)
[rank1]: File "<deepspeed_root>/deepspeed/runtime/engine.py", line 2237, in forward
[rank1]: loss = self.module(*inputs, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1753, in _wrapped_call_impl
[rank1]: return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 662, in _fn
[rank1]: return fn(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1766, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
[rank1]: return func(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/transformers/models/qwen2_moe/modeling_qwen2_moe.py", line 1317, in forward
[rank1]: outputs = self.model(
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1755, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1766, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/transformers/models/qwen2_moe/modeling_qwen2_moe.py", line 974, in forward
[rank1]: inputs_embeds = self.embed_tokens(input_ids)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1755, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1766, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/modules/sparse.py", line 190, in forward
[rank1]: return F.embedding(
[rank1]: File "<conda_env>/lib/python3.10/site-packages/torch/nn/functional.py", line 2551, in embedding
[rank1]: return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
[rank1]: RuntimeError: 'weight' must be 2-D
Environment
• Python: 3.10.12
• torch: 2.7.0+cu124
• transformers: 4.49.0
• deepspeed: 0.18.4+28009dff
• accelerate: 1.13.0
• triton: 3.0.0
• tokenizers: 0.21.4
• safetensors: 0.7.0
• CUDA: 12.4
accelerate 1.13.0
datasets 3.1.0
deepspeed 0.18.4+28009dff
einops 0.8.1
huggingface_hub 0.36.2
numpy 1.26.4
safetensors 0.7.0
tokenizers 0.21.4
torch 2.7.0+cu124
torchaudio 2.7.0+cu124
torchvision 0.22.0+cu124
transformers 4.49.0
triton 3.0.0
Additional context
A few observations:
1. The same setup works for LLaMA models, so this does not look like a generic DeepCompile failure.
2. The failure happens in the embedding layer, not inside the expert computation itself.
3. The failing model is a Qwen MoE model, so this may be related to how DeepCompile handles this architecture under compilation / sharding.
4. The traceback goes through self._compiled_call_impl(...), so the issue appears on the compiled path.