Skip to content

Commit 7ebe769

Browse files
committed
Update
[ghstack-poisoned]
1 parent c2e3c87 commit 7ebe769

1 file changed

Lines changed: 36 additions & 17 deletions

File tree

torchao/prototype/gptq/gptq_example.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -268,9 +268,31 @@ def parse_args():
268268
default=False,
269269
help="Only quantize `o_proj` layers, useful for faster GPTQ runs for debugging",
270270
)
271+
parser.add_argument(
272+
"--olmoe-layers-10-to-15-experts-only",
273+
action="store_true",
274+
default=False,
275+
help="For olmoe model, only quantize the experts (gate_up_proj and down_proj) on layers 10 through 15, useful for faster GPTQ runs for debugging",
276+
)
271277
return parser.parse_args()
272278

273279

280+
def get_fqn_to_config(config, olmoe_layers_10_to_15_experts_only=False):
281+
if olmoe_layers_10_to_15_experts_only:
282+
return FqnToConfig(
283+
{
284+
r"re:model\.layers\.(10|11|12|13|14|15)\.mlp\.experts\.gate_up_proj": config,
285+
r"re:model\.layers\.(10|11|12|13|14|15)\.mlp\.experts\.down_proj": config,
286+
}
287+
)
288+
return FqnToConfig(
289+
{
290+
r"re:.*\.experts\.gate_up_proj": config,
291+
r"re:.*\.experts\.down_proj": config,
292+
}
293+
)
294+
295+
274296
OLMOE_MODEL_ID = "allenai/OLMoE-1B-7B-0924"
275297

276298

@@ -374,6 +396,7 @@ def skip_lm_head_o_proj(module, fqn):
374396

375397
filter_fn_to_use = skip_lm_head
376398
if args.o_proj_only:
399+
assert not is_olmoe, "unsupported"
377400
filter_fn_to_use = skip_lm_head_o_proj
378401

379402
if args.quantization == "int4-rtn":
@@ -396,15 +419,14 @@ def skip_lm_head_o_proj(module, fqn):
396419
if is_olmoe:
397420
quantize_(
398421
model,
399-
FqnToConfig(
400-
{
401-
r"re:.*\.experts\.gate_up_proj": config,
402-
r"re:.*\.experts\.down_proj": config,
403-
}
422+
get_fqn_to_config(
423+
config,
424+
olmoe_layers_10_to_15_experts_only=args.olmoe_layers_10_to_15_experts_only,
404425
),
405426
filter_fn=None,
406427
)
407-
_verify_olmoe_experts_quantized(model)
428+
if not args.olmoe_layers_10_to_15_experts_only:
429+
_verify_olmoe_experts_quantized(model)
408430
else:
409431
quantize_(model, config, filter_fn=filter_fn_to_use)
410432
print(model)
@@ -444,11 +466,9 @@ def skip_lm_head_o_proj(module, fqn):
444466
if is_olmoe:
445467
quantize_(
446468
model,
447-
FqnToConfig(
448-
{
449-
r"re:.*\.experts\.gate_up_proj": observe_config,
450-
r"re:.*\.experts\.down_proj": observe_config,
451-
}
469+
get_fqn_to_config(
470+
observe_config,
471+
olmoe_layers_10_to_15_experts_only=args.olmoe_layers_10_to_15_experts_only,
452472
),
453473
filter_fn=None,
454474
)
@@ -495,15 +515,14 @@ def skip_lm_head_o_proj(module, fqn):
495515
if is_olmoe:
496516
quantize_(
497517
model,
498-
FqnToConfig(
499-
{
500-
r"re:.*\.experts\.gate_up_proj": convert_config,
501-
r"re:.*\.experts\.down_proj": convert_config,
502-
}
518+
get_fqn_to_config(
519+
convert_config,
520+
olmoe_layers_10_to_15_experts_only=args.olmoe_layers_10_to_15_experts_only,
503521
),
504522
filter_fn=None,
505523
)
506-
_verify_olmoe_experts_quantized(model)
524+
if not args.olmoe_layers_10_to_15_experts_only:
525+
_verify_olmoe_experts_quantized(model)
507526
else:
508527
quantize_(model, convert_config, filter_fn=filter_fn_to_use)
509528
else: # sequential

0 commit comments

Comments
 (0)