From 26487c243c76bc1b60b8c43594e2264e5d22490e Mon Sep 17 00:00:00 2001 From: Eddie Tsai Date: Tue, 25 Nov 2025 10:54:11 +0800 Subject: [PATCH 01/13] [GLM4MOE] Add support for Liger kernel patches in GLM-4MOE models --- src/liger_kernel/transformers/__init__.py | 3 + .../transformers/model/glm4_moe.py | 153 ++++++++++++++++++ src/liger_kernel/transformers/monkey_patch.py | 77 +++++++++ test/convergence/bf16/test_mini_models.py | 60 +++++++ .../bf16/test_mini_models_with_logits.py | 61 +++++++ test/convergence/fp32/test_mini_models.py | 57 +++++++ .../fp32/test_mini_models_with_logits.py | 57 +++++++ test/transformers/test_monkey_patch.py | 81 ++++++++++ test/utils.py | 10 ++ 9 files changed, 559 insertions(+) create mode 100644 src/liger_kernel/transformers/model/glm4_moe.py diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index 30862f02e..5f8e6d594 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -39,6 +39,7 @@ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4_moe # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401 @@ -109,6 +110,7 @@ def __getattr__(name: str): "apply_liger_kernel_to_gemma3_text", "apply_liger_kernel_to_glm4", "apply_liger_kernel_to_glm4v", + "apply_liger_kernel_to_glm4_moe", "apply_liger_kernel_to_glm4v_moe", "apply_liger_kernel_to_granite", "apply_liger_kernel_to_internvl", @@ -185,6 +187,7 @@ def __getattr__(name: str): "apply_liger_kernel_to_gemma3", "apply_liger_kernel_to_gemma3_text", "apply_liger_kernel_to_glm4", + "apply_liger_kernel_to_glm4_moe", "apply_liger_kernel_to_glm4v", "apply_liger_kernel_to_glm4v_moe", "apply_liger_kernel_to_granite", diff --git a/src/liger_kernel/transformers/model/glm4_moe.py b/src/liger_kernel/transformers/model/glm4_moe.py new file mode 100644 index 000000000..675ddb73c --- /dev/null +++ b/src/liger_kernel/transformers/model/glm4_moe.py @@ -0,0 +1,153 @@ +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.utils.deprecation import deprecate_kwarg + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss + + +@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Example: + + ```python + >>> from transformers import AutoProcessor, Glm4MoeForCausalLM + >>> import torch + + >>> MODEL_PATH = "meta-glm4_moe/Glm4Moe-2-7b-hf" + >>> messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png" + }, + { + "type": "text", + "text": "describe this image" + } + ], + } + ] + >>> processor = AutoProcessor.from_pretrained(MODEL_PATH) + >>> model = Glm4MoeForCausalLM.from_pretrained( + pretrained_model_name_or_path=MODEL_PATH, + dtype="auto", + device_map="auto", + ) + >>> inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt" + ).to(model.device) + >>> inputs.pop("token_type_ids", None) + >>> generated_ids = model.generate(**inputs, max_new_tokens=8192) + >>> output_text = processor.decode(generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False) + ``` + """ + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + if skip_logits: + loss = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=outputs.rope_deltas, + ) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index df3f74951..164f3e01b 100755 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -2064,6 +2064,83 @@ def apply_liger_kernel_to_glm4( _patch_rms_norm_module(decoder_layer.post_self_attn_layernorm, in_place=False) _patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False) +def apply_liger_kernel_to_glm4moe( + rope: bool = False, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace GLM-4MOE models. + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is False. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not( cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + from transformers.models.glm4_moe import modeling_glm4_moe + from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeModel + from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMoE + from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4 + from liger_kernel.transformers.model.glm4_moe import lce_forward as glm4_moe_lce_forward + + + if rope: + modeling_glm4_moe.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast + if rms_norm: + modeling_glm4_moe.Glm4MoeRMSNorm = LigerRMSNormForGlm4 + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(glm4_moe_lce_forward, model) + else: + modeling_glm4_moe.Glm4MoeForCausalLM.forward = glm4_moe_lce_forward + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + # get the base model from the model instance + base_model: Glm4MoeModel = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module(base_model.norm) + + for decoder_layer in base_model.layers: + if swiglu: + _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + # patch MOE layers + if isinstance(Glm4MoeMoE, type) and isinstance(decoder_layer.mlp): + experts = getattr(decoder_layer.mlp, "experts", None) + if experts is not None: + for expert in experts: + _patch_swiglu_module(expert, LigerSwiGLUMLP) + if decoder_layer.mlp.shared_experts is not None: + _patch_swiglu_module(decoder_layer.mlp.shared_experts, LigerSwiGLUMLP) + for decoder_layer.mlp in base_model.layers: + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + def apply_liger_kernel_to_glm4v( rope: bool = False, diff --git a/test/convergence/bf16/test_mini_models.py b/test/convergence/bf16/test_mini_models.py index af528ed1c..7ba74376c 100644 --- a/test/convergence/bf16/test_mini_models.py +++ b/test/convergence/bf16/test_mini_models.py @@ -27,6 +27,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_gemma2 from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text from liger_kernel.transformers import apply_liger_kernel_to_glm4 +from liger_kernel.transformers import apply_liger_kernel_to_glm4_moe from liger_kernel.transformers import apply_liger_kernel_to_glm4v from liger_kernel.transformers import apply_liger_kernel_to_glm4v_moe from liger_kernel.transformers import apply_liger_kernel_to_granite @@ -62,6 +63,7 @@ from test.utils import revert_liger_kernel_to_gemma2 from test.utils import revert_liger_kernel_to_gemma3_text from test.utils import revert_liger_kernel_to_glm4 +from test.utils import revert_liger_kernel_to_glm4_moe from test.utils import revert_liger_kernel_to_glm4v from test.utils import revert_liger_kernel_to_glm4v_moe from test.utils import revert_liger_kernel_to_granite @@ -214,6 +216,14 @@ except ImportError: GLM4_AVAILABLE = False +try: + from transformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig + from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeForCausalLM + + GLM4_MOE_AVAILABLE = True +except ImportError: + GLM4_MOE_AVAILABLE = False + try: # Glm4v is only available in transformers>=4.51.3 from transformers.models.glm4v.configuration_glm4v import Glm4vConfig @@ -1080,6 +1090,37 @@ ), ) +if GLM4_MOE_AVAILABLE: + MINI_MODEL_SETUPS["mini_glm4_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_glm4_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4_moe, + model_class=Glm4MoeForCausalLM, + mini_model_config=Glm4MoeConfig( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + partial_rotary_factor=0.5, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, # 32768 + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + rope_scaling=None, + rope_theta=500_000, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + if GLM4V_AVAILABLE: MINI_MODEL_SETUPS["mini_glm4v"] = MiniModelConfig( liger_kernel_patch_func=apply_liger_kernel_to_glm4v, @@ -1829,6 +1870,25 @@ def run_mini_model( ), ], ), + pytest.param( + "mini_glm4_moe", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not GLM4_MOE_AVAILABLE, + reason="Glm4_moe not available in this version of transformers", + ), + ], + ), pytest.param( "mini_glm4v", 32, diff --git a/test/convergence/bf16/test_mini_models_with_logits.py b/test/convergence/bf16/test_mini_models_with_logits.py index f1d29b381..78a7c0022 100644 --- a/test/convergence/bf16/test_mini_models_with_logits.py +++ b/test/convergence/bf16/test_mini_models_with_logits.py @@ -27,6 +27,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_gemma2 from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text from liger_kernel.transformers import apply_liger_kernel_to_glm4 +from liger_kernel.transformers import apply_liger_kernel_to_glm4_moe from liger_kernel.transformers import apply_liger_kernel_to_glm4v from liger_kernel.transformers import apply_liger_kernel_to_glm4v_moe from liger_kernel.transformers import apply_liger_kernel_to_granite @@ -62,6 +63,7 @@ from test.utils import revert_liger_kernel_to_gemma2 from test.utils import revert_liger_kernel_to_gemma3_text from test.utils import revert_liger_kernel_to_glm4 +from test.utils import revert_liger_kernel_to_glm4_moe from test.utils import revert_liger_kernel_to_glm4v from test.utils import revert_liger_kernel_to_glm4v_moe from test.utils import revert_liger_kernel_to_granite @@ -206,6 +208,14 @@ except ImportError: GLM4_AVAILABLE = False +try: + from transformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig + from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeForCausalLM + + GLM4_MOE_AVAILABLE = True +except ImportError: + GLM4_MOE_AVAILABLE = False + try: # Glm4v is only available in transformers>=4.51.3 from transformers.models.glm4v.configuration_glm4v import Glm4vConfig @@ -1075,6 +1085,38 @@ attn_implementation="sdpa", # default value, pytorch native attention ), ) + +if GLM4_MOE_AVAILABLE: + MINI_MODEL_SETUPS["mini_glm4_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_glm4_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4_moe, + model_class=Glm4MoeForCausalLM, + mini_model_config=Glm4MoeConfig( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + partial_rotary_factor=0.5, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, # 32768 + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + rope_scaling=None, + rope_theta=500_000, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + if GLM4V_AVAILABLE: MINI_MODEL_SETUPS["mini_glm4v"] = MiniModelConfig( liger_kernel_patch_func=apply_liger_kernel_to_glm4v, @@ -1846,6 +1888,25 @@ def run_mini_model( ), ], ), + pytest.param( + "mini_glm4_moe", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not GLM4_MOE_AVAILABLE, + reason="Glm4_moe not available in this version of transformers", + ), + ], + ), pytest.param( "mini_glm4v", 32, diff --git a/test/convergence/fp32/test_mini_models.py b/test/convergence/fp32/test_mini_models.py index 263a07323..1b41f2b05 100644 --- a/test/convergence/fp32/test_mini_models.py +++ b/test/convergence/fp32/test_mini_models.py @@ -27,6 +27,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_gemma2 from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text from liger_kernel.transformers import apply_liger_kernel_to_glm4 +from liger_kernel.transformers import apply_liger_kernel_to_glm4_moe from liger_kernel.transformers import apply_liger_kernel_to_glm4v from liger_kernel.transformers import apply_liger_kernel_to_glm4v_moe from liger_kernel.transformers import apply_liger_kernel_to_granite @@ -62,6 +63,7 @@ from test.utils import revert_liger_kernel_to_gemma2 from test.utils import revert_liger_kernel_to_gemma3_text from test.utils import revert_liger_kernel_to_glm4 +from test.utils import revert_liger_kernel_to_glm4_moe from test.utils import revert_liger_kernel_to_glm4v from test.utils import revert_liger_kernel_to_glm4v_moe from test.utils import revert_liger_kernel_to_granite @@ -202,6 +204,14 @@ except ImportError: GLM4_AVAILABLE = False +try: + from transformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig + from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeForCausalLM + + GLM4_MOE_AVAILABLE = True +except ImportError: + GLM4_MOE_AVAILABLE = False + try: # Glm4v is only available in transformers>=4.51.3 from transformers.models.glm4v.configuration_glm4v import Glm4vConfig @@ -1016,6 +1026,37 @@ ), ) +if GLM4_MOE_AVAILABLE: + MINI_MODEL_SETUPS["mini_glm4_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_glm4_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4_moe, + model_class=Glm4MoeForCausalLM, + mini_model_config=Glm4MoeConfig( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + partial_rotary_factor=0.5, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=32768, + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + rope_scaling=None, + rope_theta=500_000, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + if GLM4V_AVAILABLE: MINI_MODEL_SETUPS["mini_glm4v"] = MiniModelConfig( liger_kernel_patch_func=apply_liger_kernel_to_glm4v, @@ -1721,6 +1762,22 @@ def run_mini_model( reason="Glm4 not available in this version of transformers", ), ), + pytest.param( + "mini_glm4_moe", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not GLM4_MOE_AVAILABLE, + reason="Glm4_moe not available in this version of transformers", + ), + ), pytest.param( "mini_glm4v", 32, diff --git a/test/convergence/fp32/test_mini_models_with_logits.py b/test/convergence/fp32/test_mini_models_with_logits.py index f14d13218..3f5ba33d3 100644 --- a/test/convergence/fp32/test_mini_models_with_logits.py +++ b/test/convergence/fp32/test_mini_models_with_logits.py @@ -27,6 +27,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_gemma2 from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text from liger_kernel.transformers import apply_liger_kernel_to_glm4 +from liger_kernel.transformers import apply_liger_kernel_to_glm4_moe from liger_kernel.transformers import apply_liger_kernel_to_glm4v from liger_kernel.transformers import apply_liger_kernel_to_glm4v_moe from liger_kernel.transformers import apply_liger_kernel_to_granite @@ -62,6 +63,7 @@ from test.utils import revert_liger_kernel_to_gemma2 from test.utils import revert_liger_kernel_to_gemma3_text from test.utils import revert_liger_kernel_to_glm4 +from test.utils import revert_liger_kernel_to_glm4_moe from test.utils import revert_liger_kernel_to_glm4v from test.utils import revert_liger_kernel_to_glm4v_moe from test.utils import revert_liger_kernel_to_granite @@ -226,6 +228,14 @@ except ImportError: GLM4_AVAILABLE = False +try: + from transformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig + from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeForCausalLM + + GLM4_MOE_AVAILABLE = True +except ImportError: + GLM4_MOE_AVAILABLE = False + try: # Glm4v is only available in transformers>=4.51.3 from transformers.models.glm4v.configuration_glm4v import Glm4vConfig @@ -1093,6 +1103,37 @@ ), ) +if GLM4_MOE_AVAILABLE: + MINI_MODEL_SETUPS["mini_glm4_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_glm4_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4_moe, + model_class=Glm4MoeForCausalLM, + mini_model_config=Glm4MoeConfig( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + partial_rotary_factor=0.5, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, # 32768 + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + rope_scaling=None, + rope_theta=500_000, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + if GLM4V_AVAILABLE: MINI_MODEL_SETUPS["mini_glm4v"] = MiniModelConfig( liger_kernel_patch_func=apply_liger_kernel_to_glm4v, @@ -1725,6 +1766,22 @@ def run_mini_model( reason="Glm4 not available in this version of transformers", ), ), + pytest.param( + "mini_glm4_moe", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not GLM4_MOE_AVAILABLE, + reason="Glm4_moe not available in this version of transformers", + ), + ), pytest.param( "mini_glm4v", 32, diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index 51fa02660..b60fa311c 100755 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -160,6 +160,14 @@ def is_glm4_available(): except ImportError: return False +def is_glm4_moe_available(): + try: + import transformers.models.glm4_moe # noqa: F401 + + return True + except ImportError: + return False + def is_glm4v_available(): try: @@ -232,6 +240,7 @@ def test_import_from_root(): from liger_kernel.transformers import apply_liger_kernel_to_gemma3 # noqa: F401 from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text # noqa: F401 from liger_kernel.transformers import apply_liger_kernel_to_glm4 # noqa: F401 + from liger_kernel.transformers import apply_liger_kernel_to_glm4_moe # noqa: F401 from liger_kernel.transformers import apply_liger_kernel_to_glm4v # noqa: F401 from liger_kernel.transformers import apply_liger_kernel_to_glm4v_moe # noqa: F401 from liger_kernel.transformers import apply_liger_kernel_to_internvl # noqa: F401 @@ -2579,6 +2588,78 @@ def test_apply_liger_kernel_to_instance_for_glm4(): pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}") +@pytest.mark.skipif(not is_glm4_moe_available(), reason="glm4 module not available") +def test_apply_liger_kernel_to_instance_for_glm4_moe(): + # Ensure any monkey patching is cleaned up for subsequent tests + with patch("transformers.models.glm4_moe.modeling_glm4_moe"): + from liger_kernel.transformers.model.glm4_moe import lce_forward as glm4_moe_lce_forward + from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4 + + # Instantiate a dummy model + config = transformers.models.glm4_moe._moe.Glm4MoeConfig( + dtype=torch.bfloat16, + rms_norm_eps=1e-5, + hidden_size=32, + intermediate_size=64, + hidden_act="silu", + num_hidden_layers=2, + n_routed_experts=1, + n_shared_experts=1, + ) + dummy_model_instance = AutoModelForCausalLM.from_config(config) + + # Check that model instance variables are not yet patched with Liger modules + assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(glm4_moe_lce_forward) + assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) + for layer in dummy_model_instance.model.layers: + assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + + for decoder_layer in dummy_model_instance.base_model.layers: + # https://github.com/huggingface/transformers/blob/69f003696b55de75b7f18888c03111909a7cd537/src/transformers/models/glm4_moe/modeling_glm4_moe.py#L438 + if decoder_layer.mlp.experts is not None: + for expert in decoder_layer.mlp.experts: + assert inspect.getsource(expert.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + if decoder_layer.mlp.shared_experts is not None: + assert inspect.getsource(decoder_layer.mlp.shared_experts.forward) != inspect.getsource( + LigerSwiGLUMLP.forward + ) + + # Test applying kernels to the model instance + _apply_liger_kernel_to_instance(model=dummy_model_instance) + + # Check that the model's instance variables were correctly patched with Liger modules + assert inspect.getsource(dummy_model_instance.forward) == inspect.getsource(glm4_moe_lce_forward) + assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) + for layer in dummy_model_instance.model.layers: + assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + for decoder_layer in dummy_model_instance.base_model.layers: + if decoder_layer.mlp is not None: + assert inspect.getsource(decoder_layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) == inspect.getsource( + LigerRMSNormForGlm4.forward + ) + assert inspect.getsource(decoder_layer.input_layernorm.forward) == inspect.getsource( + LigerRMSNormForGlm4.forward + ) + if getattr(decoder_layer.mlp, "experts", None) is not None: + for expert in decoder_layer.mlp.experts: + assert inspect.getsource(expert.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + if getattr(decoder_layer.mlp, "shared_experts", None) is not None: + assert inspect.getsource(decoder_layer.mlp.shared_experts.forward) == inspect.getsource( + LigerSwiGLUMLP.forward + ) + + + try: + print(dummy_model_instance) + except Exception as e: + pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}") + + @pytest.mark.skipif(not is_glm4v_available(), reason="glm4v module not available") def test_apply_liger_kernel_to_instance_for_glm4v(): # Ensure any monkey patching is cleaned up for subsequent tests diff --git a/test/utils.py b/test/utils.py index e7c84777d..1ddd61069 100644 --- a/test/utils.py +++ b/test/utils.py @@ -573,6 +573,16 @@ def revert_liger_kernel_to_glm4(model_config: MiniModelConfig): model_config.model_class = modeling_glm4.Glm4ForCausalLM print("Liger kernel patches have been reverted.") +def revert_liger_kernel_to_glm4_moe(model_config: MiniModelConfig): + """ + Revert all Liger kernel patches applied to Glm4_moe. + """ + + from transformers.models.glm4_moe import modeling_glm4_moe + + importlib.reload(modeling_glm4_moe) + model_config.model_class = modeling_glm4_moe.Glm4MoeForCausalLM + print("Liger kernel patches have been reverted.") def revert_liger_kernel_to_glm4v(model_config: MiniModelConfig): """ From dac15e9e132bd717286b3cd1d5baaad6788a0ec6 Mon Sep 17 00:00:00 2001 From: Eddie Tsai Date: Tue, 25 Nov 2025 10:57:16 +0800 Subject: [PATCH 02/13] [GLM4MOE] Formatting functions --- src/liger_kernel/transformers/monkey_patch.py | 8 ++++---- test/transformers/test_monkey_patch.py | 4 ++-- test/utils.py | 2 ++ 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 164f3e01b..1deb1465e 100755 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -2064,6 +2064,7 @@ def apply_liger_kernel_to_glm4( _patch_rms_norm_module(decoder_layer.post_self_attn_layernorm, in_place=False) _patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False) + def apply_liger_kernel_to_glm4moe( rope: bool = False, cross_entropy: bool = False, @@ -2087,15 +2088,15 @@ def apply_liger_kernel_to_glm4moe( model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been loaded. Default is None. """ - assert not( cross_entropy and fused_linear_cross_entropy), ( + assert not (cross_entropy and fused_linear_cross_entropy), ( "cross_entropy and fused_linear_cross_entropy cannot both be True." ) from transformers.models.glm4_moe import modeling_glm4_moe from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeModel from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMoE - from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4 - from liger_kernel.transformers.model.glm4_moe import lce_forward as glm4_moe_lce_forward + from liger_kernel.transformers.model.glm4_moe import lce_forward as glm4_moe_lce_forward + from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4 if rope: modeling_glm4_moe.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast @@ -2141,7 +2142,6 @@ def apply_liger_kernel_to_glm4moe( _patch_rms_norm_module(decoder_layer.post_attention_layernorm) - def apply_liger_kernel_to_glm4v( rope: bool = False, cross_entropy: bool = False, diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index b60fa311c..cd254d512 100755 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -160,6 +160,7 @@ def is_glm4_available(): except ImportError: return False + def is_glm4_moe_available(): try: import transformers.models.glm4_moe # noqa: F401 @@ -2615,7 +2616,7 @@ def test_apply_liger_kernel_to_instance_for_glm4_moe(): assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) - + for decoder_layer in dummy_model_instance.base_model.layers: # https://github.com/huggingface/transformers/blob/69f003696b55de75b7f18888c03111909a7cd537/src/transformers/models/glm4_moe/modeling_glm4_moe.py#L438 if decoder_layer.mlp.experts is not None: @@ -2653,7 +2654,6 @@ def test_apply_liger_kernel_to_instance_for_glm4_moe(): LigerSwiGLUMLP.forward ) - try: print(dummy_model_instance) except Exception as e: diff --git a/test/utils.py b/test/utils.py index 1ddd61069..f0290416a 100644 --- a/test/utils.py +++ b/test/utils.py @@ -573,6 +573,7 @@ def revert_liger_kernel_to_glm4(model_config: MiniModelConfig): model_config.model_class = modeling_glm4.Glm4ForCausalLM print("Liger kernel patches have been reverted.") + def revert_liger_kernel_to_glm4_moe(model_config: MiniModelConfig): """ Revert all Liger kernel patches applied to Glm4_moe. @@ -584,6 +585,7 @@ def revert_liger_kernel_to_glm4_moe(model_config: MiniModelConfig): model_config.model_class = modeling_glm4_moe.Glm4MoeForCausalLM print("Liger kernel patches have been reverted.") + def revert_liger_kernel_to_glm4v(model_config: MiniModelConfig): """ Revert all Liger kernel patches applied to Glm4v. From 14cfb908f7c7af36c10e56a62f74dfdd694b2e37 Mon Sep 17 00:00:00 2001 From: Eddie Tsai Date: Tue, 25 Nov 2025 11:17:51 +0800 Subject: [PATCH 03/13] Rename function for GLM-4MOE kernel application and update model type mapping --- src/liger_kernel/transformers/monkey_patch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 1deb1465e..fe9de3b95 100755 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -2065,7 +2065,7 @@ def apply_liger_kernel_to_glm4( _patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False) -def apply_liger_kernel_to_glm4moe( +def apply_liger_kernel_to_glm4_moe( rope: bool = False, cross_entropy: bool = False, fused_linear_cross_entropy: bool = True, @@ -2827,6 +2827,7 @@ def apply_liger_kernel_to_hunyuan_v1_moe( "gemma3_text": apply_liger_kernel_to_gemma3_text, "gemma3": apply_liger_kernel_to_gemma3, "glm4": apply_liger_kernel_to_glm4, + "glm4_moe": apply_liger_kernel_to_glm4_moe, "glm4v": apply_liger_kernel_to_glm4v, "glm4v_moe": apply_liger_kernel_to_glm4v_moe, "internvl": apply_liger_kernel_to_internvl, From 973e418d1260e69721fbd7066b9074fd9327affc Mon Sep 17 00:00:00 2001 From: Eddie Tsai Date: Tue, 25 Nov 2025 11:24:20 +0800 Subject: [PATCH 04/13] Refactor lce_forward function: update return type and remove deprecated parameters --- .../transformers/model/glm4_moe.py | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/src/liger_kernel/transformers/model/glm4_moe.py b/src/liger_kernel/transformers/model/glm4_moe.py index 675ddb73c..c8ca9bbc5 100644 --- a/src/liger_kernel/transformers/model/glm4_moe.py +++ b/src/liger_kernel/transformers/model/glm4_moe.py @@ -1,13 +1,14 @@ +from typing import List from typing import Optional from typing import Tuple from typing import Union import torch -from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.utils.deprecation import deprecate_kwarg from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") @@ -16,19 +17,18 @@ def lce_forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - rope_deltas: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, skip_logits: Optional[bool] = None, **kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: +) -> Union[Tuple, LigerCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -94,14 +94,14 @@ def lce_forward( # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, - pixel_values=pixel_values, - pixel_values_videos=pixel_values_videos, - image_grid_thw=image_grid_thw, - video_grid_thw=video_grid_thw, - position_ids=position_ids, attention_mask=attention_mask, + position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, cache_position=cache_position, **kwargs, ) @@ -114,6 +114,7 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) logits = None loss = None + token_accuracy = None if skip_logits and labels is None and shift_labels is None: raise ValueError("skip_logits is True, but labels and shift_labels are None") @@ -143,11 +144,11 @@ def lce_forward( **kwargs, ) - return CausalLMOutputWithPast( + return LigerCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - rope_deltas=outputs.rope_deltas, + token_accuracy=token_accuracy, ) From 39e7d180d47e3f138e3e1206cfd5fbb1cdf98339 Mon Sep 17 00:00:00 2001 From: Eddie Tsai Date: Tue, 25 Nov 2025 12:22:48 +0800 Subject: [PATCH 05/13] Fix import path for Glm4MoeConfig in test_apply_liger_kernel_to_instance_for_glm4_moe --- test/transformers/test_monkey_patch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index cd254d512..6d07422f5 100755 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -2597,7 +2597,7 @@ def test_apply_liger_kernel_to_instance_for_glm4_moe(): from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4 # Instantiate a dummy model - config = transformers.models.glm4_moe._moe.Glm4MoeConfig( + config = transformers.models.glm4_moe.Glm4MoeConfig( dtype=torch.bfloat16, rms_norm_eps=1e-5, hidden_size=32, From ca272421fde58623ba7eadec9b1bf0b2315fcb35 Mon Sep 17 00:00:00 2001 From: Eddie Tsai Date: Tue, 25 Nov 2025 13:36:56 +0800 Subject: [PATCH 06/13] fix tests --- src/liger_kernel/transformers/monkey_patch.py | 22 +++++----- test/transformers/test_monkey_patch.py | 44 ++++++++++--------- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index fe9de3b95..bf85e5224 100755 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -2128,18 +2128,16 @@ def apply_liger_kernel_to_glm4_moe( if rms_norm: _patch_rms_norm_module(decoder_layer.input_layernorm) _patch_rms_norm_module(decoder_layer.post_attention_layernorm) - # patch MOE layers - if isinstance(Glm4MoeMoE, type) and isinstance(decoder_layer.mlp): - experts = getattr(decoder_layer.mlp, "experts", None) - if experts is not None: - for expert in experts: - _patch_swiglu_module(expert, LigerSwiGLUMLP) - if decoder_layer.mlp.shared_experts is not None: - _patch_swiglu_module(decoder_layer.mlp.shared_experts, LigerSwiGLUMLP) - for decoder_layer.mlp in base_model.layers: - if rms_norm: - _patch_rms_norm_module(decoder_layer.input_layernorm) - _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + # patch MOE layers + if isinstance(decoder_layer.mlp, Glm4MoeMoE): + experts = decoder_layer.mlp.experts + if experts is not None: + for expert in experts: + _patch_swiglu_module(expert, LigerSwiGLUMLP) + + shared_experts = decoder_layer.mlp.shared_experts + if shared_experts is not None: + _patch_swiglu_module(shared_experts, LigerSwiGLUMLP) def apply_liger_kernel_to_glm4v( diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index 6d07422f5..354c9aa11 100755 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -2606,23 +2606,28 @@ def test_apply_liger_kernel_to_instance_for_glm4_moe(): num_hidden_layers=2, n_routed_experts=1, n_shared_experts=1, + num_attention_heads=4, + num_key_value_heads=4, + head_dim=8, + first_k_dense_replace=1, ) dummy_model_instance = AutoModelForCausalLM.from_config(config) # Check that model instance variables are not yet patched with Liger modules assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(glm4_moe_lce_forward) assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) - for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) - assert inspect.getsource(layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(layer.post_attention_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) - - for decoder_layer in dummy_model_instance.base_model.layers: + for decoder_layer in dummy_model_instance.model.layers: + assert inspect.getsource(decoder_layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(decoder_layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) != inspect.getsource( + LigerRMSNorm.forward + ) # https://github.com/huggingface/transformers/blob/69f003696b55de75b7f18888c03111909a7cd537/src/transformers/models/glm4_moe/modeling_glm4_moe.py#L438 - if decoder_layer.mlp.experts is not None: - for expert in decoder_layer.mlp.experts: - assert inspect.getsource(expert.forward) != inspect.getsource(LigerSwiGLUMLP.forward) - if decoder_layer.mlp.shared_experts is not None: + if hasattr(decoder_layer.mlp, "experts") and decoder_layer.mlp.experts is not None: + if getattr(decoder_layer.mlp, "experts", None) is not None: + for expert in decoder_layer.mlp.experts: + assert inspect.getsource(expert.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + if getattr(decoder_layer.mlp, "shared_experts", None) is not None: assert inspect.getsource(decoder_layer.mlp.shared_experts.forward) != inspect.getsource( LigerSwiGLUMLP.forward ) @@ -2633,10 +2638,6 @@ def test_apply_liger_kernel_to_instance_for_glm4_moe(): # Check that the model's instance variables were correctly patched with Liger modules assert inspect.getsource(dummy_model_instance.forward) == inspect.getsource(glm4_moe_lce_forward) assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) - for layer in dummy_model_instance.model.layers: - assert inspect.getsource(layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) - assert inspect.getsource(layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) - assert inspect.getsource(layer.post_attention_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) for decoder_layer in dummy_model_instance.base_model.layers: if decoder_layer.mlp is not None: assert inspect.getsource(decoder_layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) @@ -2646,13 +2647,14 @@ def test_apply_liger_kernel_to_instance_for_glm4_moe(): assert inspect.getsource(decoder_layer.input_layernorm.forward) == inspect.getsource( LigerRMSNormForGlm4.forward ) - if getattr(decoder_layer.mlp, "experts", None) is not None: - for expert in decoder_layer.mlp.experts: - assert inspect.getsource(expert.forward) == inspect.getsource(LigerSwiGLUMLP.forward) - if getattr(decoder_layer.mlp, "shared_experts", None) is not None: - assert inspect.getsource(decoder_layer.mlp.shared_experts.forward) == inspect.getsource( - LigerSwiGLUMLP.forward - ) + if hasattr(decoder_layer.mlp, "experts") and decoder_layer.mlp.experts is not None: + if getattr(decoder_layer.mlp, "experts", None) is not None: + for expert in decoder_layer.mlp.experts: + assert inspect.getsource(expert.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + if getattr(decoder_layer.mlp, "shared_experts", None) is not None: + assert inspect.getsource(decoder_layer.mlp.shared_experts.forward) == inspect.getsource( + LigerSwiGLUMLP.forward + ) try: print(dummy_model_instance) From 5af9d164b171cc1a6224217d8ebfd13a6c267cc9 Mon Sep 17 00:00:00 2001 From: Eddie Tsai Date: Tue, 25 Nov 2025 20:42:35 +0800 Subject: [PATCH 07/13] modify to adapt to new API --- src/liger_kernel/transformers/model/glm4_moe.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/liger_kernel/transformers/model/glm4_moe.py b/src/liger_kernel/transformers/model/glm4_moe.py index c8ca9bbc5..f8b70be28 100644 --- a/src/liger_kernel/transformers/model/glm4_moe.py +++ b/src/liger_kernel/transformers/model/glm4_moe.py @@ -8,6 +8,7 @@ from transformers.utils.deprecation import deprecate_kwarg from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast @@ -91,6 +92,12 @@ def lce_forward( ``` """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, @@ -124,7 +131,7 @@ def lce_forward( skip_logits = self.training and (labels is not None or shift_labels is not None) if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=kept_hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -132,6 +139,7 @@ def lce_forward( hidden_size=self.config.hidden_size, **kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: logits = self.lm_head(kept_hidden_states) @@ -144,6 +152,13 @@ def lce_forward( **kwargs, ) + if not return_dict: + output = (logits,) + outputs[1:] + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + return output + + # Return custom output class with accuracy field return LigerCausalLMOutputWithPast( loss=loss, logits=logits, From 7f0ebf466e34d7e88d793a5d6cd86cc5a203b4a1 Mon Sep 17 00:00:00 2001 From: Eddie Tsai Date: Mon, 29 Dec 2025 15:40:37 +0800 Subject: [PATCH 08/13] Enhance GLM4-MoE support by adding MLP handling in monkey patching --- src/liger_kernel/transformers/monkey_patch.py | 12 ++++- test/transformers/test_monkey_patch.py | 53 +++++++++---------- 2 files changed, 37 insertions(+), 28 deletions(-) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 6f0ab82db..b3816ef1b 100755 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -2167,6 +2167,7 @@ def apply_liger_kernel_to_glm4_moe( "cross_entropy and fused_linear_cross_entropy cannot both be True." ) from transformers.models.glm4_moe import modeling_glm4_moe + from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMLP from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeModel from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMoE @@ -2199,7 +2200,16 @@ def apply_liger_kernel_to_glm4_moe( for decoder_layer in base_model.layers: if swiglu: - _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) + if isinstance(decoder_layer.mlp, Glm4MoeMoE): + experts = decoder_layer.mlp.experts + if experts is not None: + for expert in experts: + _patch_swiglu_module(expert, LigerSwiGLUMLP) + shared_experts = decoder_layer.mlp.shared_experts + if shared_experts is not None: + _patch_swiglu_module(shared_experts, LigerSwiGLUMLP) + elif isinstance(decoder_layer.mlp, Glm4MoeMLP): + _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) if rms_norm: _patch_rms_norm_module(decoder_layer.input_layernorm) _patch_rms_norm_module(decoder_layer.post_attention_layernorm) diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index 7d7b86876..ffb8cd2c3 100755 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -2591,10 +2591,11 @@ def test_apply_liger_kernel_to_instance_for_glm4(): @pytest.mark.skipif(not is_glm4_moe_available(), reason="glm4 module not available") def test_apply_liger_kernel_to_instance_for_glm4_moe(): + from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMLP + # Ensure any monkey patching is cleaned up for subsequent tests with patch("transformers.models.glm4_moe.modeling_glm4_moe"): from liger_kernel.transformers.model.glm4_moe import lce_forward as glm4_moe_lce_forward - from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4 # Instantiate a dummy model config = transformers.models.glm4_moe.Glm4MoeConfig( @@ -2617,20 +2618,20 @@ def test_apply_liger_kernel_to_instance_for_glm4_moe(): assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(glm4_moe_lce_forward) assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) for decoder_layer in dummy_model_instance.model.layers: - assert inspect.getsource(decoder_layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + if isinstance(decoder_layer.mlp, Glm4MoeMOE): + experts = decoder_layer.mlp.experts + if experts is not None: + for expert in experts: + assert inspect.getsource(expert.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + shared_experts = decoder_layer.mlp.shared_experts + if shared_experts is not None: + assert inspect.getsource(shared_experts.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + elif isinstance(decoder_layer.mlp, Glm4MoeMLP): + assert inspect.getsource(decoder_layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) assert inspect.getsource(decoder_layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) != inspect.getsource( LigerRMSNorm.forward ) - # https://github.com/huggingface/transformers/blob/69f003696b55de75b7f18888c03111909a7cd537/src/transformers/models/glm4_moe/modeling_glm4_moe.py#L438 - if hasattr(decoder_layer.mlp, "experts") and decoder_layer.mlp.experts is not None: - if getattr(decoder_layer.mlp, "experts", None) is not None: - for expert in decoder_layer.mlp.experts: - assert inspect.getsource(expert.forward) != inspect.getsource(LigerSwiGLUMLP.forward) - if getattr(decoder_layer.mlp, "shared_experts", None) is not None: - assert inspect.getsource(decoder_layer.mlp.shared_experts.forward) != inspect.getsource( - LigerSwiGLUMLP.forward - ) # Test applying kernels to the model instance _apply_liger_kernel_to_instance(model=dummy_model_instance) @@ -2638,23 +2639,21 @@ def test_apply_liger_kernel_to_instance_for_glm4_moe(): # Check that the model's instance variables were correctly patched with Liger modules assert inspect.getsource(dummy_model_instance.forward) == inspect.getsource(glm4_moe_lce_forward) assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) - for decoder_layer in dummy_model_instance.base_model.layers: - if decoder_layer.mlp is not None: - assert inspect.getsource(decoder_layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) - assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) == inspect.getsource( - LigerRMSNormForGlm4.forward - ) - assert inspect.getsource(decoder_layer.input_layernorm.forward) == inspect.getsource( - LigerRMSNormForGlm4.forward - ) - if hasattr(decoder_layer.mlp, "experts") and decoder_layer.mlp.experts is not None: - if getattr(decoder_layer.mlp, "experts", None) is not None: - for expert in decoder_layer.mlp.experts: + for decoder_layer in dummy_model_instance.model.layers: + if isinstance(decoder_layer.mlp, Glm4MoeMOE): + experts = decoder_layer.mlp.experts + if experts is not None: + for expert in experts: assert inspect.getsource(expert.forward) == inspect.getsource(LigerSwiGLUMLP.forward) - if getattr(decoder_layer.mlp, "shared_experts", None) is not None: - assert inspect.getsource(decoder_layer.mlp.shared_experts.forward) == inspect.getsource( - LigerSwiGLUMLP.forward - ) + shared_experts = decoder_layer.mlp.shared_experts + if shared_experts is not None: + assert inspect.getsource(shared_experts.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + elif isinstance(decoder_layer.mlp, Glm4MoeMLP): + assert inspect.getsource(decoder_layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(decoder_layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) == inspect.getsource( + LigerRMSNorm.forward + ) try: print(dummy_model_instance) From 5cf027e38656f0d2da3588107650e36804bbec49 Mon Sep 17 00:00:00 2001 From: Eddie Tsai Date: Mon, 29 Dec 2025 16:32:05 +0800 Subject: [PATCH 09/13] fix typo --- test/transformers/test_monkey_patch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index ffb8cd2c3..427f6533d 100755 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -2592,6 +2592,7 @@ def test_apply_liger_kernel_to_instance_for_glm4(): @pytest.mark.skipif(not is_glm4_moe_available(), reason="glm4 module not available") def test_apply_liger_kernel_to_instance_for_glm4_moe(): from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMLP + from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMOE # Ensure any monkey patching is cleaned up for subsequent tests with patch("transformers.models.glm4_moe.modeling_glm4_moe"): From 876d0750d65efca291db8b2096f4b3a882c74601 Mon Sep 17 00:00:00 2001 From: Eddie Tsai Date: Mon, 29 Dec 2025 16:40:25 +0800 Subject: [PATCH 10/13] fix typo --- test/transformers/test_monkey_patch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index 427f6533d..c0d92d7ae 100755 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -2619,7 +2619,7 @@ def test_apply_liger_kernel_to_instance_for_glm4_moe(): assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(glm4_moe_lce_forward) assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) for decoder_layer in dummy_model_instance.model.layers: - if isinstance(decoder_layer.mlp, Glm4MoeMOE): + if isinstance(decoder_layer.mlp, Glm4MoeMoE): experts = decoder_layer.mlp.experts if experts is not None: for expert in experts: @@ -2641,7 +2641,7 @@ def test_apply_liger_kernel_to_instance_for_glm4_moe(): assert inspect.getsource(dummy_model_instance.forward) == inspect.getsource(glm4_moe_lce_forward) assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) for decoder_layer in dummy_model_instance.model.layers: - if isinstance(decoder_layer.mlp, Glm4MoeMOE): + if isinstance(decoder_layer.mlp, Glm4MoeMoE): experts = decoder_layer.mlp.experts if experts is not None: for expert in experts: From cf9f21243f8c42b8581e19b0474efe51f335f03b Mon Sep 17 00:00:00 2001 From: Eddie Tsai Date: Tue, 6 Jan 2026 15:57:56 +0800 Subject: [PATCH 11/13] fix: update Glm4Moe import and clean up MOE layer patching in monkey_patch.py --- src/liger_kernel/transformers/monkey_patch.py | 11 +---------- test/transformers/test_monkey_patch.py | 2 +- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index b3816ef1b..6ac71aba8 100755 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -2200,6 +2200,7 @@ def apply_liger_kernel_to_glm4_moe( for decoder_layer in base_model.layers: if swiglu: + # patch moe layer if isinstance(decoder_layer.mlp, Glm4MoeMoE): experts = decoder_layer.mlp.experts if experts is not None: @@ -2213,16 +2214,6 @@ def apply_liger_kernel_to_glm4_moe( if rms_norm: _patch_rms_norm_module(decoder_layer.input_layernorm) _patch_rms_norm_module(decoder_layer.post_attention_layernorm) - # patch MOE layers - if isinstance(decoder_layer.mlp, Glm4MoeMoE): - experts = decoder_layer.mlp.experts - if experts is not None: - for expert in experts: - _patch_swiglu_module(expert, LigerSwiGLUMLP) - - shared_experts = decoder_layer.mlp.shared_experts - if shared_experts is not None: - _patch_swiglu_module(shared_experts, LigerSwiGLUMLP) def apply_liger_kernel_to_glm4v( diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index cb2fa3add..b2af3be83 100755 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -2594,7 +2594,7 @@ def test_apply_liger_kernel_to_instance_for_glm4(): @pytest.mark.skipif(not is_glm4_moe_available(), reason="glm4 module not available") def test_apply_liger_kernel_to_instance_for_glm4_moe(): from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMLP - from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMOE + from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMoE # Ensure any monkey patching is cleaned up for subsequent tests with patch("transformers.models.glm4_moe.modeling_glm4_moe"): From ccc308a130cd10b6cf57fc92dd44efa5546d5059 Mon Sep 17 00:00:00 2001 From: Eddie Tsai Date: Wed, 7 Jan 2026 15:20:37 +0800 Subject: [PATCH 12/13] fix: update rotary position embedding assignment in apply_liger_kernel_to_glm4_moe function --- src/liger_kernel/transformers/monkey_patch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 6ac71aba8..feced61c6 100755 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -2175,7 +2175,7 @@ def apply_liger_kernel_to_glm4_moe( from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4 if rope: - modeling_glm4_moe.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast + modeling_glm4_moe.apply_rotary_pos_emb = liger_rotary_pos_emb if rms_norm: modeling_glm4_moe.Glm4MoeRMSNorm = LigerRMSNormForGlm4 if cross_entropy: From 375903c822912dc591d571965ec263237cdb8e4f Mon Sep 17 00:00:00 2001 From: Eddie Tsai Date: Wed, 7 Jan 2026 15:20:45 +0800 Subject: [PATCH 13/13] feat: add MOE configuration parameters for GLM4_MOE in test models --- test/convergence/bf16/test_mini_models.py | 8 ++++++++ test/convergence/bf16/test_mini_models_with_logits.py | 8 ++++++++ test/convergence/fp32/test_mini_models.py | 8 ++++++++ test/convergence/fp32/test_mini_models_with_logits.py | 8 ++++++++ 4 files changed, 32 insertions(+) diff --git a/test/convergence/bf16/test_mini_models.py b/test/convergence/bf16/test_mini_models.py index 4034d9be4..97790bad1 100644 --- a/test/convergence/bf16/test_mini_models.py +++ b/test/convergence/bf16/test_mini_models.py @@ -1148,6 +1148,14 @@ eos_token_id=2, # 151329, 151336, 151338 pad_token_id=2, # 151329 partial_rotary_factor=0.5, + moe_intermediate_size=1408, + num_experts_per_tok=2, + n_shared_experts=1, + n_routed_experts=8, + routed_scaling_factor=1.0, + n_group=1, + topk_group=1, + first_k_dense_replace=1, cross_attention_layers=None, dropout=0, hidden_act="silu", diff --git a/test/convergence/bf16/test_mini_models_with_logits.py b/test/convergence/bf16/test_mini_models_with_logits.py index 303f66c34..82a5c197f 100644 --- a/test/convergence/bf16/test_mini_models_with_logits.py +++ b/test/convergence/bf16/test_mini_models_with_logits.py @@ -1096,6 +1096,14 @@ eos_token_id=2, # 151329, 151336, 151338 pad_token_id=2, # 151329 partial_rotary_factor=0.5, + moe_intermediate_size=1408, + num_experts_per_tok=2, + n_shared_experts=1, + n_routed_experts=8, + routed_scaling_factor=1.0, + n_group=1, + topk_group=1, + first_k_dense_replace=1, cross_attention_layers=None, dropout=0, hidden_act="silu", diff --git a/test/convergence/fp32/test_mini_models.py b/test/convergence/fp32/test_mini_models.py index eb59d964d..6e3841a11 100644 --- a/test/convergence/fp32/test_mini_models.py +++ b/test/convergence/fp32/test_mini_models.py @@ -1084,6 +1084,14 @@ eos_token_id=2, # 151329, 151336, 151338 pad_token_id=2, # 151329 partial_rotary_factor=0.5, + moe_intermediate_size=1408, + num_experts_per_tok=2, + n_shared_experts=1, + n_routed_experts=8, + routed_scaling_factor=1.0, + n_group=1, + topk_group=1, + first_k_dense_replace=1, cross_attention_layers=None, dropout=0, hidden_act="silu", diff --git a/test/convergence/fp32/test_mini_models_with_logits.py b/test/convergence/fp32/test_mini_models_with_logits.py index c0504ed08..583c0987f 100644 --- a/test/convergence/fp32/test_mini_models_with_logits.py +++ b/test/convergence/fp32/test_mini_models_with_logits.py @@ -1113,6 +1113,14 @@ eos_token_id=2, # 151329, 151336, 151338 pad_token_id=2, # 151329 partial_rotary_factor=0.5, + moe_intermediate_size=1408, + num_experts_per_tok=2, + n_shared_experts=1, + n_routed_experts=8, + routed_scaling_factor=1.0, + n_group=1, + topk_group=1, + first_k_dense_replace=1, cross_attention_layers=None, dropout=0, hidden_act="silu",