99import gc
1010import subprocess
1111import time
12+ from contextlib import nullcontext
1213from typing import Any , List , Optional
1314
1415import torch
2223from torchao .prototype .mx_formats .inference_workflow import (
2324 NVFP4DynamicActivationNVFP4WeightConfig ,
2425)
25- from torchao .quantization import Int4WeightOnlyConfig , Int8WeightOnlyConfig , quantize_
26+ from torchao .prototype .mx_formats .nvfp4_tensor import NVFP4Tensor
27+ from torchao .quantization import (
28+ FqnToConfig ,
29+ Int4WeightOnlyConfig ,
30+ Int8WeightOnlyConfig ,
31+ quantize_ ,
32+ )
2633from torchao .quantization .granularity import PerRow
2734
2835"""
@@ -264,9 +271,37 @@ def parse_args():
264271 return parser .parse_args ()
265272
266273
274+ OLMOE_MODEL_ID = "allenai/OLMoE-1B-7B-0924"
275+
276+
277+ def _verify_olmoe_experts_quantized (model ):
278+ """Assert every OlmoeExperts module has NVFP4Tensor for both expert weights."""
279+ from transformers .models .olmoe .modeling_olmoe import OlmoeExperts
280+
281+ found = 0
282+ for name , mod in model .named_modules ():
283+ if not isinstance (mod , OlmoeExperts ):
284+ continue
285+ for pname in ("gate_up_proj" , "down_proj" ):
286+ param = getattr (mod , pname )
287+ assert isinstance (param , NVFP4Tensor ), (
288+ f"{ name } .{ pname } is { type (param ).__name__ } , expected NVFP4Tensor"
289+ )
290+ found += 1
291+ assert found > 0 , "no OlmoeExperts modules found to verify"
292+ print (f"Verified NVFP4 quantization on { found } OlmoeExperts modules" )
293+
294+
267295def main ():
268296 args = parse_args ()
269297
298+ is_olmoe = args .model_id == OLMOE_MODEL_ID
299+ if is_olmoe and args .quantization not in ("nvfp4-rtn" , "nvfp4-gptq-nonsequential" ):
300+ raise ValueError (
301+ f"model { args .model_id } only supports 'nvfp4-rtn' or "
302+ f"'nvfp4-gptq-nonsequential', got '{ args .quantization } '"
303+ )
304+
270305 # lm_eval batch_size="auto" with nvfp4 gptq causes the error in
271306 # MSLK nvfp4 triton kernel, likely an unsupported shape:
272307 # https://gist.github.com/vkuzo/b71ca46365dee017d1602e9638d91603
@@ -284,10 +319,11 @@ def main():
284319 dtype = dtype_map .get ("bfloat16" , torch .bfloat16 )
285320
286321 print (f"Loading model { args .model_id } ..." )
322+ from_pretrained_kwargs = dict (device_map = "cuda:0" , dtype = dtype )
323+ if is_olmoe :
324+ from_pretrained_kwargs ["experts_implementation" ] = "grouped_mm"
287325 model = AutoModelForCausalLM .from_pretrained (
288- args .model_id ,
289- device_map = "cuda:0" ,
290- dtype = dtype ,
326+ args .model_id , ** from_pretrained_kwargs
291327 )
292328 tokenizer = AutoTokenizer .from_pretrained (args .model_id )
293329
@@ -353,7 +389,20 @@ def skip_lm_head_o_proj(module, fqn):
353389 use_dynamic_per_tensor_scale = True ,
354390 use_triton_kernel = True ,
355391 )
356- quantize_ (model , config , filter_fn = filter_fn_to_use )
392+ if is_olmoe :
393+ quantize_ (
394+ model ,
395+ FqnToConfig (
396+ {
397+ r"re:.*\.experts\.gate_up_proj" : config ,
398+ r"re:.*\.experts\.down_proj" : config ,
399+ }
400+ ),
401+ filter_fn = None ,
402+ )
403+ _verify_olmoe_experts_quantized (model )
404+ else :
405+ quantize_ (model , config , filter_fn = filter_fn_to_use )
357406
358407 elif args .quantization in [
359408 "int4-gptq-sequential" ,
@@ -387,7 +436,19 @@ def skip_lm_head_o_proj(module, fqn):
387436 percdamp = args .percdamp ,
388437 gptq_quantize_block_size = args .gptq_block_size ,
389438 )
390- quantize_ (model , observe_config , filter_fn = filter_fn_to_use )
439+ if is_olmoe :
440+ quantize_ (
441+ model ,
442+ FqnToConfig (
443+ {
444+ r"re:.*\.experts\.gate_up_proj" : observe_config ,
445+ r"re:.*\.experts\.down_proj" : observe_config ,
446+ }
447+ ),
448+ filter_fn = None ,
449+ )
450+ else :
451+ quantize_ (model , observe_config , filter_fn = filter_fn_to_use )
391452
392453 # Prepare calibration dataset
393454 print (
@@ -425,12 +486,31 @@ def skip_lm_head_o_proj(module, fqn):
425486 num_gptq_weights += 1
426487 print (f"Total GPTQ weights to convert: { num_gptq_weights } " )
427488 # Apply quantization
428- quantize_ (model , convert_config , filter_fn = filter_fn_to_use )
489+ if is_olmoe :
490+ quantize_ (
491+ model ,
492+ FqnToConfig (
493+ {
494+ r"re:.*\.experts\.gate_up_proj" : convert_config ,
495+ r"re:.*\.experts\.down_proj" : convert_config ,
496+ }
497+ ),
498+ filter_fn = None ,
499+ )
500+ _verify_olmoe_experts_quantized (model )
501+ else :
502+ quantize_ (model , convert_config , filter_fn = filter_fn_to_use )
429503 else : # sequential
430504 print (f"Applying { quant_type } GPTQ quantization (sequential)..." )
431505 assert filter_fn_to_use == skip_lm_head , "unsupported"
432506 sequential_quantize (model , dataset , convert_config )
433507
508+ if is_olmoe :
509+ # generate() switches to batched_mm for decoding, which doesn't support
510+ # NVFP4Tensor (needs aten.index.Tensor). Override to keep grouped_mm.
511+ # TODO(future): remove when NVFP4 MoE supports bmm-style decode
512+ model ._optimize_model_for_decode = nullcontext
513+
434514 quantization_end_time = time .time ()
435515 quantization_time = quantization_end_time - quantization_start_time
436516
@@ -456,6 +536,12 @@ def skip_lm_head_o_proj(module, fqn):
456536 "Please install a version that includes "
457537 "https://github.com/huggingface/transformers/pull/45573"
458538 )
539+ if is_olmoe and "gate_up_proj" not in source :
540+ raise RuntimeError (
541+ "Your version of `transformers` does not support NVFP4 MoE serialization. "
542+ "Please install a version that includes "
543+ "https://github.com/huggingface/transformers/pull/45609"
544+ )
459545
460546 if args .quantization != "none" :
461547 # Attach hf_quantizer so save_pretrained uses the flatten path for tensor
0 commit comments