Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
.cache/
docs/reference/*
./examples/MP/experiments
./examples/QM9s
./examples/QM9
*doctrees*
/site

Expand Down
48 changes: 48 additions & 0 deletions src/electrai/configs/MP/config_resnet_lcn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Dataset / loader parameters
data:
_target_: src.electrai.dataloader.dataset.RhoRead
root: /scratch/gpfs/ROSENGROUP/common/globus_share_OA/mp/dataset_2/mp_filelist.txt
split_file: null #/scratch/gpfs/ROSENGROUP/common/globus_share_OA/mp/dataset_2/split.json
precision: f32
batch_size: 1
train_workers: 8
val_workers: 2
pin_memory: false
val_frac: 0.005
drop_last: false
augmentation: false
random_seed: 42
# downsample_label: 0
# downsample_data: 0

# Model
model:
_target_: src.electrai.model.resnet_LCN.GeneratorResNet
n_residual_blocks: 32
n_channels: 32
kernel_size1: 5
kernel_size2: 5
normalize: True
use_checkpoint: False
use_lattice_conv: true
use_radial_embedding: true
num_gaussians: 500
use_positional_embedding: true
pos_embed_dim: 500

# Training parameters
precision: 32
epochs: 50
lr: 0.01
weight_decay: 0.0
warmup_length: 1
beta1: 0.9
beta2: 0.99

# Weights and biases
wandb_mode: online
entity: PrinceOA
wb_pname: mp-experiment

# checkpoints
ckpt_path: ./checkpoints
6 changes: 3 additions & 3 deletions src/electrai/dataloader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(self, datapath: str, precision: str, augmentation: bool, **kwargs):
else:
raise ValueError("No filename found.")

self.category = Path(datapath).name.split("_")[0] # example: mp_filelist.txt
self.category = Path(datapath).name.split("_")[0]
self.root = Path(datapath).parent
self.member_list = member_list

Expand All @@ -116,7 +116,7 @@ def __len__(self):

def __getitem__(self, index):
index = self.member_list[index]
data, label = utils.load_numpy_rho(
data, label, lattice = utils.load_numpy_rho(
root=self.root,
category=self.category,
index=index,
Expand All @@ -125,4 +125,4 @@ def __getitem__(self, index):
)
data = data.unsqueeze(0)
label = label.unsqueeze(0)
return {"data": data, "label": label, "index": index}
return {"data": data, "label": label, "index": index, "lattice": lattice}
12 changes: 9 additions & 3 deletions src/electrai/dataloader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,28 @@ def load_numpy_rho(
"""
root = Path(root)
if category == "mp":
data, label = load_chgcar(root, index)
data, label, lattice = load_chgcar(root, index)
elif category == "qm9":
data, label = load_npy(root, index)
data = torch.tensor(data, dtype=dtype_map[precision])
label = torch.tensor(label, dtype=dtype_map[precision])
lattice = torch.tensor(lattice, dtype=dtype_map[precision])
grid_shape = torch.tensor(
data.shape, dtype=dtype_map[precision], device=lattice.device
)
lattice = lattice / grid_shape[:, None]
if augmentation:
data, label = rand_rotate([data, label])
return data, label
return data, label, lattice


def load_chgcar(root: str | bytes | os.PathLike, index: str):
data = Chgcar.from_file(root / "data" / f"{index}.CHGCAR")
label = Chgcar.from_file(root / "label" / f"{index}.CHGCAR")
lattice = data.structure.lattice.matrix
data = data.data["total"] / data.structure.lattice.volume
label = label.data["total"] / label.structure.lattice.volume
return data, label
return data, label, lattice


def load_npy(root: str | bytes | os.PathLike, index: str):
Expand Down
80 changes: 75 additions & 5 deletions src/electrai/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def __init__(self, cfg):
self.model = instantiate(cfg.model)
self.loss_fn = NormMAE()

def forward(self, x):
return self.model(x)
def forward(self, x, lattice_vectors=None):
return self.model(x, lattice_vectors)

def training_step(self, batch):
loss = self._loss_calculation(batch)
Expand All @@ -27,6 +27,13 @@ def training_step(self, batch):
on_epoch=True,
sync_dist=False,
)
if hasattr(self.model, "conv1") and hasattr(
self.model.conv1, "last_debug_stats"
):
stats = self.model.conv1.last_debug_stats
for key, values in stats.items():
for metric, val in values.items():
self.log(f"debug/{key}/{metric}", val, on_step=True, on_epoch=False)
return loss

def validation_step(self, batch):
Expand All @@ -36,18 +43,81 @@ def validation_step(self, batch):
)
return loss

# def _log_gaussian_params(self, prefix="train_"):
# for name, module in self.model.named_modules():
# if isinstance(module, torch.nn.Module) and hasattr(
# module, "gaussian_smear"
# ):
# gaussian_smear = module.gaussian_smear

# if hasattr(gaussian_smear, "centers"):
# centers = gaussian_smear.centers
# self.log(
# f"{prefix}gaussian/centers_mean",
# centers.mean(),
# on_step=True,
# on_epoch=True,
# )
# self.log(
# f"{prefix}gaussian/centers_std",
# centers.std(),
# on_step=True,
# on_epoch=True,
# )
# self.log(
# f"{prefix}gaussian/centers_min",
# centers.min(),
# on_step=True,
# on_epoch=True,
# )
# self.log(
# f"{prefix}gaussian/centers_max",
# centers.max(),
# on_step=True,
# on_epoch=True,
# )

# if hasattr(gaussian_smear, "widths"):
# widths = gaussian_smear.widths
# self.log(
# f"{prefix}gaussian/widths_mean",
# widths.mean(),
# on_step=True,
# on_epoch=True,
# )
# self.log(
# f"{prefix}gaussian/widths_std",
# widths.std(),
# on_step=True,
# on_epoch=True,
# )
# self.log(
# f"{prefix}gaussian/widths_min",
# widths.min(),
# on_step=True,
# on_epoch=True,
# )
# self.log(
# f"{prefix}gaussian/widths_max",
# widths.max(),
# on_step=True,
# on_epoch=True,
# )
# break

def _loss_calculation(self, batch):
x = batch["data"]
y = batch["label"]
A = batch["lattice"]
if isinstance(x, list):
losses = []
for x_i, y_i in zip(x, y, strict=True):
pred = self(x_i.unsqueeze(0))
for x_i, y_i, A_i in zip(x, y, strict=True):
pred = self(x_i.unsqueeze(0), A_i.unsqueeze(0))
loss = self.loss_fn(pred, y_i.unsqueeze(0))
losses.append(loss)
loss = torch.stack(losses).mean()
else:
pred = self(x)
pred = self(x, A)
loss = self.loss_fn(pred, y)
return loss

Expand Down
Loading