Initial media depth project backup

This commit is contained in:
Codex
2026-05-20 12:25:12 +08:00
commit 4a0aebb2bd
358 changed files with 182095 additions and 0 deletions

View File

@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from .adapters import DatasetWithEnumeratedTargets
from .loaders import make_data_loader, make_dataset, SamplerType
from .collate import collate_data_and_cast
from .masking import MaskingGenerator
from .augmentations import DataAugmentationDINO

View File

@@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Tuple
from torch.utils.data import Dataset
class DatasetWithEnumeratedTargets(Dataset):
def __init__(self, dataset):
self._dataset = dataset
def get_image_data(self, index: int) -> bytes:
return self._dataset.get_image_data(index)
def get_target(self, index: int) -> Tuple[Any, int]:
target = self._dataset.get_target(index)
return (index, target)
def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]:
image, target = self._dataset[index]
target = index if target is None else target
return image, (index, target)
def __len__(self) -> int:
return len(self._dataset)

View File

@@ -0,0 +1,119 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
from torchvision import transforms
from .transforms import (
GaussianBlur,
make_normalize_transform,
)
logger = logging.getLogger("dinov2")
class DataAugmentationDINO(object):
def __init__(
self,
global_crops_scale,
local_crops_scale,
local_crops_number,
global_crops_size=224,
local_crops_size=96,
):
self.global_crops_scale = global_crops_scale
self.local_crops_scale = local_crops_scale
self.local_crops_number = local_crops_number
self.global_crops_size = global_crops_size
self.local_crops_size = local_crops_size
logger.info("###################################")
logger.info("Using data augmentation parameters:")
logger.info(f"global_crops_scale: {global_crops_scale}")
logger.info(f"local_crops_scale: {local_crops_scale}")
logger.info(f"local_crops_number: {local_crops_number}")
logger.info(f"global_crops_size: {global_crops_size}")
logger.info(f"local_crops_size: {local_crops_size}")
logger.info("###################################")
# random resized crop and flip
self.geometric_augmentation_global = transforms.Compose(
[
transforms.RandomResizedCrop(
global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.RandomHorizontalFlip(p=0.5),
]
)
self.geometric_augmentation_local = transforms.Compose(
[
transforms.RandomResizedCrop(
local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.RandomHorizontalFlip(p=0.5),
]
)
# color distorsions / blurring
color_jittering = transforms.Compose(
[
transforms.RandomApply(
[transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
p=0.8,
),
transforms.RandomGrayscale(p=0.2),
]
)
global_transfo1_extra = GaussianBlur(p=1.0)
global_transfo2_extra = transforms.Compose(
[
GaussianBlur(p=0.1),
transforms.RandomSolarize(threshold=128, p=0.2),
]
)
local_transfo_extra = GaussianBlur(p=0.5)
# normalization
self.normalize = transforms.Compose(
[
transforms.ToTensor(),
make_normalize_transform(),
]
)
self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize])
self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize])
self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize])
def __call__(self, image):
output = {}
# global crops:
im1_base = self.geometric_augmentation_global(image)
global_crop_1 = self.global_transfo1(im1_base)
im2_base = self.geometric_augmentation_global(image)
global_crop_2 = self.global_transfo2(im2_base)
output["global_crops"] = [global_crop_1, global_crop_2]
# global crops for teacher:
output["global_crops_teacher"] = [global_crop_1, global_crop_2]
# local crops:
local_crops = [
self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number)
]
output["local_crops"] = local_crops
output["offsets"] = ()
return output

View File

@@ -0,0 +1,50 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import random
def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None):
# dtype = torch.half # TODO: Remove
n_global_crops = len(samples_list[0][0]["global_crops"])
n_local_crops = len(samples_list[0][0]["local_crops"])
collated_global_crops = torch.stack([s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list])
collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list])
B = len(collated_global_crops)
N = n_tokens
n_samples_masked = int(B * mask_probability)
probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1)
upperbound = 0
masks_list = []
for i in range(0, n_samples_masked):
prob_min = probs[i]
prob_max = probs[i + 1]
masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max)))))
upperbound += int(N * prob_max)
for i in range(n_samples_masked, B):
masks_list.append(torch.BoolTensor(mask_generator(0)))
random.shuffle(masks_list)
collated_masks = torch.stack(masks_list).flatten(1)
mask_indices_list = collated_masks.flatten().nonzero().flatten()
masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks]
return {
"collated_global_crops": collated_global_crops.to(dtype),
"collated_local_crops": collated_local_crops.to(dtype),
"collated_masks": collated_masks,
"mask_indices_list": mask_indices_list,
"masks_weight": masks_weight,
"upperbound": upperbound,
"n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long),
}

View File

@@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from .image_net import ImageNet
from .image_net_22k import ImageNet22k

