Skip to content

feat: add Branch-Train-MiX method implementation for expert LLMs#211

Open
tanganke wants to merge 1 commit intomainfrom
method/btx
Open

feat: add Branch-Train-MiX method implementation for expert LLMs#211
tanganke wants to merge 1 commit intomainfrom
method/btx

Conversation

@tanganke
Copy link
Copy Markdown
Owner

@tanganke tanganke commented Mar 23, 2026

Summary by CodeRabbit

  • New Features
    • Added Branch-Train-MiX method for Qwen3 models, enabling users to mix multiple expert language models into a Mixture-of-Experts configuration with configurable expert parameters and routing strategies.

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Mar 23, 2026

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 1ef030e4-5ffa-4fcc-8f9f-e17db75944d2

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • ✅ Review completed - (🔄 Check again to review again)
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch method/btx

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@tanganke tanganke marked this pull request as ready for review March 23, 2026 10:07
Copilot AI review requested due to automatic review settings March 23, 2026 10:07
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces an initial Branch-Train-MiX implementation scaffold for Qwen3-family LLMs, focusing on mixing multiple Qwen3 expert models into a single Qwen3 MoE (Mixture-of-Experts) model.

Changes:

  • Added Qwen3 → Qwen3-MoE config construction helper and a function to mix expert weights into a Qwen3MoeForCausalLM.
  • Added a package docstring describing the Branch-Train-MiX method and its citation.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 7 comments.

File Description
fusion_bench/method/branch_train_mix/qwen3.py Adds Qwen3 MoE config construction and expert-weight mixing logic into a Qwen3 MoE model.
fusion_bench/method/branch_train_mix/init.py Adds module-level documentation for the Branch-Train-MiX method and paper reference.

Comment on lines +86 to +90
def mix_qwen3_models_to_moe(
base_model: Qwen3ForCausalLM,
expert_models: list[Qwen3ForCausalLM],
**kwargs,
) -> Qwen3MoeForCausalLM:
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mix_qwen3_models_to_moe accepts **kwargs but never uses it. Either remove **kwargs or use it (e.g., to accept/override Qwen3MoEArgs fields) so callers don’t think the function supports additional options when it doesn’t.

Copilot uses AI. Check for mistakes.
Comment on lines +50 to +54
"""Constructs a Qwen3MoeConfig from a base Qwen3Config by copying relevant parameters and setting the number of experts.

Args:
base_config: The base Qwen3Config to copy parameters from.

Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

construct_qwen3_moe_config_from_base takes moe_args but the docstring doesn’t document it. Please add moe_args to the Args section and clarify which MoE-specific settings come from moe_args vs which are copied from base_config.

Suggested change
"""Constructs a Qwen3MoeConfig from a base Qwen3Config by copying relevant parameters and setting the number of experts.
Args:
base_config: The base Qwen3Config to copy parameters from.
"""Constructs a Qwen3MoeConfig from a base Qwen3Config and MoE arguments.
All standard transformer configuration fields (e.g., vocabulary size, hidden
size, number of layers/heads, RoPE parameters, and dropout settings) are
copied directly from ``base_config``. MoE-specific settings (such as the
number of experts, experts-per-token, MoE intermediate size, sparsity
pattern, and router/auxiliary-loss behavior) are taken from ``moe_args``.
Args:
base_config: The base Qwen3Config to copy non-MoE model parameters from.
moe_args: MoE-specific configuration to apply on top of ``base_config``,
providing values for fields like ``num_experts``, ``num_experts_per_tok``,
``moe_intermediate_size``, ``decoder_sparse_step``, ``norm_topk_prob``,
``output_router_logits``, ``router_aux_loss_coef``, and
``mlp_only_layers``.

Copilot uses AI. Check for mistakes.
Comment on lines +91 to +98
"""Mixes the parameters of a base Qwen3 model and multiple expert Qwen3 models into a single Qwen3MoeForCausalLM model.

Args:
base_model: The base Qwen3 model to use as the foundation for the MoE model.
expert_models: A list of expert Qwen3 models whose parameters will be mixed into the MoE model.

