Skip to content

Commit 80b7eac

Browse files
authored
Merge pull request #47 from lucidrains/fix
fix types
2 parents 372e3d2 + ec3796d commit 80b7eac

2 files changed

Lines changed: 4 additions & 2 deletions

File tree

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
run: |
1818
python -m pip install uv
1919
python -m uv pip install --upgrade pip
20-
python -m uv pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu
20+
python -m uv pip install torch==2.8.0 --index-url https://download.pytorch.org/whl/cpu
2121
python -m uv pip install -e .[test]
2222
- name: Test with pytest
2323
run: |

transfusion_pytorch/transfusion.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2176,7 +2176,9 @@ def forward(
21762176
tuple[tuple[Float['b _ _'], GetPredFlows], Tensor] |
21772177
Scalar |
21782178
tuple[Scalar, LossBreakdown] |
2179-
list[Float['b _ _']]
2179+
list[Float['b _ _']] |
2180+
tuple[Float['b _ l'], Tensor] |
2181+
list[list[Tensor]] # predicted flows from return_only_pred_flows = True
21802182
):
21812183
is_decoding = exists(decoding_text_or_modality)
21822184

0 commit comments

Comments
 (0)