Skip to content

Commit 905df45

Browse files
committed
abstract custom models, support v3/v4
1 parent fc59d7f commit 905df45

File tree

1 file changed

+47
-39
lines changed

1 file changed

+47
-39
lines changed

active_plugins/runcellpose.py

Lines changed: 47 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -88,21 +88,27 @@
8888
CELLPOSE_DOCKERS = {'v2': ["cellprofiler/runcellpose_no_pretrained:2.3.2",
8989
"cellprofiler/runcellpose_with_pretrained:2.3.2",
9090
"cellprofiler/runcellpose_with_pretrained:2.2"],
91-
'v3': ["docker3"], #TODO
91+
'v3': ["biocontainers/cellpose:3.1.0_cv1"], #TODO
9292
'v4': ["docker4"]} #TODO
9393

9494
"Detection mode"
9595
MODEL_NAMES = {'v2':['cyto','nuclei','tissuenet','livecell', 'cyto2', 'general',
9696
'CP', 'CPx', 'TN1', 'TN2', 'TN3', 'LC1', 'LC2', 'LC3', 'LC4', 'custom'],
9797
'v3':[ "cyto3", "nuclei", "cyto2_cp3", "tissuenet_cp3", "livecell_cp3", "yeast_PhC_cp3",
9898
"yeast_BF_cp3", "bact_phase_cp3", "bact_fluor_cp3", "deepbacs_cp3", "cyto2", "cyto"],
99-
'v4':['model4']} #TODO
99+
'v4':['cpsam']}
100100

101101
DENOISER_NAMES = ['denoise_cyto3', 'deblur_cyto3', 'upsample_cyto3',
102102
'denoise_nuclei', 'deblur_nuclei', 'upsample_nuclei']
103103
# Only these models support size scaling
104104
SIZED_MODELS = {"cyto3", "cyto2", "cyto", "nuclei"}
105105

106+
def get_custom_model_vars(self):
107+
model_file = self.model_file_name.value
108+
model_directory = self.model_directory.get_absolute_path()
109+
model_path = os.path.join(model_directory, model_file)
110+
return model_file, model_directory, model_path
111+
106112
class RunCellpose(ImageSegmentation):
107113
category = "Object Processing"
108114

