From 76851f66be44e6dbe6b45eb66da9a02e20c5c489 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 21 May 2026 21:31:00 +0000 Subject: [PATCH] feat: add atlas annotation rule for vessel nodes parquet Agent-Logs-Url: https://github.com/khanlab/SPIMquant/sessions/b428b7e3-b550-4d82-a264-b62ff69bf10c Co-authored-by: akhanf <11492701+akhanf@users.noreply.github.com> --- spimquant/workflow/rules/counts.smk | 18 +- spimquant/workflow/rules/fieldfrac.smk | 16 +- spimquant/workflow/rules/groupstats.smk | 168 +++++++++--------- spimquant/workflow/rules/masking.smk | 86 ++++----- spimquant/workflow/rules/patches.smk | 88 ++++----- spimquant/workflow/rules/segstats.smk | 18 +- spimquant/workflow/rules/sidecars.smk | 12 +- spimquant/workflow/rules/vessels.smk | 91 ++++++++-- .../scripts/map_atlas_to_vessel_nodes.py | 22 +++ 9 files changed, 298 insertions(+), 221 deletions(-) create mode 100644 spimquant/workflow/scripts/map_atlas_to_vessel_nodes.py diff --git a/spimquant/workflow/rules/counts.smk b/spimquant/workflow/rules/counts.smk index c13ce3b..4e6a295 100644 --- a/spimquant/workflow/rules/counts.smk +++ b/spimquant/workflow/rules/counts.smk @@ -10,9 +10,6 @@ rule counts_per_voxel: suffix="regionprops.parquet", **inputs["spim"].wildcards, ), - params: - coord_column_names=config["coord_column_names"], - zarrnii_kwargs=zarrnii_in_kwargs, output: counts_nii=bids( root=root, @@ -27,13 +24,16 @@ rule counts_per_voxel: resources: mem_mb=15000, runtime=20, + params: + coord_column_names=config["coord_column_names"], + zarrnii_kwargs=zarrnii_in_kwargs, script: "../scripts/counts_per_voxel.py" rule counts_per_voxel_template: """Calculate counts per voxel based on points - in template space""" +in template space""" input: template=bids(root=root, template="{template}", suffix="anat.nii.gz"), regionprops_parquet=bids( @@ -44,8 +44,6 @@ rule counts_per_voxel_template: suffix="regionprops.parquet", **inputs["spim"].wildcards, ), - params: - coord_column_names=config["template_coord_column_names"], output: counts_nii=bids( root=root, @@ -60,13 +58,15 @@ rule counts_per_voxel_template: resources: mem_mb=64000, runtime=30, + params: + coord_column_names=config["template_coord_column_names"], script: "../scripts/counts_per_voxel_template.py" rule coloc_per_voxel_template: """Calculate coloc counts per voxel based on points - in template space""" +in template space""" input: template=bids(root=root, template="{template}", suffix="anat.nii.gz"), coloc_parquet=bids( @@ -77,8 +77,6 @@ rule coloc_per_voxel_template: suffix="coloc.parquet", **inputs["spim"].wildcards, ), - params: - coord_column_names=config["template_coloc_coord_column_names"], output: counts_nii=bids( root=root, @@ -92,5 +90,7 @@ rule coloc_per_voxel_template: resources: mem_mb=64000, runtime=30, + params: + coord_column_names=config["template_coloc_coord_column_names"], script: "../scripts/coloc_per_voxel_template.py" diff --git a/spimquant/workflow/rules/fieldfrac.smk b/spimquant/workflow/rules/fieldfrac.smk index 7a55900..faf96ba 100644 --- a/spimquant/workflow/rules/fieldfrac.smk +++ b/spimquant/workflow/rules/fieldfrac.smk @@ -1,12 +1,12 @@ rule fieldfrac: """Calculate field fraction from binary mask. - - Computes the fraction of brain tissue occupied by the segmented pathology at each - voxel by downsampling the high-resolution mask. The output resolution (level) can - differ from the input mask resolution, with the downsampling factor calculated - automatically. Field fraction values range from 0-100. - """ + +Computes the fraction of brain tissue occupied by the segmented pathology at each +voxel by downsampling the high-resolution mask. The output resolution (level) can +differ from the input mask resolution, with the downsampling factor calculated +automatically. Field fraction values range from 0-100. +""" input: mask=bids( root=root, @@ -17,8 +17,6 @@ rule fieldfrac: suffix="mask.ozx", **inputs["spim"].wildcards, ), - params: - hires_level=config["segmentation_level"], output: fieldfrac_nii=bids( root=root, @@ -33,6 +31,8 @@ rule fieldfrac: resources: mem_mb=16000, runtime=30, + params: + hires_level=config["segmentation_level"], script: "../scripts/fieldfrac.py" diff --git a/spimquant/workflow/rules/groupstats.smk b/spimquant/workflow/rules/groupstats.smk index 3e41c6f..6ff8f5f 100644 --- a/spimquant/workflow/rules/groupstats.smk +++ b/spimquant/workflow/rules/groupstats.smk @@ -9,10 +9,10 @@ participants.tsv to define contrasts. rule perform_group_stats: """Perform group-based statistical tests on segmentation statistics. - - This rule reads segstats.tsv files from all participants and performs - statistical tests based on contrasts defined in participants.tsv. - """ + +This rule reads segstats.tsv files from all participants and performs +statistical tests based on contrasts defined in participants.tsv. +""" input: segstats_tsvs=lambda wildcards: inputs["spim"].expand( bids( @@ -26,15 +26,6 @@ rule perform_group_stats: ) ), participants_tsv=os.path.join(config["bids_dir"], "participants.tsv"), - params: - contrast_column=config.get("contrast_column", None), - contrast_values=config.get("contrast_values", None), - metric_columns=expand( - "{stain}+{metric}", stain=stains_for_seg, metric=config["seg_metrics"] - ), - coloc_metric_columns=expand( - "coloc+{metric}", metric=config["coloc_seg_metrics"] - ), output: stats_tsv=bids( root=root, @@ -48,16 +39,25 @@ rule perform_group_stats: resources: mem_mb=1500, runtime=10, + params: + contrast_column=config.get("contrast_column", None), + contrast_values=config.get("contrast_values", None), + metric_columns=expand( + "{stain}+{metric}", stain=stains_for_seg, metric=config["seg_metrics"] + ), + coloc_metric_columns=expand( + "coloc+{metric}", metric=config["coloc_seg_metrics"] + ), script: "../scripts/perform_group_stats.py" rule create_stats_heatmap: """Create heatmap visualizations from group statistics results. - - This rule takes the group statistics TSV and creates heatmaps for - visualization of significant differences across brain regions. - """ + +This rule takes the group statistics TSV and creates heatmaps for +visualization of significant differences across brain regions. +""" input: stats_tsv=bids( root=root, @@ -68,13 +68,6 @@ rule create_stats_heatmap: suffix="groupstats.tsv", ), label_tsv=bids(root=root, template="{template}", seg="{seg}", suffix="dseg.tsv"), - params: - metric_columns=expand( - "{stain}+{metric}", stain=stains_for_seg, metric=config["seg_metrics"] - ), - coloc_metric_columns=expand( - "coloc+{metric}", metric=config["coloc_seg_metrics"] - ), output: heatmap_png=bids( root=root, @@ -88,16 +81,23 @@ rule create_stats_heatmap: resources: mem_mb=8000, runtime=15, + params: + metric_columns=expand( + "{stain}+{metric}", stain=stains_for_seg, metric=config["seg_metrics"] + ), + coloc_metric_columns=expand( + "coloc+{metric}", metric=config["coloc_seg_metrics"] + ), script: "../scripts/create_stats_heatmap.py" rule map_groupstats_to_template_nii: """Map group statistics to template space as NIfTI files. - - This rule paints brain regions with statistical values (e.g., t-statistics, - p-values) to create volumetric heatmaps for 3D visualization. - """ + +This rule paints brain regions with statistical values (e.g., t-statistics, +p-values) to create volumetric heatmaps for 3D visualization. +""" input: tsv=bids( root=root, @@ -109,9 +109,6 @@ rule map_groupstats_to_template_nii: ), dseg=bids(root=root, template="{template}", seg="{seg}", suffix="dseg.nii.gz"), label_tsv=bids(root=root, template="{template}", seg="{seg}", suffix="dseg.tsv"), - params: - label_column="index", - feature_column="{metric}_{stat}", output: nii=bids( root=root, @@ -126,18 +123,21 @@ rule map_groupstats_to_template_nii: resources: mem_mb=16000, runtime=15, + params: + label_column="index", + feature_column="{metric}_{stat}", script: "../scripts/map_tsv_dseg_to_nii.py" rule concat_subj_parquet: """Concatenate parquet files across all subjects. - - This rule collects regionprops.parquet or coloc.parquet files - from all participants, adds a participant_id column to - identify each subject's data, and merges with participant - metadata from participants.tsv. - """ + +This rule collects regionprops.parquet or coloc.parquet files +from all participants, adds a participant_id column to +identify each subject's data, and merges with participant +metadata from participants.tsv. +""" input: parquet_files=inputs["spim"].expand( bids( @@ -169,7 +169,7 @@ rule concat_subj_parquet: rule group_counts_per_voxel: """Calculate counts per voxel based on concatenated points - in template space""" +in template space""" input: template=bids(root=root, template="{template}", suffix="anat.nii.gz"), regionprops_parquet=bids( @@ -179,8 +179,6 @@ rule group_counts_per_voxel: desc="{desc}", suffix="regionprops.parquet", ), - params: - coord_column_names=config["template_coord_column_names"], output: counts_nii=bids( root=root, @@ -194,13 +192,15 @@ rule group_counts_per_voxel: resources: mem_mb=200000, runtime=30, + params: + coord_column_names=config["template_coord_column_names"], script: "../scripts/counts_per_voxel_template.py" rule group_coloc_counts_per_voxel: """Calculate counts per voxel based on concatenated coloc points - in template space""" +in template space""" input: template=bids(root=root, template="{template}", suffix="anat.nii.gz"), coloc_parquet=bids( @@ -210,8 +210,6 @@ rule group_coloc_counts_per_voxel: desc="{desc}", suffix="coloc.parquet", ), - params: - coord_column_names=config["template_coloc_coord_column_names"], output: counts_nii=bids( root=root, @@ -225,19 +223,21 @@ rule group_coloc_counts_per_voxel: resources: mem_mb=15000, runtime=10, + params: + coord_column_names=config["template_coloc_coord_column_names"], script: "../scripts/coloc_per_voxel_template.py" rule concat_subj_parquet_contrast: """Concatenate parquet files across subjects filtered by contrast. - - This rule collects regionprops.parquet or coloc.parquet files - from all participants, adds a participant_id column to - identify each subject's data, merges with participant - metadata from participants.tsv, and filters to include only - rows where the contrast_column matches the contrast_value. - """ + +This rule collects regionprops.parquet or coloc.parquet files +from all participants, adds a participant_id column to +identify each subject's data, merges with participant +metadata from participants.tsv, and filters to include only +rows where the contrast_column matches the contrast_value. +""" input: parquet_files=inputs["spim"].expand( bids( @@ -251,9 +251,6 @@ rule concat_subj_parquet_contrast: allow_missing=True, ), participants_tsv=os.path.join(config["bids_dir"], "participants.tsv"), - params: - contrast_column="{contrast_column}", - contrast_value="{contrast_value}", output: parquet=bids( root=root, @@ -267,13 +264,16 @@ rule concat_subj_parquet_contrast: resources: mem_mb=1500, runtime=10, + params: + contrast_column="{contrast_column}", + contrast_value="{contrast_value}", script: "../scripts/concat_subj_parquet_contrast.py" rule group_counts_per_voxel_contrast: """Calculate counts per voxel based on concatenated points - in template space, filtered by contrast""" +in template space, filtered by contrast""" input: template=bids(root=root, template="{template}", suffix="anat.nii.gz"), regionprops_parquet=bids( @@ -284,8 +284,6 @@ rule group_counts_per_voxel_contrast: contrast="{contrast_column}+{contrast_value}", suffix="regionprops.parquet", ), - params: - coord_column_names=config["template_coord_column_names"], output: counts_nii=bids( root=root, @@ -300,13 +298,15 @@ rule group_counts_per_voxel_contrast: resources: mem_mb=200000, runtime=30, + params: + coord_column_names=config["template_coord_column_names"], script: "../scripts/counts_per_voxel_template.py" rule group_coloc_counts_per_voxel_contrast: """Calculate counts per voxel based on concatenated coloc points - in template space, filtered by contrast""" +in template space, filtered by contrast""" input: template=bids(root=root, template="{template}", suffix="anat.nii.gz"), coloc_parquet=bids( @@ -317,8 +317,6 @@ rule group_coloc_counts_per_voxel_contrast: contrast="{contrast_column}+{contrast_value}", suffix="coloc.parquet", ), - params: - coord_column_names=config["template_coloc_coord_column_names"], output: counts_nii=bids( root=root, @@ -333,20 +331,22 @@ rule group_coloc_counts_per_voxel_contrast: resources: mem_mb=15000, runtime=10, + params: + coord_column_names=config["template_coloc_coord_column_names"], script: "../scripts/coloc_per_voxel_template.py" rule concat_subj_segstats_contrast: """Concatenate segstats.tsv files across subjects filtered by contrast - and compute group averages. - - This rule collects mergedsegstats.tsv files from all participants, - adds a participant_id column to identify each subject's data, - merges with participant metadata from participants.tsv, filters - to include only rows where the contrast_column matches the - contrast_value, and computes group averages for each atlas region. - """ +and compute group averages. + +This rule collects mergedsegstats.tsv files from all participants, +adds a participant_id column to identify each subject's data, +merges with participant metadata from participants.tsv, filters +to include only rows where the contrast_column matches the +contrast_value, and computes group averages for each atlas region. +""" input: segstats_tsvs=lambda wildcards: inputs["spim"].expand( bids( @@ -360,15 +360,6 @@ rule concat_subj_segstats_contrast: ) ), participants_tsv=os.path.join(config["bids_dir"], "participants.tsv"), - params: - contrast_column="{contrast_column}", - contrast_value="{contrast_value}", - metric_columns=expand( - "{stain}+{metric}", stain=stains_for_seg, metric=config["seg_metrics"] - ), - coloc_metric_columns=expand( - "coloc+{metric}", metric=config["coloc_seg_metrics"] - ), output: tsv=bids( root=root, @@ -383,17 +374,26 @@ rule concat_subj_segstats_contrast: resources: mem_mb=1500, runtime=10, + params: + contrast_column="{contrast_column}", + contrast_value="{contrast_value}", + metric_columns=expand( + "{stain}+{metric}", stain=stains_for_seg, metric=config["seg_metrics"] + ), + coloc_metric_columns=expand( + "coloc+{metric}", metric=config["coloc_seg_metrics"] + ), script: "../scripts/concat_subj_segstats_contrast.py" rule map_groupavg_segstats_to_template_nii: """Map group-averaged segstats to template space as NIfTI files. - - This rule takes the group-averaged segstats for a specific contrast - and paints brain regions with the averaged metric values to create - volumetric maps for 3D visualization. - """ + +This rule takes the group-averaged segstats for a specific contrast +and paints brain regions with the averaged metric values to create +volumetric maps for 3D visualization. +""" input: tsv=bids( root=root, @@ -406,9 +406,6 @@ rule map_groupavg_segstats_to_template_nii: ), dseg=bids(root=root, template="{template}", seg="{seg}", suffix="dseg.nii.gz"), label_tsv=bids(root=root, template="{template}", seg="{seg}", suffix="dseg.tsv"), - params: - label_column="index", - feature_column="{metric}", output: nii=bids( root=root, @@ -424,5 +421,8 @@ rule map_groupavg_segstats_to_template_nii: resources: mem_mb=16000, runtime=15, + params: + label_column="index", + feature_column="{metric}", script: "../scripts/map_tsv_dseg_to_nii.py" diff --git a/spimquant/workflow/rules/masking.smk b/spimquant/workflow/rules/masking.smk index 13a79b6..35c9dc9 100644 --- a/spimquant/workflow/rules/masking.smk +++ b/spimquant/workflow/rules/masking.smk @@ -2,7 +2,7 @@ Brain masking workflow for SPIMquant. This module creates brain tissue masks using Gaussian mixture modeling (Atropos) -combined with template-derived priors. The masks separate brain tissue from +combined with template-derived priors. The masks separate brain tissue from background and are used in subsequent intensity correction and registration steps. Key workflow stages: @@ -19,11 +19,11 @@ tissue classification results, improving mask quality especially at brain bounda rule pre_atropos: """Prepare image for Atropos segmentation. - - Downsamples and preprocesses the SPIM image for efficient tissue classification. - Applies log transformation and intensity normalization to improve GMM convergence. - Also creates an initial foreground mask to restrict computation to brain regions. - """ + +Downsamples and preprocesses the SPIM image for efficient tissue classification. +Applies log transformation and intensity normalization to improve GMM convergence. +Also creates an initial foreground mask to restrict computation to brain regions. +""" input: nii=bids( root=root, @@ -33,12 +33,6 @@ rule pre_atropos: suffix="SPIM.nii.gz", **inputs["spim"].wildcards, ), - params: - downsampling=( - "10%" - if config["sloppy"] - else config["masking"]["pre_atropos_downsampling"] - ), output: downsampled=temp( bids( @@ -62,12 +56,18 @@ rule pre_atropos: **inputs["spim"].wildcards, ), ), + conda: + "../envs/c3d.yaml" threads: 1 resources: mem_mb=3000, runtime=15, - conda: - "../envs/c3d.yaml" + params: + downsampling=( + "10%" + if config["sloppy"] + else config["masking"]["pre_atropos_downsampling"] + ), shell: "c3d {input.nii} -resample {params.downsampling} -shift 1 -log -stretch 2% 98% 0 100 -clip 0 100 -o {output.downsampled} -scale 0 -shift 1 -o {output.mask}" @@ -75,21 +75,16 @@ rule pre_atropos: rule atropos_seg: """Perform tissue classification using Atropos (k-class GMM). - Uses ANTs Atropos to classify tissue into k intensity classes via Gaussian - mixture modeling with Markov random field (MRF) spatial smoothing. Outputs - a discrete segmentation and posterior probability maps for each class. +Uses ANTs Atropos to classify tissue into k intensity classes via Gaussian +mixture modeling with Markov random field (MRF) spatial smoothing. Outputs +a discrete segmentation and posterior probability maps for each class. - Automatically decrements k from init_k down to min_k if Atropos fails, - handling cases where the image lacks enough distinct intensity classes. - """ +Automatically decrements k from init_k down to min_k if Atropos fails, +handling cases where the image lacks enough distinct intensity classes. +""" input: downsampled=rules.pre_atropos.output.downsampled, mask=rules.pre_atropos.output.mask, - params: - mrf_smoothing=0.3, - mrf_radius="2x2x2", - init_k=config["masking"]["gmm_init_k"], - min_k=config["masking"]["gmm_min_k"], output: dseg=temp( bids( @@ -121,6 +116,11 @@ rule atropos_seg: resources: mem_mb=32000, runtime=45, + params: + mrf_smoothing=0.3, + mrf_radius="2x2x2", + init_k=config["masking"]["gmm_init_k"], + min_k=config["masking"]["gmm_min_k"], script: "../scripts/atropos_seg.py" @@ -148,23 +148,23 @@ rule post_atropos: **inputs["spim"].wildcards, ), ), + conda: + "../envs/c3d.yaml" threads: 1 resources: mem_mb=3000, runtime=15, - conda: - "../envs/c3d.yaml" shell: "c3d -interpolation NearestNeighbor {input.ref} {input.dseg} -reslice-identity -o {output.dseg}" rule init_affine_reg: """Perform initial affine registration for obtaining mask priors. - - Registers subject SPIM to template using 12-DOF affine transformation. - This initial alignment enables template brain masks to be warped to subject - space as priors for brain masking, even before final registration. - """ + +Registers subject SPIM to template using 12-DOF affine transformation. +This initial alignment enables template brain masks to be warped to subject +space as priors for brain masking, even before final registration. +""" input: template=get_template_for_reg, subject=bids( @@ -175,8 +175,6 @@ rule init_affine_reg: suffix="SPIM.nii.gz", **inputs["spim"].wildcards, ), - params: - iters="10x0x0" if config["sloppy"] else "100x100", output: xfm_ras=temp( bids( @@ -212,6 +210,8 @@ rule init_affine_reg: resources: mem_mb=16000, runtime=15, + params: + iters="10x0x0" if config["sloppy"] else "100x100", shell: "greedy -threads {threads} -d 3 -i {input.template} {input.subject} " " -a -dof 12 -ia-image-centers -m NMI -o {output.xfm_ras} -n {params.iters} && " @@ -256,11 +256,11 @@ rule affine_transform_template_mask_to_subject: rule create_mask_from_gmm_and_prior: """Create final brain mask by combining GMM classes with template prior. - - Combines tissue classification results from Atropos with the template-derived - brain mask to create a refined brain mask. Uses both intensity-based tissue - classification and spatial prior information for improved accuracy. - """ + +Combines tissue classification results from Atropos with the template-derived +brain mask to create a refined brain mask. Uses both intensity-based tissue +classification and spatial prior information for improved accuracy. +""" input: tissue_dseg=bids( root=root, @@ -308,8 +308,6 @@ rule create_mask_from_gmm: suffix="dseg.nii", **inputs["spim"].wildcards, ), - params: - bg_label=config["masking"]["gmm_bg_class"], output: mask=bids( root=root, @@ -320,11 +318,13 @@ rule create_mask_from_gmm: suffix="mask.nii.gz", **inputs["spim"].wildcards, ), + conda: + "../envs/c3d.yaml" threads: 1 resources: mem_mb=4000, runtime=15, - conda: - "../envs/c3d.yaml" + params: + bg_label=config["masking"]["gmm_bg_class"], shell: "c3d {input} -threshold {params.bg_label} {params.bg_label} 0 1 -o {output}" diff --git a/spimquant/workflow/rules/patches.smk b/spimquant/workflow/rules/patches.smk index beed675..8ba2f7c 100644 --- a/spimquant/workflow/rules/patches.smk +++ b/spimquant/workflow/rules/patches.smk @@ -23,13 +23,13 @@ Patches are named with atlas abbreviation and patch number for easy identificati rule create_spim_patches: """Create patches from SPIM zarr data based on atlas regions. - This rule extracts fixed-size patches from the SPIM zarr data at locations - sampled from specified atlas regions. Patches are saved as NIfTI files - named with the atlas, label abbreviation, and patch number. +This rule extracts fixed-size patches from the SPIM zarr data at locations +sampled from specified atlas regions. Patches are saved as NIfTI files +named with the atlas, label abbreviation, and patch number. - Level in the output file is the resolution of the patches, (e.g. level 0) - but level in the dseg input is the (downsampled) registration_level. - """ +Level in the output file is the resolution of the patches, (e.g. level 0) +but level in the dseg input is the (downsampled) registration_level. +""" input: spim=inputs["spim"].path, dseg=bids( @@ -42,14 +42,6 @@ rule create_spim_patches: **inputs["spim"].wildcards, ), label_tsv=bids(root=root, template="{template}", seg="{seg}", suffix="dseg.tsv"), - params: - patch_size=config.get("patch_size", [256, 256, 256]), - n_patches=config.get("n_patches_per_label", 10), - patch_labels=config.get("patch_labels", None), - hires_level=0, #input is the raw data - seed=config.get("patch_seed", 42), - zarrnii_kwargs=zarrnii_in_kwargs, - patch_uint8=not config.get("no_patch_uint8", False), output: patches_dir=directory( bids( @@ -68,6 +60,14 @@ rule create_spim_patches: resources: mem_mb=32000, runtime=30, + params: + patch_size=config.get("patch_size", [256, 256, 256]), + n_patches=config.get("n_patches_per_label", 10), + patch_labels=config.get("patch_labels", None), + hires_level=0, #input is the raw data + seed=config.get("patch_seed", 42), + zarrnii_kwargs=zarrnii_in_kwargs, + patch_uint8=not config.get("no_patch_uint8", False), script: "../scripts/create_patches.py" @@ -75,10 +75,10 @@ rule create_spim_patches: rule create_mask_patches: """Create patches from segmentation mask zarr data based on atlas regions. - This rule extracts fixed-size patches from the cleaned segmentation mask - zarr data at locations sampled from specified atlas regions. Patches are - saved as NIfTI files named with the atlas, label abbreviation, and patch number. - """ +This rule extracts fixed-size patches from the cleaned segmentation mask +zarr data at locations sampled from specified atlas regions. Patches are +saved as NIfTI files named with the atlas, label abbreviation, and patch number. +""" input: mask=bids( root=root, @@ -99,13 +99,6 @@ rule create_mask_patches: **inputs["spim"].wildcards, ), label_tsv=bids(root=root, template="{template}", seg="{seg}", suffix="dseg.tsv"), - params: - patch_size=config.get("patch_size", [256, 256, 256]), - n_patches=config.get("n_patches_per_label", 10), - patch_labels=config.get("patch_labels", None), - seed=config.get("patch_seed", 42), - hires_level=config["segmentation_level"], - patch_uint8=not config.get("no_patch_uint8", False), output: patches_dir=directory( bids( @@ -124,6 +117,13 @@ rule create_mask_patches: resources: mem_mb=32000, runtime=30, + params: + patch_size=config.get("patch_size", [256, 256, 256]), + n_patches=config.get("n_patches_per_label", 10), + patch_labels=config.get("patch_labels", None), + seed=config.get("patch_seed", 42), + hires_level=config["segmentation_level"], + patch_uint8=not config.get("no_patch_uint8", False), script: "../scripts/create_patches.py" @@ -131,10 +131,10 @@ rule create_mask_patches: rule create_corrected_spim_patches: """Create patches from corrected SPIM zarr data based on atlas regions. - This rule extracts fixed-size patches from the intensity-corrected SPIM - zarr data at locations sampled from specified atlas regions. Patches are - saved as NIfTI files named with the atlas, label abbreviation, and patch number. - """ +This rule extracts fixed-size patches from the intensity-corrected SPIM +zarr data at locations sampled from specified atlas regions. Patches are +saved as NIfTI files named with the atlas, label abbreviation, and patch number. +""" input: corrected=bids( root=work, @@ -155,13 +155,6 @@ rule create_corrected_spim_patches: **inputs["spim"].wildcards, ), label_tsv=bids(root=root, template="{template}", seg="{seg}", suffix="dseg.tsv"), - params: - patch_size=config.get("patch_size", [256, 256, 256]), - n_patches=config.get("n_patches_per_label", 10), - patch_labels=config.get("patch_labels", None), - seed=config.get("patch_seed", 42), - hires_level=config["segmentation_level"], - patch_uint8=not config.get("no_patch_uint8", False), output: patches_dir=directory( bids( @@ -180,6 +173,13 @@ rule create_corrected_spim_patches: resources: mem_mb=32000, runtime=30, + params: + patch_size=config.get("patch_size", [256, 256, 256]), + n_patches=config.get("n_patches_per_label", 10), + patch_labels=config.get("patch_labels", None), + seed=config.get("patch_seed", 42), + hires_level=config["segmentation_level"], + patch_uint8=not config.get("no_patch_uint8", False), script: "../scripts/create_patches.py" @@ -187,10 +187,10 @@ rule create_corrected_spim_patches: rule create_imaris_crops: """Create high-resolution Imaris datasets from SPIM data based on atlas region bounding boxes. - This rule extracts crops from SPIM zarr data based on bounding boxes of - specified atlas regions. Crops are saved as Imaris datasets using zarrnii's - to_imaris() function. Level defaults to 0 for high-resolution output. - """ +This rule extracts crops from SPIM zarr data based on bounding boxes of +specified atlas regions. Crops are saved as Imaris datasets using zarrnii's +to_imaris() function. Level defaults to 0 for high-resolution output. +""" input: spim=inputs["spim"].path, dseg=bids( @@ -203,10 +203,6 @@ rule create_imaris_crops: **inputs["spim"].wildcards, ), label_tsv=bids(root=root, template="{template}", seg="{seg}", suffix="dseg.tsv"), - params: - crop_labels=config.get("crop_labels", None), - hires_level=0, # input is the raw data - zarrnii_kwargs=zarrnii_in_kwargs, output: crops_dir=directory( bids( @@ -224,5 +220,9 @@ rule create_imaris_crops: resources: mem_mb=32000, runtime=60, + params: + crop_labels=config.get("crop_labels", None), + hires_level=0, # input is the raw data + zarrnii_kwargs=zarrnii_in_kwargs, script: "../scripts/create_imaris_crops.py" diff --git a/spimquant/workflow/rules/segstats.smk b/spimquant/workflow/rules/segstats.smk index 38c5a0f..187c9d4 100644 --- a/spimquant/workflow/rules/segstats.smk +++ b/spimquant/workflow/rules/segstats.smk @@ -12,8 +12,6 @@ rule map_regionprops_to_atlas_rois: **inputs["spim"].wildcards, ), label_tsv=bids(root=root, template="{template}", seg="{seg}", suffix="dseg.tsv"), - params: - coord_column_names=config["coord_column_names"], output: regionprops_tsv=bids( root=root, @@ -43,6 +41,8 @@ rule map_regionprops_to_atlas_rois: resources: mem_mb=32000, runtime=15, + params: + coord_column_names=config["coord_column_names"], script: "../scripts/map_atlas_to_regionprops.py" @@ -59,8 +59,6 @@ rule map_coloc_to_atlas_rois: ), dseg=bids(root=root, template="{template}", seg="{seg}", suffix="dseg.nii.gz"), label_tsv=bids(root=root, template="{template}", seg="{seg}", suffix="dseg.tsv"), - params: - coord_column_names=["template_coloc_x", "template_coloc_y", "template_coloc_z"], output: coloc_tsv=temp( bids( @@ -88,6 +86,8 @@ rule map_coloc_to_atlas_rois: resources: mem_mb=32000, runtime=15, + params: + coord_column_names=["template_coloc_x", "template_coloc_y", "template_coloc_z"], script: "../scripts/map_atlas_to_coloc.py" @@ -150,7 +150,7 @@ rule merge_into_segstats_tsv: rule merge_into_colocsegstats_tsv: - """ also includes fieldfracstats.tsv to obtain the volume to turn count into density""" + """also includes fieldfracstats.tsv to obtain the volume to turn count into density""" input: coloc_tsv=bids( root=root, @@ -181,8 +181,6 @@ rule merge_into_colocsegstats_tsv: suffix="fieldfracstats.tsv", **inputs["spim"].wildcards, ), - params: - columns_to_drop=["fieldfrac"], output: tsv=temp( bids( @@ -199,6 +197,8 @@ rule merge_into_colocsegstats_tsv: resources: mem_mb=1500, runtime=15, + params: + columns_to_drop=["fieldfrac"], script: "../scripts/merge_into_segstats_tsv.py" @@ -239,8 +239,6 @@ rule merge_indiv_and_coloc_segstats_tsv: stain=stains_for_seg, allow_missing=True, ), - params: - stains=stains_for_seg, output: merged_tsv=bids( root=root, @@ -255,5 +253,7 @@ rule merge_indiv_and_coloc_segstats_tsv: resources: mem_mb=1500, runtime=15, + params: + stains=stains_for_seg, script: "../scripts/merge_indiv_and_coloc_segstats_tsv.py" diff --git a/spimquant/workflow/rules/sidecars.smk b/spimquant/workflow/rules/sidecars.smk index ee956c4..cf7e9a7 100644 --- a/spimquant/workflow/rules/sidecars.smk +++ b/spimquant/workflow/rules/sidecars.smk @@ -12,13 +12,13 @@ rule create_tabular_json_sidecar_from_tsv: "{prefix}.tsv", output: "{prefix}.json", - params: - column_descriptions=config["column_descriptions"], - stats_maps=config["stats_maps"], threads: 1 resources: mem_mb=500, runtime=5, + params: + column_descriptions=config["column_descriptions"], + stats_maps=config["stats_maps"], script: "../scripts/create_json_sidecar.py" @@ -29,12 +29,12 @@ rule create_tabular_json_sidecar_from_parquet: "{prefix}.parquet", output: "{prefix}.json", - params: - column_descriptions=config["column_descriptions"], - stats_maps=config["stats_maps"], threads: 1 resources: mem_mb=500, runtime=5, + params: + column_descriptions=config["column_descriptions"], + stats_maps=config["stats_maps"], script: "../scripts/create_json_sidecar.py" diff --git a/spimquant/workflow/rules/vessels.smk b/spimquant/workflow/rules/vessels.smk index 4528b96..7efde7e 100644 --- a/spimquant/workflow/rules/vessels.smk +++ b/spimquant/workflow/rules/vessels.smk @@ -12,12 +12,6 @@ rule run_vesselfm: input: spim=inputs["spim"].path, model_path="resources/models/vesselfm.pt", - params: - zarrnii_kwargs=zarrnii_in_kwargs, - vesselfm_kwargs=lambda wildcards, input: { - "chunk_size": (1, 128, 128, 128), - "model_path": input.model_path, - }, output: mask=bids( root=root, @@ -34,6 +28,12 @@ rule run_vesselfm: cpus_per_gpu=32, mem_mb=64000, runtime=lambda wildcards: max(1, int(200.0 / (3.0 ** float(wildcards.level)))), # rough estimate, clamped to >=1 + params: + zarrnii_kwargs=zarrnii_in_kwargs, + vesselfm_kwargs=lambda wildcards, input: { + "chunk_size": (1, 128, 128, 128), + "model_path": input.model_path, + }, script: "../scripts/vesselfm.py" @@ -41,12 +41,12 @@ rule run_vesselfm: rule signed_distance_transform: """Compute signed distance transform from a binary mask. - Applies the chamfer distance transform (distance_transform_cdt from scipy) - to a binary mask using dask map_overlap for chunked, parallel processing. - The output is a signed distance transform computed as dt_outside - dt_inside, - where negative values indicate the interior and positive values indicate - the exterior of the mask. - """ +Applies the chamfer distance transform (distance_transform_cdt from scipy) +to a binary mask using dask map_overlap for chunked, parallel processing. +The output is a signed distance transform computed as dt_outside - dt_inside, +where negative values indicate the interior and positive values indicate +the exterior of the mask. +""" input: mask=bids( root=root, @@ -57,8 +57,6 @@ rule signed_distance_transform: suffix="mask.ozx", **inputs["spim"].wildcards, ), - params: - overlap_depth=32, output: dist=bids( root=root, @@ -74,6 +72,8 @@ rule signed_distance_transform: mem_mb=256000, disk_mb=2097152, runtime=360, + params: + overlap_depth=32, script: "../scripts/signed_distance_transform.py" @@ -90,8 +90,6 @@ rule skeletonize_vessels_mask: suffix="mask.ozx", **inputs["spim"].wildcards, ), - params: - overlap_depth=32, output: mask=bids( root=root, @@ -107,6 +105,8 @@ rule skeletonize_vessels_mask: mem_mb=256000, disk_mb=2097152, runtime=360, + params: + overlap_depth=32, script: "../scripts/skeletonize_vessels_mask.py" @@ -132,8 +132,6 @@ rule vessel_skeleton_graph: suffix="dist.ozx", **inputs["spim"].wildcards, ), - params: - overlap_depth=32, output: graph_parquet=bids( root=root, @@ -149,6 +147,8 @@ rule vessel_skeleton_graph: mem_mb=256000, disk_mb=2097152, runtime=360, + params: + overlap_depth=32, script: "../scripts/skeleton_graph_from_sdt.py" @@ -190,3 +190,58 @@ rule vessel_graph_to_nodes_edges: runtime=360, script: "../scripts/convert_vessel_graph_to_nodes_edges.py" + + +rule map_atlas_to_vessel_nodes: + """Map atlas regions to vessel nodes using physical coordinates.""" + input: + nodes_parquet=bids( + root=root, + datatype="vessels", + stain="{stain}", + level="{level}", + desc="{desc}+skeleton", + suffix="nodes.parquet", + **inputs["spim"].wildcards, + ), + dseg=bids( + root=root, + datatype="parc", + seg="{seg}", + level="{level}", + from_="{template}", + suffix="dseg.nii.gz", + **inputs["spim"].wildcards, + ), + label_tsv=bids(root=root, template="{template}", seg="{seg}", suffix="dseg.tsv"), + output: + nodes_parquet=bids( + root=root, + datatype="vessels", + stain="{stain}", + level="{level}", + desc="{desc}+skeleton", + seg="{seg}", + from_="{template}", + suffix="nodes.parquet", + **inputs["spim"].wildcards, + ), + counts_tsv=bids( + root=root, + datatype="vessels", + stain="{stain}", + level="{level}", + desc="{desc}+skeleton", + seg="{seg}", + from_="{template}", + suffix="nodecounts.tsv", + **inputs["spim"].wildcards, + ), + threads: 4 + resources: + mem_mb=32000, + runtime=15, + params: + coord_column_names=["x", "y", "z"], + script: + "../scripts/map_atlas_to_vessel_nodes.py" diff --git a/spimquant/workflow/scripts/map_atlas_to_vessel_nodes.py b/spimquant/workflow/scripts/map_atlas_to_vessel_nodes.py new file mode 100644 index 0000000..08a8247 --- /dev/null +++ b/spimquant/workflow/scripts/map_atlas_to_vessel_nodes.py @@ -0,0 +1,22 @@ +"""Map vessel node coordinates to atlas regions and generate labeled statistics. + +This script takes vessel node coordinates (from a nodes parquet file) and maps +them to atlas regions using a ZarrNiiAtlas object, generating two outputs: +1. Annotated nodes parquet with atlas region labels assigned to each node +2. Count statistics showing the number of nodes per atlas region +""" + +import pandas as pd +from zarrnii import ZarrNiiAtlas + +nodes = pd.read_parquet(snakemake.input.nodes_parquet).to_dict(orient="list") +atlas = ZarrNiiAtlas.from_files(snakemake.input.dseg, snakemake.input.label_tsv) + +df_nodes, df_counts = atlas.label_region_properties( + nodes, + coord_column_names=snakemake.params.coord_column_names, + include_names=True, +) + +df_nodes.to_parquet(snakemake.output.nodes_parquet, index=False) +df_counts.to_csv(snakemake.output.counts_tsv, sep="\t", index=False)