Skip to content
This repository was archived by the owner on Jan 22, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions d2go/trainer/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,20 @@

from d2go.evaluation.api import AccuracyDict, MetricsDict

# TODO (T127368935) Split to TrainNetOutput and TestNetOutput

@dataclass
class TrainNetOutput:
accuracy: AccuracyDict[float]
metrics: MetricsDict[float]
# Optional, because we use None to distinguish "not used" from
# empty model configs. With T127368935, this should be reverted to dict.
model_configs: Optional[Dict[str, str]]
model_configs: Dict[str, str]
# TODO (T127368603): decide if `tensorboard_log_dir` should be part of output
tensorboard_log_dir: Optional[str] = None


@dataclass
class TestNetOutput:
accuracy: AccuracyDict[float]
metrics: MetricsDict[float]
# TODO (T127368603): decide if `tensorboard_log_dir` should be part of output
tensorboard_log_dir: Optional[str] = None

Expand Down
24 changes: 14 additions & 10 deletions tools/lightning_train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from d2go.runner.callbacks.quantization import QuantizationAwareTraining
from d2go.runner.lightning_task import DefaultTask
from d2go.setup import basic_argument_parser, prepare_for_launch, setup_after_launch
from d2go.trainer.api import TrainNetOutput
from d2go.trainer.api import TestNetOutput, TrainNetOutput
from d2go.trainer.helper import parse_precision_from_string
from d2go.trainer.lightning.training_loop import _do_test, _do_train
from detectron2.utils.file_io import PathManager
Expand Down Expand Up @@ -103,7 +103,7 @@ def main(
output_dir: str,
runner_class: Union[str, Type[DefaultTask]],
eval_only: bool = False,
) -> TrainNetOutput:
) -> Union[TrainNetOutput, TestNetOutput]:
"""Main function for launching a training with lightning trainer
Args:
cfg: D2go config node
Expand All @@ -123,18 +123,22 @@ def main(
logger.info(f"Resuming training from checkpoint: {last_checkpoint}.")

trainer = pl.Trainer(**trainer_params)
model_configs = None

if eval_only:
_do_test(trainer, task)
return TestNetOutput(
tensorboard_log_dir=trainer_params["logger"].log_dir,
accuracy=task.eval_res,
metrics=task.eval_res,
)
else:
model_configs = _do_train(cfg, trainer, task)

return TrainNetOutput(
tensorboard_log_dir=trainer_params["logger"].log_dir,
accuracy=task.eval_res,
metrics=task.eval_res,
model_configs=model_configs,
)
return TrainNetOutput(
tensorboard_log_dir=trainer_params["logger"].log_dir,
accuracy=task.eval_res,
metrics=task.eval_res,
model_configs=model_configs,
)


def argument_parser():
Expand Down
8 changes: 3 additions & 5 deletions tools/train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import sys
from typing import List, Type, Union

import detectron2.utils.comm as comm
from d2go.config import CfgNode
from d2go.distributed import launch
from d2go.runner import BaseRunner
Expand All @@ -22,7 +21,7 @@
setup_before_launch,
setup_root_logger,
)
from d2go.trainer.api import TrainNetOutput
from d2go.trainer.api import TestNetOutput, TrainNetOutput
from d2go.trainer.fsdp import create_ddp_model_with_sharding
from d2go.utils.misc import (
dump_trained_model_configs,
Expand All @@ -40,7 +39,7 @@ def main(
runner_class: Union[str, Type[BaseRunner]],
eval_only: bool = False,
resume: bool = True, # NOTE: always enable resume when running on cluster
) -> TrainNetOutput:
) -> Union[TrainNetOutput, TestNetOutput]:
runner = setup_after_launch(cfg, output_dir, runner_class)

model = runner.build_model(cfg)
Expand All @@ -58,9 +57,8 @@ def main(
model.eval()
metrics = runner.do_test(cfg, model, train_iter=train_iter)
print_metrics_table(metrics)
return TrainNetOutput(
return TestNetOutput(
accuracy=metrics,
model_configs={},
metrics=metrics,
)

Expand Down