@@ -197,7 +203,7 @@ def create_settings(self):
197203
In Cellpose 1-3, Cellpose models come with a pre-defined object diameter. Your image will be resized during detection to attempt to
198204
match the diameter expected by the model. The default models have an expected diameter of ~16 pixels, if trying to
199205
detect much smaller objects it may be more efficient to resize the image first using the Resize module.
200-
If set to 0 in Cellpose 1-3, it will attempt to automatically detect object size. Note that automatic diameter mode does not work when running on 3D images.
206+
If set to 0 in Cellpose 1-3, it will attempt to automatically detect object size.
201207
Note that automatic diameter mode does not work when running on 3D images.
202208
""",
203209
)
@@ -497,11 +503,11 @@ def visible_settings(self):
497503
if self.docker_or_python.value == "Python":
498504
vis_settings += [self.omni]
499505

500-
if self.mode_v2.value != "nuclei":
506+
if self.mode.value != "nuclei":
501507
vis_settings += [self.supply_nuclei]
502508
if self.supply_nuclei.value:
503509
vis_settings += [self.nuclei_image]
504-
if self.mode_v2.value == "custom":
510+
if self.mode.value == "custom":
505511
vis_settings += [
506512
self.model_directory,
507513
self.model_file_name,
@@ -516,9 +522,10 @@ def visible_settings(self):
516522
self.min_size,
517523
self.flow_threshold,
518524
self.y_name,
519-
self.invert,
520525
self.save_probabilities,
521526
]
527+
if self.cellpose_version.value in ['v2','v3']:
528+
vis_settings += [self.invert]
522529

523530
vis_settings += [self.do_3D, self.stitch_threshold, self.remove_edge_masks]
524531

@@ -546,7 +553,7 @@ def visible_settings(self):
546553
def validate_module(self, pipeline):
547554
"""If using custom model, validate the model file opens and works"""
548555
from cellpose import models
549-
if self.mode_v2.value == "custom":
556+
if self.mode.value == "custom":
550557
model_file = self.model_file_name.value
551558
model_directory = self.model_directory.get_absolute_path()
552559
model_path = os.path.join(model_directory, model_file)
@@ -602,7 +609,7 @@ def run(self, workspace):
602609
"Color images are not currently supported. Please provide greyscale images."
603610
)
604611

605-
if self.mode_v2.value != "nuclei" and self.supply_nuclei.value:
612+
if self.mode.value != "nuclei" and self.supply_nuclei.value:
606613
nuc_image = images.get_image(self.nuclei_image.value)
607614
# CellPose 1-3 expects RGB, we'll have a blank red channel, cells in green and nuclei in blue.
608615
if self.do_3D.value:
@@ -630,23 +637,19 @@ def run(self, workspace):
630637
if self.cellpose_version.value == 'v2':
631638
assert int(self.cellpose_ver[0])<=2, "Cellpose version selected in RunCellpose module doesn't match version in Python"
632639
if float(self.cellpose_ver[0:3]) >= 0.6 and int(self.cellpose_ver[0])<2:
633-
if self.mode_v2.value != 'custom':
634-
model = models.Cellpose(model_type= self.mode_v2.value,
640+
if self.mode.value != 'custom':
641+
model = models.Cellpose(model_type= self.mode.value,
635642
gpu=self.use_gpu.value)
636643
else:
637-
model_file = self.model_file_name.value
638-
model_directory = self.model_directory.get_absolute_path()
639-
model_path = os.path.join(model_directory, model_file)
644+
model_file, model_directory, model_path = get_custom_model_vars(self)
640645
model = models.CellposeModel(pretrained_model=model_path, gpu=self.use_gpu.value)
641646

642647
else:
643-
if self.mode_v2.value != 'custom':
644-
model = models.CellposeModel(model_type= self.mode_v2.value,
648+
if self.mode.value != 'custom':
649+
model = models.CellposeModel(model_type= self.mode.value,
645650
gpu=self.use_gpu.value)
646651
else:
647-
model_file = self.model_file_name.value
648-
model_directory = self.model_directory.get_absolute_path()
649-
model_path = os.path.join(model_directory, model_file)
652+
model_file, model_directory, model_path = get_custom_model_vars(self)
650653
model = models.CellposeModel(pretrained_model=model_path, gpu=self.use_gpu.value)
651654

652655
try:
@@ -691,26 +694,24 @@ def run(self, workspace):
691694

692695
elif self.cellpose_version.value == 'v3':
693696
assert int(self.cellpose_ver[0])==3, "Cellpose version selected in RunCellpose module doesn't match version in Python"
694-
if self.mode_v3.value == 'custom':
695-
model_file = self.model_file_name.value
696-
model_directory = self.model_directory.get_absolute_path()
697-
model_path = os.path.join(model_directory, model_file)
698-
model_params = (self.mode_v3.value, self.use_gpu.value)
699-
LOGGER.info(f"Loading new model: {self.mode_v3.value}")
700-
if self.mode_v3.value in SIZED_MODELS:
697+
if self.mode.value == 'custom':
698+
model_file, model_directory, model_path = get_custom_model_vars(self)
699+
model_params = (self.mode.value, self.use_gpu.value)
700+
LOGGER.info(f"Loading new model: {self.mode.value}")
701+
if self.mode.value in SIZED_MODELS:
701702
self.current_model = models.Cellpose(
702-
model_type=self.mode_v3.value, gpu=self.use_gpu.value)
703+
model_type=self.mode.value, gpu=self.use_gpu.value)
703704
else:
704705
self.current_model = models.CellposeModel(
705-
model_type=self.mode_v3.value, gpu=self.use_gpu.value)
706+
model_type=self.mode.value, gpu=self.use_gpu.value)
706707
self.current_model_params = model_params
707708

708709
if self.denoise.value:
709710
from cellpose import denoise
710711
recon_params = (
711712
self.denoise_type.value,
712713
self.use_gpu.value,
713-
self.mode_v3.value != "nuclei" and self.supply_nuclei.value
714+
self.mode.value != "nuclei" and self.supply_nuclei.value
714715
)
715716
self.recon_model = denoise.DenoiseModel(
716717
model_type=recon_params[0],
@@ -729,7 +730,7 @@ def run(self, workspace):
729730
elif self.denoise_type.value == "upsample_nuclei":
730731
diam = 17
731732
# Result only includes input channels
732-
if self.mode_v2.value != "nuclei" and self.supply_nuclei.value:
733+
if self.mode.value != "nuclei" and self.supply_nuclei.value:
733734
channels = [0, 1]
734735
else:
735736
input_data = x_data
@@ -782,12 +783,11 @@ def run(self, workspace):
782783
os.makedirs(temp_img_dir, exist_ok=True)
783784

784785
temp_img_path = os.path.join(temp_img_dir, unique_name+".tiff")
785-
if self.mode_v2.value == "custom":
786-
model_file = self.model_file_name.value
787-
model_directory = self.model_directory.get_absolute_path()
788-
model_path = os.path.join(model_directory, model_file)
786+
if self.mode.value == "custom":
787+
model_file, model_directory, model_path = get_custom_model_vars(self)
788+
model = models.CellposeModel(pretrained_model=model_path, gpu=self.use_gpu.value)
789+
789790
temp_model_dir = os.path.join(temp_dir, "model")
790-
791791
os.makedirs(temp_model_dir, exist_ok=True)
792792
# Copy the model
793793
shutil.copy(model_path, os.path.join(temp_model_dir, model_file))
@@ -799,19 +799,27 @@ def run(self, workspace):
799799
if self.use_gpu.value:
800800
cmd += ['--gpus', 'all']
801801
cmd += ['cellpose', '--verbose', '--dir', '/data/img', '--pretrained_model']
802-
if self.mode_v2.value !='custom':
803-
cmd += [self.mode_v2.value]
802+
if self.mode.value !='custom':
803+
cmd += [self.mode.value]
804804
else:
805805
cmd += ['/data/model/' + model_file]
806-
cmd += ['--chan', str(channels[0]), '--chan2', str(channels[1]), '--diameter', str(diam)]
806+
if self.cellpose_version.value == 'v3':
807+
if self.denoise.value:
808+
cmd += ['--denoise', self.denoise_type.value]
809+
if self.cellpose_version.value in ['v2','v3']:
810+
cmd += ['--chan', str(channels[0]), '--chan2', str(channels[1]), '--diameter', str(diam)]
811+
if self.cellpose_version.value in ['v4']:
812+
if self.specify_diameter.value:
813+
cmd += ['--diameter', str(diam)]
807814
if self.use_averaging.value:
808815
cmd += ['--net_avg']
809816
if self.do_3D.value:
810817
cmd += ['--do_3D']
811818
cmd += ['--anisotropy', str(anisotropy), '--flow_threshold', str(self.flow_threshold.value), '--cellprob_threshold',
812819
str(self.cellprob_threshold.value), '--stitch_threshold', str(self.stitch_threshold.value), '--min_size', str(self.min_size.value)]
813-
if self.invert.value:
814-
cmd += ['--invert']
820+
if self.cellpose_version.value in ['v2','v3']:
821+
if self.invert.value:
822+
cmd += ['--invert']
815823
if self.remove_edge_masks.value:
816824
cmd += ['--exclude_on_edges']
817825
print(cmd)

0 commit comments

Comments
 (0)