66from pathlib import Path
77
88import numpy as np
9- import pandas as pd
109import torch
1110from addict import Dict
1211from loguru import logger
@@ -58,7 +57,6 @@ def spatial_emb(
5857 spatial_edge : Tensor ,
5958 config : Dict ,
6059 mnn_dataset : Dataset = None ,
61- meta : pd .DataFrame = None ,
6260 pretrained_model : ScSimCLR = None ,
6361 batch : np .ndarray = None ,
6462 DDP : bool = False ,
@@ -76,8 +74,6 @@ def spatial_emb(
7674 model config
7775 mnn_dataset
7876 mnn dataset
79- meta
80- meta of cells
8177 pretrained_model
8278 pre-trained single cell model
8379 batch
@@ -99,9 +95,9 @@ def spatial_emb(
9995 datamodule = get_graph_datamodule (graph , config , mnn_dataset )
10096
10197 if mnn_dataset is None :
102- model = OmicsSpatialSimCLR (config .model , meta )
98+ model = OmicsSpatialSimCLR (config .model )
10399 else :
104- model = OmicsSpatialSimCLRMNN (config .model , meta )
100+ model = OmicsSpatialSimCLRMNN (config .model )
105101
106102 if pretrained_model is not None :
107103 model .center_encoder = deepcopy (pretrained_model .center_encoder )
@@ -121,7 +117,6 @@ def sc_emb(
121117 x : np .ndarray ,
122118 config : Dict ,
123119 mnn_dataset : Dataset = None ,
124- meta : pd .DataFrame = None ,
125120 batch : np .ndarray = None ,
126121) -> tuple [ScSimCLR , np .ndarray | None ]:
127122 r"""
@@ -135,8 +130,6 @@ def sc_emb(
135130 model config
136131 mnn_dataset:
137132 mnn dataset
138- meta:
139- meta of cells
140133 batch:
141134 batch index
142135
@@ -152,7 +145,7 @@ def sc_emb(
152145 mnn_flag = True if mnn_dataset is not None else False
153146 if not config .pretrain .force :
154147 try :
155- return load_sc_model (config , mnn_flag , meta ), None
148+ return load_sc_model (config , mnn_flag ), None
156149 except Exception as e : # noqa
157150 logger .info (f"Not found pre-trained model: { e } " )
158151
@@ -166,9 +159,9 @@ def sc_emb(
166159 datamodule = LightningScMNNData (config .loader , train_dataset , val_dataset , mnn_dataset )
167160
168161 if mnn_flag :
169- model = ScSimCLRMNN (meta , config .model )
162+ model = ScSimCLRMNN (config .model )
170163 else :
171- model = ScSimCLR (meta , config .model )
164+ model = ScSimCLR (config .model )
172165
173166 if config .model .fix_sc :
174167 fit_and_inference (model , datamodule , config .model , show_name = "single cell" )
@@ -179,7 +172,7 @@ def sc_emb(
179172 return model , center_emb
180173
181174
182- def load_sc_model (config , mnn_flag : bool , meta : pd . DataFrame = None ):
175+ def load_sc_model (config , mnn_flag : bool ):
183176 r"""
184177 Load omics encoder model
185178
@@ -189,14 +182,12 @@ def load_sc_model(config, mnn_flag: bool, meta: pd.DataFrame = None):
189182 model config
190183 mnn_flag
191184 whether use mnn
192- meta
193- meta of cells
194185 """
195186 model_path = Path (config .model .work_dir ) / "pretrain"
196187 # sort by modification time
197188 model_path = sorted (model_path .glob ("*.ckpt" ), key = os .path .getmtime )[- 1 ]
198189 logger .info (f"Loading model from { model_path } " )
199- kwargs = {"meta" : meta , " config" : config .model }
190+ kwargs = {"config" : config .model }
200191 if mnn_flag :
201192 sc_model = ScSimCLRMNN .load_from_checkpoint (model_path , ** kwargs )
202193 else :
@@ -205,7 +196,7 @@ def load_sc_model(config, mnn_flag: bool, meta: pd.DataFrame = None):
205196 return sc_model
206197
207198
208- def load_spatial_model (config , mnn_flag : bool , meta : pd . DataFrame = None ):
199+ def load_spatial_model (config , mnn_flag : bool ):
209200 r"""
210201 Load decipher spatial model
211202
@@ -215,14 +206,12 @@ def load_spatial_model(config, mnn_flag: bool, meta: pd.DataFrame = None):
215206 model config
216207 mnn_flag
217208 whether use mnn
218- meta
219- meta of cells
220209 """
221210 model_path = Path (config .model .work_dir ) / "model"
222211 model_path = sorted (model_path .glob ("*.ckpt" ), key = os .path .getmtime )[- 1 ]
223212 logger .info (f"Loading model from { model_path } " )
224213 config .model .device_num = 1
225- kwargs = {"config" : config .model , "meta" : meta }
214+ kwargs = {"config" : config .model }
226215 if mnn_flag :
227216 model = OmicsSpatialSimCLRMNN .load_from_checkpoint (model_path , ** kwargs )
228217 else :
0 commit comments