Skip to content

Commit 2424224

Browse files
authored
hot fix for MultiStepLR_HotFix
1 parent 67a2654 commit 2424224

1 file changed

Lines changed: 9 additions & 3 deletions

File tree

FastAutoAugment/lr_scheduler.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
2+
from torch.optim.lr_scheduler import MultiStepLR
33
from theconf import Config as C
44

55

@@ -10,8 +10,14 @@ def adjust_learning_rate_resnet(optimizer):
1010
"""
1111

1212
if C.get()['epoch'] == 90:
13-
return torch.optim.lr_scheduler.MultiStepLR(optimizer, [30, 60, 80])
13+
return MultiStepLR_HotFix(optimizer, [30, 60, 80])
1414
elif C.get()['epoch'] == 270: # autoaugment
15-
return torch.optim.lr_scheduler.MultiStepLR(optimizer, [90, 180, 240])
15+
return MultiStepLR_HotFix(optimizer, [90, 180, 240])
1616
else:
1717
raise ValueError('invalid epoch=%d for resnet scheduler' % C.get()['epoch'])
18+
19+
20+
class MultiStepLR_HotFix(MultiStepLR):
21+
def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1):
22+
super(MultiStepLR_HotFix, self).__init__(optimizer, milestones, gamma, last_epoch)
23+
self.milestones = list(milestones)

0 commit comments

Comments
 (0)