Skip to content

Commit 7152867

Browse files
committed
Update to v0.2.0
1 parent 0925377 commit 7152867

31 files changed

+272
-1237
lines changed

DEVELOPER.md

Lines changed: 0 additions & 39 deletions
This file was deleted.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ install_pyg_dependencies
3737
(Optional) You can install [RAPIDS](https://docs.rapids.ai/install) to accelerate visualization.
3838

3939
```sh
40-
mamba create -n decipher -c conda-forge -c rapidsai -c nvidia python=3.11 rapids=24.12 'cuda-version>=12.0,<=12.2' -y && conda activate decipher
40+
mamba create -n decipher -c conda-forge -c rapidsai -c nvidia python=3.11 rapids=25.06 'cuda-version>=12.0,<=12.8' -y && conda activate decipher
4141
pip install cell-decipher
4242
install_pyg_dependencies
4343
```

decipher/cls.py

Lines changed: 5 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,8 @@
88

99
import numpy as np
1010
import pandas as pd
11-
import scanpy as sc
1211
import scipy.sparse as sps
1312
import torch
14-
import torch.nn.functional as F
1513
import yaml
1614
from addict import Dict
1715
from anndata import AnnData
@@ -26,8 +24,7 @@
2624
from .explain.gene.mixin import GeneSelectMixin
2725
from .explain.regress.mixin import RegressMixin
2826
from .graphic.build import build_graph
29-
from .plot import plot_sc
30-
from .utils import CFG, estimate_spot_size, global_seed, scanpy_viz, sync_config
27+
from .utils import CFG, global_seed, l2norm, sync_config
3128

3229

3330
class DECIPHER(RegressMixin, GeneSelectMixin, MNNMixin, DDPMixin):
@@ -157,60 +154,24 @@ def fit_omics(self) -> None:
157154
mnn_dataset = MNNDataset(self.x, self.valid_cellidx, self.mnn_dict)
158155
logger.info(f"Using MNN with {len(np.unique(self.batch))} batches.")
159156
# train model
160-
sc_model, center_emb_pretrain = sc_emb(
161-
self.x, self.cfg.omics, mnn_dataset, self.meta, self.batch
162-
)
157+
sc_model, center_emb_pretrain = sc_emb(self.x, self.cfg.omics, mnn_dataset, self.batch)
163158
center_emb, self.nbr_emb = spatial_emb(
164159
self.x,
165160
self.edge_index,
166161
self.cfg.omics,
167162
mnn_dataset,
168-
self.meta,
169163
sc_model,
170164
self.batch,
171165
)
172166
self.center_emb = center_emb_pretrain if center_emb_pretrain else center_emb
173-
# as float
174-
self.center_emb = self.center_emb.astype(np.float32)
175-
self.nbr_emb = self.nbr_emb.astype(np.float32)
167+
# norm
168+
self.center_emb = l2norm(self.center_emb.astype(np.float32))
169+
self.nbr_emb = l2norm(self.nbr_emb.astype(np.float32))
176170
# save embeddings
177171
np.save(self.work_dir / "center_emb.npy", self.center_emb)
178172
np.save(self.work_dir / "nbr_emb.npy", self.nbr_emb)
179173
logger.info(f"Results saved to {self.work_dir}")
180174

181-
def visualize(self, resolution: float = 0.5) -> None:
182-
r"""
183-
Visualize results, should run after `fit_omics`
184-
185-
Parameters
186-
----------
187-
resolution
188-
resolution for clustering
189-
"""
190-
if (self.work_dir / "embedding.h5ad").exists():
191-
adata = sc.read_h5ad(self.work_dir / "embedding.h5ad")
192-
else:
193-
norm_center = F.normalize(torch.tensor(self.center_emb)).numpy()
194-
norm_nbr = F.normalize(torch.tensor(self.nbr_emb)).numpy()
195-
adata = sc.AnnData(
196-
X=np.zeros((self.center_emb.shape[0], 1)),
197-
obsm={
198-
"X_center": self.center_emb,
199-
"X_nbr": self.nbr_emb,
200-
"X_merge": np.hstack([norm_center, norm_nbr]),
201-
"spatial": self.coords,
202-
},
203-
obs=self.meta.astype(str),
204-
)
205-
adata.uns["spot_size"] = estimate_spot_size(adata.obsm["spatial"])
206-
adata = scanpy_viz(adata, resolution=resolution)
207-
adata.write_h5ad(self.work_dir / "embedding.h5ad")
208-
color_vars = ["leiden_center", "leiden_nbr"]
209-
for var in ["_celltype", "_batch"]:
210-
if var in adata.obs.columns:
211-
color_vars.append(var)
212-
plot_sc(adata, color_vars)
213-
214175
def load(self, from_dir: str | Path = None) -> None:
215176
r"""
216177
Load saved results, should run after `register_data`

decipher/data/mnn_dataset.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,6 @@
1919
from torch_geometric.data import Data
2020
from torch_geometric.data.lightning import LightningNodeData
2121

22-
try:
23-
import cupy as cp
24-
25-
CUPY_AVAILABLE = True
26-
except ImportError:
27-
CUPY_AVAILABLE = False
28-
2922
from ..graphic.knn import knn
3023
from ..utils import l2norm
3124

@@ -175,7 +168,7 @@ def train_dataloader(self):
175168
combined_loader = CombinedLoader(loaders, mode="max_size_cycle")
176169
return combined_loader
177170

178-
def val_dataloader(self):
171+
def test_dataloader(self):
179172
val_cfg = deepcopy(self.loader_config)
180173
val_cfg.update({"batch_size": 1024, "shuffle": False, "drop_last": False})
181174
return DataLoader(self.val_dataset, **val_cfg)
@@ -275,7 +268,7 @@ def svd(x: np.ndarray, y: np.ndarray, k_components: int = 20) -> tuple[np.ndarra
275268
"""
276269
logger.debug(f"x shape: {x.shape}, y shape: {y.shape}")
277270
if x.shape[0] > 1_000_000 or y.shape[0] > 1_000_000:
278-
logger.debug("Use harmony for large dataset.")
271+
logger.debug("Use harmony-based SVD for large dataset.")
279272
from harmony import harmonize
280273

281274
# batch
@@ -287,22 +280,16 @@ def svd(x: np.ndarray, y: np.ndarray, k_components: int = 20) -> tuple[np.ndarra
287280
# harmonize
288281
z_norm = harmonize(z, batch, "batch", use_gpu=True)
289282
return z_norm
290-
elif x.shape[0] > 200_000 or y.shape[0] > 200_000:
291-
logger.debug("Use CPU for middle dataset")
292-
dot = torch.from_numpy(x) @ torch.from_numpy(y).T # faster than np
283+
284+
try:
285+
dot = torch.from_numpy(x).cuda().half() @ torch.from_numpy(y).T.cuda().half()
286+
dot = dot.cpu().float().numpy()
287+
logger.info("Use CUDA for small dataset")
288+
except: # noqa
289+
logger.error(f"CUDA failed: {x.shape}, {y.shape}, use CPU instead.")
290+
dot = torch.from_numpy(x) @ torch.from_numpy(y).T
293291
dot = dot.numpy()
294-
else:
295-
try:
296-
dot = torch.from_numpy(x).cuda().half() @ torch.from_numpy(y).T.cuda().half()
297-
if CUPY_AVAILABLE:
298-
dot = cp.asarray(dot.to(torch.float32)).get()
299-
else:
300-
dot = dot.cpu().float().numpy()
301-
logger.info("Use CUDA for small dataset")
302-
except: # noqa
303-
logger.error("CUDA failed")
304-
dot = torch.from_numpy(x) @ torch.from_numpy(y).T
305-
dot = dot.numpy()
292+
torch.cuda.empty_cache()
306293
u, s, vh = randomized_svd(dot, n_components=k_components, random_state=0)
307294
z = np.vstack([u, vh.T]) # gene x k_components
308295
z = z @ np.sqrt(np.diag(s)) # will reduce the MNN pairs number greatly

decipher/ddp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def fit_ddp(self, gpus: int = -1, ddp_pretrain: bool = False) -> None:
4141
logger.warning("Using DDP with < 500k cells is not recommended.")
4242

4343
max_gpus = torch.cuda.device_count()
44+
assert max_gpus > 1, "DDP requires at least 2 GPUs."
4445
gpus = min(gpus, max_gpus) if gpus > 0 else max_gpus
4546

4647
if ddp_pretrain:

decipher/emb.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from pathlib import Path
77

88
import numpy as np
9-
import pandas as pd
109
import torch
1110
from addict import Dict
1211
from loguru import logger
@@ -58,7 +57,6 @@ def spatial_emb(
5857
spatial_edge: Tensor,
5958
config: Dict,
6059
mnn_dataset: Dataset = None,
61-
meta: pd.DataFrame = None,
6260
pretrained_model: ScSimCLR = None,
6361
batch: np.ndarray = None,
6462
DDP: bool = False,
@@ -76,8 +74,6 @@ def spatial_emb(
7674
model config
7775
mnn_dataset
7876
mnn dataset
79-
meta
80-
meta of cells
8177
pretrained_model
8278
pre-trained single cell model
8379
batch
@@ -99,9 +95,9 @@ def spatial_emb(
9995
datamodule = get_graph_datamodule(graph, config, mnn_dataset)
10096

10197
if mnn_dataset is None:
102-
model = OmicsSpatialSimCLR(config.model, meta)
98+
model = OmicsSpatialSimCLR(config.model)
10399
else:
104-
model = OmicsSpatialSimCLRMNN(config.model, meta)
100+
model = OmicsSpatialSimCLRMNN(config.model)
105101

106102
if pretrained_model is not None:
107103
model.center_encoder = deepcopy(pretrained_model.center_encoder)
@@ -121,7 +117,6 @@ def sc_emb(
121117
x: np.ndarray,
122118
config: Dict,
123119
mnn_dataset: Dataset = None,
124-
meta: pd.DataFrame = None,
125120
batch: np.ndarray = None,
126121
) -> tuple[ScSimCLR, np.ndarray | None]:
127122
r"""
@@ -135,8 +130,6 @@ def sc_emb(
135130
model config
136131
mnn_dataset:
137132
mnn dataset
138-
meta:
139-
meta of cells
140133
batch:
141134
batch index
142135
@@ -152,7 +145,7 @@ def sc_emb(
152145
mnn_flag = True if mnn_dataset is not None else False
153146
if not config.pretrain.force:
154147
try:
155-
return load_sc_model(config, mnn_flag, meta), None
148+
return load_sc_model(config, mnn_flag), None
156149
except Exception as e: # noqa
157150
logger.info(f"Not found pre-trained model: {e}")
158151

@@ -166,9 +159,9 @@ def sc_emb(
166159
datamodule = LightningScMNNData(config.loader, train_dataset, val_dataset, mnn_dataset)
167160

168161
if mnn_flag:
169-
model = ScSimCLRMNN(meta, config.model)
162+
model = ScSimCLRMNN(config.model)
170163
else:
171-
model = ScSimCLR(meta, config.model)
164+
model = ScSimCLR(config.model)
172165

173166
if config.model.fix_sc:
174167
fit_and_inference(model, datamodule, config.model, show_name="single cell")
@@ -179,7 +172,7 @@ def sc_emb(
179172
return model, center_emb
180173

181174

182-
def load_sc_model(config, mnn_flag: bool, meta: pd.DataFrame = None):
175+
def load_sc_model(config, mnn_flag: bool):
183176
r"""
184177
Load omics encoder model
185178
@@ -189,14 +182,12 @@ def load_sc_model(config, mnn_flag: bool, meta: pd.DataFrame = None):
189182
model config
190183
mnn_flag
191184
whether use mnn
192-
meta
193-
meta of cells
194185
"""
195186
model_path = Path(config.model.work_dir) / "pretrain"
196187
# sort by modification time
197188
model_path = sorted(model_path.glob("*.ckpt"), key=os.path.getmtime)[-1]
198189
logger.info(f"Loading model from {model_path}")
199-
kwargs = {"meta": meta, "config": config.model}
190+
kwargs = {"config": config.model}
200191
if mnn_flag:
201192
sc_model = ScSimCLRMNN.load_from_checkpoint(model_path, **kwargs)
202193
else:
@@ -205,7 +196,7 @@ def load_sc_model(config, mnn_flag: bool, meta: pd.DataFrame = None):
205196
return sc_model
206197

207198

208-
def load_spatial_model(config, mnn_flag: bool, meta: pd.DataFrame = None):
199+
def load_spatial_model(config, mnn_flag: bool):
209200
r"""
210201
Load decipher spatial model
211202
@@ -215,14 +206,12 @@ def load_spatial_model(config, mnn_flag: bool, meta: pd.DataFrame = None):
215206
model config
216207
mnn_flag
217208
whether use mnn
218-
meta
219-
meta of cells
220209
"""
221210
model_path = Path(config.model.work_dir) / "model"
222211
model_path = sorted(model_path.glob("*.ckpt"), key=os.path.getmtime)[-1]
223212
logger.info(f"Loading model from {model_path}")
224213
config.model.device_num = 1
225-
kwargs = {"config": config.model, "meta": meta}
214+
kwargs = {"config": config.model}
226215
if mnn_flag:
227216
model = OmicsSpatialSimCLRMNN.load_from_checkpoint(model_path, **kwargs)
228217
else:

decipher/graphic/knn.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,7 @@
66
from annoy import AnnoyIndex
77
from loguru import logger
88

9-
try:
10-
from cuml.neighbors import NearestNeighbors as cuNearestNeighbors
11-
12-
CUML_FLAG = True
13-
except ImportError:
14-
CUML_FLAG = False
15-
logger.warning("cuML is not available.")
9+
from ..utils import RSC_FLAG
1610

1711

1812
def knn(
@@ -52,7 +46,7 @@ def knn(
5246
if method == "auto":
5347
method = ["cuml", "faiss", "annoy"]
5448
method = method if isinstance(method, list) else [method]
55-
if not CUML_FLAG and "cuml" in method:
49+
if not RSC_FLAG and "cuml" in method:
5650
method.remove("cuml")
5751
if not approx and "annoy" in method:
5852
method.remove("annoy")
@@ -123,6 +117,8 @@ def knn_cuml(
123117
r"""
124118
Build k-NN graph by cuML
125119
"""
120+
from cuml.neighbors import NearestNeighbors as cuNearestNeighbors
121+
126122
model = cuNearestNeighbors(n_neighbors=k, metric=metric)
127123
model.fit(ref)
128124
distances, indices = model.kneighbors(query)

0 commit comments

Comments
 (0)