View File

@@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from io import BytesIO
from typing import Any
from PIL import Image
class Decoder:
def decode(self) -> Any:
raise NotImplementedError
class ImageDataDecoder(Decoder):
def __init__(self, image_data: bytes) -> None:
self._image_data = image_data
def decode(self) -> Image:
f = BytesIO(self._image_data)
return Image.open(f).convert(mode="RGB")
class TargetDecoder(Decoder):
def __init__(self, target: Any):
self._target = target
def decode(self) -> Any:
return self._target

View File

@@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Tuple
from torchvision.datasets import VisionDataset
from .decoders import TargetDecoder, ImageDataDecoder
class ExtendedVisionDataset(VisionDataset):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) # type: ignore
def get_image_data(self, index: int) -> bytes:
raise NotImplementedError
def get_target(self, index: int) -> Any:
raise NotImplementedError
def __getitem__(self, index: int) -> Tuple[Any, Any]:
try:
image_data = self.get_image_data(index)
image = ImageDataDecoder(image_data).decode()
except Exception as e:
raise RuntimeError(f"can not read image for sample {index}") from e
target = self.get_target(index)
target = TargetDecoder(target).decode()
if self.transforms is not None:
image, target = self.transforms(image, target)
return image, target
def __len__(self) -> int:
raise NotImplementedError

View File

