Skip to content

Commit 597e086

Browse files
committed
Feat: Implement GradientAccumulation for SupervisedTrainer (Issue #6100)
Closes #6100 Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent 894068a commit 597e086

3 files changed

Lines changed: 474 additions & 0 deletions

File tree

monai/engines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .trainer import AdversarialTrainer, GanTrainer, SupervisedTrainer, Trainer
1616
from .utils import (
1717
DiffusionPrepareBatch,
18+
GradientAccumulation,
1819
IterationEvents,
1920
PrepareBatch,
2021
PrepareBatchDefault,

monai/engines/utils.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
"default_make_latent",
4242
"engine_apply_transform",
4343
"default_metric_cmp_fn",
44+
"GradientAccumulation",
4445
]
4546

4647

@@ -360,3 +361,121 @@ def default_metric_cmp_fn(current_metric: float, prev_best: float) -> bool:
360361
361362
"""
362363
return current_metric > prev_best
364+
365+
366+
def _noop(*args: Any, **kwargs: Any) -> None:
367+
"""No-op callable used to suppress optimizer/scaler methods during gradient accumulation."""
368+
369+
370+
class GradientAccumulation:
371+
"""
372+
Callable class implementing gradient accumulation for use with ``SupervisedTrainer``.
373+
374+
Gradients are accumulated over ``accumulation_steps`` mini-batches before calling
375+
``optimizer.step()``, simulating a larger effective batch size on memory-constrained
376+
hardware.
377+
378+
Pass an instance as ``iteration_update`` when constructing ``SupervisedTrainer``::
379+
380+
trainer = SupervisedTrainer(
381+
...,
382+
iteration_update=GradientAccumulation(accumulation_steps=4),
383+
)
384+
385+
All ``IterationEvents`` (``FORWARD_COMPLETED``, ``LOSS_COMPLETED``,
386+
``BACKWARD_COMPLETED``, ``MODEL_COMPLETED``) still fire on every mini-batch, so
387+
existing handlers (checkpoint savers, metric loggers, etc.) are unaffected.
388+
389+
When ``epoch_length`` is known, the optimizer is flushed at the end of each epoch
390+
even if ``epoch_length % accumulation_steps != 0``, so no gradients are silently
391+
discarded. For iterable datasets (``epoch_length=None``) this flush does not apply.
392+
393+
The loss stored in ``engine.state.output[Keys.LOSS]`` is the **unscaled**
394+
original loss value, so metrics and loggers report the true loss. Internally
395+
the loss is divided by ``accumulation_steps`` for the backward pass only.
396+
397+
Args:
398+
accumulation_steps: number of mini-batches to accumulate before updating
399+
weights. Must be a positive integer. Default: 2.
400+
401+
Raises:
402+
ValueError: when ``accumulation_steps`` is not a positive integer.
403+
"""
404+
405+
def __init__(self, accumulation_steps: int = 2) -> None:
406+
if not isinstance(accumulation_steps, int) or accumulation_steps < 1:
407+
raise ValueError(
408+
f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}."
409+
)
410+
self.accumulation_steps = accumulation_steps
411+
412+
def __repr__(self) -> str:
413+
return f"GradientAccumulation(accumulation_steps={self.accumulation_steps})"
414+
415+
def __call__(self, engine: Any, batchdata: dict) -> dict:
416+
"""
417+
Execute one iteration with gradient accumulation.
418+
419+
Args:
420+
engine: the Ignite engine (usually ``SupervisedTrainer``).
421+
batchdata: batch data for this iteration.
422+
423+
Returns:
424+
the output dict from ``engine._iteration()``.
425+
"""
426+
acc = self.accumulation_steps
427+
428+
if acc == 1:
429+
return engine._iteration(engine, batchdata)
430+
431+
# engine.state.iteration is 1-indexed and already incremented before __call__
432+
epoch_length = engine.state.epoch_length # None for iterable datasets
433+
if epoch_length is not None:
434+
local_iter = (engine.state.iteration - 1) % epoch_length # 0-indexed within epoch
435+
should_zero_grad = local_iter % acc == 0
436+
should_step = (local_iter + 1) % acc == 0 or (local_iter + 1) == epoch_length
437+
else:
438+
local_iter = engine.state.iteration - 1 # 0-indexed global
439+
should_zero_grad = local_iter % acc == 0
440+
should_step = (local_iter + 1) % acc == 0
441+
442+
# Save and conditionally suppress zero_grad. Only clear gradients at the start of an accumulation cycle.
443+
original_zero_grad = engine.optimizer.zero_grad
444+
if not should_zero_grad:
445+
engine.optimizer.zero_grad = _noop
446+
447+
# Save and wrap loss_function to scale by 1/accumulation_steps. This ensures the per-mini-batch
448+
# gradient contribution is correct: the scaled loss will be backpropagated, and accumulated gradients
449+
# will average to the same value they would with the full batch.
450+
original_loss_fn = engine.loss_function
451+
engine.loss_function = lambda *args, **kwargs: original_loss_fn(*args, **kwargs) / acc
452+
453+
# Save and conditionally suppress optimizer.step. Only update weights at the end of an accumulation cycle.
454+
# Also patch GradScaler.step and GradScaler.update when step is suppressed, for mixed-precision training.
455+
original_step = engine.optimizer.step
456+
original_scaler_step = None
457+
original_scaler_update = None
458+
if not should_step:
459+
engine.optimizer.step = _noop
460+
if hasattr(engine, "scaler") and engine.scaler is not None:
461+
original_scaler_step = engine.scaler.step
462+
original_scaler_update = engine.scaler.update
463+
engine.scaler.step = _noop
464+
engine.scaler.update = _noop
465+
466+
try:
467+
result = engine._iteration(engine, batchdata)
468+
finally:
469+
engine.optimizer.zero_grad = original_zero_grad
470+
engine.loss_function = original_loss_fn
471+
engine.optimizer.step = original_step
472+
if original_scaler_step is not None:
473+
engine.scaler.step = original_scaler_step
474+
engine.scaler.update = original_scaler_update
475+
476+
# Restore the unscaled loss for logging and metrics. The backward pass
477+
# already used the scaled value, so this only affects what handlers see.
478+
if CommonKeys.LOSS in result:
479+
result[CommonKeys.LOSS] = result[CommonKeys.LOSS] * acc
480+
481+
return result

0 commit comments

Comments
 (0)