|
- """
- warm up step learning rate.
- """
- from collections import Counter
- import numpy as np
-
- from .linear_warmup import linear_warmup_lr
-
-
- def lr_steps(global_step, lr_init, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
- """Set learning rate."""
- lr_each_step = []
- total_steps = steps_per_epoch * total_epochs
- warmup_steps = steps_per_epoch * warmup_epochs
- if warmup_steps != 0:
- inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
- else:
- inc_each_step = 0
- for i in range(total_steps):
- if i < warmup_steps:
- lr_value = float(lr_init) + inc_each_step * float(i)
- else:
- base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
- lr_value = float(lr_max) * base * base
- if lr_value < 0.0:
- lr_value = 0.0
- lr_each_step.append(lr_value)
-
- current_step = global_step
- lr_each_step = np.array(lr_each_step).astype(np.float32)
- learning_rate = lr_each_step[current_step:]
-
- return learning_rate
-
-
- def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
- """warmup_step_lr"""
- base_lr = lr
- warmup_init_lr = 0
- total_steps = int(max_epoch * steps_per_epoch)
- warmup_steps = int(warmup_epochs * steps_per_epoch)
- milestones = lr_epochs
- milestones_steps = []
- for milestone in milestones:
- milestones_step = milestone * steps_per_epoch
- milestones_steps.append(milestones_step)
-
- lr_each_step = []
- lr = base_lr
- milestones_steps_counter = Counter(milestones_steps)
- for i in range(total_steps):
- if i < warmup_steps:
- lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
- else:
- lr = lr * gamma**milestones_steps_counter[i]
- lr_each_step.append(lr)
-
- return np.array(lr_each_step).astype(np.float32)
-
-
- def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
- return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
-
-
- def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
- lr_epochs = []
- for i in range(1, max_epoch):
- if i % epoch_size == 0:
- lr_epochs.append(i)
- return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)
|