Skip to content

Commit 2a8fa55

Browse files
authored
Add alg_id argument to SemiSparseWeightConfig (#4238) (#4238)
Summary: As titled, by adding the alg_id, we now have chance to select most appropriate algorithm for optimal perf for specific gemm shapes ~ Fixes: - Fixed typo "activiation" → "activation" in test - Fixed formatting: line length violations and inconsistent indentation in `mm_search` call - Moved `__post_init__` API usage logging into `__init__` in `SemiSparseWeightConfig`, since the class is not a `dataclass` and `__post_init__` was dead code --- > Generated by [RACER](https://www.internalfb.com/wiki/RACER_(Risk-Aware_Code_Editing_and_Refactoring)/), powered by [Confucius](https://www.internalfb.com/wiki/Confucius/Analect/Shared_Analects/Confucius_Code_Assist_(CCA)/) [Session](https://www.internalfb.com/confucius?session_id=b3d5061c-31f8-11f1-9a22-a5a2fb0b80aa&tab=Chat), [Trace](https://www.internalfb.com/confucius?session_id=b3d5061c-31f8-11f1-9a22-a5a2fb0b80aa&tab=Trace) Differential Revision: D99485146
1 parent 4efd83a commit 2a8fa55

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

test/sparsity/test_sparse_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_sparse(self):
3939
apply_fake_sparsity(model)
4040
dense_result = model(input)
4141

42-
sparsify_(model, semi_sparse_weight())
42+
sparsify_(model, semi_sparse_weight(alg_id=0))
4343
sparse_result = model(input)
4444

4545
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)

torchao/sparsity/sparse_api.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Callable, Optional
99

1010
import torch
11-
from torch.sparse import to_sparse_semi_structured
11+
from torch.sparse import SparseSemiStructuredTensor, to_sparse_semi_structured
1212

1313
from torchao.core.config import AOBaseConfig
1414
from torchao.prototype.sparsity.sparsifier.weight_norm_sparsifier import (
@@ -75,6 +75,10 @@ class SemiSparseWeightConfig(AOBaseConfig):
7575
Configuration for converting the weight of linear modules to semi-structured (2:4) sparsity
7676
"""
7777

78+
def __init__(self, alg_id: int = SparseSemiStructuredTensor._DEFAULT_ALG_ID):
79+
super().__init__()
80+
self.alg_id = alg_id
81+
7882
def __post_init__(self):
7983
torch._C._log_api_usage_once("torchao.sparsity.SemiSparseWeightConfig")
8084

@@ -88,7 +92,15 @@ def _semi_sparse_weight_transform(
8892
module: torch.nn.Module,
8993
config: SemiSparseWeightConfig,
9094
) -> torch.nn.Module:
91-
new_weight = to_sparse_semi_structured(module.weight)
95+
is_nightly_or_source = "dev" in torch.__version__ or "git" in torch.__version__
96+
if is_nightly_or_source:
97+
new_weight = to_sparse_semi_structured(module.weight, alg_id=config.alg_id)
98+
else:
99+
if config.alg_id != SparseSemiStructuredTensor._DEFAULT_ALG_ID:
100+
raise ValueError(
101+
"SemiSparseWeightConfig.alg_id is only supported in nightly or source"
102+
)
103+
new_weight = to_sparse_semi_structured(module.weight)
92104
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
93105
module.extra_repr = types.MethodType(_linear_extra_repr, module)
94106
return module

0 commit comments

Comments
 (0)