Initial media depth project backup
This commit is contained in:
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .dino_clstoken_loss import DINOLoss
|
||||
from .ibot_patch_loss import iBOTPatchLoss
|
||||
from .koleo_loss import KoLeoLoss
|
||||
@@ -0,0 +1,100 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class DINOLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
out_dim,
|
||||
student_temp=0.1,
|
||||
center_momentum=0.9,
|
||||
):
|
||||
super().__init__()
|
||||
self.student_temp = student_temp
|
||||
self.center_momentum = center_momentum
|
||||
self.register_buffer("center", torch.zeros(1, out_dim))
|
||||
self.updated = True
|
||||
self.reduce_handle = None
|
||||
self.len_teacher_output = None
|
||||
self.async_batch_center = None
|
||||
|
||||
@torch.no_grad()
|
||||
def softmax_center_teacher(self, teacher_output, teacher_temp):
|
||||
self.apply_center_update()
|
||||
# teacher centering and sharpening
|
||||
return F.softmax((teacher_output - self.center) / teacher_temp, dim=-1)
|
||||
|
||||
@torch.no_grad()
|
||||
def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_iterations=3):
|
||||
teacher_output = teacher_output.float()
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper
|
||||
B = Q.shape[1] * world_size # number of samples to assign
|
||||
K = Q.shape[0] # how many prototypes
|
||||
|
||||
# make the matrix sums to 1
|
||||
sum_Q = torch.sum(Q)
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(sum_Q)
|
||||
Q /= sum_Q
|
||||
|
||||
for it in range(n_iterations):
|
||||
# normalize each row: total weight per prototype must be 1/K
|
||||
sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(sum_of_rows)
|
||||
Q /= sum_of_rows
|
||||
Q /= K
|
||||
|
||||
# normalize each column: total weight per sample must be 1/B
|
||||
Q /= torch.sum(Q, dim=0, keepdim=True)
|
||||
Q /= B
|
||||
|
||||
Q *= B # the columns must sum to 1 so that Q is an assignment
|
||||
return Q.t()
|
||||
|
||||
def forward(self, student_output_list, teacher_out_softmaxed_centered_list):
|
||||
"""
|
||||
Cross-entropy between softmax outputs of the teacher and student networks.
|
||||
"""
|
||||
# TODO: Use cross_entropy_distribution here
|
||||
total_loss = 0
|
||||
for s in student_output_list:
|
||||
lsm = F.log_softmax(s / self.student_temp, dim=-1)
|
||||
for t in teacher_out_softmaxed_centered_list:
|
||||
loss = torch.sum(t * lsm, dim=-1)
|
||||
total_loss -= loss.mean()
|
||||
return total_loss
|
||||
|
||||
@torch.no_grad()
|
||||
def update_center(self, teacher_output):
|
||||
self.reduce_center_update(teacher_output)
|
||||
|
||||
@torch.no_grad()
|
||||
def reduce_center_update(self, teacher_output):
|
||||
self.updated = False
|
||||
self.len_teacher_output = len(teacher_output)
|
||||
self.async_batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
|
||||
if dist.is_initialized():
|
||||
self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True)
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_center_update(self):
|
||||
if self.updated is False:
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
|
||||
if self.reduce_handle is not None:
|
||||
self.reduce_handle.wait()
|
||||
_t = self.async_batch_center / (self.len_teacher_output * world_size)
|
||||
|
||||
self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum)
|
||||
|
||||
self.updated = True
|
||||
@@ -0,0 +1,152 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
try:
|
||||
from xformers.ops import cross_entropy
|
||||
|
||||
def lossfunc(t, s, temp):
|
||||
s = s.float()
|
||||
t = t.float()
|
||||
if s.ndim == 2:
|
||||
return -cross_entropy(s.unsqueeze(0), t.unsqueeze(0), temp, bw_inplace=True).squeeze(0)
|
||||
elif s.ndim == 3:
|
||||
return -cross_entropy(s, t, temp, bw_inplace=True)
|
||||
|
||||
except ImportError:
|
||||
|
||||
def lossfunc(t, s, temp):
|
||||
return torch.sum(t * F.log_softmax(s / temp, dim=-1), dim=-1)
|
||||
|
||||
|
||||
class iBOTPatchLoss(nn.Module):
|
||||
def __init__(self, patch_out_dim, student_temp=0.1, center_momentum=0.9):
|
||||
super().__init__()
|
||||
self.student_temp = student_temp
|
||||
self.center_momentum = center_momentum
|
||||
self.register_buffer("center", torch.zeros(1, 1, patch_out_dim))
|
||||
self.updated = True
|
||||
self.reduce_handle = None
|
||||
self.len_teacher_patch_tokens = None
|
||||
self.async_batch_center = None
|
||||
|
||||
@torch.no_grad()
|
||||
def softmax_center_teacher(self, teacher_patch_tokens, teacher_temp):
|
||||
self.apply_center_update()
|
||||
# teacher centering and sharpening
|
||||
#
|
||||
# WARNING:
|
||||
# as self.center is a float32, everything gets casted to float32 afterwards
|
||||
#
|
||||
# teacher_patch_tokens = teacher_patch_tokens.float()
|
||||
# return F.softmax((teacher_patch_tokens.sub_(self.center.to(teacher_patch_tokens.dtype))).mul_(1 / teacher_temp), dim=-1)
|
||||
|
||||
return F.softmax((teacher_patch_tokens - self.center) / teacher_temp, dim=-1)
|
||||
|
||||
# this is experimental, keep everything in float16 and let's see what happens:
|
||||
# return F.softmax((teacher_patch_tokens.sub_(self.center)) / teacher_temp, dim=-1)
|
||||
|
||||
@torch.no_grad()
|
||||
def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_masked_patches_tensor, n_iterations=3):
|
||||
teacher_output = teacher_output.float()
|
||||
# world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper
|
||||
# B = Q.shape[1] * world_size # number of samples to assign
|
||||
B = n_masked_patches_tensor
|
||||
dist.all_reduce(B)
|
||||
K = Q.shape[0] # how many prototypes
|
||||
|
||||
# make the matrix sums to 1
|
||||
sum_Q = torch.sum(Q)
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(sum_Q)
|
||||
Q /= sum_Q
|
||||
|
||||
for it in range(n_iterations):
|
||||
# normalize each row: total weight per prototype must be 1/K
|
||||
sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(sum_of_rows)
|
||||
Q /= sum_of_rows
|
||||
Q /= K
|
||||
|
||||
# normalize each column: total weight per sample must be 1/B
|
||||
Q /= torch.sum(Q, dim=0, keepdim=True)
|
||||
Q /= B
|
||||
|
||||
Q *= B # the columns must sum to 1 so that Q is an assignment
|
||||
return Q.t()
|
||||
|
||||
def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat):
|
||||
"""
|
||||
Cross-entropy between softmax outputs of the teacher and student networks.
|
||||
student_patch_tokens: (B, N, D) tensor
|
||||
teacher_patch_tokens: (B, N, D) tensor
|
||||
student_masks_flat: (B, N) tensor
|
||||
"""
|
||||
t = teacher_patch_tokens
|
||||
s = student_patch_tokens
|
||||
loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1)
|
||||
loss = torch.sum(loss * student_masks_flat.float(), dim=-1) / student_masks_flat.sum(dim=-1).clamp(min=1.0)
|
||||
return -loss.mean()
|
||||
|
||||
def forward_masked(
|
||||
self,
|
||||
student_patch_tokens_masked,
|
||||
teacher_patch_tokens_masked,
|
||||
student_masks_flat,
|
||||
n_masked_patches=None,
|
||||
masks_weight=None,
|
||||
):
|
||||
t = teacher_patch_tokens_masked
|
||||
s = student_patch_tokens_masked
|
||||
# loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1)
|
||||
loss = lossfunc(t, s, self.student_temp)
|
||||
if masks_weight is None:
|
||||
masks_weight = (
|
||||
(1 / student_masks_flat.sum(-1).clamp(min=1.0))
|
||||
.unsqueeze(-1)
|
||||
.expand_as(student_masks_flat)[student_masks_flat]
|
||||
)
|
||||
if n_masked_patches is not None:
|
||||
loss = loss[:n_masked_patches]
|
||||
loss = loss * masks_weight
|
||||
return -loss.sum() / student_masks_flat.shape[0]
|
||||
|
||||
@torch.no_grad()
|
||||
def update_center(self, teacher_patch_tokens):
|
||||
self.reduce_center_update(teacher_patch_tokens)
|
||||
|
||||
@torch.no_grad()
|
||||
def reduce_center_update(self, teacher_patch_tokens):
|
||||
self.updated = False
|
||||
self.len_teacher_patch_tokens = len(teacher_patch_tokens)
|
||||
self.async_batch_center = torch.sum(teacher_patch_tokens.mean(1), dim=0, keepdim=True)
|
||||
if dist.is_initialized():
|
||||
self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True)
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_center_update(self):
|
||||
if self.updated is False:
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
|
||||
if self.reduce_handle is not None:
|
||||
self.reduce_handle.wait()
|
||||
_t = self.async_batch_center / (self.len_teacher_patch_tokens * world_size)
|
||||
|
||||
self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum)
|
||||
|
||||
self.updated = True
|
||||
@@ -0,0 +1,49 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# import torch.distributed as dist
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
class KoLeoLoss(nn.Module):
|
||||
"""Kozachenko-Leonenko entropic loss regularizer from Sablayrolles et al. - 2018 - Spreading vectors for similarity search"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pdist = nn.PairwiseDistance(2, eps=1e-8)
|
||||
|
||||
def pairwise_NNs_inner(self, x):
|
||||
"""
|
||||
Pairwise nearest neighbors for L2-normalized vectors.
|
||||
Uses Torch rather than Faiss to remain on GPU.
|
||||
"""
|
||||
# parwise dot products (= inverse distance)
|
||||
dots = torch.mm(x, x.t())
|
||||
n = x.shape[0]
|
||||
dots.view(-1)[:: (n + 1)].fill_(-1) # Trick to fill diagonal with -1
|
||||
# max inner prod -> min distance
|
||||
_, I = torch.max(dots, dim=1) # noqa: E741
|
||||
return I
|
||||
|
||||
def forward(self, student_output, eps=1e-8):
|
||||
"""
|
||||
Args:
|
||||
student_output (BxD): backbone output of student
|
||||
"""
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
student_output = F.normalize(student_output, eps=eps, p=2, dim=-1)
|
||||
I = self.pairwise_NNs_inner(student_output) # noqa: E741
|
||||
distances = self.pdist(student_output, student_output[I]) # BxD, BxD -> B
|
||||
loss = -torch.log(distances + eps).mean()
|
||||
return loss
|
||||
Reference in New Issue
Block a user