Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions examples/lightning/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import math
import os

from dataclasses import _MISSING_TYPE
from dataclasses import dataclass

import datasets
Expand All @@ -22,6 +21,8 @@
from liger_kernel.transformers import AutoLigerKernelForCausalLM
from liger_kernel.utils import infer_device

from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

_RETAIN_COLUMNS = {"input_ids", "attention_mask", "labels"}
QUESTION = "<Question>"
CHOICES = "<Choices>"
Expand All @@ -33,7 +34,6 @@ class Args:
data: str = "cais/mmlu"
output_dir: str = "mmlu_finetuning"
max_length: int = 2048
# for llam3 8B model, deepspeed will OOM with 16 on 8XA100 80G and 8 will OOM with 8XA100 40G
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why removing comments?

batch_size: int = 4
lr: float = 6e-6
weight_decay: float = 0.05
Expand All @@ -46,10 +46,8 @@ class Args:
def warmup_cosine_schedule(warmup_steps, total_steps, min_lr=0):
def lr_lambda(current_step):
if current_step < warmup_steps:
# Linear warmup
return float(current_step) / float(max(1, warmup_steps))
else:
# Cosine annealing
progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
return max(min_lr, 0.5 * (1 + math.cos(math.pi * progress)))

Expand All @@ -61,7 +59,7 @@ def parse_args() -> Args:
for k, v in Args.__dataclass_fields__.items():
parser.add_argument(f"--{k}", type=v.type, default=v.default)
parsed = parser.parse_args()
return Args(**{k: v for k, v in vars(parsed).items() if not isinstance(v, _MISSING_TYPE)})
return Args(**vars(parsed))


class LanguageModel(pl.LightningModule):
Expand All @@ -72,7 +70,6 @@ def __init__(self, args: Args, tokenizer):
self.model = None

def configure_model(self):
# https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html#speed-up-model-initialization
if self.model is not None:
return
self.model = AutoLigerKernelForCausalLM.from_pretrained(
Expand All @@ -89,7 +86,7 @@ def training_step(self, batch):
outputs = self.model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"],
labels=batch.get("labels"),
)
loss = outputs.loss
self.log_dict(
Expand All @@ -107,11 +104,11 @@ def validation_step(self, batch):
outputs = self.model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"],
labels=batch.get("labels"),
)
loss = outputs.loss
self.log_dict(
{"val_loss": outputs.loss},
{"val_loss": loss},
on_step=True,
on_epoch=True,
prog_bar=True,
Expand Down Expand Up @@ -182,10 +179,10 @@ def setup(self, stage) -> None:
dataset = datasets.load_dataset(self.args.data, "auxiliary_train")
flattened_data = [
{
"answer": x["train"]["answer"],
"choices": x["train"]["choices"],
"question": x["train"]["question"],
"subject": x["train"]["subject"],
"answer": x["answer"],
"choices": x["choices"],
"question": x["question"],
"subject": x["subject"],
}
for x in dataset["train"]
]
Expand Down Expand Up @@ -237,11 +234,11 @@ def train():

if args.strategy == "fsdp":
strategy = FSDPStrategy(
auto_wrap_policy=layers,
auto_wrap_policy=transformer_auto_wrap_policy,
sharding_strategy="FULL_SHARD",
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
sync_module_states=True,
activation_checkpointing_policy=layers,
activation_checkpointing_policy=transformer_auto_wrap_policy,
mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16),
forward_prefetch=True,
)
Expand All @@ -251,16 +248,18 @@ def train():
precision = "bf16-mixed"
elif args.strategy == "ddp":
strategy = "ddp"
precision = "bf16-true"
precision = "bf16-mixed"
else:
strategy = "auto"
precision = "bf16-true"
precision = "bf16-mixed"

device = infer_device()
devices = args.num_gpu or (torch.cuda.device_count() if torch.cuda.is_available() else 1)

trainer = pl.Trainer(
accelerator=device,
strategy=strategy,
devices=(getattr(torch, device).device_count() if args.num_gpu is None else args.num_gpu),
devices=devices,
default_root_dir=args.output_dir,
log_every_n_steps=1,
max_epochs=1,
Expand Down