From 149eddd338b036dd9231ed5fca3391d7523750de Mon Sep 17 00:00:00 2001 From: Severin Magel Date: Tue, 26 May 2026 21:03:55 -0400 Subject: [PATCH 1/2] Add CNN base class and IceCubeDNN model MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce `graphnet.models.cnn`, the convolutional-network counterpart to the existing GNN backbones, intended to consume the image data representation added in the image-representation PR. - `CNN` — abstract base class (analogue of `GNN`) defining the interface for convolutional backbones operating on image-shaped `Data` objects. - `IceCubeDNN` — configurable CNN backbone following the IceCube DNN reconstruction architecture. `cnn/__init__.py` exports `CNN` and `IceCubeDNN`; the `LCSC` model lands in a follow-up PR. Split from #813. --- src/graphnet/models/cnn/__init__.py | 4 + src/graphnet/models/cnn/cnn.py | 35 ++ src/graphnet/models/cnn/icecube_dnn.py | 421 +++++++++++++++++++++++++ 3 files changed, 460 insertions(+) create mode 100644 src/graphnet/models/cnn/__init__.py create mode 100644 src/graphnet/models/cnn/cnn.py create mode 100644 src/graphnet/models/cnn/icecube_dnn.py diff --git a/src/graphnet/models/cnn/__init__.py b/src/graphnet/models/cnn/__init__.py new file mode 100644 index 000000000..c8ccf9bdc --- /dev/null +++ b/src/graphnet/models/cnn/__init__.py @@ -0,0 +1,4 @@ +"""CNN-specific modules, for performing the main learnable operations.""" + +from .cnn import CNN +from .icecube_dnn import IceCubeDNN diff --git a/src/graphnet/models/cnn/cnn.py b/src/graphnet/models/cnn/cnn.py new file mode 100644 index 000000000..2453790e4 --- /dev/null +++ b/src/graphnet/models/cnn/cnn.py @@ -0,0 +1,35 @@ +"""Base CNN-specific `Model` class(es).""" + +from abc import abstractmethod + +from torch import Tensor +from torch_geometric.data import Data + +from graphnet.models import Model + + +class CNN(Model): + """Base class for all core CNN models in graphnet.""" + + def __init__(self, nb_inputs: int, nb_outputs: int) -> None: + """Construct `CNN`.""" + # Base class constructor + super().__init__() + + # Member variables + self._nb_inputs = nb_inputs + self._nb_outputs = nb_outputs + + @property + def nb_inputs(self) -> int: + """Return number of input features.""" + return self._nb_inputs + + @property + def nb_outputs(self) -> int: + """Return number of output features.""" + return self._nb_outputs + + @abstractmethod + def forward(self, data: Data) -> Tensor: + """Apply learnable forward pass in model.""" diff --git a/src/graphnet/models/cnn/icecube_dnn.py b/src/graphnet/models/cnn/icecube_dnn.py new file mode 100644 index 000000000..c899f39a3 --- /dev/null +++ b/src/graphnet/models/cnn/icecube_dnn.py @@ -0,0 +1,421 @@ +"""Implementation of the IceCube DNN image convolution model by Theo Glauch. + +Based on the `upgoing_muon_energy` model from +https://github.com/IceCubeOpenSource/i3deepice/tree/master +""" + +from typing import List, Tuple, Union + +import torch +from torch import nn +from pytorch_lightning import LightningModule +from torch_geometric.data import Data +from .cnn import CNN + + +class Conv3dBN(LightningModule): + """3D convolution with batch normalization from Theo Glauch's DNN.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Tuple[int, int, int], + padding: Union[str, Tuple[int, int, int]], + bias: bool = False, + ): + """Create a Conv3dBN module. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + kernel_size: Size of the kernel. + padding: Padding of the kernel. + bias: If True, bias is used in the Convolution. + """ + super().__init__() + + self.conv = nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + bias=bias, + ) + + self.bn = nn.BatchNorm3d(out_channels) + self.activation = nn.ReLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the Conv3dBN.""" + return self.activation(self.bn(self.conv(x))) + + +class InceptionBlock4(LightningModule): + """Inception block with 4 parallel towers from Theo Glauch's DNN.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + t0: int = 2, + t1: int = 4, + t2: int = 5, + n_pool: int = 3, + ): + """Create a InceptionBlock4 module. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + t0: Size of the first kernel sequence. + t1: Size of the second kernel sequence. + t2: Size of the third kernel sequence. + n_pool: Size of the pooling kernel. + """ + super().__init__() + + self.tower0 = nn.Sequential( + Conv3dBN( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(t0, 1, 1), + padding="same", + ), + Conv3dBN( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(1, t0, 1), + padding="same", + ), + Conv3dBN( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(1, 1, t0), + padding="same", + ), + ) + + self.tower1 = nn.Sequential( + Conv3dBN( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(t1, 1, 1), + padding="same", + ), + Conv3dBN( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(1, t1, 1), + padding="same", + ), + Conv3dBN( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(1, 1, t1), + padding="same", + ), + ) + + self.tower4 = nn.Sequential( + Conv3dBN( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, 1, t2), + padding="same", + ), + ) + + self.tower3 = nn.Sequential( + nn.MaxPool3d( + kernel_size=(n_pool, n_pool, n_pool), + stride=(1, 1, 1), + padding=(n_pool // 2, n_pool // 2, n_pool // 2), + ), + Conv3dBN( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, 1, 1), + padding="same", + ), + ) + self.out_channels = out_channels * 4 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the InceptionBlock4.""" + ret = torch.cat( + [ + self.tower0(x), + self.tower1(x), + self.tower3(x), + self.tower4(x), + ], + dim=1, + ) + return ret + + +class InceptionResnet(LightningModule): + """Inception block with residual connections from Theo Glauch's DNN.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + t1: int = 2, + t2: int = 4, + n_pool: int = 3, + scale: float = 0.1, + ): + """Create a InceptionResnet module. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + t1: Size of the first kernel sequence. + t2: Size of the second kernel sequence. + n_pool: Size of the pooling kernel. + scale: Scaling factor for the residual connection. + """ + super().__init__() + self._scale = scale + self.tower1 = nn.Sequential( + Conv3dBN( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, 1, 1), + padding="same", + ), + Conv3dBN( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(t1, 1, 1), + padding="same", + ), + Conv3dBN( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(1, t1, 1), + padding="same", + ), + Conv3dBN( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(1, 1, t1), + padding="same", + ), + ) + self.tower2 = nn.Sequential( + Conv3dBN( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, 1, 1), + padding="same", + ), + Conv3dBN( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(t2, 1, 1), + padding="same", + ), + Conv3dBN( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(1, t2, 1), + padding="same", + ), + Conv3dBN( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(1, 1, t2), + padding="same", + ), + ) + self.tower3 = nn.Sequential( + nn.MaxPool3d( + kernel_size=(n_pool, n_pool, n_pool), + stride=(1, 1, 1), + padding=(n_pool // 2, n_pool // 2, n_pool // 2), + ), + Conv3dBN( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(1, 1, 1), + padding="same", + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the InceptionResnet block.""" + tmp = torch.cat( + [ + self.tower1(x), + self.tower2(x), + self.tower3(x), + ], + dim=1, + ) + return x + self._scale * tmp + + +class IceCubeDNN(CNN): + """Implementation of the IceCube DNN by Theo Glauch. + + An inception-based 3D CNN originally used within IceCube. Based on + the model from + https://github.com/IceCubeOpenSource/i3deepice/tree/master + """ + + def __init__( + self, + nb_inputs: int = 15, + nb_outputs: int = 16, + image_size: Tuple[int, int, int] = (10, 10, 60), + inception_out_channels: int = 18, + inception_configs: List[Tuple[int, int, int]] = [ + (2, 5, 8), + (2, 3, 7), + (2, 4, 8), + (3, 5, 9), + (2, 8, 9), + ], + resnet_out_channels: int = 24, + resnet_t2_pattern: List[int] = [3, 4, 5], + num_resblocks1_repeats: int = 6, + num_resblocks2_repeats: int = 6, + avgpool1_size: Tuple[int, int, int] = (2, 2, 3), + avgpool2_size: Tuple[int, int, int] = (1, 1, 2), + avgpool3_size: Tuple[int, int, int] = (1, 1, 2), + pointwise_channels: List[int] = [64, 4], + mlp_hidden_sizes: List[int] = [120, 64], + ) -> None: + """Construct `IceCubeDNN`. + + Args: + nb_inputs: Number of input features. + nb_outputs: Number of output features. + image_size: Spatial dimensions of the input image + (height, width, depth). + inception_out_channels: Output channels per tower in each + inception block. + inception_configs: List of (t0, t1, t2) kernel size tuples + for each InceptionBlock4 layer. + resnet_out_channels: Output channels per tower in each + inception-resnet block. + resnet_t2_pattern: Pattern of t2 kernel sizes repeated in + each group of resnet blocks. + num_resblocks1_repeats: Number of times to repeat the + resnet_t2_pattern in the first resnet stage. + num_resblocks2_repeats: Number of times to repeat the + resnet_t2_pattern in the second resnet stage. + avgpool1_size: Kernel size for the first average pooling. + avgpool2_size: Kernel size for the second average pooling. + avgpool3_size: Kernel size for the third average pooling. + pointwise_channels: Output channels for each 1x1x1 + convolution layer. + mlp_hidden_sizes: Hidden layer sizes for the final MLP. + The input size is computed from the preceding layers + and the output size is nb_outputs. + """ + super().__init__(nb_inputs, nb_outputs) + + # Inception blocks + inception_blocks = [] + in_ch = nb_inputs + for t0, t1, t2 in inception_configs: + inception_blocks.append( + InceptionBlock4( + in_channels=in_ch, + out_channels=inception_out_channels, + t0=t0, + t1=t1, + t2=t2, + ) + ) + in_ch = inception_out_channels * 4 + self.inceptionblocks4 = nn.Sequential(*inception_blocks) + + # All inception/resnet blocks use "same" padding, so spatial + # dimensions only change at pooling layers. + spatial = list(image_size) + + self.avgpool1 = nn.AvgPool3d(avgpool1_size) + spatial = [s // p for s, p in zip(spatial, avgpool1_size)] + self.bn1 = nn.BatchNorm3d(in_ch) + + # First resnet stage + resnet_in_ch = in_ch + tmp = [] + for _ in range(num_resblocks1_repeats): + for t2 in resnet_t2_pattern: + tmp.append( + InceptionResnet( + in_channels=resnet_in_ch, + out_channels=resnet_out_channels, + t2=t2, + ) + ) + resnet_in_ch = resnet_out_channels * 3 + self.resblocks1 = nn.Sequential(*tmp) + + self.avgpool2 = nn.AvgPool3d(avgpool2_size) + spatial = [s // p for s, p in zip(spatial, avgpool2_size)] + self.bn2 = nn.BatchNorm3d(resnet_in_ch) + + # Second resnet stage + tmp = [] + for _ in range(num_resblocks2_repeats): + for t2 in resnet_t2_pattern: + tmp.append( + InceptionResnet( + in_channels=resnet_in_ch, + out_channels=resnet_out_channels, + t2=t2, + ) + ) + resnet_in_ch = resnet_out_channels * 3 + self.resblocks2 = nn.Sequential(*tmp) + + # Pointwise 1x1x1 convolutions + pointwise_layers: List[nn.Module] = [] + pw_in = resnet_in_ch + for pw_out in pointwise_channels: + pointwise_layers.append( + nn.Conv3d( + in_channels=pw_in, + out_channels=pw_out, + kernel_size=(1, 1, 1), + padding=(0, 0, 0), + ) + ) + pointwise_layers.append(nn.ReLU()) + pw_in = pw_out + self.convs111 = nn.Sequential(*pointwise_layers) + + self.avgpool3 = nn.AvgPool3d(avgpool3_size) + spatial = [s // p for s, p in zip(spatial, avgpool3_size)] + + # MLP head + latent_dim = pw_in * spatial[0] * spatial[1] * spatial[2] + mlp_sizes = [latent_dim] + mlp_hidden_sizes + [nb_outputs] + mlp_layers: List[nn.Module] = [] + for i in range(len(mlp_sizes) - 1): + mlp_layers.append(nn.Linear(mlp_sizes[i], mlp_sizes[i + 1])) + self.mlps = nn.Sequential(*mlp_layers) + + def forward(self, data: Data) -> torch.Tensor: + """Apply learnable forward pass in model.""" + assert len(data.x) == 1, "Only one image expected" + x = data.x[0] + x = self.inceptionblocks4(x) + x = self.avgpool1(x) + x = self.bn1(x) + x = self.resblocks1(x) + x = self.avgpool2(x) + x = self.bn2(x) + x = self.resblocks2(x) + x = self.convs111(x) + x = self.avgpool3(x) + x = nn.Flatten()(x) + x = self.mlps(x) + return x From 36fd6f209b3b4c69ea8b2ba315195697a7173842 Mon Sep 17 00:00:00 2001 From: Severin Magel Date: Tue, 26 May 2026 21:09:58 -0400 Subject: [PATCH 2/2] Add LCSC CNN model Add the LCSC convolutional backbone, building on the `CNN` base class. Registered in `cnn/__init__.py` alongside `CNN` and `IceCubeDNN`. Stacked on the CNN-base PR. Split from #813. --- src/graphnet/models/cnn/__init__.py | 1 + src/graphnet/models/cnn/lcsc.py | 551 ++++++++++++++++++++++++++++ 2 files changed, 552 insertions(+) create mode 100644 src/graphnet/models/cnn/lcsc.py diff --git a/src/graphnet/models/cnn/__init__.py b/src/graphnet/models/cnn/__init__.py index c8ccf9bdc..d44dd9f83 100644 --- a/src/graphnet/models/cnn/__init__.py +++ b/src/graphnet/models/cnn/__init__.py @@ -2,3 +2,4 @@ from .cnn import CNN from .icecube_dnn import IceCubeDNN +from .lcsc import LCSC diff --git a/src/graphnet/models/cnn/lcsc.py b/src/graphnet/models/cnn/lcsc.py new file mode 100644 index 000000000..67e65d8bd --- /dev/null +++ b/src/graphnet/models/cnn/lcsc.py @@ -0,0 +1,551 @@ +"""Module for the Lightning CNN signal classifier (LCSC). + +All credits go to Alexander Harnisch (https://github.com/AlexHarn) +""" + +from .cnn import CNN +import torch +from torch_geometric.data import Data +from typing import List, Union, Tuple + + +class LCSC(CNN): + """Lightning CNN Signal Classifier (LCSC). + + All credits go to Alexander Harnisch ( + https://github.com/AlexHarn) + + Works with any single-image representation. The default + parameters were tested on IceCube simulation using the + Main Array image only. + """ + + def __init__( + self, + num_input_features: int, + out_put_dim: int = 2, + input_norm: bool = True, + num_conv_layers: int = 8, + conv_filters: List[int] = [50, 50, 50, 50, 50, 50, 50, 10], + kernel_size: Union[int, List[Union[int, List[int]]]] = 3, + padding: Union[str, int, List[Union[str, int]]] = "Same", + pooling_type: List[Union[None, str]] = [ + None, + "Avg", + None, + "Avg", + None, + "Avg", + None, + "Avg", + ], + pooling_kernel_size: List[Union[None, int, List[int]]] = [ + None, + [1, 1, 2], + None, + [2, 2, 2], + None, + [2, 2, 2], + None, + [2, 2, 2], + ], + pooling_stride: Union[int, List[Union[None, int, List[int]]]] = [ + None, + [1, 1, 2], + None, + [2, 2, 2], + None, + [2, 2, 2], + None, + [2, 2, 2], + ], + num_fc_neurons: int = 50, + norm_list: bool = True, + norm_type: str = "Batch", + image_size: Tuple[int, int, int] = (10, 10, 60), + ) -> None: + """Initialize the Lightning CNN signal classifier (LCSC). + + Args: + num_input_features: Number of input features. + out_put_dim: Number of output dimensions of final MLP. + Defaults to 2. + input_norm: Whether to apply normalization to the input. + Defaults to True. + num_conv_layers: Number of convolutional layers. + Defaults to 8. + conv_filters: List of number of convolutional + filters to use in hidden layers. + Defaults to [50, 50, 50, 50, 50, 50, 50, 50, 10]. + NOTE needs to have the length of `num_conv_layers`. + kernel_size: Size of the convolutional kernels. + Options are: + int: single integer for all dimensions + and all layers, + e.g. 3 would equal [3, 3, 3] for each layer. + list: list of integers specifying the kernel size, + for each layer for all dimensions equally, + e.g. [3, 5, 6] would equal [[3,3,3], [5,5,5], [6,6,6]]. + NOTE: needs to have the length of `num_conv_layers`. + If a list of lists is provided, each list will be used + for the corresponding layer as kernel size. + NOTE: If a list if passed it needs to have the length + of `num_conv_layers`. + padding: Padding for the convolutional layers. + Options are: + 'Same' for same convolutional padding, + int: single integer for all dimensions and all layers, + e.g. 1 would equal [1, 1, 1]. + list: list of integers specifying the padding for each + dimension, for each layer equally, + e.g. [1, 2, 3]. + NOTE: If a list is passed it needs to have the length + of `num_conv_layers`. + Defaults to 'Same'. + pooling_type: List of pooling types for layers. + Options are + None : No pooling is used, + 'Avg' : Average pooling is used, + 'Max' : Max pooling is used + Defaults to [ + None, 'Avg', + None, 'Avg', + None, 'Avg', + None, 'Avg' + ]. + NOTE: the length of the list must be equal to + `num_conv_layers`. + pooling_kernel_size: List of pooling kernel sizes for each + layer. If an integer is provided, it will be used for + all layers. In case of a list the options for its + elements are: + list: list of integers for each dimension, e.g. [1, 1, 2]. + int: single integer for all dimensions, + e.g. 2 would equal [2, 2, 2]. + If None, no pooling is applied. + NOTE: If a list is passed it needs to have the length + of `num_conv_layers`. + Defaults to [ + None, [1, 1, 2], + None, [2, 2, 2], + None, [2, 2, 2], + None, [2, 2, 2] + ]. + pooling_stride: List of pooling strides for each layer. + If an integer is provided, it will be used for all layers. + In case of a list the options for its elements are: + list: list of integers for each dimension, e.g. [1, 1, 2]. + int: single integer for all dimensions, + e.g. 2 would equal [2, 2, 2]. + If None, no pooling is applied. + NOTE: If a list is passed it needs to have the length + of `num_conv_layers`. + Defaults to [ + None, [1, 1, 2], + None, [2, 2, 2], + None, [2, 2, 2], + None, [2, 2, 2] + ]. + num_fc_neurons: Number of neurons in the fully connected + layers. Defaults to 50. + norm_list: Whether to apply normalization for each + convolutional layer. If a boolean is provided, it will + be used for all layers. Defaults to True. + NOTE: If a list is passed it needs to have the length + of `num_conv_layers`. + norm_type: Type of normalization to use. + Options are 'Batch' or 'Instance'. + Defaults to 'Batch'. + image_size: Size of the input image in the format + (height, width, depth). + NOTE: Only needs to be changed if the input image is not + the standard IceCube 86 image size. + """ + super().__init__(nb_inputs=num_input_features, nb_outputs=out_put_dim) + + # Check and parse input parameters + conv_filters, kernel_size, padding = self._parse_conv_arguments( + num_conv_layers=num_conv_layers, + conv_filters=conv_filters, + kernel_size=kernel_size, + padding=padding, + ) + pooling_kernel_size, pooling_stride = self._parse_pooling_arguments( + num_conv_layers=num_conv_layers, + pooling_kernel_size=pooling_kernel_size, + pooling_stride=pooling_stride, + ) + self._norm_list, norm_class = self._parse_norm_arguments( + num_conv_layers=num_conv_layers, + num_input_features=num_input_features, + input_norm=input_norm, + norm_list=norm_list, + norm_type=norm_type, + ) + + # Set convolution, pooling, and normalization layers + self.input_norm = input_norm + dimensions = self._set_conv_layers( + num_conv_layers=num_conv_layers, + num_input_features=num_input_features, + image_size=image_size, + conv_filters=conv_filters, + kernel_size=kernel_size, + padding=padding, + pooling_type=pooling_type, + pooling_kernel_size=pooling_kernel_size, + pooling_stride=pooling_stride, + norm_class=norm_class, + ) + + # Set linear layers + latent_dim = ( + dimensions[0] * dimensions[1] * dimensions[2] * dimensions[3] + ) + self.flatten = torch.nn.Flatten() + self.fc1 = torch.nn.Linear(latent_dim, num_fc_neurons) + self.fc2 = torch.nn.Linear(num_fc_neurons, out_put_dim) + + def _parse_conv_arguments( + self, + num_conv_layers: int, + conv_filters: Union[int, List[int]], + kernel_size: Union[int, List[Union[int, List[int]]]], + padding: Union[str, int, List[Union[str, int]]], + ) -> Tuple[List[int], List, List]: + """Parse and validate convolution arguments. + + Args: + num_conv_layers: Number of convolutional layers. + conv_filters: Convolutional filters per layer. + kernel_size: Kernel sizes per layer. + padding: Padding per layer. + + Returns: + Parsed conv_filters, kernel_size, and padding as lists. + """ + if isinstance(conv_filters, int): + conv_filters = [conv_filters for _ in range(num_conv_layers)] + else: + if not isinstance(conv_filters, list): + raise TypeError( + f"`conv_filters` must be a " + f"list or an integer, not {type(conv_filters)}!" + ) + if len(conv_filters) != num_conv_layers: + raise ValueError( + f"`conv_filters` must have {num_conv_layers} " + f"elements, not {len(conv_filters)}!" + ) + + if isinstance(kernel_size, int): + kernel_size = [ # type: ignore[assignment] + [kernel_size, kernel_size, kernel_size] + for _ in range(num_conv_layers) + ] + else: + if not isinstance(kernel_size, list): + raise TypeError( + "`kernel_size` must be a list or an " + f"integer, not {type(kernel_size)}!" + ) + if len(kernel_size) != num_conv_layers: + raise ValueError( + f"`kernel_size` must have {num_conv_layers} " + f"elements, not {len(kernel_size)}!" + ) + + if isinstance(padding, int): + padding = [padding for _ in range(num_conv_layers)] + elif isinstance(padding, str): + if padding.lower() == "same": + padding = ["same" for _ in range(num_conv_layers)] + else: + raise ValueError( + "`padding` must be 'Same' or an integer, " + f"not {padding}!" + ) + else: + if not isinstance(padding, list): + raise TypeError( + f"`padding` must be a list or " + f"an integer, not {type(padding)}!" + ) + if len(padding) != num_conv_layers: + raise ValueError( + f"`padding` must have {num_conv_layers} " + f"elements, not {len(padding)}!" + ) + + return conv_filters, kernel_size, padding + + def _parse_pooling_arguments( + self, + num_conv_layers: int, + pooling_kernel_size: Union[int, List[Union[None, int, List[int]]]], + pooling_stride: Union[int, List[Union[None, int, List[int]]]], + ) -> Tuple[List, List]: + """Parse and validate pooling arguments. + + Args: + num_conv_layers: Number of convolutional layers. + pooling_kernel_size: Pooling kernel sizes per layer. + pooling_stride: Pooling strides per layer. + + Returns: + Parsed pooling_kernel_size and pooling_stride as lists. + """ + if isinstance(pooling_kernel_size, int): + pooling_kernel_size = [ + pooling_kernel_size for _ in range(num_conv_layers) + ] + else: + if not isinstance(pooling_kernel_size, list): + raise TypeError( + "`pooling_kernel_size` must be a list or " + f"an integer, not {type(pooling_kernel_size)}!" + ) + if len(pooling_kernel_size) != num_conv_layers: + raise ValueError( + f"`pooling_kernel_size` must have " + f"{num_conv_layers} elements, not " + f"{len(pooling_kernel_size)}!" + ) + + if isinstance(pooling_stride, int): + pooling_stride = [pooling_stride for _ in range(num_conv_layers)] + else: + if not isinstance(pooling_stride, list): + raise TypeError( + "`pooling_stride` must be a list or an integer, " + f"not {type(pooling_stride)}!" + ) + if len(pooling_stride) != num_conv_layers: + raise ValueError( + f"`pooling_stride` must have {num_conv_layers} " + f"elements, not {len(pooling_stride)}!" + ) + + return pooling_kernel_size, pooling_stride + + def _parse_norm_arguments( + self, + num_conv_layers: int, + num_input_features: int, + input_norm: bool, + norm_list: Union[bool, List[bool]], + norm_type: str, + ) -> Tuple[List[bool], type]: + """Parse and validate normalization arguments. + + Args: + num_conv_layers: Number of convolutional layers. + num_input_features: Number of input features. + input_norm: Whether to apply input normalization. + norm_list: Per-layer normalization flags. + norm_type: Type of normalization ('Batch' or 'Instance'). + + Returns: + Parsed norm_list and the normalization class. + """ + if isinstance(norm_list, bool): + parsed_norm_list = [norm_list for _ in range(num_conv_layers)] + else: + if not isinstance(norm_list, list): + raise TypeError( + "`norm_list` must be a list or a boolean, " + f"not {type(norm_list)}!" + ) + if len(norm_list) != num_conv_layers: + raise ValueError( + f"`norm_list` must have {num_conv_layers} " + f"elements, not {len(norm_list)}!" + ) + parsed_norm_list = norm_list + + if norm_type.lower() == "instance": + norm_class = torch.nn.InstanceNorm3d + if input_norm: + self.input_normal = torch.nn.InstanceNorm3d(num_input_features) + elif norm_type.lower() == "batch": + norm_class = torch.nn.BatchNorm3d + if input_norm: + self.input_normal = torch.nn.BatchNorm3d( + num_input_features, momentum=None, affine=False + ) + else: + raise ValueError( + "`norm_type` has to be 'instance' or " + f"'batch', not '{norm_type}'!" + ) + + return parsed_norm_list, norm_class + + def _set_conv_layers( + self, + num_conv_layers: int, + num_input_features: int, + image_size: Tuple[int, int, int], + conv_filters: List[int], + kernel_size: List, + padding: List, + pooling_type: List[Union[None, str]], + pooling_kernel_size: List, + pooling_stride: List, + norm_class: type, + ) -> List[int]: + """Build convolution, pooling, and normalization layers. + + Args: + num_conv_layers: Number of convolutional layers. + num_input_features: Number of input features. + image_size: Size of the input image (height, width, depth). + conv_filters: Convolutional filters per layer. + kernel_size: Kernel sizes per layer. + padding: Padding per layer. + pooling_type: Pooling type per layer. + pooling_kernel_size: Pooling kernel sizes per layer. + pooling_stride: Pooling strides per layer. + norm_class: Normalization class to use. + + Returns: + Output dimensions after all layers. + """ + self.conv = torch.nn.ModuleList() + self.pool = torch.nn.ModuleList() + self.normal = torch.nn.ModuleList() + + dimensions: List[int] = [num_input_features, *image_size] + for i in range(num_conv_layers): + self.conv.append( + torch.nn.Conv3d( + dimensions[0], + conv_filters[i], + kernel_size=kernel_size[i], + padding=padding[i], + ) + ) + dimensions = self._calc_output_dimension( + dimensions, + conv_filters[i], + kernel_size[i], + padding[i], + ) + if pooling_type[i] is None or pooling_type[i] == "None": + self.pool.append(None) + elif pooling_type[i] == "Avg": + self.pool.append( + torch.nn.AvgPool3d( + kernel_size=pooling_kernel_size[i], + stride=pooling_stride[i], + ) + ) + dimensions = self._calc_output_dimension( + dimensions, + out_channels=dimensions[0], + kernel_size=pooling_kernel_size[i], + stride=pooling_stride[i], + ) + elif pooling_type[i] == "Max": + self.pool.append( + torch.nn.MaxPool3d( + kernel_size=pooling_kernel_size[i], + stride=pooling_stride[i], + ) + ) + dimensions = self._calc_output_dimension( + dimensions, + out_channels=dimensions[0], + kernel_size=pooling_kernel_size[i], + stride=pooling_stride[i], + ) + else: + raise ValueError( + "Pooling type must be 'None', 'Avg' or 'Max'!" + ) + if self._norm_list[i]: + self.normal.append(norm_class(dimensions[0])) + else: + self.normal.append(None) + + return dimensions + + def _calc_output_dimension( + self, + dimensions: List[int], + out_channels: int, + kernel_size: Union[None, int, List[int]], + padding: Union[str, int, List[int]] = 0, + stride: Union[None, int, List[int]] = 1, + ) -> List[int]: + """Calculate the output dimension after a CNN layers. + + Works for Conv3D, MaxPool3D and AvgPool3D layers. + + Args: + dimensions: Current dimensions of the input tensor. + (C,H,W,D) where C is the number of channels, + H is the height, W is the width and D is the depth. + out_channels: Number of output channels. + kernel_size: Size of the kernel. + If an integer is provided, it will be used for all dimensions. + padding: Padding size. + If an integer is provided, it will be used for all dimensions. + If 'Same', the padding will be calculated to keep the + output size the same as the input size. + Defaults to 0. + stride: Stride size. + If an integer is provided, it will be used for all dimensions. + Defaults to 1. + + Returns: + New dimensions after the layer. + + NOTE: For the pooling layers, set out_channels equal to the + input channels. Since they do not change the number of channels. + """ + krnl_sz: int + if isinstance(padding, str): + if not padding.lower() == "same": + raise ValueError( + f"`padding` must be 'Same' or an integer, not {padding}!" + ) + dimensions[0] = out_channels + else: + for i in range(1, 4): + if isinstance(kernel_size, list): + krnl_sz = kernel_size[i - 1] + else: + assert isinstance(kernel_size, int) + krnl_sz = kernel_size + if isinstance(padding, list): + pad = padding[i - 1] + else: + pad = padding + if isinstance(stride, list): + strd = stride[i - 1] + else: + assert isinstance(stride, int) + strd = stride + dimensions[i] = (dimensions[i] + 2 * pad - krnl_sz) // strd + 1 + + return dimensions + + def forward(self, data: Data) -> torch.Tensor: + """Forward pass of the LCSC.""" + assert len(data.x) == 1, "Only a single image is expected" + x = data.x[0] + if self.input_norm: + x = self.input_normal(x) + for i in range(len(self.conv)): + x = self.conv[i](x) + if self.pool[i] is not None: + x = self.pool[i](x) + x = torch.nn.functional.elu(x) + if self.normal[i] is not None: + x = self.normal[i](x) + + x = self.flatten(x) + x = torch.nn.functional.elu(self.fc1(x)) + x = torch.nn.functional.elu(self.fc2(x)) + return x