Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 61 additions & 54 deletions climanet/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings

import numpy as np
from .utils import add_month_day_dims, calc_stats
from .utils import add_month_day_dims, calc_stats, add_month_hour_dims
from .geo_embedding_utils import (
calculate_sh_geo_pos_embeddings,
compute_patch_geo_pos_embedding,
Expand All @@ -28,7 +28,21 @@ def __init__(
sh_pos_table: str = None, # Optional; str formatted path to precomputed table of sh
sh_embed_dim: int = 96, # sh_embed_dim should <= (sh_order_L + 1)**2
sh_order_L: int = 10,
is_hourly: bool = False,
):
"""Initialize the dataset with daily and monthly data, and optional land mask.

Args:
daily_da: xarray DataArray with daily data (M, time, H, W)
monthly_da: xarray DataArray with monthly data (M, H, W)
land_mask: Optional xarray DataArray with land mask (H, W) or (1, H, W)
time_dim: Name of the time dimension in the input data
spatial_dims: Tuple of (lat_dim, lon_dim) names in the input data
patch_size: Tuple of (patch_height, patch_width) in pixels
stride: Tuple of (stride_height, stride_width) in pixels. If None, defaults to patch_size (non-overlapping patches).
is_hourly: Whether the daily data is hourly (T=31*24) or daily (T=31).

"""
self.spatial_dims = spatial_dims
self.patch_size = patch_size
self.daily_da = daily_da
Expand All @@ -53,46 +67,55 @@ def __init__(
f"Patch size {patch_size} is larger than data dimensions {daily_da.sizes[spatial_dims]}"
)

# Reshape daily → (M, T=31, H, W), monthly → (M, H, W),
# and get padded_days_mask → (M, T=31)
daily_mt, monthly_m, padded_days_mask, daily_timef = add_month_day_dims(
daily_da, monthly_da, time_dim=time_dim
)
if is_hourly:
# hours_per_day == 24
# Reshape daily → (M, T=31*24, H, W), monthly → (M, H, W),
# and get padded_days_mask → (M, T=31*24)
daily_mt, monthly_m, padded_days_mask, daily_timef = add_month_hour_dims(
daily_da, monthly_da, time_dim=time_dim
)
else:
# Reshape daily → (M, T=31, H, W), monthly → (M, H, W),
# and get padded_days_mask → (M, T=31)
daily_mt, monthly_m, padded_days_mask, daily_timef = add_month_day_dims(
daily_da, monthly_da, time_dim=time_dim
)

# Convert to numpy once — all __getitem__ calls use these
self.daily_np = daily_mt.to_numpy().copy().astype(np.float32) # (M, T=31, H, W) float
self.monthly_np = monthly_m.to_numpy().copy().astype(np.float32) # (M, H, W) float
self.padded_mask_np = padded_days_mask.to_numpy().copy() # (M, T=31) bool
self.daily_timef_np = daily_timef.to_numpy().copy().astype(np.float32) # (M,T=31, 4)
self.daily_t = torch.from_numpy(daily_mt.values.astype(np.float32)) # (M, T=31, H, W)
self.monthly_t = torch.from_numpy(monthly_m.values.astype(np.float32)) # (M, H, W)
self.padded_days_tensor = torch.from_numpy(padded_days_mask.values.copy()).bool() # (M, T=31)
self.daily_timef_t = torch.from_numpy(daily_timef.values.astype(np.float32)) # (M, T=31, 4)

# Store coordinate arrays
self.lat_coords = daily_da[spatial_dims[0]].to_numpy().copy()
self.lon_coords = daily_da[spatial_dims[1]].to_numpy().copy()

if land_mask is not None:
lm = land_mask.to_numpy().copy()
lm = torch.from_numpy(land_mask.values.copy()).bool()
if lm.ndim == 3:
lm = lm.squeeze(0) # (1, H, W) → (H, W)
self.land_mask_np = lm
self.land_mask_t = lm
else:
self.land_mask_np = None
self.land_mask_t = None

# Precompute the NaN mask before filling NaNs
# daily_mask: True where NaN (i.e. missing ocean data, not land)
self.daily_nan_mask = np.isnan(self.daily_np) # (M, T=31, H, W)
self.daily_nan_mask = torch.isnan(self.daily_t) # (M, T=31, H, W)

# NaNs will be filled with 0 in-place
np.nan_to_num(self.daily_np, copy=False, nan=0.0)
self.daily_t.nan_to_num_(nan=0.0)

# Stats will be set later via set_stats() for train/test datasets
self.daily_mean = None
self.daily_std = None

# Precompute padded_days_mask as a tensor (same for all patches)
self.padded_days_tensor = torch.from_numpy(self.padded_mask_np).bool()
# Pre-build zero land tensor for the no-mask case
ph, pw = self.patch_size
self._zero_land = torch.zeros(ph, pw, dtype=torch.bool)

# Precompute lazy index mapping for patches
H, W = self.daily_np.shape[2], self.daily_np.shape[3]
H, W = self.daily_t.shape[2], self.daily_t.shape[3]
self.patch_indices = self._compute_patch_indices(H, W)

# Precompute geoposition and scale embeddings for patches
Expand All @@ -101,6 +124,9 @@ def __init__(
self.patch_geo_embeddings, self.patch_scale_features = (
self._compute_geoscalepatch_embeddings()
)
self.scale_f_dim = torch.tensor(self.patch_scale_features.shape[-1])
self.sh_embed_dim_t = torch.tensor(self.sh_embed_dim)
self.harmonic_order_t = torch.tensor(self.sh_order_L)

def _get_geo_pos(self, sh_pos_table: str):
"""Calculate or retrieve spherical harmonics based geo position embeddings."""
Expand Down Expand Up @@ -205,33 +231,19 @@ def __getitem__(self, idx):
ph, pw = self.patch_size

# Extract spatial patch via numpy slicing — faster than xarray indexing
daily_patch = self.daily_np[
:, :, i : i + ph, j : j + pw
] # (M, T, H, W) -> (M,T,pH, pW)
monthly_patch = self.monthly_np[
:, i : i + ph, j : j + pw
] # (M, H, W) -> (M, pH, pW)
daily_nan_mask = self.daily_nan_mask[
:, :, i : i + ph, j : j + pw
] # (M, T, H, W) -> (M, T, pH, pW)

if self.land_mask_np is not None:
land_patch = self.land_mask_np[i : i + ph, j : j + pw] # (H, W)
land_tensor = torch.from_numpy(np.ascontiguousarray(land_patch)).bool()
else:
land_tensor = torch.zeros(ph, pw, dtype=torch.bool)
# (M, T, H, W) -> (M,T,pH, pW)
daily_tensor = self.daily_t[:, :, i : i + ph, j : j + pw ].unsqueeze(0)

# geo_pos_tensor = self.sh_geo_pos[i: i + ph, j: j + pw] # (H,W, sh_emb_dim) -> (pH, pW, sh_embed_dim)
# (M, H, W) -> (M, pH, pW)
monthly_tensor = self.monthly_t[:, i : i + ph, j : j + pw]

# Convert to tensors (from_numpy is zero-copy on contiguous arrays)
# (1, M, T, H, W)
daily_tensor = torch.from_numpy(daily_patch).unsqueeze(0)
# (M, H, W)
monthly_tensor = torch.from_numpy(monthly_patch)
# (1, M, T, H, W)
daily_nan_mask = torch.from_numpy(daily_nan_mask).unsqueeze(0)
# ( M, T, 2)
daily_timef_tensor = torch.from_numpy(self.daily_timef_np)
# (M, T, H, W) -> (M, T, pH, pW)
daily_nan_mask = self.daily_nan_mask[:, :, i : i + ph, j : j + pw].unsqueeze(0)

if self.land_mask_t is not None:
land_tensor = self.land_mask_t[i : i + ph, j : j + pw] # (H, W)
else:
land_tensor = self._zero_land

# daily_mask: NaN locations that are NOT land
# Reshape land_tensor for broadcasting: (pH, pW) → (1, 1, 1, pH, pW)
Expand All @@ -249,24 +261,19 @@ def __getitem__(self, idx):
# get scale feature for patch
scale_feature_tensor = self.patch_scale_features[idx] # (10,)

# create tensors to pass sh embedding dimension, harmonic order, and scale feature dim
sh_embed_dim = torch.tensor(self.sh_embed_dim)
harmonic_order = torch.tensor(self.sh_order_L)
scale_f_dim = torch.tensor(len(scale_feature_tensor))

# Convert to tensors
return {
"daily_patch": daily_tensor, # (C=1, M, T=31, pH, pW)
"monthly_patch": monthly_tensor, # (M, pH, pW)
"daily_mask_patch": daily_mask_tensor, # (C=1, M, T=31, pH, pW)
"land_mask_patch": land_tensor, # (pH,pW) True=Land
"daily_timef_patch": daily_timef_tensor, # (M, T=31, 2)
"daily_timef_patch": self.daily_timef_t, # (M, T=31, 2)
"padded_days_mask": self.padded_days_tensor, # (M, T=31) True=padded
"scale_feature_patch": scale_feature_tensor, # (10,)
"geo_pos_embedding_patch": geo_pos_embedding_tensor, # (sh_embed_dim,)
"sh_embed_dim": sh_embed_dim,
"harmonic_order": harmonic_order,
"scale_f_dim": scale_f_dim,
"sh_embed_dim": self.sh_embed_dim_t,
"harmonic_order": self.harmonic_order_t,
"scale_f_dim": self.scale_f_dim,
"coords": (i, j),
"lat_patch": lat_patch, # (pH,)
"lon_patch": lon_patch, # (pW,)
Expand All @@ -282,14 +289,14 @@ def compute_stats(self, indices: list = None) -> Tuple[np.ndarray, np.ndarray]:
Tuple of (mean, std) arrays
"""
if indices is None:
data = self.monthly_np # (M, H, W)
data = self.monthly_t.numpy() # (M, H, W)
else:
# Stack selected spatial patches
ph, pw = self.patch_size
patches = []
for idx in indices:
i, j = self.patch_indices[idx]
patch = self.monthly_np[:, i : i + ph, j : j + pw]
patch = self.monthly_t[:, i : i + ph, j : j + pw].numpy()
patches.append(patch)
data = np.concatenate(patches, axis=-1)

Expand Down
3 changes: 1 addition & 2 deletions climanet/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def predict_monthly_var(
run_dir: str = ".",
verbose: bool = True,
dataloader_num_workers: int = 2,
predict_threads: int | None = None,
):
"""
Predicts monthly variable values using a trained model and a provided dataset.
Expand Down Expand Up @@ -107,7 +106,7 @@ def predict_monthly_var(
# Initialize an empty list to store predictions
base_dataset = dataset.dataset if hasattr(dataset, "dataset") else dataset

M = base_dataset.monthly_np.shape[0]
M = base_dataset.monthly_t.shape[0]
H, W = base_dataset.patch_size
all_predictions = torch.empty(len(dataset), M, H, W)

Expand Down
77 changes: 28 additions & 49 deletions climanet/st_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
import math
import torch
import torch.nn as nn
from einops import rearrange


device = "cuda" if torch.cuda.is_available() else "cpu"


class VideoEncoder(nn.Module):
Expand Down Expand Up @@ -69,11 +65,16 @@ def forward(self, x, mask):
as an additional input channel
"""
# x: (B,1,T,H,W), mask: (B,1,T,H,W) where True means missing
valid = (~mask).float()
valid = mask.logical_not().to(x.dtype)
x = x * valid # zero-out missing values
x = torch.cat([x, valid], dim=1) # add validity as a channel

x = self.proj(x) # (B, C, T', H', W')
x = rearrange(x, "b c t h w -> b (t h w) c")

B, C, Tp, Hp, Wp = x.shape
x = x.contiguous()
x = x.permute(0, 2, 3, 4, 1).reshape(B, Tp * Hp * Wp, C)

x = self.norm(x)
x = self.drop(x)
return x # (B, N_patches, embed_dim)
Expand Down Expand Up @@ -232,16 +233,13 @@ class TemporalAttentionAggregator(nn.Module):
months.
"""

def __init__(self, embed_dim=128, max_days=31, max_months=12, dropout=0.0):
def __init__(self, embed_dim=128, max_months=12, dropout=0.0):
"""Initialize the temporal attention aggregator.

Args:
embed_dim: Dimension of the embedding. The default is 128.
Many vision transformers use embedding dimensions that are multiples
of 64 (e.g., 64, 128, 256). This can be tuned.
max_days: Maximum length of the temporal dimension to precompute
encodings for. Default is 31, which is sufficient for a month of
daily data.
max_months: Maximum number of months (temporal patches) to precompute
encodings for. Default is 12, which is sufficient for a year of monthly data.
dropout: Dropout rate for regularization in the day scorer and
Expand Down Expand Up @@ -282,6 +280,10 @@ def __init__(self, embed_dim=128, max_days=31, max_months=12, dropout=0.0):
nn.Dropout(dropout),
)

# Pre-compute and register as buffer — auto-moves with .to(device/dtype)
pe = self.pos_months(max_months) # (max_months, C)
self.register_buffer("pe_months_cache", pe) # tracks device/dtype automatically

def forward(self, x, M, T, H, W, time_features, padded_days_mask=None):
"""
Args:
Expand All @@ -298,41 +300,35 @@ def forward(self, x, M, T, H, W, time_features, padded_days_mask=None):
Tensor of shape (B, M, H*W, C) with one temporally aggregated, where C is the embedding dimension.
"""
B, M, Tp, Hp, Wp, C = x.shape
HW = Hp * Wp

# Reshape to (B, Hp*Wp, M, Tp, C) for temporal processing
seq = x.permute(0, 3, 4, 1, 2, 5).reshape(B, Hp * Wp, M, Tp, C)
seq = x.permute(0, 3, 4, 1, 2, 5).reshape(B, HW, M, Tp, C)

temp_emb = self.time_embed(time_features) # (B,M,T,emd_dim)
# expand spatially
temp_emb = temp_emb[:, None, :, :, :] # [B, 1, M, T, C]
temp_emb = temp_emb.expand(-1, H * W, -1, -1, -1)
pe_months = self.pos_months(M).to(seq.device).to(seq.dtype) # (M, C)
temp_emb = self.time_embed(time_features)

seq = seq + temp_emb # add temporal embeddings
seq = seq + pe_months[None, None, :, None, :] # add month PE
pe_months = self.pe_months_cache[:M]
token_emb = temp_emb + pe_months[None, :, None, :] # (B, M, T, C)

# Day attention per month
day_logits = self.day_scorer(seq).squeeze(-1) # (B, HW, M, T)
day_logits = self.day_scorer(token_emb).squeeze(-1) # (B, M, T)

# padded_days_mask is (B, M, T) true=padded, -> (B, HW, M, T)
if padded_days_mask is not None:
pad = padded_days_mask[:, None, :, :].expand(B, H * W, M, T)
day_logits = day_logits.masked_fill(pad, float("-inf"))
day_logits = day_logits.masked_fill(padded_days_mask, float("-inf"))
day_w = torch.softmax(day_logits, dim=-1) # (B, M, T)

month_tokens = torch.einsum("bmt,bhmtc->bhmc", day_w, seq) # (B, HW, M, C)

day_w = torch.softmax(day_logits, dim=-1) # turns inf to 0
month_tokens = (seq * day_w.unsqueeze(-1)).sum(dim=3) # (B, HW, M, C)
month_emb = torch.einsum("bmt,bmtc->bmc", day_w, token_emb) # (B, M, C)
month_tokens = month_tokens + month_emb[:, None, :, :]

# Cross-month attention at each spatial location
z = rearrange(month_tokens, "b s m c -> (b s) m c")
z = month_tokens.reshape(B * HW, M, C)
z_ln = self.month_ln(z)
attn_out, _ = self.month_attn(z_ln, z_ln, z_ln, need_weights=False)
z = z + attn_out
z = z + self.month_ffn(z)
out = z.view(B, HW, M, C).permute(0, 2, 1, 3)

# Back to (B, M, Hp*Wp, C)
z = z.view(B, Hp * Wp, M, C)
out = z.permute(0, 2, 1, 3) # (B, M, Hp*Wp, C)
return out # (B, M, H*W, C) C: embedding dimension
return out # (B, M, H*W, C) C: embedding dimension


class MonthlyConvDecoder(nn.Module):
Expand Down Expand Up @@ -631,7 +627,6 @@ def __init__(
in_chans=1,
embed_dim=128,
patch_size=(1, 4, 4),
max_days=31,
max_months=12,
num_months=12,
hidden=256,
Expand All @@ -648,7 +643,6 @@ def __init__(
in_chans: Number of input channels (e.g., 1 for SST, additional channels possible)
embed_dim: Dimension of the patch embedding
patch_size: Tuple of (T, H, W) patch sizes for temporal and spatial patching
max_days: Maximum number of days for temporal positional encoding
max_months: Maximum number of months for temporal positional encoding
num_months: Number of months to predict (output channels in decoder)
hidden: Hidden dimension used in the decoder
Expand Down Expand Up @@ -676,7 +670,6 @@ def __init__(
)
self.temporal = TemporalAttentionAggregator(
embed_dim=embed_dim,
max_days=max_days,
max_months=max_months,
dropout=dropout,
)
Expand Down Expand Up @@ -742,21 +735,6 @@ def forward(
)
assert T % self.patch_size[0] == 0, "T must be divisible by patch size"

if self.patch_size[0] > 1:
daily_timef = daily_timef.view(B, M, Tp, self.patch_size[0], 4).mean(
dim=3
) # -> (B,M, Tp, 4)

if padded_days_mask is not None and self.patch_size[0] > 1:
B, M, T_days = padded_days_mask.shape
if T_days % self.patch_size[0] != 0:
raise ValueError(
f"T_days={T_days} must be divisible by patch_size[0]={self.patch_size[0]}"
)
padded_days_mask = padded_days_mask.view(
B, M, T_days // self.patch_size[0], self.patch_size[0]
).all(dim=-1) # (B, M, Tp)

# Step 1: Encode spatio-temporal patches
# each month independently by folding M into batch
# encoder input shape = (B, C, T, H, W) where C is channel.
Expand Down Expand Up @@ -793,6 +771,7 @@ def forward(
# Step 4: Spatial mixing with Transformer
# spatial transformer input shape = (B, N, C), output shape = (B, N, C) C: embedding dimension
# M is folded in B.

C = x.shape[-1]
x = x.reshape(B * M, Hp * Wp, C)
x = self.spatial_tr(x)
Expand Down
Loading