diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index c0b8bf289..c51ac871d 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -39,6 +39,7 @@ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4_moe # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gpt_oss # noqa: F401 @@ -110,6 +111,7 @@ def __getattr__(name: str): "apply_liger_kernel_to_gemma3_text", "apply_liger_kernel_to_glm4", "apply_liger_kernel_to_glm4v", + "apply_liger_kernel_to_glm4_moe", "apply_liger_kernel_to_glm4v_moe", "apply_liger_kernel_to_gpt_oss", "apply_liger_kernel_to_granite", @@ -187,6 +189,7 @@ def __getattr__(name: str): "apply_liger_kernel_to_gemma3", "apply_liger_kernel_to_gemma3_text", "apply_liger_kernel_to_glm4", + "apply_liger_kernel_to_glm4_moe", "apply_liger_kernel_to_glm4v", "apply_liger_kernel_to_glm4v_moe", "apply_liger_kernel_to_gpt_oss", diff --git a/src/liger_kernel/transformers/model/glm4_moe.py b/src/liger_kernel/transformers/model/glm4_moe.py new file mode 100644 index 000000000..f8b70be28 --- /dev/null +++ b/src/liger_kernel/transformers/model/glm4_moe.py @@ -0,0 +1,169 @@ +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from transformers.utils.deprecation import deprecate_kwarg + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast + + +@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **kwargs, +) -> Union[Tuple, LigerCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Example: + + ```python + >>> from transformers import AutoProcessor, Glm4MoeForCausalLM + >>> import torch + + >>> MODEL_PATH = "meta-glm4_moe/Glm4Moe-2-7b-hf" + >>> messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png" + }, + { + "type": "text", + "text": "describe this image" + } + ], + } + ] + >>> processor = AutoProcessor.from_pretrained(MODEL_PATH) + >>> model = Glm4MoeForCausalLM.from_pretrained( + pretrained_model_name_or_path=MODEL_PATH, + dtype="auto", + device_map="auto", + ) + >>> inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt" + ).to(model.device) + >>> inputs.pop("token_type_ids", None) + >>> generated_ids = model.generate(**inputs, max_new_tokens=8192) + >>> output_text = processor.decode(generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False) + ``` + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + token_accuracy = None + + if skip_logits and labels is None and shift_labels is None: + raise ValueError("skip_logits is True, but labels and shift_labels are None") + + if skip_logits is None: + # By default, if in training mode, don't materialize logits + skip_logits = self.training and (labels is not None or shift_labels is not None) + + if skip_logits: + result = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) + + else: + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + output = ((loss,) + output) if loss is not None else output + output = output + (token_accuracy,) if token_accuracy is not None else output + return output + + # Return custom output class with accuracy field + return LigerCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + token_accuracy=token_accuracy, + ) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 8b206486a..feced61c6 100755 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -2140,6 +2140,82 @@ def apply_liger_kernel_to_glm4( _patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False) +def apply_liger_kernel_to_glm4_moe( + rope: bool = False, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace GLM-4MOE models. + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is False. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + from transformers.models.glm4_moe import modeling_glm4_moe + from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMLP + from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeModel + from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMoE + + from liger_kernel.transformers.model.glm4_moe import lce_forward as glm4_moe_lce_forward + from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4 + + if rope: + modeling_glm4_moe.apply_rotary_pos_emb = liger_rotary_pos_emb + if rms_norm: + modeling_glm4_moe.Glm4MoeRMSNorm = LigerRMSNormForGlm4 + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(glm4_moe_lce_forward, model) + else: + modeling_glm4_moe.Glm4MoeForCausalLM.forward = glm4_moe_lce_forward + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + # get the base model from the model instance + base_model: Glm4MoeModel = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module(base_model.norm) + + for decoder_layer in base_model.layers: + if swiglu: + # patch moe layer + if isinstance(decoder_layer.mlp, Glm4MoeMoE): + experts = decoder_layer.mlp.experts + if experts is not None: + for expert in experts: + _patch_swiglu_module(expert, LigerSwiGLUMLP) + shared_experts = decoder_layer.mlp.shared_experts + if shared_experts is not None: + _patch_swiglu_module(shared_experts, LigerSwiGLUMLP) + elif isinstance(decoder_layer.mlp, Glm4MoeMLP): + _patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP) + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + def apply_liger_kernel_to_glm4v( rope: bool = False, cross_entropy: bool = False, @@ -2828,6 +2904,7 @@ def apply_liger_kernel_to_hunyuan_v1_moe( "gemma3_text": apply_liger_kernel_to_gemma3_text, "gemma3": apply_liger_kernel_to_gemma3, "glm4": apply_liger_kernel_to_glm4, + "glm4_moe": apply_liger_kernel_to_glm4_moe, "glm4v": apply_liger_kernel_to_glm4v, "glm4v_moe": apply_liger_kernel_to_glm4v_moe, "gpt_oss": apply_liger_kernel_to_gpt_oss, diff --git a/test/convergence/bf16/test_mini_models.py b/test/convergence/bf16/test_mini_models.py index 3b1782520..97790bad1 100644 --- a/test/convergence/bf16/test_mini_models.py +++ b/test/convergence/bf16/test_mini_models.py @@ -27,6 +27,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_gemma2 from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text from liger_kernel.transformers import apply_liger_kernel_to_glm4 +from liger_kernel.transformers import apply_liger_kernel_to_glm4_moe from liger_kernel.transformers import apply_liger_kernel_to_glm4v from liger_kernel.transformers import apply_liger_kernel_to_glm4v_moe from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss @@ -63,6 +64,7 @@ from test.utils import revert_liger_kernel_to_gemma2 from test.utils import revert_liger_kernel_to_gemma3_text from test.utils import revert_liger_kernel_to_glm4 +from test.utils import revert_liger_kernel_to_glm4_moe from test.utils import revert_liger_kernel_to_glm4v from test.utils import revert_liger_kernel_to_glm4v_moe from test.utils import revert_liger_kernel_to_gpt_oss @@ -216,6 +218,14 @@ except ImportError: GLM4_AVAILABLE = False +try: + from transformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig + from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeForCausalLM + + GLM4_MOE_AVAILABLE = True +except ImportError: + GLM4_MOE_AVAILABLE = False + try: # Glm4v is only available in transformers>=4.51.3 from transformers.models.glm4v.configuration_glm4v import Glm4vConfig @@ -1128,6 +1138,45 @@ ), ) +if GLM4_MOE_AVAILABLE: + MINI_MODEL_SETUPS["mini_glm4_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_glm4_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4_moe, + model_class=Glm4MoeForCausalLM, + mini_model_config=Glm4MoeConfig( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + partial_rotary_factor=0.5, + moe_intermediate_size=1408, + num_experts_per_tok=2, + n_shared_experts=1, + n_routed_experts=8, + routed_scaling_factor=1.0, + n_group=1, + topk_group=1, + first_k_dense_replace=1, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, # 32768 + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + rope_scaling=None, + rope_theta=500_000, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + if GLM4V_AVAILABLE: MINI_MODEL_SETUPS["mini_glm4v"] = MiniModelConfig( liger_kernel_patch_func=apply_liger_kernel_to_glm4v, @@ -1896,6 +1945,25 @@ def run_mini_model( ), ], ), + pytest.param( + "mini_glm4_moe", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not GLM4_MOE_AVAILABLE, + reason="Glm4_moe not available in this version of transformers", + ), + ], + ), pytest.param( "mini_glm4v", 32, diff --git a/test/convergence/bf16/test_mini_models_with_logits.py b/test/convergence/bf16/test_mini_models_with_logits.py index 082d3cb2c..82a5c197f 100644 --- a/test/convergence/bf16/test_mini_models_with_logits.py +++ b/test/convergence/bf16/test_mini_models_with_logits.py @@ -27,6 +27,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_gemma2 from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text from liger_kernel.transformers import apply_liger_kernel_to_glm4 +from liger_kernel.transformers import apply_liger_kernel_to_glm4_moe from liger_kernel.transformers import apply_liger_kernel_to_glm4v from liger_kernel.transformers import apply_liger_kernel_to_glm4v_moe from liger_kernel.transformers import apply_liger_kernel_to_granite @@ -62,6 +63,7 @@ from test.utils import revert_liger_kernel_to_gemma2 from test.utils import revert_liger_kernel_to_gemma3_text from test.utils import revert_liger_kernel_to_glm4 +from test.utils import revert_liger_kernel_to_glm4_moe from test.utils import revert_liger_kernel_to_glm4v from test.utils import revert_liger_kernel_to_glm4v_moe from test.utils import revert_liger_kernel_to_granite @@ -206,6 +208,14 @@ except ImportError: GLM4_AVAILABLE = False +try: + from transformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig + from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeForCausalLM + + GLM4_MOE_AVAILABLE = True +except ImportError: + GLM4_MOE_AVAILABLE = False + try: # Glm4v is only available in transformers>=4.51.3 from transformers.models.glm4v.configuration_glm4v import Glm4vConfig @@ -1075,6 +1085,46 @@ attn_implementation="sdpa", # default value, pytorch native attention ), ) + +if GLM4_MOE_AVAILABLE: + MINI_MODEL_SETUPS["mini_glm4_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_glm4_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4_moe, + model_class=Glm4MoeForCausalLM, + mini_model_config=Glm4MoeConfig( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + partial_rotary_factor=0.5, + moe_intermediate_size=1408, + num_experts_per_tok=2, + n_shared_experts=1, + n_routed_experts=8, + routed_scaling_factor=1.0, + n_group=1, + topk_group=1, + first_k_dense_replace=1, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, # 32768 + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + rope_scaling=None, + rope_theta=500_000, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + if GLM4V_AVAILABLE: MINI_MODEL_SETUPS["mini_glm4v"] = MiniModelConfig( liger_kernel_patch_func=apply_liger_kernel_to_glm4v, @@ -1847,6 +1897,25 @@ def run_mini_model( ), ], ), + pytest.param( + "mini_glm4_moe", + 32, + 1e-5, + torch.bfloat16, + 1e-2, + 1e-2, + 1e-1, + 1e-2, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not GLM4_MOE_AVAILABLE, + reason="Glm4_moe not available in this version of transformers", + ), + ], + ), pytest.param( "mini_glm4v", 32, diff --git a/test/convergence/fp32/test_mini_models.py b/test/convergence/fp32/test_mini_models.py index a0e2a95c6..6e3841a11 100644 --- a/test/convergence/fp32/test_mini_models.py +++ b/test/convergence/fp32/test_mini_models.py @@ -27,6 +27,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_gemma2 from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text from liger_kernel.transformers import apply_liger_kernel_to_glm4 +from liger_kernel.transformers import apply_liger_kernel_to_glm4_moe from liger_kernel.transformers import apply_liger_kernel_to_glm4v from liger_kernel.transformers import apply_liger_kernel_to_glm4v_moe from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss @@ -63,6 +64,7 @@ from test.utils import revert_liger_kernel_to_gemma2 from test.utils import revert_liger_kernel_to_gemma3_text from test.utils import revert_liger_kernel_to_glm4 +from test.utils import revert_liger_kernel_to_glm4_moe from test.utils import revert_liger_kernel_to_glm4v from test.utils import revert_liger_kernel_to_glm4v_moe from test.utils import revert_liger_kernel_to_gpt_oss @@ -204,6 +206,14 @@ except ImportError: GLM4_AVAILABLE = False +try: + from transformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig + from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeForCausalLM + + GLM4_MOE_AVAILABLE = True +except ImportError: + GLM4_MOE_AVAILABLE = False + try: # Glm4v is only available in transformers>=4.51.3 from transformers.models.glm4v.configuration_glm4v import Glm4vConfig @@ -1064,6 +1074,45 @@ ), ) +if GLM4_MOE_AVAILABLE: + MINI_MODEL_SETUPS["mini_glm4_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_glm4_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4_moe, + model_class=Glm4MoeForCausalLM, + mini_model_config=Glm4MoeConfig( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + partial_rotary_factor=0.5, + moe_intermediate_size=1408, + num_experts_per_tok=2, + n_shared_experts=1, + n_routed_experts=8, + routed_scaling_factor=1.0, + n_group=1, + topk_group=1, + first_k_dense_replace=1, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=32768, + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + rope_scaling=None, + rope_theta=500_000, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + if GLM4V_AVAILABLE: MINI_MODEL_SETUPS["mini_glm4v"] = MiniModelConfig( liger_kernel_patch_func=apply_liger_kernel_to_glm4v, @@ -1785,6 +1834,22 @@ def run_mini_model( reason="Glm4 not available in this version of transformers", ), ), + pytest.param( + "mini_glm4_moe", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not GLM4_MOE_AVAILABLE, + reason="Glm4_moe not available in this version of transformers", + ), + ), pytest.param( "mini_glm4v", 32, diff --git a/test/convergence/fp32/test_mini_models_with_logits.py b/test/convergence/fp32/test_mini_models_with_logits.py index d6debfb3a..583c0987f 100644 --- a/test/convergence/fp32/test_mini_models_with_logits.py +++ b/test/convergence/fp32/test_mini_models_with_logits.py @@ -27,6 +27,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_gemma2 from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text from liger_kernel.transformers import apply_liger_kernel_to_glm4 +from liger_kernel.transformers import apply_liger_kernel_to_glm4_moe from liger_kernel.transformers import apply_liger_kernel_to_glm4v from liger_kernel.transformers import apply_liger_kernel_to_glm4v_moe from liger_kernel.transformers import apply_liger_kernel_to_granite @@ -62,6 +63,7 @@ from test.utils import revert_liger_kernel_to_gemma2 from test.utils import revert_liger_kernel_to_gemma3_text from test.utils import revert_liger_kernel_to_glm4 +from test.utils import revert_liger_kernel_to_glm4_moe from test.utils import revert_liger_kernel_to_glm4v from test.utils import revert_liger_kernel_to_glm4v_moe from test.utils import revert_liger_kernel_to_granite @@ -226,6 +228,14 @@ except ImportError: GLM4_AVAILABLE = False +try: + from transformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig + from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeForCausalLM + + GLM4_MOE_AVAILABLE = True +except ImportError: + GLM4_MOE_AVAILABLE = False + try: # Glm4v is only available in transformers>=4.51.3 from transformers.models.glm4v.configuration_glm4v import Glm4vConfig @@ -1093,6 +1103,45 @@ ), ) +if GLM4_MOE_AVAILABLE: + MINI_MODEL_SETUPS["mini_glm4_moe"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_glm4_moe, + liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4_moe, + model_class=Glm4MoeForCausalLM, + mini_model_config=Glm4MoeConfig( + bos_token_id=1, # None + eos_token_id=2, # 151329, 151336, 151338 + pad_token_id=2, # 151329 + partial_rotary_factor=0.5, + moe_intermediate_size=1408, + num_experts_per_tok=2, + n_shared_experts=1, + n_routed_experts=8, + routed_scaling_factor=1.0, + n_group=1, + topk_group=1, + first_k_dense_replace=1, + cross_attention_layers=None, + dropout=0, + hidden_act="silu", + hidden_size=1024, # 6144 + initializer_range=0.02, + intermediate_size=2048, # 14336 + max_position_embeddings=4096, # 32768 + num_attention_heads=8, # 48 + num_hidden_layers=4, # 61 + num_key_value_heads=2, + rms_norm_eps=1e-5, + rope_scaling=None, + rope_theta=500_000, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32000, # 151552 + attention_bias=True, + attn_implementation="sdpa", # default value, pytorch native attention + ), + ) + if GLM4V_AVAILABLE: MINI_MODEL_SETUPS["mini_glm4v"] = MiniModelConfig( liger_kernel_patch_func=apply_liger_kernel_to_glm4v, @@ -1731,6 +1780,22 @@ def run_mini_model( reason="Glm4 not available in this version of transformers", ), ), + pytest.param( + "mini_glm4_moe", + 32, + 1e-4, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not GLM4_MOE_AVAILABLE, + reason="Glm4_moe not available in this version of transformers", + ), + ), pytest.param( "mini_glm4v", 32, diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index 71ebf592a..b2af3be83 100755 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -161,6 +161,15 @@ def is_glm4_available(): return False +def is_glm4_moe_available(): + try: + import transformers.models.glm4_moe # noqa: F401 + + return True + except ImportError: + return False + + def is_glm4v_available(): try: import transformers.models.glm4v # noqa: F401 @@ -232,6 +241,7 @@ def test_import_from_root(): from liger_kernel.transformers import apply_liger_kernel_to_gemma3 # noqa: F401 from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text # noqa: F401 from liger_kernel.transformers import apply_liger_kernel_to_glm4 # noqa: F401 + from liger_kernel.transformers import apply_liger_kernel_to_glm4_moe # noqa: F401 from liger_kernel.transformers import apply_liger_kernel_to_glm4v # noqa: F401 from liger_kernel.transformers import apply_liger_kernel_to_glm4v_moe # noqa: F401 from liger_kernel.transformers import apply_liger_kernel_to_internvl # noqa: F401 @@ -2581,6 +2591,79 @@ def test_apply_liger_kernel_to_instance_for_glm4(): pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}") +@pytest.mark.skipif(not is_glm4_moe_available(), reason="glm4 module not available") +def test_apply_liger_kernel_to_instance_for_glm4_moe(): + from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMLP + from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMoE + + # Ensure any monkey patching is cleaned up for subsequent tests + with patch("transformers.models.glm4_moe.modeling_glm4_moe"): + from liger_kernel.transformers.model.glm4_moe import lce_forward as glm4_moe_lce_forward + + # Instantiate a dummy model + config = transformers.models.glm4_moe.Glm4MoeConfig( + dtype=torch.bfloat16, + rms_norm_eps=1e-5, + hidden_size=32, + intermediate_size=64, + hidden_act="silu", + num_hidden_layers=2, + n_routed_experts=1, + n_shared_experts=1, + num_attention_heads=4, + num_key_value_heads=4, + head_dim=8, + first_k_dense_replace=1, + ) + dummy_model_instance = AutoModelForCausalLM.from_config(config) + + # Check that model instance variables are not yet patched with Liger modules + assert inspect.getsource(dummy_model_instance.forward) != inspect.getsource(glm4_moe_lce_forward) + assert inspect.getsource(dummy_model_instance.model.norm.forward) != inspect.getsource(LigerRMSNorm.forward) + for decoder_layer in dummy_model_instance.model.layers: + if isinstance(decoder_layer.mlp, Glm4MoeMoE): + experts = decoder_layer.mlp.experts + if experts is not None: + for expert in experts: + assert inspect.getsource(expert.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + shared_experts = decoder_layer.mlp.shared_experts + if shared_experts is not None: + assert inspect.getsource(shared_experts.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + elif isinstance(decoder_layer.mlp, Glm4MoeMLP): + assert inspect.getsource(decoder_layer.mlp.forward) != inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(decoder_layer.input_layernorm.forward) != inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) != inspect.getsource( + LigerRMSNorm.forward + ) + + # Test applying kernels to the model instance + _apply_liger_kernel_to_instance(model=dummy_model_instance) + + # Check that the model's instance variables were correctly patched with Liger modules + assert inspect.getsource(dummy_model_instance.forward) == inspect.getsource(glm4_moe_lce_forward) + assert inspect.getsource(dummy_model_instance.model.norm.forward) == inspect.getsource(LigerRMSNorm.forward) + for decoder_layer in dummy_model_instance.model.layers: + if isinstance(decoder_layer.mlp, Glm4MoeMoE): + experts = decoder_layer.mlp.experts + if experts is not None: + for expert in experts: + assert inspect.getsource(expert.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + shared_experts = decoder_layer.mlp.shared_experts + if shared_experts is not None: + assert inspect.getsource(shared_experts.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + elif isinstance(decoder_layer.mlp, Glm4MoeMLP): + assert inspect.getsource(decoder_layer.mlp.forward) == inspect.getsource(LigerSwiGLUMLP.forward) + assert inspect.getsource(decoder_layer.input_layernorm.forward) == inspect.getsource(LigerRMSNorm.forward) + assert inspect.getsource(decoder_layer.post_attention_layernorm.forward) == inspect.getsource( + LigerRMSNorm.forward + ) + + try: + print(dummy_model_instance) + except Exception as e: + pytest.fail(f"An exception occured in extra_expr: {type(e).__name__} - {e}") + + @pytest.mark.skipif(not is_glm4v_available(), reason="glm4v module not available") def test_apply_liger_kernel_to_instance_for_glm4v(): # Ensure any monkey patching is cleaned up for subsequent tests diff --git a/test/utils.py b/test/utils.py index 5e7de9586..62c75e07e 100644 --- a/test/utils.py +++ b/test/utils.py @@ -591,6 +591,18 @@ def revert_liger_kernel_to_glm4(model_config: MiniModelConfig): print("Liger kernel patches have been reverted.") +def revert_liger_kernel_to_glm4_moe(model_config: MiniModelConfig): + """ + Revert all Liger kernel patches applied to Glm4_moe. + """ + + from transformers.models.glm4_moe import modeling_glm4_moe + + importlib.reload(modeling_glm4_moe) + model_config.model_class = modeling_glm4_moe.Glm4MoeForCausalLM + print("Liger kernel patches have been reverted.") + + def revert_liger_kernel_to_glm4v(model_config: MiniModelConfig): """ Revert all Liger kernel patches applied to Glm4v.