@@ -0,0 +1,291 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import csv
from enum import Enum
import logging
import os
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
from .extended import ExtendedVisionDataset
logger = logging.getLogger("dinov2")
_Target = int
class _Split(Enum):
TRAIN = "train"
VAL = "val"
TEST = "test" # NOTE: torchvision does not support the test split
@property
def length(self) -> int:
split_lengths = {
_Split.TRAIN: 1_281_167,
_Split.VAL: 50_000,
_Split.TEST: 100_000,
}
return split_lengths[self]
def get_dirname(self, class_id: Optional[str] = None) -> str:
return self.value if class_id is None else os.path.join(self.value, class_id)
def get_image_relpath(self, actual_index: int, class_id: Optional[str] = None) -> str:
dirname = self.get_dirname(class_id)
if self == _Split.TRAIN:
basename = f"{class_id}_{actual_index}"
else: # self in (_Split.VAL, _Split.TEST):
basename = f"ILSVRC2012_{self.value}_{actual_index:08d}"
return os.path.join(dirname, basename + ".JPEG")
def parse_image_relpath(self, image_relpath: str) -> Tuple[str, int]:
assert self != _Split.TEST
dirname, filename = os.path.split(image_relpath)
class_id = os.path.split(dirname)[-1]
basename, _ = os.path.splitext(filename)
actual_index = int(basename.split("_")[-1])
return class_id, actual_index
class ImageNet(ExtendedVisionDataset):
Target = Union[_Target]
Split = Union[_Split]
def __init__(
self,
*,
split: "ImageNet.Split",
root: str,
extra: str,
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super().__init__(root, transforms, transform, target_transform)
self._extra_root = extra
self._split = split
self._entries = None
self._class_ids = None
self._class_names = None
@property
def split(self) -> "ImageNet.Split":
return self._split
def _get_extra_full_path(self, extra_path: str) -> str:
return os.path.join(self._extra_root, extra_path)
def _load_extra(self, extra_path: str) -> np.ndarray:
extra_full_path = self._get_extra_full_path(extra_path)
return np.load(extra_full_path, mmap_mode="r")
def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
extra_full_path = self._get_extra_full_path(extra_path)
os.makedirs(self._extra_root, exist_ok=True)
np.save(extra_full_path, extra_array)
@property
def _entries_path(self) -> str:
return f"entries-{self._split.value.upper()}.npy"
@property
def _class_ids_path(self) -> str:
return f"class-ids-{self._split.value.upper()}.npy"
@property
def _class_names_path(self) -> str:
return f"class-names-{self._split.value.upper()}.npy"
def _get_entries(self) -> np.ndarray:
if self._entries is None:
self._entries = self._load_extra(self._entries_path)
assert self._entries is not None
return self._entries
def _get_class_ids(self) -> np.ndarray:
if self._split == _Split.TEST:
assert False, "Class IDs are not available in TEST split"
if self._class_ids is None:
self._class_ids = self._load_extra(self._class_ids_path)
assert self._class_ids is not None
return self._class_ids
def _get_class_names(self) -> np.ndarray:
if self._split == _Split.TEST:
assert False, "Class names are not available in TEST split"
if self._class_names is None:
self._class_names = self._load_extra(self._class_names_path)
assert self._class_names is not None
return self._class_names
def find_class_id(self, class_index: int) -> str:
class_ids = self._get_class_ids()
return str(class_ids[class_index])
def find_class_name(self, class_index: int) -> str:
class_names = self._get_class_names()
return str(class_names[class_index])
def get_image_data(self, index: int) -> bytes:
entries = self._get_entries()
actual_index = entries[index]["actual_index"]
class_id = self.get_class_id(index)
image_relpath = self.split.get_image_relpath(actual_index, class_id)
image_full_path = os.path.join(self.root, image_relpath)
with open(image_full_path, mode="rb") as f:
image_data = f.read()
return image_data
def get_target(self, index: int) -> Optional[Target]:
entries = self._get_entries()
class_index = entries[index]["class_index"]
return None if self.split == _Split.TEST else int(class_index)
def get_targets(self) -> Optional[np.ndarray]:
entries = self._get_entries()
return None if self.split == _Split.TEST else entries["class_index"]
def get_class_id(self, index: int) -> Optional[str]:
entries = self._get_entries()
class_id = entries[index]["class_id"]
return None if self.split == _Split.TEST else str(class_id)
def get_class_name(self, index: int) -> Optional[str]:
entries = self._get_entries()
class_name = entries[index]["class_name"]
return None if self.split == _Split.TEST else str(class_name)
def __len__(self) -> int:
entries = self._get_entries()
assert len(entries) == self.split.length
return len(entries)
def _load_labels(self, labels_path: str) -> List[Tuple[str, str]]:
labels_full_path = os.path.join(self.root, labels_path)
labels = []
try:
with open(labels_full_path, "r") as f:
reader = csv.reader(f)
for row in reader:
class_id, class_name = row
labels.append((class_id, class_name))
except OSError as e:
raise RuntimeError(f'can not read labels file "{labels_full_path}"') from e
return labels
def _dump_entries(self) -> None:
split = self.split
if split == ImageNet.Split.TEST:
dataset = None
sample_count = split.length
max_class_id_length, max_class_name_length = 0, 0
else:
labels_path = "labels.txt"
logger.info(f'loading labels from "{labels_path}"')
labels = self._load_labels(labels_path)
# NOTE: Using torchvision ImageFolder for consistency
from torchvision.datasets import ImageFolder
dataset_root = os.path.join(self.root, split.get_dirname())
dataset = ImageFolder(dataset_root)
sample_count = len(dataset)
max_class_id_length, max_class_name_length = -1, -1
for sample in dataset.samples:
_, class_index = sample
class_id, class_name = labels[class_index]
max_class_id_length = max(len(class_id), max_class_id_length)
max_class_name_length = max(len(class_name), max_class_name_length)
dtype = np.dtype(
[
("actual_index", "<u4"),
("class_index", "<u4"),
("class_id", f"U{max_class_id_length}"),
("class_name", f"U{max_class_name_length}"),
]
)
entries_array = np.empty(sample_count, dtype=dtype)
if split == ImageNet.Split.TEST:
old_percent = -1
for index in range(sample_count):
percent = 100 * (index + 1) // sample_count
if percent > old_percent:
logger.info(f"creating entries: {percent}%")
old_percent = percent
actual_index = index + 1
class_index = np.uint32(-1)
class_id, class_name = "", ""
entries_array[index] = (actual_index, class_index, class_id, class_name)
else:
class_names = {class_id: class_name for class_id, class_name in labels}
assert dataset
old_percent = -1
for index in range(sample_count):
percent = 100 * (index + 1) // sample_count
if percent > old_percent:
logger.info(f"creating entries: {percent}%")
old_percent = percent
image_full_path, class_index = dataset.samples[index]
image_relpath = os.path.relpath(image_full_path, self.root)
class_id, actual_index = split.parse_image_relpath(image_relpath)
class_name = class_names[class_id]
entries_array[index] = (actual_index, class_index, class_id, class_name)
logger.info(f'saving entries to "{self._entries_path}"')
self._save_extra(entries_array, self._entries_path)
def _dump_class_ids_and_names(self) -> None:
split = self.split
if split == ImageNet.Split.TEST:
return
entries_array = self._load_extra(self._entries_path)
max_class_id_length, max_class_name_length, max_class_index = -1, -1, -1
for entry in entries_array:
class_index, class_id, class_name = (
entry["class_index"],
entry["class_id"],
entry["class_name"],
)
max_class_index = max(int(class_index), max_class_index)
max_class_id_length = max(len(str(class_id)), max_class_id_length)
max_class_name_length = max(len(str(class_name)), max_class_name_length)
class_count = max_class_index + 1
class_ids_array = np.empty(class_count, dtype=f"U{max_class_id_length}")
class_names_array = np.empty(class_count, dtype=f"U{max_class_name_length}")
for entry in entries_array:
class_index, class_id, class_name = (
entry["class_index"],
entry["class_id"],
entry["class_name"],
)
class_ids_array[class_index] = class_id
class_names_array[class_index] = class_name
logger.info(f'saving class IDs to "{self._class_ids_path}"')
self._save_extra(class_ids_array, self._class_ids_path)
logger.info(f'saving class names to "{self._class_names_path}"')
self._save_extra(class_names_array, self._class_names_path)
def dump_extra(self) -> None:
self._dump_entries()
self._dump_class_ids_and_names()

View File

@@ -0,0 +1,303 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from enum import Enum
from functools import lru_cache
from gzip import GzipFile
from io import BytesIO
from mmap import ACCESS_READ, mmap
import os
from typing import Any, Callable, List, Optional, Set, Tuple
import warnings
import numpy as np
from .extended import ExtendedVisionDataset
_Labels = int
_DEFAULT_MMAP_CACHE_SIZE = 16 # Warning: This can exhaust file descriptors
@dataclass
class _ClassEntry:
block_offset: int
maybe_filename: Optional[str] = None
@dataclass
class _Entry:
class_index: int # noqa: E701
start_offset: int
end_offset: int
filename: str
class _Split(Enum):
TRAIN = "train"
VAL = "val"
@property
def length(self) -> int:
return {
_Split.TRAIN: 11_797_647,
_Split.VAL: 561_050,
}[self]
def entries_path(self):
return f"imagenet21kp_{self.value}.txt"
def _get_tarball_path(class_id: str) -> str:
return f"{class_id}.tar"
def _make_mmap_tarball(tarballs_root: str, mmap_cache_size: int):
@lru_cache(maxsize=mmap_cache_size)
def _mmap_tarball(class_id: str) -> mmap:
tarball_path = _get_tarball_path(class_id)
tarball_full_path = os.path.join(tarballs_root, tarball_path)
with open(tarball_full_path) as f:
return mmap(fileno=f.fileno(), length=0, access=ACCESS_READ)
return _mmap_tarball
class ImageNet22k(ExtendedVisionDataset):
_GZIPPED_INDICES: Set[int] = {
841_545,
1_304_131,
2_437_921,
2_672_079,
2_795_676,
2_969_786,
6_902_965,
6_903_550,
6_903_628,
7_432_557,
7_432_589,
7_813_809,
8_329_633,
10_296_990,
10_417_652,
10_492_265,
10_598_078,
10_782_398,
10_902_612,
11_203_736,
11_342_890,
11_397_596,
11_589_762,
11_705_103,
12_936_875,
13_289_782,
}
Labels = _Labels
def __init__(
self,
*,
root: str,
extra: str,
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
mmap_cache_size: int = _DEFAULT_MMAP_CACHE_SIZE,
) -> None:
super().__init__(root, transforms, transform, target_transform)
self._extra_root = extra
entries_path = self._get_entries_path(root)
self._entries = self._load_extra(entries_path)
class_ids_path = self._get_class_ids_path(root)
self._class_ids = self._load_extra(class_ids_path)
self._gzipped_indices = ImageNet22k._GZIPPED_INDICES
self._mmap_tarball = _make_mmap_tarball(self._tarballs_root, mmap_cache_size)
def _get_entries_path(self, root: Optional[str] = None) -> str:
return "entries.npy"
def _get_class_ids_path(self, root: Optional[str] = None) -> str:
return "class-ids.npy"
def _find_class_ids(self, path: str) -> List[str]:
class_ids = []
with os.scandir(path) as entries:
for entry in entries:
root, ext = os.path.splitext(entry.name)
if ext != ".tar":
continue
class_ids.append(root)
return sorted(class_ids)
def _load_entries_class_ids(self, root: Optional[str] = None) -> Tuple[List[_Entry], List[str]]:
root = self.get_root(root)
entries: List[_Entry] = []
class_ids = self._find_class_ids(root)
for class_index, class_id in enumerate(class_ids):
path = os.path.join(root, "blocks", f"{class_id}.log")
class_entries = []
try:
with open(path) as f:
for line in f:
line = line.rstrip()
block, filename = line.split(":")
block_offset = int(block[6:])
filename = filename[1:]
maybe_filename = None
if filename != "** Block of NULs **":
maybe_filename = filename
_, ext = os.path.splitext(filename)
# assert ext == ".JPEG"
class_entry = _ClassEntry(block_offset, maybe_filename)
class_entries.append(class_entry)
except OSError as e:
raise RuntimeError(f'can not read blocks file "{path}"') from e
assert class_entries[-1].maybe_filename is None
for class_entry1, class_entry2 in zip(class_entries, class_entries[1:]):
assert class_entry1.block_offset <= class_entry2.block_offset
start_offset = 512 * class_entry1.block_offset
end_offset = 512 * class_entry2.block_offset
assert class_entry1.maybe_filename is not None
filename = class_entry1.maybe_filename
entry = _Entry(class_index, start_offset, end_offset, filename)
# Skip invalid image files (PIL throws UnidentifiedImageError)
if filename == "n06470073_47249.JPEG":
continue
entries.append(entry)
return entries, class_ids
def _load_extra(self, extra_path: str) -> np.ndarray:
extra_root = self._extra_root
extra_full_path = os.path.join(extra_root, extra_path)
return np.load(extra_full_path, mmap_mode="r")
def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
extra_root = self._extra_root
extra_full_path = os.path.join(extra_root, extra_path)
os.makedirs(extra_root, exist_ok=True)
np.save(extra_full_path, extra_array)
@property
def _tarballs_root(self) -> str:
return self.root
def find_class_id(self, class_index: int) -> str:
return str(self._class_ids[class_index])
def get_image_data(self, index: int) -> bytes:
entry = self._entries[index]
class_id = entry["class_id"]
class_mmap = self._mmap_tarball(class_id)
start_offset, end_offset = entry["start_offset"], entry["end_offset"]
try:
mapped_data = class_mmap[start_offset:end_offset]
data = mapped_data[512:] # Skip entry header block
if len(data) >= 2 and tuple(data[:2]) == (0x1F, 0x8B):
assert index in self._gzipped_indices, f"unexpected gzip header for sample {index}"
with GzipFile(fileobj=BytesIO(data)) as g:
data = g.read()
except Exception as e:
raise RuntimeError(f"can not retrieve image data for sample {index} " f'from "{class_id}" tarball') from e
return data
def get_target(self, index: int) -> Any:
return int(self._entries[index]["class_index"])
def get_targets(self) -> np.ndarray:
return self._entries["class_index"]
def get_class_id(self, index: int) -> str:
return str(self._entries[index]["class_id"])
def get_class_ids(self) -> np.ndarray:
return self._entries["class_id"]
def __getitem__(self, index: int) -> Tuple[Any, Any]:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return super().__getitem__(index)
def __len__(self) -> int:
return len(self._entries)
def _dump_entries(self, *args, **kwargs) -> None:
entries, class_ids = self._load_entries_class_ids(*args, **kwargs)
max_class_id_length, max_filename_length, max_class_index = -1, -1, -1
for entry in entries:
class_id = class_ids[entry.class_index]
max_class_index = max(entry.class_index, max_class_index)
max_class_id_length = max(len(class_id), max_class_id_length)
max_filename_length = max(len(entry.filename), max_filename_length)
dtype = np.dtype(
[
("class_index", "<u4"),
("class_id", f"U{max_class_id_length}"),
("start_offset", "<u4"),
("end_offset", "<u4"),
("filename", f"U{max_filename_length}"),
]
)
sample_count = len(entries)
entries_array = np.empty(sample_count, dtype=dtype)
for i, entry in enumerate(entries):
class_index = entry.class_index
class_id = class_ids[class_index]
start_offset = entry.start_offset
end_offset = entry.end_offset
filename = entry.filename
entries_array[i] = (
class_index,
class_id,
start_offset,
end_offset,
filename,
)
entries_path = self._get_entries_path(*args, **kwargs)
self._save_extra(entries_array, entries_path)
def _dump_class_ids(self, *args, **kwargs) -> None:
entries_path = self._get_entries_path(*args, **kwargs)
entries_array = self._load_extra(entries_path)
max_class_id_length, max_class_index = -1, -1
for entry in entries_array:
class_index, class_id = entry["class_index"], entry["class_id"]
max_class_index = max(int(class_index), max_class_index)
max_class_id_length = max(len(str(class_id)), max_class_id_length)
class_ids_array = np.empty(max_class_index + 1, dtype=f"U{max_class_id_length}")
for entry in entries_array:
class_index, class_id = entry["class_index"], entry["class_id"]
class_ids_array[class_index] = class_id
class_ids_path = self._get_class_ids_path(*args, **kwargs)
self._save_extra(class_ids_array, class_ids_path)
def _dump_extra(self, *args, **kwargs) -> None:
self._dump_entries(*args, *kwargs)
self._dump_class_ids(*args, *kwargs)
def dump_extra(self, root: Optional[str] = None) -> None:
return self._dump_extra(root)

View File

@@ -0,0 +1,223 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
from enum import Enum
from typing import Any, Callable, List, Optional, TypeVar
import torch
from torch.utils.data import Sampler
from .datasets import ImageNet, ImageNet22k
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
logger = logging.getLogger("dinov2")
class SamplerType(Enum):
DISTRIBUTED = 0
EPOCH = 1
INFINITE = 2
SHARDED_INFINITE = 3
SHARDED_INFINITE_NEW = 4
def _make_bool_str(b: bool) -> str:
return "yes" if b else "no"
def _make_sample_transform(image_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
def transform(sample):
image, target = sample
if image_transform is not None:
image = image_transform(image)
if target_transform is not None:
target = target_transform(target)
return image, target
return transform
def _parse_dataset_str(dataset_str: str):
tokens = dataset_str.split(":")
name = tokens[0]
kwargs = {}
for token in tokens[1:]:
key, value = token.split("=")
assert key in ("root", "extra", "split")
kwargs[key] = value
if name == "ImageNet":
class_ = ImageNet
if "split" in kwargs:
kwargs["split"] = ImageNet.Split[kwargs["split"]]
elif name == "ImageNet22k":
class_ = ImageNet22k
else:
raise ValueError(f'Unsupported dataset "{name}"')
return class_, kwargs
def make_dataset(
*,
dataset_str: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
):
"""
Creates a dataset with the specified parameters.
Args:
dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN).
transform: A transform to apply to images.
target_transform: A transform to apply to targets.
Returns:
The created dataset.
"""
logger.info(f'using dataset: "{dataset_str}"')
class_, kwargs = _parse_dataset_str(dataset_str)
dataset = class_(transform=transform, target_transform=target_transform, **kwargs)
logger.info(f"# of dataset samples: {len(dataset):,d}")
# Aggregated datasets do not expose (yet) these attributes, so add them.
if not hasattr(dataset, "transform"):
setattr(dataset, "transform", transform)
if not hasattr(dataset, "target_transform"):
setattr(dataset, "target_transform", target_transform)
return dataset
def _make_sampler(
*,
dataset,
type: Optional[SamplerType] = None,
shuffle: bool = False,
seed: int = 0,
size: int = -1,
advance: int = 0,
) -> Optional[Sampler]:
sample_count = len(dataset)
if type == SamplerType.INFINITE:
logger.info("sampler: infinite")
if size > 0:
raise ValueError("sampler size > 0 is invalid")
return InfiniteSampler(
sample_count=sample_count,
shuffle=shuffle,
seed=seed,
advance=advance,
)
elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW):
logger.info("sampler: sharded infinite")
if size > 0:
raise ValueError("sampler size > 0 is invalid")
# TODO: Remove support for old shuffling
use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW
return ShardedInfiniteSampler(
sample_count=sample_count,
shuffle=shuffle,
seed=seed,
advance=advance,
use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice,
)
elif type == SamplerType.EPOCH:
logger.info("sampler: epoch")
if advance > 0:
raise NotImplementedError("sampler advance > 0 is not supported")
size = size if size > 0 else sample_count
logger.info(f"# of samples / epoch: {size:,d}")
return EpochSampler(
size=size,
sample_count=sample_count,
shuffle=shuffle,
seed=seed,
)
elif type == SamplerType.DISTRIBUTED:
logger.info("sampler: distributed")
if size > 0:
raise ValueError("sampler size > 0 is invalid")
if advance > 0:
raise ValueError("sampler advance > 0 is invalid")
return torch.utils.data.DistributedSampler(
dataset=dataset,
shuffle=shuffle,
seed=seed,
drop_last=False,
)
logger.info("sampler: none")
return None
T = TypeVar("T")
def make_data_loader(
*,
dataset,
batch_size: int,
num_workers: int,
shuffle: bool = True,
seed: int = 0,
sampler_type: Optional[SamplerType] = SamplerType.INFINITE,
sampler_size: int = -1,
sampler_advance: int = 0,
drop_last: bool = True,
persistent_workers: bool = False,
collate_fn: Optional[Callable[[List[T]], Any]] = None,
):
"""
Creates a data loader with the specified parameters.
Args:
dataset: A dataset (third party, LaViDa or WebDataset).
batch_size: The size of batches to generate.
num_workers: The number of workers to use.
shuffle: Whether to shuffle samples.
seed: The random seed to use.
sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None.
sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset.
sampler_advance: How many samples to skip (when applicable).
drop_last: Whether the last non-full batch of data should be dropped.
persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once.
collate_fn: Function that performs batch collation
"""
sampler = _make_sampler(
dataset=dataset,
type=sampler_type,
shuffle=shuffle,
seed=seed,
size=sampler_size,
advance=sampler_advance,
)
logger.info("using PyTorch data loader")
data_loader = torch.utils.data.DataLoader(
dataset,
sampler=sampler,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
drop_last=drop_last,
persistent_workers=persistent_workers,
collate_fn=collate_fn,
)
try:
logger.info(f"# of batches: {len(data_loader):,d}")
except TypeError: # data loader has no length
logger.info("infinite data loader")
return data_loader

View File

@@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import random
import math
import numpy as np
class MaskingGenerator:
def __init__(
self,
input_size,
num_masking_patches=None,
min_num_patches=4,
max_num_patches=None,
min_aspect=0.3,
max_aspect=None,
):
if not isinstance(input_size, tuple):
input_size = (input_size,) * 2
self.height, self.width = input_size
self.num_patches = self.height * self.width
self.num_masking_patches = num_masking_patches
self.min_num_patches = min_num_patches
self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches
max_aspect = max_aspect or 1 / min_aspect
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
def __repr__(self):
repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
self.height,
self.width,
self.min_num_patches,
self.max_num_patches,
self.num_masking_patches,
self.log_aspect_ratio[0],
self.log_aspect_ratio[1],
)
return repr_str
def get_shape(self):
return self.height, self.width
def _mask(self, mask, max_mask_patches):
delta = 0
for _ in range(10):
target_area = random.uniform(self.min_num_patches, max_mask_patches)
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < self.width and h < self.height:
top = random.randint(0, self.height - h)
left = random.randint(0, self.width - w)
num_masked = mask[top : top + h, left : left + w].sum()
# Overlap
if 0 < h * w - num_masked <= max_mask_patches:
for i in range(top, top + h):
for j in range(left, left + w):
if mask[i, j] == 0:
mask[i, j] = 1
delta += 1
if delta > 0:
break
return delta
def __call__(self, num_masking_patches=0):
mask = np.zeros(shape=self.get_shape(), dtype=bool)
mask_count = 0
while mask_count < num_masking_patches:
max_mask_patches = num_masking_patches - mask_count
max_mask_patches = min(max_mask_patches, self.max_num_patches)
delta = self._mask(mask, max_mask_patches)
if delta == 0:
break
else:
mask_count += delta
return mask