Returns:
A Qwen3MoeForCausalLM model with parameters mixed from the base and expert models.
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

base_model is only used to read base_model.config; none of its weights are copied into moe_model despite the docstring describing it as the “foundation”. This likely leaves any non-overwritten MoE/router parameters at random init. Consider initializing from base_model weights where possible, or remove/rename base_model and update the docstring to match.

Suggested change
"""Mixes the parameters of a base Qwen3 model and multiple expert Qwen3 models into a single Qwen3MoeForCausalLM model.
Args:
base_model: The base Qwen3 model to use as the foundation for the MoE model.
expert_models: A list of expert Qwen3 models whose parameters will be mixed into the MoE model.
Returns:
A Qwen3MoeForCausalLM model with parameters mixed from the base and expert models.
"""Construct a Qwen3MoeForCausalLM model using the config from a base Qwen3 model
and parameters mixed from multiple expert Qwen3 models.
Args:
base_model: A Qwen3 model whose configuration (``base_model.config``) is used
to derive the corresponding MoE configuration. Its weights are not copied
into the resulting MoE model.
expert_models: A list of expert Qwen3 models whose parameters are averaged
into the corresponding non-MoE submodules of the MoE model (e.g. embeddings,
attention, layer norms, and, when applicable, MLP experts).
Returns:
A Qwen3MoeForCausalLM model whose architecture is determined by ``base_model``
via its config, and whose non-MoE parameters are initialized from
``expert_models`` via simple averaging. MoE-specific/router parameters that
are not explicitly set remain at their default initialization.

Copilot uses AI. Check for mistakes.
Comment on lines +100 to +103
moe_args = Qwen3MoEArgs(num_experts=len(expert_models))

# Construct the MoE config from the base model's config and the provided MoE arguments
moe_config = construct_qwen3_moe_config_from_base(base_model.config, moe_args)
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There’s no guard for expert_models being empty. With an empty list, simple_average will assert and the later torch.stack(...) will fail. Add an explicit validation early (with a clear error message) so failures are predictable.

