|
41 | 41 | "default_make_latent", |
42 | 42 | "engine_apply_transform", |
43 | 43 | "default_metric_cmp_fn", |
| 44 | + "GradientAccumulation", |
44 | 45 | ] |
45 | 46 |
|
46 | 47 |
|
@@ -360,3 +361,121 @@ def default_metric_cmp_fn(current_metric: float, prev_best: float) -> bool: |
360 | 361 |
|
361 | 362 | """ |
362 | 363 | return current_metric > prev_best |
| 364 | + |
| 365 | + |
| 366 | +def _noop(*args: Any, **kwargs: Any) -> None: |
| 367 | + """No-op callable used to suppress optimizer/scaler methods during gradient accumulation.""" |
| 368 | + |
| 369 | + |
| 370 | +class GradientAccumulation: |
| 371 | + """ |
| 372 | + Callable class implementing gradient accumulation for use with ``SupervisedTrainer``. |
| 373 | +
|
| 374 | + Gradients are accumulated over ``accumulation_steps`` mini-batches before calling |
| 375 | + ``optimizer.step()``, simulating a larger effective batch size on memory-constrained |
| 376 | + hardware. |
| 377 | +
|
| 378 | + Pass an instance as ``iteration_update`` when constructing ``SupervisedTrainer``:: |
| 379 | +
|
| 380 | + trainer = SupervisedTrainer( |
| 381 | + ..., |
| 382 | + iteration_update=GradientAccumulation(accumulation_steps=4), |
| 383 | + ) |
| 384 | +
|
| 385 | + All ``IterationEvents`` (``FORWARD_COMPLETED``, ``LOSS_COMPLETED``, |
| 386 | + ``BACKWARD_COMPLETED``, ``MODEL_COMPLETED``) still fire on every mini-batch, so |
| 387 | + existing handlers (checkpoint savers, metric loggers, etc.) are unaffected. |
| 388 | +
|
| 389 | + When ``epoch_length`` is known, the optimizer is flushed at the end of each epoch |
| 390 | + even if ``epoch_length % accumulation_steps != 0``, so no gradients are silently |
| 391 | + discarded. For iterable datasets (``epoch_length=None``) this flush does not apply. |
| 392 | +
|
| 393 | + The loss stored in ``engine.state.output[Keys.LOSS]`` is the **unscaled** |
| 394 | + original loss value, so metrics and loggers report the true loss. Internally |
| 395 | + the loss is divided by ``accumulation_steps`` for the backward pass only. |
| 396 | +
|
| 397 | + Args: |
| 398 | + accumulation_steps: number of mini-batches to accumulate before updating |
| 399 | + weights. Must be a positive integer. Default: 2. |
| 400 | +
|
| 401 | + Raises: |
| 402 | + ValueError: when ``accumulation_steps`` is not a positive integer. |
| 403 | + """ |
| 404 | + |
| 405 | + def __init__(self, accumulation_steps: int = 2) -> None: |
| 406 | + if not isinstance(accumulation_steps, int) or accumulation_steps < 1: |
| 407 | + raise ValueError( |
| 408 | + f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}." |
| 409 | + ) |
| 410 | + self.accumulation_steps = accumulation_steps |
| 411 | + |
| 412 | + def __repr__(self) -> str: |
| 413 | + return f"GradientAccumulation(accumulation_steps={self.accumulation_steps})" |
| 414 | + |
| 415 | + def __call__(self, engine: Any, batchdata: dict) -> dict: |
| 416 | + """ |
| 417 | + Execute one iteration with gradient accumulation. |
| 418 | +
|
| 419 | + Args: |
| 420 | + engine: the Ignite engine (usually ``SupervisedTrainer``). |
| 421 | + batchdata: batch data for this iteration. |
| 422 | +
|
| 423 | + Returns: |
| 424 | + the output dict from ``engine._iteration()``. |
| 425 | + """ |
| 426 | + acc = self.accumulation_steps |
| 427 | + |
| 428 | + if acc == 1: |
| 429 | + return engine._iteration(engine, batchdata) |
| 430 | + |
| 431 | + # engine.state.iteration is 1-indexed and already incremented before __call__ |
| 432 | + epoch_length = engine.state.epoch_length # None for iterable datasets |
| 433 | + if epoch_length is not None: |
| 434 | + local_iter = (engine.state.iteration - 1) % epoch_length # 0-indexed within epoch |
| 435 | + should_zero_grad = local_iter % acc == 0 |
| 436 | + should_step = (local_iter + 1) % acc == 0 or (local_iter + 1) == epoch_length |
| 437 | + else: |
| 438 | + local_iter = engine.state.iteration - 1 # 0-indexed global |
| 439 | + should_zero_grad = local_iter % acc == 0 |
| 440 | + should_step = (local_iter + 1) % acc == 0 |
| 441 | + |
| 442 | + # Save and conditionally suppress zero_grad. Only clear gradients at the start of an accumulation cycle. |
| 443 | + original_zero_grad = engine.optimizer.zero_grad |
| 444 | + if not should_zero_grad: |
| 445 | + engine.optimizer.zero_grad = _noop |
| 446 | + |
| 447 | + # Save and wrap loss_function to scale by 1/accumulation_steps. This ensures the per-mini-batch |
| 448 | + # gradient contribution is correct: the scaled loss will be backpropagated, and accumulated gradients |
| 449 | + # will average to the same value they would with the full batch. |
| 450 | + original_loss_fn = engine.loss_function |
| 451 | + engine.loss_function = lambda *args, **kwargs: original_loss_fn(*args, **kwargs) / acc |
| 452 | + |
| 453 | + # Save and conditionally suppress optimizer.step. Only update weights at the end of an accumulation cycle. |
| 454 | + # Also patch GradScaler.step and GradScaler.update when step is suppressed, for mixed-precision training. |
| 455 | + original_step = engine.optimizer.step |
| 456 | + original_scaler_step = None |
| 457 | + original_scaler_update = None |
| 458 | + if not should_step: |
| 459 | + engine.optimizer.step = _noop |
| 460 | + if hasattr(engine, "scaler") and engine.scaler is not None: |
| 461 | + original_scaler_step = engine.scaler.step |
| 462 | + original_scaler_update = engine.scaler.update |
| 463 | + engine.scaler.step = _noop |
| 464 | + engine.scaler.update = _noop |
| 465 | + |
| 466 | + try: |
| 467 | + result = engine._iteration(engine, batchdata) |
| 468 | + finally: |
| 469 | + engine.optimizer.zero_grad = original_zero_grad |
| 470 | + engine.loss_function = original_loss_fn |
| 471 | + engine.optimizer.step = original_step |
| 472 | + if original_scaler_step is not None: |
| 473 | + engine.scaler.step = original_scaler_step |
| 474 | + engine.scaler.update = original_scaler_update |
| 475 | + |
| 476 | + # Restore the unscaled loss for logging and metrics. The backward pass |
| 477 | + # already used the scaled value, so this only affects what handlers see. |
| 478 | + if CommonKeys.LOSS in result: |
| 479 | + result[CommonKeys.LOSS] = result[CommonKeys.LOSS] * acc |
| 480 | + |
| 481 | + return result |
0 commit comments