@@ -613,16 +613,17 @@ def __init__(
613613 def forward (
614614 self ,
615615 x ,
616- attn_mask : Tensor | None = None ,
616+ attn_mask : Tensor | None = None , # for manual masking
617617 rotary_emb : Tensor | None = None ,
618618 cache : Tensor | None = None ,
619619 causal = False ,
620- block_mask = None ,
620+ block_mask = None , # only passed in for flex attention
621621 return_kv_cache = False
622622 ):
623623 device = x .device
624624
625625 assert not (exists (block_mask ) and exists (attn_mask ))
626+ assert not (not self .use_flex_attn and exists (block_mask )), 'you cannot pass in the `block_mask` if `use_flex_attn` was not set to be `True`'
626627
627628 x = self .norm (x )
628629
@@ -748,7 +749,10 @@ def forward(
748749 causal_mask : bool = False ,
749750 return_kv_cache = False
750751 ):
751- batch , seq_len , device = x .shape [0 ], x .shape [- 2 ], x .device
752+ batch , seq_len , device , input_is_cuda = x .shape [0 ], x .shape [- 2 ], x .device , x .is_cuda
753+
754+ should_use_flex_attn = input_is_cuda and self .use_flex_attn
755+
752756 assert not (exists (attn_mask ) and exists (modality_positions ))
753757
754758 # handle time
@@ -766,7 +770,7 @@ def forward(
766770 attn_mask_kwargs = dict ()
767771
768772 if causal_mask :
769- if self . use_flex_attn :
773+ if should_use_flex_attn :
770774 block_mask = create_block_mask (causal , B = None , H = None , Q_LEN = seq_len , KV_LEN = seq_len , device = device )
771775 attn_mask_kwargs .update (block_mask = block_mask )
772776 else :
@@ -775,7 +779,7 @@ def forward(
775779 if exists (modality_positions ):
776780 assert not causal_mask
777781
778- if self . use_flex_attn :
782+ if should_use_flex_attn :
779783 transfusion_mask_fn = transfusion_attn_mask (modality_positions )
780784 block_mask = create_block_mask (transfusion_mask_fn , B = None , H = None , Q_LEN = seq_len , KV_LEN = seq_len , device = device )
781785 attn_mask_kwargs .update (block_mask = block_mask )
0 commit comments