View File

@@ -0,0 +1,230 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import itertools
from typing import Any, Optional
import warnings
import numpy as np
import torch
from torch.utils.data.sampler import Sampler
import dinov2.distributed as distributed
class EpochSampler(Sampler):
def __init__(
self,
*,
size: int,
sample_count: int,
shuffle: bool = False,
seed: int = 0,
start: Optional[int] = None,
step: Optional[int] = None,
):
self._size = size
self._sample_count = sample_count
self._shuffle = shuffle
self._seed = seed
self._start = distributed.get_global_rank() if start is None else start
self._step = distributed.get_global_size() if step is None else step
self._epoch = 0
def __iter__(self):
count = (self._size + self._sample_count - 1) // self._sample_count
tiled_indices = np.tile(np.arange(self._sample_count), count)
if self._shuffle:
seed = self._seed * self._epoch if self._seed != 0 else self._epoch
rng = np.random.default_rng(seed)
iterable = rng.choice(tiled_indices, self._size, replace=False)
else:
iterable = tiled_indices[: self._size]
yield from itertools.islice(iterable, self._start, None, self._step)
def __len__(self):
return (self._size - self._start + self._step - 1) // self._step
def set_epoch(self, epoch):
self._epoch = epoch
def _get_numpy_dtype(size: int) -> Any:
return np.int32 if size <= 2**31 else np.int64
def _get_torch_dtype(size: int) -> Any:
return torch.int32 if size <= 2**31 else torch.int64
def _generate_randperm_indices(*, size: int, generator: torch.Generator):
"""Generate the indices of a random permutation."""
dtype = _get_torch_dtype(size)
# This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921
perm = torch.arange(size, dtype=dtype)
for i in range(size):
j = torch.randint(i, size, size=(1,), generator=generator).item()
# Always swap even if no-op
value = perm[j].item()
perm[j] = perm[i].item()
perm[i] = value
yield value
class InfiniteSampler(Sampler):
def __init__(
self,
*,
sample_count: int,
shuffle: bool = False,
seed: int = 0,
start: Optional[int] = None,
step: Optional[int] = None,
advance: int = 0,
):
self._sample_count = sample_count
self._seed = seed
self._shuffle = shuffle
self._start = distributed.get_global_rank() if start is None else start
self._step = distributed.get_global_size() if step is None else step
self._advance = advance
def __iter__(self):
if self._shuffle:
iterator = self._shuffled_iterator()
else:
iterator = self._iterator()
yield from itertools.islice(iterator, self._advance, None)
def _iterator(self):
assert not self._shuffle
while True:
iterable = range(self._sample_count)
yield from itertools.islice(iterable, self._start, None, self._step)
def _shuffled_iterator(self):
assert self._shuffle
# Instantiate a generator here (rather than in the ctor) to keep the class
# picklable (requirement of mp.spawn)
generator = torch.Generator().manual_seed(self._seed)
while True:
iterable = _generate_randperm_indices(size=self._sample_count, generator=generator)
yield from itertools.islice(iterable, self._start, None, self._step)
# The following function is somewhat equivalent to _new_shuffle_tensor_slice below,
# but avoids a full in-place random permutation generation.
def _shuffle_tensor_slice(
*, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
) -> np.ndarray:
stop = len(tensor)
count = stop // step
drop_count = stop - step * count
if drop_count:
warnings.warn(f"# of dropped samples: {drop_count}")
dtype = _get_numpy_dtype(stop)
result = np.empty(count, dtype=dtype)
for i in range(count):
j = torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0
result[i] = result[j]
result[j] = tensor[start + i * step].item()
return result
def _new_shuffle_tensor_slice(
*, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
) -> np.ndarray:
stop = len(tensor)
count = stop // step
dtype = torch.int64 # Needed for using randperm result as indices
count = stop // step
drop_count = stop - step * count
if drop_count:
warnings.warn(f"# of dropped samples: {drop_count}")
indices = torch.randperm(count, dtype=dtype, generator=generator)
return tensor[start::step][indices].numpy()
def _make_seed(seed: int, start: int, iter_count: int) -> int:
# NOTE: Tried a few variants (including iter_count << 32), this one worked best.
return seed + start + (iter_count << 24)
class ShardedInfiniteSampler(Sampler):
def __init__(
self,
*,
sample_count: int,
shuffle: bool = False,
seed: int = 0,
start: Optional[int] = None,
step: Optional[int] = None,
advance: int = 0,
use_new_shuffle_tensor_slice: bool = False,
):
self._sample_count = sample_count
self._seed = seed
self._shuffle = shuffle
self._start = distributed.get_global_rank() if start is None else start
self._step = distributed.get_global_size() if step is None else step
self._advance = advance
self._iter_count = 0
self._shuffle_tensor_slice_fn = (
_new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice
)
def __iter__(self):
iter_count = self._advance // self._sample_count
if iter_count > 0:
self._advance -= iter_count * self._sample_count
self._iter_count += iter_count
if self._shuffle:
iterator = self._shuffled_iterator()
else:
iterator = self._iterator()
yield from itertools.islice(iterator, self._advance, None)
def _iterator(self):
assert not self._shuffle
while True:
iterable = range(self._sample_count)
yield from itertools.islice(iterable, self._start, None, self._step)
def _shuffled_iterator(self):
assert self._shuffle
# Instantiate a generator here (rather than in the ctor) to be keep the class
# picklable (requirement of mp.spawn)
generator = torch.Generator()
# Always shuffle everything first
generator.manual_seed(self._seed)
dtype = _get_torch_dtype(self._sample_count)
perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator)
while True:
# Re-seed on each iteration to allow skipping whole permutations
seed = _make_seed(self._seed, self._start, self._iter_count)
generator.manual_seed(seed)
iterable = self._shuffle_tensor_slice_fn(
tensor=perm, start=self._start, step=self._step, generator=generator
)
yield from iterable
self._iter_count += 1

