feat: add Branch-Train-MiX method implementation for expert LLMs#211
feat: add Branch-Train-MiX method implementation for expert LLMs#211
Conversation
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: defaults Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Pull request overview
This PR introduces an initial Branch-Train-MiX implementation scaffold for Qwen3-family LLMs, focusing on mixing multiple Qwen3 expert models into a single Qwen3 MoE (Mixture-of-Experts) model.
Changes:
- Added Qwen3 → Qwen3-MoE config construction helper and a function to mix expert weights into a
Qwen3MoeForCausalLM. - Added a package docstring describing the Branch-Train-MiX method and its citation.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 7 comments.
| File | Description |
|---|---|
| fusion_bench/method/branch_train_mix/qwen3.py | Adds Qwen3 MoE config construction and expert-weight mixing logic into a Qwen3 MoE model. |
| fusion_bench/method/branch_train_mix/init.py | Adds module-level documentation for the Branch-Train-MiX method and paper reference. |
| def mix_qwen3_models_to_moe( | ||
| base_model: Qwen3ForCausalLM, | ||
| expert_models: list[Qwen3ForCausalLM], | ||
| **kwargs, | ||
| ) -> Qwen3MoeForCausalLM: |
There was a problem hiding this comment.
mix_qwen3_models_to_moe accepts **kwargs but never uses it. Either remove **kwargs or use it (e.g., to accept/override Qwen3MoEArgs fields) so callers don’t think the function supports additional options when it doesn’t.
| """Constructs a Qwen3MoeConfig from a base Qwen3Config by copying relevant parameters and setting the number of experts. | ||
|
|
||
| Args: | ||
| base_config: The base Qwen3Config to copy parameters from. | ||
|
|
There was a problem hiding this comment.
construct_qwen3_moe_config_from_base takes moe_args but the docstring doesn’t document it. Please add moe_args to the Args section and clarify which MoE-specific settings come from moe_args vs which are copied from base_config.
| """Constructs a Qwen3MoeConfig from a base Qwen3Config by copying relevant parameters and setting the number of experts. | |
| Args: | |
| base_config: The base Qwen3Config to copy parameters from. | |
| """Constructs a Qwen3MoeConfig from a base Qwen3Config and MoE arguments. | |
| All standard transformer configuration fields (e.g., vocabulary size, hidden | |
| size, number of layers/heads, RoPE parameters, and dropout settings) are | |
| copied directly from ``base_config``. MoE-specific settings (such as the | |
| number of experts, experts-per-token, MoE intermediate size, sparsity | |
| pattern, and router/auxiliary-loss behavior) are taken from ``moe_args``. | |
| Args: | |
| base_config: The base Qwen3Config to copy non-MoE model parameters from. | |
| moe_args: MoE-specific configuration to apply on top of ``base_config``, | |
| providing values for fields like ``num_experts``, ``num_experts_per_tok``, | |
| ``moe_intermediate_size``, ``decoder_sparse_step``, ``norm_topk_prob``, | |
| ``output_router_logits``, ``router_aux_loss_coef``, and | |
| ``mlp_only_layers``. |
| """Mixes the parameters of a base Qwen3 model and multiple expert Qwen3 models into a single Qwen3MoeForCausalLM model. | ||
|
|
||
| Args: | ||
| base_model: The base Qwen3 model to use as the foundation for the MoE model. | ||
| expert_models: A list of expert Qwen3 models whose parameters will be mixed into the MoE model. | ||
|
|
||
| Returns: | ||
| A Qwen3MoeForCausalLM model with parameters mixed from the base and expert models. |
There was a problem hiding this comment.
base_model is only used to read base_model.config; none of its weights are copied into moe_model despite the docstring describing it as the “foundation”. This likely leaves any non-overwritten MoE/router parameters at random init. Consider initializing from base_model weights where possible, or remove/rename base_model and update the docstring to match.
| """Mixes the parameters of a base Qwen3 model and multiple expert Qwen3 models into a single Qwen3MoeForCausalLM model. | |
| Args: | |
| base_model: The base Qwen3 model to use as the foundation for the MoE model. | |
| expert_models: A list of expert Qwen3 models whose parameters will be mixed into the MoE model. | |
| Returns: | |
| A Qwen3MoeForCausalLM model with parameters mixed from the base and expert models. | |
| """Construct a Qwen3MoeForCausalLM model using the config from a base Qwen3 model | |
| and parameters mixed from multiple expert Qwen3 models. | |
| Args: | |
| base_model: A Qwen3 model whose configuration (``base_model.config``) is used | |
| to derive the corresponding MoE configuration. Its weights are not copied | |
| into the resulting MoE model. | |
| expert_models: A list of expert Qwen3 models whose parameters are averaged | |
| into the corresponding non-MoE submodules of the MoE model (e.g. embeddings, | |
| attention, layer norms, and, when applicable, MLP experts). | |
| Returns: | |
| A Qwen3MoeForCausalLM model whose architecture is determined by ``base_model`` | |
| via its config, and whose non-MoE parameters are initialized from | |
| ``expert_models`` via simple averaging. MoE-specific/router parameters that | |
| are not explicitly set remain at their default initialization. |
| moe_args = Qwen3MoEArgs(num_experts=len(expert_models)) | ||
|
|
||
| # Construct the MoE config from the base model's config and the provided MoE arguments | ||
| moe_config = construct_qwen3_moe_config_from_base(base_model.config, moe_args) |
There was a problem hiding this comment.
There’s no guard for expert_models being empty. With an empty list, simple_average will assert and the later torch.stack(...) will fail. Add an explicit validation early (with a clear error message) so failures are predictable.
| def _average_expert_parameters(module_name): | ||
| print(f"Averaging parameters for module: {module_name}") | ||
| base_module = moe_model.get_submodule(module_name) | ||
| expert_modules = [ |
There was a problem hiding this comment.
_average_expert_parameters uses print(...) for per-module progress. For large models this will be very noisy and slow, and it bypasses the project’s logging configuration. Please switch to logging.getLogger(__name__) (or gate output behind a verbosity flag).
| def mix_qwen3_models_to_moe( | ||
| base_model: Qwen3ForCausalLM, | ||
| expert_models: list[Qwen3ForCausalLM], | ||
| **kwargs, | ||
| ) -> Qwen3MoeForCausalLM: |
There was a problem hiding this comment.
New model-mixing behavior is introduced here but there are no tests. The repo has method-level unit tests (e.g., tests/test_depth_upscaling.py, tests/test_simple_average.py); consider adding a small test that builds minimal Qwen3/Qwen3MoE configs (or mocks modules) and verifies the resulting MoE expert weight tensors have expected shapes and are populated from the experts.
| mlp_only_layers: Layer indices that use a plain MLP instead of a MoE block. | ||
| If ``None``, defaults to an empty list (use ``decoder_sparse_step`` to | ||
| determine sparsity). | ||
| """ | ||
|
|
There was a problem hiding this comment.
Docstring says mlp_only_layers defaults to an empty list when None, but the dataclass leaves it as None and passes that through to Qwen3MoeConfig. Either default it to field(default_factory=list) or adjust the docstring so it matches the actual behavior.
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (2)
fusion_bench/method/branch_train_mix/__init__.py (1)
1-5: Module lacks re-exports and registry entry for lazy import system.The
__init__.pyonly contains a docstring but doesn't re-export the public API (Qwen3MoEArgs,construct_qwen3_moe_config_from_base,mix_qwen3_models_to_moe) fromqwen3.py. Additionally,branch_train_mixis not registered infusion_bench/method/__init__.py's_import_structuredictionary.To make this module accessible through the framework's lazy import system:
- Add exports here using
LazyImporter(per coding guidelines to avoid heavy imports at module level)- Update
fusion_bench/method/__init__.pyto includebranch_train_mixin_import_structureBased on learnings: "Update
fusion_bench/method/__init__.pywith the import structure when adding new fusion algorithms to make them available through the lazy import system."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fusion_bench/method/branch_train_mix/__init__.py` around lines 1 - 5, The module only has a docstring and must re-export its public API via the lazy-import system: in fusion_bench/method/branch_train_mix/__init__.py use the project's LazyImporter pattern to expose Qwen3MoEArgs, construct_qwen3_moe_config_from_base, and mix_qwen3_models_to_moe from qwen3.py (so imports are deferred), and then add "branch_train_mix": ["Qwen3MoEArgs", "construct_qwen3_moe_config_from_base", "mix_qwen3_models_to_moe"] to the _import_structure in fusion_bench/method/__init__.py so the package registry can lazily import this module.fusion_bench/method/branch_train_mix/qwen3.py (1)
108-131: Replaceprint()with proper logging; verifyrotary_embaveraging is needed.
Line 110: Using
print()for status output is not ideal for a library. Consider using Python'sloggingmodule or the framework's logging utilities.Line 124:
model.rotary_emb(RoPE embeddings) typically contains computed position embeddings based on configuration, not trainable parameters. Averaging may be unnecessary or could cause issues if the configs differ.♻️ Suggested logging fix
+import logging + +log = logging.getLogger(__name__) + def mix_qwen3_models_to_moe(...): ... def _average_expert_parameters(module_name): - print(f"Averaging parameters for module: {module_name}") + log.info(f"Averaging parameters for module: {module_name}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@fusion_bench/method/branch_train_mix/qwen3.py` around lines 108 - 131, Replace the ad-hoc print in _average_expert_parameters with the package logger (use logging.getLogger(__name__) or the project logger) and log at an appropriate level (info/debug) so status messages integrate with the app’s logging; locate _average_expert_parameters, moe_model, expert_models and simple_average to update the call site. Also do not blindly average "model.rotary_emb": before calling simple_average for the "model.rotary_emb" module, check that moe_model.get_submodule("model.rotary_emb") and each expert get_submodule have trainable parameters (or a flag like requires_grad) and only call simple_average when parameters exist and configs match; otherwise skip logging that you skipped it. Ensure the loop that uses moe_config.num_hidden_layers still averages attention and layernorm modules as before.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@fusion_bench/method/branch_train_mix/qwen3.py`:
- Around line 86-106: The function mix_qwen3_models_to_moe currently ignores
**kwargs and constructs Qwen3MoEArgs with a hardcoded default
moe_intermediate_size which can mismatch the base model; change the signature to
accept an optional moe_args: Qwen3MoEArgs = None and/or consume **kwargs (or
remove **kwargs) so callers can supply arguments, and when moe_args is None
derive moe_intermediate_size from the base model (e.g.,
base_model.config.intermediate_size or base_model.config.hidden_size *
appropriate factor) before creating Qwen3MoEArgs; update the call sites to use
construct_qwen3_moe_config_from_base and Qwen3MoeForCausalLM with this derived
or provided moe_args so MLP weight dimensions align.
- Around line 1-16: Top-level heavy imports (torch, transformers and
Qwen3*/Qwen3Moe* classes) should be moved into the functions to avoid importing
heavy dependencies on module import; remove module-level imports of torch,
Qwen3Config, Qwen3ForCausalLM, Qwen3MoeConfig, Qwen3MoeForCausalLM,
Qwen3MoeExperts, and Qwen3MoeMLP and instead add the specific imports at the
start of construct_qwen3_moe_config_from_base (import Qwen3MoeConfig) and at the
start of mix_qwen3_models_to_moe (import torch, Qwen3MoeForCausalLM and the
Qwen3MoeExperts/Qwen3MoeMLP symbols) so the heavy libraries are only loaded when
those functions are executed.
- Around line 46-83: The Qwen3MoeConfig constructor is receiving invalid/renamed
params: replace rope_parameters with rope_theta (use base_config.rope_theta),
remove unsupported params attention_bias, use_sliding_window, sliding_window,
and attention_dropout from the Qwen3MoeConfig call, and ensure mlp_only_layers
is a list (e.g. pass moe_args.mlp_only_layers or [] if None) in
construct_qwen3_moe_config_from_base so only valid Qwen3MoeConfig fields are
supplied.
---
Nitpick comments:
In `@fusion_bench/method/branch_train_mix/__init__.py`:
- Around line 1-5: The module only has a docstring and must re-export its public
API via the lazy-import system: in
fusion_bench/method/branch_train_mix/__init__.py use the project's LazyImporter
pattern to expose Qwen3MoEArgs, construct_qwen3_moe_config_from_base, and
mix_qwen3_models_to_moe from qwen3.py (so imports are deferred), and then add
"branch_train_mix": ["Qwen3MoEArgs", "construct_qwen3_moe_config_from_base",
"mix_qwen3_models_to_moe"] to the _import_structure in
fusion_bench/method/__init__.py so the package registry can lazily import this
module.
In `@fusion_bench/method/branch_train_mix/qwen3.py`:
- Around line 108-131: Replace the ad-hoc print in _average_expert_parameters
with the package logger (use logging.getLogger(__name__) or the project logger)
and log at an appropriate level (info/debug) so status messages integrate with
the app’s logging; locate _average_expert_parameters, moe_model, expert_models
and simple_average to update the call site. Also do not blindly average
"model.rotary_emb": before calling simple_average for the "model.rotary_emb"
module, check that moe_model.get_submodule("model.rotary_emb") and each expert
get_submodule have trainable parameters (or a flag like requires_grad) and only
call simple_average when parameters exist and configs match; otherwise skip
logging that you skipped it. Ensure the loop that uses
moe_config.num_hidden_layers still averages attention and layernorm modules as
before.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2e547f8f-ab8c-4bb4-b109-b7b82dd1c956
📒 Files selected for processing (2)
fusion_bench/method/branch_train_mix/__init__.pyfusion_bench/method/branch_train_mix/qwen3.py
| from dataclasses import dataclass | ||
| from typing import cast | ||
|
|
||
| import torch | ||
| from transformers import ( | ||
| Qwen3Config, | ||
| Qwen3ForCausalLM, | ||
| Qwen3MoeConfig, | ||
| Qwen3MoeForCausalLM, | ||
| ) | ||
| from transformers.models.qwen3_moe.modeling_qwen3_moe import ( | ||
| Qwen3MoeExperts, | ||
| Qwen3MoeMLP, | ||
| ) | ||
|
|
||
| from fusion_bench.method.simple_average import simple_average |
There was a problem hiding this comment.
🛠️ Refactor suggestion | 🟠 Major
Move heavy imports inside function bodies per coding guidelines.
PyTorch and Transformers are imported at module level, which violates the project's coding guidelines. These should be deferred to function/method bodies to avoid loading heavy dependencies when the module is imported.
As per coding guidelines: "Avoid importing PyTorch and Transformers at module level in method, modelpool, and taskpool implementations; defer imports to function/method bodies."
♻️ Suggested refactor
from dataclasses import dataclass
from typing import cast
+from typing import TYPE_CHECKING
-import torch
-from transformers import (
- Qwen3Config,
- Qwen3ForCausalLM,
- Qwen3MoeConfig,
- Qwen3MoeForCausalLM,
-)
-from transformers.models.qwen3_moe.modeling_qwen3_moe import (
- Qwen3MoeExperts,
- Qwen3MoeMLP,
-)
+if TYPE_CHECKING:
+ from transformers import (
+ Qwen3Config,
+ Qwen3ForCausalLM,
+ Qwen3MoeConfig,
+ Qwen3MoeForCausalLM,
+ )
from fusion_bench.method.simple_average import simple_averageThen add the imports at the start of construct_qwen3_moe_config_from_base and mix_qwen3_models_to_moe:
def construct_qwen3_moe_config_from_base(...):
from transformers import Qwen3MoeConfig
# ... rest of function
def mix_qwen3_models_to_moe(...):
import torch
from transformers import Qwen3MoeForCausalLM
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
Qwen3MoeExperts,
Qwen3MoeMLP,
)
# ... rest of function🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fusion_bench/method/branch_train_mix/qwen3.py` around lines 1 - 16, Top-level
heavy imports (torch, transformers and Qwen3*/Qwen3Moe* classes) should be moved
into the functions to avoid importing heavy dependencies on module import;
remove module-level imports of torch, Qwen3Config, Qwen3ForCausalLM,
Qwen3MoeConfig, Qwen3MoeForCausalLM, Qwen3MoeExperts, and Qwen3MoeMLP and
instead add the specific imports at the start of
construct_qwen3_moe_config_from_base (import Qwen3MoeConfig) and at the start of
mix_qwen3_models_to_moe (import torch, Qwen3MoeForCausalLM and the
Qwen3MoeExperts/Qwen3MoeMLP symbols) so the heavy libraries are only loaded when
those functions are executed.
| def construct_qwen3_moe_config_from_base( | ||
| base_config: Qwen3Config, | ||
| moe_args: Qwen3MoEArgs, | ||
| ) -> Qwen3MoeConfig: | ||
| """Constructs a Qwen3MoeConfig from a base Qwen3Config by copying relevant parameters and setting the number of experts. | ||
|
|
||
| Args: | ||
| base_config: The base Qwen3Config to copy parameters from. | ||
|
|
||
| """ | ||
| return Qwen3MoeConfig( | ||
| vocab_size=base_config.vocab_size, | ||
| hidden_size=base_config.hidden_size, | ||
| intermediate_size=base_config.intermediate_size, | ||
| num_hidden_layers=base_config.num_hidden_layers, | ||
| num_attention_heads=base_config.num_attention_heads, | ||
| num_key_value_heads=base_config.num_key_value_heads, | ||
| hidden_act=base_config.hidden_act, | ||
| max_position_embeddings=base_config.max_position_embeddings, | ||
| initializer_range=base_config.initializer_range, | ||
| rms_norm_eps=base_config.rms_norm_eps, | ||
| use_cache=base_config.use_cache, | ||
| tie_word_embeddings=base_config.tie_word_embeddings, | ||
| rope_parameters=base_config.rope_parameters, | ||
| attention_bias=base_config.attention_bias, | ||
| use_sliding_window=base_config.use_sliding_window, | ||
| sliding_window=base_config.sliding_window, | ||
| attention_dropout=base_config.attention_dropout, | ||
| # MoE arguments | ||
| num_experts=moe_args.num_experts, | ||
| num_experts_per_tok=moe_args.num_experts_per_tok, | ||
| moe_intermediate_size=moe_args.moe_intermediate_size, | ||
| decoder_sparse_step=moe_args.decoder_sparse_step, | ||
| norm_topk_prob=moe_args.norm_topk_prob, | ||
| output_router_logits=moe_args.output_router_logits, | ||
| router_aux_loss_coef=moe_args.router_aux_loss_coef, | ||
| mlp_only_layers=moe_args.mlp_only_layers, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
What version of Hugging Face transformers added Qwen3MoeConfig and what attributes does it support?
💡 Result:
Qwen3MoeConfig was added in Hugging Face Transformers version 4.52.2 or later, as evidenced by its documentation availability in v4.52.2 and presence in v5.3.0 (current as of 2026). It is documented across multiple versions including v4.52.2, v5.3.0, and main, with source code links to GitHub at those tags. No earlier versions mention it, and it aligns with the Qwen3 MoE model release. Supported attributes (from official docs) include: - vocab_size (int, default 151936): Vocabulary size. - hidden_size (int, default 2048): Hidden dimension. - intermediate_size (int, default 6144): MLP dimension. - num_hidden_layers (int, default 24): Number of layers. - num_attention_heads (int): Attention heads. - num_key_value_heads (int): KV heads for GQA. - num_experts_per_tok (int, default 8): Selected experts per token. - num_experts (int, default 128): Total routed experts. - mlp_only_layers (list[int], default []): Layers using MLP instead of MoE. - moe_intermediate_size (int, default 768): MoE intermediate size. - norm_topk_prob (bool): Normalize top-k probabilities. - output_router_logits (bool): Output router logits. - eos_token_id (int): End-of-sequence token. - rope_theta (float, default 10000.0): RoPE scaling. - rms_norm_eps (float, default 1e-6): RMSNorm epsilon. - use_cache (bool, default True): KV caching. - And standard PretrainedConfig attributes like max_position_embeddings (32768), hidden_act ('silu'), initializer_range (0.02), etc.
Citations:
- 1: https://huggingface.co/docs/transformers/v5.3.0/model_doc/qwen3_moe
- 2: https://huggingface.co/docs/transformers/en/model_doc/qwen3_moe
- 3: https://huggingface.co/docs/transformers/model_doc/qwen3_moe
- 4: https://huggingface.co/docs/transformers/v4.52.2/model_doc/qwen3_moe
- 5: https://www.huggingface.co
- 6: https://huggingface.co/docs/transformers/model_doc/qwen3_next
Fix invalid parameter names passed to Qwen3MoeConfig constructor.
The function passes several attributes that are not supported by Qwen3MoeConfig:
- Line 69:
rope_parametersshould berope_theta(renamed in the Transformers API) - Lines 70-73:
attention_bias,use_sliding_window,sliding_window, andattention_dropoutare not validQwen3MoeConfigparameters - Line 82:
mlp_only_layersexpects a list (default[]), notNone
These will cause a TypeError at runtime. Remove or replace unsupported parameters with their correct equivalents from the official Qwen3MoeConfig API.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fusion_bench/method/branch_train_mix/qwen3.py` around lines 46 - 83, The
Qwen3MoeConfig constructor is receiving invalid/renamed params: replace
rope_parameters with rope_theta (use base_config.rope_theta), remove unsupported
params attention_bias, use_sliding_window, sliding_window, and attention_dropout
from the Qwen3MoeConfig call, and ensure mlp_only_layers is a list (e.g. pass
moe_args.mlp_only_layers or [] if None) in construct_qwen3_moe_config_from_base
so only valid Qwen3MoeConfig fields are supplied.
| def mix_qwen3_models_to_moe( | ||
| base_model: Qwen3ForCausalLM, | ||
| expert_models: list[Qwen3ForCausalLM], | ||
| **kwargs, | ||
| ) -> Qwen3MoeForCausalLM: | ||
| """Mixes the parameters of a base Qwen3 model and multiple expert Qwen3 models into a single Qwen3MoeForCausalLM model. | ||
|
|
||
| Args: | ||
| base_model: The base Qwen3 model to use as the foundation for the MoE model. | ||
| expert_models: A list of expert Qwen3 models whose parameters will be mixed into the MoE model. | ||
|
|
||
| Returns: | ||
| A Qwen3MoeForCausalLM model with parameters mixed from the base and expert models. | ||
| """ | ||
| moe_args = Qwen3MoEArgs(num_experts=len(expert_models)) | ||
|
|
||
| # Construct the MoE config from the base model's config and the provided MoE arguments | ||
| moe_config = construct_qwen3_moe_config_from_base(base_model.config, moe_args) | ||
|
|
||
| # Initialize a new MoE model with the constructed config | ||
| moe_model = Qwen3MoeForCausalLM(moe_config) |
There was a problem hiding this comment.
Potential dimension mismatch and unused **kwargs.
Two concerns:
-
**kwargsis unused (line 89) - This silently ignores any additional arguments passed by callers, which could hide bugs. -
Hardcoded
moe_intermediate_size=768(line 100) - The function createsQwen3MoEArgswith defaultmoe_intermediate_size=768, but this may not match the base model'sintermediate_size. When copying MLP weights from expert models to the MoE model, mismatched dimensions would cause runtime errors or silent truncation.
Consider accepting moe_args as a parameter or deriving moe_intermediate_size from the base model's config:
🐛 Proposed fix
def mix_qwen3_models_to_moe(
base_model: Qwen3ForCausalLM,
expert_models: list[Qwen3ForCausalLM],
- **kwargs,
+ moe_args: Qwen3MoEArgs | None = None,
) -> Qwen3MoeForCausalLM:
- moe_args = Qwen3MoEArgs(num_experts=len(expert_models))
+ if moe_args is None:
+ moe_args = Qwen3MoEArgs(
+ num_experts=len(expert_models),
+ moe_intermediate_size=base_model.config.intermediate_size,
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@fusion_bench/method/branch_train_mix/qwen3.py` around lines 86 - 106, The
function mix_qwen3_models_to_moe currently ignores **kwargs and constructs
Qwen3MoEArgs with a hardcoded default moe_intermediate_size which can mismatch
the base model; change the signature to accept an optional moe_args:
Qwen3MoEArgs = None and/or consume **kwargs (or remove **kwargs) so callers can
supply arguments, and when moe_args is None derive moe_intermediate_size from
the base model (e.g., base_model.config.intermediate_size or
base_model.config.hidden_size * appropriate factor) before creating
Qwen3MoEArgs; update the call sites to use construct_qwen3_moe_config_from_base
and Qwen3MoeForCausalLM with this derived or provided moe_args so MLP weight
dimensions align.
Summary by CodeRabbit