-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Fix #8462: embed patch sizes in einops pattern for einops >= 0.8 compatibility #8834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,6 +29,32 @@ | |
| SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos", "fourier"} | ||
|
|
||
|
|
||
| class _PatchRearrange(nn.Module): | ||
| """Fallback patch rearrangement using pure PyTorch, for einops compatibility.""" | ||
|
|
||
| def __init__(self, spatial_dims: int, patch_size: tuple) -> None: | ||
| super().__init__() | ||
| self.spatial_dims = spatial_dims | ||
| self.patch_size = patch_size | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| B, C = x.shape[0], x.shape[1] | ||
| sp = x.shape[2:] | ||
| g = tuple(s // p for s, p in zip(sp, self.patch_size)) | ||
| v: list[int] = [B, C] | ||
| for gi, pi in zip(g, self.patch_size): | ||
| v += [gi, pi] | ||
| x = x.view(*v) | ||
| n = self.spatial_dims | ||
| gdims = list(range(2, 2 + 2 * n, 2)) | ||
| pdims = list(range(3, 3 + 2 * n, 2)) | ||
| x = x.permute(0, *gdims, *pdims, 1).contiguous() | ||
| n_patches = 1 | ||
| for gi in g: | ||
| n_patches *= gi | ||
| return x.reshape(B, n_patches, -1) | ||
|
|
||
|
|
||
| class PatchEmbeddingBlock(nn.Module): | ||
| """ | ||
| A patch embedding block, based on: "Dosovitskiy et al., | ||
|
|
@@ -97,14 +123,16 @@ def __init__( | |
| in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size | ||
| ) | ||
| elif self.proj_type == "perceptron": | ||
| # for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)" | ||
| # for 3d: "b c (h p1) (w p2) (d p3) -> b (h w d) (p1 p2 p3 c)" | ||
| chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims] | ||
| from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars) | ||
| to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)" | ||
| axes_len = {f"p{i + 1}": p for i, p in enumerate(patch_size)} | ||
| self.patch_embeddings = nn.Sequential( | ||
| Rearrange(f"{from_chars} -> {to_chars}", **axes_len), nn.Linear(self.patch_dim, hidden_size) | ||
| ) | ||
| try: | ||
| rearrange_layer: nn.Module = Rearrange(f"{from_chars} -> {to_chars}", **axes_len) | ||
| except TypeError: | ||
| rearrange_layer = _PatchRearrange(spatial_dims, tuple(int(p) for p in patch_size)) | ||
| self.patch_embeddings = nn.Sequential(rearrange_layer, nn.Linear(self.patch_dim, hidden_size)) | ||
|
Comment on lines
+126
to
+135
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🌐 Web query:
💡 Result: No, einops Rearrange patterns do not support standalone integer literals (e.g., 'b 2 h w') as axis specifications. Patterns use symbolic axis names (letters like 'b', 'h'), anonymous axes (numbers in parentheses like '(b1 b2)'), ellipsis (...), or underscores (_ for skipping). Integer values are provided via the axes_lengths keyword argument (e.g., rearrange(x, 'b (h1 h) w c -> ...', h1=2)) or parsed from tensor shapes via parse_shape. This syntax has been consistent across versions, including 0.8.0+ (released 2024-04-28), with no changes introducing integer literal support in patterns per release notes, changelogs, and docs. Examples in official documentation and tutorials confirm named axes only. Citations:
🏁 Script executed: # Get the file and examine _PatchRearrange class and test coverage
head -150 monai/networks/blocks/patchembedding.py | tail -100Repository: Project-MONAI/MONAI Length of output: 4640 🏁 Script executed: # Check for existing tests covering the fallback path
find . -name "*.py" -path "*/test*" -type f | xargs grep -l "patchembedding" | head -5Repository: Project-MONAI/MONAI Length of output: 130 🏁 Script executed: # Get the _PatchRearrange class definition (around lines 32-55)
sed -n '32,56p' monai/networks/blocks/patchembedding.pyRepository: Project-MONAI/MONAI Length of output: 958 🏁 Script executed: # Check the test file for coverage of _PatchRearrange and the fallback path
cat monai/tests/networks/blocks/test_patchembedding.pyRepository: Project-MONAI/MONAI Length of output: 145 🏁 Script executed: # Find the test file with correct path
find . -name "test_patchembedding.py" -type f 2>/dev/nullRepository: Project-MONAI/MONAI Length of output: 109 🏁 Script executed: # Also check for any tests that import or test _PatchRearrange
rg "_PatchRearrange" --type pyRepository: Project-MONAI/MONAI Length of output: 275 🏁 Script executed: # Read the test file
cat ./tests/networks/blocks/test_patchembedding.pyRepository: Project-MONAI/MONAI Length of output: 7510 🏁 Script executed: # Check einops version constraints and any comments about the fallback logic
head -30 monai/networks/blocks/patchembedding.pyRepository: Project-MONAI/MONAI Length of output: 1308 🏁 Script executed: # Verify if x.view() in _PatchRearrange could fail on non-contiguous tensors
rg "x\.view\(" monai/networks/blocks/patchembedding.py -A 2 -B 2Repository: Project-MONAI/MONAI Length of output: 233 Retract the simplification suggestion; the current try/except approach is necessary. einops Rearrange does not support integer literals in patterns—numeric axis values must be passed via However, address these remaining issues in
🤖 Prompt for AI Agents
Comment on lines
+131
to
+135
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fallback path isn't deterministically covered by tests.
As per coding guidelines: "Ensure new or modified definitions will be covered by existing or new unit tests." 🤖 Prompt for AI Agents |
||
| self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) | ||
| self.dropout = nn.Dropout(dropout_rate) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
If the fallback is kept: tighten type hint, prefer
reshape, add Google-style docstrings.Three points on
_PatchRearrange:x.view(*v)at line 47 will raise on non-contiguous inputs.reshapeis safer and no slower here.patch_size: tupleis too loose —tuple[int, ...].Args:/Returns:are expected on__init__andforward; the current one-line class docstring doesn't cover them.As per coding guidelines: "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."
♻️ Proposed patch
🧰 Tools
🪛 Ruff (0.15.10)
[warning] 43-43:
zip()without an explicitstrict=parameterAdd explicit value for parameter
strict=(B905)
[warning] 45-45:
zip()without an explicitstrict=parameterAdd explicit value for parameter
strict=(B905)
🤖 Prompt for AI Agents