Copilot uses AI. Check for mistakes.
Comment on lines +109 to +112
def _average_expert_parameters(module_name):
print(f"Averaging parameters for module: {module_name}")
base_module = moe_model.get_submodule(module_name)
expert_modules = [
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_average_expert_parameters uses print(...) for per-module progress. For large models this will be very noisy and slow, and it bypasses the project’s logging configuration. Please switch to logging.getLogger(__name__) (or gate output behind a verbosity flag).

Copilot uses AI. Check for mistakes.
Comment on lines +86 to +90
def mix_qwen3_models_to_moe(
base_model: Qwen3ForCausalLM,
expert_models: list[Qwen3ForCausalLM],
**kwargs,
) -> Qwen3MoeForCausalLM:
Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New model-mixing behavior is introduced here but there are no tests. The repo has method-level unit tests (e.g., tests/test_depth_upscaling.py, tests/test_simple_average.py); consider adding a small test that builds minimal Qwen3/Qwen3MoE configs (or mocks modules) and verifies the resulting MoE expert weight tensors have expected shapes and are populated from the experts.

Copilot uses AI. Check for mistakes.
Comment on lines +31 to +35
mlp_only_layers: Layer indices that use a plain MLP instead of a MoE block.
If ``None``, defaults to an empty list (use ``decoder_sparse_step`` to
determine sparsity).
"""

Copy link

Copilot AI Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring says mlp_only_layers defaults to an empty list when None, but the dataclass leaves it as None and passes that through to Qwen3MoeConfig. Either default it to field(default_factory=list) or adjust the docstring so it matches the actual behavior.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (2)
fusion_bench/method/branch_train_mix/__init__.py (1)

1-5: Module lacks re-exports and registry entry for lazy import system.

The __init__.py only contains a docstring but doesn't re-export the public API (Qwen3MoEArgs, construct_qwen3_moe_config_from_base, mix_qwen3_models_to_moe) from qwen3.py. Additionally, branch_train_mix is not registered in fusion_bench/method/__init__.py's _import_structure dictionary.

To make this module accessible through the framework's lazy import system:

  1. Add exports here using LazyImporter (per coding guidelines to avoid heavy imports at module level)
  2. Update fusion_bench/method/__init__.py to include branch_train_mix in _import_structure

Based on learnings: "Update fusion_bench/method/__init__.py with the import structure when adding new fusion algorithms to make them available through the lazy import system."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fusion_bench/method/branch_train_mix/__init__.py` around lines 1 - 5, The
module only has a docstring and must re-export its public API via the
lazy-import system: in fusion_bench/method/branch_train_mix/__init__.py use the
project's LazyImporter pattern to expose Qwen3MoEArgs,
construct_qwen3_moe_config_from_base, and mix_qwen3_models_to_moe from qwen3.py
(so imports are deferred), and then add "branch_train_mix": ["Qwen3MoEArgs",
"construct_qwen3_moe_config_from_base", "mix_qwen3_models_to_moe"] to the
_import_structure in fusion_bench/method/__init__.py so the package registry can
lazily import this module.
fusion_bench/method/branch_train_mix/qwen3.py (1)

108-131: Replace print() with proper logging; verify rotary_emb averaging is needed.

  1. Line 110: Using print() for status output is not ideal for a library. Consider using Python's logging module or the framework's logging utilities.

  2. Line 124: model.rotary_emb (RoPE embeddings) typically contains computed position embeddings based on configuration, not trainable parameters. Averaging may be unnecessary or could cause issues if the configs differ.

♻️ Suggested logging fix
+import logging
+
+log = logging.getLogger(__name__)
+
 def mix_qwen3_models_to_moe(...):
     ...
     def _average_expert_parameters(module_name):
-        print(f"Averaging parameters for module: {module_name}")
+        log.info(f"Averaging parameters for module: {module_name}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fusion_bench/method/branch_train_mix/qwen3.py` around lines 108 - 131,
Replace the ad-hoc print in _average_expert_parameters with the package logger
(use logging.getLogger(__name__) or the project logger) and log at an
appropriate level (info/debug) so status messages integrate with the app’s
logging; locate _average_expert_parameters, moe_model, expert_models and
simple_average to update the call site. Also do not blindly average
"model.rotary_emb": before calling simple_average for the "model.rotary_emb"
module, check that moe_model.get_submodule("model.rotary_emb") and each expert
get_submodule have trainable parameters (or a flag like requires_grad) and only
call simple_average when parameters exist and configs match; otherwise skip
logging that you skipped it. Ensure the loop that uses
moe_config.num_hidden_layers still averages attention and layernorm modules as
before.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@fusion_bench/method/branch_train_mix/qwen3.py`:
- Around line 86-106: The function mix_qwen3_models_to_moe currently ignores
**kwargs and constructs Qwen3MoEArgs with a hardcoded default
moe_intermediate_size which can mismatch the base model; change the signature to
accept an optional moe_args: Qwen3MoEArgs = None and/or consume **kwargs (or
remove **kwargs) so callers can supply arguments, and when moe_args is None
derive moe_intermediate_size from the base model (e.g.,
base_model.config.intermediate_size or base_model.config.hidden_size *
appropriate factor) before creating Qwen3MoEArgs; update the call sites to use
construct_qwen3_moe_config_from_base and Qwen3MoeForCausalLM with this derived
or provided moe_args so MLP weight dimensions align.
- Around line 1-16: Top-level heavy imports (torch, transformers and
Qwen3*/Qwen3Moe* classes) should be moved into the functions to avoid importing
heavy dependencies on module import; remove module-level imports of torch,
Qwen3Config, Qwen3ForCausalLM, Qwen3MoeConfig, Qwen3MoeForCausalLM,
Qwen3MoeExperts, and Qwen3MoeMLP and instead add the specific imports at the
start of construct_qwen3_moe_config_from_base (import Qwen3MoeConfig) and at the
start of mix_qwen3_models_to_moe (import torch, Qwen3MoeForCausalLM and the
Qwen3MoeExperts/Qwen3MoeMLP symbols) so the heavy libraries are only loaded when
those functions are executed.
- Around line 46-83: The Qwen3MoeConfig constructor is receiving invalid/renamed
params: replace rope_parameters with rope_theta (use base_config.rope_theta),
remove unsupported params attention_bias, use_sliding_window, sliding_window,
and attention_dropout from the Qwen3MoeConfig call, and ensure mlp_only_layers
is a list (e.g. pass moe_args.mlp_only_layers or [] if None) in
construct_qwen3_moe_config_from_base so only valid Qwen3MoeConfig fields are
supplied.

---

Nitpick comments:
In `@fusion_bench/method/branch_train_mix/__init__.py`:
- Around line 1-5: The module only has a docstring and must re-export its public
API via the lazy-import system: in
fusion_bench/method/branch_train_mix/__init__.py use the project's LazyImporter
pattern to expose Qwen3MoEArgs, construct_qwen3_moe_config_from_base, and
mix_qwen3_models_to_moe from qwen3.py (so imports are deferred), and then add
"branch_train_mix": ["Qwen3MoEArgs", "construct_qwen3_moe_config_from_base",
"mix_qwen3_models_to_moe"] to the _import_structure in
fusion_bench/method/__init__.py so the package registry can lazily import this
module.

In `@fusion_bench/method/branch_train_mix/qwen3.py`:
- Around line 108-131: Replace the ad-hoc print in _average_expert_parameters
with the package logger (use logging.getLogger(__name__) or the project logger)
and log at an appropriate level (info/debug) so status messages integrate with
the app’s logging; locate _average_expert_parameters, moe_model, expert_models
and simple_average to update the call site. Also do not blindly average
"model.rotary_emb": before calling simple_average for the "model.rotary_emb"
module, check that moe_model.get_submodule("model.rotary_emb") and each expert
get_submodule have trainable parameters (or a flag like requires_grad) and only
call simple_average when parameters exist and configs match; otherwise skip
logging that you skipped it. Ensure the loop that uses
moe_config.num_hidden_layers still averages attention and layernorm modules as
before.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 2e547f8f-ab8c-4bb4-b109-b7b82dd1c956

📥 Commits

Reviewing files that changed from the base of the PR and between 9a8df3c and a90a8d1.

📒 Files selected for processing (2)
  • fusion_bench/method/branch_train_mix/__init__.py
  • fusion_bench/method/branch_train_mix/qwen3.py

Comment on lines +1 to +16
from dataclasses import dataclass
from typing import cast

import torch
from transformers import (
Qwen3Config,
Qwen3ForCausalLM,
Qwen3MoeConfig,
Qwen3MoeForCausalLM,
)
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
Qwen3MoeExperts,
Qwen3MoeMLP,
)

from fusion_bench.method.simple_average import simple_average
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Move heavy imports inside function bodies per coding guidelines.

PyTorch and Transformers are imported at module level, which violates the project's coding guidelines. These should be deferred to function/method bodies to avoid loading heavy dependencies when the module is imported.

As per coding guidelines: "Avoid importing PyTorch and Transformers at module level in method, modelpool, and taskpool implementations; defer imports to function/method bodies."

♻️ Suggested refactor
 from dataclasses import dataclass
 from typing import cast
+from typing import TYPE_CHECKING
 
-import torch
-from transformers import (
-    Qwen3Config,
-    Qwen3ForCausalLM,
-    Qwen3MoeConfig,
-    Qwen3MoeForCausalLM,
-)
-from transformers.models.qwen3_moe.modeling_qwen3_moe import (
-    Qwen3MoeExperts,
-    Qwen3MoeMLP,
-)
+if TYPE_CHECKING:
+    from transformers import (
+        Qwen3Config,
+        Qwen3ForCausalLM,
+        Qwen3MoeConfig,
+        Qwen3MoeForCausalLM,
+    )
 
 from fusion_bench.method.simple_average import simple_average

Then add the imports at the start of construct_qwen3_moe_config_from_base and mix_qwen3_models_to_moe:

def construct_qwen3_moe_config_from_base(...):
    from transformers import Qwen3MoeConfig
    # ... rest of function

def mix_qwen3_models_to_moe(...):
    import torch
    from transformers import Qwen3MoeForCausalLM
    from transformers.models.qwen3_moe.modeling_qwen3_moe import (
        Qwen3MoeExperts,
        Qwen3MoeMLP,
    )
    # ... rest of function
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fusion_bench/method/branch_train_mix/qwen3.py` around lines 1 - 16, Top-level
heavy imports (torch, transformers and Qwen3*/Qwen3Moe* classes) should be moved
into the functions to avoid importing heavy dependencies on module import;
remove module-level imports of torch, Qwen3Config, Qwen3ForCausalLM,
Qwen3MoeConfig, Qwen3MoeForCausalLM, Qwen3MoeExperts, and Qwen3MoeMLP and
instead add the specific imports at the start of
construct_qwen3_moe_config_from_base (import Qwen3MoeConfig) and at the start of
mix_qwen3_models_to_moe (import torch, Qwen3MoeForCausalLM and the
Qwen3MoeExperts/Qwen3MoeMLP symbols) so the heavy libraries are only loaded when
those functions are executed.

Comment on lines +46 to +83
def construct_qwen3_moe_config_from_base(
base_config: Qwen3Config,
moe_args: Qwen3MoEArgs,
) -> Qwen3MoeConfig:
"""Constructs a Qwen3MoeConfig from a base Qwen3Config by copying relevant parameters and setting the number of experts.

Args:
base_config: The base Qwen3Config to copy parameters from.

"""
return Qwen3MoeConfig(
vocab_size=base_config.vocab_size,
hidden_size=base_config.hidden_size,
intermediate_size=base_config.intermediate_size,
num_hidden_layers=base_config.num_hidden_layers,
num_attention_heads=base_config.num_attention_heads,
num_key_value_heads=base_config.num_key_value_heads,
hidden_act=base_config.hidden_act,
max_position_embeddings=base_config.max_position_embeddings,
initializer_range=base_config.initializer_range,
rms_norm_eps=base_config.rms_norm_eps,
use_cache=base_config.use_cache,
tie_word_embeddings=base_config.tie_word_embeddings,
rope_parameters=base_config.rope_parameters,
attention_bias=base_config.attention_bias,
use_sliding_window=base_config.use_sliding_window,
sliding_window=base_config.sliding_window,
attention_dropout=base_config.attention_dropout,
# MoE arguments
num_experts=moe_args.num_experts,
num_experts_per_tok=moe_args.num_experts_per_tok,
moe_intermediate_size=moe_args.moe_intermediate_size,
decoder_sparse_step=moe_args.decoder_sparse_step,
norm_topk_prob=moe_args.norm_topk_prob,
output_router_logits=moe_args.output_router_logits,
router_aux_loss_coef=moe_args.router_aux_loss_coef,
mlp_only_layers=moe_args.mlp_only_layers,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🌐 Web query:

What version of Hugging Face transformers added Qwen3MoeConfig and what attributes does it support?

💡 Result:

Qwen3MoeConfig was added in Hugging Face Transformers version 4.52.2 or later, as evidenced by its documentation availability in v4.52.2 and presence in v5.3.0 (current as of 2026). It is documented across multiple versions including v4.52.2, v5.3.0, and main, with source code links to GitHub at those tags. No earlier versions mention it, and it aligns with the Qwen3 MoE model release. Supported attributes (from official docs) include: - vocab_size (int, default 151936): Vocabulary size. - hidden_size (int, default 2048): Hidden dimension. - intermediate_size (int, default 6144): MLP dimension. - num_hidden_layers (int, default 24): Number of layers. - num_attention_heads (int): Attention heads. - num_key_value_heads (int): KV heads for GQA. - num_experts_per_tok (int, default 8): Selected experts per token. - num_experts (int, default 128): Total routed experts. - mlp_only_layers (list[int], default []): Layers using MLP instead of MoE. - moe_intermediate_size (int, default 768): MoE intermediate size. - norm_topk_prob (bool): Normalize top-k probabilities. - output_router_logits (bool): Output router logits. - eos_token_id (int): End-of-sequence token. - rope_theta (float, default 10000.0): RoPE scaling. - rms_norm_eps (float, default 1e-6): RMSNorm epsilon. - use_cache (bool, default True): KV caching. - And standard PretrainedConfig attributes like max_position_embeddings (32768), hidden_act ('silu'), initializer_range (0.02), etc.

Citations:


Fix invalid parameter names passed to Qwen3MoeConfig constructor.

The function passes several attributes that are not supported by Qwen3MoeConfig:

  • Line 69: rope_parameters should be rope_theta (renamed in the Transformers API)
  • Lines 70-73: attention_bias, use_sliding_window, sliding_window, and attention_dropout are not valid Qwen3MoeConfig parameters
  • Line 82: mlp_only_layers expects a list (default []), not None

These will cause a TypeError at runtime. Remove or replace unsupported parameters with their correct equivalents from the official Qwen3MoeConfig API.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fusion_bench/method/branch_train_mix/qwen3.py` around lines 46 - 83, The
Qwen3MoeConfig constructor is receiving invalid/renamed params: replace
rope_parameters with rope_theta (use base_config.rope_theta), remove unsupported
params attention_bias, use_sliding_window, sliding_window, and attention_dropout
from the Qwen3MoeConfig call, and ensure mlp_only_layers is a list (e.g. pass
moe_args.mlp_only_layers or [] if None) in construct_qwen3_moe_config_from_base
so only valid Qwen3MoeConfig fields are supplied.

Comment on lines +86 to +106
def mix_qwen3_models_to_moe(
base_model: Qwen3ForCausalLM,
expert_models: list[Qwen3ForCausalLM],
**kwargs,
) -> Qwen3MoeForCausalLM:
"""Mixes the parameters of a base Qwen3 model and multiple expert Qwen3 models into a single Qwen3MoeForCausalLM model.

Args:
base_model: The base Qwen3 model to use as the foundation for the MoE model.
expert_models: A list of expert Qwen3 models whose parameters will be mixed into the MoE model.

Returns:
A Qwen3MoeForCausalLM model with parameters mixed from the base and expert models.
"""
moe_args = Qwen3MoEArgs(num_experts=len(expert_models))

# Construct the MoE config from the base model's config and the provided MoE arguments
moe_config = construct_qwen3_moe_config_from_base(base_model.config, moe_args)

# Initialize a new MoE model with the constructed config
moe_model = Qwen3MoeForCausalLM(moe_config)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Potential dimension mismatch and unused **kwargs.

Two concerns:

  1. **kwargs is unused (line 89) - This silently ignores any additional arguments passed by callers, which could hide bugs.

  2. Hardcoded moe_intermediate_size=768 (line 100) - The function creates Qwen3MoEArgs with default moe_intermediate_size=768, but this may not match the base model's intermediate_size. When copying MLP weights from expert models to the MoE model, mismatched dimensions would cause runtime errors or silent truncation.

Consider accepting moe_args as a parameter or deriving moe_intermediate_size from the base model's config:

🐛 Proposed fix
 def mix_qwen3_models_to_moe(
     base_model: Qwen3ForCausalLM,
     expert_models: list[Qwen3ForCausalLM],
-    **kwargs,
+    moe_args: Qwen3MoEArgs | None = None,
 ) -> Qwen3MoeForCausalLM:
-    moe_args = Qwen3MoEArgs(num_experts=len(expert_models))
+    if moe_args is None:
+        moe_args = Qwen3MoEArgs(
+            num_experts=len(expert_models),
+            moe_intermediate_size=base_model.config.intermediate_size,
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@fusion_bench/method/branch_train_mix/qwen3.py` around lines 86 - 106, The
function mix_qwen3_models_to_moe currently ignores **kwargs and constructs
Qwen3MoEArgs with a hardcoded default moe_intermediate_size which can mismatch
the base model; change the signature to accept an optional moe_args:
Qwen3MoEArgs = None and/or consume **kwargs (or remove **kwargs) so callers can
supply arguments, and when moe_args is None derive moe_intermediate_size from
the base model (e.g., base_model.config.intermediate_size or
base_model.config.hidden_size * appropriate factor) before creating
Qwen3MoEArgs; update the call sites to use construct_qwen3_moe_config_from_base
and Qwen3MoeForCausalLM with this derived or provided moe_args so MLP weight
dimensions align.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants