diff --git a/spimquant/config/snakebids.yml b/spimquant/config/snakebids.yml index 795b33d..7fd01cc 100644 --- a/spimquant/config/snakebids.yml +++ b/spimquant/config/snakebids.yml @@ -250,15 +250,20 @@ parse_args: action: store nargs: '+' - --contrast_column: - help: "Column name in participants.tsv to use for defining group contrasts (e.g., 'treatment', 'genotype'). Required for group-level statistical analysis." + --model: + help: "Statistical model formula in patsy/statsmodels format, with 'metric' as a placeholder for the response variable (e.g. 'metric ~ C(treatment) * C(genotype) + age'). Required for group-level analysis." default: null type: str - --contrast_values: - help: "Two group values for contrast comparison (e.g., 'control' 'drug'). Used with --contrast_column for statistical testing. Provide exactly 2 values." + --pairwise: + help: "Factor(s) in participants.tsv for which all pairwise comparisons should be computed. Can be specified multiple times (e.g. '--pairwise treatment --pairwise genotype')." default: null - nargs: 2 + nargs: '+' + + --within: + help: "Factors defining strata within which pairwise contrasts are computed (e.g. '--within genotype sex'). All combinations of these factor levels will be used as separate strata." + default: null + nargs: '+' #--- workflow specific configuration -- diff --git a/spimquant/workflow/Snakefile b/spimquant/workflow/Snakefile index 0a123fa..5cd2ad1 100644 --- a/spimquant/workflow/Snakefile +++ b/spimquant/workflow/Snakefile @@ -1,5 +1,7 @@ import os +from itertools import combinations as _combinations, product as _product +import pandas as _pd from zarrnii import ZarrNii from snakemake.utils import format from snakebids import bids, generate_inputs, get_wildcard_constraints, set_bids_spec @@ -141,19 +143,80 @@ for seg in config["crop_atlas_segs"]: ) else: crop_atlas_segs.append(seg) -# Validate that contrast arguments are provided when using group analysis level +# Validate that model/pairwise arguments are provided when using group analysis level +# and generate pairwise contrast labels from participants.tsv at planning time +pairwise_contrast_labels = [] +pairwise_contrast_info = {} # label -> dict with factor, levelA, levelB, strata + if config.get("analysis_level") == "group": - if config.get("contrast_column") is None or config.get("contrast_values") is None: + if not config.get("model"): + raise ValueError( + "When using group analysis level, --model must be specified " + "(e.g. 'metric ~ C(treatment) + age')." + ) + if not config.get("pairwise"): raise ValueError( - "When using group analysis level, both --contrast_column and " - "--contrast_values must be specified for filtering group data by contrasts." + "When using group analysis level, --pairwise must be specified " + "(e.g. '--pairwise treatment')." ) + _participants_tsv = os.path.join(config["bids_dir"], "participants.tsv") + if not os.path.exists(_participants_tsv): + raise ValueError( + f"participants.tsv not found at '{_participants_tsv}'. " + "This file is required for group-level analysis." + ) + _participants_df = _pd.read_csv(_participants_tsv, sep="\t") + + _pairwise_factors = config["pairwise"] + _within_factors = config.get("within") or [] + + for _factor in _pairwise_factors: + if _factor not in _participants_df.columns: + raise ValueError( + f"Pairwise factor '{_factor}' not found in participants.tsv. " + f"Available columns: {list(_participants_df.columns)}" + ) + _levels = sorted(str(v) for v in _participants_df[_factor].dropna().unique()) + + if _within_factors: + _within_level_lists = [] + for _wf in _within_factors: + if _wf not in _participants_df.columns: + raise ValueError( + f"Within factor '{_wf}' not found in participants.tsv. " + f"Available columns: {list(_participants_df.columns)}" + ) + _within_level_lists.append( + sorted(str(v) for v in _participants_df[_wf].dropna().unique()) + ) + for _lA, _lB in _combinations(_levels, 2): + for _stratum in _product(*_within_level_lists): + _strata_dict = dict(zip(_within_factors, _stratum)) + _strata_str = "+".join(f"{f}-{v}" for f, v in _strata_dict.items()) + _label = f"{_factor}+{_lA}vs{_lB}+{_strata_str}" + pairwise_contrast_labels.append(_label) + pairwise_contrast_info[_label] = { + "factor": _factor, + "levelA": _lA, + "levelB": _lB, + "strata": _strata_dict, + } + else: + for _lA, _lB in _combinations(_levels, 2): + _label = f"{_factor}+{_lA}vs{_lB}" + pairwise_contrast_labels.append(_label) + pairwise_contrast_info[_label] = { + "factor": _factor, + "levelA": _lA, + "levelB": _lB, + "strata": {}, + } + wildcard_constraints: stain="[a-zA-Z0-9]+", - contrast_column="[a-zA-Z0-9_]+", - contrast_value="[a-zA-Z0-9_]+", + pairwise_contrast="[a-zA-Z0-9+_-]+", rule all_templatereg_deform_zooms: @@ -511,6 +574,7 @@ rule all_imaris_crops: rule all_group_stats: """Target rule for group-level statistical analysis.""" input: + # Per-contrast groupstats TSV and PNG expand( bids( root=root, @@ -518,13 +582,16 @@ rule all_group_stats: seg="{seg}", from_="{template}", desc="{desc}", + contrast="{pairwise_contrast}", suffix="groupstats.{ext}", ), seg=atlas_segs, desc=config["seg_method"], template=config["template"], + pairwise_contrast=pairwise_contrast_labels, ext=["png", "tsv"], ), + # Per-contrast NIfTI stat maps for each seg metric expand( bids( root=root, @@ -532,6 +599,7 @@ rule all_group_stats: seg="{seg}", space="{template}", desc="{desc}", + contrast="{pairwise_contrast}", metric="{stain}+{metric}", suffix="{stat}.nii.gz", ), @@ -541,7 +609,9 @@ rule all_group_stats: stain=stains_for_seg, metric=config["seg_metrics"], stat=config["stats_maps"], + pairwise_contrast=pairwise_contrast_labels, ), + # All-subjects count maps (not contrast-specific) expand( bids( root=root, @@ -556,48 +626,11 @@ rule all_group_stats: level=range(4), stain=stains_for_seg, ), - # Add contrast-filtered outputs - expand( - bids( - root=root, - datatype="group", - level="{level}", - space="{template}", - desc="{desc}", - contrast="{contrast_column}+{contrast_value}", - suffix="{stain}+count.nii.gz", - ), - desc=config["seg_method"], - template=config["template"], - level=range(4), - stain=stains_for_seg, - contrast_column=config["contrast_column"], - contrast_value=config["contrast_values"], - ), - # Add group-averaged segstats maps by contrast - expand( - bids( - root=root, - datatype="group", - seg="{seg}", - space="{template}", - desc="{desc}", - contrast="{contrast_column}+{contrast_value}", - metric="{stain}+{metric}", - suffix="groupavg.nii.gz", - ), - seg=atlas_segs, - desc=config["seg_method"], - template=config["template"], - stain=stains_for_seg, - metric=config["seg_metrics"], - contrast_column=config["contrast_column"], - contrast_value=config["contrast_values"], - ), rule all_group_stats_coloc: input: + # Per-contrast NIfTI stat maps for coloc metrics expand( bids( root=root, @@ -605,6 +638,7 @@ rule all_group_stats_coloc: seg="{seg}", space="{template}", desc="{desc}", + contrast="{pairwise_contrast}", metric="coloc+{metric}", suffix="{stat}.nii.gz", ), @@ -613,7 +647,9 @@ rule all_group_stats_coloc: template=config["template"], metric=config["coloc_seg_metrics"], stat=config["stats_maps"], + pairwise_contrast=pairwise_contrast_labels, ), + # All-subjects colocalization count maps (not contrast-specific) expand( bids( root=root, @@ -627,40 +663,6 @@ rule all_group_stats_coloc: template=config["template"], level=range(4), ), - expand( - bids( - root=root, - datatype="group", - space="{template}", - level="{level}", - desc="{desc}", - contrast="{contrast_column}+{contrast_value}", - suffix="coloccount.nii.gz", - ), - desc=config["seg_method"], - template=config["template"], - level=range(4), - contrast_column=config["contrast_column"], - contrast_value=config["contrast_values"], - ), - expand( - bids( - root=root, - datatype="group", - seg="{seg}", - space="{template}", - desc="{desc}", - contrast="{contrast_column}+{contrast_value}", - metric="coloc+{metric}", - suffix="groupavg.nii.gz", - ), - seg=atlas_segs, - desc=config["seg_method"], - template=config["template"], - metric=config["coloc_seg_metrics"], - contrast_column=config["contrast_column"], - contrast_value=config["contrast_values"], - ), rule all_qc: diff --git a/spimquant/workflow/rules/groupstats.smk b/spimquant/workflow/rules/groupstats.smk index 3e41c6f..0691394 100644 --- a/spimquant/workflow/rules/groupstats.smk +++ b/spimquant/workflow/rules/groupstats.smk @@ -1,18 +1,19 @@ """ Group-level statistical analysis rules for SPIMquant. -This module performs group-based statistical tests on segmentation statistics -(e.g., fieldfrac, density, volume) across participants, using metadata from -participants.tsv to define contrasts. +This module performs formula-based statistical modelling on segmentation +statistics (e.g., fieldfrac, density, volume) across participants, using +metadata from participants.tsv to fit OLS models and compute pairwise contrasts. """ rule perform_group_stats: - """Perform group-based statistical tests on segmentation statistics. - - This rule reads segstats.tsv files from all participants and performs - statistical tests based on contrasts defined in participants.tsv. - """ + """Perform formula-based group statistical tests on segmentation statistics. + +Fits a single global OLS model per region/metric using the user-supplied +formula, then computes pairwise contrast statistics (t-stat, p-value, +Cohen's d) for the specified contrast using the model's covariance matrix. +""" input: segstats_tsvs=lambda wildcards: inputs["spim"].expand( bids( @@ -26,15 +27,6 @@ rule perform_group_stats: ) ), participants_tsv=os.path.join(config["bids_dir"], "participants.tsv"), - params: - contrast_column=config.get("contrast_column", None), - contrast_values=config.get("contrast_values", None), - metric_columns=expand( - "{stain}+{metric}", stain=stains_for_seg, metric=config["seg_metrics"] - ), - coloc_metric_columns=expand( - "coloc+{metric}", metric=config["coloc_seg_metrics"] - ), output: stats_tsv=bids( root=root, @@ -42,22 +34,35 @@ rule perform_group_stats: seg="{seg}", from_="{template}", desc="{desc}", + contrast="{pairwise_contrast}", suffix="groupstats.tsv", ), threads: 1 resources: mem_mb=1500, runtime=10, + params: + model=config.get("model", None), + pairwise_contrast_info=lambda wc: pairwise_contrast_info.get( + wc.pairwise_contrast, {} + ), + within_factors=config.get("within") or [], + metric_columns=expand( + "{stain}+{metric}", stain=stains_for_seg, metric=config["seg_metrics"] + ), + coloc_metric_columns=expand( + "coloc+{metric}", metric=config["coloc_seg_metrics"] + ), script: "../scripts/perform_group_stats.py" rule create_stats_heatmap: """Create heatmap visualizations from group statistics results. - - This rule takes the group statistics TSV and creates heatmaps for - visualization of significant differences across brain regions. - """ + +This rule takes the group statistics TSV and creates heatmaps for +visualization of significant differences across brain regions. +""" input: stats_tsv=bids( root=root, @@ -65,16 +70,10 @@ rule create_stats_heatmap: seg="{seg}", from_="{template}", desc="{desc}", + contrast="{pairwise_contrast}", suffix="groupstats.tsv", ), label_tsv=bids(root=root, template="{template}", seg="{seg}", suffix="dseg.tsv"), - params: - metric_columns=expand( - "{stain}+{metric}", stain=stains_for_seg, metric=config["seg_metrics"] - ), - coloc_metric_columns=expand( - "coloc+{metric}", metric=config["coloc_seg_metrics"] - ), output: heatmap_png=bids( root=root, @@ -82,22 +81,30 @@ rule create_stats_heatmap: seg="{seg}", from_="{template}", desc="{desc}", + contrast="{pairwise_contrast}", suffix="groupstats.png", ), threads: 1 resources: mem_mb=8000, runtime=15, + params: + metric_columns=expand( + "{stain}+{metric}", stain=stains_for_seg, metric=config["seg_metrics"] + ), + coloc_metric_columns=expand( + "coloc+{metric}", metric=config["coloc_seg_metrics"] + ), script: "../scripts/create_stats_heatmap.py" rule map_groupstats_to_template_nii: """Map group statistics to template space as NIfTI files. - - This rule paints brain regions with statistical values (e.g., t-statistics, - p-values) to create volumetric heatmaps for 3D visualization. - """ + +This rule paints brain regions with statistical values (e.g., t-statistics, +p-values) to create volumetric heatmaps for 3D visualization. +""" input: tsv=bids( root=root, @@ -105,13 +112,11 @@ rule map_groupstats_to_template_nii: seg="{seg}", from_="{template}", desc="{desc}", + contrast="{pairwise_contrast}", suffix="groupstats.tsv", ), dseg=bids(root=root, template="{template}", seg="{seg}", suffix="dseg.nii.gz"), label_tsv=bids(root=root, template="{template}", seg="{seg}", suffix="dseg.tsv"), - params: - label_column="index", - feature_column="{metric}_{stat}", output: nii=bids( root=root, @@ -119,6 +124,7 @@ rule map_groupstats_to_template_nii: seg="{seg}", space="{template}", desc="{desc}", + contrast="{pairwise_contrast}", metric="{metric}", suffix="{stat}.nii.gz", ), @@ -126,18 +132,21 @@ rule map_groupstats_to_template_nii: resources: mem_mb=16000, runtime=15, + params: + label_column="index", + feature_column="{metric}_{stat}", script: "../scripts/map_tsv_dseg_to_nii.py" rule concat_subj_parquet: """Concatenate parquet files across all subjects. - - This rule collects regionprops.parquet or coloc.parquet files - from all participants, adds a participant_id column to - identify each subject's data, and merges with participant - metadata from participants.tsv. - """ + +This rule collects regionprops.parquet or coloc.parquet files +from all participants, adds a participant_id column to +identify each subject's data, and merges with participant +metadata from participants.tsv. +""" input: parquet_files=inputs["spim"].expand( bids( @@ -169,7 +178,7 @@ rule concat_subj_parquet: rule group_counts_per_voxel: """Calculate counts per voxel based on concatenated points - in template space""" +in template space""" input: template=bids(root=root, template="{template}", suffix="anat.nii.gz"), regionprops_parquet=bids( @@ -179,8 +188,6 @@ rule group_counts_per_voxel: desc="{desc}", suffix="regionprops.parquet", ), - params: - coord_column_names=config["template_coord_column_names"], output: counts_nii=bids( root=root, @@ -194,13 +201,15 @@ rule group_counts_per_voxel: resources: mem_mb=200000, runtime=30, + params: + coord_column_names=config["template_coord_column_names"], script: "../scripts/counts_per_voxel_template.py" rule group_coloc_counts_per_voxel: """Calculate counts per voxel based on concatenated coloc points - in template space""" +in template space""" input: template=bids(root=root, template="{template}", suffix="anat.nii.gz"), coloc_parquet=bids( @@ -210,8 +219,6 @@ rule group_coloc_counts_per_voxel: desc="{desc}", suffix="coloc.parquet", ), - params: - coord_column_names=config["template_coloc_coord_column_names"], output: counts_nii=bids( root=root, @@ -225,204 +232,7 @@ rule group_coloc_counts_per_voxel: resources: mem_mb=15000, runtime=10, - script: - "../scripts/coloc_per_voxel_template.py" - - -rule concat_subj_parquet_contrast: - """Concatenate parquet files across subjects filtered by contrast. - - This rule collects regionprops.parquet or coloc.parquet files - from all participants, adds a participant_id column to - identify each subject's data, merges with participant - metadata from participants.tsv, and filters to include only - rows where the contrast_column matches the contrast_value. - """ - input: - parquet_files=inputs["spim"].expand( - bids( - root=root, - datatype="tabular", - space="{template}", - desc="{desc}", - suffix="{suffix}.parquet", - **inputs["spim"].wildcards, - ), - allow_missing=True, - ), - participants_tsv=os.path.join(config["bids_dir"], "participants.tsv"), - params: - contrast_column="{contrast_column}", - contrast_value="{contrast_value}", - output: - parquet=bids( - root=root, - datatype="group", - space="{template}", - desc="{desc}", - contrast="{contrast_column}+{contrast_value}", - suffix="{suffix,regionprops|coloc}.parquet", - ), - threads: 1 - resources: - mem_mb=1500, - runtime=10, - script: - "../scripts/concat_subj_parquet_contrast.py" - - -rule group_counts_per_voxel_contrast: - """Calculate counts per voxel based on concatenated points - in template space, filtered by contrast""" - input: - template=bids(root=root, template="{template}", suffix="anat.nii.gz"), - regionprops_parquet=bids( - root=root, - datatype="group", - space="{template}", - desc="{desc}", - contrast="{contrast_column}+{contrast_value}", - suffix="regionprops.parquet", - ), - params: - coord_column_names=config["template_coord_column_names"], - output: - counts_nii=bids( - root=root, - datatype="group", - space="{template}", - level="{level}", - desc="{desc}", - contrast="{contrast_column}+{contrast_value}", - suffix="{stain}+count.nii.gz", - ), - threads: 16 - resources: - mem_mb=200000, - runtime=30, - script: - "../scripts/counts_per_voxel_template.py" - - -rule group_coloc_counts_per_voxel_contrast: - """Calculate counts per voxel based on concatenated coloc points - in template space, filtered by contrast""" - input: - template=bids(root=root, template="{template}", suffix="anat.nii.gz"), - coloc_parquet=bids( - root=root, - datatype="group", - space="{template}", - desc="{desc}", - contrast="{contrast_column}+{contrast_value}", - suffix="coloc.parquet", - ), params: coord_column_names=config["template_coloc_coord_column_names"], - output: - counts_nii=bids( - root=root, - datatype="group", - space="{template}", - level="{level}", - desc="{desc}", - contrast="{contrast_column}+{contrast_value}", - suffix="coloccount.nii.gz", - ), - threads: 16 - resources: - mem_mb=15000, - runtime=10, script: "../scripts/coloc_per_voxel_template.py" - - -rule concat_subj_segstats_contrast: - """Concatenate segstats.tsv files across subjects filtered by contrast - and compute group averages. - - This rule collects mergedsegstats.tsv files from all participants, - adds a participant_id column to identify each subject's data, - merges with participant metadata from participants.tsv, filters - to include only rows where the contrast_column matches the - contrast_value, and computes group averages for each atlas region. - """ - input: - segstats_tsvs=lambda wildcards: inputs["spim"].expand( - bids( - root=root, - datatype="tabular", - seg=wildcards.seg, - from_=wildcards.template, - desc=wildcards.desc, - suffix="mergedsegstats.tsv", - **inputs["spim"].wildcards, - ) - ), - participants_tsv=os.path.join(config["bids_dir"], "participants.tsv"), - params: - contrast_column="{contrast_column}", - contrast_value="{contrast_value}", - metric_columns=expand( - "{stain}+{metric}", stain=stains_for_seg, metric=config["seg_metrics"] - ), - coloc_metric_columns=expand( - "coloc+{metric}", metric=config["coloc_seg_metrics"] - ), - output: - tsv=bids( - root=root, - datatype="group", - seg="{seg}", - from_="{template}", - desc="{desc}", - contrast="{contrast_column}+{contrast_value}", - suffix="groupavgsegstats.tsv", - ), - threads: 1 - resources: - mem_mb=1500, - runtime=10, - script: - "../scripts/concat_subj_segstats_contrast.py" - - -rule map_groupavg_segstats_to_template_nii: - """Map group-averaged segstats to template space as NIfTI files. - - This rule takes the group-averaged segstats for a specific contrast - and paints brain regions with the averaged metric values to create - volumetric maps for 3D visualization. - """ - input: - tsv=bids( - root=root, - datatype="group", - seg="{seg}", - from_="{template}", - desc="{desc}", - contrast="{contrast_column}+{contrast_value}", - suffix="groupavgsegstats.tsv", - ), - dseg=bids(root=root, template="{template}", seg="{seg}", suffix="dseg.nii.gz"), - label_tsv=bids(root=root, template="{template}", seg="{seg}", suffix="dseg.tsv"), - params: - label_column="index", - feature_column="{metric}", - output: - nii=bids( - root=root, - datatype="group", - seg="{seg}", - space="{template}", - desc="{desc}", - contrast="{contrast_column}+{contrast_value}", - metric="{metric}", - suffix="groupavg.nii.gz", - ), - threads: 8 - resources: - mem_mb=16000, - runtime=15, - script: - "../scripts/map_tsv_dseg_to_nii.py" diff --git a/spimquant/workflow/scripts/concat_subj_parquet_contrast.py b/spimquant/workflow/scripts/concat_subj_parquet_contrast.py deleted file mode 100644 index 22e9d22..0000000 --- a/spimquant/workflow/scripts/concat_subj_parquet_contrast.py +++ /dev/null @@ -1,128 +0,0 @@ -"""Concatenate subject-level parquet files filtered by contrast. - -This script concatenates parquet files (regionprops or coloc) from multiple -participants, adds a participant_id column to identify each subject's data, -merges with participant metadata from participants.tsv, and filters the data -to include only rows where the specified contrast_column matches the -contrast_value. - -This is a Snakemake script that expects the `snakemake` object to be available, -which is automatically provided when executed as part of a Snakemake workflow. -""" - -import os -from pathlib import Path - -import pandas as pd - - -def extract_participant_id(path): - """Extract participant_id from a BIDS file path. - - Parameters - ---------- - path : str - Path to a BIDS file - - Returns - ------- - str or None - Participant ID (e.g., 'sub-01') or None if not found - """ - parts = Path(path).parts - for part in parts: - if part.startswith("sub-"): - return part - return None - - -def load_parquets_with_metadata(parquet_paths, participants_df): - """Load all parquet files and merge with participant metadata. - - Parameters - ---------- - parquet_paths : list - List of paths to parquet files - participants_df : pd.DataFrame - DataFrame containing participant metadata - - Returns - ------- - pd.DataFrame - Combined dataframe with all subjects' data and participant metadata - """ - all_data = [] - - for path in parquet_paths: - if not os.path.exists(path): - continue - - # Extract subject ID from path - subject_id = extract_participant_id(path) - - if subject_id is None: - continue - - # Load parquet file - df = pd.read_parquet(path) - df["participant_id"] = subject_id - - all_data.append(df) - - if not all_data: - raise ValueError( - f"No valid parquet files found. Attempted to load {len(parquet_paths)} " - f"file(s). Check that the files exist and contain valid parquet data with " - f"subject IDs in the file paths (e.g., 'sub-01')." - ) - - # Combine all data - combined = pd.concat(all_data, ignore_index=True) - - # Merge with participants metadata - merged = combined.merge(participants_df, on="participant_id", how="left") - - return merged - - -def main(): - """Main function - uses snakemake object provided by Snakemake workflow.""" - # Load participants metadata - participants_df = pd.read_csv(snakemake.input.participants_tsv, sep="\t") - - # Validate participants.tsv has required columns - if "participant_id" not in participants_df.columns: - raise ValueError("participants.tsv must contain a 'participant_id' column") - - # Get contrast filtering parameters - contrast_column = snakemake.params.contrast_column - contrast_value = snakemake.params.contrast_value - - # Validate contrast column exists in participants.tsv - if contrast_column not in participants_df.columns: - raise ValueError( - f"Contrast column '{contrast_column}' not found in participants.tsv. " - f"Available columns: {list(participants_df.columns)}" - ) - - # Load and combine all parquet files - combined_data = load_parquets_with_metadata( - snakemake.input.parquet_files, participants_df - ) - - # Filter data based on contrast column and value - filtered_data = combined_data[combined_data[contrast_column] == contrast_value] - - if filtered_data.empty: - raise ValueError( - f"No data found for {contrast_column}={contrast_value}. " - f"Available values in {contrast_column}: " - f"{combined_data[contrast_column].unique().tolist()}" - ) - - # Save filtered parquet file - filtered_data.to_parquet(snakemake.output.parquet, index=False) - - -if __name__ == "__main__": - main() diff --git a/spimquant/workflow/scripts/concat_subj_segstats_contrast.py b/spimquant/workflow/scripts/concat_subj_segstats_contrast.py deleted file mode 100644 index 0dfa50d..0000000 --- a/spimquant/workflow/scripts/concat_subj_segstats_contrast.py +++ /dev/null @@ -1,188 +0,0 @@ -"""Concatenate subject-level segstats.tsv files filtered by contrast and compute group averages. - -This script concatenates segstats.tsv files from multiple participants, adds a -participant_id column to identify each subject's data, merges with participant -metadata from participants.tsv, filters the data to include only rows where the -specified contrast_column matches the contrast_value, and computes group averages -for each atlas region and metric. - -This is a Snakemake script that expects the `snakemake` object to be available, -which is automatically provided when executed as part of a Snakemake workflow. -""" - -import os -from pathlib import Path - -import pandas as pd -import numpy as np - - -def extract_participant_id(path): - """Extract participant_id from a BIDS file path. - - Parameters - ---------- - path : str - Path to a BIDS file - - Returns - ------- - str or None - Participant ID (e.g., 'sub-01') or None if not found - """ - parts = Path(path).parts - for part in parts: - if part.startswith("sub-"): - return part - return None - - -def load_segstats_with_metadata(segstats_paths, participants_df): - """Load all segstats files and merge with participant metadata. - - Parameters - ---------- - segstats_paths : list - List of paths to segstats.tsv files - participants_df : pd.DataFrame - DataFrame containing participant metadata - - Returns - ------- - pd.DataFrame - Combined dataframe with segstats and participant metadata - """ - all_data = [] - - for path in segstats_paths: - if not os.path.exists(path): - continue - - # Extract subject ID from path - subject_id = extract_participant_id(path) - - if subject_id is None: - continue - - # Load segstats - df = pd.read_csv(path, sep="\t") - df["participant_id"] = subject_id - - all_data.append(df) - - if not all_data: - raise ValueError( - f"No valid segstats files found. Attempted to load {len(segstats_paths)} " - f"file(s). Check that the files exist and contain valid data with " - f"subject IDs in the file paths (e.g., 'sub-01')." - ) - - # Combine all segstats - combined = pd.concat(all_data, ignore_index=True) - - # Merge with participants metadata - merged = combined.merge(participants_df, on="participant_id", how="left") - - return merged - - -def compute_group_averages(data, metric_columns): - """Compute group averages for each atlas region and metric. - - Parameters - ---------- - data : pd.DataFrame - Combined dataframe with segstats data - metric_columns : list - List of metric column names to average - - Returns - ------- - pd.DataFrame - Dataframe with group averages for each region - """ - # Group by region (index and name) - groupby_cols = ["index", "name"] - - # Build aggregation dict - only include columns that exist - agg_dict = {} - missing_columns = [] - for col in metric_columns: - if col in data.columns: - agg_dict[col] = "mean" - else: - missing_columns.append(col) - - # Warn about missing columns to aid debugging - if missing_columns: - print( - f"Warning: The following metric columns were not found in data: {missing_columns}" - ) - print(f"Available columns: {list(data.columns)}") - - if not agg_dict: - raise ValueError( - f"None of the specified metric columns were found in the data. " - f"Requested: {metric_columns}. Available: {list(data.columns)}" - ) - - # Compute group averages - group_avg = data.groupby(groupby_cols, as_index=False).agg(agg_dict) - - return group_avg - - -def main(): - """Main function - uses snakemake object provided by Snakemake workflow.""" - # Load participants metadata - participants_df = pd.read_csv(snakemake.input.participants_tsv, sep="\t") - - # Validate participants.tsv has required columns - if "participant_id" not in participants_df.columns: - raise ValueError("participants.tsv must contain a 'participant_id' column") - - # Get contrast filtering parameters - contrast_column = snakemake.params.contrast_column - contrast_value = snakemake.params.contrast_value - - # Validate contrast column exists in participants.tsv - if contrast_column not in participants_df.columns: - raise ValueError( - f"Contrast column '{contrast_column}' not found in participants.tsv. " - f"Available columns: {list(participants_df.columns)}" - ) - - # Load and combine all segstats files - combined_data = load_segstats_with_metadata( - snakemake.input.segstats_tsvs, participants_df - ) - - # Filter data based on contrast column and value - filtered_data = combined_data[combined_data[contrast_column] == contrast_value] - - if filtered_data.empty: - raise ValueError( - f"No data found for {contrast_column}={contrast_value}. " - f"Available values in {contrast_column}: " - f"{combined_data[contrast_column].unique().tolist()}" - ) - - # Get metric columns to average - handle None values - metric_columns = [] - if snakemake.params.metric_columns is not None: - metric_columns.extend(snakemake.params.metric_columns) - if snakemake.params.coloc_metric_columns is not None: - metric_columns.extend(snakemake.params.coloc_metric_columns) - - if not metric_columns: - raise ValueError("No metric columns specified for averaging") - - # Compute group averages - group_avg = compute_group_averages(filtered_data, metric_columns) - - # Save averaged segstats table - group_avg.to_csv(snakemake.output.tsv, sep="\t", index=False) - - -if __name__ == "__main__": - main() diff --git a/spimquant/workflow/scripts/perform_group_stats.py b/spimquant/workflow/scripts/perform_group_stats.py index da952bb..be1c0cb 100644 --- a/spimquant/workflow/scripts/perform_group_stats.py +++ b/spimquant/workflow/scripts/perform_group_stats.py @@ -1,19 +1,31 @@ -"""Perform group-based statistical analysis on segmentation statistics. - -This script reads segstats.tsv files from multiple participants and performs -statistical tests (e.g., t-tests, ANOVA) based on contrasts defined in the -participants.tsv file. - -This is a Snakemake script that expects the `snakemake` object to be available, -which is automatically provided when executed as part of a Snakemake workflow. +"""Perform formula-based group statistical analysis on segmentation statistics. + +This script reads segstats.tsv files from multiple participants, fits a single +global OLS model per region/metric using the user-supplied patsy/statsmodels +formula, and computes pairwise contrast statistics using the model's covariance +matrix. + +The contrast is specified via ``snakemake.params.pairwise_contrast_info``, a +dict with keys: +- ``factor``: the column in participants.tsv to compare levels of +- ``levelA`` / ``levelB``: the two levels to contrast (levelA − levelB) +- ``strata``: a dict of {column: value} pairs for stratified analyses; + if non-empty the model is fit with *all* data but marginal means are + evaluated at the given stratum values. + +This is a Snakemake script that expects the ``snakemake`` object to be +available, which is automatically provided when executed as part of a +Snakemake workflow. """ import os -import pandas as pd -import numpy as np -from scipy import stats from pathlib import Path +import numpy as np +import pandas as pd +import statsmodels.formula.api as smf +from patsy import dmatrix + def load_segstats_with_metadata(segstats_paths, participants_df): """Load all segstats files and merge with participant metadata. @@ -21,14 +33,14 @@ def load_segstats_with_metadata(segstats_paths, participants_df): Parameters ---------- segstats_paths : list - List of paths to segstats.tsv files + List of paths to segstats.tsv files. participants_df : pd.DataFrame - DataFrame containing participant metadata + DataFrame containing participant metadata from participants.tsv. Returns ------- pd.DataFrame - Combined dataframe with segstats and participant metadata + Combined dataframe with segstats and participant metadata. """ all_data = [] @@ -36,185 +48,268 @@ def load_segstats_with_metadata(segstats_paths, participants_df): if not os.path.exists(path): continue - # Extract subject ID from path - # Path format: ...sub-{subject_id}/... parts = Path(path).parts - subject_id = None - for part in parts: - if part.startswith("sub-"): - subject_id = part - break - + subject_id = next((p for p in parts if p.startswith("sub-")), None) if subject_id is None: continue - # Load segstats df = pd.read_csv(path, sep="\t") df["participant_id"] = subject_id - all_data.append(df) if not all_data: raise ValueError("No valid segstats files found") - # Combine all segstats combined = pd.concat(all_data, ignore_index=True) + return combined.merge(participants_df, on="participant_id", how="left") - # Merge with participants metadata - merged = combined.merge(participants_df, on="participant_id", how="left") - return merged +def build_prediction_row(region_data, pairwise_factor, level, strata): + """Build a one-row DataFrame for marginal mean prediction. + Continuous variables are held at their mean; categorical variables are + held at their mode. The pairwise factor and any strata variables are set + to the supplied values. -def perform_two_group_test(data, group_column, group1_value, group2_value, metrics): - """Perform two-sample t-tests for each region and metric. + Parameters + ---------- + region_data : pd.DataFrame + pairwise_factor : str + level : str + strata : dict + + Returns + ------- + pd.DataFrame (single row) + """ + row = {} + for col in region_data.columns: + if col == "participant_id": + continue + if pd.api.types.is_numeric_dtype(region_data[col]): + row[col] = [region_data[col].mean()] + else: + mode_vals = region_data[col].mode() + row[col] = [mode_vals.iloc[0] if len(mode_vals) > 0 else None] + + row[pairwise_factor] = [level] + for factor, value in strata.items(): + row[factor] = [value] + + return pd.DataFrame(row) + + +def compute_contrast_for_metric( + region_data, + formula_template, + metric, + pairwise_factor, + level_a, + level_b, + strata, +): + """Fit a global OLS model and compute one pairwise contrast for *metric*. + + The model is fit on all rows of *region_data*. Marginal means are + computed by predicting at reference covariate values with the pairwise + factor set to each level in turn (and strata variables fixed at the + requested values). + + Parameters + ---------- + region_data : pd.DataFrame + formula_template : str + Formula with the literal string ``metric`` as the response variable + placeholder, e.g. ``"metric ~ C(treatment) + age"``. + metric : str + Actual metric column name; replaces ``metric`` in the formula. + pairwise_factor : str + level_a, level_b : str + strata : dict + + Returns + ------- + dict with keys ``{metric}_tstat``, ``{metric}_pval``, ``{metric}_cohensd``, + ``{metric}_mean_{level_a}``, ``{metric}_mean_{level_b}``, + ``n_{level_a}``, ``n_{level_b}``. + """ + # Replace the placeholder with the actual (possibly backtick-quoted) column. + actual_formula = formula_template.replace("metric", f"`{metric}`") + + # Filter to rows relevant for the raw-mean / Cohen's d calculation. + grp_a = region_data[region_data[pairwise_factor] == level_a] + grp_b = region_data[region_data[pairwise_factor] == level_b] + if strata: + for f, v in strata.items(): + grp_a = grp_a[grp_a[f] == v] + grp_b = grp_b[grp_b[f] == v] + + n_a = int(grp_a[metric].dropna().shape[0]) + n_b = int(grp_b[metric].dropna().shape[0]) + mean_a = float(grp_a[metric].mean()) if n_a > 0 else np.nan + mean_b = float(grp_b[metric].mean()) if n_b > 0 else np.nan + std_a = float(grp_a[metric].std()) if n_a > 1 else np.nan + std_b = float(grp_b[metric].std()) if n_b > 1 else np.nan + + result = { + f"n_{level_a}": n_a, + f"n_{level_b}": n_b, + f"{metric}_mean_{level_a}": mean_a, + f"{metric}_mean_{level_b}": mean_b, + f"{metric}_tstat": np.nan, + f"{metric}_pval": np.nan, + f"{metric}_cohensd": np.nan, + } + + if n_a < 2 or n_b < 2: + return result + + try: + fitted = smf.ols(actual_formula, data=region_data).fit() + + # Build prediction rows for each level at the desired strata. + pred_df_a = build_prediction_row(region_data, pairwise_factor, level_a, strata) + pred_df_b = build_prediction_row(region_data, pairwise_factor, level_b, strata) + + # Use patsy with the model's design_info for consistent dummy encoding. + design_info = fitted.model.data.design_info + dm_a = np.asarray(dmatrix(design_info, pred_df_a, return_type="matrix")) + dm_b = np.asarray(dmatrix(design_info, pred_df_b, return_type="matrix")) + contrast_vec = (dm_a - dm_b)[0] + + ct = fitted.t_test(contrast_vec) + tstat = float(np.asarray(ct.tvalue).item()) + pval = float(np.asarray(ct.pvalue).item()) + + # Cohen's d from pooled standard deviation. + denom = n_a + n_b - 2 + if denom > 0 and not (np.isnan(std_a) or np.isnan(std_b)): + pooled_var = ((n_a - 1) * std_a**2 + (n_b - 1) * std_b**2) / denom + pooled_std = np.sqrt(pooled_var) if pooled_var >= 0 else np.nan + cohensd = (mean_a - mean_b) / pooled_std if pooled_std > 0 else np.nan + else: + cohensd = np.nan + + result[f"{metric}_tstat"] = tstat + result[f"{metric}_pval"] = pval + result[f"{metric}_cohensd"] = cohensd + + except Exception as exc: # noqa: BLE001 + print( + f"Warning: model fitting failed for metric '{metric}' " + f"(factor='{pairwise_factor}', {level_a} vs {level_b}, " + f"strata={strata}): {exc}. Leaving stats as NaN." + ) # leave NaN placeholders + + return result + + +def perform_model_based_contrast( + data, formula, pairwise_factor, level_a, level_b, strata, metrics +): + """Compute pairwise contrast statistics for every region and metric. Parameters ---------- data : pd.DataFrame - Combined dataframe with segstats and metadata - group_column : str - Column name for grouping (e.g., 'treatment') - group1_value : str - Value for group 1 (e.g., 'control') - group2_value : str - Value for group 2 (e.g., 'drug') - metrics : list - List of metric columns to test (e.g., ['fieldfrac', 'density']) + Combined dataframe with segstats and participant metadata. + formula : str + Model formula (patsy/statsmodels), with ``metric`` as placeholder. + pairwise_factor : str + level_a, level_b : str + strata : dict + metrics : list[str] Returns ------- pd.DataFrame - Results dataframe with statistics for each region + One row per region with columns for each metric's statistics. """ - results = [] - - # Get unique regions regions = data[["index", "name"]].drop_duplicates() + rows = [] for _, region in regions.iterrows(): - region_idx = region["index"] - region_name = region["name"] - - # Filter data for this region - region_data = data[data["index"] == region_idx] - - result = { - "index": region_idx, - "name": region_name, - } + region_data = data[data["index"] == region["index"]].copy() - # Get group data - group1_data = region_data[region_data[group_column] == group1_value] - group2_data = region_data[region_data[group_column] == group2_value] + row = {"index": region["index"], "name": region["name"]} - result[f"n_{group1_value}"] = len(group1_data) - result[f"n_{group2_value}"] = len(group2_data) - - # Perform tests for each metric for metric in metrics: if metric not in region_data.columns: - continue - - g1_values = group1_data[metric].dropna() - g2_values = group2_data[metric].dropna() - - if len(g1_values) < 2 or len(g2_values) < 2: - # Not enough data for testing - result[f"{metric}_mean_{group1_value}"] = ( - g1_values.mean() if len(g1_values) > 0 else np.nan - ) - result[f"{metric}_mean_{group2_value}"] = ( - g2_values.mean() if len(g2_values) > 0 else np.nan + row.update( + { + f"n_{level_a}": np.nan, + f"n_{level_b}": np.nan, + f"{metric}_mean_{level_a}": np.nan, + f"{metric}_mean_{level_b}": np.nan, + f"{metric}_tstat": np.nan, + f"{metric}_pval": np.nan, + f"{metric}_cohensd": np.nan, + } ) - result[f"{metric}_tstat"] = np.nan - result[f"{metric}_pval"] = np.nan continue - # Calculate means - result[f"{metric}_mean_{group1_value}"] = g1_values.mean() - result[f"{metric}_mean_{group2_value}"] = g2_values.mean() - result[f"{metric}_std_{group1_value}"] = g1_values.std() - result[f"{metric}_std_{group2_value}"] = g2_values.std() - - # Perform two-sample t-test - tstat, pval = stats.ttest_ind(g1_values, g2_values) - result[f"{metric}_tstat"] = tstat - result[f"{metric}_pval"] = pval - - # Calculate effect size (Cohen's d) - pooled_std = np.sqrt( - ( - (len(g1_values) - 1) * g1_values.std() ** 2 - + (len(g2_values) - 1) * g2_values.std() ** 2 - ) - / (len(g1_values) + len(g2_values) - 2) + stats = compute_contrast_for_metric( + region_data, + formula, + metric, + pairwise_factor, + level_a, + level_b, + strata, ) - if pooled_std > 0: - cohens_d = (g1_values.mean() - g2_values.mean()) / pooled_std - result[f"{metric}_cohensd"] = cohens_d - else: - result[f"{metric}_cohensd"] = np.nan + row.update(stats) - results.append(result) + rows.append(row) - return pd.DataFrame(results) + return pd.DataFrame(rows) def main(): """Main function - uses snakemake object provided by Snakemake workflow.""" participants_df = pd.read_csv(snakemake.input.participants_tsv, sep="\t") - # Validate participants.tsv has required columns if "participant_id" not in participants_df.columns: raise ValueError("participants.tsv must contain a 'participant_id' column") - # Load and combine all segstats files combined_data = load_segstats_with_metadata( snakemake.input.segstats_tsvs, participants_df ) - # Get contrast information - contrast_column = snakemake.params.contrast_column - contrast_values = snakemake.params.contrast_values + formula = snakemake.params.model + contrast_info = snakemake.params.pairwise_contrast_info metrics = snakemake.params.metric_columns + snakemake.params.coloc_metric_columns - # Validate contrast column exists if specified - if contrast_column is not None and contrast_column not in participants_df.columns: + if not formula: + raise ValueError("No model formula supplied (--model is required).") + if not contrast_info: raise ValueError( - f"Contrast column '{contrast_column}' not found in participants.tsv. " - f"Available columns: {list(participants_df.columns)}" + "No pairwise contrast information found for this wildcard. " + "Check that --pairwise matches a column in participants.tsv." ) - if contrast_column is None or contrast_values is None: - # No contrasts specified - just aggregate data - # Group by region and compute summary statistics - - agg_dict = {"participant_id": "count"} - for metric in metrics: - agg_dict[metric] = ["mean", "std", "min", "max"] - - results = combined_data.groupby(["index", "name"]).agg(agg_dict).reset_index() - results.columns = ["_".join(col).strip("_") for col in results.columns.values] - results.rename(columns={"participant_id_count": "n_subjects"}, inplace=True) - - elif len(contrast_values) == 2: - # Two-group comparison + pairwise_factor = contrast_info["factor"] + level_a = contrast_info["levelA"] + level_b = contrast_info["levelB"] + strata = contrast_info.get("strata", {}) + + # Validate required columns exist + for col in [pairwise_factor] + list(strata.keys()): + if col not in combined_data.columns: + raise ValueError( + f"Column '{col}' not found in data after merging with " + f"participants.tsv. Available columns: {list(combined_data.columns)}" + ) - results = perform_two_group_test( - combined_data, - contrast_column, - contrast_values[0], - contrast_values[1], - metrics, - ) - else: - raise ValueError( - "Currently only two-group contrasts are supported. " - f"Got {len(contrast_values)} groups: {contrast_values}" - ) + results = perform_model_based_contrast( + combined_data, + formula, + pairwise_factor, + level_a, + level_b, + strata, + metrics, + ) - # Save results results.to_csv(snakemake.output.stats_tsv, sep="\t", index=False)