|
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
- import torchvision.models as models
-
-
- class Model(nn.Module):
- def __init__(self, arch, head, out_dim):
- super(Model, self).__init__()
-
- model = getattr(models, arch)()
- inplanes = model.inplanes
- backbone = list(model.children())[:-1]
- backbone[0] = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
- del backbone[3]
- backbone.append(nn.Flatten(start_dim=1))
- self.backbone = nn.Sequential(*backbone)
- if head == 'proj':
- self.head = nn.Sequential(
- nn.Linear(inplanes, 512, bias=False),
- nn.BatchNorm1d(512),
- nn.ReLU(inplace=True),
- nn.Linear(512, out_dim)
- )
- else:
- self.head = nn.Sequential(
- nn.Linear(inplanes, out_dim)
- )
-
-
- def forward_feature(self, x):
- return self.backbone(x)
-
- def forward(self, x):
- return self.head(self.backbone(x))
-
-
- class ContrastLoss(nn.Module):
- """Implementation of the Contrastive Cross Entropy Loss.
- This implementation follows the SimCLR[0] paper. If you enable the memory
- bank by setting the `memory_bank_size` value > 0 the loss behaves like
- the one described in the MoCo[1] paper.
- - [0] SimCLR, 2020, https://arxiv.org/abs/2002.05709
- - [1] MoCo, 2020, https://arxiv.org/abs/1911.05722
-
- Attributes:
- temperature:
- Scale logits by the inverse of the temperature.
- memory_bank_size:
- Number of negative samples to store in the memory bank.
- Use 0 for SimCLR. For MoCo we typically use numbers like 4096 or 65536.
- gather_distributed:
- If True then negatives from all gpus are gathered before the
- loss calculation. This flag has no effect if memory_bank_size > 0.
- Raises:
- ValueError: If abs(temperature) < 1e-8 to prevent divide by zero.
- Examples:
- >>> # initialize loss function without memory bank
- >>> loss_fn = NTXentLoss(memory_bank_size=0)
- >>>
- >>> # generate two random transforms of images
- >>> t0 = transforms(images)
- >>> t1 = transforms(images)
- >>>
- >>> # feed through SimCLR or MoCo model
- >>> batch = torch.cat((t0, t1), dim=0)
- >>> output = model(batch)
- >>>
- >>> # calculate loss
- >>> loss = loss_fn(output)
- """
- eps = 1e-8
- def __init__(self, embed_dim: int = 128, temperature: float = 0.5, bank_size: int = 2**12, reduction: str = "mean"):
- super().__init__()
- assert abs(temperature) >= self.eps, f'Illegal temperature: abs({self.temperature}) < 1e-8'
- self.embed_dim = embed_dim
- self.temperature = temperature
- self.cross_entropy = nn.CrossEntropyLoss(reduction=reduction)
-
- self.bank_size = bank_size
- self.register_buffer("bank", tensor=F.normalize(torch.randn(self.bank_size, self.embed_dim, dtype=torch.float), dim=-1), persistent=False)
- self.register_buffer("bank_ptr", tensor=torch.zeros(1, dtype=torch.long), persistent=False)
-
- @torch.no_grad()
- def _dequeue_and_enqueue(self, batch: torch.Tensor):
- """Dequeue the oldest batch and add the latest one
- Args:
- batch:
- The latest batch of keys to add to the memory bank.
- """
- batch_size = batch.size(0)
- ptr = int(self.bank_ptr.item())
-
- if ptr + batch_size >= self.bank_size:
- self.bank[ptr:] = batch[:self.bank_size - ptr].detach()
- self.bank_ptr[0] = 0
- else:
- self.bank[ptr:ptr + batch_size] = batch.detach()
- self.bank_ptr[0] = ptr + batch_size
-
- def forward(self, query: torch.Tensor, key: torch.Tensor):
- """Forward pass through Contrastive Cross-Entropy Loss.
- If used with a memory bank, the samples from the memory bank are used
- as negative examples. Otherwise, within-batch samples are used as
- negative samples.
- Args:
- out0:
- Output projections of the first set of transformed images.
- Shape: (batch_size, embedding_size)
- out1:
- Output projections of the second set of transformed images.
- Shape: (batch_size, embedding_size)
- Returns:
- Contrastive Cross Entropy Loss value.
- """
-
- device = query.device
- batch_size = query.size(0)
-
- # normalize the output to length 1
- query = F.normalize(query, dim=1)
- key = F.normalize(key, dim=1)
-
- # ask memory bank for negative samples and extend it with out1 if
- # out1 requires a gradient, otherwise keep the same vectors in the
- # memory bank (this allows for keeping the memory bank constant e.g.
- # for evaluating the loss on the test set)
- # out1: shape: (batch_size, embedding_size)
- # negatives: shape: (embedding_size, memory_bank_size)
- negatives = self.bank.clone().detach()
-
- # We use the cosine similarity, which is a dot product (einsum) here,
- # as all vectors are already normalized to unit length.
- # Notation in einsum: n = batch_size, c = embedding_size and k = memory_bank_size.
-
- # use negatives from memory bank
- negatives = negatives.to(device)
-
- # sim_pos is of shape (batch_size, 1) and sim_pos[i] denotes the similarity
- # of the i-th sample in the batch to its positive pair
- sim_pos = torch.einsum('nc,nc->n', query, key).unsqueeze(-1)
-
- # sim_neg is of shape (batch_size, memory_bank_size) and sim_neg[i,j] denotes the similarity
- # of the i-th sample to the j-th negative sample
- sim_neg = torch.einsum('nc,kc->nk', query, negatives)
-
- # set the labels to the first "class", i.e. sim_pos,
- # so that it is maximized in relation to sim_neg
- logits = torch.cat([sim_pos, sim_neg], dim=1)
- labels = torch.zeros(batch_size, device=device, dtype=torch.long)
-
- loss = self.cross_entropy(logits / self.temperature, labels)
-
- self._dequeue_and_enqueue(key)
-
- return loss
|