|
- from torch import optim as optim
-
-
- def build_optimizer(config, model):
- """
- Build optimizer, set weight decay of normalization to 0 by default.
- """
- skip = {}
- skip_keywords = {}
- if hasattr(model, 'no_weight_decay'):
- skip = model.no_weight_decay()
- if hasattr(model, 'no_weight_decay_keywords'):
- skip_keywords = model.no_weight_decay_keywords()
- parameters = set_weight_decay(model, skip, skip_keywords,config.TRAIN.BASE_LR)
-
- opt_lower = config.TRAIN.OPTIMIZER.NAME.lower()
- optimizer = None
- if opt_lower == 'sgd':
- optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True,
- lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
- elif opt_lower == 'adamw':
- optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS,
- lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY)
-
- return optimizer
-
- # def set_weight_decay(model, skip_list=(), skip_keywords=(),lr=0.0):
- # has_decay = []
- # no_decay = []
- # high_lr = []
- # for name, param in model.named_parameters():
- # if not param.requires_grad:
- # continue # frozen weights
- # if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
- # check_keywords_in_name(name, skip_keywords):
- # if 'meta' in name:
- # high_lr.append(param)
- # else:
- # no_decay.append(param)
- # # print(f"{name} has no weight decay")
- # else:
- # has_decay.append(param)
- # return [{'params': has_decay},
- # # {'params':high_lr,'weight_decay': 0.,'lr':lr*10},
- # {'params':high_lr,'lr':lr*20},
- # {'params': no_decay, 'weight_decay': 0.}]
-
- def set_weight_decay(model, skip_list=(), skip_keywords=(),lr=0.0):
- has_decay = []
- no_decay = []
-
- for name, param in model.named_parameters():
- if not param.requires_grad:
- continue # frozen weights
- if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \
- check_keywords_in_name(name, skip_keywords):
- no_decay.append(param)
- # print(f"{name} has no weight decay")
- else:
- has_decay.append(param)
- return [{'params': has_decay},
- {'params': no_decay, 'weight_decay': 0.}]
-
-
- def check_keywords_in_name(name, keywords=()):
- isin = False
- for keyword in keywords:
- if keyword in name:
- isin = True
- return isin
|