diff --git a/inference/infer.py b/inference/infer.py index d8f2b21..61bed03 100644 --- a/inference/infer.py +++ b/inference/infer.py @@ -84,7 +84,7 @@ def seed_everything(seed=42): model = AutoModelForCausalLM.from_pretrained( stage1_model, torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn + attn_implementation="sdpa", # Use SDPA instead of flash_attention_2 # device_map="auto", ) # to device, if gpu is available @@ -262,7 +262,7 @@ def split_lyrics(lyrics): model_stage2 = AutoModelForCausalLM.from_pretrained( stage2_model, torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", + attn_implementation="sdpa", # Use SDPA instead of flash_attention_2 # device_map="auto", ) model_stage2.to(device)