@@ -268,9 +268,31 @@ def parse_args():
268268 default = False ,
269269 help = "Only quantize `o_proj` layers, useful for faster GPTQ runs for debugging" ,
270270 )
271+ parser .add_argument (
272+ "--olmoe-layers-10-to-15-experts-only" ,
273+ action = "store_true" ,
274+ default = False ,
275+ help = "For olmoe model, only quantize the experts (gate_up_proj and down_proj) on layers 10 through 15, useful for faster GPTQ runs for debugging" ,
276+ )
271277 return parser .parse_args ()
272278
273279
280+ def get_fqn_to_config (config , olmoe_layers_10_to_15_experts_only = False ):
281+ if olmoe_layers_10_to_15_experts_only :
282+ return FqnToConfig (
283+ {
284+ r"re:model\.layers\.(10|11|12|13|14|15)\.mlp\.experts\.gate_up_proj" : config ,
285+ r"re:model\.layers\.(10|11|12|13|14|15)\.mlp\.experts\.down_proj" : config ,
286+ }
287+ )
288+ return FqnToConfig (
289+ {
290+ r"re:.*\.experts\.gate_up_proj" : config ,
291+ r"re:.*\.experts\.down_proj" : config ,
292+ }
293+ )
294+
295+
274296OLMOE_MODEL_ID = "allenai/OLMoE-1B-7B-0924"
275297
276298
@@ -374,6 +396,7 @@ def skip_lm_head_o_proj(module, fqn):
374396
375397 filter_fn_to_use = skip_lm_head
376398 if args .o_proj_only :
399+ assert not is_olmoe , "unsupported"
377400 filter_fn_to_use = skip_lm_head_o_proj
378401
379402 if args .quantization == "int4-rtn" :
@@ -396,15 +419,14 @@ def skip_lm_head_o_proj(module, fqn):
396419 if is_olmoe :
397420 quantize_ (
398421 model ,
399- FqnToConfig (
400- {
401- r"re:.*\.experts\.gate_up_proj" : config ,
402- r"re:.*\.experts\.down_proj" : config ,
403- }
422+ get_fqn_to_config (
423+ config ,
424+ olmoe_layers_10_to_15_experts_only = args .olmoe_layers_10_to_15_experts_only ,
404425 ),
405426 filter_fn = None ,
406427 )
407- _verify_olmoe_experts_quantized (model )
428+ if not args .olmoe_layers_10_to_15_experts_only :
429+ _verify_olmoe_experts_quantized (model )
408430 else :
409431 quantize_ (model , config , filter_fn = filter_fn_to_use )
410432 print (model )
@@ -444,11 +466,9 @@ def skip_lm_head_o_proj(module, fqn):
444466 if is_olmoe :
445467 quantize_ (
446468 model ,
447- FqnToConfig (
448- {
449- r"re:.*\.experts\.gate_up_proj" : observe_config ,
450- r"re:.*\.experts\.down_proj" : observe_config ,
451- }
469+ get_fqn_to_config (
470+ observe_config ,
471+ olmoe_layers_10_to_15_experts_only = args .olmoe_layers_10_to_15_experts_only ,
452472 ),
453473 filter_fn = None ,
454474 )
@@ -495,15 +515,14 @@ def skip_lm_head_o_proj(module, fqn):
495515 if is_olmoe :
496516 quantize_ (
497517 model ,
498- FqnToConfig (
499- {
500- r"re:.*\.experts\.gate_up_proj" : convert_config ,
501- r"re:.*\.experts\.down_proj" : convert_config ,
502- }
518+ get_fqn_to_config (
519+ convert_config ,
520+ olmoe_layers_10_to_15_experts_only = args .olmoe_layers_10_to_15_experts_only ,
503521 ),
504522 filter_fn = None ,
505523 )
506- _verify_olmoe_experts_quantized (model )
524+ if not args .olmoe_layers_10_to_15_experts_only :
525+ _verify_olmoe_experts_quantized (model )
507526 else :
508527 quantize_ (model , convert_config , filter_fn = filter_fn_to_use )
509528 else : # sequential
0 commit comments