|
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.optim as optim
- from torchvision import datasets, transforms
- import accelerate
- from accelerate import Accelerator
-
- print(accelerate.utils.is_npu_available())
-
- class BasicModel(nn.Module):
- def __init__(self):
- super().__init__()
- self.conv1 = nn.Conv2d(1, 32, 3, 1)
- self.conv2 = nn.Conv2d(32, 64, 3, 1)
- self.dropout1 = nn.Dropout(0.25)
- self.dropout2 = nn.Dropout(0.5)
- self.fc1 = nn.Linear(9216, 128)
- self.fc2 = nn.Linear(128, 10)
- self.act = F.relu
-
- def forward(self, x):
- x = self.act(self.conv1(x))
- x = self.act(self.conv2(x))
- x = F.max_pool2d(x, 2)
- x = self.dropout1(x)
- x = torch.flatten(x, 1)
- x = self.act(self.fc1(x))
- x = self.dropout2(x)
- x = self.fc2(x)
- output = F.log_softmax(x, dim=1)
- return output
-
-
- def train_ddp_accelerate():
- accelerator = Accelerator()
- # Build DataLoaders
- transform = transforms.Compose([
- transforms.ToTensor(),
- transforms.Normalize((0.1307), (0.3081))
- ])
-
- train_dset = datasets.MNIST('data', train=True, download=True, transform=transform)
- test_dset = datasets.MNIST('data', train=False, transform=transform)
-
- train_loader = torch.utils.data.DataLoader(train_dset, shuffle=True, batch_size=64)
- test_loader = torch.utils.data.DataLoader(test_dset, shuffle=False, batch_size=64)
-
- # Build model
- model = BasicModel()
-
- # Build optimizer
- optimizer = optim.AdamW(model.parameters(), lr=1e-3)
-
- # Send everything through `accelerator.prepare`
- train_loader, test_loader, model, optimizer = accelerator.prepare(
- train_loader, test_loader, model, optimizer
- )
-
- # Train for a single epoch
- model.train()
- for epoch in range(10):
- for batch_idx, (data, target) in enumerate(train_loader):
- output = model(data)
- loss = F.nll_loss(output, target)
- print('Epoch:{}, Loss:{}'.format(epoch, loss))
- accelerator.backward(loss)
- optimizer.step()
- optimizer.zero_grad()
-
- # Evaluate
- model.eval()
- correct = 0
- with torch.no_grad():
- for data, target in test_loader:
- data, target = data.to('npu'), target.to('npu')
- output = model(data)
- pred = output.argmax(dim=1, keepdim=True)
- correct += pred.eq(target.view_as(pred)).sum().item()
- print(f'Accuracy: {100. * correct / len(test_loader.dataset)}')
-
- if __name__ == '__main__':
- train_ddp_accelerate()
|