|
| 1 | +#' @title Ratio Image CNN for Biomarker Panel Classification |
| 2 | +#' |
| 3 | +#' @description |
| 4 | +#' Converts biomarker expression profiles into pairwise log-ratio images and |
| 5 | +#' classifies them using a CNN. This approach is inherently within-sample |
| 6 | +#' normalized because pairwise ratios cancel multiplicative batch effects. |
| 7 | +#' |
| 8 | +#' @details |
| 9 | +#' The analogy to histopathology image analysis is exact: |
| 10 | +#' \itemize{ |
| 11 | +#' \item Batch effects in biomarkers ≈ brightness/contrast variation in images |
| 12 | +#' \item Pairwise log-ratios ≈ color normalization (relative channel intensities) |
| 13 | +#' \item CNN on ratio images ≈ learning spatial patterns invariant to staining |
| 14 | +#' } |
| 15 | +#' |
| 16 | +#' For a panel of p biomarkers, each sample becomes a p×p image where |
| 17 | +#' pixel(i,j) = log2(biomarker_i / biomarker_j). On log-scale data, this |
| 18 | +#' is simply the difference. The CNN then learns non-linear patterns in the |
| 19 | +#' pairwise relationship structure. |
| 20 | +#' |
| 21 | +#' @references |
| 22 | +#' Stawiski K et al. (2026). Pairwise ratio images for batch-effect-free |
| 23 | +#' biomarker classification. (in preparation) |
| 24 | +#' |
| 25 | +#' Sharma A et al. (2019). DeepInsight: A methodology to transform a |
| 26 | +#' non-image data to an image for convolution neural network architecture. |
| 27 | +#' Scientific Reports, 9(1), 11399. |
| 28 | +#' |
| 29 | +#' @name ratio-image-cnn |
| 30 | +NULL |
| 31 | + |
| 32 | + |
| 33 | +#' @title Create Pairwise Ratio Image from Expression Vector |
| 34 | +#' |
| 35 | +#' @description |
| 36 | +#' Converts a single sample's expression vector into a p×p pairwise |
| 37 | +#' log-ratio matrix. If the input is on log scale, pixel(i,j) = x[i] - x[j] |
| 38 | +#' which equals log(expr_i / expr_j). |
| 39 | +#' |
| 40 | +#' @param x Numeric vector of length p (one sample's expression values, log-scale) |
| 41 | +#' |
| 42 | +#' @return A p×p numeric matrix (the ratio image) |
| 43 | +#' |
| 44 | +#' @export |
| 45 | +make_ratio_image <- function(x) { |
| 46 | + p <- length(x) |
| 47 | + outer(x, x, "-") |
| 48 | +} |
| 49 | + |
| 50 | + |
| 51 | +#' @title Batch-Create Ratio Images from Expression Matrix |
| 52 | +#' |
| 53 | +#' @description |
| 54 | +#' Creates ratio images for all samples in an expression matrix. |
| 55 | +#' |
| 56 | +#' @param mat Numeric matrix (samples × features), log-scale |
| 57 | +#' |
| 58 | +#' @return A 3D array (samples × features × features) |
| 59 | +#' |
| 60 | +#' @export |
| 61 | +make_ratio_images <- function(mat) { |
| 62 | + n <- nrow(mat) |
| 63 | + p <- ncol(mat) |
| 64 | + imgs <- array(0, dim = c(n, p, p)) |
| 65 | + for (i in seq_len(n)) { |
| 66 | + imgs[i, , ] <- outer(mat[i, ], mat[i, ], "-") |
| 67 | + } |
| 68 | + imgs |
| 69 | +} |
| 70 | + |
| 71 | + |
| 72 | +#' @title Train Ratio Image CNN for Binary Classification |
| 73 | +#' |
| 74 | +#' @description |
| 75 | +#' Trains a lightweight CNN on pairwise ratio images for binary classification |
| 76 | +#' (e.g., cancer vs healthy). Uses torch for GPU acceleration. |
| 77 | +#' |
| 78 | +#' @param X_train Numeric matrix (N_train × p features), log-scale |
| 79 | +#' @param y_train Integer/numeric vector (0/1 labels) |
| 80 | +#' @param X_test Numeric matrix (N_test × p features), log-scale |
| 81 | +#' @param epochs Number of training epochs (default: 30) |
| 82 | +#' @param lr Learning rate (default: 0.001) |
| 83 | +#' @param batch_size Batch size (default: 256) |
| 84 | +#' @param class_weight Whether to apply inverse-frequency class weighting (default: TRUE) |
| 85 | +#' @param device "cuda" or "cpu" (default: auto-detect) |
| 86 | +#' @param verbose Print progress (default: TRUE) |
| 87 | +#' |
| 88 | +#' @return A list with: |
| 89 | +#' \itemize{ |
| 90 | +#' \item predictions: numeric vector of predicted probabilities for test set |
| 91 | +#' \item model: trained torch model |
| 92 | +#' \item train_images: ratio images used for training (3D array) |
| 93 | +#' \item image_stats: mean and sd used for image normalization |
| 94 | +#' } |
| 95 | +#' |
| 96 | +#' @examples |
| 97 | +#' \dontrun{ |
| 98 | +#' result <- train_ratio_cnn(X_train, y_train, X_test, epochs = 30) |
| 99 | +#' auc <- pROC::auc(pROC::roc(y_test, result$predictions)) |
| 100 | +#' } |
| 101 | +#' |
| 102 | +#' @export |
| 103 | +train_ratio_cnn <- function(X_train, y_train, X_test, |
| 104 | + epochs = 30L, lr = 0.001, |
| 105 | + batch_size = 256L, |
| 106 | + class_weight = TRUE, |
| 107 | + device = NULL, |
| 108 | + verbose = TRUE) { |
| 109 | + |
| 110 | + if (!requireNamespace("torch", quietly = TRUE)) { |
| 111 | + stop("Package 'torch' is required for CNN training. Install with: install.packages('torch')") |
| 112 | + } |
| 113 | + |
| 114 | + # Auto-detect device |
| 115 | + if (is.null(device)) { |
| 116 | + device <- if (torch::cuda_is_available()) "cuda" else "cpu" |
| 117 | + } |
| 118 | + if (verbose) message("Training ratio-image CNN on: ", device) |
| 119 | + |
| 120 | + # Create ratio images |
| 121 | + train_imgs <- make_ratio_images(X_train) |
| 122 | + test_imgs <- make_ratio_images(X_test) |
| 123 | + |
| 124 | + # Normalize images using training statistics |
| 125 | + img_mean <- mean(train_imgs) |
| 126 | + img_sd <- sd(as.vector(train_imgs)) |
| 127 | + if (img_sd > 0) { |
| 128 | + train_imgs <- (train_imgs - img_mean) / img_sd |
| 129 | + test_imgs <- (test_imgs - img_mean) / img_sd |
| 130 | + } |
| 131 | + |
| 132 | + p <- ncol(X_train) |
| 133 | + |
| 134 | + # Convert to torch tensors (N, 1, H, W) |
| 135 | + X_tr <- torch::torch_tensor(train_imgs, dtype = torch::torch_float())$unsqueeze(2) |
| 136 | + y_tr <- torch::torch_tensor(as.numeric(y_train), dtype = torch::torch_float()) |
| 137 | + X_te <- torch::torch_tensor(test_imgs, dtype = torch::torch_float())$unsqueeze(2) |
| 138 | + |
| 139 | + if (device == "cuda") { |
| 140 | + X_tr <- X_tr$cuda() |
| 141 | + y_tr <- y_tr$cuda() |
| 142 | + X_te <- X_te$cuda() |
| 143 | + } |
| 144 | + |
| 145 | + # Class weights |
| 146 | + pos_weight <- if (class_weight) { |
| 147 | + n_pos <- sum(y_train == 1) |
| 148 | + n_neg <- sum(y_train == 0) |
| 149 | + torch::torch_tensor(n_neg / max(n_pos, 1), dtype = torch::torch_float()) |
| 150 | + } else { |
| 151 | + torch::torch_tensor(1.0) |
| 152 | + } |
| 153 | + if (device == "cuda") pos_weight <- pos_weight$cuda() |
| 154 | + |
| 155 | + # Define CNN model |
| 156 | + model <- torch::nn_module( |
| 157 | + initialize = function() { |
| 158 | + self$conv1 <- torch::nn_conv2d(1, 16, kernel_size = 3, padding = 1) |
| 159 | + self$conv2 <- torch::nn_conv2d(16, 32, kernel_size = 3, padding = 1) |
| 160 | + self$pool <- torch::nn_adaptive_avg_pool2d(3) |
| 161 | + self$fc1 <- torch::nn_linear(32 * 9, 64) |
| 162 | + self$fc2 <- torch::nn_linear(64, 1) |
| 163 | + self$relu <- torch::nn_relu() |
| 164 | + self$dropout <- torch::nn_dropout(0.3) |
| 165 | + }, |
| 166 | + forward = function(x) { |
| 167 | + x <- self$relu(self$conv1(x)) |
| 168 | + x <- self$relu(self$conv2(x)) |
| 169 | + x <- self$pool(x) |
| 170 | + x <- x$view(c(x$size(1), -1)) |
| 171 | + x <- self$dropout(self$relu(self$fc1(x))) |
| 172 | + self$fc2(x) |
| 173 | + } |
| 174 | + ) |
| 175 | + |
| 176 | + net <- model() |
| 177 | + if (device == "cuda") net <- net$cuda() |
| 178 | + |
| 179 | + optimizer <- torch::optim_adam(net$parameters, lr = lr, weight_decay = 1e-4) |
| 180 | + criterion <- torch::nn_bce_with_logits_loss(pos_weight = pos_weight) |
| 181 | + |
| 182 | + # Training loop |
| 183 | + dataset <- torch::dataset( |
| 184 | + initialize = function(X, y) { |
| 185 | + self$X <- X |
| 186 | + self$y <- y |
| 187 | + }, |
| 188 | + .getitem = function(i) { |
| 189 | + list(x = self$X[i, , , ], y = self$y[i]) |
| 190 | + }, |
| 191 | + .length = function() { |
| 192 | + self$X$size(1) |
| 193 | + } |
| 194 | + )(X_tr, y_tr) |
| 195 | + |
| 196 | + loader <- torch::dataloader(dataset, batch_size = batch_size, shuffle = TRUE) |
| 197 | + |
| 198 | + net$train() |
| 199 | + for (epoch in seq_len(epochs)) { |
| 200 | + epoch_loss <- 0 |
| 201 | + coro::loop(for (batch in loader) { |
| 202 | + optimizer$zero_grad() |
| 203 | + output <- net(batch$x)$squeeze() |
| 204 | + loss <- criterion(output, batch$y) |
| 205 | + loss$backward() |
| 206 | + optimizer$step() |
| 207 | + epoch_loss <- epoch_loss + loss$item() |
| 208 | + }) |
| 209 | + if (verbose && epoch %% 10 == 0) { |
| 210 | + message(sprintf(" Epoch %d/%d, loss: %.4f", epoch, epochs, epoch_loss)) |
| 211 | + } |
| 212 | + } |
| 213 | + |
| 214 | + # Prediction |
| 215 | + net$eval() |
| 216 | + torch::with_no_grad({ |
| 217 | + logits <- net(X_te)$squeeze() |
| 218 | + preds <- torch::torch_sigmoid(logits)$cpu()$to(dtype = torch::torch_float()) |
| 219 | + }) |
| 220 | + |
| 221 | + list( |
| 222 | + predictions = as.numeric(preds), |
| 223 | + model = net, |
| 224 | + train_images = train_imgs, |
| 225 | + image_stats = list(mean = img_mean, sd = img_sd) |
| 226 | + ) |
| 227 | +} |
0 commit comments