From 7510fbbc88eaa36724ac9082a4b3e9c52976ce10 Mon Sep 17 00:00:00 2001 From: Yanghan Wang Date: Thu, 15 Dec 2022 16:45:53 -0800 Subject: [PATCH] use "legacy" dataclass at operator level and separate TestNetOutput from TrainNetOutput Summary: Pull Request resolved: https://github.com/facebookresearch/d2go/pull/444 Differential Revision: D41828774 fbshipit-source-id: 084248e21de6cffd6c3b6b3ff673bd77d08b5a97 --- d2go/trainer/api.py | 14 ++++++++++---- tools/lightning_train_net.py | 24 ++++++++++++++---------- tools/train_net.py | 8 +++----- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/d2go/trainer/api.py b/d2go/trainer/api.py index 19b2133f..7fa1e83a 100644 --- a/d2go/trainer/api.py +++ b/d2go/trainer/api.py @@ -11,14 +11,20 @@ from d2go.evaluation.api import AccuracyDict, MetricsDict -# TODO (T127368935) Split to TrainNetOutput and TestNetOutput + @dataclass class TrainNetOutput: accuracy: AccuracyDict[float] metrics: MetricsDict[float] - # Optional, because we use None to distinguish "not used" from - # empty model configs. With T127368935, this should be reverted to dict. - model_configs: Optional[Dict[str, str]] + model_configs: Dict[str, str] + # TODO (T127368603): decide if `tensorboard_log_dir` should be part of output + tensorboard_log_dir: Optional[str] = None + + +@dataclass +class TestNetOutput: + accuracy: AccuracyDict[float] + metrics: MetricsDict[float] # TODO (T127368603): decide if `tensorboard_log_dir` should be part of output tensorboard_log_dir: Optional[str] = None diff --git a/tools/lightning_train_net.py b/tools/lightning_train_net.py index f8668724..b9eaaf1b 100644 --- a/tools/lightning_train_net.py +++ b/tools/lightning_train_net.py @@ -12,7 +12,7 @@ from d2go.runner.callbacks.quantization import QuantizationAwareTraining from d2go.runner.lightning_task import DefaultTask from d2go.setup import basic_argument_parser, prepare_for_launch, setup_after_launch -from d2go.trainer.api import TrainNetOutput +from d2go.trainer.api import TestNetOutput, TrainNetOutput from d2go.trainer.helper import parse_precision_from_string from d2go.trainer.lightning.training_loop import _do_test, _do_train from detectron2.utils.file_io import PathManager @@ -103,7 +103,7 @@ def main( output_dir: str, runner_class: Union[str, Type[DefaultTask]], eval_only: bool = False, -) -> TrainNetOutput: +) -> Union[TrainNetOutput, TestNetOutput]: """Main function for launching a training with lightning trainer Args: cfg: D2go config node @@ -123,18 +123,22 @@ def main( logger.info(f"Resuming training from checkpoint: {last_checkpoint}.") trainer = pl.Trainer(**trainer_params) - model_configs = None + if eval_only: _do_test(trainer, task) + return TestNetOutput( + tensorboard_log_dir=trainer_params["logger"].log_dir, + accuracy=task.eval_res, + metrics=task.eval_res, + ) else: model_configs = _do_train(cfg, trainer, task) - - return TrainNetOutput( - tensorboard_log_dir=trainer_params["logger"].log_dir, - accuracy=task.eval_res, - metrics=task.eval_res, - model_configs=model_configs, - ) + return TrainNetOutput( + tensorboard_log_dir=trainer_params["logger"].log_dir, + accuracy=task.eval_res, + metrics=task.eval_res, + model_configs=model_configs, + ) def argument_parser(): diff --git a/tools/train_net.py b/tools/train_net.py index fb4f7ec8..3612398a 100644 --- a/tools/train_net.py +++ b/tools/train_net.py @@ -9,7 +9,6 @@ import sys from typing import List, Type, Union -import detectron2.utils.comm as comm from d2go.config import CfgNode from d2go.distributed import launch from d2go.runner import BaseRunner @@ -22,7 +21,7 @@ setup_before_launch, setup_root_logger, ) -from d2go.trainer.api import TrainNetOutput +from d2go.trainer.api import TestNetOutput, TrainNetOutput from d2go.trainer.fsdp import create_ddp_model_with_sharding from d2go.utils.misc import ( dump_trained_model_configs, @@ -40,7 +39,7 @@ def main( runner_class: Union[str, Type[BaseRunner]], eval_only: bool = False, resume: bool = True, # NOTE: always enable resume when running on cluster -) -> TrainNetOutput: +) -> Union[TrainNetOutput, TestNetOutput]: runner = setup_after_launch(cfg, output_dir, runner_class) model = runner.build_model(cfg) @@ -58,9 +57,8 @@ def main( model.eval() metrics = runner.do_test(cfg, model, train_iter=train_iter) print_metrics_table(metrics) - return TrainNetOutput( + return TestNetOutput( accuracy=metrics, - model_configs={}, metrics=metrics, )