|
-
- import numpy as np
- import torch
- from torch import nn
-
-
- class NBeatsBlock(nn.Module):
- """
- N-BEATS basic block
- """
- def __init__(self, input_size, layer_size, output_size, theta_size=None, n_layers=4,
- block_type='generic', poly_degree=None, harmonics=None):
- super(NBeatsBlock, self).__init__()
- self.layers = nn.ModuleList([nn.Linear(in_features=input_size, out_features=layer_size)] +
- [nn.Linear(in_features=layer_size, out_features=layer_size)
- for _ in range(n_layers)])
- if block_type == 'generic':
- if theta_size is None:
- theta_size = input_size + output_size
- self.out_layer = nn.Linear(in_features=layer_size, out_features=theta_size)
- self.basis_function = GenericBasis(backcast_size=input_size,
- forecast_size=output_size)
- elif block_type == 'trend':
- if theta_size is None:
- theta_size = 2 * (poly_degree + 1)
-
- self.out_layer = nn.Linear(in_features=layer_size, out_features=theta_size)
- self.basis_function = TrendBasis(backcast_size=input_size,
- forecast_size=output_size,
- poly_degree=poly_degree)
- elif block_type == 'seasonality':
- if theta_size is None:
- theta_size = 4 * int(np.ceil(harmonics / 2 * output_size) - (harmonics - 1))
-
- self.out_layer = nn.Linear(in_features=layer_size, out_features=theta_size)
- self.basis_function = SeasonalityBasis(backcast_size=input_size,
- forecast_size=output_size,
- harmonics=harmonics)
- else:
- raise ValueError("Argument block_type should be one of ['generic', 'trend', 'seasonality'].")
-
- def forward(self, x: torch.tensor):
- layer_input = x.float()
- for layer in self.layers:
- layer_input = torch.relu(layer(layer_input))
- block_out = self.out_layer(layer_input)
- block_out = self.basis_function(block_out)
- return block_out
-
-
- class GenericBasis(nn.Module):
- """
- Generic basis for block output
- """
- def __init__(self, backcast_size, forecast_size):
- super(GenericBasis, self).__init__()
- self.backcast_size = backcast_size
- self.forecast_size = forecast_size
-
- def forward(self, theta):
- backcast = theta[:, :self.backcast_size]
- forecast = theta[:, -self.forecast_size:]
- return backcast, forecast
-
-
- class TrendBasis(nn.Module):
- """
- Trend basis for block output
- """
- def __init__(self, backcast_size, forecast_size, poly_degree):
- super(TrendBasis, self).__init__()
- self.poly_degree = poly_degree + 1
- self.backcast_base = nn.Parameter(torch.tensor(
- np.concatenate([[np.power(np.arange(backcast_size) / backcast_size, i)]
- for i in range(self.poly_degree)]), dtype=torch.float32),
- requires_grad=False)
- self.forecast_base = nn.Parameter(torch.tensor(
- np.concatenate([[np.power(np.arange(forecast_size) / forecast_size, i)]
- for i in range(self.poly_degree)]), dtype=torch.float32),
- requires_grad=False)
-
- def forward(self, theta):
- backcast = torch.matmul(theta[:, self.poly_degree:], self.backcast_base)
- forecast = torch.matmul(theta[:, :self.poly_degree], self.forecast_base)
- return backcast, forecast
-
-
- class SeasonalityBasis(nn.Module):
- """
- Seasonality basis for block output
- """
- def __init__(self, backcast_size, forecast_size, harmonics):
- super(SeasonalityBasis, self).__init__()
- self.frequency = np.append(np.zeros(1, dtype=np.float32),
- np.arange(harmonics, harmonics / 2 * forecast_size,
- dtype=np.float32) / harmonics)[None, :]
- backcast_grid = -2 * np.pi * (
- np.arange(backcast_size, dtype=np.float32)[:, None] / forecast_size) * self.frequency
- forecast_grid = 2 * np.pi * (
- np.arange(forecast_size, dtype=np.float32)[:, None] / forecast_size) * self.frequency
- self.backcast_cos_base = nn.Parameter(torch.tensor(np.transpose(np.cos(backcast_grid)), dtype=torch.float32),
- requires_grad=False)
- self.backcast_sin_base = nn.Parameter(torch.tensor(np.transpose(np.sin(backcast_grid)), dtype=torch.float32),
- requires_grad=False)
- self.forecast_cos_base = nn.Parameter(torch.tensor(np.transpose(np.cos(forecast_grid)), dtype=torch.float32),
- requires_grad=False)
- self.forecast_sin_base = nn.Parameter(torch.tensor(np.transpose(np.sin(forecast_grid)), dtype=torch.float32),
- requires_grad=False)
-
- def forward(self, theta):
- num_params = theta.shape[1] // 4
-
- backcast_harmonics_cos = torch.matmul(theta[:, 2 * num_params: 3 * num_params], self.backcast_cos_base)
- backcast_harmonics_sin = torch.matmul(theta[:, 3 * num_params:], self.backcast_sin_base)
- backcast = backcast_harmonics_cos + backcast_harmonics_sin
-
- forecast_harmonics_cos = torch.matmul(theta[:, : num_params], self.forecast_cos_base)
- forecast_harmonics_sin = torch.matmul(theta[:, num_params: 2 * num_params], self.forecast_sin_base)
- forecast = forecast_harmonics_cos + forecast_harmonics_sin
-
- return backcast, forecast
-
-
- class NBeats(nn.Module):
- """
- N-BEATS architecture
- """
- def __init__(self, input_size, output_size, interpretable=False,
- generic_num_blocks=None, trend_num_blocks=None, seasonality_num_blocks=None,
- generic_num_layers=None, trend_num_layers=None, seasonality_num_layers=None,
- generic_layers_size=None, trend_layers_size=None, seasonality_layers_size=None,
- poly_degree=None, harmonics=None, **kwargs):
- super(NBeats, self).__init__()
- self.input_size = input_size
- self.output_size = output_size
- self.interpretable = interpretable
-
- if not interpretable:
- self.blocks = nn.ModuleList([NBeatsBlock(input_size=input_size,
- layer_size=generic_layers_size,
- output_size=output_size,
- n_layers=generic_num_layers,
- block_type='generic')
- for _ in range(generic_num_blocks)])
- else:
- trend_block = NBeatsBlock(input_size=input_size,
- layer_size=trend_layers_size,
- output_size=output_size,
- n_layers=trend_num_layers,
- block_type='trend',
- poly_degree=poly_degree)
- seasonality_block = NBeatsBlock(input_size=input_size,
- layer_size=seasonality_layers_size,
- output_size=output_size,
- n_layers=seasonality_num_layers,
- block_type='seasonality',
- harmonics=harmonics)
-
- self.trend_num_blocks = trend_num_blocks
- self.seasonality_num_blocks = seasonality_num_blocks
- self.blocks = nn.ModuleList([trend_block for _ in range(trend_num_blocks)] +
- [seasonality_block for _ in range(seasonality_num_blocks)])
-
- def forward(self, x: torch.tensor, input_mask: torch.tensor = None):
- resid = x.flip(dims=(1,))
- if input_mask is None:
- input_mask = torch.ones_like(x)
- else:
- input_mask = input_mask.flip(dims=(1,))
-
- # output interpretable results
- if self.interpretable:
- trend_out = x[:, -self.output_size:]
- for i in range(self.trend_num_blocks):
- block = self.blocks[i]
- backcast, forecast = block(resid)
- trend_out += forecast
- resid = (resid - backcast) * input_mask
-
- seasonality_out = torch.zeros_like(trend_out)
- for i in range(self.trend_num_blocks, self.trend_num_blocks+self.seasonality_num_blocks):
- block = self.blocks[i]
- backcast, forecast = block(resid)
- seasonality_out += forecast
- resid = (resid - backcast) * input_mask
-
- out = trend_out + seasonality_out
- return out, trend_out, seasonality_out
- else:
- out = x[:, -self.output_size:]
- for block in self.blocks:
- backcast, forecast = block(resid)
- out += forecast
- resid = (resid - backcast) * input_mask
- return out
|