Initial media depth project backup
This commit is contained in:
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
__version__ = "0.0.1"
|
||||
@@ -0,0 +1,23 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import pathlib
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
def load_config(config_name: str):
|
||||
config_filename = config_name + ".yaml"
|
||||
return OmegaConf.load(pathlib.Path(__file__).parent.resolve() / config_filename)
|
||||
|
||||
|
||||
dinov2_default_config = load_config("ssl_default_config")
|
||||
|
||||
|
||||
def load_and_merge_config(config_name: str):
|
||||
default_config = OmegaConf.create(dinov2_default_config)
|
||||
loaded_config = load_config(config_name)
|
||||
return OmegaConf.merge(default_config, loaded_config)
|
||||
@@ -0,0 +1,6 @@
|
||||
student:
|
||||
arch: vit_base
|
||||
patch_size: 14
|
||||
crops:
|
||||
global_crops_size: 518 # this is to set up the position embeddings properly
|
||||
local_crops_size: 98
|
||||
@@ -0,0 +1,7 @@
|
||||
student:
|
||||
arch: vit_giant2
|
||||
patch_size: 14
|
||||
ffn_layer: swiglufused
|
||||
crops:
|
||||
global_crops_size: 518 # this is to set up the position embeddings properly
|
||||
local_crops_size: 98
|
||||
@@ -0,0 +1,6 @@
|
||||
student:
|
||||
arch: vit_large
|
||||
patch_size: 14
|
||||
crops:
|
||||
global_crops_size: 518 # this is to set up the position embeddings properly
|
||||
local_crops_size: 98
|
||||
@@ -0,0 +1,6 @@
|
||||
student:
|
||||
arch: vit_small
|
||||
patch_size: 14
|
||||
crops:
|
||||
global_crops_size: 518 # this is to set up the position embeddings properly
|
||||
local_crops_size: 98
|
||||
@@ -0,0 +1,115 @@
|
||||
MODEL:
|
||||
WEIGHTS: ''
|
||||
compute_precision:
|
||||
grad_scaler: true
|
||||
teacher:
|
||||
backbone:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
reduce_dtype: fp16
|
||||
buffer_dtype: fp32
|
||||
dino_head:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
reduce_dtype: fp16
|
||||
buffer_dtype: fp32
|
||||
ibot_head:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
reduce_dtype: fp16
|
||||
buffer_dtype: fp32
|
||||
student:
|
||||
backbone:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
reduce_dtype: fp16
|
||||
buffer_dtype: fp32
|
||||
dino_head:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
reduce_dtype: fp32
|
||||
buffer_dtype: fp32
|
||||
ibot_head:
|
||||
sharding_strategy: SHARD_GRAD_OP
|
||||
mixed_precision:
|
||||
param_dtype: fp16
|
||||
reduce_dtype: fp32
|
||||
buffer_dtype: fp32
|
||||
dino:
|
||||
loss_weight: 1.0
|
||||
head_n_prototypes: 65536
|
||||
head_bottleneck_dim: 256
|
||||
head_nlayers: 3
|
||||
head_hidden_dim: 2048
|
||||
koleo_loss_weight: 0.1
|
||||
ibot:
|
||||
loss_weight: 1.0
|
||||
mask_sample_probability: 0.5
|
||||
mask_ratio_min_max:
|
||||
- 0.1
|
||||
- 0.5
|
||||
separate_head: false
|
||||
head_n_prototypes: 65536
|
||||
head_bottleneck_dim: 256
|
||||
head_nlayers: 3
|
||||
head_hidden_dim: 2048
|
||||
train:
|
||||
batch_size_per_gpu: 64
|
||||
dataset_path: ImageNet:split=TRAIN
|
||||
output_dir: .
|
||||
saveckp_freq: 20
|
||||
seed: 0
|
||||
num_workers: 10
|
||||
OFFICIAL_EPOCH_LENGTH: 1250
|
||||
cache_dataset: true
|
||||
centering: "centering" # or "sinkhorn_knopp"
|
||||
student:
|
||||
arch: vit_large
|
||||
patch_size: 16
|
||||
drop_path_rate: 0.3
|
||||
layerscale: 1.0e-05
|
||||
drop_path_uniform: true
|
||||
pretrained_weights: ''
|
||||
ffn_layer: "mlp"
|
||||
block_chunks: 0
|
||||
qkv_bias: true
|
||||
proj_bias: true
|
||||
ffn_bias: true
|
||||
teacher:
|
||||
momentum_teacher: 0.992
|
||||
final_momentum_teacher: 1
|
||||
warmup_teacher_temp: 0.04
|
||||
teacher_temp: 0.07
|
||||
warmup_teacher_temp_epochs: 30
|
||||
optim:
|
||||
epochs: 100
|
||||
weight_decay: 0.04
|
||||
weight_decay_end: 0.4
|
||||
base_lr: 0.004 # learning rate for a batch size of 1024
|
||||
lr: 0. # will be set after applying scaling rule
|
||||
warmup_epochs: 10
|
||||
min_lr: 1.0e-06
|
||||
clip_grad: 3.0
|
||||
freeze_last_layer_epochs: 1
|
||||
scaling_rule: sqrt_wrt_1024
|
||||
patch_embed_lr_mult: 0.2
|
||||
layerwise_decay: 0.9
|
||||
adamw_beta1: 0.9
|
||||
adamw_beta2: 0.999
|
||||
crops:
|
||||
global_crops_scale:
|
||||
- 0.32
|
||||
- 1.0
|
||||
local_crops_number: 8
|
||||
local_crops_scale:
|
||||
- 0.05
|
||||
- 0.32
|
||||
global_crops_size: 224
|
||||
local_crops_size: 96
|
||||
evaluation:
|
||||
eval_period_iterations: 12500
|
||||
@@ -0,0 +1,26 @@
|
||||
dino:
|
||||
head_n_prototypes: 131072
|
||||
head_bottleneck_dim: 384
|
||||
ibot:
|
||||
separate_head: true
|
||||
head_n_prototypes: 131072
|
||||
train:
|
||||
batch_size_per_gpu: 12
|
||||
dataset_path: ImageNet22k
|
||||
centering: sinkhorn_knopp
|
||||
student:
|
||||
arch: vit_giant2
|
||||
patch_size: 14
|
||||
drop_path_rate: 0.4
|
||||
ffn_layer: swiglufused
|
||||
block_chunks: 4
|
||||
teacher:
|
||||
momentum_teacher: 0.994
|
||||
optim:
|
||||
epochs: 500
|
||||
weight_decay_end: 0.2
|
||||
base_lr: 2.0e-04 # learning rate for a batch size of 1024
|
||||
warmup_epochs: 80
|
||||
layerwise_decay: 1.0
|
||||
crops:
|
||||
local_crops_size: 98
|
||||
@@ -0,0 +1,26 @@
|
||||
dino:
|
||||
head_n_prototypes: 131072
|
||||
head_bottleneck_dim: 384
|
||||
ibot:
|
||||
separate_head: true
|
||||
head_n_prototypes: 131072
|
||||
train:
|
||||
batch_size_per_gpu: 32
|
||||
dataset_path: ImageNet22k
|
||||
centering: sinkhorn_knopp
|
||||
student:
|
||||
arch: vit_large
|
||||
patch_size: 14
|
||||
drop_path_rate: 0.4
|
||||
ffn_layer: swiglufused
|
||||
block_chunks: 4
|
||||
teacher:
|
||||
momentum_teacher: 0.994
|
||||
optim:
|
||||
epochs: 500
|
||||
weight_decay_end: 0.2
|
||||
base_lr: 2.0e-04 # learning rate for a batch size of 1024
|
||||
warmup_epochs: 80
|
||||
layerwise_decay: 1.0
|
||||
crops:
|
||||
local_crops_size: 98
|
||||
@@ -0,0 +1,6 @@
|
||||
# this corresponds to the default config
|
||||
train:
|
||||
dataset_path: ImageNet:split=TRAIN
|
||||
batch_size_per_gpu: 64
|
||||
student:
|
||||
block_chunks: 4
|
||||
@@ -0,0 +1,11 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .adapters import DatasetWithEnumeratedTargets
|
||||
from .loaders import make_data_loader, make_dataset, SamplerType
|
||||
from .collate import collate_data_and_cast
|
||||
from .masking import MaskingGenerator
|
||||
from .augmentations import DataAugmentationDINO
|
||||
@@ -0,0 +1,29 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Tuple
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class DatasetWithEnumeratedTargets(Dataset):
|
||||
def __init__(self, dataset):
|
||||
self._dataset = dataset
|
||||
|
||||
def get_image_data(self, index: int) -> bytes:
|
||||
return self._dataset.get_image_data(index)
|
||||
|
||||
def get_target(self, index: int) -> Tuple[Any, int]:
|
||||
target = self._dataset.get_target(index)
|
||||
return (index, target)
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]:
|
||||
image, target = self._dataset[index]
|
||||
target = index if target is None else target
|
||||
return image, (index, target)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._dataset)
|
||||
@@ -0,0 +1,119 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
from torchvision import transforms
|
||||
|
||||
from .transforms import (
|
||||
GaussianBlur,
|
||||
make_normalize_transform,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
class DataAugmentationDINO(object):
|
||||
def __init__(
|
||||
self,
|
||||
global_crops_scale,
|
||||
local_crops_scale,
|
||||
local_crops_number,
|
||||
global_crops_size=224,
|
||||
local_crops_size=96,
|
||||
):
|
||||
self.global_crops_scale = global_crops_scale
|
||||
self.local_crops_scale = local_crops_scale
|
||||
self.local_crops_number = local_crops_number
|
||||
self.global_crops_size = global_crops_size
|
||||
self.local_crops_size = local_crops_size
|
||||
|
||||
logger.info("###################################")
|
||||
logger.info("Using data augmentation parameters:")
|
||||
logger.info(f"global_crops_scale: {global_crops_scale}")
|
||||
logger.info(f"local_crops_scale: {local_crops_scale}")
|
||||
logger.info(f"local_crops_number: {local_crops_number}")
|
||||
logger.info(f"global_crops_size: {global_crops_size}")
|
||||
logger.info(f"local_crops_size: {local_crops_size}")
|
||||
logger.info("###################################")
|
||||
|
||||
# random resized crop and flip
|
||||
self.geometric_augmentation_global = transforms.Compose(
|
||||
[
|
||||
transforms.RandomResizedCrop(
|
||||
global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
|
||||
),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
]
|
||||
)
|
||||
|
||||
self.geometric_augmentation_local = transforms.Compose(
|
||||
[
|
||||
transforms.RandomResizedCrop(
|
||||
local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
|
||||
),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
]
|
||||
)
|
||||
|
||||
# color distorsions / blurring
|
||||
color_jittering = transforms.Compose(
|
||||
[
|
||||
transforms.RandomApply(
|
||||
[transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
|
||||
p=0.8,
|
||||
),
|
||||
transforms.RandomGrayscale(p=0.2),
|
||||
]
|
||||
)
|
||||
|
||||
global_transfo1_extra = GaussianBlur(p=1.0)
|
||||
|
||||
global_transfo2_extra = transforms.Compose(
|
||||
[
|
||||
GaussianBlur(p=0.1),
|
||||
transforms.RandomSolarize(threshold=128, p=0.2),
|
||||
]
|
||||
)
|
||||
|
||||
local_transfo_extra = GaussianBlur(p=0.5)
|
||||
|
||||
# normalization
|
||||
self.normalize = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
make_normalize_transform(),
|
||||
]
|
||||
)
|
||||
|
||||
self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize])
|
||||
self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize])
|
||||
self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize])
|
||||
|
||||
def __call__(self, image):
|
||||
output = {}
|
||||
|
||||
# global crops:
|
||||
im1_base = self.geometric_augmentation_global(image)
|
||||
global_crop_1 = self.global_transfo1(im1_base)
|
||||
|
||||
im2_base = self.geometric_augmentation_global(image)
|
||||
global_crop_2 = self.global_transfo2(im2_base)
|
||||
|
||||
output["global_crops"] = [global_crop_1, global_crop_2]
|
||||
|
||||
# global crops for teacher:
|
||||
output["global_crops_teacher"] = [global_crop_1, global_crop_2]
|
||||
|
||||
# local crops:
|
||||
local_crops = [
|
||||
self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number)
|
||||
]
|
||||
output["local_crops"] = local_crops
|
||||
output["offsets"] = ()
|
||||
|
||||
return output
|
||||
@@ -0,0 +1,50 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import random
|
||||
|
||||
|
||||
def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None):
|
||||
# dtype = torch.half # TODO: Remove
|
||||
|
||||
n_global_crops = len(samples_list[0][0]["global_crops"])
|
||||
n_local_crops = len(samples_list[0][0]["local_crops"])
|
||||
|
||||
collated_global_crops = torch.stack([s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list])
|
||||
|
||||
collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list])
|
||||
|
||||
B = len(collated_global_crops)
|
||||
N = n_tokens
|
||||
n_samples_masked = int(B * mask_probability)
|
||||
probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1)
|
||||
upperbound = 0
|
||||
masks_list = []
|
||||
for i in range(0, n_samples_masked):
|
||||
prob_min = probs[i]
|
||||
prob_max = probs[i + 1]
|
||||
masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max)))))
|
||||
upperbound += int(N * prob_max)
|
||||
for i in range(n_samples_masked, B):
|
||||
masks_list.append(torch.BoolTensor(mask_generator(0)))
|
||||
|
||||
random.shuffle(masks_list)
|
||||
|
||||
collated_masks = torch.stack(masks_list).flatten(1)
|
||||
mask_indices_list = collated_masks.flatten().nonzero().flatten()
|
||||
|
||||
masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks]
|
||||
|
||||
return {
|
||||
"collated_global_crops": collated_global_crops.to(dtype),
|
||||
"collated_local_crops": collated_local_crops.to(dtype),
|
||||
"collated_masks": collated_masks,
|
||||
"mask_indices_list": mask_indices_list,
|
||||
"masks_weight": masks_weight,
|
||||
"upperbound": upperbound,
|
||||
"n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long),
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .image_net import ImageNet
|
||||
from .image_net_22k import ImageNet22k
|
||||
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class Decoder:
|
||||
def decode(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ImageDataDecoder(Decoder):
|
||||
def __init__(self, image_data: bytes) -> None:
|
||||
self._image_data = image_data
|
||||
|
||||
def decode(self) -> Image:
|
||||
f = BytesIO(self._image_data)
|
||||
return Image.open(f).convert(mode="RGB")
|
||||
|
||||
|
||||
class TargetDecoder(Decoder):
|
||||
def __init__(self, target: Any):
|
||||
self._target = target
|
||||
|
||||
def decode(self) -> Any:
|
||||
return self._target
|
||||
@@ -0,0 +1,39 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Tuple
|
||||
|
||||
from torchvision.datasets import VisionDataset
|
||||
|
||||
from .decoders import TargetDecoder, ImageDataDecoder
|
||||
|
||||
|
||||
class ExtendedVisionDataset(VisionDataset):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs) # type: ignore
|
||||
|
||||
def get_image_data(self, index: int) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_target(self, index: int) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||||
try:
|
||||
image_data = self.get_image_data(index)
|
||||
image = ImageDataDecoder(image_data).decode()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"can not read image for sample {index}") from e
|
||||
target = self.get_target(index)
|
||||
target = TargetDecoder(target).decode()
|
||||
|
||||
if self.transforms is not None:
|
||||
image, target = self.transforms(image, target)
|
||||
|
||||
return image, target
|
||||
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,291 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import csv
|
||||
from enum import Enum
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .extended import ExtendedVisionDataset
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
_Target = int
|
||||
|
||||
|
||||
class _Split(Enum):
|
||||
TRAIN = "train"
|
||||
VAL = "val"
|
||||
TEST = "test" # NOTE: torchvision does not support the test split
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
split_lengths = {
|
||||
_Split.TRAIN: 1_281_167,
|
||||
_Split.VAL: 50_000,
|
||||
_Split.TEST: 100_000,
|
||||
}
|
||||
return split_lengths[self]
|
||||
|
||||
def get_dirname(self, class_id: Optional[str] = None) -> str:
|
||||
return self.value if class_id is None else os.path.join(self.value, class_id)
|
||||
|
||||
def get_image_relpath(self, actual_index: int, class_id: Optional[str] = None) -> str:
|
||||
dirname = self.get_dirname(class_id)
|
||||
if self == _Split.TRAIN:
|
||||
basename = f"{class_id}_{actual_index}"
|
||||
else: # self in (_Split.VAL, _Split.TEST):
|
||||
basename = f"ILSVRC2012_{self.value}_{actual_index:08d}"
|
||||
return os.path.join(dirname, basename + ".JPEG")
|
||||
|
||||
def parse_image_relpath(self, image_relpath: str) -> Tuple[str, int]:
|
||||
assert self != _Split.TEST
|
||||
dirname, filename = os.path.split(image_relpath)
|
||||
class_id = os.path.split(dirname)[-1]
|
||||
basename, _ = os.path.splitext(filename)
|
||||
actual_index = int(basename.split("_")[-1])
|
||||
return class_id, actual_index
|
||||
|
||||
|
||||
class ImageNet(ExtendedVisionDataset):
|
||||
Target = Union[_Target]
|
||||
Split = Union[_Split]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
split: "ImageNet.Split",
|
||||
root: str,
|
||||
extra: str,
|
||||
transforms: Optional[Callable] = None,
|
||||
transform: Optional[Callable] = None,
|
||||
target_transform: Optional[Callable] = None,
|
||||
) -> None:
|
||||
super().__init__(root, transforms, transform, target_transform)
|
||||
self._extra_root = extra
|
||||
self._split = split
|
||||
|
||||
self._entries = None
|
||||
self._class_ids = None
|
||||
self._class_names = None
|
||||
|
||||
@property
|
||||
def split(self) -> "ImageNet.Split":
|
||||
return self._split
|
||||
|
||||
def _get_extra_full_path(self, extra_path: str) -> str:
|
||||
return os.path.join(self._extra_root, extra_path)
|
||||
|
||||
def _load_extra(self, extra_path: str) -> np.ndarray:
|
||||
extra_full_path = self._get_extra_full_path(extra_path)
|
||||
return np.load(extra_full_path, mmap_mode="r")
|
||||
|
||||
def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
|
||||
extra_full_path = self._get_extra_full_path(extra_path)
|
||||
os.makedirs(self._extra_root, exist_ok=True)
|
||||
np.save(extra_full_path, extra_array)
|
||||
|
||||
@property
|
||||
def _entries_path(self) -> str:
|
||||
return f"entries-{self._split.value.upper()}.npy"
|
||||
|
||||
@property
|
||||
def _class_ids_path(self) -> str:
|
||||
return f"class-ids-{self._split.value.upper()}.npy"
|
||||
|
||||
@property
|
||||
def _class_names_path(self) -> str:
|
||||
return f"class-names-{self._split.value.upper()}.npy"
|
||||
|
||||
def _get_entries(self) -> np.ndarray:
|
||||
if self._entries is None:
|
||||
self._entries = self._load_extra(self._entries_path)
|
||||
assert self._entries is not None
|
||||
return self._entries
|
||||
|
||||
def _get_class_ids(self) -> np.ndarray:
|
||||
if self._split == _Split.TEST:
|
||||
assert False, "Class IDs are not available in TEST split"
|
||||
if self._class_ids is None:
|
||||
self._class_ids = self._load_extra(self._class_ids_path)
|
||||
assert self._class_ids is not None
|
||||
return self._class_ids
|
||||
|
||||
def _get_class_names(self) -> np.ndarray:
|
||||
if self._split == _Split.TEST:
|
||||
assert False, "Class names are not available in TEST split"
|
||||
if self._class_names is None:
|
||||
self._class_names = self._load_extra(self._class_names_path)
|
||||
assert self._class_names is not None
|
||||
return self._class_names
|
||||
|
||||
def find_class_id(self, class_index: int) -> str:
|
||||
class_ids = self._get_class_ids()
|
||||
return str(class_ids[class_index])
|
||||
|
||||
def find_class_name(self, class_index: int) -> str:
|
||||
class_names = self._get_class_names()
|
||||
return str(class_names[class_index])
|
||||
|
||||
def get_image_data(self, index: int) -> bytes:
|
||||
entries = self._get_entries()
|
||||
actual_index = entries[index]["actual_index"]
|
||||
|
||||
class_id = self.get_class_id(index)
|
||||
|
||||
image_relpath = self.split.get_image_relpath(actual_index, class_id)
|
||||
image_full_path = os.path.join(self.root, image_relpath)
|
||||
with open(image_full_path, mode="rb") as f:
|
||||
image_data = f.read()
|
||||
return image_data
|
||||
|
||||
def get_target(self, index: int) -> Optional[Target]:
|
||||
entries = self._get_entries()
|
||||
class_index = entries[index]["class_index"]
|
||||
return None if self.split == _Split.TEST else int(class_index)
|
||||
|
||||
def get_targets(self) -> Optional[np.ndarray]:
|
||||
entries = self._get_entries()
|
||||
return None if self.split == _Split.TEST else entries["class_index"]
|
||||
|
||||
def get_class_id(self, index: int) -> Optional[str]:
|
||||
entries = self._get_entries()
|
||||
class_id = entries[index]["class_id"]
|
||||
return None if self.split == _Split.TEST else str(class_id)
|
||||
|
||||
def get_class_name(self, index: int) -> Optional[str]:
|
||||
entries = self._get_entries()
|
||||
class_name = entries[index]["class_name"]
|
||||
return None if self.split == _Split.TEST else str(class_name)
|
||||
|
||||
def __len__(self) -> int:
|
||||
entries = self._get_entries()
|
||||
assert len(entries) == self.split.length
|
||||
return len(entries)
|
||||
|
||||
def _load_labels(self, labels_path: str) -> List[Tuple[str, str]]:
|
||||
labels_full_path = os.path.join(self.root, labels_path)
|
||||
labels = []
|
||||
|
||||
try:
|
||||
with open(labels_full_path, "r") as f:
|
||||
reader = csv.reader(f)
|
||||
for row in reader:
|
||||
class_id, class_name = row
|
||||
labels.append((class_id, class_name))
|
||||
except OSError as e:
|
||||
raise RuntimeError(f'can not read labels file "{labels_full_path}"') from e
|
||||
|
||||
return labels
|
||||
|
||||
def _dump_entries(self) -> None:
|
||||
split = self.split
|
||||
if split == ImageNet.Split.TEST:
|
||||
dataset = None
|
||||
sample_count = split.length
|
||||
max_class_id_length, max_class_name_length = 0, 0
|
||||
else:
|
||||
labels_path = "labels.txt"
|
||||
logger.info(f'loading labels from "{labels_path}"')
|
||||
labels = self._load_labels(labels_path)
|
||||
|
||||
# NOTE: Using torchvision ImageFolder for consistency
|
||||
from torchvision.datasets import ImageFolder
|
||||
|
||||
dataset_root = os.path.join(self.root, split.get_dirname())
|
||||
dataset = ImageFolder(dataset_root)
|
||||
sample_count = len(dataset)
|
||||
max_class_id_length, max_class_name_length = -1, -1
|
||||
for sample in dataset.samples:
|
||||
_, class_index = sample
|
||||
class_id, class_name = labels[class_index]
|
||||
max_class_id_length = max(len(class_id), max_class_id_length)
|
||||
max_class_name_length = max(len(class_name), max_class_name_length)
|
||||
|
||||
dtype = np.dtype(
|
||||
[
|
||||
("actual_index", "<u4"),
|
||||
("class_index", "<u4"),
|
||||
("class_id", f"U{max_class_id_length}"),
|
||||
("class_name", f"U{max_class_name_length}"),
|
||||
]
|
||||
)
|
||||
entries_array = np.empty(sample_count, dtype=dtype)
|
||||
|
||||
if split == ImageNet.Split.TEST:
|
||||
old_percent = -1
|
||||
for index in range(sample_count):
|
||||
percent = 100 * (index + 1) // sample_count
|
||||
if percent > old_percent:
|
||||
logger.info(f"creating entries: {percent}%")
|
||||
old_percent = percent
|
||||
|
||||
actual_index = index + 1
|
||||
class_index = np.uint32(-1)
|
||||
class_id, class_name = "", ""
|
||||
entries_array[index] = (actual_index, class_index, class_id, class_name)
|
||||
else:
|
||||
class_names = {class_id: class_name for class_id, class_name in labels}
|
||||
|
||||
assert dataset
|
||||
old_percent = -1
|
||||
for index in range(sample_count):
|
||||
percent = 100 * (index + 1) // sample_count
|
||||
if percent > old_percent:
|
||||
logger.info(f"creating entries: {percent}%")
|
||||
old_percent = percent
|
||||
|
||||
image_full_path, class_index = dataset.samples[index]
|
||||
image_relpath = os.path.relpath(image_full_path, self.root)
|
||||
class_id, actual_index = split.parse_image_relpath(image_relpath)
|
||||
class_name = class_names[class_id]
|
||||
entries_array[index] = (actual_index, class_index, class_id, class_name)
|
||||
|
||||
logger.info(f'saving entries to "{self._entries_path}"')
|
||||
self._save_extra(entries_array, self._entries_path)
|
||||
|
||||
def _dump_class_ids_and_names(self) -> None:
|
||||
split = self.split
|
||||
if split == ImageNet.Split.TEST:
|
||||
return
|
||||
|
||||
entries_array = self._load_extra(self._entries_path)
|
||||
|
||||
max_class_id_length, max_class_name_length, max_class_index = -1, -1, -1
|
||||
for entry in entries_array:
|
||||
class_index, class_id, class_name = (
|
||||
entry["class_index"],
|
||||
entry["class_id"],
|
||||
entry["class_name"],
|
||||
)
|
||||
max_class_index = max(int(class_index), max_class_index)
|
||||
max_class_id_length = max(len(str(class_id)), max_class_id_length)
|
||||
max_class_name_length = max(len(str(class_name)), max_class_name_length)
|
||||
|
||||
class_count = max_class_index + 1
|
||||
class_ids_array = np.empty(class_count, dtype=f"U{max_class_id_length}")
|
||||
class_names_array = np.empty(class_count, dtype=f"U{max_class_name_length}")
|
||||
for entry in entries_array:
|
||||
class_index, class_id, class_name = (
|
||||
entry["class_index"],
|
||||
entry["class_id"],
|
||||
entry["class_name"],
|
||||
)
|
||||
class_ids_array[class_index] = class_id
|
||||
class_names_array[class_index] = class_name
|
||||
|
||||
logger.info(f'saving class IDs to "{self._class_ids_path}"')
|
||||
self._save_extra(class_ids_array, self._class_ids_path)
|
||||
|
||||
logger.info(f'saving class names to "{self._class_names_path}"')
|
||||
self._save_extra(class_names_array, self._class_names_path)
|
||||
|
||||
def dump_extra(self) -> None:
|
||||
self._dump_entries()
|
||||
self._dump_class_ids_and_names()
|
||||
@@ -0,0 +1,303 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from gzip import GzipFile
|
||||
from io import BytesIO
|
||||
from mmap import ACCESS_READ, mmap
|
||||
import os
|
||||
from typing import Any, Callable, List, Optional, Set, Tuple
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .extended import ExtendedVisionDataset
|
||||
|
||||
|
||||
_Labels = int
|
||||
|
||||
_DEFAULT_MMAP_CACHE_SIZE = 16 # Warning: This can exhaust file descriptors
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ClassEntry:
|
||||
block_offset: int
|
||||
maybe_filename: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Entry:
|
||||
class_index: int # noqa: E701
|
||||
start_offset: int
|
||||
end_offset: int
|
||||
filename: str
|
||||
|
||||
|
||||
class _Split(Enum):
|
||||
TRAIN = "train"
|
||||
VAL = "val"
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
return {
|
||||
_Split.TRAIN: 11_797_647,
|
||||
_Split.VAL: 561_050,
|
||||
}[self]
|
||||
|
||||
def entries_path(self):
|
||||
return f"imagenet21kp_{self.value}.txt"
|
||||
|
||||
|
||||
def _get_tarball_path(class_id: str) -> str:
|
||||
return f"{class_id}.tar"
|
||||
|
||||
|
||||
def _make_mmap_tarball(tarballs_root: str, mmap_cache_size: int):
|
||||
@lru_cache(maxsize=mmap_cache_size)
|
||||
def _mmap_tarball(class_id: str) -> mmap:
|
||||
tarball_path = _get_tarball_path(class_id)
|
||||
tarball_full_path = os.path.join(tarballs_root, tarball_path)
|
||||
with open(tarball_full_path) as f:
|
||||
return mmap(fileno=f.fileno(), length=0, access=ACCESS_READ)
|
||||
|
||||
return _mmap_tarball
|
||||
|
||||
|
||||
class ImageNet22k(ExtendedVisionDataset):
|
||||
_GZIPPED_INDICES: Set[int] = {
|
||||
841_545,
|
||||
1_304_131,
|
||||
2_437_921,
|
||||
2_672_079,
|
||||
2_795_676,
|
||||
2_969_786,
|
||||
6_902_965,
|
||||
6_903_550,
|
||||
6_903_628,
|
||||
7_432_557,
|
||||
7_432_589,
|
||||
7_813_809,
|
||||
8_329_633,
|
||||
10_296_990,
|
||||
10_417_652,
|
||||
10_492_265,
|
||||
10_598_078,
|
||||
10_782_398,
|
||||
10_902_612,
|
||||
11_203_736,
|
||||
11_342_890,
|
||||
11_397_596,
|
||||
11_589_762,
|
||||
11_705_103,
|
||||
12_936_875,
|
||||
13_289_782,
|
||||
}
|
||||
Labels = _Labels
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
root: str,
|
||||
extra: str,
|
||||
transforms: Optional[Callable] = None,
|
||||
transform: Optional[Callable] = None,
|
||||
target_transform: Optional[Callable] = None,
|
||||
mmap_cache_size: int = _DEFAULT_MMAP_CACHE_SIZE,
|
||||
) -> None:
|
||||
super().__init__(root, transforms, transform, target_transform)
|
||||
self._extra_root = extra
|
||||
|
||||
entries_path = self._get_entries_path(root)
|
||||
self._entries = self._load_extra(entries_path)
|
||||
|
||||
class_ids_path = self._get_class_ids_path(root)
|
||||
self._class_ids = self._load_extra(class_ids_path)
|
||||
|
||||
self._gzipped_indices = ImageNet22k._GZIPPED_INDICES
|
||||
self._mmap_tarball = _make_mmap_tarball(self._tarballs_root, mmap_cache_size)
|
||||
|
||||
def _get_entries_path(self, root: Optional[str] = None) -> str:
|
||||
return "entries.npy"
|
||||
|
||||
def _get_class_ids_path(self, root: Optional[str] = None) -> str:
|
||||
return "class-ids.npy"
|
||||
|
||||
def _find_class_ids(self, path: str) -> List[str]:
|
||||
class_ids = []
|
||||
|
||||
with os.scandir(path) as entries:
|
||||
for entry in entries:
|
||||
root, ext = os.path.splitext(entry.name)
|
||||
if ext != ".tar":
|
||||
continue
|
||||
class_ids.append(root)
|
||||
|
||||
return sorted(class_ids)
|
||||
|
||||
def _load_entries_class_ids(self, root: Optional[str] = None) -> Tuple[List[_Entry], List[str]]:
|
||||
root = self.get_root(root)
|
||||
entries: List[_Entry] = []
|
||||
class_ids = self._find_class_ids(root)
|
||||
|
||||
for class_index, class_id in enumerate(class_ids):
|
||||
path = os.path.join(root, "blocks", f"{class_id}.log")
|
||||
class_entries = []
|
||||
|
||||
try:
|
||||
with open(path) as f:
|
||||
for line in f:
|
||||
line = line.rstrip()
|
||||
block, filename = line.split(":")
|
||||
block_offset = int(block[6:])
|
||||
filename = filename[1:]
|
||||
|
||||
maybe_filename = None
|
||||
if filename != "** Block of NULs **":
|
||||
maybe_filename = filename
|
||||
_, ext = os.path.splitext(filename)
|
||||
# assert ext == ".JPEG"
|
||||
|
||||
class_entry = _ClassEntry(block_offset, maybe_filename)
|
||||
class_entries.append(class_entry)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f'can not read blocks file "{path}"') from e
|
||||
|
||||
assert class_entries[-1].maybe_filename is None
|
||||
|
||||
for class_entry1, class_entry2 in zip(class_entries, class_entries[1:]):
|
||||
assert class_entry1.block_offset <= class_entry2.block_offset
|
||||
start_offset = 512 * class_entry1.block_offset
|
||||
end_offset = 512 * class_entry2.block_offset
|
||||
assert class_entry1.maybe_filename is not None
|
||||
filename = class_entry1.maybe_filename
|
||||
entry = _Entry(class_index, start_offset, end_offset, filename)
|
||||
# Skip invalid image files (PIL throws UnidentifiedImageError)
|
||||
if filename == "n06470073_47249.JPEG":
|
||||
continue
|
||||
entries.append(entry)
|
||||
|
||||
return entries, class_ids
|
||||
|
||||
def _load_extra(self, extra_path: str) -> np.ndarray:
|
||||
extra_root = self._extra_root
|
||||
extra_full_path = os.path.join(extra_root, extra_path)
|
||||
return np.load(extra_full_path, mmap_mode="r")
|
||||
|
||||
def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
|
||||
extra_root = self._extra_root
|
||||
extra_full_path = os.path.join(extra_root, extra_path)
|
||||
os.makedirs(extra_root, exist_ok=True)
|
||||
np.save(extra_full_path, extra_array)
|
||||
|
||||
@property
|
||||
def _tarballs_root(self) -> str:
|
||||
return self.root
|
||||
|
||||
def find_class_id(self, class_index: int) -> str:
|
||||
return str(self._class_ids[class_index])
|
||||
|
||||
def get_image_data(self, index: int) -> bytes:
|
||||
entry = self._entries[index]
|
||||
class_id = entry["class_id"]
|
||||
class_mmap = self._mmap_tarball(class_id)
|
||||
|
||||
start_offset, end_offset = entry["start_offset"], entry["end_offset"]
|
||||
try:
|
||||
mapped_data = class_mmap[start_offset:end_offset]
|
||||
data = mapped_data[512:] # Skip entry header block
|
||||
|
||||
if len(data) >= 2 and tuple(data[:2]) == (0x1F, 0x8B):
|
||||
assert index in self._gzipped_indices, f"unexpected gzip header for sample {index}"
|
||||
with GzipFile(fileobj=BytesIO(data)) as g:
|
||||
data = g.read()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"can not retrieve image data for sample {index} " f'from "{class_id}" tarball') from e
|
||||
|
||||
return data
|
||||
|
||||
def get_target(self, index: int) -> Any:
|
||||
return int(self._entries[index]["class_index"])
|
||||
|
||||
def get_targets(self) -> np.ndarray:
|
||||
return self._entries["class_index"]
|
||||
|
||||
def get_class_id(self, index: int) -> str:
|
||||
return str(self._entries[index]["class_id"])
|
||||
|
||||
def get_class_ids(self) -> np.ndarray:
|
||||
return self._entries["class_id"]
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
return super().__getitem__(index)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._entries)
|
||||
|
||||
def _dump_entries(self, *args, **kwargs) -> None:
|
||||
entries, class_ids = self._load_entries_class_ids(*args, **kwargs)
|
||||
|
||||
max_class_id_length, max_filename_length, max_class_index = -1, -1, -1
|
||||
for entry in entries:
|
||||
class_id = class_ids[entry.class_index]
|
||||
max_class_index = max(entry.class_index, max_class_index)
|
||||
max_class_id_length = max(len(class_id), max_class_id_length)
|
||||
max_filename_length = max(len(entry.filename), max_filename_length)
|
||||
|
||||
dtype = np.dtype(
|
||||
[
|
||||
("class_index", "<u4"),
|
||||
("class_id", f"U{max_class_id_length}"),
|
||||
("start_offset", "<u4"),
|
||||
("end_offset", "<u4"),
|
||||
("filename", f"U{max_filename_length}"),
|
||||
]
|
||||
)
|
||||
sample_count = len(entries)
|
||||
entries_array = np.empty(sample_count, dtype=dtype)
|
||||
for i, entry in enumerate(entries):
|
||||
class_index = entry.class_index
|
||||
class_id = class_ids[class_index]
|
||||
start_offset = entry.start_offset
|
||||
end_offset = entry.end_offset
|
||||
filename = entry.filename
|
||||
entries_array[i] = (
|
||||
class_index,
|
||||
class_id,
|
||||
start_offset,
|
||||
end_offset,
|
||||
filename,
|
||||
)
|
||||
|
||||
entries_path = self._get_entries_path(*args, **kwargs)
|
||||
self._save_extra(entries_array, entries_path)
|
||||
|
||||
def _dump_class_ids(self, *args, **kwargs) -> None:
|
||||
entries_path = self._get_entries_path(*args, **kwargs)
|
||||
entries_array = self._load_extra(entries_path)
|
||||
|
||||
max_class_id_length, max_class_index = -1, -1
|
||||
for entry in entries_array:
|
||||
class_index, class_id = entry["class_index"], entry["class_id"]
|
||||
max_class_index = max(int(class_index), max_class_index)
|
||||
max_class_id_length = max(len(str(class_id)), max_class_id_length)
|
||||
|
||||
class_ids_array = np.empty(max_class_index + 1, dtype=f"U{max_class_id_length}")
|
||||
for entry in entries_array:
|
||||
class_index, class_id = entry["class_index"], entry["class_id"]
|
||||
class_ids_array[class_index] = class_id
|
||||
class_ids_path = self._get_class_ids_path(*args, **kwargs)
|
||||
self._save_extra(class_ids_array, class_ids_path)
|
||||
|
||||
def _dump_extra(self, *args, **kwargs) -> None:
|
||||
self._dump_entries(*args, *kwargs)
|
||||
self._dump_class_ids(*args, *kwargs)
|
||||
|
||||
def dump_extra(self, root: Optional[str] = None) -> None:
|
||||
return self._dump_extra(root)
|
||||
@@ -0,0 +1,223 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, List, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
from .datasets import ImageNet, ImageNet22k
|
||||
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
class SamplerType(Enum):
|
||||
DISTRIBUTED = 0
|
||||
EPOCH = 1
|
||||
INFINITE = 2
|
||||
SHARDED_INFINITE = 3
|
||||
SHARDED_INFINITE_NEW = 4
|
||||
|
||||
|
||||
def _make_bool_str(b: bool) -> str:
|
||||
return "yes" if b else "no"
|
||||
|
||||
|
||||
def _make_sample_transform(image_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
|
||||
def transform(sample):
|
||||
image, target = sample
|
||||
if image_transform is not None:
|
||||
image = image_transform(image)
|
||||
if target_transform is not None:
|
||||
target = target_transform(target)
|
||||
return image, target
|
||||
|
||||
return transform
|
||||
|
||||
|
||||
def _parse_dataset_str(dataset_str: str):
|
||||
tokens = dataset_str.split(":")
|
||||
|
||||
name = tokens[0]
|
||||
kwargs = {}
|
||||
|
||||
for token in tokens[1:]:
|
||||
key, value = token.split("=")
|
||||
assert key in ("root", "extra", "split")
|
||||
kwargs[key] = value
|
||||
|
||||
if name == "ImageNet":
|
||||
class_ = ImageNet
|
||||
if "split" in kwargs:
|
||||
kwargs["split"] = ImageNet.Split[kwargs["split"]]
|
||||
elif name == "ImageNet22k":
|
||||
class_ = ImageNet22k
|
||||
else:
|
||||
raise ValueError(f'Unsupported dataset "{name}"')
|
||||
|
||||
return class_, kwargs
|
||||
|
||||
|
||||
def make_dataset(
|
||||
*,
|
||||
dataset_str: str,
|
||||
transform: Optional[Callable] = None,
|
||||
target_transform: Optional[Callable] = None,
|
||||
):
|
||||
"""
|
||||
Creates a dataset with the specified parameters.
|
||||
|
||||
Args:
|
||||
dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN).
|
||||
transform: A transform to apply to images.
|
||||
target_transform: A transform to apply to targets.
|
||||
|
||||
Returns:
|
||||
The created dataset.
|
||||
"""
|
||||
logger.info(f'using dataset: "{dataset_str}"')
|
||||
|
||||
class_, kwargs = _parse_dataset_str(dataset_str)
|
||||
dataset = class_(transform=transform, target_transform=target_transform, **kwargs)
|
||||
|
||||
logger.info(f"# of dataset samples: {len(dataset):,d}")
|
||||
|
||||
# Aggregated datasets do not expose (yet) these attributes, so add them.
|
||||
if not hasattr(dataset, "transform"):
|
||||
setattr(dataset, "transform", transform)
|
||||
if not hasattr(dataset, "target_transform"):
|
||||
setattr(dataset, "target_transform", target_transform)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def _make_sampler(
|
||||
*,
|
||||
dataset,
|
||||
type: Optional[SamplerType] = None,
|
||||
shuffle: bool = False,
|
||||
seed: int = 0,
|
||||
size: int = -1,
|
||||
advance: int = 0,
|
||||
) -> Optional[Sampler]:
|
||||
sample_count = len(dataset)
|
||||
|
||||
if type == SamplerType.INFINITE:
|
||||
logger.info("sampler: infinite")
|
||||
if size > 0:
|
||||
raise ValueError("sampler size > 0 is invalid")
|
||||
return InfiniteSampler(
|
||||
sample_count=sample_count,
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
advance=advance,
|
||||
)
|
||||
elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW):
|
||||
logger.info("sampler: sharded infinite")
|
||||
if size > 0:
|
||||
raise ValueError("sampler size > 0 is invalid")
|
||||
# TODO: Remove support for old shuffling
|
||||
use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW
|
||||
return ShardedInfiniteSampler(
|
||||
sample_count=sample_count,
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
advance=advance,
|
||||
use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice,
|
||||
)
|
||||
elif type == SamplerType.EPOCH:
|
||||
logger.info("sampler: epoch")
|
||||
if advance > 0:
|
||||
raise NotImplementedError("sampler advance > 0 is not supported")
|
||||
size = size if size > 0 else sample_count
|
||||
logger.info(f"# of samples / epoch: {size:,d}")
|
||||
return EpochSampler(
|
||||
size=size,
|
||||
sample_count=sample_count,
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
)
|
||||
elif type == SamplerType.DISTRIBUTED:
|
||||
logger.info("sampler: distributed")
|
||||
if size > 0:
|
||||
raise ValueError("sampler size > 0 is invalid")
|
||||
if advance > 0:
|
||||
raise ValueError("sampler advance > 0 is invalid")
|
||||
return torch.utils.data.DistributedSampler(
|
||||
dataset=dataset,
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
logger.info("sampler: none")
|
||||
return None
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def make_data_loader(
|
||||
*,
|
||||
dataset,
|
||||
batch_size: int,
|
||||
num_workers: int,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
sampler_type: Optional[SamplerType] = SamplerType.INFINITE,
|
||||
sampler_size: int = -1,
|
||||
sampler_advance: int = 0,
|
||||
drop_last: bool = True,
|
||||
persistent_workers: bool = False,
|
||||
collate_fn: Optional[Callable[[List[T]], Any]] = None,
|
||||
):
|
||||
"""
|
||||
Creates a data loader with the specified parameters.
|
||||
|
||||
Args:
|
||||
dataset: A dataset (third party, LaViDa or WebDataset).
|
||||
batch_size: The size of batches to generate.
|
||||
num_workers: The number of workers to use.
|
||||
shuffle: Whether to shuffle samples.
|
||||
seed: The random seed to use.
|
||||
sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None.
|
||||
sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset.
|
||||
sampler_advance: How many samples to skip (when applicable).
|
||||
drop_last: Whether the last non-full batch of data should be dropped.
|
||||
persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once.
|
||||
collate_fn: Function that performs batch collation
|
||||
"""
|
||||
|
||||
sampler = _make_sampler(
|
||||
dataset=dataset,
|
||||
type=sampler_type,
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
size=sampler_size,
|
||||
advance=sampler_advance,
|
||||
)
|
||||
|
||||
logger.info("using PyTorch data loader")
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=drop_last,
|
||||
persistent_workers=persistent_workers,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
try:
|
||||
logger.info(f"# of batches: {len(data_loader):,d}")
|
||||
except TypeError: # data loader has no length
|
||||
logger.info("infinite data loader")
|
||||
return data_loader
|
||||
@@ -0,0 +1,87 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import random
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MaskingGenerator:
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
num_masking_patches=None,
|
||||
min_num_patches=4,
|
||||
max_num_patches=None,
|
||||
min_aspect=0.3,
|
||||
max_aspect=None,
|
||||
):
|
||||
if not isinstance(input_size, tuple):
|
||||
input_size = (input_size,) * 2
|
||||
self.height, self.width = input_size
|
||||
|
||||
self.num_patches = self.height * self.width
|
||||
self.num_masking_patches = num_masking_patches
|
||||
|
||||
self.min_num_patches = min_num_patches
|
||||
self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches
|
||||
|
||||
max_aspect = max_aspect or 1 / min_aspect
|
||||
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
|
||||
self.height,
|
||||
self.width,
|
||||
self.min_num_patches,
|
||||
self.max_num_patches,
|
||||
self.num_masking_patches,
|
||||
self.log_aspect_ratio[0],
|
||||
self.log_aspect_ratio[1],
|
||||
)
|
||||
return repr_str
|
||||
|
||||
def get_shape(self):
|
||||
return self.height, self.width
|
||||
|
||||
def _mask(self, mask, max_mask_patches):
|
||||
delta = 0
|
||||
for _ in range(10):
|
||||
target_area = random.uniform(self.min_num_patches, max_mask_patches)
|
||||
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
if w < self.width and h < self.height:
|
||||
top = random.randint(0, self.height - h)
|
||||
left = random.randint(0, self.width - w)
|
||||
|
||||
num_masked = mask[top : top + h, left : left + w].sum()
|
||||
# Overlap
|
||||
if 0 < h * w - num_masked <= max_mask_patches:
|
||||
for i in range(top, top + h):
|
||||
for j in range(left, left + w):
|
||||
if mask[i, j] == 0:
|
||||
mask[i, j] = 1
|
||||
delta += 1
|
||||
|
||||
if delta > 0:
|
||||
break
|
||||
return delta
|
||||
|
||||
def __call__(self, num_masking_patches=0):
|
||||
mask = np.zeros(shape=self.get_shape(), dtype=bool)
|
||||
mask_count = 0
|
||||
while mask_count < num_masking_patches:
|
||||
max_mask_patches = num_masking_patches - mask_count
|
||||
max_mask_patches = min(max_mask_patches, self.max_num_patches)
|
||||
|
||||
delta = self._mask(mask, max_mask_patches)
|
||||
if delta == 0:
|
||||
break
|
||||
else:
|
||||
mask_count += delta
|
||||
|
||||
return mask
|
||||
@@ -0,0 +1,230 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import itertools
|
||||
from typing import Any, Optional
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
import dinov2.distributed as distributed
|
||||
|
||||
|
||||
class EpochSampler(Sampler):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
size: int,
|
||||
sample_count: int,
|
||||
shuffle: bool = False,
|
||||
seed: int = 0,
|
||||
start: Optional[int] = None,
|
||||
step: Optional[int] = None,
|
||||
):
|
||||
self._size = size
|
||||
self._sample_count = sample_count
|
||||
self._shuffle = shuffle
|
||||
self._seed = seed
|
||||
self._start = distributed.get_global_rank() if start is None else start
|
||||
self._step = distributed.get_global_size() if step is None else step
|
||||
self._epoch = 0
|
||||
|
||||
def __iter__(self):
|
||||
count = (self._size + self._sample_count - 1) // self._sample_count
|
||||
tiled_indices = np.tile(np.arange(self._sample_count), count)
|
||||
if self._shuffle:
|
||||
seed = self._seed * self._epoch if self._seed != 0 else self._epoch
|
||||
rng = np.random.default_rng(seed)
|
||||
iterable = rng.choice(tiled_indices, self._size, replace=False)
|
||||
else:
|
||||
iterable = tiled_indices[: self._size]
|
||||
|
||||
yield from itertools.islice(iterable, self._start, None, self._step)
|
||||
|
||||
def __len__(self):
|
||||
return (self._size - self._start + self._step - 1) // self._step
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self._epoch = epoch
|
||||
|
||||
|
||||
def _get_numpy_dtype(size: int) -> Any:
|
||||
return np.int32 if size <= 2**31 else np.int64
|
||||
|
||||
|
||||
def _get_torch_dtype(size: int) -> Any:
|
||||
return torch.int32 if size <= 2**31 else torch.int64
|
||||
|
||||
|
||||
def _generate_randperm_indices(*, size: int, generator: torch.Generator):
|
||||
"""Generate the indices of a random permutation."""
|
||||
dtype = _get_torch_dtype(size)
|
||||
# This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921
|
||||
perm = torch.arange(size, dtype=dtype)
|
||||
for i in range(size):
|
||||
j = torch.randint(i, size, size=(1,), generator=generator).item()
|
||||
|
||||
# Always swap even if no-op
|
||||
value = perm[j].item()
|
||||
perm[j] = perm[i].item()
|
||||
perm[i] = value
|
||||
yield value
|
||||
|
||||
|
||||
class InfiniteSampler(Sampler):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
sample_count: int,
|
||||
shuffle: bool = False,
|
||||
seed: int = 0,
|
||||
start: Optional[int] = None,
|
||||
step: Optional[int] = None,
|
||||
advance: int = 0,
|
||||
):
|
||||
self._sample_count = sample_count
|
||||
self._seed = seed
|
||||
self._shuffle = shuffle
|
||||
self._start = distributed.get_global_rank() if start is None else start
|
||||
self._step = distributed.get_global_size() if step is None else step
|
||||
self._advance = advance
|
||||
|
||||
def __iter__(self):
|
||||
if self._shuffle:
|
||||
iterator = self._shuffled_iterator()
|
||||
else:
|
||||
iterator = self._iterator()
|
||||
|
||||
yield from itertools.islice(iterator, self._advance, None)
|
||||
|
||||
def _iterator(self):
|
||||
assert not self._shuffle
|
||||
|
||||
while True:
|
||||
iterable = range(self._sample_count)
|
||||
yield from itertools.islice(iterable, self._start, None, self._step)
|
||||
|
||||
def _shuffled_iterator(self):
|
||||
assert self._shuffle
|
||||
|
||||
# Instantiate a generator here (rather than in the ctor) to keep the class
|
||||
# picklable (requirement of mp.spawn)
|
||||
generator = torch.Generator().manual_seed(self._seed)
|
||||
|
||||
while True:
|
||||
iterable = _generate_randperm_indices(size=self._sample_count, generator=generator)
|
||||
yield from itertools.islice(iterable, self._start, None, self._step)
|
||||
|
||||
|
||||
# The following function is somewhat equivalent to _new_shuffle_tensor_slice below,
|
||||
# but avoids a full in-place random permutation generation.
|
||||
def _shuffle_tensor_slice(
|
||||
*, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
|
||||
) -> np.ndarray:
|
||||
stop = len(tensor)
|
||||
count = stop // step
|
||||
drop_count = stop - step * count
|
||||
if drop_count:
|
||||
warnings.warn(f"# of dropped samples: {drop_count}")
|
||||
|
||||
dtype = _get_numpy_dtype(stop)
|
||||
result = np.empty(count, dtype=dtype)
|
||||
|
||||
for i in range(count):
|
||||
j = torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0
|
||||
|
||||
result[i] = result[j]
|
||||
result[j] = tensor[start + i * step].item()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _new_shuffle_tensor_slice(
|
||||
*, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
|
||||
) -> np.ndarray:
|
||||
stop = len(tensor)
|
||||
count = stop // step
|
||||
dtype = torch.int64 # Needed for using randperm result as indices
|
||||
count = stop // step
|
||||
drop_count = stop - step * count
|
||||
if drop_count:
|
||||
warnings.warn(f"# of dropped samples: {drop_count}")
|
||||
indices = torch.randperm(count, dtype=dtype, generator=generator)
|
||||
return tensor[start::step][indices].numpy()
|
||||
|
||||
|
||||
def _make_seed(seed: int, start: int, iter_count: int) -> int:
|
||||
# NOTE: Tried a few variants (including iter_count << 32), this one worked best.
|
||||
return seed + start + (iter_count << 24)
|
||||
|
||||
|
||||
class ShardedInfiniteSampler(Sampler):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
sample_count: int,
|
||||
shuffle: bool = False,
|
||||
seed: int = 0,
|
||||
start: Optional[int] = None,
|
||||
step: Optional[int] = None,
|
||||
advance: int = 0,
|
||||
use_new_shuffle_tensor_slice: bool = False,
|
||||
):
|
||||
self._sample_count = sample_count
|
||||
self._seed = seed
|
||||
self._shuffle = shuffle
|
||||
self._start = distributed.get_global_rank() if start is None else start
|
||||
self._step = distributed.get_global_size() if step is None else step
|
||||
self._advance = advance
|
||||
self._iter_count = 0
|
||||
self._shuffle_tensor_slice_fn = (
|
||||
_new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
iter_count = self._advance // self._sample_count
|
||||
if iter_count > 0:
|
||||
self._advance -= iter_count * self._sample_count
|
||||
self._iter_count += iter_count
|
||||
|
||||
if self._shuffle:
|
||||
iterator = self._shuffled_iterator()
|
||||
else:
|
||||
iterator = self._iterator()
|
||||
|
||||
yield from itertools.islice(iterator, self._advance, None)
|
||||
|
||||
def _iterator(self):
|
||||
assert not self._shuffle
|
||||
|
||||
while True:
|
||||
iterable = range(self._sample_count)
|
||||
yield from itertools.islice(iterable, self._start, None, self._step)
|
||||
|
||||
def _shuffled_iterator(self):
|
||||
assert self._shuffle
|
||||
|
||||
# Instantiate a generator here (rather than in the ctor) to be keep the class
|
||||
# picklable (requirement of mp.spawn)
|
||||
generator = torch.Generator()
|
||||
|
||||
# Always shuffle everything first
|
||||
generator.manual_seed(self._seed)
|
||||
dtype = _get_torch_dtype(self._sample_count)
|
||||
perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator)
|
||||
|
||||
while True:
|
||||
# Re-seed on each iteration to allow skipping whole permutations
|
||||
seed = _make_seed(self._seed, self._start, self._iter_count)
|
||||
generator.manual_seed(seed)
|
||||
|
||||
iterable = self._shuffle_tensor_slice_fn(
|
||||
tensor=perm, start=self._start, step=self._step, generator=generator
|
||||
)
|
||||
yield from iterable
|
||||
self._iter_count += 1
|
||||
@@ -0,0 +1,92 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
class GaussianBlur(transforms.RandomApply):
|
||||
"""
|
||||
Apply Gaussian Blur to the PIL image.
|
||||
"""
|
||||
|
||||
def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0):
|
||||
# NOTE: torchvision is applying 1 - probability to return the original image
|
||||
keep_p = 1 - p
|
||||
transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max))
|
||||
super().__init__(transforms=[transform], p=keep_p)
|
||||
|
||||
|
||||
class MaybeToTensor(transforms.ToTensor):
|
||||
"""
|
||||
Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor.
|
||||
"""
|
||||
|
||||
def __call__(self, pic):
|
||||
"""
|
||||
Args:
|
||||
pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor.
|
||||
Returns:
|
||||
Tensor: Converted image.
|
||||
"""
|
||||
if isinstance(pic, torch.Tensor):
|
||||
return pic
|
||||
return super().__call__(pic)
|
||||
|
||||
|
||||
# Use timm's names
|
||||
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
||||
|
||||
|
||||
def make_normalize_transform(
|
||||
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
|
||||
std: Sequence[float] = IMAGENET_DEFAULT_STD,
|
||||
) -> transforms.Normalize:
|
||||
return transforms.Normalize(mean=mean, std=std)
|
||||
|
||||
|
||||
# This roughly matches torchvision's preset for classification training:
|
||||
# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44
|
||||
def make_classification_train_transform(
|
||||
*,
|
||||
crop_size: int = 224,
|
||||
interpolation=transforms.InterpolationMode.BICUBIC,
|
||||
hflip_prob: float = 0.5,
|
||||
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
|
||||
std: Sequence[float] = IMAGENET_DEFAULT_STD,
|
||||
):
|
||||
transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
|
||||
if hflip_prob > 0.0:
|
||||
transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob))
|
||||
transforms_list.extend(
|
||||
[
|
||||
MaybeToTensor(),
|
||||
make_normalize_transform(mean=mean, std=std),
|
||||
]
|
||||
)
|
||||
return transforms.Compose(transforms_list)
|
||||
|
||||
|
||||
# This matches (roughly) torchvision's preset for classification evaluation:
|
||||
# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69
|
||||
def make_classification_eval_transform(
|
||||
*,
|
||||
resize_size: int = 256,
|
||||
interpolation=transforms.InterpolationMode.BICUBIC,
|
||||
crop_size: int = 224,
|
||||
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
|
||||
std: Sequence[float] = IMAGENET_DEFAULT_STD,
|
||||
) -> transforms.Compose:
|
||||
transforms_list = [
|
||||
transforms.Resize(resize_size, interpolation=interpolation),
|
||||
transforms.CenterCrop(crop_size),
|
||||
MaybeToTensor(),
|
||||
make_normalize_transform(mean=mean, std=std),
|
||||
]
|
||||
return transforms.Compose(transforms_list)
|
||||
@@ -0,0 +1,271 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import socket
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
_LOCAL_RANK = -1
|
||||
_LOCAL_WORLD_SIZE = -1
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
"""
|
||||
Returns:
|
||||
True if distributed training is enabled
|
||||
"""
|
||||
return dist.is_available() and dist.is_initialized()
|
||||
|
||||
|
||||
def get_global_size() -> int:
|
||||
"""
|
||||
Returns:
|
||||
The number of processes in the process group
|
||||
"""
|
||||
return dist.get_world_size() if is_enabled() else 1
|
||||
|
||||
|
||||
def get_global_rank() -> int:
|
||||
"""
|
||||
Returns:
|
||||
The rank of the current process within the global process group.
|
||||
"""
|
||||
return dist.get_rank() if is_enabled() else 0
|
||||
|
||||
|
||||
def get_local_rank() -> int:
|
||||
"""
|
||||
Returns:
|
||||
The rank of the current process within the local (per-machine) process group.
|
||||
"""
|
||||
if not is_enabled():
|
||||
return 0
|
||||
assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
|
||||
return _LOCAL_RANK
|
||||
|
||||
|
||||
def get_local_size() -> int:
|
||||
"""
|
||||
Returns:
|
||||
The size of the per-machine process group,
|
||||
i.e. the number of processes per machine.
|
||||
"""
|
||||
if not is_enabled():
|
||||
return 1
|
||||
assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
|
||||
return _LOCAL_WORLD_SIZE
|
||||
|
||||
|
||||
def is_main_process() -> bool:
|
||||
"""
|
||||
Returns:
|
||||
True if the current process is the main one.
|
||||
"""
|
||||
return get_global_rank() == 0
|
||||
|
||||
|
||||
def _restrict_print_to_main_process() -> None:
|
||||
"""
|
||||
This function disables printing when not in the main process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop("force", False)
|
||||
if is_main_process() or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def _get_master_port(seed: int = 0) -> int:
|
||||
MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000)
|
||||
|
||||
master_port_str = os.environ.get("MASTER_PORT")
|
||||
if master_port_str is None:
|
||||
rng = random.Random(seed)
|
||||
return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT)
|
||||
|
||||
return int(master_port_str)
|
||||
|
||||
|
||||
def _get_available_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
# A "" host address means INADDR_ANY i.e. binding to all interfaces.
|
||||
# Note this is not compatible with IPv6.
|
||||
s.bind(("", 0))
|
||||
port = s.getsockname()[1]
|
||||
return port
|
||||
|
||||
|
||||
_TORCH_DISTRIBUTED_ENV_VARS = (
|
||||
"MASTER_ADDR",
|
||||
"MASTER_PORT",
|
||||
"RANK",
|
||||
"WORLD_SIZE",
|
||||
"LOCAL_RANK",
|
||||
"LOCAL_WORLD_SIZE",
|
||||
)
|
||||
|
||||
|
||||
def _collect_env_vars() -> Dict[str, str]:
|
||||
return {env_var: os.environ[env_var] for env_var in _TORCH_DISTRIBUTED_ENV_VARS if env_var in os.environ}
|
||||
|
||||
|
||||
def _is_slurm_job_process() -> bool:
|
||||
return "SLURM_JOB_ID" in os.environ
|
||||
|
||||
|
||||
def _parse_slurm_node_list(s: str) -> List[str]:
|
||||
nodes = []
|
||||
# Extract "hostname", "hostname[1-2,3,4-5]," substrings
|
||||
p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?")
|
||||
for m in p.finditer(s):
|
||||
prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)]
|
||||
for suffix in suffixes.split(","):
|
||||
span = suffix.split("-")
|
||||
if len(span) == 1:
|
||||
nodes.append(prefix + suffix)
|
||||
else:
|
||||
width = len(span[0])
|
||||
start, end = int(span[0]), int(span[1]) + 1
|
||||
nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)])
|
||||
return nodes
|
||||
|
||||
|
||||
def _check_env_variable(key: str, new_value: str):
|
||||
# Only check for difference with preset environment variables
|
||||
if key in os.environ and os.environ[key] != new_value:
|
||||
raise RuntimeError(f"Cannot export environment variables as {key} is already set")
|
||||
|
||||
|
||||
class _TorchDistributedEnvironment:
|
||||
def __init__(self):
|
||||
self.master_addr = "127.0.0.1"
|
||||
self.master_port = 0
|
||||
self.rank = -1
|
||||
self.world_size = -1
|
||||
self.local_rank = -1
|
||||
self.local_world_size = -1
|
||||
|
||||
if _is_slurm_job_process():
|
||||
return self._set_from_slurm_env()
|
||||
|
||||
env_vars = _collect_env_vars()
|
||||
if not env_vars:
|
||||
# Environment is not set
|
||||
pass
|
||||
elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS):
|
||||
# Environment is fully set
|
||||
return self._set_from_preset_env()
|
||||
else:
|
||||
# Environment is partially set
|
||||
collected_env_vars = ", ".join(env_vars.keys())
|
||||
raise RuntimeError(f"Partially set environment: {collected_env_vars}")
|
||||
|
||||
if torch.cuda.device_count() > 0:
|
||||
return self._set_from_local()
|
||||
|
||||
raise RuntimeError("Can't initialize PyTorch distributed environment")
|
||||
|
||||
# Slurm job created with sbatch, submitit, etc...
|
||||
def _set_from_slurm_env(self):
|
||||
# logger.info("Initialization from Slurm environment")
|
||||
job_id = int(os.environ["SLURM_JOB_ID"])
|
||||
node_count = int(os.environ["SLURM_JOB_NUM_NODES"])
|
||||
nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"])
|
||||
assert len(nodes) == node_count
|
||||
|
||||
self.master_addr = nodes[0]
|
||||
self.master_port = _get_master_port(seed=job_id)
|
||||
self.rank = int(os.environ["SLURM_PROCID"])
|
||||
self.world_size = int(os.environ["SLURM_NTASKS"])
|
||||
assert self.rank < self.world_size
|
||||
self.local_rank = int(os.environ["SLURM_LOCALID"])
|
||||
self.local_world_size = self.world_size // node_count
|
||||
assert self.local_rank < self.local_world_size
|
||||
|
||||
# Single node job with preset environment (i.e. torchrun)
|
||||
def _set_from_preset_env(self):
|
||||
# logger.info("Initialization from preset environment")
|
||||
self.master_addr = os.environ["MASTER_ADDR"]
|
||||
self.master_port = os.environ["MASTER_PORT"]
|
||||
self.rank = int(os.environ["RANK"])
|
||||
self.world_size = int(os.environ["WORLD_SIZE"])
|
||||
assert self.rank < self.world_size
|
||||
self.local_rank = int(os.environ["LOCAL_RANK"])
|
||||
self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
|
||||
assert self.local_rank < self.local_world_size
|
||||
|
||||
# Single node and GPU job (i.e. local script run)
|
||||
def _set_from_local(self):
|
||||
# logger.info("Initialization from local")
|
||||
self.master_addr = "127.0.0.1"
|
||||
self.master_port = _get_available_port()
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
self.local_rank = 0
|
||||
self.local_world_size = 1
|
||||
|
||||
def export(self, *, overwrite: bool) -> "_TorchDistributedEnvironment":
|
||||
# See the "Environment variable initialization" section from
|
||||
# https://pytorch.org/docs/stable/distributed.html for the complete list of
|
||||
# environment variables required for the env:// initialization method.
|
||||
env_vars = {
|
||||
"MASTER_ADDR": self.master_addr,
|
||||
"MASTER_PORT": str(self.master_port),
|
||||
"RANK": str(self.rank),
|
||||
"WORLD_SIZE": str(self.world_size),
|
||||
"LOCAL_RANK": str(self.local_rank),
|
||||
"LOCAL_WORLD_SIZE": str(self.local_world_size),
|
||||
}
|
||||
if not overwrite:
|
||||
for k, v in env_vars.items():
|
||||
_check_env_variable(k, v)
|
||||
|
||||
os.environ.update(env_vars)
|
||||
return self
|
||||
|
||||
|
||||
def enable(*, set_cuda_current_device: bool = True, overwrite: bool = False, allow_nccl_timeout: bool = False):
|
||||
"""Enable distributed mode
|
||||
|
||||
Args:
|
||||
set_cuda_current_device: If True, call torch.cuda.set_device() to set the
|
||||
current PyTorch CUDA device to the one matching the local rank.
|
||||
overwrite: If True, overwrites already set variables. Else fails.
|
||||
"""
|
||||
|
||||
global _LOCAL_RANK, _LOCAL_WORLD_SIZE
|
||||
if _LOCAL_RANK >= 0 or _LOCAL_WORLD_SIZE >= 0:
|
||||
raise RuntimeError("Distributed mode has already been enabled")
|
||||
torch_env = _TorchDistributedEnvironment()
|
||||
torch_env.export(overwrite=overwrite)
|
||||
|
||||
if set_cuda_current_device:
|
||||
torch.cuda.set_device(torch_env.local_rank)
|
||||
|
||||
if allow_nccl_timeout:
|
||||
# This allows to use torch distributed timeout in a NCCL backend
|
||||
key, value = "NCCL_ASYNC_ERROR_HANDLING", "1"
|
||||
if not overwrite:
|
||||
_check_env_variable(key, value)
|
||||
os.environ[key] = value
|
||||
|
||||
dist.init_process_group(backend="nccl")
|
||||
dist.barrier()
|
||||
|
||||
# Finalize setup
|
||||
_LOCAL_RANK = torch_env.local_rank
|
||||
_LOCAL_WORLD_SIZE = torch_env.local_world_size
|
||||
_restrict_print_to_main_process()
|
||||
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
@@ -0,0 +1,405 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
from functools import partial
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.functional import one_hot, softmax
|
||||
|
||||
import dinov2.distributed as distributed
|
||||
from dinov2.data import SamplerType, make_data_loader, make_dataset
|
||||
from dinov2.data.transforms import make_classification_eval_transform
|
||||
from dinov2.eval.metrics import AccuracyAveraging, build_topk_accuracy_metric
|
||||
from dinov2.eval.setup import get_args_parser as get_setup_args_parser
|
||||
from dinov2.eval.setup import setup_and_build_model
|
||||
from dinov2.eval.utils import ModelWithNormalize, evaluate, extract_features
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
def get_args_parser(
|
||||
description: Optional[str] = None,
|
||||
parents: Optional[List[argparse.ArgumentParser]] = None,
|
||||
add_help: bool = True,
|
||||
):
|
||||
parents = parents or []
|
||||
setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
|
||||
parents = [setup_args_parser]
|
||||
parser = argparse.ArgumentParser(
|
||||
description=description,
|
||||
parents=parents,
|
||||
add_help=add_help,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train-dataset",
|
||||
dest="train_dataset_str",
|
||||
type=str,
|
||||
help="Training dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val-dataset",
|
||||
dest="val_dataset_str",
|
||||
type=str,
|
||||
help="Validation dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nb_knn",
|
||||
nargs="+",
|
||||
type=int,
|
||||
help="Number of NN to use. 20 is usually working the best.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
help="Temperature used in the voting coefficient",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gather-on-cpu",
|
||||
action="store_true",
|
||||
help="Whether to gather the train features on cpu, slower"
|
||||
"but useful to avoid OOM for large datasets (e.g. ImageNet22k).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
help="Batch size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n-per-class-list",
|
||||
nargs="+",
|
||||
type=int,
|
||||
help="Number to take per class",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n-tries",
|
||||
type=int,
|
||||
help="Number of tries",
|
||||
)
|
||||
parser.set_defaults(
|
||||
train_dataset_str="ImageNet:split=TRAIN",
|
||||
val_dataset_str="ImageNet:split=VAL",
|
||||
nb_knn=[10, 20, 100, 200],
|
||||
temperature=0.07,
|
||||
batch_size=256,
|
||||
n_per_class_list=[-1],
|
||||
n_tries=1,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
class KnnModule(torch.nn.Module):
|
||||
"""
|
||||
Gets knn of test features from all processes on a chunk of the train features
|
||||
|
||||
Each rank gets a chunk of the train features as well as a chunk of the test features.
|
||||
In `compute_neighbors`, for each rank one after the other, its chunk of test features
|
||||
is sent to all devices, partial knns are computed with each chunk of train features
|
||||
then collated back on the original device.
|
||||
"""
|
||||
|
||||
def __init__(self, train_features, train_labels, nb_knn, T, device, num_classes=1000):
|
||||
super().__init__()
|
||||
|
||||
self.global_rank = distributed.get_global_rank()
|
||||
self.global_size = distributed.get_global_size()
|
||||
|
||||
self.device = device
|
||||
self.train_features_rank_T = train_features.chunk(self.global_size)[self.global_rank].T.to(self.device)
|
||||
self.candidates = train_labels.chunk(self.global_size)[self.global_rank].view(1, -1).to(self.device)
|
||||
|
||||
self.nb_knn = nb_knn
|
||||
self.max_k = max(self.nb_knn)
|
||||
self.T = T
|
||||
self.num_classes = num_classes
|
||||
|
||||
def _get_knn_sims_and_labels(self, similarity, train_labels):
|
||||
topk_sims, indices = similarity.topk(self.max_k, largest=True, sorted=True)
|
||||
neighbors_labels = torch.gather(train_labels, 1, indices)
|
||||
return topk_sims, neighbors_labels
|
||||
|
||||
def _similarity_for_rank(self, features_rank, source_rank):
|
||||
# Send the features from `source_rank` to all ranks
|
||||
broadcast_shape = torch.tensor(features_rank.shape).to(self.device)
|
||||
torch.distributed.broadcast(broadcast_shape, source_rank)
|
||||
|
||||
broadcasted = features_rank
|
||||
if self.global_rank != source_rank:
|
||||
broadcasted = torch.zeros(*broadcast_shape, dtype=features_rank.dtype, device=self.device)
|
||||
torch.distributed.broadcast(broadcasted, source_rank)
|
||||
|
||||
# Compute the neighbors for `source_rank` among `train_features_rank_T`
|
||||
similarity_rank = torch.mm(broadcasted, self.train_features_rank_T)
|
||||
candidate_labels = self.candidates.expand(len(similarity_rank), -1)
|
||||
return self._get_knn_sims_and_labels(similarity_rank, candidate_labels)
|
||||
|
||||
def _gather_all_knn_for_rank(self, topk_sims, neighbors_labels, target_rank):
|
||||
# Gather all neighbors for `target_rank`
|
||||
topk_sims_rank = retrieved_rank = None
|
||||
if self.global_rank == target_rank:
|
||||
topk_sims_rank = [torch.zeros_like(topk_sims) for _ in range(self.global_size)]
|
||||
retrieved_rank = [torch.zeros_like(neighbors_labels) for _ in range(self.global_size)]
|
||||
|
||||
torch.distributed.gather(topk_sims, topk_sims_rank, dst=target_rank)
|
||||
torch.distributed.gather(neighbors_labels, retrieved_rank, dst=target_rank)
|
||||
|
||||
if self.global_rank == target_rank:
|
||||
# Perform a second top-k on the k * global_size retrieved neighbors
|
||||
topk_sims_rank = torch.cat(topk_sims_rank, dim=1)
|
||||
retrieved_rank = torch.cat(retrieved_rank, dim=1)
|
||||
results = self._get_knn_sims_and_labels(topk_sims_rank, retrieved_rank)
|
||||
return results
|
||||
return None
|
||||
|
||||
def compute_neighbors(self, features_rank):
|
||||
for rank in range(self.global_size):
|
||||
topk_sims, neighbors_labels = self._similarity_for_rank(features_rank, rank)
|
||||
results = self._gather_all_knn_for_rank(topk_sims, neighbors_labels, rank)
|
||||
if results is not None:
|
||||
topk_sims_rank, neighbors_labels_rank = results
|
||||
return topk_sims_rank, neighbors_labels_rank
|
||||
|
||||
def forward(self, features_rank):
|
||||
"""
|
||||
Compute the results on all values of `self.nb_knn` neighbors from the full `self.max_k`
|
||||
"""
|
||||
assert all(k <= self.max_k for k in self.nb_knn)
|
||||
|
||||
topk_sims, neighbors_labels = self.compute_neighbors(features_rank)
|
||||
batch_size = neighbors_labels.shape[0]
|
||||
topk_sims_transform = softmax(topk_sims / self.T, 1)
|
||||
matmul = torch.mul(
|
||||
one_hot(neighbors_labels, num_classes=self.num_classes),
|
||||
topk_sims_transform.view(batch_size, -1, 1),
|
||||
)
|
||||
probas_for_k = {k: torch.sum(matmul[:, :k, :], 1) for k in self.nb_knn}
|
||||
return probas_for_k
|
||||
|
||||
|
||||
class DictKeysModule(torch.nn.Module):
|
||||
def __init__(self, keys):
|
||||
super().__init__()
|
||||
self.keys = keys
|
||||
|
||||
def forward(self, features_dict, targets):
|
||||
for k in self.keys:
|
||||
features_dict = features_dict[k]
|
||||
return {"preds": features_dict, "target": targets}
|
||||
|
||||
|
||||
def create_module_dict(*, module, n_per_class_list, n_tries, nb_knn, train_features, train_labels):
|
||||
modules = {}
|
||||
mapping = create_class_indices_mapping(train_labels)
|
||||
for npc in n_per_class_list:
|
||||
if npc < 0: # Only one try needed when using the full data
|
||||
full_module = module(
|
||||
train_features=train_features,
|
||||
train_labels=train_labels,
|
||||
nb_knn=nb_knn,
|
||||
)
|
||||
modules["full"] = ModuleDictWithForward({"1": full_module})
|
||||
continue
|
||||
all_tries = {}
|
||||
for t in range(n_tries):
|
||||
final_indices = filter_train(mapping, npc, seed=t)
|
||||
k_list = list(set(nb_knn + [npc]))
|
||||
k_list = sorted([el for el in k_list if el <= npc])
|
||||
all_tries[str(t)] = module(
|
||||
train_features=train_features[final_indices],
|
||||
train_labels=train_labels[final_indices],
|
||||
nb_knn=k_list,
|
||||
)
|
||||
modules[f"{npc} per class"] = ModuleDictWithForward(all_tries)
|
||||
|
||||
return ModuleDictWithForward(modules)
|
||||
|
||||
|
||||
def filter_train(mapping, n_per_class, seed):
|
||||
torch.manual_seed(seed)
|
||||
final_indices = []
|
||||
for k in mapping.keys():
|
||||
index = torch.randperm(len(mapping[k]))[:n_per_class]
|
||||
final_indices.append(mapping[k][index])
|
||||
return torch.cat(final_indices).squeeze()
|
||||
|
||||
|
||||
def create_class_indices_mapping(labels):
|
||||
unique_labels, inverse = torch.unique(labels, return_inverse=True)
|
||||
mapping = {unique_labels[i]: (inverse == i).nonzero() for i in range(len(unique_labels))}
|
||||
return mapping
|
||||
|
||||
|
||||
class ModuleDictWithForward(torch.nn.ModuleDict):
|
||||
def forward(self, *args, **kwargs):
|
||||
return {k: module(*args, **kwargs) for k, module in self._modules.items()}
|
||||
|
||||
|
||||
def eval_knn(
|
||||
model,
|
||||
train_dataset,
|
||||
val_dataset,
|
||||
accuracy_averaging,
|
||||
nb_knn,
|
||||
temperature,
|
||||
batch_size,
|
||||
num_workers,
|
||||
gather_on_cpu,
|
||||
n_per_class_list=[-1],
|
||||
n_tries=1,
|
||||
):
|
||||
model = ModelWithNormalize(model)
|
||||
|
||||
logger.info("Extracting features for train set...")
|
||||
train_features, train_labels = extract_features(
|
||||
model, train_dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu
|
||||
)
|
||||
logger.info(f"Train features created, shape {train_features.shape}.")
|
||||
|
||||
val_dataloader = make_data_loader(
|
||||
dataset=val_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
sampler_type=SamplerType.DISTRIBUTED,
|
||||
drop_last=False,
|
||||
shuffle=False,
|
||||
persistent_workers=True,
|
||||
)
|
||||
num_classes = train_labels.max() + 1
|
||||
metric_collection = build_topk_accuracy_metric(accuracy_averaging, num_classes=num_classes)
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
partial_module = partial(KnnModule, T=temperature, device=device, num_classes=num_classes)
|
||||
knn_module_dict = create_module_dict(
|
||||
module=partial_module,
|
||||
n_per_class_list=n_per_class_list,
|
||||
n_tries=n_tries,
|
||||
nb_knn=nb_knn,
|
||||
train_features=train_features,
|
||||
train_labels=train_labels,
|
||||
)
|
||||
postprocessors, metrics = {}, {}
|
||||
for n_per_class, knn_module in knn_module_dict.items():
|
||||
for t, knn_try in knn_module.items():
|
||||
postprocessors = {
|
||||
**postprocessors,
|
||||
**{(n_per_class, t, k): DictKeysModule([n_per_class, t, k]) for k in knn_try.nb_knn},
|
||||
}
|
||||
metrics = {**metrics, **{(n_per_class, t, k): metric_collection.clone() for k in knn_try.nb_knn}}
|
||||
model_with_knn = torch.nn.Sequential(model, knn_module_dict)
|
||||
|
||||
# ============ evaluation ... ============
|
||||
logger.info("Start the k-NN classification.")
|
||||
_, results_dict = evaluate(model_with_knn, val_dataloader, postprocessors, metrics, device)
|
||||
|
||||
# Averaging the results over the n tries for each value of n_per_class
|
||||
for n_per_class, knn_module in knn_module_dict.items():
|
||||
first_try = list(knn_module.keys())[0]
|
||||
k_list = knn_module[first_try].nb_knn
|
||||
for k in k_list:
|
||||
keys = results_dict[(n_per_class, first_try, k)].keys() # keys are e.g. `top-1` and `top-5`
|
||||
results_dict[(n_per_class, k)] = {
|
||||
key: torch.mean(torch.stack([results_dict[(n_per_class, t, k)][key] for t in knn_module.keys()]))
|
||||
for key in keys
|
||||
}
|
||||
for t in knn_module.keys():
|
||||
del results_dict[(n_per_class, t, k)]
|
||||
|
||||
return results_dict
|
||||
|
||||
|
||||
def eval_knn_with_model(
|
||||
model,
|
||||
output_dir,
|
||||
train_dataset_str="ImageNet:split=TRAIN",
|
||||
val_dataset_str="ImageNet:split=VAL",
|
||||
nb_knn=(10, 20, 100, 200),
|
||||
temperature=0.07,
|
||||
autocast_dtype=torch.float,
|
||||
accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY,
|
||||
transform=None,
|
||||
gather_on_cpu=False,
|
||||
batch_size=256,
|
||||
num_workers=5,
|
||||
n_per_class_list=[-1],
|
||||
n_tries=1,
|
||||
):
|
||||
transform = transform or make_classification_eval_transform()
|
||||
|
||||
train_dataset = make_dataset(
|
||||
dataset_str=train_dataset_str,
|
||||
transform=transform,
|
||||
)
|
||||
val_dataset = make_dataset(
|
||||
dataset_str=val_dataset_str,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
with torch.cuda.amp.autocast(dtype=autocast_dtype):
|
||||
results_dict_knn = eval_knn(
|
||||
model=model,
|
||||
train_dataset=train_dataset,
|
||||
val_dataset=val_dataset,
|
||||
accuracy_averaging=accuracy_averaging,
|
||||
nb_knn=nb_knn,
|
||||
temperature=temperature,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
gather_on_cpu=gather_on_cpu,
|
||||
n_per_class_list=n_per_class_list,
|
||||
n_tries=n_tries,
|
||||
)
|
||||
|
||||
results_dict = {}
|
||||
if distributed.is_main_process():
|
||||
for knn_ in results_dict_knn.keys():
|
||||
top1 = results_dict_knn[knn_]["top-1"].item() * 100.0
|
||||
top5 = results_dict_knn[knn_]["top-5"].item() * 100.0
|
||||
results_dict[f"{knn_} Top 1"] = top1
|
||||
results_dict[f"{knn_} Top 5"] = top5
|
||||
logger.info(f"{knn_} classifier result: Top1: {top1:.2f} Top5: {top5:.2f}")
|
||||
|
||||
metrics_file_path = os.path.join(output_dir, "results_eval_knn.json")
|
||||
with open(metrics_file_path, "a") as f:
|
||||
for k, v in results_dict.items():
|
||||
f.write(json.dumps({k: v}) + "\n")
|
||||
|
||||
if distributed.is_enabled():
|
||||
torch.distributed.barrier()
|
||||
return results_dict
|
||||
|
||||
|
||||
def main(args):
|
||||
model, autocast_dtype = setup_and_build_model(args)
|
||||
eval_knn_with_model(
|
||||
model=model,
|
||||
output_dir=args.output_dir,
|
||||
train_dataset_str=args.train_dataset_str,
|
||||
val_dataset_str=args.val_dataset_str,
|
||||
nb_knn=args.nb_knn,
|
||||
temperature=args.temperature,
|
||||
autocast_dtype=autocast_dtype,
|
||||
accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY,
|
||||
transform=None,
|
||||
gather_on_cpu=args.gather_on_cpu,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=5,
|
||||
n_per_class_list=args.n_per_class_list,
|
||||
n_tries=args.n_tries,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
description = "DINOv2 k-NN evaluation"
|
||||
args_parser = get_args_parser(description=description)
|
||||
args = args_parser.parse_args()
|
||||
sys.exit(main(args))
|
||||
@@ -0,0 +1,626 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
from functools import partial
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
|
||||
|
||||
from dinov2.data import SamplerType, make_data_loader, make_dataset
|
||||
from dinov2.data.transforms import make_classification_eval_transform, make_classification_train_transform
|
||||
import dinov2.distributed as distributed
|
||||
from dinov2.eval.metrics import MetricType, build_metric
|
||||
from dinov2.eval.setup import get_args_parser as get_setup_args_parser
|
||||
from dinov2.eval.setup import setup_and_build_model
|
||||
from dinov2.eval.utils import ModelWithIntermediateLayers, evaluate
|
||||
from dinov2.logging import MetricLogger
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
def get_args_parser(
|
||||
description: Optional[str] = None,
|
||||
parents: Optional[List[argparse.ArgumentParser]] = None,
|
||||
add_help: bool = True,
|
||||
):
|
||||
parents = parents or []
|
||||
setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
|
||||
parents = [setup_args_parser]
|
||||
parser = argparse.ArgumentParser(
|
||||
description=description,
|
||||
parents=parents,
|
||||
add_help=add_help,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train-dataset",
|
||||
dest="train_dataset_str",
|
||||
type=str,
|
||||
help="Training dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val-dataset",
|
||||
dest="val_dataset_str",
|
||||
type=str,
|
||||
help="Validation dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-datasets",
|
||||
dest="test_dataset_strs",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Test datasets, none to reuse the validation dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
type=int,
|
||||
help="Number of training epochs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
help="Batch Size (per GPU)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
help="Number de Workers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epoch-length",
|
||||
type=int,
|
||||
help="Length of an epoch in number of iterations",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-checkpoint-frequency",
|
||||
type=int,
|
||||
help="Number of epochs between two named checkpoint saves.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval-period-iterations",
|
||||
type=int,
|
||||
help="Number of iterations between two evaluations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning-rates",
|
||||
nargs="+",
|
||||
type=float,
|
||||
help="Learning rates to grid search.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-resume",
|
||||
action="store_true",
|
||||
help="Whether to not resume from existing checkpoints",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val-metric-type",
|
||||
type=MetricType,
|
||||
choices=list(MetricType),
|
||||
help="Validation metric",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-metric-types",
|
||||
type=MetricType,
|
||||
choices=list(MetricType),
|
||||
nargs="+",
|
||||
help="Evaluation metric",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--classifier-fpath",
|
||||
type=str,
|
||||
help="Path to a file containing pretrained linear classifiers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val-class-mapping-fpath",
|
||||
type=str,
|
||||
help="Path to a file containing a mapping to adjust classifier outputs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-class-mapping-fpaths",
|
||||
nargs="+",
|
||||
type=str,
|
||||
help="Path to a file containing a mapping to adjust classifier outputs",
|
||||
)
|
||||
parser.set_defaults(
|
||||
train_dataset_str="ImageNet:split=TRAIN",
|
||||
val_dataset_str="ImageNet:split=VAL",
|
||||
test_dataset_strs=None,
|
||||
epochs=10,
|
||||
batch_size=128,
|
||||
num_workers=8,
|
||||
epoch_length=1250,
|
||||
save_checkpoint_frequency=20,
|
||||
eval_period_iterations=1250,
|
||||
learning_rates=[1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2, 2e-2, 5e-2, 0.1],
|
||||
val_metric_type=MetricType.MEAN_ACCURACY,
|
||||
test_metric_types=None,
|
||||
classifier_fpath=None,
|
||||
val_class_mapping_fpath=None,
|
||||
test_class_mapping_fpaths=[None],
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def has_ddp_wrapper(m: nn.Module) -> bool:
|
||||
return isinstance(m, DistributedDataParallel)
|
||||
|
||||
|
||||
def remove_ddp_wrapper(m: nn.Module) -> nn.Module:
|
||||
return m.module if has_ddp_wrapper(m) else m
|
||||
|
||||
|
||||
def _pad_and_collate(batch):
|
||||
maxlen = max(len(targets) for image, targets in batch)
|
||||
padded_batch = [
|
||||
(image, np.pad(targets, (0, maxlen - len(targets)), constant_values=-1)) for image, targets in batch
|
||||
]
|
||||
return torch.utils.data.default_collate(padded_batch)
|
||||
|
||||
|
||||
def create_linear_input(x_tokens_list, use_n_blocks, use_avgpool):
|
||||
intermediate_output = x_tokens_list[-use_n_blocks:]
|
||||
output = torch.cat([class_token for _, class_token in intermediate_output], dim=-1)
|
||||
if use_avgpool:
|
||||
output = torch.cat(
|
||||
(
|
||||
output,
|
||||
torch.mean(intermediate_output[-1][0], dim=1), # patch tokens
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
output = output.reshape(output.shape[0], -1)
|
||||
return output.float()
|
||||
|
||||
|
||||
class LinearClassifier(nn.Module):
|
||||
"""Linear layer to train on top of frozen features"""
|
||||
|
||||
def __init__(self, out_dim, use_n_blocks, use_avgpool, num_classes=1000):
|
||||
super().__init__()
|
||||
self.out_dim = out_dim
|
||||
self.use_n_blocks = use_n_blocks
|
||||
self.use_avgpool = use_avgpool
|
||||
self.num_classes = num_classes
|
||||
self.linear = nn.Linear(out_dim, num_classes)
|
||||
self.linear.weight.data.normal_(mean=0.0, std=0.01)
|
||||
self.linear.bias.data.zero_()
|
||||
|
||||
def forward(self, x_tokens_list):
|
||||
output = create_linear_input(x_tokens_list, self.use_n_blocks, self.use_avgpool)
|
||||
return self.linear(output)
|
||||
|
||||
|
||||
class AllClassifiers(nn.Module):
|
||||
def __init__(self, classifiers_dict):
|
||||
super().__init__()
|
||||
self.classifiers_dict = nn.ModuleDict()
|
||||
self.classifiers_dict.update(classifiers_dict)
|
||||
|
||||
def forward(self, inputs):
|
||||
return {k: v.forward(inputs) for k, v in self.classifiers_dict.items()}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.classifiers_dict)
|
||||
|
||||
|
||||
class LinearPostprocessor(nn.Module):
|
||||
def __init__(self, linear_classifier, class_mapping=None):
|
||||
super().__init__()
|
||||
self.linear_classifier = linear_classifier
|
||||
self.register_buffer("class_mapping", None if class_mapping is None else torch.LongTensor(class_mapping))
|
||||
|
||||
def forward(self, samples, targets):
|
||||
preds = self.linear_classifier(samples)
|
||||
return {
|
||||
"preds": preds[:, self.class_mapping] if self.class_mapping is not None else preds,
|
||||
"target": targets,
|
||||
}
|
||||
|
||||
|
||||
def scale_lr(learning_rates, batch_size):
|
||||
return learning_rates * (batch_size * distributed.get_global_size()) / 256.0
|
||||
|
||||
|
||||
def setup_linear_classifiers(sample_output, n_last_blocks_list, learning_rates, batch_size, num_classes=1000):
|
||||
linear_classifiers_dict = nn.ModuleDict()
|
||||
optim_param_groups = []
|
||||
for n in n_last_blocks_list:
|
||||
for avgpool in [False, True]:
|
||||
for _lr in learning_rates:
|
||||
lr = scale_lr(_lr, batch_size)
|
||||
out_dim = create_linear_input(sample_output, use_n_blocks=n, use_avgpool=avgpool).shape[1]
|
||||
linear_classifier = LinearClassifier(
|
||||
out_dim, use_n_blocks=n, use_avgpool=avgpool, num_classes=num_classes
|
||||
)
|
||||
linear_classifier = linear_classifier.cuda()
|
||||
linear_classifiers_dict[
|
||||
f"classifier_{n}_blocks_avgpool_{avgpool}_lr_{lr:.5f}".replace(".", "_")
|
||||
] = linear_classifier
|
||||
optim_param_groups.append({"params": linear_classifier.parameters(), "lr": lr})
|
||||
|
||||
linear_classifiers = AllClassifiers(linear_classifiers_dict)
|
||||
if distributed.is_enabled():
|
||||
linear_classifiers = nn.parallel.DistributedDataParallel(linear_classifiers)
|
||||
|
||||
return linear_classifiers, optim_param_groups
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_linear_classifiers(
|
||||
feature_model,
|
||||
linear_classifiers,
|
||||
data_loader,
|
||||
metric_type,
|
||||
metrics_file_path,
|
||||
training_num_classes,
|
||||
iteration,
|
||||
prefixstring="",
|
||||
class_mapping=None,
|
||||
best_classifier_on_val=None,
|
||||
):
|
||||
logger.info("running validation !")
|
||||
|
||||
num_classes = len(class_mapping) if class_mapping is not None else training_num_classes
|
||||
metric = build_metric(metric_type, num_classes=num_classes)
|
||||
postprocessors = {k: LinearPostprocessor(v, class_mapping) for k, v in linear_classifiers.classifiers_dict.items()}
|
||||
metrics = {k: metric.clone() for k in linear_classifiers.classifiers_dict}
|
||||
|
||||
_, results_dict_temp = evaluate(
|
||||
feature_model,
|
||||
data_loader,
|
||||
postprocessors,
|
||||
metrics,
|
||||
torch.cuda.current_device(),
|
||||
)
|
||||
|
||||
logger.info("")
|
||||
results_dict = {}
|
||||
max_accuracy = 0
|
||||
best_classifier = ""
|
||||
for i, (classifier_string, metric) in enumerate(results_dict_temp.items()):
|
||||
logger.info(f"{prefixstring} -- Classifier: {classifier_string} * {metric}")
|
||||
if (
|
||||
best_classifier_on_val is None and metric["top-1"].item() > max_accuracy
|
||||
) or classifier_string == best_classifier_on_val:
|
||||
max_accuracy = metric["top-1"].item()
|
||||
best_classifier = classifier_string
|
||||
|
||||
results_dict["best_classifier"] = {"name": best_classifier, "accuracy": max_accuracy}
|
||||
|
||||
logger.info(f"best classifier: {results_dict['best_classifier']}")
|
||||
|
||||
if distributed.is_main_process():
|
||||
with open(metrics_file_path, "a") as f:
|
||||
f.write(f"iter: {iteration}\n")
|
||||
for k, v in results_dict.items():
|
||||
f.write(json.dumps({k: v}) + "\n")
|
||||
f.write("\n")
|
||||
|
||||
return results_dict
|
||||
|
||||
|
||||
def eval_linear(
|
||||
*,
|
||||
feature_model,
|
||||
linear_classifiers,
|
||||
train_data_loader,
|
||||
val_data_loader,
|
||||
metrics_file_path,
|
||||
optimizer,
|
||||
scheduler,
|
||||
output_dir,
|
||||
max_iter,
|
||||
checkpoint_period, # In number of iter, creates a new file every period
|
||||
running_checkpoint_period, # Period to update main checkpoint file
|
||||
eval_period,
|
||||
metric_type,
|
||||
training_num_classes,
|
||||
resume=True,
|
||||
classifier_fpath=None,
|
||||
val_class_mapping=None,
|
||||
):
|
||||
checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler)
|
||||
start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1
|
||||
|
||||
periodic_checkpointer = PeriodicCheckpointer(checkpointer, checkpoint_period, max_iter=max_iter)
|
||||
iteration = start_iter
|
||||
logger.info("Starting training from iteration {}".format(start_iter))
|
||||
metric_logger = MetricLogger(delimiter=" ")
|
||||
header = "Training"
|
||||
|
||||
for data, labels in metric_logger.log_every(
|
||||
train_data_loader,
|
||||
10,
|
||||
header,
|
||||
max_iter,
|
||||
start_iter,
|
||||
):
|
||||
data = data.cuda(non_blocking=True)
|
||||
labels = labels.cuda(non_blocking=True)
|
||||
|
||||
features = feature_model(data)
|
||||
outputs = linear_classifiers(features)
|
||||
|
||||
losses = {f"loss_{k}": nn.CrossEntropyLoss()(v, labels) for k, v in outputs.items()}
|
||||
loss = sum(losses.values())
|
||||
|
||||
# compute the gradients
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
# step
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
# log
|
||||
if iteration % 10 == 0:
|
||||
torch.cuda.synchronize()
|
||||
metric_logger.update(loss=loss.item())
|
||||
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
||||
print("lr", optimizer.param_groups[0]["lr"])
|
||||
|
||||
if iteration - start_iter > 5:
|
||||
if iteration % running_checkpoint_period == 0:
|
||||
torch.cuda.synchronize()
|
||||
if distributed.is_main_process():
|
||||
logger.info("Checkpointing running_checkpoint")
|
||||
periodic_checkpointer.save("running_checkpoint_linear_eval", iteration=iteration)
|
||||
torch.cuda.synchronize()
|
||||
periodic_checkpointer.step(iteration)
|
||||
|
||||
if eval_period > 0 and (iteration + 1) % eval_period == 0 and iteration != max_iter - 1:
|
||||
_ = evaluate_linear_classifiers(
|
||||
feature_model=feature_model,
|
||||
linear_classifiers=remove_ddp_wrapper(linear_classifiers),
|
||||
data_loader=val_data_loader,
|
||||
metrics_file_path=metrics_file_path,
|
||||
prefixstring=f"ITER: {iteration}",
|
||||
metric_type=metric_type,
|
||||
training_num_classes=training_num_classes,
|
||||
iteration=iteration,
|
||||
class_mapping=val_class_mapping,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
iteration = iteration + 1
|
||||
|
||||
val_results_dict = evaluate_linear_classifiers(
|
||||
feature_model=feature_model,
|
||||
linear_classifiers=remove_ddp_wrapper(linear_classifiers),
|
||||
data_loader=val_data_loader,
|
||||
metrics_file_path=metrics_file_path,
|
||||
metric_type=metric_type,
|
||||
training_num_classes=training_num_classes,
|
||||
iteration=iteration,
|
||||
class_mapping=val_class_mapping,
|
||||
)
|
||||
return val_results_dict, feature_model, linear_classifiers, iteration
|
||||
|
||||
|
||||
def make_eval_data_loader(test_dataset_str, batch_size, num_workers, metric_type):
|
||||
test_dataset = make_dataset(
|
||||
dataset_str=test_dataset_str,
|
||||
transform=make_classification_eval_transform(),
|
||||
)
|
||||
test_data_loader = make_data_loader(
|
||||
dataset=test_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
sampler_type=SamplerType.DISTRIBUTED,
|
||||
drop_last=False,
|
||||
shuffle=False,
|
||||
persistent_workers=False,
|
||||
collate_fn=_pad_and_collate if metric_type == MetricType.IMAGENET_REAL_ACCURACY else None,
|
||||
)
|
||||
return test_data_loader
|
||||
|
||||
|
||||
def test_on_datasets(
|
||||
feature_model,
|
||||
linear_classifiers,
|
||||
test_dataset_strs,
|
||||
batch_size,
|
||||
num_workers,
|
||||
test_metric_types,
|
||||
metrics_file_path,
|
||||
training_num_classes,
|
||||
iteration,
|
||||
best_classifier_on_val,
|
||||
prefixstring="",
|
||||
test_class_mappings=[None],
|
||||
):
|
||||
results_dict = {}
|
||||
for test_dataset_str, class_mapping, metric_type in zip(test_dataset_strs, test_class_mappings, test_metric_types):
|
||||
logger.info(f"Testing on {test_dataset_str}")
|
||||
test_data_loader = make_eval_data_loader(test_dataset_str, batch_size, num_workers, metric_type)
|
||||
dataset_results_dict = evaluate_linear_classifiers(
|
||||
feature_model,
|
||||
remove_ddp_wrapper(linear_classifiers),
|
||||
test_data_loader,
|
||||
metric_type,
|
||||
metrics_file_path,
|
||||
training_num_classes,
|
||||
iteration,
|
||||
prefixstring="",
|
||||
class_mapping=class_mapping,
|
||||
best_classifier_on_val=best_classifier_on_val,
|
||||
)
|
||||
results_dict[f"{test_dataset_str}_accuracy"] = 100.0 * dataset_results_dict["best_classifier"]["accuracy"]
|
||||
return results_dict
|
||||
|
||||
|
||||
def run_eval_linear(
|
||||
model,
|
||||
output_dir,
|
||||
train_dataset_str,
|
||||
val_dataset_str,
|
||||
batch_size,
|
||||
epochs,
|
||||
epoch_length,
|
||||
num_workers,
|
||||
save_checkpoint_frequency,
|
||||
eval_period_iterations,
|
||||
learning_rates,
|
||||
autocast_dtype,
|
||||
test_dataset_strs=None,
|
||||
resume=True,
|
||||
classifier_fpath=None,
|
||||
val_class_mapping_fpath=None,
|
||||
test_class_mapping_fpaths=[None],
|
||||
val_metric_type=MetricType.MEAN_ACCURACY,
|
||||
test_metric_types=None,
|
||||
):
|
||||
seed = 0
|
||||
|
||||
if test_dataset_strs is None:
|
||||
test_dataset_strs = [val_dataset_str]
|
||||
if test_metric_types is None:
|
||||
test_metric_types = [val_metric_type] * len(test_dataset_strs)
|
||||
else:
|
||||
assert len(test_metric_types) == len(test_dataset_strs)
|
||||
assert len(test_dataset_strs) == len(test_class_mapping_fpaths)
|
||||
|
||||
train_transform = make_classification_train_transform()
|
||||
train_dataset = make_dataset(
|
||||
dataset_str=train_dataset_str,
|
||||
transform=train_transform,
|
||||
)
|
||||
training_num_classes = len(torch.unique(torch.Tensor(train_dataset.get_targets().astype(int))))
|
||||
sampler_type = SamplerType.SHARDED_INFINITE
|
||||
# sampler_type = SamplerType.INFINITE
|
||||
|
||||
n_last_blocks_list = [1, 4]
|
||||
n_last_blocks = max(n_last_blocks_list)
|
||||
autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype)
|
||||
feature_model = ModelWithIntermediateLayers(model, n_last_blocks, autocast_ctx)
|
||||
sample_output = feature_model(train_dataset[0][0].unsqueeze(0).cuda())
|
||||
|
||||
linear_classifiers, optim_param_groups = setup_linear_classifiers(
|
||||
sample_output,
|
||||
n_last_blocks_list,
|
||||
learning_rates,
|
||||
batch_size,
|
||||
training_num_classes,
|
||||
)
|
||||
|
||||
optimizer = torch.optim.SGD(optim_param_groups, momentum=0.9, weight_decay=0)
|
||||
max_iter = epochs * epoch_length
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0)
|
||||
checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler)
|
||||
start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1
|
||||
train_data_loader = make_data_loader(
|
||||
dataset=train_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
shuffle=True,
|
||||
seed=seed,
|
||||
sampler_type=sampler_type,
|
||||
sampler_advance=start_iter,
|
||||
drop_last=True,
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_data_loader = make_eval_data_loader(val_dataset_str, batch_size, num_workers, val_metric_type)
|
||||
|
||||
checkpoint_period = save_checkpoint_frequency * epoch_length
|
||||
|
||||
if val_class_mapping_fpath is not None:
|
||||
logger.info(f"Using class mapping from {val_class_mapping_fpath}")
|
||||
val_class_mapping = np.load(val_class_mapping_fpath)
|
||||
else:
|
||||
val_class_mapping = None
|
||||
|
||||
test_class_mappings = []
|
||||
for class_mapping_fpath in test_class_mapping_fpaths:
|
||||
if class_mapping_fpath is not None and class_mapping_fpath != "None":
|
||||
logger.info(f"Using class mapping from {class_mapping_fpath}")
|
||||
class_mapping = np.load(class_mapping_fpath)
|
||||
else:
|
||||
class_mapping = None
|
||||
test_class_mappings.append(class_mapping)
|
||||
|
||||
metrics_file_path = os.path.join(output_dir, "results_eval_linear.json")
|
||||
val_results_dict, feature_model, linear_classifiers, iteration = eval_linear(
|
||||
feature_model=feature_model,
|
||||
linear_classifiers=linear_classifiers,
|
||||
train_data_loader=train_data_loader,
|
||||
val_data_loader=val_data_loader,
|
||||
metrics_file_path=metrics_file_path,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
output_dir=output_dir,
|
||||
max_iter=max_iter,
|
||||
checkpoint_period=checkpoint_period,
|
||||
running_checkpoint_period=epoch_length,
|
||||
eval_period=eval_period_iterations,
|
||||
metric_type=val_metric_type,
|
||||
training_num_classes=training_num_classes,
|
||||
resume=resume,
|
||||
val_class_mapping=val_class_mapping,
|
||||
classifier_fpath=classifier_fpath,
|
||||
)
|
||||
results_dict = {}
|
||||
if len(test_dataset_strs) > 1 or test_dataset_strs[0] != val_dataset_str:
|
||||
results_dict = test_on_datasets(
|
||||
feature_model,
|
||||
linear_classifiers,
|
||||
test_dataset_strs,
|
||||
batch_size,
|
||||
0, # num_workers,
|
||||
test_metric_types,
|
||||
metrics_file_path,
|
||||
training_num_classes,
|
||||
iteration,
|
||||
val_results_dict["best_classifier"]["name"],
|
||||
prefixstring="",
|
||||
test_class_mappings=test_class_mappings,
|
||||
)
|
||||
results_dict["best_classifier"] = val_results_dict["best_classifier"]["name"]
|
||||
results_dict[f"{val_dataset_str}_accuracy"] = 100.0 * val_results_dict["best_classifier"]["accuracy"]
|
||||
logger.info("Test Results Dict " + str(results_dict))
|
||||
|
||||
return results_dict
|
||||
|
||||
|
||||
def main(args):
|
||||
model, autocast_dtype = setup_and_build_model(args)
|
||||
run_eval_linear(
|
||||
model=model,
|
||||
output_dir=args.output_dir,
|
||||
train_dataset_str=args.train_dataset_str,
|
||||
val_dataset_str=args.val_dataset_str,
|
||||
test_dataset_strs=args.test_dataset_strs,
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs,
|
||||
epoch_length=args.epoch_length,
|
||||
num_workers=args.num_workers,
|
||||
save_checkpoint_frequency=args.save_checkpoint_frequency,
|
||||
eval_period_iterations=args.eval_period_iterations,
|
||||
learning_rates=args.learning_rates,
|
||||
autocast_dtype=autocast_dtype,
|
||||
resume=not args.no_resume,
|
||||
classifier_fpath=args.classifier_fpath,
|
||||
val_metric_type=args.val_metric_type,
|
||||
test_metric_types=args.test_metric_types,
|
||||
val_class_mapping_fpath=args.val_class_mapping_fpath,
|
||||
test_class_mapping_fpaths=args.test_class_mapping_fpaths,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
description = "DINOv2 linear evaluation"
|
||||
args_parser = get_args_parser(description=description)
|
||||
args = args_parser.parse_args()
|
||||
sys.exit(main(args))
|
||||
@@ -0,0 +1,445 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
from cuml.linear_model import LogisticRegression
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.distributed
|
||||
from torch import nn
|
||||
from torch.utils.data import TensorDataset
|
||||
from torchmetrics import MetricTracker
|
||||
|
||||
from dinov2.data import make_dataset
|
||||
from dinov2.data.transforms import make_classification_eval_transform
|
||||
from dinov2.distributed import get_global_rank, get_global_size
|
||||
from dinov2.eval.metrics import MetricType, build_metric
|
||||
from dinov2.eval.setup import get_args_parser as get_setup_args_parser
|
||||
from dinov2.eval.setup import setup_and_build_model
|
||||
from dinov2.eval.utils import evaluate, extract_features
|
||||
from dinov2.utils.dtype import as_torch_dtype
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
DEFAULT_MAX_ITER = 1_000
|
||||
C_POWER_RANGE = torch.linspace(-6, 5, 45)
|
||||
_CPU_DEVICE = torch.device("cpu")
|
||||
|
||||
|
||||
def get_args_parser(
|
||||
description: Optional[str] = None,
|
||||
parents: Optional[List[argparse.ArgumentParser]] = None,
|
||||
add_help: bool = True,
|
||||
):
|
||||
parents = parents or []
|
||||
setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
|
||||
parents = [setup_args_parser]
|
||||
parser = argparse.ArgumentParser(
|
||||
description=description,
|
||||
parents=parents,
|
||||
add_help=add_help,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train-dataset",
|
||||
dest="train_dataset_str",
|
||||
type=str,
|
||||
help="Training dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val-dataset",
|
||||
dest="val_dataset_str",
|
||||
type=str,
|
||||
help="Validation dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--finetune-dataset-str",
|
||||
dest="finetune_dataset_str",
|
||||
type=str,
|
||||
help="Fine-tuning dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--finetune-on-val",
|
||||
action="store_true",
|
||||
help="If there is no finetune dataset, whether to choose the "
|
||||
"hyperparameters on the val set instead of 10%% of the train dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metric-type",
|
||||
type=MetricType,
|
||||
choices=list(MetricType),
|
||||
help="Metric type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train-features-device",
|
||||
type=str,
|
||||
help="Device to gather train features (cpu, cuda, cuda:0, etc.), default: %(default)s",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train-dtype",
|
||||
type=str,
|
||||
help="Data type to convert the train features to (default: %(default)s)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-train-iters",
|
||||
type=int,
|
||||
help="Maximum number of train iterations (default: %(default)s)",
|
||||
)
|
||||
parser.set_defaults(
|
||||
train_dataset_str="ImageNet:split=TRAIN",
|
||||
val_dataset_str="ImageNet:split=VAL",
|
||||
finetune_dataset_str=None,
|
||||
metric_type=MetricType.MEAN_ACCURACY,
|
||||
train_features_device="cpu",
|
||||
train_dtype="float64",
|
||||
max_train_iters=DEFAULT_MAX_ITER,
|
||||
finetune_on_val=False,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
class LogRegModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C,
|
||||
max_iter=DEFAULT_MAX_ITER,
|
||||
dtype=torch.float64,
|
||||
device=_CPU_DEVICE,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.estimator = LogisticRegression(
|
||||
penalty="l2",
|
||||
C=C,
|
||||
max_iter=max_iter,
|
||||
output_type="numpy",
|
||||
tol=1e-12,
|
||||
linesearch_max_iter=50,
|
||||
)
|
||||
|
||||
def forward(self, samples, targets):
|
||||
samples_device = samples.device
|
||||
samples = samples.to(dtype=self.dtype, device=self.device)
|
||||
if self.device == _CPU_DEVICE:
|
||||
samples = samples.numpy()
|
||||
probas = self.estimator.predict_proba(samples)
|
||||
return {"preds": torch.from_numpy(probas).to(samples_device), "target": targets}
|
||||
|
||||
def fit(self, train_features, train_labels):
|
||||
train_features = train_features.to(dtype=self.dtype, device=self.device)
|
||||
train_labels = train_labels.to(dtype=self.dtype, device=self.device)
|
||||
if self.device == _CPU_DEVICE:
|
||||
# both cuML and sklearn only work with numpy arrays on CPU
|
||||
train_features = train_features.numpy()
|
||||
train_labels = train_labels.numpy()
|
||||
self.estimator.fit(train_features, train_labels)
|
||||
|
||||
|
||||
def evaluate_model(*, logreg_model, logreg_metric, test_data_loader, device):
|
||||
postprocessors = {"metrics": logreg_model}
|
||||
metrics = {"metrics": logreg_metric}
|
||||
return evaluate(nn.Identity(), test_data_loader, postprocessors, metrics, device)
|
||||
|
||||
|
||||
def train_for_C(*, C, max_iter, train_features, train_labels, dtype=torch.float64, device=_CPU_DEVICE):
|
||||
logreg_model = LogRegModule(C, max_iter=max_iter, dtype=dtype, device=device)
|
||||
logreg_model.fit(train_features, train_labels)
|
||||
return logreg_model
|
||||
|
||||
|
||||
def train_and_evaluate(
|
||||
*,
|
||||
C,
|
||||
max_iter,
|
||||
train_features,
|
||||
train_labels,
|
||||
logreg_metric,
|
||||
test_data_loader,
|
||||
train_dtype=torch.float64,
|
||||
train_features_device,
|
||||
eval_device,
|
||||
):
|
||||
logreg_model = train_for_C(
|
||||
C=C,
|
||||
max_iter=max_iter,
|
||||
train_features=train_features,
|
||||
train_labels=train_labels,
|
||||
dtype=train_dtype,
|
||||
device=train_features_device,
|
||||
)
|
||||
return evaluate_model(
|
||||
logreg_model=logreg_model,
|
||||
logreg_metric=logreg_metric,
|
||||
test_data_loader=test_data_loader,
|
||||
device=eval_device,
|
||||
)
|
||||
|
||||
|
||||
def sweep_C_values(
|
||||
*,
|
||||
train_features,
|
||||
train_labels,
|
||||
test_data_loader,
|
||||
metric_type,
|
||||
num_classes,
|
||||
train_dtype=torch.float64,
|
||||
train_features_device=_CPU_DEVICE,
|
||||
max_train_iters=DEFAULT_MAX_ITER,
|
||||
):
|
||||
if metric_type == MetricType.PER_CLASS_ACCURACY:
|
||||
# If we want to output per-class accuracy, we select the hyperparameters with mean per class
|
||||
metric_type = MetricType.MEAN_PER_CLASS_ACCURACY
|
||||
logreg_metric = build_metric(metric_type, num_classes=num_classes)
|
||||
metric_tracker = MetricTracker(logreg_metric, maximize=True)
|
||||
ALL_C = 10**C_POWER_RANGE
|
||||
logreg_models = {}
|
||||
|
||||
train_features = train_features.to(dtype=train_dtype, device=train_features_device)
|
||||
train_labels = train_labels.to(device=train_features_device)
|
||||
|
||||
for i in range(get_global_rank(), len(ALL_C), get_global_size()):
|
||||
C = ALL_C[i].item()
|
||||
logger.info(
|
||||
f"Training for C = {C:.5f}, dtype={train_dtype}, "
|
||||
f"features: {train_features.shape}, {train_features.dtype}, "
|
||||
f"labels: {train_labels.shape}, {train_labels.dtype}"
|
||||
)
|
||||
logreg_models[C] = train_for_C(
|
||||
C=C,
|
||||
max_iter=max_train_iters,
|
||||
train_features=train_features,
|
||||
train_labels=train_labels,
|
||||
dtype=train_dtype,
|
||||
device=train_features_device,
|
||||
)
|
||||
|
||||
gather_list = [None for _ in range(get_global_size())]
|
||||
torch.distributed.all_gather_object(gather_list, logreg_models)
|
||||
|
||||
logreg_models_gathered = {}
|
||||
for logreg_dict in gather_list:
|
||||
logreg_models_gathered.update(logreg_dict)
|
||||
|
||||
for i in range(len(ALL_C)):
|
||||
metric_tracker.increment()
|
||||
C = ALL_C[i].item()
|
||||
evals = evaluate_model(
|
||||
logreg_model=logreg_models_gathered[C],
|
||||
logreg_metric=metric_tracker,
|
||||
test_data_loader=test_data_loader,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
logger.info(f"Trained for C = {C:.5f}, accuracies = {evals}")
|
||||
|
||||
best_stats, which_epoch = metric_tracker.best_metric(return_step=True)
|
||||
best_stats_100 = {k: 100.0 * v for k, v in best_stats.items()}
|
||||
if which_epoch["top-1"] == i:
|
||||
best_C = C
|
||||
logger.info(f"Sweep best {best_stats_100}, best C = {best_C:.6f}")
|
||||
|
||||
return best_stats, best_C
|
||||
|
||||
|
||||
def eval_log_regression(
|
||||
*,
|
||||
model,
|
||||
train_dataset,
|
||||
val_dataset,
|
||||
finetune_dataset,
|
||||
metric_type,
|
||||
batch_size,
|
||||
num_workers,
|
||||
finetune_on_val=False,
|
||||
train_dtype=torch.float64,
|
||||
train_features_device=_CPU_DEVICE,
|
||||
max_train_iters=DEFAULT_MAX_ITER,
|
||||
):
|
||||
"""
|
||||
Implements the "standard" process for log regression evaluation:
|
||||
The value of C is chosen by training on train_dataset and evaluating on
|
||||
finetune_dataset. Then, the final model is trained on a concatenation of
|
||||
train_dataset and finetune_dataset, and is evaluated on val_dataset.
|
||||
If there is no finetune_dataset, the value of C is the one that yields
|
||||
the best results on a random 10% subset of the train dataset
|
||||
"""
|
||||
|
||||
start = time.time()
|
||||
|
||||
train_features, train_labels = extract_features(
|
||||
model, train_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE)
|
||||
)
|
||||
val_features, val_labels = extract_features(
|
||||
model, val_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE)
|
||||
)
|
||||
val_data_loader = torch.utils.data.DataLoader(
|
||||
TensorDataset(val_features, val_labels),
|
||||
batch_size=batch_size,
|
||||
drop_last=False,
|
||||
num_workers=0,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
if finetune_dataset is None and finetune_on_val:
|
||||
logger.info("Choosing hyperparameters on the val dataset")
|
||||
finetune_features, finetune_labels = val_features, val_labels
|
||||
elif finetune_dataset is None and not finetune_on_val:
|
||||
logger.info("Choosing hyperparameters on 10% of the train dataset")
|
||||
torch.manual_seed(0)
|
||||
indices = torch.randperm(len(train_features), device=train_features.device)
|
||||
finetune_index = indices[: len(train_features) // 10]
|
||||
train_index = indices[len(train_features) // 10 :]
|
||||
finetune_features, finetune_labels = train_features[finetune_index], train_labels[finetune_index]
|
||||
train_features, train_labels = train_features[train_index], train_labels[train_index]
|
||||
else:
|
||||
logger.info("Choosing hyperparameters on the finetune dataset")
|
||||
finetune_features, finetune_labels = extract_features(
|
||||
model, finetune_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE)
|
||||
)
|
||||
# release the model - free GPU memory
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
finetune_data_loader = torch.utils.data.DataLoader(
|
||||
TensorDataset(finetune_features, finetune_labels),
|
||||
batch_size=batch_size,
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
if len(train_labels.shape) > 1:
|
||||
num_classes = train_labels.shape[1]
|
||||
else:
|
||||
num_classes = train_labels.max() + 1
|
||||
|
||||
logger.info("Using cuML for logistic regression")
|
||||
|
||||
best_stats, best_C = sweep_C_values(
|
||||
train_features=train_features,
|
||||
train_labels=train_labels,
|
||||
test_data_loader=finetune_data_loader,
|
||||
metric_type=metric_type,
|
||||
num_classes=num_classes,
|
||||
train_dtype=train_dtype,
|
||||
train_features_device=train_features_device,
|
||||
max_train_iters=max_train_iters,
|
||||
)
|
||||
|
||||
if not finetune_on_val:
|
||||
logger.info("Best parameter found, concatenating features")
|
||||
train_features = torch.cat((train_features, finetune_features))
|
||||
train_labels = torch.cat((train_labels, finetune_labels))
|
||||
|
||||
logger.info("Training final model")
|
||||
logreg_metric = build_metric(metric_type, num_classes=num_classes)
|
||||
evals = train_and_evaluate(
|
||||
C=best_C,
|
||||
max_iter=max_train_iters,
|
||||
train_features=train_features,
|
||||
train_labels=train_labels,
|
||||
logreg_metric=logreg_metric.clone(),
|
||||
test_data_loader=val_data_loader,
|
||||
eval_device=torch.cuda.current_device(),
|
||||
train_dtype=train_dtype,
|
||||
train_features_device=train_features_device,
|
||||
)
|
||||
|
||||
best_stats = evals[1]["metrics"]
|
||||
|
||||
best_stats["best_C"] = best_C
|
||||
|
||||
logger.info(f"Log regression evaluation done in {int(time.time() - start)}s")
|
||||
return best_stats
|
||||
|
||||
|
||||
def eval_log_regression_with_model(
|
||||
model,
|
||||
train_dataset_str="ImageNet:split=TRAIN",
|
||||
val_dataset_str="ImageNet:split=VAL",
|
||||
finetune_dataset_str=None,
|
||||
autocast_dtype=torch.float,
|
||||
finetune_on_val=False,
|
||||
metric_type=MetricType.MEAN_ACCURACY,
|
||||
train_dtype=torch.float64,
|
||||
train_features_device=_CPU_DEVICE,
|
||||
max_train_iters=DEFAULT_MAX_ITER,
|
||||
):
|
||||
cudnn.benchmark = True
|
||||
|
||||
transform = make_classification_eval_transform(resize_size=224)
|
||||
target_transform = None
|
||||
|
||||
train_dataset = make_dataset(dataset_str=train_dataset_str, transform=transform, target_transform=target_transform)
|
||||
val_dataset = make_dataset(dataset_str=val_dataset_str, transform=transform, target_transform=target_transform)
|
||||
if finetune_dataset_str is not None:
|
||||
finetune_dataset = make_dataset(
|
||||
dataset_str=finetune_dataset_str, transform=transform, target_transform=target_transform
|
||||
)
|
||||
else:
|
||||
finetune_dataset = None
|
||||
|
||||
with torch.cuda.amp.autocast(dtype=autocast_dtype):
|
||||
results_dict_logreg = eval_log_regression(
|
||||
model=model,
|
||||
train_dataset=train_dataset,
|
||||
val_dataset=val_dataset,
|
||||
finetune_dataset=finetune_dataset,
|
||||
metric_type=metric_type,
|
||||
batch_size=256,
|
||||
num_workers=0, # 5,
|
||||
finetune_on_val=finetune_on_val,
|
||||
train_dtype=train_dtype,
|
||||
train_features_device=train_features_device,
|
||||
max_train_iters=max_train_iters,
|
||||
)
|
||||
|
||||
results_dict = {
|
||||
"top-1": results_dict_logreg["top-1"].cpu().numpy() * 100.0,
|
||||
"top-5": results_dict_logreg.get("top-5", torch.tensor(0.0)).cpu().numpy() * 100.0,
|
||||
"best_C": results_dict_logreg["best_C"],
|
||||
}
|
||||
logger.info(
|
||||
"\n".join(
|
||||
[
|
||||
"Training of the supervised logistic regression on frozen features completed.\n"
|
||||
"Top-1 test accuracy: {acc:.1f}".format(acc=results_dict["top-1"]),
|
||||
"Top-5 test accuracy: {acc:.1f}".format(acc=results_dict["top-5"]),
|
||||
"obtained for C = {c:.6f}".format(c=results_dict["best_C"]),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
torch.distributed.barrier()
|
||||
return results_dict
|
||||
|
||||
|
||||
def main(args):
|
||||
model, autocast_dtype = setup_and_build_model(args)
|
||||
eval_log_regression_with_model(
|
||||
model=model,
|
||||
train_dataset_str=args.train_dataset_str,
|
||||
val_dataset_str=args.val_dataset_str,
|
||||
finetune_dataset_str=args.finetune_dataset_str,
|
||||
autocast_dtype=autocast_dtype,
|
||||
finetune_on_val=args.finetune_on_val,
|
||||
metric_type=args.metric_type,
|
||||
train_dtype=as_torch_dtype(args.train_dtype),
|
||||
train_features_device=torch.device(args.train_features_device),
|
||||
max_train_iters=args.max_train_iters,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
description = "DINOv2 logistic regression evaluation"
|
||||
args_parser = get_args_parser(description=description)
|
||||
args = args_parser.parse_args()
|
||||
sys.exit(main(args))
|
||||
@@ -0,0 +1,114 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torchmetrics import Metric, MetricCollection
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
from torchmetrics.utilities.data import dim_zero_cat, select_topk
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
class MetricType(Enum):
|
||||
MEAN_ACCURACY = "mean_accuracy"
|
||||
MEAN_PER_CLASS_ACCURACY = "mean_per_class_accuracy"
|
||||
PER_CLASS_ACCURACY = "per_class_accuracy"
|
||||
IMAGENET_REAL_ACCURACY = "imagenet_real_accuracy"
|
||||
|
||||
@property
|
||||
def accuracy_averaging(self):
|
||||
return getattr(AccuracyAveraging, self.name, None)
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class AccuracyAveraging(Enum):
|
||||
MEAN_ACCURACY = "micro"
|
||||
MEAN_PER_CLASS_ACCURACY = "macro"
|
||||
PER_CLASS_ACCURACY = "none"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
def build_metric(metric_type: MetricType, *, num_classes: int, ks: Optional[tuple] = None):
|
||||
if metric_type.accuracy_averaging is not None:
|
||||
return build_topk_accuracy_metric(
|
||||
average_type=metric_type.accuracy_averaging,
|
||||
num_classes=num_classes,
|
||||
ks=(1, 5) if ks is None else ks,
|
||||
)
|
||||
elif metric_type == MetricType.IMAGENET_REAL_ACCURACY:
|
||||
return build_topk_imagenet_real_accuracy_metric(
|
||||
num_classes=num_classes,
|
||||
ks=(1, 5) if ks is None else ks,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unknown metric type {metric_type}")
|
||||
|
||||
|
||||
def build_topk_accuracy_metric(average_type: AccuracyAveraging, num_classes: int, ks: tuple = (1, 5)):
|
||||
metrics: Dict[str, Metric] = {
|
||||
f"top-{k}": MulticlassAccuracy(top_k=k, num_classes=int(num_classes), average=average_type.value) for k in ks
|
||||
}
|
||||
return MetricCollection(metrics)
|
||||
|
||||
|
||||
def build_topk_imagenet_real_accuracy_metric(num_classes: int, ks: tuple = (1, 5)):
|
||||
metrics: Dict[str, Metric] = {f"top-{k}": ImageNetReaLAccuracy(top_k=k, num_classes=int(num_classes)) for k in ks}
|
||||
return MetricCollection(metrics)
|
||||
|
||||
|
||||
class ImageNetReaLAccuracy(Metric):
|
||||
is_differentiable: bool = False
|
||||
higher_is_better: Optional[bool] = None
|
||||
full_state_update: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: int,
|
||||
top_k: int = 1,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.num_classes = num_classes
|
||||
self.top_k = top_k
|
||||
self.add_state("tp", [], dist_reduce_fx="cat")
|
||||
|
||||
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
|
||||
# preds [B, D]
|
||||
# target [B, A]
|
||||
# preds_oh [B, D] with 0 and 1
|
||||
# select top K highest probabilities, use one hot representation
|
||||
preds_oh = select_topk(preds, self.top_k)
|
||||
# target_oh [B, D + 1] with 0 and 1
|
||||
target_oh = torch.zeros((preds_oh.shape[0], preds_oh.shape[1] + 1), device=target.device, dtype=torch.int32)
|
||||
target = target.long()
|
||||
# for undefined targets (-1) use a fake value `num_classes`
|
||||
target[target == -1] = self.num_classes
|
||||
# fill targets, use one hot representation
|
||||
target_oh.scatter_(1, target, 1)
|
||||
# target_oh [B, D] (remove the fake target at index `num_classes`)
|
||||
target_oh = target_oh[:, :-1]
|
||||
# tp [B] with 0 and 1
|
||||
tp = (preds_oh * target_oh == 1).sum(dim=1)
|
||||
# at least one match between prediction and target
|
||||
tp.clip_(max=1)
|
||||
# ignore instances where no targets are defined
|
||||
mask = target_oh.sum(dim=1) > 0
|
||||
tp = tp[mask]
|
||||
self.tp.append(tp) # type: ignore
|
||||
|
||||
def compute(self) -> Tensor:
|
||||
tp = dim_zero_cat(self.tp) # type: ignore
|
||||
return tp.float().mean()
|
||||
@@ -0,0 +1,76 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
|
||||
from dinov2.models import build_model_from_cfg
|
||||
from dinov2.utils.config import setup
|
||||
import dinov2.utils.utils as dinov2_utils
|
||||
|
||||
|
||||
def get_args_parser(
|
||||
description: Optional[str] = None,
|
||||
parents: Optional[List[argparse.ArgumentParser]] = None,
|
||||
add_help: bool = True,
|
||||
):
|
||||
parser = argparse.ArgumentParser(
|
||||
description=description,
|
||||
parents=parents or [],
|
||||
add_help=add_help,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config-file",
|
||||
type=str,
|
||||
help="Model configuration file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained-weights",
|
||||
type=str,
|
||||
help="Pretrained model weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Output directory to write results and logs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--opts",
|
||||
help="Extra configuration options",
|
||||
default=[],
|
||||
nargs="+",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def get_autocast_dtype(config):
|
||||
teacher_dtype_str = config.compute_precision.teacher.backbone.mixed_precision.param_dtype
|
||||
if teacher_dtype_str == "fp16":
|
||||
return torch.half
|
||||
elif teacher_dtype_str == "bf16":
|
||||
return torch.bfloat16
|
||||
else:
|
||||
return torch.float
|
||||
|
||||
|
||||
def build_model_for_eval(config, pretrained_weights):
|
||||
model, _ = build_model_from_cfg(config, only_teacher=True)
|
||||
dinov2_utils.load_pretrained_weights(model, pretrained_weights, "teacher")
|
||||
model.eval()
|
||||
model.cuda()
|
||||
return model
|
||||
|
||||
|
||||
def setup_and_build_model(args) -> Tuple[Any, torch.dtype]:
|
||||
cudnn.benchmark = True
|
||||
config = setup(args)
|
||||
model = build_model_for_eval(config, args.pretrained_weights)
|
||||
autocast_dtype = get_autocast_dtype(config)
|
||||
return model, autocast_dtype
|
||||
@@ -0,0 +1,147 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchmetrics import MetricCollection
|
||||
|
||||
from dinov2.data import DatasetWithEnumeratedTargets, SamplerType, make_data_loader
|
||||
import dinov2.distributed as distributed
|
||||
from dinov2.logging import MetricLogger
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
class ModelWithNormalize(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, samples):
|
||||
return nn.functional.normalize(self.model(samples), dim=1, p=2)
|
||||
|
||||
|
||||
class ModelWithIntermediateLayers(nn.Module):
|
||||
def __init__(self, feature_model, n_last_blocks, autocast_ctx):
|
||||
super().__init__()
|
||||
self.feature_model = feature_model
|
||||
self.feature_model.eval()
|
||||
self.n_last_blocks = n_last_blocks
|
||||
self.autocast_ctx = autocast_ctx
|
||||
|
||||
def forward(self, images):
|
||||
with torch.inference_mode():
|
||||
with self.autocast_ctx():
|
||||
features = self.feature_model.get_intermediate_layers(
|
||||
images, self.n_last_blocks, return_class_token=True
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def evaluate(
|
||||
model: nn.Module,
|
||||
data_loader,
|
||||
postprocessors: Dict[str, nn.Module],
|
||||
metrics: Dict[str, MetricCollection],
|
||||
device: torch.device,
|
||||
criterion: Optional[nn.Module] = None,
|
||||
):
|
||||
model.eval()
|
||||
if criterion is not None:
|
||||
criterion.eval()
|
||||
|
||||
for metric in metrics.values():
|
||||
metric = metric.to(device)
|
||||
|
||||
metric_logger = MetricLogger(delimiter=" ")
|
||||
header = "Test:"
|
||||
|
||||
for samples, targets, *_ in metric_logger.log_every(data_loader, 10, header):
|
||||
outputs = model(samples.to(device))
|
||||
targets = targets.to(device)
|
||||
|
||||
if criterion is not None:
|
||||
loss = criterion(outputs, targets)
|
||||
metric_logger.update(loss=loss.item())
|
||||
|
||||
for k, metric in metrics.items():
|
||||
metric_inputs = postprocessors[k](outputs, targets)
|
||||
metric.update(**metric_inputs)
|
||||
|
||||
metric_logger.synchronize_between_processes()
|
||||
logger.info(f"Averaged stats: {metric_logger}")
|
||||
|
||||
stats = {k: metric.compute() for k, metric in metrics.items()}
|
||||
metric_logger_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
||||
return metric_logger_stats, stats
|
||||
|
||||
|
||||
def all_gather_and_flatten(tensor_rank):
|
||||
tensor_all_ranks = torch.empty(
|
||||
distributed.get_global_size(),
|
||||
*tensor_rank.shape,
|
||||
dtype=tensor_rank.dtype,
|
||||
device=tensor_rank.device,
|
||||
)
|
||||
tensor_list = list(tensor_all_ranks.unbind(0))
|
||||
torch.distributed.all_gather(tensor_list, tensor_rank.contiguous())
|
||||
return tensor_all_ranks.flatten(end_dim=1)
|
||||
|
||||
|
||||
def extract_features(model, dataset, batch_size, num_workers, gather_on_cpu=False):
|
||||
dataset_with_enumerated_targets = DatasetWithEnumeratedTargets(dataset)
|
||||
sample_count = len(dataset_with_enumerated_targets)
|
||||
data_loader = make_data_loader(
|
||||
dataset=dataset_with_enumerated_targets,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
sampler_type=SamplerType.DISTRIBUTED,
|
||||
drop_last=False,
|
||||
shuffle=False,
|
||||
)
|
||||
return extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu=False):
|
||||
gather_device = torch.device("cpu") if gather_on_cpu else torch.device("cuda")
|
||||
metric_logger = MetricLogger(delimiter=" ")
|
||||
features, all_labels = None, None
|
||||
for samples, (index, labels_rank) in metric_logger.log_every(data_loader, 10):
|
||||
samples = samples.cuda(non_blocking=True)
|
||||
labels_rank = labels_rank.cuda(non_blocking=True)
|
||||
index = index.cuda(non_blocking=True)
|
||||
features_rank = model(samples).float()
|
||||
|
||||
# init storage feature matrix
|
||||
if features is None:
|
||||
features = torch.zeros(sample_count, features_rank.shape[-1], device=gather_device)
|
||||
labels_shape = list(labels_rank.shape)
|
||||
labels_shape[0] = sample_count
|
||||
all_labels = torch.full(labels_shape, fill_value=-1, device=gather_device)
|
||||
logger.info(f"Storing features into tensor of shape {features.shape}")
|
||||
|
||||
# share indexes, features and labels between processes
|
||||
index_all = all_gather_and_flatten(index).to(gather_device)
|
||||
features_all_ranks = all_gather_and_flatten(features_rank).to(gather_device)
|
||||
labels_all_ranks = all_gather_and_flatten(labels_rank).to(gather_device)
|
||||
|
||||
# update storage feature matrix
|
||||
if len(index_all) > 0:
|
||||
features.index_copy_(0, index_all, features_all_ranks)
|
||||
all_labels.index_copy_(0, index_all, labels_all_ranks)
|
||||
|
||||
logger.info(f"Features shape: {tuple(features.shape)}")
|
||||
logger.info(f"Labels shape: {tuple(all_labels.shape)}")
|
||||
|
||||
assert torch.all(all_labels > -1)
|
||||
|
||||
return features, all_labels
|
||||
@@ -0,0 +1,158 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import dinov2.distributed as distributed
|
||||
from functools import partial
|
||||
from fvcore.common.checkpoint import Checkpointer
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import ShardingStrategy
|
||||
from torch.distributed.fsdp import MixedPrecision
|
||||
from torch.distributed.fsdp import StateDictType
|
||||
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
from torch.distributed.fsdp._runtime_utils import _reshard
|
||||
|
||||
|
||||
def get_fsdp_wrapper(model_cfg, modules_to_wrap=set()):
|
||||
sharding_strategy_dict = {
|
||||
"NO_SHARD": ShardingStrategy.NO_SHARD,
|
||||
"SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP,
|
||||
"FULL_SHARD": ShardingStrategy.FULL_SHARD,
|
||||
}
|
||||
|
||||
dtype_dict = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
mixed_precision_config = MixedPrecision(
|
||||
param_dtype=dtype_dict[model_cfg.mixed_precision.param_dtype],
|
||||
reduce_dtype=dtype_dict[model_cfg.mixed_precision.reduce_dtype],
|
||||
buffer_dtype=dtype_dict[model_cfg.mixed_precision.buffer_dtype],
|
||||
)
|
||||
|
||||
sharding_strategy_config = sharding_strategy_dict[model_cfg.sharding_strategy]
|
||||
|
||||
local_rank = distributed.get_local_rank()
|
||||
|
||||
fsdp_wrapper = partial(
|
||||
FSDP,
|
||||
sharding_strategy=sharding_strategy_config,
|
||||
mixed_precision=mixed_precision_config,
|
||||
device_id=local_rank,
|
||||
sync_module_states=True,
|
||||
use_orig_params=True,
|
||||
auto_wrap_policy=ModuleWrapPolicy(modules_to_wrap),
|
||||
)
|
||||
return fsdp_wrapper
|
||||
|
||||
|
||||
def is_fsdp(x):
|
||||
return isinstance(x, FSDP)
|
||||
|
||||
|
||||
def is_sharded_fsdp(x):
|
||||
return is_fsdp(x) and x.sharding_strategy is not ShardingStrategy.NO_SHARD
|
||||
|
||||
|
||||
def free_if_fsdp(x):
|
||||
if is_sharded_fsdp(x):
|
||||
handles = x._handles
|
||||
true_list = [True for h in handles]
|
||||
_reshard(x, handles, true_list)
|
||||
|
||||
|
||||
def get_fsdp_modules(x):
|
||||
return FSDP.fsdp_modules(x)
|
||||
|
||||
|
||||
def reshard_fsdp_model(x):
|
||||
for m in get_fsdp_modules(x):
|
||||
free_if_fsdp(m)
|
||||
|
||||
|
||||
def rankstr():
|
||||
return f"rank_{distributed.get_global_rank()}"
|
||||
|
||||
|
||||
class FSDPCheckpointer(Checkpointer):
|
||||
def save(self, name: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Dump model and checkpointables to a file.
|
||||
|
||||
Args:
|
||||
name (str): name of the file.
|
||||
kwargs (dict): extra arbitrary data to save.
|
||||
"""
|
||||
if not self.save_dir or not self.save_to_disk:
|
||||
return
|
||||
|
||||
data = {}
|
||||
with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT):
|
||||
data["model"] = self.model.state_dict()
|
||||
|
||||
# data["model"] = self.model.state_dict()
|
||||
for key, obj in self.checkpointables.items():
|
||||
data[key] = obj.state_dict()
|
||||
data.update(kwargs)
|
||||
|
||||
basename = f"{name}.{rankstr()}.pth"
|
||||
save_file = os.path.join(self.save_dir, basename)
|
||||
assert os.path.basename(save_file) == basename, basename
|
||||
self.logger.info("Saving checkpoint to {}".format(save_file))
|
||||
with self.path_manager.open(save_file, "wb") as f:
|
||||
torch.save(data, f)
|
||||
self.tag_last_checkpoint(basename)
|
||||
|
||||
def load(self, *args, **kwargs):
|
||||
with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT):
|
||||
return super().load(*args, **kwargs)
|
||||
|
||||
def has_checkpoint(self) -> bool:
|
||||
"""
|
||||
Returns:
|
||||
bool: whether a checkpoint exists in the target directory.
|
||||
"""
|
||||
save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}")
|
||||
return self.path_manager.exists(save_file)
|
||||
|
||||
def get_checkpoint_file(self) -> str:
|
||||
"""
|
||||
Returns:
|
||||
str: The latest checkpoint file in target directory.
|
||||
"""
|
||||
save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}")
|
||||
try:
|
||||
with self.path_manager.open(save_file, "r") as f:
|
||||
last_saved = f.read().strip()
|
||||
except IOError:
|
||||
# if file doesn't exist, maybe because it has just been
|
||||
# deleted by a separate process
|
||||
return ""
|
||||
# pyre-fixme[6]: For 2nd param expected `Union[PathLike[str], str]` but got
|
||||
# `Union[bytes, str]`.
|
||||
return os.path.join(self.save_dir, last_saved)
|
||||
|
||||
def tag_last_checkpoint(self, last_filename_basename: str) -> None:
|
||||
"""
|
||||
Tag the last checkpoint.
|
||||
|
||||
Args:
|
||||
last_filename_basename (str): the basename of the last filename.
|
||||
"""
|
||||
if distributed.is_enabled():
|
||||
torch.distributed.barrier()
|
||||
save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}")
|
||||
with self.path_manager.open(save_file, "w") as f:
|
||||
f.write(last_filename_basename) # pyre-ignore
|
||||
|
||||
|
||||
ShardedGradScaler = ShardedGradScaler
|
||||
@@ -0,0 +1,12 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .dino_head import DINOHead
|
||||
from .mlp import Mlp
|
||||
from .patch_embed import PatchEmbed
|
||||
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
||||
from .block import NestedTensorBlock
|
||||
from .attention import MemEffAttention
|
||||
@@ -0,0 +1,81 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
||||
|
||||
import logging
|
||||
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention, unbind, fmha
|
||||
|
||||
XFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.warning("xFormers not available")
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
proj_bias: bool = True,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
|
||||
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class MemEffAttention(Attention):
|
||||
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
||||
if not XFORMERS_AVAILABLE:
|
||||
assert attn_bias is None, "xFormers is required for nested tensors usage"
|
||||
return super().forward(x)
|
||||
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
||||
|
||||
q, k, v = unbind(qkv, 2)
|
||||
|
||||
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
||||
x = x.reshape([B, N, C])
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
@@ -0,0 +1,252 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
||||
|
||||
import logging
|
||||
from typing import Callable, List, Any, Tuple, Dict
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
from .attention import Attention, MemEffAttention
|
||||
from .drop_path import DropPath
|
||||
from .layer_scale import LayerScale
|
||||
from .mlp import Mlp
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
try:
|
||||
from xformers.ops import fmha
|
||||
from xformers.ops import scaled_index_add, index_select_cat
|
||||
|
||||
XFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.warning("xFormers not available")
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = False,
|
||||
proj_bias: bool = True,
|
||||
ffn_bias: bool = True,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
init_values=None,
|
||||
drop_path: float = 0.0,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
||||
attn_class: Callable[..., nn.Module] = Attention,
|
||||
ffn_layer: Callable[..., nn.Module] = Mlp,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = attn_class(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = ffn_layer(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
bias=ffn_bias,
|
||||
)
|
||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.sample_drop_ratio = drop_path
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
def attn_residual_func(x: Tensor) -> Tensor:
|
||||
return self.ls1(self.attn(self.norm1(x)))
|
||||
|
||||
def ffn_residual_func(x: Tensor) -> Tensor:
|
||||
return self.ls2(self.mlp(self.norm2(x)))
|
||||
|
||||
if self.training and self.sample_drop_ratio > 0.1:
|
||||
# the overhead is compensated only for a drop path rate larger than 0.1
|
||||
x = drop_add_residual_stochastic_depth(
|
||||
x,
|
||||
residual_func=attn_residual_func,
|
||||
sample_drop_ratio=self.sample_drop_ratio,
|
||||
)
|
||||
x = drop_add_residual_stochastic_depth(
|
||||
x,
|
||||
residual_func=ffn_residual_func,
|
||||
sample_drop_ratio=self.sample_drop_ratio,
|
||||
)
|
||||
elif self.training and self.sample_drop_ratio > 0.0:
|
||||
x = x + self.drop_path1(attn_residual_func(x))
|
||||
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
||||
else:
|
||||
x = x + attn_residual_func(x)
|
||||
x = x + ffn_residual_func(x)
|
||||
return x
|
||||
|
||||
|
||||
def drop_add_residual_stochastic_depth(
|
||||
x: Tensor,
|
||||
residual_func: Callable[[Tensor], Tensor],
|
||||
sample_drop_ratio: float = 0.0,
|
||||
) -> Tensor:
|
||||
# 1) extract subset using permutation
|
||||
b, n, d = x.shape
|
||||
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
||||
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
||||
x_subset = x[brange]
|
||||
|
||||
# 2) apply residual_func to get residual
|
||||
residual = residual_func(x_subset)
|
||||
|
||||
x_flat = x.flatten(1)
|
||||
residual = residual.flatten(1)
|
||||
|
||||
residual_scale_factor = b / sample_subset_size
|
||||
|
||||
# 3) add the residual
|
||||
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
||||
return x_plus_residual.view_as(x)
|
||||
|
||||
|
||||
def get_branges_scales(x, sample_drop_ratio=0.0):
|
||||
b, n, d = x.shape
|
||||
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
||||
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
||||
residual_scale_factor = b / sample_subset_size
|
||||
return brange, residual_scale_factor
|
||||
|
||||
|
||||
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
||||
if scaling_vector is None:
|
||||
x_flat = x.flatten(1)
|
||||
residual = residual.flatten(1)
|
||||
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
||||
else:
|
||||
x_plus_residual = scaled_index_add(
|
||||
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
||||
)
|
||||
return x_plus_residual
|
||||
|
||||
|
||||
attn_bias_cache: Dict[Tuple, Any] = {}
|
||||
|
||||
|
||||
def get_attn_bias_and_cat(x_list, branges=None):
|
||||
"""
|
||||
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
||||
"""
|
||||
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
||||
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
||||
if all_shapes not in attn_bias_cache.keys():
|
||||
seqlens = []
|
||||
for b, x in zip(batch_sizes, x_list):
|
||||
for _ in range(b):
|
||||
seqlens.append(x.shape[1])
|
||||
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
||||
attn_bias._batch_sizes = batch_sizes
|
||||
attn_bias_cache[all_shapes] = attn_bias
|
||||
|
||||
if branges is not None:
|
||||
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
||||
else:
|
||||
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
||||
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
||||
|
||||
return attn_bias_cache[all_shapes], cat_tensors
|
||||
|
||||
|
||||
def drop_add_residual_stochastic_depth_list(
|
||||
x_list: List[Tensor],
|
||||
residual_func: Callable[[Tensor, Any], Tensor],
|
||||
sample_drop_ratio: float = 0.0,
|
||||
scaling_vector=None,
|
||||
) -> Tensor:
|
||||
# 1) generate random set of indices for dropping samples in the batch
|
||||
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
||||
branges = [s[0] for s in branges_scales]
|
||||
residual_scale_factors = [s[1] for s in branges_scales]
|
||||
|
||||
# 2) get attention bias and index+concat the tensors
|
||||
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
||||
|
||||
# 3) apply residual_func to get residual, and split the result
|
||||
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
||||
|
||||
outputs = []
|
||||
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
||||
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
||||
return outputs
|
||||
|
||||
|
||||
class NestedTensorBlock(Block):
|
||||
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
||||
"""
|
||||
x_list contains a list of tensors to nest together and run
|
||||
"""
|
||||
assert isinstance(self.attn, MemEffAttention)
|
||||
|
||||
if self.training and self.sample_drop_ratio > 0.0:
|
||||
|
||||
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
||||
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
||||
|
||||
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
||||
return self.mlp(self.norm2(x))
|
||||
|
||||
x_list = drop_add_residual_stochastic_depth_list(
|
||||
x_list,
|
||||
residual_func=attn_residual_func,
|
||||
sample_drop_ratio=self.sample_drop_ratio,
|
||||
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
||||
)
|
||||
x_list = drop_add_residual_stochastic_depth_list(
|
||||
x_list,
|
||||
residual_func=ffn_residual_func,
|
||||
sample_drop_ratio=self.sample_drop_ratio,
|
||||
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
||||
)
|
||||
return x_list
|
||||
else:
|
||||
|
||||
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
||||
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
||||
|
||||
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
||||
return self.ls2(self.mlp(self.norm2(x)))
|
||||
|
||||
attn_bias, x = get_attn_bias_and_cat(x_list)
|
||||
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
||||
x = x + ffn_residual_func(x)
|
||||
return attn_bias.split(x)
|
||||
|
||||
def forward(self, x_or_x_list):
|
||||
if isinstance(x_or_x_list, Tensor):
|
||||
return super().forward(x_or_x_list)
|
||||
elif isinstance(x_or_x_list, list):
|
||||
assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
|
||||
return self.forward_nested(x_or_x_list)
|
||||
else:
|
||||
raise AssertionError
|
||||
@@ -0,0 +1,59 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.init import trunc_normal_
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
|
||||
class DINOHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_dim,
|
||||
out_dim,
|
||||
use_bn=False,
|
||||
nlayers=3,
|
||||
hidden_dim=2048,
|
||||
bottleneck_dim=256,
|
||||
mlp_bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
nlayers = max(nlayers, 1)
|
||||
self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
|
||||
self.apply(self._init_weights)
|
||||
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
||||
self.last_layer.weight_g.data.fill_(1)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.mlp(x)
|
||||
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
||||
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
||||
x = self.last_layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
|
||||
if nlayers == 1:
|
||||
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
||||
else:
|
||||
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
||||
if use_bn:
|
||||
layers.append(nn.BatchNorm1d(hidden_dim))
|
||||
layers.append(nn.GELU())
|
||||
for _ in range(nlayers - 2):
|
||||
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
||||
if use_bn:
|
||||
layers.append(nn.BatchNorm1d(hidden_dim))
|
||||
layers.append(nn.GELU())
|
||||
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
||||
return nn.Sequential(*layers)
|
||||
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
||||
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
||||
if drop_prob == 0.0 or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||
if keep_prob > 0.0:
|
||||
random_tensor.div_(keep_prob)
|
||||
output = x * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
@@ -0,0 +1,28 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
||||
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
init_values: Union[float, Tensor] = 1e-5,
|
||||
inplace: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.inplace = inplace
|
||||
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
||||
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
||||
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
@@ -0,0 +1,89 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
||||
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def make_2tuple(x):
|
||||
if isinstance(x, tuple):
|
||||
assert len(x) == 2
|
||||
return x
|
||||
|
||||
assert isinstance(x, int)
|
||||
return (x, x)
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
||||
|
||||
Args:
|
||||
img_size: Image size.
|
||||
patch_size: Patch token size.
|
||||
in_chans: Number of input image channels.
|
||||
embed_dim: Number of linear projection output channels.
|
||||
norm_layer: Normalization layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer: Optional[Callable] = None,
|
||||
flatten_embedding: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
image_HW = make_2tuple(img_size)
|
||||
patch_HW = make_2tuple(patch_size)
|
||||
patch_grid_size = (
|
||||
image_HW[0] // patch_HW[0],
|
||||
image_HW[1] // patch_HW[1],
|
||||
)
|
||||
|
||||
self.img_size = image_HW
|
||||
self.patch_size = patch_HW
|
||||
self.patches_resolution = patch_grid_size
|
||||
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.flatten_embedding = flatten_embedding
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
_, _, H, W = x.shape
|
||||
patch_H, patch_W = self.patch_size
|
||||
|
||||
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
||||
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
||||
|
||||
x = self.proj(x) # B C H W
|
||||
H, W = x.size(2), x.size(3)
|
||||
x = x.flatten(2).transpose(1, 2) # B HW C
|
||||
x = self.norm(x)
|
||||
if not self.flatten_embedding:
|
||||
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
||||
return x
|
||||
|
||||
def flops(self) -> float:
|
||||
Ho, Wo = self.patches_resolution
|
||||
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
||||
if self.norm is not None:
|
||||
flops += Ho * Wo * self.embed_dim
|
||||
return flops
|
||||
@@ -0,0 +1,63 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
from torch import Tensor, nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SwiGLUFFN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer: Callable[..., nn.Module] = None,
|
||||
drop: float = 0.0,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
||||
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x12 = self.w12(x)
|
||||
x1, x2 = x12.chunk(2, dim=-1)
|
||||
hidden = F.silu(x1) * x2
|
||||
return self.w3(hidden)
|
||||
|
||||
|
||||
try:
|
||||
from xformers.ops import SwiGLU
|
||||
|
||||
XFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
SwiGLU = SwiGLUFFN
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
class SwiGLUFFNFused(SwiGLU):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer: Callable[..., nn.Module] = None,
|
||||
drop: float = 0.0,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
||||
super().__init__(
|
||||
in_features=in_features,
|
||||
hidden_features=hidden_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
)
|
||||
@@ -0,0 +1,103 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
import dinov2.distributed as distributed
|
||||
from .helpers import MetricLogger, SmoothedValue
|
||||
|
||||
|
||||
# So that calling _configure_logger multiple times won't add many handlers
|
||||
@functools.lru_cache()
|
||||
def _configure_logger(
|
||||
name: Optional[str] = None,
|
||||
*,
|
||||
level: int = logging.DEBUG,
|
||||
output: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Configure a logger.
|
||||
|
||||
Adapted from Detectron2.
|
||||
|
||||
Args:
|
||||
name: The name of the logger to configure.
|
||||
level: The logging level to use.
|
||||
output: A file name or a directory to save log. If None, will not save log file.
|
||||
If ends with ".txt" or ".log", assumed to be a file name.
|
||||
Otherwise, logs will be saved to `output/log.txt`.
|
||||
|
||||
Returns:
|
||||
The configured logger.
|
||||
"""
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(level)
|
||||
logger.propagate = False
|
||||
|
||||
# Loosely match Google glog format:
|
||||
# [IWEF]yyyymmdd hh:mm:ss.uuuuuu threadid file:line] msg
|
||||
# but use a shorter timestamp and include the logger name:
|
||||
# [IWEF]yyyymmdd hh:mm:ss logger threadid file:line] msg
|
||||
fmt_prefix = "%(levelname).1s%(asctime)s %(process)s %(name)s %(filename)s:%(lineno)s] "
|
||||
fmt_message = "%(message)s"
|
||||
fmt = fmt_prefix + fmt_message
|
||||
datefmt = "%Y%m%d %H:%M:%S"
|
||||
formatter = logging.Formatter(fmt=fmt, datefmt=datefmt)
|
||||
|
||||
# stdout logging for main worker only
|
||||
if distributed.is_main_process():
|
||||
handler = logging.StreamHandler(stream=sys.stdout)
|
||||
handler.setLevel(logging.DEBUG)
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
# file logging for all workers
|
||||
if output:
|
||||
if os.path.splitext(output)[-1] in (".txt", ".log"):
|
||||
filename = output
|
||||
else:
|
||||
filename = os.path.join(output, "logs", "log.txt")
|
||||
|
||||
if not distributed.is_main_process():
|
||||
global_rank = distributed.get_global_rank()
|
||||
filename = filename + ".rank{}".format(global_rank)
|
||||
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
|
||||
handler = logging.StreamHandler(open(filename, "a"))
|
||||
handler.setLevel(logging.DEBUG)
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def setup_logging(
|
||||
output: Optional[str] = None,
|
||||
*,
|
||||
name: Optional[str] = None,
|
||||
level: int = logging.DEBUG,
|
||||
capture_warnings: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Setup logging.
|
||||
|
||||
Args:
|
||||
output: A file name or a directory to save log files. If None, log
|
||||
files will not be saved. If output ends with ".txt" or ".log", it
|
||||
is assumed to be a file name.
|
||||
Otherwise, logs will be saved to `output/log.txt`.
|
||||
name: The name of the logger to configure, by default the root logger.
|
||||
level: The logging level to use.
|
||||
capture_warnings: Whether warnings should be captured as logs.
|
||||
"""
|
||||
logging.captureWarnings(capture_warnings)
|
||||
_configure_logger(name, level=level, output=output)
|
||||
@@ -0,0 +1,195 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import defaultdict, deque
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
import dinov2.distributed as distributed
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
class MetricLogger(object):
|
||||
def __init__(self, delimiter="\t", output_file=None):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
self.output_file = output_file
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
assert isinstance(v, (float, int))
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr))
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append("{}: {}".format(name, str(meter)))
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def dump_in_output_file(self, iteration, iter_time, data_time):
|
||||
if self.output_file is None or not distributed.is_main_process():
|
||||
return
|
||||
dict_to_dump = dict(
|
||||
iteration=iteration,
|
||||
iter_time=iter_time,
|
||||
data_time=data_time,
|
||||
)
|
||||
dict_to_dump.update({k: v.median for k, v in self.meters.items()})
|
||||
with open(self.output_file, "a") as f:
|
||||
f.write(json.dumps(dict_to_dump) + "\n")
|
||||
pass
|
||||
|
||||
def log_every(self, iterable, print_freq, header=None, n_iterations=None, start_iteration=0):
|
||||
i = start_iteration
|
||||
if not header:
|
||||
header = ""
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt="{avg:.6f}")
|
||||
data_time = SmoothedValue(fmt="{avg:.6f}")
|
||||
|
||||
if n_iterations is None:
|
||||
n_iterations = len(iterable)
|
||||
|
||||
space_fmt = ":" + str(len(str(n_iterations))) + "d"
|
||||
|
||||
log_list = [
|
||||
header,
|
||||
"[{0" + space_fmt + "}/{1}]",
|
||||
"eta: {eta}",
|
||||
"{meters}",
|
||||
"time: {time}",
|
||||
"data: {data}",
|
||||
]
|
||||
if torch.cuda.is_available():
|
||||
log_list += ["max mem: {memory:.0f}"]
|
||||
|
||||
log_msg = self.delimiter.join(log_list)
|
||||
MB = 1024.0 * 1024.0
|
||||
for obj in iterable:
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % print_freq == 0 or i == n_iterations - 1:
|
||||
self.dump_in_output_file(iteration=i, iter_time=iter_time.avg, data_time=data_time.avg)
|
||||
eta_seconds = iter_time.global_avg * (n_iterations - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
logger.info(
|
||||
log_msg.format(
|
||||
i,
|
||||
n_iterations,
|
||||
eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time),
|
||||
data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / MB,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
log_msg.format(
|
||||
i,
|
||||
n_iterations,
|
||||
eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time),
|
||||
data=str(data_time),
|
||||
)
|
||||
)
|
||||
i += 1
|
||||
end = time.time()
|
||||
if i >= n_iterations:
|
||||
break
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
logger.info("{} Total time: {} ({:.6f} s / it)".format(header, total_time_str, total_time / n_iterations))
|
||||
|
||||
|
||||
class SmoothedValue:
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, num=1):
|
||||
self.deque.append(value)
|
||||
self.count += num
|
||||
self.total += value * num
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Distributed synchronization of the metric
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
if not distributed.is_enabled():
|
||||
return
|
||||
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
||||
torch.distributed.barrier()
|
||||
torch.distributed.all_reduce(t)
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median,
|
||||
avg=self.avg,
|
||||
global_avg=self.global_avg,
|
||||
max=self.max,
|
||||
value=self.value,
|
||||
)
|
||||
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .dino_clstoken_loss import DINOLoss
|
||||
from .ibot_patch_loss import iBOTPatchLoss
|
||||
from .koleo_loss import KoLeoLoss
|
||||
@@ -0,0 +1,100 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class DINOLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
out_dim,
|
||||
student_temp=0.1,
|
||||
center_momentum=0.9,
|
||||
):
|
||||
super().__init__()
|
||||
self.student_temp = student_temp
|
||||
self.center_momentum = center_momentum
|
||||
self.register_buffer("center", torch.zeros(1, out_dim))
|
||||
self.updated = True
|
||||
self.reduce_handle = None
|
||||
self.len_teacher_output = None
|
||||
self.async_batch_center = None
|
||||
|
||||
@torch.no_grad()
|
||||
def softmax_center_teacher(self, teacher_output, teacher_temp):
|
||||
self.apply_center_update()
|
||||
# teacher centering and sharpening
|
||||
return F.softmax((teacher_output - self.center) / teacher_temp, dim=-1)
|
||||
|
||||
@torch.no_grad()
|
||||
def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_iterations=3):
|
||||
teacher_output = teacher_output.float()
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper
|
||||
B = Q.shape[1] * world_size # number of samples to assign
|
||||
K = Q.shape[0] # how many prototypes
|
||||
|
||||
# make the matrix sums to 1
|
||||
sum_Q = torch.sum(Q)
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(sum_Q)
|
||||
Q /= sum_Q
|
||||
|
||||
for it in range(n_iterations):
|
||||
# normalize each row: total weight per prototype must be 1/K
|
||||
sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(sum_of_rows)
|
||||
Q /= sum_of_rows
|
||||
Q /= K
|
||||
|
||||
# normalize each column: total weight per sample must be 1/B
|
||||
Q /= torch.sum(Q, dim=0, keepdim=True)
|
||||
Q /= B
|
||||
|
||||
Q *= B # the columns must sum to 1 so that Q is an assignment
|
||||
return Q.t()
|
||||
|
||||
def forward(self, student_output_list, teacher_out_softmaxed_centered_list):
|
||||
"""
|
||||
Cross-entropy between softmax outputs of the teacher and student networks.
|
||||
"""
|
||||
# TODO: Use cross_entropy_distribution here
|
||||
total_loss = 0
|
||||
for s in student_output_list:
|
||||
lsm = F.log_softmax(s / self.student_temp, dim=-1)
|
||||
for t in teacher_out_softmaxed_centered_list:
|
||||
loss = torch.sum(t * lsm, dim=-1)
|
||||
total_loss -= loss.mean()
|
||||
return total_loss
|
||||
|
||||
@torch.no_grad()
|
||||
def update_center(self, teacher_output):
|
||||
self.reduce_center_update(teacher_output)
|
||||
|
||||
@torch.no_grad()
|
||||
def reduce_center_update(self, teacher_output):
|
||||
self.updated = False
|
||||
self.len_teacher_output = len(teacher_output)
|
||||
self.async_batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
|
||||
if dist.is_initialized():
|
||||
self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True)
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_center_update(self):
|
||||
if self.updated is False:
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
|
||||
if self.reduce_handle is not None:
|
||||
self.reduce_handle.wait()
|
||||
_t = self.async_batch_center / (self.len_teacher_output * world_size)
|
||||
|
||||
self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum)
|
||||
|
||||
self.updated = True
|
||||
@@ -0,0 +1,152 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
try:
|
||||
from xformers.ops import cross_entropy
|
||||
|
||||
def lossfunc(t, s, temp):
|
||||
s = s.float()
|
||||
t = t.float()
|
||||
if s.ndim == 2:
|
||||
return -cross_entropy(s.unsqueeze(0), t.unsqueeze(0), temp, bw_inplace=True).squeeze(0)
|
||||
elif s.ndim == 3:
|
||||
return -cross_entropy(s, t, temp, bw_inplace=True)
|
||||
|
||||
except ImportError:
|
||||
|
||||
def lossfunc(t, s, temp):
|
||||
return torch.sum(t * F.log_softmax(s / temp, dim=-1), dim=-1)
|
||||
|
||||
|
||||
class iBOTPatchLoss(nn.Module):
|
||||
def __init__(self, patch_out_dim, student_temp=0.1, center_momentum=0.9):
|
||||
super().__init__()
|
||||
self.student_temp = student_temp
|
||||
self.center_momentum = center_momentum
|
||||
self.register_buffer("center", torch.zeros(1, 1, patch_out_dim))
|
||||
self.updated = True
|
||||
self.reduce_handle = None
|
||||
self.len_teacher_patch_tokens = None
|
||||
self.async_batch_center = None
|
||||
|
||||
@torch.no_grad()
|
||||
def softmax_center_teacher(self, teacher_patch_tokens, teacher_temp):
|
||||
self.apply_center_update()
|
||||
# teacher centering and sharpening
|
||||
#
|
||||
# WARNING:
|
||||
# as self.center is a float32, everything gets casted to float32 afterwards
|
||||
#
|
||||
# teacher_patch_tokens = teacher_patch_tokens.float()
|
||||
# return F.softmax((teacher_patch_tokens.sub_(self.center.to(teacher_patch_tokens.dtype))).mul_(1 / teacher_temp), dim=-1)
|
||||
|
||||
return F.softmax((teacher_patch_tokens - self.center) / teacher_temp, dim=-1)
|
||||
|
||||
# this is experimental, keep everything in float16 and let's see what happens:
|
||||
# return F.softmax((teacher_patch_tokens.sub_(self.center)) / teacher_temp, dim=-1)
|
||||
|
||||
@torch.no_grad()
|
||||
def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_masked_patches_tensor, n_iterations=3):
|
||||
teacher_output = teacher_output.float()
|
||||
# world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper
|
||||
# B = Q.shape[1] * world_size # number of samples to assign
|
||||
B = n_masked_patches_tensor
|
||||
dist.all_reduce(B)
|
||||
K = Q.shape[0] # how many prototypes
|
||||
|
||||
# make the matrix sums to 1
|
||||
sum_Q = torch.sum(Q)
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(sum_Q)
|
||||
Q /= sum_Q
|
||||
|
||||
for it in range(n_iterations):
|
||||
# normalize each row: total weight per prototype must be 1/K
|
||||
sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
|
||||
if dist.is_initialized():
|
||||
dist.all_reduce(sum_of_rows)
|
||||
Q /= sum_of_rows
|
||||
Q /= K
|
||||
|
||||
# normalize each column: total weight per sample must be 1/B
|
||||
Q /= torch.sum(Q, dim=0, keepdim=True)
|
||||
Q /= B
|
||||
|
||||
Q *= B # the columns must sum to 1 so that Q is an assignment
|
||||
return Q.t()
|
||||
|
||||
def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat):
|
||||
"""
|
||||
Cross-entropy between softmax outputs of the teacher and student networks.
|
||||
student_patch_tokens: (B, N, D) tensor
|
||||
teacher_patch_tokens: (B, N, D) tensor
|
||||
student_masks_flat: (B, N) tensor
|
||||
"""
|
||||
t = teacher_patch_tokens
|
||||
s = student_patch_tokens
|
||||
loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1)
|
||||
loss = torch.sum(loss * student_masks_flat.float(), dim=-1) / student_masks_flat.sum(dim=-1).clamp(min=1.0)
|
||||
return -loss.mean()
|
||||
|
||||
def forward_masked(
|
||||
self,
|
||||
student_patch_tokens_masked,
|
||||
teacher_patch_tokens_masked,
|
||||
student_masks_flat,
|
||||
n_masked_patches=None,
|
||||
masks_weight=None,
|
||||
):
|
||||
t = teacher_patch_tokens_masked
|
||||
s = student_patch_tokens_masked
|
||||
# loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1)
|
||||
loss = lossfunc(t, s, self.student_temp)
|
||||
if masks_weight is None:
|
||||
masks_weight = (
|
||||
(1 / student_masks_flat.sum(-1).clamp(min=1.0))
|
||||
.unsqueeze(-1)
|
||||
.expand_as(student_masks_flat)[student_masks_flat]
|
||||
)
|
||||
if n_masked_patches is not None:
|
||||
loss = loss[:n_masked_patches]
|
||||
loss = loss * masks_weight
|
||||
return -loss.sum() / student_masks_flat.shape[0]
|
||||
|
||||
@torch.no_grad()
|
||||
def update_center(self, teacher_patch_tokens):
|
||||
self.reduce_center_update(teacher_patch_tokens)
|
||||
|
||||
@torch.no_grad()
|
||||
def reduce_center_update(self, teacher_patch_tokens):
|
||||
self.updated = False
|
||||
self.len_teacher_patch_tokens = len(teacher_patch_tokens)
|
||||
self.async_batch_center = torch.sum(teacher_patch_tokens.mean(1), dim=0, keepdim=True)
|
||||
if dist.is_initialized():
|
||||
self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True)
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_center_update(self):
|
||||
if self.updated is False:
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
|
||||
if self.reduce_handle is not None:
|
||||
self.reduce_handle.wait()
|
||||
_t = self.async_batch_center / (self.len_teacher_patch_tokens * world_size)
|
||||
|
||||
self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum)
|
||||
|
||||
self.updated = True
|
||||
@@ -0,0 +1,49 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# import torch.distributed as dist
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
class KoLeoLoss(nn.Module):
|
||||
"""Kozachenko-Leonenko entropic loss regularizer from Sablayrolles et al. - 2018 - Spreading vectors for similarity search"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pdist = nn.PairwiseDistance(2, eps=1e-8)
|
||||
|
||||
def pairwise_NNs_inner(self, x):
|
||||
"""
|
||||
Pairwise nearest neighbors for L2-normalized vectors.
|
||||
Uses Torch rather than Faiss to remain on GPU.
|
||||
"""
|
||||
# parwise dot products (= inverse distance)
|
||||
dots = torch.mm(x, x.t())
|
||||
n = x.shape[0]
|
||||
dots.view(-1)[:: (n + 1)].fill_(-1) # Trick to fill diagonal with -1
|
||||
# max inner prod -> min distance
|
||||
_, I = torch.max(dots, dim=1) # noqa: E741
|
||||
return I
|
||||
|
||||
def forward(self, student_output, eps=1e-8):
|
||||
"""
|
||||
Args:
|
||||
student_output (BxD): backbone output of student
|
||||
"""
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
student_output = F.normalize(student_output, eps=eps, p=2, dim=-1)
|
||||
I = self.pairwise_NNs_inner(student_output) # noqa: E741
|
||||
distances = self.pdist(student_output, student_output[I]) # BxD, BxD -> B
|
||||
loss = -torch.log(distances + eps).mean()
|
||||
return loss
|
||||
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
from . import vision_transformer as vits
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
def build_model(args, only_teacher=False, img_size=224):
|
||||
args.arch = args.arch.removesuffix("_memeff")
|
||||
if "vit" in args.arch:
|
||||
vit_kwargs = dict(
|
||||
img_size=img_size,
|
||||
patch_size=args.patch_size,
|
||||
init_values=args.layerscale,
|
||||
ffn_layer=args.ffn_layer,
|
||||
block_chunks=args.block_chunks,
|
||||
qkv_bias=args.qkv_bias,
|
||||
proj_bias=args.proj_bias,
|
||||
ffn_bias=args.ffn_bias,
|
||||
)
|
||||
teacher = vits.__dict__[args.arch](**vit_kwargs)
|
||||
if only_teacher:
|
||||
return teacher, teacher.embed_dim
|
||||
student = vits.__dict__[args.arch](
|
||||
**vit_kwargs,
|
||||
drop_path_rate=args.drop_path_rate,
|
||||
drop_path_uniform=args.drop_path_uniform,
|
||||
)
|
||||
embed_dim = student.embed_dim
|
||||
return student, teacher, embed_dim
|
||||
|
||||
|
||||
def build_model_from_cfg(cfg, only_teacher=False):
|
||||
return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
|
||||
@@ -0,0 +1,358 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
||||
|
||||
from functools import partial
|
||||
import math
|
||||
import logging
|
||||
from typing import Sequence, Tuple, Union, Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn.init import trunc_normal_
|
||||
|
||||
from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
||||
if not depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
for child_name, child_module in module.named_children():
|
||||
child_name = ".".join((name, child_name)) if name else child_name
|
||||
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
||||
if depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
return module
|
||||
|
||||
|
||||
class BlockChunk(nn.ModuleList):
|
||||
def forward(self, x):
|
||||
for b in self:
|
||||
x = b(x)
|
||||
return x
|
||||
|
||||
|
||||
class DinoVisionTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
ffn_bias=True,
|
||||
proj_bias=True,
|
||||
drop_path_rate=0.0,
|
||||
drop_path_uniform=False,
|
||||
init_values=None, # for layerscale: None or 0 => no layerscale
|
||||
embed_layer=PatchEmbed,
|
||||
act_layer=nn.GELU,
|
||||
block_fn=Block,
|
||||
ffn_layer="mlp",
|
||||
block_chunks=1,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
img_size (int, tuple): input image size
|
||||
patch_size (int, tuple): patch size
|
||||
in_chans (int): number of input channels
|
||||
embed_dim (int): embedding dimension
|
||||
depth (int): depth of transformer
|
||||
num_heads (int): number of attention heads
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||
qkv_bias (bool): enable bias for qkv if True
|
||||
proj_bias (bool): enable bias for proj in attn if True
|
||||
ffn_bias (bool): enable bias for ffn if True
|
||||
drop_path_rate (float): stochastic depth rate
|
||||
drop_path_uniform (bool): apply uniform drop rate across blocks
|
||||
weight_init (str): weight init scheme
|
||||
init_values (float): layer-scale init values
|
||||
embed_layer (nn.Module): patch embedding layer
|
||||
act_layer (nn.Module): MLP activation layer
|
||||
block_fn (nn.Module): transformer block class
|
||||
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
||||
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
||||
"""
|
||||
super().__init__()
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_tokens = 1
|
||||
self.n_blocks = depth
|
||||
self.num_heads = num_heads
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
||||
|
||||
if drop_path_uniform is True:
|
||||
dpr = [drop_path_rate] * depth
|
||||
else:
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
|
||||
if ffn_layer == "mlp":
|
||||
logger.info("using MLP layer as FFN")
|
||||
ffn_layer = Mlp
|
||||
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
||||
logger.info("using SwiGLU layer as FFN")
|
||||
ffn_layer = SwiGLUFFNFused
|
||||
elif ffn_layer == "identity":
|
||||
logger.info("using Identity layer as FFN")
|
||||
|
||||
def f(*args, **kwargs):
|
||||
return nn.Identity()
|
||||
|
||||
ffn_layer = f
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
blocks_list = [
|
||||
block_fn(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias,
|
||||
ffn_bias=ffn_bias,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
ffn_layer=ffn_layer,
|
||||
init_values=init_values,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
if block_chunks > 0:
|
||||
self.chunked_blocks = True
|
||||
chunked_blocks = []
|
||||
chunksize = depth // block_chunks
|
||||
for i in range(0, depth, chunksize):
|
||||
# this is to keep the block index consistent if we chunk the block list
|
||||
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
||||
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
||||
else:
|
||||
self.chunked_blocks = False
|
||||
self.blocks = nn.ModuleList(blocks_list)
|
||||
|
||||
self.norm = norm_layer(embed_dim)
|
||||
self.head = nn.Identity()
|
||||
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
nn.init.normal_(self.cls_token, std=1e-6)
|
||||
named_apply(init_weights_vit_timm, self)
|
||||
|
||||
def interpolate_pos_encoding(self, x, w, h):
|
||||
previous_dtype = x.dtype
|
||||
npatch = x.shape[1] - 1
|
||||
N = self.pos_embed.shape[1] - 1
|
||||
if npatch == N and w == h:
|
||||
return self.pos_embed
|
||||
pos_embed = self.pos_embed.float()
|
||||
class_pos_embed = pos_embed[:, 0]
|
||||
patch_pos_embed = pos_embed[:, 1:]
|
||||
dim = x.shape[-1]
|
||||
w0 = w // self.patch_size
|
||||
h0 = h // self.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
w0, h0 = w0 + 0.1, h0 + 0.1
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
||||
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
||||
mode="bicubic",
|
||||
)
|
||||
|
||||
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
||||
|
||||
def prepare_tokens_with_masks(self, x, masks=None):
|
||||
B, nc, w, h = x.shape
|
||||
x = self.patch_embed(x)
|
||||
if masks is not None:
|
||||
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
||||
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + self.interpolate_pos_encoding(x, w, h)
|
||||
|
||||
return x
|
||||
|
||||
def forward_features_list(self, x_list, masks_list):
|
||||
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
all_x = x
|
||||
output = []
|
||||
for x, masks in zip(all_x, masks_list):
|
||||
x_norm = self.norm(x)
|
||||
output.append(
|
||||
{
|
||||
"x_norm_clstoken": x_norm[:, 0],
|
||||
"x_norm_patchtokens": x_norm[:, 1:],
|
||||
"x_prenorm": x,
|
||||
"masks": masks,
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
def forward_features(self, x, masks=None):
|
||||
if isinstance(x, list):
|
||||
return self.forward_features_list(x, masks)
|
||||
|
||||
x = self.prepare_tokens_with_masks(x, masks)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
x_norm = self.norm(x)
|
||||
return {
|
||||
"x_norm_clstoken": x_norm[:, 0],
|
||||
"x_norm_patchtokens": x_norm[:, 1:],
|
||||
"x_prenorm": x,
|
||||
"masks": masks,
|
||||
}
|
||||
|
||||
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
||||
x = self.prepare_tokens_with_masks(x)
|
||||
# If n is an int, take the n last blocks. If it's a list, take them
|
||||
output, total_block_len = [], len(self.blocks)
|
||||
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x)
|
||||
if i in blocks_to_take:
|
||||
output.append(x)
|
||||
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
||||
return output
|
||||
|
||||
def _get_intermediate_layers_chunked(self, x, n=1):
|
||||
x = self.prepare_tokens_with_masks(x)
|
||||
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
||||
# If n is an int, take the n last blocks. If it's a list, take them
|
||||
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
||||
for block_chunk in self.blocks:
|
||||
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
||||
x = blk(x)
|
||||
if i in blocks_to_take:
|
||||
output.append(x)
|
||||
i += 1
|
||||
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
||||
return output
|
||||
|
||||
def get_intermediate_layers(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
||||
reshape: bool = False,
|
||||
return_class_token: bool = False,
|
||||
norm=True,
|
||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
||||
if self.chunked_blocks:
|
||||
outputs = self._get_intermediate_layers_chunked(x, n)
|
||||
else:
|
||||
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
||||
if norm:
|
||||
outputs = [self.norm(out) for out in outputs]
|
||||
class_tokens = [out[:, 0] for out in outputs]
|
||||
outputs = [out[:, 1:] for out in outputs]
|
||||
if reshape:
|
||||
B, _, w, h = x.shape
|
||||
outputs = [
|
||||
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
||||
for out in outputs
|
||||
]
|
||||
if return_class_token:
|
||||
return tuple(zip(outputs, class_tokens))
|
||||
return tuple(outputs)
|
||||
|
||||
def forward(self, *args, is_training=False, **kwargs):
|
||||
ret = self.forward_features(*args, **kwargs)
|
||||
if is_training:
|
||||
return ret
|
||||
else:
|
||||
return self.head(ret["x_norm_clstoken"])
|
||||
|
||||
|
||||
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
||||
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
||||
if isinstance(module, nn.Linear):
|
||||
trunc_normal_(module.weight, std=0.02)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
|
||||
def vit_small(patch_size=16, **kwargs):
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=384,
|
||||
depth=12,
|
||||
num_heads=6,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def vit_base(patch_size=16, **kwargs):
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def vit_large(patch_size=16, **kwargs):
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=1024,
|
||||
depth=24,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def vit_giant2(patch_size=16, **kwargs):
|
||||
"""
|
||||
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
||||
"""
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=1536,
|
||||
depth=40,
|
||||
num_heads=24,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from dinov2.eval.knn import get_args_parser as get_knn_args_parser
|
||||
from dinov2.logging import setup_logging
|
||||
from dinov2.run.submit import get_args_parser, submit_jobs
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
def __call__(self):
|
||||
from dinov2.eval.knn import main as knn_main
|
||||
|
||||
self._setup_args()
|
||||
knn_main(self.args)
|
||||
|
||||
def checkpoint(self):
|
||||
import submitit
|
||||
|
||||
logger.info(f"Requeuing {self.args}")
|
||||
empty = type(self)(self.args)
|
||||
return submitit.helpers.DelayedSubmission(empty)
|
||||
|
||||
def _setup_args(self):
|
||||
import submitit
|
||||
|
||||
job_env = submitit.JobEnvironment()
|
||||
self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id))
|
||||
logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
|
||||
logger.info(f"Args: {self.args}")
|
||||
|
||||
|
||||
def main():
|
||||
description = "Submitit launcher for DINOv2 k-NN evaluation"
|
||||
knn_args_parser = get_knn_args_parser(add_help=False)
|
||||
parents = [knn_args_parser]
|
||||
args_parser = get_args_parser(description=description, parents=parents)
|
||||
args = args_parser.parse_args()
|
||||
|
||||
setup_logging()
|
||||
|
||||
assert os.path.exists(args.config_file), "Configuration file does not exist!"
|
||||
submit_jobs(Evaluator, args, name="dinov2:knn")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from dinov2.eval.linear import get_args_parser as get_linear_args_parser
|
||||
from dinov2.logging import setup_logging
|
||||
from dinov2.run.submit import get_args_parser, submit_jobs
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
def __call__(self):
|
||||
from dinov2.eval.linear import main as linear_main
|
||||
|
||||
self._setup_args()
|
||||
linear_main(self.args)
|
||||
|
||||
def checkpoint(self):
|
||||
import submitit
|
||||
|
||||
logger.info(f"Requeuing {self.args}")
|
||||
empty = type(self)(self.args)
|
||||
return submitit.helpers.DelayedSubmission(empty)
|
||||
|
||||
def _setup_args(self):
|
||||
import submitit
|
||||
|
||||
job_env = submitit.JobEnvironment()
|
||||
self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id))
|
||||
logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
|
||||
logger.info(f"Args: {self.args}")
|
||||
|
||||
|
||||
def main():
|
||||
description = "Submitit launcher for DINOv2 linear evaluation"
|
||||
linear_args_parser = get_linear_args_parser(add_help=False)
|
||||
parents = [linear_args_parser]
|
||||
args_parser = get_args_parser(description=description, parents=parents)
|
||||
args = args_parser.parse_args()
|
||||
|
||||
setup_logging()
|
||||
|
||||
assert os.path.exists(args.config_file), "Configuration file does not exist!"
|
||||
submit_jobs(Evaluator, args, name="dinov2:linear")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from dinov2.eval.log_regression import get_args_parser as get_log_regression_args_parser
|
||||
from dinov2.logging import setup_logging
|
||||
from dinov2.run.submit import get_args_parser, submit_jobs
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
def __call__(self):
|
||||
from dinov2.eval.log_regression import main as log_regression_main
|
||||
|
||||
self._setup_args()
|
||||
log_regression_main(self.args)
|
||||
|
||||
def checkpoint(self):
|
||||
import submitit
|
||||
|
||||
logger.info(f"Requeuing {self.args}")
|
||||
empty = type(self)(self.args)
|
||||
return submitit.helpers.DelayedSubmission(empty)
|
||||
|
||||
def _setup_args(self):
|
||||
import submitit
|
||||
|
||||
job_env = submitit.JobEnvironment()
|
||||
self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id))
|
||||
logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
|
||||
logger.info(f"Args: {self.args}")
|
||||
|
||||
|
||||
def main():
|
||||
description = "Submitit launcher for DINOv2 logistic evaluation"
|
||||
log_regression_args_parser = get_log_regression_args_parser(add_help=False)
|
||||
parents = [log_regression_args_parser]
|
||||
args_parser = get_args_parser(description=description, parents=parents)
|
||||
args = args_parser.parse_args()
|
||||
|
||||
setup_logging()
|
||||
|
||||
assert os.path.exists(args.config_file), "Configuration file does not exist!"
|
||||
submit_jobs(Evaluator, args, name="dinov2:logreg")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -0,0 +1,123 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import submitit
|
||||
|
||||
from dinov2.utils.cluster import (
|
||||
get_slurm_executor_parameters,
|
||||
get_slurm_partition,
|
||||
get_user_checkpoint_path,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
def get_args_parser(
|
||||
description: Optional[str] = None,
|
||||
parents: Optional[List[argparse.ArgumentParser]] = None,
|
||||
add_help: bool = True,
|
||||
) -> argparse.ArgumentParser:
|
||||
parents = parents or []
|
||||
slurm_partition = get_slurm_partition()
|
||||
parser = argparse.ArgumentParser(
|
||||
description=description,
|
||||
parents=parents,
|
||||
add_help=add_help,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ngpus",
|
||||
"--gpus",
|
||||
"--gpus-per-node",
|
||||
default=8,
|
||||
type=int,
|
||||
help="Number of GPUs to request on each node",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nodes",
|
||||
"--nnodes",
|
||||
default=2,
|
||||
type=int,
|
||||
help="Number of nodes to request",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
default=2800,
|
||||
type=int,
|
||||
help="Duration of the job",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--partition",
|
||||
default=slurm_partition,
|
||||
type=str,
|
||||
help="Partition where to submit",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-volta32",
|
||||
action="store_true",
|
||||
help="Request V100-32GB GPUs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--comment",
|
||||
default="",
|
||||
type=str,
|
||||
help="Comment to pass to scheduler, e.g. priority message",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--exclude",
|
||||
default="",
|
||||
type=str,
|
||||
help="Nodes to exclude",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def get_shared_folder() -> Path:
|
||||
user_checkpoint_path = get_user_checkpoint_path()
|
||||
if user_checkpoint_path is None:
|
||||
raise RuntimeError("Path to user checkpoint cannot be determined")
|
||||
path = user_checkpoint_path / "experiments"
|
||||
path.mkdir(exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def submit_jobs(task_class, args, name: str):
|
||||
if not args.output_dir:
|
||||
args.output_dir = str(get_shared_folder() / "%j")
|
||||
|
||||
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30)
|
||||
|
||||
kwargs = {}
|
||||
if args.use_volta32:
|
||||
kwargs["slurm_constraint"] = "volta32gb"
|
||||
if args.comment:
|
||||
kwargs["slurm_comment"] = args.comment
|
||||
if args.exclude:
|
||||
kwargs["slurm_exclude"] = args.exclude
|
||||
|
||||
executor_params = get_slurm_executor_parameters(
|
||||
nodes=args.nodes,
|
||||
num_gpus_per_node=args.ngpus,
|
||||
timeout_min=args.timeout, # max is 60 * 72
|
||||
slurm_signal_delay_s=120,
|
||||
slurm_partition=args.partition,
|
||||
**kwargs,
|
||||
)
|
||||
executor.update_parameters(name=name, **executor_params)
|
||||
|
||||
task = task_class(args)
|
||||
job = executor.submit(task)
|
||||
|
||||
logger.info(f"Submitted job_id: {job.job_id}")
|
||||
str_output_dir = os.path.abspath(args.output_dir).replace("%j", str(job.job_id))
|
||||
logger.info(f"Logs and checkpoints will be saved at: {str_output_dir}")
|
||||
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from dinov2.logging import setup_logging
|
||||
from dinov2.train import get_args_parser as get_train_args_parser
|
||||
from dinov2.run.submit import get_args_parser, submit_jobs
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
|
||||
def __call__(self):
|
||||
from dinov2.train import main as train_main
|
||||
|
||||
self._setup_args()
|
||||
train_main(self.args)
|
||||
|
||||
def checkpoint(self):
|
||||
import submitit
|
||||
|
||||
logger.info(f"Requeuing {self.args}")
|
||||
empty = type(self)(self.args)
|
||||
return submitit.helpers.DelayedSubmission(empty)
|
||||
|
||||
def _setup_args(self):
|
||||
import submitit
|
||||
|
||||
job_env = submitit.JobEnvironment()
|
||||
self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id))
|
||||
logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
|
||||
logger.info(f"Args: {self.args}")
|
||||
|
||||
|
||||
def main():
|
||||
description = "Submitit launcher for DINOv2 training"
|
||||
train_args_parser = get_train_args_parser(add_help=False)
|
||||
parents = [train_args_parser]
|
||||
args_parser = get_args_parser(description=description, parents=parents)
|
||||
args = args_parser.parse_args()
|
||||
|
||||
setup_logging()
|
||||
|
||||
assert os.path.exists(args.config_file), "Configuration file does not exist!"
|
||||
submit_jobs(Trainer, args, name="dinov2:train")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -0,0 +1,8 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .train import get_args_parser, main
|
||||
from .ssl_meta_arch import SSLMetaArch
|
||||
@@ -0,0 +1,403 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from functools import partial
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from dinov2.loss import DINOLoss, iBOTPatchLoss, KoLeoLoss
|
||||
from dinov2.models import build_model_from_cfg
|
||||
from dinov2.layers import DINOHead
|
||||
from dinov2.utils.utils import has_batchnorms
|
||||
from dinov2.utils.param_groups import get_params_groups_with_decay, fuse_params_groups
|
||||
from dinov2.fsdp import get_fsdp_wrapper, ShardedGradScaler, get_fsdp_modules, reshard_fsdp_model
|
||||
|
||||
from dinov2.models.vision_transformer import BlockChunk
|
||||
|
||||
try:
|
||||
from xformers.ops import fmha
|
||||
|
||||
XFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
XFORMERS_AVAILABLE = False
|
||||
assert XFORMERS_AVAILABLE, "xFormers is required for DINOv2 training"
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
class SSLMetaArch(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.fp16_scaler = ShardedGradScaler() if cfg.compute_precision.grad_scaler else None
|
||||
|
||||
student_model_dict = dict()
|
||||
teacher_model_dict = dict()
|
||||
|
||||
student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg)
|
||||
student_model_dict["backbone"] = student_backbone
|
||||
teacher_model_dict["backbone"] = teacher_backbone
|
||||
logger.info(f"OPTIONS -- architecture : embed_dim: {embed_dim}")
|
||||
|
||||
if cfg.student.pretrained_weights:
|
||||
chkpt = torch.load(cfg.student.pretrained_weights)
|
||||
logger.info(f"OPTIONS -- pretrained weights: loading from {cfg.student.pretrained_weights}")
|
||||
student_backbone.load_state_dict(chkpt["model"], strict=False)
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
self.dino_out_dim = cfg.dino.head_n_prototypes
|
||||
|
||||
self.do_dino = cfg.dino.loss_weight > 0
|
||||
self.do_koleo = cfg.dino.koleo_loss_weight > 0
|
||||
self.do_ibot = cfg.ibot.loss_weight > 0
|
||||
self.ibot_separate_head = cfg.ibot.separate_head
|
||||
|
||||
logger.info("OPTIONS -- DINO")
|
||||
if self.do_dino:
|
||||
logger.info(f"OPTIONS -- DINO -- loss_weight: {cfg.dino.loss_weight}")
|
||||
logger.info(f"OPTIONS -- DINO -- head_n_prototypes: {cfg.dino.head_n_prototypes}")
|
||||
logger.info(f"OPTIONS -- DINO -- head_bottleneck_dim: {cfg.dino.head_bottleneck_dim}")
|
||||
logger.info(f"OPTIONS -- DINO -- head_hidden_dim: {cfg.dino.head_hidden_dim}")
|
||||
self.dino_loss_weight = cfg.dino.loss_weight
|
||||
dino_head = partial(
|
||||
DINOHead,
|
||||
in_dim=embed_dim,
|
||||
out_dim=cfg.dino.head_n_prototypes,
|
||||
hidden_dim=cfg.dino.head_hidden_dim,
|
||||
bottleneck_dim=cfg.dino.head_bottleneck_dim,
|
||||
nlayers=cfg.dino.head_nlayers,
|
||||
)
|
||||
self.dino_loss = DINOLoss(self.dino_out_dim)
|
||||
if self.do_koleo:
|
||||
logger.info("OPTIONS -- DINO -- applying KOLEO regularization")
|
||||
self.koleo_loss = KoLeoLoss()
|
||||
|
||||
else:
|
||||
logger.info("OPTIONS -- DINO -- not using DINO")
|
||||
|
||||
if self.do_dino or self.do_ibot:
|
||||
student_model_dict["dino_head"] = dino_head()
|
||||
teacher_model_dict["dino_head"] = dino_head()
|
||||
|
||||
logger.info("OPTIONS -- IBOT")
|
||||
logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}")
|
||||
logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_ratio_tuple: {cfg.ibot.mask_ratio_min_max}")
|
||||
logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_sample_probability: {cfg.ibot.mask_sample_probability}")
|
||||
if self.do_ibot:
|
||||
self.ibot_loss_weight = cfg.ibot.loss_weight
|
||||
assert max(cfg.ibot.mask_ratio_min_max) > 0, "please provide a positive mask ratio tuple for ibot"
|
||||
assert cfg.ibot.mask_sample_probability > 0, "please provide a positive mask probability for ibot"
|
||||
self.ibot_out_dim = cfg.ibot.head_n_prototypes if self.ibot_separate_head else cfg.dino.head_n_prototypes
|
||||
self.ibot_patch_loss = iBOTPatchLoss(self.ibot_out_dim)
|
||||
if self.ibot_separate_head:
|
||||
logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}")
|
||||
logger.info(f"OPTIONS -- IBOT -- head_n_prototypes: {cfg.ibot.head_n_prototypes}")
|
||||
logger.info(f"OPTIONS -- IBOT -- head_bottleneck_dim: {cfg.ibot.head_bottleneck_dim}")
|
||||
logger.info(f"OPTIONS -- IBOT -- head_hidden_dim: {cfg.ibot.head_hidden_dim}")
|
||||
ibot_head = partial(
|
||||
DINOHead,
|
||||
in_dim=embed_dim,
|
||||
out_dim=cfg.ibot.head_n_prototypes,
|
||||
hidden_dim=cfg.ibot.head_hidden_dim,
|
||||
bottleneck_dim=cfg.ibot.head_bottleneck_dim,
|
||||
nlayers=cfg.ibot.head_nlayers,
|
||||
)
|
||||
student_model_dict["ibot_head"] = ibot_head()
|
||||
teacher_model_dict["ibot_head"] = ibot_head()
|
||||
else:
|
||||
logger.info("OPTIONS -- IBOT -- head shared with DINO")
|
||||
|
||||
self.need_to_synchronize_fsdp_streams = True
|
||||
|
||||
self.student = nn.ModuleDict(student_model_dict)
|
||||
self.teacher = nn.ModuleDict(teacher_model_dict)
|
||||
|
||||
# there is no backpropagation through the teacher, so no need for gradients
|
||||
for p in self.teacher.parameters():
|
||||
p.requires_grad = False
|
||||
logger.info(f"Student and Teacher are built: they are both {cfg.student.arch} network.")
|
||||
|
||||
def forward(self, inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def backprop_loss(self, loss):
|
||||
if self.fp16_scaler is not None:
|
||||
self.fp16_scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
def forward_backward(self, images, teacher_temp):
|
||||
n_global_crops = 2
|
||||
assert n_global_crops == 2
|
||||
n_local_crops = self.cfg.crops.local_crops_number
|
||||
|
||||
global_crops = images["collated_global_crops"].cuda(non_blocking=True)
|
||||
local_crops = images["collated_local_crops"].cuda(non_blocking=True)
|
||||
|
||||
masks = images["collated_masks"].cuda(non_blocking=True)
|
||||
mask_indices_list = images["mask_indices_list"].cuda(non_blocking=True)
|
||||
n_masked_patches_tensor = images["n_masked_patches"].cuda(non_blocking=True)
|
||||
n_masked_patches = mask_indices_list.shape[0]
|
||||
upperbound = images["upperbound"]
|
||||
masks_weight = images["masks_weight"].cuda(non_blocking=True)
|
||||
|
||||
n_local_crops_loss_terms = max(n_local_crops * n_global_crops, 1)
|
||||
n_global_crops_loss_terms = (n_global_crops - 1) * n_global_crops
|
||||
|
||||
do_dino = self.do_dino
|
||||
do_ibot = self.do_ibot
|
||||
|
||||
# loss scales
|
||||
ibot_loss_scale = 1.0 / n_global_crops
|
||||
|
||||
# teacher output
|
||||
@torch.no_grad()
|
||||
def get_teacher_output():
|
||||
x, n_global_crops_teacher = global_crops, n_global_crops
|
||||
teacher_backbone_output_dict = self.teacher.backbone(x, is_training=True)
|
||||
teacher_cls_tokens = teacher_backbone_output_dict["x_norm_clstoken"]
|
||||
teacher_cls_tokens = teacher_cls_tokens.chunk(n_global_crops_teacher)
|
||||
# watch out: these are chunked and cat'd in reverse so A is matched to B in the global crops dino loss
|
||||
teacher_cls_tokens = torch.cat((teacher_cls_tokens[1], teacher_cls_tokens[0]))
|
||||
ibot_teacher_patch_tokens = teacher_backbone_output_dict["x_norm_patchtokens"]
|
||||
_dim = ibot_teacher_patch_tokens.shape[-1]
|
||||
n_cls_tokens = teacher_cls_tokens.shape[0]
|
||||
|
||||
if do_ibot and not self.ibot_separate_head:
|
||||
buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound + n_cls_tokens, _dim)
|
||||
buffer_tensor_teacher[:n_cls_tokens].copy_(teacher_cls_tokens)
|
||||
torch.index_select(
|
||||
ibot_teacher_patch_tokens.flatten(0, 1),
|
||||
dim=0,
|
||||
index=mask_indices_list,
|
||||
out=buffer_tensor_teacher[n_cls_tokens : n_cls_tokens + n_masked_patches],
|
||||
)
|
||||
tokens_after_head = self.teacher.dino_head(buffer_tensor_teacher)
|
||||
teacher_cls_tokens_after_head = tokens_after_head[:n_cls_tokens]
|
||||
masked_teacher_patch_tokens_after_head = tokens_after_head[
|
||||
n_cls_tokens : n_cls_tokens + n_masked_patches
|
||||
]
|
||||
elif do_ibot and self.ibot_separate_head:
|
||||
buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound, _dim)
|
||||
torch.index_select(
|
||||
ibot_teacher_patch_tokens.flatten(0, 1),
|
||||
dim=0,
|
||||
index=mask_indices_list,
|
||||
out=buffer_tensor_teacher[:n_masked_patches],
|
||||
)
|
||||
teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens)
|
||||
masked_teacher_patch_tokens_after_head = self.teacher.ibot_head(buffer_tensor_teacher)[
|
||||
:n_masked_patches
|
||||
]
|
||||
else:
|
||||
teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens)
|
||||
masked_teacher_ibot_softmaxed_centered = None
|
||||
|
||||
if self.cfg.train.centering == "centering":
|
||||
teacher_dino_softmaxed_centered_list = self.dino_loss.softmax_center_teacher(
|
||||
teacher_cls_tokens_after_head, teacher_temp=teacher_temp
|
||||
).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:])
|
||||
self.dino_loss.update_center(teacher_cls_tokens_after_head)
|
||||
if do_ibot:
|
||||
masked_teacher_patch_tokens_after_head = masked_teacher_patch_tokens_after_head.unsqueeze(0)
|
||||
masked_teacher_ibot_softmaxed_centered = self.ibot_patch_loss.softmax_center_teacher(
|
||||
masked_teacher_patch_tokens_after_head[:, :n_masked_patches], teacher_temp=teacher_temp
|
||||
)
|
||||
masked_teacher_ibot_softmaxed_centered = masked_teacher_ibot_softmaxed_centered.squeeze(0)
|
||||
self.ibot_patch_loss.update_center(masked_teacher_patch_tokens_after_head[:n_masked_patches])
|
||||
|
||||
elif self.cfg.train.centering == "sinkhorn_knopp":
|
||||
teacher_dino_softmaxed_centered_list = self.dino_loss.sinkhorn_knopp_teacher(
|
||||
teacher_cls_tokens_after_head, teacher_temp=teacher_temp
|
||||
).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:])
|
||||
|
||||
if do_ibot:
|
||||
masked_teacher_ibot_softmaxed_centered = self.ibot_patch_loss.sinkhorn_knopp_teacher(
|
||||
masked_teacher_patch_tokens_after_head,
|
||||
teacher_temp=teacher_temp,
|
||||
n_masked_patches_tensor=n_masked_patches_tensor,
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered
|
||||
|
||||
teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered = get_teacher_output()
|
||||
reshard_fsdp_model(self.teacher)
|
||||
|
||||
loss_dict = {}
|
||||
|
||||
loss_accumulator = 0 # for backprop
|
||||
student_global_backbone_output_dict, student_local_backbone_output_dict = self.student.backbone(
|
||||
[global_crops, local_crops], masks=[masks, None], is_training=True
|
||||
)
|
||||
|
||||
inputs_for_student_head_list = []
|
||||
|
||||
# 1a: local crops cls tokens
|
||||
student_local_cls_tokens = student_local_backbone_output_dict["x_norm_clstoken"]
|
||||
inputs_for_student_head_list.append(student_local_cls_tokens.unsqueeze(0))
|
||||
|
||||
# 1b: global crops cls tokens
|
||||
student_global_cls_tokens = student_global_backbone_output_dict["x_norm_clstoken"]
|
||||
inputs_for_student_head_list.append(student_global_cls_tokens.unsqueeze(0))
|
||||
|
||||
# 1c: global crops patch tokens
|
||||
if do_ibot:
|
||||
_dim = student_global_backbone_output_dict["x_norm_clstoken"].shape[-1]
|
||||
ibot_student_patch_tokens = student_global_backbone_output_dict["x_norm_patchtokens"]
|
||||
buffer_tensor_patch_tokens = ibot_student_patch_tokens.new_zeros(upperbound, _dim)
|
||||
buffer_tensor_patch_tokens[:n_masked_patches].copy_(
|
||||
torch.index_select(ibot_student_patch_tokens.flatten(0, 1), dim=0, index=mask_indices_list)
|
||||
)
|
||||
if not self.ibot_separate_head:
|
||||
inputs_for_student_head_list.append(buffer_tensor_patch_tokens.unsqueeze(0))
|
||||
else:
|
||||
student_global_masked_patch_tokens_after_head = self.student.ibot_head(buffer_tensor_patch_tokens)[
|
||||
:n_masked_patches
|
||||
]
|
||||
|
||||
# 2: run
|
||||
_attn_bias, cat_inputs = fmha.BlockDiagonalMask.from_tensor_list(inputs_for_student_head_list)
|
||||
outputs_list = _attn_bias.split(self.student.dino_head(cat_inputs))
|
||||
|
||||
# 3a: local crops cls tokens
|
||||
student_local_cls_tokens_after_head = outputs_list.pop(0).squeeze(0)
|
||||
|
||||
# 3b: global crops cls tokens
|
||||
student_global_cls_tokens_after_head = outputs_list.pop(0).squeeze(0)
|
||||
|
||||
# 3c: global crops patch tokens
|
||||
if do_ibot and not self.ibot_separate_head:
|
||||
student_global_masked_patch_tokens_after_head = outputs_list.pop(0).squeeze(0)[:n_masked_patches]
|
||||
|
||||
if n_local_crops > 0:
|
||||
dino_local_crops_loss = self.dino_loss(
|
||||
student_output_list=student_local_cls_tokens_after_head.chunk(n_local_crops),
|
||||
teacher_out_softmaxed_centered_list=teacher_dino_softmaxed_centered_list,
|
||||
) / (n_global_crops_loss_terms + n_local_crops_loss_terms)
|
||||
|
||||
# store for display
|
||||
loss_dict["dino_local_crops_loss"] = dino_local_crops_loss
|
||||
|
||||
# accumulate loss
|
||||
loss_accumulator += self.dino_loss_weight * dino_local_crops_loss
|
||||
|
||||
# process global crops
|
||||
loss_scales = 2 # this is here since we process global crops together
|
||||
|
||||
if do_dino:
|
||||
# compute loss
|
||||
dino_global_crops_loss = (
|
||||
self.dino_loss(
|
||||
student_output_list=[student_global_cls_tokens_after_head],
|
||||
teacher_out_softmaxed_centered_list=[
|
||||
teacher_dino_softmaxed_centered_list.flatten(0, 1)
|
||||
], # these were chunked and stacked in reverse so A is matched to B
|
||||
)
|
||||
* loss_scales
|
||||
/ (n_global_crops_loss_terms + n_local_crops_loss_terms)
|
||||
)
|
||||
|
||||
loss_dict["dino_global_crops_loss"] = dino_global_crops_loss
|
||||
|
||||
# accumulate loss
|
||||
loss_accumulator += self.dino_loss_weight * dino_global_crops_loss
|
||||
|
||||
student_cls_tokens = student_global_cls_tokens
|
||||
|
||||
if self.do_koleo:
|
||||
koleo_loss = self.cfg.dino.koleo_loss_weight * sum(
|
||||
self.koleo_loss(p) for p in student_cls_tokens.chunk(2)
|
||||
) # we don't apply koleo loss between cls tokens of a same image
|
||||
loss_accumulator += koleo_loss
|
||||
loss_dict["koleo_loss"] = (
|
||||
koleo_loss / loss_scales
|
||||
) # this is to display the same losses as before but we can remove eventually
|
||||
|
||||
if do_ibot:
|
||||
# compute loss
|
||||
ibot_patch_loss = (
|
||||
self.ibot_patch_loss.forward_masked(
|
||||
student_global_masked_patch_tokens_after_head,
|
||||
masked_teacher_ibot_softmaxed_centered,
|
||||
student_masks_flat=masks,
|
||||
n_masked_patches=n_masked_patches,
|
||||
masks_weight=masks_weight,
|
||||
)
|
||||
* loss_scales
|
||||
* ibot_loss_scale
|
||||
)
|
||||
|
||||
# store for display
|
||||
loss_dict["ibot_loss"] = ibot_patch_loss / 2
|
||||
|
||||
# accumulate loss
|
||||
loss_accumulator += self.ibot_loss_weight * ibot_patch_loss
|
||||
|
||||
self.backprop_loss(loss_accumulator)
|
||||
|
||||
self.fsdp_synchronize_streams()
|
||||
|
||||
return loss_dict
|
||||
|
||||
def fsdp_synchronize_streams(self):
|
||||
if self.need_to_synchronize_fsdp_streams:
|
||||
torch.cuda.synchronize()
|
||||
self.student.dino_head._streams = (
|
||||
self.teacher.dino_head._streams
|
||||
) = self.student.backbone._streams = self.teacher.backbone._streams
|
||||
self.need_to_synchronize_fsdp_streams = False
|
||||
|
||||
def update_teacher(self, m):
|
||||
student_param_list = []
|
||||
teacher_param_list = []
|
||||
with torch.no_grad():
|
||||
for k in self.student.keys():
|
||||
for ms, mt in zip(get_fsdp_modules(self.student[k]), get_fsdp_modules(self.teacher[k])):
|
||||
student_param_list += ms.params
|
||||
teacher_param_list += mt.params
|
||||
torch._foreach_mul_(teacher_param_list, m)
|
||||
torch._foreach_add_(teacher_param_list, student_param_list, alpha=1 - m)
|
||||
|
||||
def train(self):
|
||||
super().train()
|
||||
self.teacher.eval()
|
||||
|
||||
def get_maybe_fused_params_for_submodel(self, m):
|
||||
params_groups = get_params_groups_with_decay(
|
||||
model=m,
|
||||
lr_decay_rate=self.cfg.optim.layerwise_decay,
|
||||
patch_embed_lr_mult=self.cfg.optim.patch_embed_lr_mult,
|
||||
)
|
||||
fused_params_groups = fuse_params_groups(params_groups)
|
||||
logger.info("fusing param groups")
|
||||
|
||||
for g in fused_params_groups:
|
||||
g["foreach"] = True
|
||||
return fused_params_groups
|
||||
|
||||
def get_params_groups(self):
|
||||
all_params_groups = []
|
||||
for m in self.student.values():
|
||||
all_params_groups += self.get_maybe_fused_params_for_submodel(m)
|
||||
return all_params_groups
|
||||
|
||||
def prepare_for_distributed_training(self):
|
||||
logger.info("DISTRIBUTED FSDP -- preparing model for distributed training")
|
||||
if has_batchnorms(self.student):
|
||||
raise NotImplementedError
|
||||
# below will synchronize all student subnetworks across gpus:
|
||||
for k, v in self.student.items():
|
||||
self.teacher[k].load_state_dict(self.student[k].state_dict())
|
||||
student_model_cfg = self.cfg.compute_precision.student[k]
|
||||
self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k])
|
||||
teacher_model_cfg = self.cfg.compute_precision.teacher[k]
|
||||
self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k])
|
||||
@@ -0,0 +1,319 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
from fvcore.common.checkpoint import PeriodicCheckpointer
|
||||
import torch
|
||||
|
||||
from dinov2.data import SamplerType, make_data_loader, make_dataset
|
||||
from dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator
|
||||
import dinov2.distributed as distributed
|
||||
from dinov2.fsdp import FSDPCheckpointer
|
||||
from dinov2.logging import MetricLogger
|
||||
from dinov2.utils.config import setup
|
||||
from dinov2.utils.utils import CosineScheduler
|
||||
|
||||
from dinov2.train.ssl_meta_arch import SSLMetaArch
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True # PyTorch 1.12 sets this to False by default
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
def get_args_parser(add_help: bool = True):
|
||||
parser = argparse.ArgumentParser("DINOv2 training", add_help=add_help)
|
||||
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
|
||||
parser.add_argument(
|
||||
"--no-resume",
|
||||
action="store_true",
|
||||
help="Whether to not attempt to resume from the checkpoint directory. ",
|
||||
)
|
||||
parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
|
||||
parser.add_argument("--eval", type=str, default="", help="Eval type to perform")
|
||||
parser.add_argument(
|
||||
"opts",
|
||||
help="""
|
||||
Modify config options at the end of the command. For Yacs configs, use
|
||||
space-separated "PATH.KEY VALUE" pairs.
|
||||
For python-based LazyConfig, use "path.key=value".
|
||||
""".strip(),
|
||||
default=None,
|
||||
nargs=argparse.REMAINDER,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
"--output_dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Output directory to save logs and checkpoints",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def build_optimizer(cfg, params_groups):
|
||||
return torch.optim.AdamW(params_groups, betas=(cfg.optim.adamw_beta1, cfg.optim.adamw_beta2))
|
||||
|
||||
|
||||
def build_schedulers(cfg):
|
||||
OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH
|
||||
lr = dict(
|
||||
base_value=cfg.optim["lr"],
|
||||
final_value=cfg.optim["min_lr"],
|
||||
total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH,
|
||||
warmup_iters=cfg.optim["warmup_epochs"] * OFFICIAL_EPOCH_LENGTH,
|
||||
start_warmup_value=0,
|
||||
)
|
||||
wd = dict(
|
||||
base_value=cfg.optim["weight_decay"],
|
||||
final_value=cfg.optim["weight_decay_end"],
|
||||
total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH,
|
||||
)
|
||||
momentum = dict(
|
||||
base_value=cfg.teacher["momentum_teacher"],
|
||||
final_value=cfg.teacher["final_momentum_teacher"],
|
||||
total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH,
|
||||
)
|
||||
teacher_temp = dict(
|
||||
base_value=cfg.teacher["teacher_temp"],
|
||||
final_value=cfg.teacher["teacher_temp"],
|
||||
total_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH,
|
||||
warmup_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH,
|
||||
start_warmup_value=cfg.teacher["warmup_teacher_temp"],
|
||||
)
|
||||
|
||||
lr_schedule = CosineScheduler(**lr)
|
||||
wd_schedule = CosineScheduler(**wd)
|
||||
momentum_schedule = CosineScheduler(**momentum)
|
||||
teacher_temp_schedule = CosineScheduler(**teacher_temp)
|
||||
last_layer_lr_schedule = CosineScheduler(**lr)
|
||||
|
||||
last_layer_lr_schedule.schedule[
|
||||
: cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH
|
||||
] = 0 # mimicking the original schedules
|
||||
|
||||
logger.info("Schedulers ready.")
|
||||
|
||||
return (
|
||||
lr_schedule,
|
||||
wd_schedule,
|
||||
momentum_schedule,
|
||||
teacher_temp_schedule,
|
||||
last_layer_lr_schedule,
|
||||
)
|
||||
|
||||
|
||||
def apply_optim_scheduler(optimizer, lr, wd, last_layer_lr):
|
||||
for param_group in optimizer.param_groups:
|
||||
is_last_layer = param_group["is_last_layer"]
|
||||
lr_multiplier = param_group["lr_multiplier"]
|
||||
wd_multiplier = param_group["wd_multiplier"]
|
||||
param_group["weight_decay"] = wd * wd_multiplier
|
||||
param_group["lr"] = (last_layer_lr if is_last_layer else lr) * lr_multiplier
|
||||
|
||||
|
||||
def do_test(cfg, model, iteration):
|
||||
new_state_dict = model.teacher.state_dict()
|
||||
|
||||
if distributed.is_main_process():
|
||||
iterstring = str(iteration)
|
||||
eval_dir = os.path.join(cfg.train.output_dir, "eval", iterstring)
|
||||
os.makedirs(eval_dir, exist_ok=True)
|
||||
# save teacher checkpoint
|
||||
teacher_ckp_path = os.path.join(eval_dir, "teacher_checkpoint.pth")
|
||||
torch.save({"teacher": new_state_dict}, teacher_ckp_path)
|
||||
|
||||
|
||||
def do_train(cfg, model, resume=False):
|
||||
model.train()
|
||||
inputs_dtype = torch.half
|
||||
fp16_scaler = model.fp16_scaler # for mixed precision training
|
||||
|
||||
# setup optimizer
|
||||
|
||||
optimizer = build_optimizer(cfg, model.get_params_groups())
|
||||
(
|
||||
lr_schedule,
|
||||
wd_schedule,
|
||||
momentum_schedule,
|
||||
teacher_temp_schedule,
|
||||
last_layer_lr_schedule,
|
||||
) = build_schedulers(cfg)
|
||||
|
||||
# checkpointer
|
||||
checkpointer = FSDPCheckpointer(model, cfg.train.output_dir, optimizer=optimizer, save_to_disk=True)
|
||||
|
||||
start_iter = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1
|
||||
|
||||
OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH
|
||||
max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH
|
||||
|
||||
periodic_checkpointer = PeriodicCheckpointer(
|
||||
checkpointer,
|
||||
period=3 * OFFICIAL_EPOCH_LENGTH,
|
||||
max_iter=max_iter,
|
||||
max_to_keep=3,
|
||||
)
|
||||
|
||||
# setup data preprocessing
|
||||
|
||||
img_size = cfg.crops.global_crops_size
|
||||
patch_size = cfg.student.patch_size
|
||||
n_tokens = (img_size // patch_size) ** 2
|
||||
mask_generator = MaskingGenerator(
|
||||
input_size=(img_size // patch_size, img_size // patch_size),
|
||||
max_num_patches=0.5 * img_size // patch_size * img_size // patch_size,
|
||||
)
|
||||
|
||||
data_transform = DataAugmentationDINO(
|
||||
cfg.crops.global_crops_scale,
|
||||
cfg.crops.local_crops_scale,
|
||||
cfg.crops.local_crops_number,
|
||||
global_crops_size=cfg.crops.global_crops_size,
|
||||
local_crops_size=cfg.crops.local_crops_size,
|
||||
)
|
||||
|
||||
collate_fn = partial(
|
||||
collate_data_and_cast,
|
||||
mask_ratio_tuple=cfg.ibot.mask_ratio_min_max,
|
||||
mask_probability=cfg.ibot.mask_sample_probability,
|
||||
n_tokens=n_tokens,
|
||||
mask_generator=mask_generator,
|
||||
dtype=inputs_dtype,
|
||||
)
|
||||
|
||||
# setup data loader
|
||||
|
||||
dataset = make_dataset(
|
||||
dataset_str=cfg.train.dataset_path,
|
||||
transform=data_transform,
|
||||
target_transform=lambda _: (),
|
||||
)
|
||||
# sampler_type = SamplerType.INFINITE
|
||||
sampler_type = SamplerType.SHARDED_INFINITE
|
||||
data_loader = make_data_loader(
|
||||
dataset=dataset,
|
||||
batch_size=cfg.train.batch_size_per_gpu,
|
||||
num_workers=cfg.train.num_workers,
|
||||
shuffle=True,
|
||||
seed=start_iter, # TODO: Fix this -- cfg.train.seed
|
||||
sampler_type=sampler_type,
|
||||
sampler_advance=0, # TODO(qas): fix this -- start_iter * cfg.train.batch_size_per_gpu,
|
||||
drop_last=True,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
# training loop
|
||||
|
||||
iteration = start_iter
|
||||
|
||||
logger.info("Starting training from iteration {}".format(start_iter))
|
||||
metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json")
|
||||
metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file)
|
||||
header = "Training"
|
||||
|
||||
for data in metric_logger.log_every(
|
||||
data_loader,
|
||||
10,
|
||||
header,
|
||||
max_iter,
|
||||
start_iter,
|
||||
):
|
||||
current_batch_size = data["collated_global_crops"].shape[0] / 2
|
||||
if iteration > max_iter:
|
||||
return
|
||||
|
||||
# apply schedules
|
||||
|
||||
lr = lr_schedule[iteration]
|
||||
wd = wd_schedule[iteration]
|
||||
mom = momentum_schedule[iteration]
|
||||
teacher_temp = teacher_temp_schedule[iteration]
|
||||
last_layer_lr = last_layer_lr_schedule[iteration]
|
||||
apply_optim_scheduler(optimizer, lr, wd, last_layer_lr)
|
||||
|
||||
# compute losses
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
loss_dict = model.forward_backward(data, teacher_temp=teacher_temp)
|
||||
|
||||
# clip gradients
|
||||
|
||||
if fp16_scaler is not None:
|
||||
if cfg.optim.clip_grad:
|
||||
fp16_scaler.unscale_(optimizer)
|
||||
for v in model.student.values():
|
||||
v.clip_grad_norm_(cfg.optim.clip_grad)
|
||||
fp16_scaler.step(optimizer)
|
||||
fp16_scaler.update()
|
||||
else:
|
||||
if cfg.optim.clip_grad:
|
||||
for v in model.student.values():
|
||||
v.clip_grad_norm_(cfg.optim.clip_grad)
|
||||
optimizer.step()
|
||||
|
||||
# perform teacher EMA update
|
||||
|
||||
model.update_teacher(mom)
|
||||
|
||||
# logging
|
||||
|
||||
if distributed.get_global_size() > 1:
|
||||
for v in loss_dict.values():
|
||||
torch.distributed.all_reduce(v)
|
||||
loss_dict_reduced = {k: v.item() / distributed.get_global_size() for k, v in loss_dict.items()}
|
||||
|
||||
if math.isnan(sum(loss_dict_reduced.values())):
|
||||
logger.info("NaN detected")
|
||||
raise AssertionError
|
||||
losses_reduced = sum(loss for loss in loss_dict_reduced.values())
|
||||
|
||||
metric_logger.update(lr=lr)
|
||||
metric_logger.update(wd=wd)
|
||||
metric_logger.update(mom=mom)
|
||||
metric_logger.update(last_layer_lr=last_layer_lr)
|
||||
metric_logger.update(current_batch_size=current_batch_size)
|
||||
metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced)
|
||||
|
||||
# checkpointing and testing
|
||||
|
||||
if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0:
|
||||
do_test(cfg, model, f"training_{iteration}")
|
||||
torch.cuda.synchronize()
|
||||
periodic_checkpointer.step(iteration)
|
||||
|
||||
iteration = iteration + 1
|
||||
metric_logger.synchronize_between_processes()
|
||||
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
||||
|
||||
|
||||
def main(args):
|
||||
cfg = setup(args)
|
||||
|
||||
model = SSLMetaArch(cfg).to(torch.device("cuda"))
|
||||
model.prepare_for_distributed_training()
|
||||
|
||||
logger.info("Model:\n{}".format(model))
|
||||
if args.eval_only:
|
||||
iteration = (
|
||||
FSDPCheckpointer(model, save_dir=cfg.train.output_dir)
|
||||
.resume_or_load(cfg.MODEL.WEIGHTS, resume=not args.no_resume)
|
||||
.get("iteration", -1)
|
||||
+ 1
|
||||
)
|
||||
return do_test(cfg, model, f"manual_{iteration}")
|
||||
|
||||
do_train(cfg, model, resume=not args.no_resume)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args_parser(add_help=True).parse_args()
|
||||
main(args)
|
||||
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
@@ -0,0 +1,96 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class ClusterType(Enum):
|
||||
AWS = "aws"
|
||||
FAIR = "fair"
|
||||
RSC = "rsc"
|
||||
|
||||
|
||||
def _guess_cluster_type() -> ClusterType:
|
||||
uname = os.uname()
|
||||
if uname.sysname == "Linux":
|
||||
if uname.release.endswith("-aws"):
|
||||
# Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
|
||||
return ClusterType.AWS
|
||||
elif uname.nodename.startswith("rsc"):
|
||||
# Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
|
||||
return ClusterType.RSC
|
||||
|
||||
return ClusterType.FAIR
|
||||
|
||||
|
||||
def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
|
||||
if cluster_type is None:
|
||||
return _guess_cluster_type()
|
||||
|
||||
return cluster_type
|
||||
|
||||
|
||||
def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
|
||||
cluster_type = get_cluster_type(cluster_type)
|
||||
if cluster_type is None:
|
||||
return None
|
||||
|
||||
CHECKPOINT_DIRNAMES = {
|
||||
ClusterType.AWS: "checkpoints",
|
||||
ClusterType.FAIR: "checkpoint",
|
||||
ClusterType.RSC: "checkpoint/dino",
|
||||
}
|
||||
return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
|
||||
|
||||
|
||||
def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
|
||||
checkpoint_path = get_checkpoint_path(cluster_type)
|
||||
if checkpoint_path is None:
|
||||
return None
|
||||
|
||||
username = os.environ.get("USER")
|
||||
assert username is not None
|
||||
return checkpoint_path / username
|
||||
|
||||
|
||||
def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
|
||||
cluster_type = get_cluster_type(cluster_type)
|
||||
if cluster_type is None:
|
||||
return None
|
||||
|
||||
SLURM_PARTITIONS = {
|
||||
ClusterType.AWS: "learnlab",
|
||||
ClusterType.FAIR: "learnlab",
|
||||
ClusterType.RSC: "learn",
|
||||
}
|
||||
return SLURM_PARTITIONS[cluster_type]
|
||||
|
||||
|
||||
def get_slurm_executor_parameters(
|
||||
nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
# create default parameters
|
||||
params = {
|
||||
"mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
|
||||
"gpus_per_node": num_gpus_per_node,
|
||||
"tasks_per_node": num_gpus_per_node, # one task per GPU
|
||||
"cpus_per_task": 10,
|
||||
"nodes": nodes,
|
||||
"slurm_partition": get_slurm_partition(cluster_type),
|
||||
}
|
||||
# apply cluster-specific adjustments
|
||||
cluster_type = get_cluster_type(cluster_type)
|
||||
if cluster_type == ClusterType.AWS:
|
||||
params["cpus_per_task"] = 12
|
||||
del params["mem_gb"]
|
||||
elif cluster_type == ClusterType.RSC:
|
||||
params["cpus_per_task"] = 12
|
||||
# set additional parameters / apply overrides
|
||||
params.update(kwargs)
|
||||
return params
|
||||
@@ -0,0 +1,73 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import logging
|
||||
import os
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import dinov2.distributed as distributed
|
||||
from dinov2.logging import setup_logging
|
||||
from dinov2.utils import utils
|
||||
from dinov2.configs import dinov2_default_config
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
def apply_scaling_rules_to_cfg(cfg): # to fix
|
||||
if cfg.optim.scaling_rule == "sqrt_wrt_1024":
|
||||
base_lr = cfg.optim.base_lr
|
||||
cfg.optim.lr = base_lr
|
||||
cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
|
||||
logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return cfg
|
||||
|
||||
|
||||
def write_config(cfg, output_dir, name="config.yaml"):
|
||||
logger.info(OmegaConf.to_yaml(cfg))
|
||||
saved_cfg_path = os.path.join(output_dir, name)
|
||||
with open(saved_cfg_path, "w") as f:
|
||||
OmegaConf.save(config=cfg, f=f)
|
||||
return saved_cfg_path
|
||||
|
||||
|
||||
def get_cfg_from_args(args):
|
||||
args.output_dir = os.path.abspath(args.output_dir)
|
||||
args.opts += [f"train.output_dir={args.output_dir}"]
|
||||
default_cfg = OmegaConf.create(dinov2_default_config)
|
||||
cfg = OmegaConf.load(args.config_file)
|
||||
cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
|
||||
return cfg
|
||||
|
||||
|
||||
def default_setup(args):
|
||||
distributed.enable(overwrite=True)
|
||||
seed = getattr(args, "seed", 0)
|
||||
rank = distributed.get_global_rank()
|
||||
|
||||
global logger
|
||||
setup_logging(output=args.output_dir, level=logging.INFO)
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
utils.fix_random_seeds(seed + rank)
|
||||
logger.info("git:\n {}\n".format(utils.get_sha()))
|
||||
logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
|
||||
|
||||
|
||||
def setup(args):
|
||||
"""
|
||||
Create configs and perform basic setups.
|
||||
"""
|
||||
cfg = get_cfg_from_args(args)
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
default_setup(args)
|
||||
apply_scaling_rules_to_cfg(cfg)
|
||||
write_config(cfg, args.output_dir)
|
||||
return cfg
|
||||
@@ -0,0 +1,38 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Dict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
TypeSpec = Union[str, np.dtype, torch.dtype]
|
||||
|
||||
|
||||
_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
|
||||
np.dtype("bool"): torch.bool,
|
||||
np.dtype("uint8"): torch.uint8,
|
||||
np.dtype("int8"): torch.int8,
|
||||
np.dtype("int16"): torch.int16,
|
||||
np.dtype("int32"): torch.int32,
|
||||
np.dtype("int64"): torch.int64,
|
||||
np.dtype("float16"): torch.float16,
|
||||
np.dtype("float32"): torch.float32,
|
||||
np.dtype("float64"): torch.float64,
|
||||
np.dtype("complex64"): torch.complex64,
|
||||
np.dtype("complex128"): torch.complex128,
|
||||
}
|
||||
|
||||
|
||||
def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
|
||||
if isinstance(dtype, torch.dtype):
|
||||
return dtype
|
||||
if isinstance(dtype, str):
|
||||
dtype = np.dtype(dtype)
|
||||
assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}"
|
||||
return _NUMPY_TO_TORCH_DTYPE[dtype]
|
||||
@@ -0,0 +1,94 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
|
||||
"""
|
||||
Calculate lr decay rate for different ViT blocks.
|
||||
Args:
|
||||
name (string): parameter name.
|
||||
lr_decay_rate (float): base lr decay rate.
|
||||
num_layers (int): number of ViT blocks.
|
||||
Returns:
|
||||
lr decay rate for the given parameter.
|
||||
"""
|
||||
layer_id = num_layers + 1
|
||||
if name.startswith("backbone") or force_is_backbone:
|
||||
if ".pos_embed" in name or ".patch_embed" in name or ".mask_token" in name or ".cls_token" in name:
|
||||
layer_id = 0
|
||||
elif force_is_backbone and (
|
||||
"pos_embed" in name or "patch_embed" in name or "mask_token" in name or "cls_token" in name
|
||||
):
|
||||
layer_id = 0
|
||||
elif ".blocks." in name and ".residual." not in name:
|
||||
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
|
||||
elif chunked_blocks and "blocks." in name and "residual." not in name:
|
||||
layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
|
||||
elif "blocks." in name and "residual." not in name:
|
||||
layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
|
||||
|
||||
return lr_decay_rate ** (num_layers + 1 - layer_id)
|
||||
|
||||
|
||||
def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
|
||||
chunked_blocks = False
|
||||
if hasattr(model, "n_blocks"):
|
||||
logger.info("chunked fsdp")
|
||||
n_blocks = model.n_blocks
|
||||
chunked_blocks = model.chunked_blocks
|
||||
elif hasattr(model, "blocks"):
|
||||
logger.info("first code branch")
|
||||
n_blocks = len(model.blocks)
|
||||
elif hasattr(model, "backbone"):
|
||||
logger.info("second code branch")
|
||||
n_blocks = len(model.backbone.blocks)
|
||||
else:
|
||||
logger.info("else code branch")
|
||||
n_blocks = 0
|
||||
all_param_groups = []
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
name = name.replace("_fsdp_wrapped_module.", "")
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
decay_rate = get_vit_lr_decay_rate(
|
||||
name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
|
||||
)
|
||||
d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
|
||||
|
||||
if "last_layer" in name:
|
||||
d.update({"is_last_layer": True})
|
||||
|
||||
if name.endswith(".bias") or "norm" in name or "gamma" in name:
|
||||
d.update({"wd_multiplier": 0.0})
|
||||
|
||||
if "patch_embed" in name:
|
||||
d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
|
||||
|
||||
all_param_groups.append(d)
|
||||
logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
|
||||
|
||||
return all_param_groups
|
||||
|
||||
|
||||
def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
|
||||
fused_params_groups = defaultdict(lambda: {"params": []})
|
||||
for d in all_params_groups:
|
||||
identifier = ""
|
||||
for k in keys:
|
||||
identifier += k + str(d[k]) + "_"
|
||||
|
||||
for k in keys:
|
||||
fused_params_groups[identifier][k] = d[k]
|
||||
fused_params_groups[identifier]["params"].append(d["params"])
|
||||
|
||||
return fused_params_groups.values()
|
||||
@@ -0,0 +1,96 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
|
||||
if urlparse(pretrained_weights).scheme: # If it looks like an URL
|
||||
state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
|
||||
else:
|
||||
state_dict = torch.load(pretrained_weights, map_location="cpu")
|
||||
if checkpoint_key is not None and checkpoint_key in state_dict:
|
||||
logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
|
||||
state_dict = state_dict[checkpoint_key]
|
||||
# remove `module.` prefix
|
||||
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
||||
# remove `backbone.` prefix induced by multicrop wrapper
|
||||
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
||||
msg = model.load_state_dict(state_dict, strict=False)
|
||||
logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
|
||||
|
||||
|
||||
def fix_random_seeds(seed=31):
|
||||
"""
|
||||
Fix random seeds.
|
||||
"""
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
def get_sha():
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
def _run(command):
|
||||
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
|
||||
|
||||
sha = "N/A"
|
||||
diff = "clean"
|
||||
branch = "N/A"
|
||||
try:
|
||||
sha = _run(["git", "rev-parse", "HEAD"])
|
||||
subprocess.check_output(["git", "diff"], cwd=cwd)
|
||||
diff = _run(["git", "diff-index", "HEAD"])
|
||||
diff = "has uncommitted changes" if diff else "clean"
|
||||
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
||||
except Exception:
|
||||
pass
|
||||
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
||||
return message
|
||||
|
||||
|
||||
class CosineScheduler(object):
|
||||
def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
|
||||
super().__init__()
|
||||
self.final_value = final_value
|
||||
self.total_iters = total_iters
|
||||
|
||||
freeze_schedule = np.zeros((freeze_iters))
|
||||
|
||||
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
||||
|
||||
iters = np.arange(total_iters - warmup_iters - freeze_iters)
|
||||
schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
|
||||
self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
|
||||
|
||||
assert len(self.schedule) == self.total_iters
|
||||
|
||||
def __getitem__(self, it):
|
||||
if it >= self.total_iters:
|
||||
return self.final_value
|
||||
else:
|
||||
return self.schedule[it]
|
||||
|
||||
|
||||
def has_batchnorms(model):
|
||||
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, bn_types):
|
||||
return True
|
||||
return False
|
||||
Reference in New Issue
Block a user