diff --git a/d2go/modeling/backbone/fbnet_cfg.py b/d2go/modeling/backbone/fbnet_cfg.py index dfc22162..38be0afc 100644 --- a/d2go/modeling/backbone/fbnet_cfg.py +++ b/d2go/modeling/backbone/fbnet_cfg.py @@ -66,6 +66,9 @@ def add_fbnet_v2_default_configs(_C): # https://github.com/rbgirshick/yacs/blob/master/yacs/config.py#L410 _C.MODEL.FBNET_V2.NORM_ARGS = [] + # Set up operators fusion option in the backbone's blocks + _C.MODEL.FBNET_V2.FUSE_OPS = False + _C.MODEL.VT_FPN = CN() _C.MODEL.VT_FPN.IN_FEATURES = ["res2", "res3", "res4", "res5"] diff --git a/d2go/modeling/backbone/fbnet_v2.py b/d2go/modeling/backbone/fbnet_v2.py index fcfb50b9..078e88a4 100644 --- a/d2go/modeling/backbone/fbnet_v2.py +++ b/d2go/modeling/backbone/fbnet_v2.py @@ -197,12 +197,16 @@ def build_fbnet(cfg, name, in_channels): stages = [] trunk_stride_per_stage = _get_stride_per_stage(arch_def_blocks) shape_spec_per_stage = [] + fuse_ops = getattr(cfg.MODEL.FBNET_V2, "FUSE_OPS", False) + is_qat = getattr(cfg, "QAT", False) for i, stride_i in enumerate(trunk_stride_per_stage): stages.append( builder.build_blocks( arch_def_blocks, stage_indices=[i], prefix_name=FBNET_BUILDER_IDENTIFIER + "_", + fuse_ops=fuse_ops, + is_qat=is_qat, ) ) shape_spec_per_stage.append(