View File

@@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Sequence
import torch
from torchvision import transforms
class GaussianBlur(transforms.RandomApply):
"""
Apply Gaussian Blur to the PIL image.
"""
def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0):
# NOTE: torchvision is applying 1 - probability to return the original image
keep_p = 1 - p
transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max))
super().__init__(transforms=[transform], p=keep_p)
class MaybeToTensor(transforms.ToTensor):
"""
Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor.
"""
def __call__(self, pic):
"""
Args:
pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor.
Returns:
Tensor: Converted image.
"""
if isinstance(pic, torch.Tensor):
return pic
return super().__call__(pic)
# Use timm's names
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
def make_normalize_transform(
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
std: Sequence[float] = IMAGENET_DEFAULT_STD,
) -> transforms.Normalize:
return transforms.Normalize(mean=mean, std=std)
# This roughly matches torchvision's preset for classification training:
# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44
def make_classification_train_transform(
*,
crop_size: int = 224,
interpolation=transforms.InterpolationMode.BICUBIC,
hflip_prob: float = 0.5,
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
std: Sequence[float] = IMAGENET_DEFAULT_STD,
):
transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
if hflip_prob > 0.0:
transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob))
transforms_list.extend(
[
MaybeToTensor(),
make_normalize_transform(mean=mean, std=std),
]
)
return transforms.Compose(transforms_list)
# This matches (roughly) torchvision's preset for classification evaluation:
# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69
def make_classification_eval_transform(
*,
resize_size: int = 256,
interpolation=transforms.InterpolationMode.BICUBIC,
crop_size: int = 224,
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
std: Sequence[float] = IMAGENET_DEFAULT_STD,
) -> transforms.Compose:
transforms_list = [
transforms.Resize(resize_size, interpolation=interpolation),
transforms.CenterCrop(crop_size),
MaybeToTensor(),
make_normalize_transform(mean=mean, std=std),
]
return transforms.Compose(transforms_list)