Initial media depth project backup
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .adapters import DatasetWithEnumeratedTargets
|
||||
from .loaders import make_data_loader, make_dataset, SamplerType
|
||||
from .collate import collate_data_and_cast
|
||||
from .masking import MaskingGenerator
|
||||
from .augmentations import DataAugmentationDINO
|
||||
@@ -0,0 +1,29 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Tuple
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class DatasetWithEnumeratedTargets(Dataset):
|
||||
def __init__(self, dataset):
|
||||
self._dataset = dataset
|
||||
|
||||
def get_image_data(self, index: int) -> bytes:
|
||||
return self._dataset.get_image_data(index)
|
||||
|
||||
def get_target(self, index: int) -> Tuple[Any, int]:
|
||||
target = self._dataset.get_target(index)
|
||||
return (index, target)
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]:
|
||||
image, target = self._dataset[index]
|
||||
target = index if target is None else target
|
||||
return image, (index, target)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._dataset)
|
||||
@@ -0,0 +1,119 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
from torchvision import transforms
|
||||
|
||||
from .transforms import (
|
||||
GaussianBlur,
|
||||
make_normalize_transform,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
class DataAugmentationDINO(object):
|
||||
def __init__(
|
||||
self,
|
||||
global_crops_scale,
|
||||
local_crops_scale,
|
||||
local_crops_number,
|
||||
global_crops_size=224,
|
||||
local_crops_size=96,
|
||||
):
|
||||
self.global_crops_scale = global_crops_scale
|
||||
self.local_crops_scale = local_crops_scale
|
||||
self.local_crops_number = local_crops_number
|
||||
self.global_crops_size = global_crops_size
|
||||
self.local_crops_size = local_crops_size
|
||||
|
||||
logger.info("###################################")
|
||||
logger.info("Using data augmentation parameters:")
|
||||
logger.info(f"global_crops_scale: {global_crops_scale}")
|
||||
logger.info(f"local_crops_scale: {local_crops_scale}")
|
||||
logger.info(f"local_crops_number: {local_crops_number}")
|
||||
logger.info(f"global_crops_size: {global_crops_size}")
|
||||
logger.info(f"local_crops_size: {local_crops_size}")
|
||||
logger.info("###################################")
|
||||
|
||||
# random resized crop and flip
|
||||
self.geometric_augmentation_global = transforms.Compose(
|
||||
[
|
||||
transforms.RandomResizedCrop(
|
||||
global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
|
||||
),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
]
|
||||
)
|
||||
|
||||
self.geometric_augmentation_local = transforms.Compose(
|
||||
[
|
||||
transforms.RandomResizedCrop(
|
||||
local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
|
||||
),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
]
|
||||
)
|
||||
|
||||
# color distorsions / blurring
|
||||
color_jittering = transforms.Compose(
|
||||
[
|
||||
transforms.RandomApply(
|
||||
[transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
|
||||
p=0.8,
|
||||
),
|
||||
transforms.RandomGrayscale(p=0.2),
|
||||
]
|
||||
)
|
||||
|
||||
global_transfo1_extra = GaussianBlur(p=1.0)
|
||||
|
||||
global_transfo2_extra = transforms.Compose(
|
||||
[
|
||||
GaussianBlur(p=0.1),
|
||||
transforms.RandomSolarize(threshold=128, p=0.2),
|
||||
]
|
||||
)
|
||||
|
||||
local_transfo_extra = GaussianBlur(p=0.5)
|
||||
|
||||
# normalization
|
||||
self.normalize = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
make_normalize_transform(),
|
||||
]
|
||||
)
|
||||
|
||||
self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize])
|
||||
self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize])
|
||||
self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize])
|
||||
|
||||
def __call__(self, image):
|
||||
output = {}
|
||||
|
||||
# global crops:
|
||||
im1_base = self.geometric_augmentation_global(image)
|
||||
global_crop_1 = self.global_transfo1(im1_base)
|
||||
|
||||
im2_base = self.geometric_augmentation_global(image)
|
||||
global_crop_2 = self.global_transfo2(im2_base)
|
||||
|
||||
output["global_crops"] = [global_crop_1, global_crop_2]
|
||||
|
||||
# global crops for teacher:
|
||||
output["global_crops_teacher"] = [global_crop_1, global_crop_2]
|
||||
|
||||
# local crops:
|
||||
local_crops = [
|
||||
self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number)
|
||||
]
|
||||
output["local_crops"] = local_crops
|
||||
output["offsets"] = ()
|
||||
|
||||
return output
|
||||
@@ -0,0 +1,50 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import random
|
||||
|
||||
|
||||
def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None):
|
||||
# dtype = torch.half # TODO: Remove
|
||||
|
||||
n_global_crops = len(samples_list[0][0]["global_crops"])
|
||||
n_local_crops = len(samples_list[0][0]["local_crops"])
|
||||
|
||||
collated_global_crops = torch.stack([s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list])
|
||||
|
||||
collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list])
|
||||
|
||||
B = len(collated_global_crops)
|
||||
N = n_tokens
|
||||
n_samples_masked = int(B * mask_probability)
|
||||
probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1)
|
||||
upperbound = 0
|
||||
masks_list = []
|
||||
for i in range(0, n_samples_masked):
|
||||
prob_min = probs[i]
|
||||
prob_max = probs[i + 1]
|
||||
masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max)))))
|
||||
upperbound += int(N * prob_max)
|
||||
for i in range(n_samples_masked, B):
|
||||
masks_list.append(torch.BoolTensor(mask_generator(0)))
|
||||
|
||||
random.shuffle(masks_list)
|
||||
|
||||
collated_masks = torch.stack(masks_list).flatten(1)
|
||||
mask_indices_list = collated_masks.flatten().nonzero().flatten()
|
||||
|
||||
masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks]
|
||||
|
||||
return {
|
||||
"collated_global_crops": collated_global_crops.to(dtype),
|
||||
"collated_local_crops": collated_local_crops.to(dtype),
|
||||
"collated_masks": collated_masks,
|
||||
"mask_indices_list": mask_indices_list,
|
||||
"masks_weight": masks_weight,
|
||||
"upperbound": upperbound,
|
||||
"n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long),
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .image_net import ImageNet
|
||||
from .image_net_22k import ImageNet22k
|
||||
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class Decoder:
|
||||
def decode(self) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ImageDataDecoder(Decoder):
|
||||
def __init__(self, image_data: bytes) -> None:
|
||||
self._image_data = image_data
|
||||
|
||||
def decode(self) -> Image:
|
||||
f = BytesIO(self._image_data)
|
||||
return Image.open(f).convert(mode="RGB")
|
||||
|
||||
|
||||
class TargetDecoder(Decoder):
|
||||
def __init__(self, target: Any):
|
||||
self._target = target
|
||||
|
||||
def decode(self) -> Any:
|
||||
return self._target
|
||||
@@ -0,0 +1,39 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Tuple
|
||||
|
||||
from torchvision.datasets import VisionDataset
|
||||
|
||||
from .decoders import TargetDecoder, ImageDataDecoder
|
||||
|
||||
|
||||
class ExtendedVisionDataset(VisionDataset):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs) # type: ignore
|
||||
|
||||
def get_image_data(self, index: int) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_target(self, index: int) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||||
try:
|
||||
image_data = self.get_image_data(index)
|
||||
image = ImageDataDecoder(image_data).decode()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"can not read image for sample {index}") from e
|
||||
target = self.get_target(index)
|
||||
target = TargetDecoder(target).decode()
|
||||
|
||||
if self.transforms is not None:
|
||||
image, target = self.transforms(image, target)
|
||||
|
||||
return image, target
|
||||
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,291 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import csv
|
||||
from enum import Enum
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .extended import ExtendedVisionDataset
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
_Target = int
|
||||
|
||||
|
||||
class _Split(Enum):
|
||||
TRAIN = "train"
|
||||
VAL = "val"
|
||||
TEST = "test" # NOTE: torchvision does not support the test split
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
split_lengths = {
|
||||
_Split.TRAIN: 1_281_167,
|
||||
_Split.VAL: 50_000,
|
||||
_Split.TEST: 100_000,
|
||||
}
|
||||
return split_lengths[self]
|
||||
|
||||
def get_dirname(self, class_id: Optional[str] = None) -> str:
|
||||
return self.value if class_id is None else os.path.join(self.value, class_id)
|
||||
|
||||
def get_image_relpath(self, actual_index: int, class_id: Optional[str] = None) -> str:
|
||||
dirname = self.get_dirname(class_id)
|
||||
if self == _Split.TRAIN:
|
||||
basename = f"{class_id}_{actual_index}"
|
||||
else: # self in (_Split.VAL, _Split.TEST):
|
||||
basename = f"ILSVRC2012_{self.value}_{actual_index:08d}"
|
||||
return os.path.join(dirname, basename + ".JPEG")
|
||||
|
||||
def parse_image_relpath(self, image_relpath: str) -> Tuple[str, int]:
|
||||
assert self != _Split.TEST
|
||||
dirname, filename = os.path.split(image_relpath)
|
||||
class_id = os.path.split(dirname)[-1]
|
||||
basename, _ = os.path.splitext(filename)
|
||||
actual_index = int(basename.split("_")[-1])
|
||||
return class_id, actual_index
|
||||
|
||||
|
||||
class ImageNet(ExtendedVisionDataset):
|
||||
Target = Union[_Target]
|
||||
Split = Union[_Split]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
split: "ImageNet.Split",
|
||||
root: str,
|
||||
extra: str,
|
||||
transforms: Optional[Callable] = None,
|
||||
transform: Optional[Callable] = None,
|
||||
target_transform: Optional[Callable] = None,
|
||||
) -> None:
|
||||
super().__init__(root, transforms, transform, target_transform)
|
||||
self._extra_root = extra
|
||||
self._split = split
|
||||
|
||||
self._entries = None
|
||||
self._class_ids = None
|
||||
self._class_names = None
|
||||
|
||||
@property
|
||||
def split(self) -> "ImageNet.Split":
|
||||
return self._split
|
||||
|
||||
def _get_extra_full_path(self, extra_path: str) -> str:
|
||||
return os.path.join(self._extra_root, extra_path)
|
||||
|
||||
def _load_extra(self, extra_path: str) -> np.ndarray:
|
||||
extra_full_path = self._get_extra_full_path(extra_path)
|
||||
return np.load(extra_full_path, mmap_mode="r")
|
||||
|
||||
def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
|
||||
extra_full_path = self._get_extra_full_path(extra_path)
|
||||
os.makedirs(self._extra_root, exist_ok=True)
|
||||
np.save(extra_full_path, extra_array)
|
||||
|
||||
@property
|
||||
def _entries_path(self) -> str:
|
||||
return f"entries-{self._split.value.upper()}.npy"
|
||||
|
||||
@property
|
||||
def _class_ids_path(self) -> str:
|
||||
return f"class-ids-{self._split.value.upper()}.npy"
|
||||
|
||||
@property
|
||||
def _class_names_path(self) -> str:
|
||||
return f"class-names-{self._split.value.upper()}.npy"
|
||||
|
||||
def _get_entries(self) -> np.ndarray:
|
||||
if self._entries is None:
|
||||
self._entries = self._load_extra(self._entries_path)
|
||||
assert self._entries is not None
|
||||
return self._entries
|
||||
|
||||
def _get_class_ids(self) -> np.ndarray:
|
||||
if self._split == _Split.TEST:
|
||||
assert False, "Class IDs are not available in TEST split"
|
||||
if self._class_ids is None:
|
||||
self._class_ids = self._load_extra(self._class_ids_path)
|
||||
assert self._class_ids is not None
|
||||
return self._class_ids
|
||||
|
||||
def _get_class_names(self) -> np.ndarray:
|
||||
if self._split == _Split.TEST:
|
||||
assert False, "Class names are not available in TEST split"
|
||||
if self._class_names is None:
|
||||
self._class_names = self._load_extra(self._class_names_path)
|
||||
assert self._class_names is not None
|
||||
return self._class_names
|
||||
|
||||
def find_class_id(self, class_index: int) -> str:
|
||||
class_ids = self._get_class_ids()
|
||||
return str(class_ids[class_index])
|
||||
|
||||
def find_class_name(self, class_index: int) -> str:
|
||||
class_names = self._get_class_names()
|
||||
return str(class_names[class_index])
|
||||
|
||||
def get_image_data(self, index: int) -> bytes:
|
||||
entries = self._get_entries()
|
||||
actual_index = entries[index]["actual_index"]
|
||||
|
||||
class_id = self.get_class_id(index)
|
||||
|
||||
image_relpath = self.split.get_image_relpath(actual_index, class_id)
|
||||
image_full_path = os.path.join(self.root, image_relpath)
|
||||
with open(image_full_path, mode="rb") as f:
|
||||
image_data = f.read()
|
||||
return image_data
|
||||
|
||||
def get_target(self, index: int) -> Optional[Target]:
|
||||
entries = self._get_entries()
|
||||
class_index = entries[index]["class_index"]
|
||||
return None if self.split == _Split.TEST else int(class_index)
|
||||
|
||||
def get_targets(self) -> Optional[np.ndarray]:
|
||||
entries = self._get_entries()
|
||||
return None if self.split == _Split.TEST else entries["class_index"]
|
||||
|
||||
def get_class_id(self, index: int) -> Optional[str]:
|
||||
entries = self._get_entries()
|
||||
class_id = entries[index]["class_id"]
|
||||
return None if self.split == _Split.TEST else str(class_id)
|
||||
|
||||
def get_class_name(self, index: int) -> Optional[str]:
|
||||
entries = self._get_entries()
|
||||
class_name = entries[index]["class_name"]
|
||||
return None if self.split == _Split.TEST else str(class_name)
|
||||
|
||||
def __len__(self) -> int:
|
||||
entries = self._get_entries()
|
||||
assert len(entries) == self.split.length
|
||||
return len(entries)
|
||||
|
||||
def _load_labels(self, labels_path: str) -> List[Tuple[str, str]]:
|
||||
labels_full_path = os.path.join(self.root, labels_path)
|
||||
labels = []
|
||||
|
||||
try:
|
||||
with open(labels_full_path, "r") as f:
|
||||
reader = csv.reader(f)
|
||||
for row in reader:
|
||||
class_id, class_name = row
|
||||
labels.append((class_id, class_name))
|
||||
except OSError as e:
|
||||
raise RuntimeError(f'can not read labels file "{labels_full_path}"') from e
|
||||
|
||||
return labels
|
||||
|
||||
def _dump_entries(self) -> None:
|
||||
split = self.split
|
||||
if split == ImageNet.Split.TEST:
|
||||
dataset = None
|
||||
sample_count = split.length
|
||||
max_class_id_length, max_class_name_length = 0, 0
|
||||
else:
|
||||
labels_path = "labels.txt"
|
||||
logger.info(f'loading labels from "{labels_path}"')
|
||||
labels = self._load_labels(labels_path)
|
||||
|
||||
# NOTE: Using torchvision ImageFolder for consistency
|
||||
from torchvision.datasets import ImageFolder
|
||||
|
||||
dataset_root = os.path.join(self.root, split.get_dirname())
|
||||
dataset = ImageFolder(dataset_root)
|
||||
sample_count = len(dataset)
|
||||
max_class_id_length, max_class_name_length = -1, -1
|
||||
for sample in dataset.samples:
|
||||
_, class_index = sample
|
||||
class_id, class_name = labels[class_index]
|
||||
max_class_id_length = max(len(class_id), max_class_id_length)
|
||||
max_class_name_length = max(len(class_name), max_class_name_length)
|
||||
|
||||
dtype = np.dtype(
|
||||
[
|
||||
("actual_index", "<u4"),
|
||||
("class_index", "<u4"),
|
||||
("class_id", f"U{max_class_id_length}"),
|
||||
("class_name", f"U{max_class_name_length}"),
|
||||
]
|
||||
)
|
||||
entries_array = np.empty(sample_count, dtype=dtype)
|
||||
|
||||
if split == ImageNet.Split.TEST:
|
||||
old_percent = -1
|
||||
for index in range(sample_count):
|
||||
percent = 100 * (index + 1) // sample_count
|
||||
if percent > old_percent:
|
||||
logger.info(f"creating entries: {percent}%")
|
||||
old_percent = percent
|
||||
|
||||
actual_index = index + 1
|
||||
class_index = np.uint32(-1)
|
||||
class_id, class_name = "", ""
|
||||
entries_array[index] = (actual_index, class_index, class_id, class_name)
|
||||
else:
|
||||
class_names = {class_id: class_name for class_id, class_name in labels}
|
||||
|
||||
assert dataset
|
||||
old_percent = -1
|
||||
for index in range(sample_count):
|
||||
percent = 100 * (index + 1) // sample_count
|
||||
if percent > old_percent:
|
||||
logger.info(f"creating entries: {percent}%")
|
||||
old_percent = percent
|
||||
|
||||
image_full_path, class_index = dataset.samples[index]
|
||||
image_relpath = os.path.relpath(image_full_path, self.root)
|
||||
class_id, actual_index = split.parse_image_relpath(image_relpath)
|
||||
class_name = class_names[class_id]
|
||||
entries_array[index] = (actual_index, class_index, class_id, class_name)
|
||||
|
||||
logger.info(f'saving entries to "{self._entries_path}"')
|
||||
self._save_extra(entries_array, self._entries_path)
|
||||
|
||||
def _dump_class_ids_and_names(self) -> None:
|
||||
split = self.split
|
||||
if split == ImageNet.Split.TEST:
|
||||
return
|
||||
|
||||
entries_array = self._load_extra(self._entries_path)
|
||||
|
||||
max_class_id_length, max_class_name_length, max_class_index = -1, -1, -1
|
||||
for entry in entries_array:
|
||||
class_index, class_id, class_name = (
|
||||
entry["class_index"],
|
||||
entry["class_id"],
|
||||
entry["class_name"],
|
||||
)
|
||||
max_class_index = max(int(class_index), max_class_index)
|
||||
max_class_id_length = max(len(str(class_id)), max_class_id_length)
|
||||
max_class_name_length = max(len(str(class_name)), max_class_name_length)
|
||||
|
||||
class_count = max_class_index + 1
|
||||
class_ids_array = np.empty(class_count, dtype=f"U{max_class_id_length}")
|
||||
class_names_array = np.empty(class_count, dtype=f"U{max_class_name_length}")
|
||||
for entry in entries_array:
|
||||
class_index, class_id, class_name = (
|
||||
entry["class_index"],
|
||||
entry["class_id"],
|
||||
entry["class_name"],
|
||||
)
|
||||
class_ids_array[class_index] = class_id
|
||||
class_names_array[class_index] = class_name
|
||||
|
||||
logger.info(f'saving class IDs to "{self._class_ids_path}"')
|
||||
self._save_extra(class_ids_array, self._class_ids_path)
|
||||
|
||||
logger.info(f'saving class names to "{self._class_names_path}"')
|
||||
self._save_extra(class_names_array, self._class_names_path)
|
||||
|
||||
def dump_extra(self) -> None:
|
||||
self._dump_entries()
|
||||
self._dump_class_ids_and_names()
|
||||
@@ -0,0 +1,303 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from gzip import GzipFile
|
||||
from io import BytesIO
|
||||
from mmap import ACCESS_READ, mmap
|
||||
import os
|
||||
from typing import Any, Callable, List, Optional, Set, Tuple
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .extended import ExtendedVisionDataset
|
||||
|
||||
|
||||
_Labels = int
|
||||
|
||||
_DEFAULT_MMAP_CACHE_SIZE = 16 # Warning: This can exhaust file descriptors
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ClassEntry:
|
||||
block_offset: int
|
||||
maybe_filename: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Entry:
|
||||
class_index: int # noqa: E701
|
||||
start_offset: int
|
||||
end_offset: int
|
||||
filename: str
|
||||
|
||||
|
||||
class _Split(Enum):
|
||||
TRAIN = "train"
|
||||
VAL = "val"
|
||||
|
||||
@property
|
||||
def length(self) -> int:
|
||||
return {
|
||||
_Split.TRAIN: 11_797_647,
|
||||
_Split.VAL: 561_050,
|
||||
}[self]
|
||||
|
||||
def entries_path(self):
|
||||
return f"imagenet21kp_{self.value}.txt"
|
||||
|
||||
|
||||
def _get_tarball_path(class_id: str) -> str:
|
||||
return f"{class_id}.tar"
|
||||
|
||||
|
||||
def _make_mmap_tarball(tarballs_root: str, mmap_cache_size: int):
|
||||
@lru_cache(maxsize=mmap_cache_size)
|
||||
def _mmap_tarball(class_id: str) -> mmap:
|
||||
tarball_path = _get_tarball_path(class_id)
|
||||
tarball_full_path = os.path.join(tarballs_root, tarball_path)
|
||||
with open(tarball_full_path) as f:
|
||||
return mmap(fileno=f.fileno(), length=0, access=ACCESS_READ)
|
||||
|
||||
return _mmap_tarball
|
||||
|
||||
|
||||
class ImageNet22k(ExtendedVisionDataset):
|
||||
_GZIPPED_INDICES: Set[int] = {
|
||||
841_545,
|
||||
1_304_131,
|
||||
2_437_921,
|
||||
2_672_079,
|
||||
2_795_676,
|
||||
2_969_786,
|
||||
6_902_965,
|
||||
6_903_550,
|
||||
6_903_628,
|
||||
7_432_557,
|
||||
7_432_589,
|
||||
7_813_809,
|
||||
8_329_633,
|
||||
10_296_990,
|
||||
10_417_652,
|
||||
10_492_265,
|
||||
10_598_078,
|
||||
10_782_398,
|
||||
10_902_612,
|
||||
11_203_736,
|
||||
11_342_890,
|
||||
11_397_596,
|
||||
11_589_762,
|
||||
11_705_103,
|
||||
12_936_875,
|
||||
13_289_782,
|
||||
}
|
||||
Labels = _Labels
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
root: str,
|
||||
extra: str,
|
||||
transforms: Optional[Callable] = None,
|
||||
transform: Optional[Callable] = None,
|
||||
target_transform: Optional[Callable] = None,
|
||||
mmap_cache_size: int = _DEFAULT_MMAP_CACHE_SIZE,
|
||||
) -> None:
|
||||
super().__init__(root, transforms, transform, target_transform)
|
||||
self._extra_root = extra
|
||||
|
||||
entries_path = self._get_entries_path(root)
|
||||
self._entries = self._load_extra(entries_path)
|
||||
|
||||
class_ids_path = self._get_class_ids_path(root)
|
||||
self._class_ids = self._load_extra(class_ids_path)
|
||||
|
||||
self._gzipped_indices = ImageNet22k._GZIPPED_INDICES
|
||||
self._mmap_tarball = _make_mmap_tarball(self._tarballs_root, mmap_cache_size)
|
||||
|
||||
def _get_entries_path(self, root: Optional[str] = None) -> str:
|
||||
return "entries.npy"
|
||||
|
||||
def _get_class_ids_path(self, root: Optional[str] = None) -> str:
|
||||
return "class-ids.npy"
|
||||
|
||||
def _find_class_ids(self, path: str) -> List[str]:
|
||||
class_ids = []
|
||||
|
||||
with os.scandir(path) as entries:
|
||||
for entry in entries:
|
||||
root, ext = os.path.splitext(entry.name)
|
||||
if ext != ".tar":
|
||||
continue
|
||||
class_ids.append(root)
|
||||
|
||||
return sorted(class_ids)
|
||||
|
||||
def _load_entries_class_ids(self, root: Optional[str] = None) -> Tuple[List[_Entry], List[str]]:
|
||||
root = self.get_root(root)
|
||||
entries: List[_Entry] = []
|
||||
class_ids = self._find_class_ids(root)
|
||||
|
||||
for class_index, class_id in enumerate(class_ids):
|
||||
path = os.path.join(root, "blocks", f"{class_id}.log")
|
||||
class_entries = []
|
||||
|
||||
try:
|
||||
with open(path) as f:
|
||||
for line in f:
|
||||
line = line.rstrip()
|
||||
block, filename = line.split(":")
|
||||
block_offset = int(block[6:])
|
||||
filename = filename[1:]
|
||||
|
||||
maybe_filename = None
|
||||
if filename != "** Block of NULs **":
|
||||
maybe_filename = filename
|
||||
_, ext = os.path.splitext(filename)
|
||||
# assert ext == ".JPEG"
|
||||
|
||||
class_entry = _ClassEntry(block_offset, maybe_filename)
|
||||
class_entries.append(class_entry)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f'can not read blocks file "{path}"') from e
|
||||
|
||||
assert class_entries[-1].maybe_filename is None
|
||||
|
||||
for class_entry1, class_entry2 in zip(class_entries, class_entries[1:]):
|
||||
assert class_entry1.block_offset <= class_entry2.block_offset
|
||||
start_offset = 512 * class_entry1.block_offset
|
||||
end_offset = 512 * class_entry2.block_offset
|
||||
assert class_entry1.maybe_filename is not None
|
||||
filename = class_entry1.maybe_filename
|
||||
entry = _Entry(class_index, start_offset, end_offset, filename)
|
||||
# Skip invalid image files (PIL throws UnidentifiedImageError)
|
||||
if filename == "n06470073_47249.JPEG":
|
||||
continue
|
||||
entries.append(entry)
|
||||
|
||||
return entries, class_ids
|
||||
|
||||
def _load_extra(self, extra_path: str) -> np.ndarray:
|
||||
extra_root = self._extra_root
|
||||
extra_full_path = os.path.join(extra_root, extra_path)
|
||||
return np.load(extra_full_path, mmap_mode="r")
|
||||
|
||||
def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
|
||||
extra_root = self._extra_root
|
||||
extra_full_path = os.path.join(extra_root, extra_path)
|
||||
os.makedirs(extra_root, exist_ok=True)
|
||||
np.save(extra_full_path, extra_array)
|
||||
|
||||
@property
|
||||
def _tarballs_root(self) -> str:
|
||||
return self.root
|
||||
|
||||
def find_class_id(self, class_index: int) -> str:
|
||||
return str(self._class_ids[class_index])
|
||||
|
||||
def get_image_data(self, index: int) -> bytes:
|
||||
entry = self._entries[index]
|
||||
class_id = entry["class_id"]
|
||||
class_mmap = self._mmap_tarball(class_id)
|
||||
|
||||
start_offset, end_offset = entry["start_offset"], entry["end_offset"]
|
||||
try:
|
||||
mapped_data = class_mmap[start_offset:end_offset]
|
||||
data = mapped_data[512:] # Skip entry header block
|
||||
|
||||
if len(data) >= 2 and tuple(data[:2]) == (0x1F, 0x8B):
|
||||
assert index in self._gzipped_indices, f"unexpected gzip header for sample {index}"
|
||||
with GzipFile(fileobj=BytesIO(data)) as g:
|
||||
data = g.read()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"can not retrieve image data for sample {index} " f'from "{class_id}" tarball') from e
|
||||
|
||||
return data
|
||||
|
||||
def get_target(self, index: int) -> Any:
|
||||
return int(self._entries[index]["class_index"])
|
||||
|
||||
def get_targets(self) -> np.ndarray:
|
||||
return self._entries["class_index"]
|
||||
|
||||
def get_class_id(self, index: int) -> str:
|
||||
return str(self._entries[index]["class_id"])
|
||||
|
||||
def get_class_ids(self) -> np.ndarray:
|
||||
return self._entries["class_id"]
|
||||
|
||||
def __getitem__(self, index: int) -> Tuple[Any, Any]:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
return super().__getitem__(index)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._entries)
|
||||
|
||||
def _dump_entries(self, *args, **kwargs) -> None:
|
||||
entries, class_ids = self._load_entries_class_ids(*args, **kwargs)
|
||||
|
||||
max_class_id_length, max_filename_length, max_class_index = -1, -1, -1
|
||||
for entry in entries:
|
||||
class_id = class_ids[entry.class_index]
|
||||
max_class_index = max(entry.class_index, max_class_index)
|
||||
max_class_id_length = max(len(class_id), max_class_id_length)
|
||||
max_filename_length = max(len(entry.filename), max_filename_length)
|
||||
|
||||
dtype = np.dtype(
|
||||
[
|
||||
("class_index", "<u4"),
|
||||
("class_id", f"U{max_class_id_length}"),
|
||||
("start_offset", "<u4"),
|
||||
("end_offset", "<u4"),
|
||||
("filename", f"U{max_filename_length}"),
|
||||
]
|
||||
)
|
||||
sample_count = len(entries)
|
||||
entries_array = np.empty(sample_count, dtype=dtype)
|
||||
for i, entry in enumerate(entries):
|
||||
class_index = entry.class_index
|
||||
class_id = class_ids[class_index]
|
||||
start_offset = entry.start_offset
|
||||
end_offset = entry.end_offset
|
||||
filename = entry.filename
|
||||
entries_array[i] = (
|
||||
class_index,
|
||||
class_id,
|
||||
start_offset,
|
||||
end_offset,
|
||||
filename,
|
||||
)
|
||||
|
||||
entries_path = self._get_entries_path(*args, **kwargs)
|
||||
self._save_extra(entries_array, entries_path)
|
||||
|
||||
def _dump_class_ids(self, *args, **kwargs) -> None:
|
||||
entries_path = self._get_entries_path(*args, **kwargs)
|
||||
entries_array = self._load_extra(entries_path)
|
||||
|
||||
max_class_id_length, max_class_index = -1, -1
|
||||
for entry in entries_array:
|
||||
class_index, class_id = entry["class_index"], entry["class_id"]
|
||||
max_class_index = max(int(class_index), max_class_index)
|
||||
max_class_id_length = max(len(str(class_id)), max_class_id_length)
|
||||
|
||||
class_ids_array = np.empty(max_class_index + 1, dtype=f"U{max_class_id_length}")
|
||||
for entry in entries_array:
|
||||
class_index, class_id = entry["class_index"], entry["class_id"]
|
||||
class_ids_array[class_index] = class_id
|
||||
class_ids_path = self._get_class_ids_path(*args, **kwargs)
|
||||
self._save_extra(class_ids_array, class_ids_path)
|
||||
|
||||
def _dump_extra(self, *args, **kwargs) -> None:
|
||||
self._dump_entries(*args, *kwargs)
|
||||
self._dump_class_ids(*args, *kwargs)
|
||||
|
||||
def dump_extra(self, root: Optional[str] = None) -> None:
|
||||
return self._dump_extra(root)
|
||||
@@ -0,0 +1,223 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, List, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
from .datasets import ImageNet, ImageNet22k
|
||||
from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
class SamplerType(Enum):
|
||||
DISTRIBUTED = 0
|
||||
EPOCH = 1
|
||||
INFINITE = 2
|
||||
SHARDED_INFINITE = 3
|
||||
SHARDED_INFINITE_NEW = 4
|
||||
|
||||
|
||||
def _make_bool_str(b: bool) -> str:
|
||||
return "yes" if b else "no"
|
||||
|
||||
|
||||
def _make_sample_transform(image_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
|
||||
def transform(sample):
|
||||
image, target = sample
|
||||
if image_transform is not None:
|
||||
image = image_transform(image)
|
||||
if target_transform is not None:
|
||||
target = target_transform(target)
|
||||
return image, target
|
||||
|
||||
return transform
|
||||
|
||||
|
||||
def _parse_dataset_str(dataset_str: str):
|
||||
tokens = dataset_str.split(":")
|
||||
|
||||
name = tokens[0]
|
||||
kwargs = {}
|
||||
|
||||
for token in tokens[1:]:
|
||||
key, value = token.split("=")
|
||||
assert key in ("root", "extra", "split")
|
||||
kwargs[key] = value
|
||||
|
||||
if name == "ImageNet":
|
||||
class_ = ImageNet
|
||||
if "split" in kwargs:
|
||||
kwargs["split"] = ImageNet.Split[kwargs["split"]]
|
||||
elif name == "ImageNet22k":
|
||||
class_ = ImageNet22k
|
||||
else:
|
||||
raise ValueError(f'Unsupported dataset "{name}"')
|
||||
|
||||
return class_, kwargs
|
||||
|
||||
|
||||
def make_dataset(
|
||||
*,
|
||||
dataset_str: str,
|
||||
transform: Optional[Callable] = None,
|
||||
target_transform: Optional[Callable] = None,
|
||||
):
|
||||
"""
|
||||
Creates a dataset with the specified parameters.
|
||||
|
||||
Args:
|
||||
dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN).
|
||||
transform: A transform to apply to images.
|
||||
target_transform: A transform to apply to targets.
|
||||
|
||||
Returns:
|
||||
The created dataset.
|
||||
"""
|
||||
logger.info(f'using dataset: "{dataset_str}"')
|
||||
|
||||
class_, kwargs = _parse_dataset_str(dataset_str)
|
||||
dataset = class_(transform=transform, target_transform=target_transform, **kwargs)
|
||||
|
||||
logger.info(f"# of dataset samples: {len(dataset):,d}")
|
||||
|
||||
# Aggregated datasets do not expose (yet) these attributes, so add them.
|
||||
if not hasattr(dataset, "transform"):
|
||||
setattr(dataset, "transform", transform)
|
||||
if not hasattr(dataset, "target_transform"):
|
||||
setattr(dataset, "target_transform", target_transform)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def _make_sampler(
|
||||
*,
|
||||
dataset,
|
||||
type: Optional[SamplerType] = None,
|
||||
shuffle: bool = False,
|
||||
seed: int = 0,
|
||||
size: int = -1,
|
||||
advance: int = 0,
|
||||
) -> Optional[Sampler]:
|
||||
sample_count = len(dataset)
|
||||
|
||||
if type == SamplerType.INFINITE:
|
||||
logger.info("sampler: infinite")
|
||||
if size > 0:
|
||||
raise ValueError("sampler size > 0 is invalid")
|
||||
return InfiniteSampler(
|
||||
sample_count=sample_count,
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
advance=advance,
|
||||
)
|
||||
elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW):
|
||||
logger.info("sampler: sharded infinite")
|
||||
if size > 0:
|
||||
raise ValueError("sampler size > 0 is invalid")
|
||||
# TODO: Remove support for old shuffling
|
||||
use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW
|
||||
return ShardedInfiniteSampler(
|
||||
sample_count=sample_count,
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
advance=advance,
|
||||
use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice,
|
||||
)
|
||||
elif type == SamplerType.EPOCH:
|
||||
logger.info("sampler: epoch")
|
||||
if advance > 0:
|
||||
raise NotImplementedError("sampler advance > 0 is not supported")
|
||||
size = size if size > 0 else sample_count
|
||||
logger.info(f"# of samples / epoch: {size:,d}")
|
||||
return EpochSampler(
|
||||
size=size,
|
||||
sample_count=sample_count,
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
)
|
||||
elif type == SamplerType.DISTRIBUTED:
|
||||
logger.info("sampler: distributed")
|
||||
if size > 0:
|
||||
raise ValueError("sampler size > 0 is invalid")
|
||||
if advance > 0:
|
||||
raise ValueError("sampler advance > 0 is invalid")
|
||||
return torch.utils.data.DistributedSampler(
|
||||
dataset=dataset,
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
logger.info("sampler: none")
|
||||
return None
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def make_data_loader(
|
||||
*,
|
||||
dataset,
|
||||
batch_size: int,
|
||||
num_workers: int,
|
||||
shuffle: bool = True,
|
||||
seed: int = 0,
|
||||
sampler_type: Optional[SamplerType] = SamplerType.INFINITE,
|
||||
sampler_size: int = -1,
|
||||
sampler_advance: int = 0,
|
||||
drop_last: bool = True,
|
||||
persistent_workers: bool = False,
|
||||
collate_fn: Optional[Callable[[List[T]], Any]] = None,
|
||||
):
|
||||
"""
|
||||
Creates a data loader with the specified parameters.
|
||||
|
||||
Args:
|
||||
dataset: A dataset (third party, LaViDa or WebDataset).
|
||||
batch_size: The size of batches to generate.
|
||||
num_workers: The number of workers to use.
|
||||
shuffle: Whether to shuffle samples.
|
||||
seed: The random seed to use.
|
||||
sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None.
|
||||
sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset.
|
||||
sampler_advance: How many samples to skip (when applicable).
|
||||
drop_last: Whether the last non-full batch of data should be dropped.
|
||||
persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once.
|
||||
collate_fn: Function that performs batch collation
|
||||
"""
|
||||
|
||||
sampler = _make_sampler(
|
||||
dataset=dataset,
|
||||
type=sampler_type,
|
||||
shuffle=shuffle,
|
||||
seed=seed,
|
||||
size=sampler_size,
|
||||
advance=sampler_advance,
|
||||
)
|
||||
|
||||
logger.info("using PyTorch data loader")
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=drop_last,
|
||||
persistent_workers=persistent_workers,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
try:
|
||||
logger.info(f"# of batches: {len(data_loader):,d}")
|
||||
except TypeError: # data loader has no length
|
||||
logger.info("infinite data loader")
|
||||
return data_loader
|
||||
@@ -0,0 +1,87 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import random
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MaskingGenerator:
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
num_masking_patches=None,
|
||||
min_num_patches=4,
|
||||
max_num_patches=None,
|
||||
min_aspect=0.3,
|
||||
max_aspect=None,
|
||||
):
|
||||
if not isinstance(input_size, tuple):
|
||||
input_size = (input_size,) * 2
|
||||
self.height, self.width = input_size
|
||||
|
||||
self.num_patches = self.height * self.width
|
||||
self.num_masking_patches = num_masking_patches
|
||||
|
||||
self.min_num_patches = min_num_patches
|
||||
self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches
|
||||
|
||||
max_aspect = max_aspect or 1 / min_aspect
|
||||
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
|
||||
self.height,
|
||||
self.width,
|
||||
self.min_num_patches,
|
||||
self.max_num_patches,
|
||||
self.num_masking_patches,
|
||||
self.log_aspect_ratio[0],
|
||||
self.log_aspect_ratio[1],
|
||||
)
|
||||
return repr_str
|
||||
|
||||
def get_shape(self):
|
||||
return self.height, self.width
|
||||
|
||||
def _mask(self, mask, max_mask_patches):
|
||||
delta = 0
|
||||
for _ in range(10):
|
||||
target_area = random.uniform(self.min_num_patches, max_mask_patches)
|
||||
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
|
||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
if w < self.width and h < self.height:
|
||||
top = random.randint(0, self.height - h)
|
||||
left = random.randint(0, self.width - w)
|
||||
|
||||
num_masked = mask[top : top + h, left : left + w].sum()
|
||||
# Overlap
|
||||
if 0 < h * w - num_masked <= max_mask_patches:
|
||||
for i in range(top, top + h):
|
||||
for j in range(left, left + w):
|
||||
if mask[i, j] == 0:
|
||||
mask[i, j] = 1
|
||||
delta += 1
|
||||
|
||||
if delta > 0:
|
||||
break
|
||||
return delta
|
||||
|
||||
def __call__(self, num_masking_patches=0):
|
||||
mask = np.zeros(shape=self.get_shape(), dtype=bool)
|
||||
mask_count = 0
|
||||
while mask_count < num_masking_patches:
|
||||
max_mask_patches = num_masking_patches - mask_count
|
||||
max_mask_patches = min(max_mask_patches, self.max_num_patches)
|
||||
|
||||
delta = self._mask(mask, max_mask_patches)
|
||||
if delta == 0:
|
||||
break
|
||||
else:
|
||||
mask_count += delta
|
||||
|
||||
return mask
|
||||
@@ -0,0 +1,230 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import itertools
|
||||
from typing import Any, Optional
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
import dinov2.distributed as distributed
|
||||
|
||||
|
||||
class EpochSampler(Sampler):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
size: int,
|
||||
sample_count: int,
|
||||
shuffle: bool = False,
|
||||
seed: int = 0,
|
||||
start: Optional[int] = None,
|
||||
step: Optional[int] = None,
|
||||
):
|
||||
self._size = size
|
||||
self._sample_count = sample_count
|
||||
self._shuffle = shuffle
|
||||
self._seed = seed
|
||||
self._start = distributed.get_global_rank() if start is None else start
|
||||
self._step = distributed.get_global_size() if step is None else step
|
||||
self._epoch = 0
|
||||
|
||||
def __iter__(self):
|
||||
count = (self._size + self._sample_count - 1) // self._sample_count
|
||||
tiled_indices = np.tile(np.arange(self._sample_count), count)
|
||||
if self._shuffle:
|
||||
seed = self._seed * self._epoch if self._seed != 0 else self._epoch
|
||||
rng = np.random.default_rng(seed)
|
||||
iterable = rng.choice(tiled_indices, self._size, replace=False)
|
||||
else:
|
||||
iterable = tiled_indices[: self._size]
|
||||
|
||||
yield from itertools.islice(iterable, self._start, None, self._step)
|
||||
|
||||
def __len__(self):
|
||||
return (self._size - self._start + self._step - 1) // self._step
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self._epoch = epoch
|
||||
|
||||
|
||||
def _get_numpy_dtype(size: int) -> Any:
|
||||
return np.int32 if size <= 2**31 else np.int64
|
||||
|
||||
|
||||
def _get_torch_dtype(size: int) -> Any:
|
||||
return torch.int32 if size <= 2**31 else torch.int64
|
||||
|
||||
|
||||
def _generate_randperm_indices(*, size: int, generator: torch.Generator):
|
||||
"""Generate the indices of a random permutation."""
|
||||
dtype = _get_torch_dtype(size)
|
||||
# This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921
|
||||
perm = torch.arange(size, dtype=dtype)
|
||||
for i in range(size):
|
||||
j = torch.randint(i, size, size=(1,), generator=generator).item()
|
||||
|
||||
# Always swap even if no-op
|
||||
value = perm[j].item()
|
||||
perm[j] = perm[i].item()
|
||||
perm[i] = value
|
||||
yield value
|
||||
|
||||
|
||||
class InfiniteSampler(Sampler):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
sample_count: int,
|
||||
shuffle: bool = False,
|
||||
seed: int = 0,
|
||||
start: Optional[int] = None,
|
||||
step: Optional[int] = None,
|
||||
advance: int = 0,
|
||||
):
|
||||
self._sample_count = sample_count
|
||||
self._seed = seed
|
||||
self._shuffle = shuffle
|
||||
self._start = distributed.get_global_rank() if start is None else start
|
||||
self._step = distributed.get_global_size() if step is None else step
|
||||
self._advance = advance
|
||||
|
||||
def __iter__(self):
|
||||
if self._shuffle:
|
||||
iterator = self._shuffled_iterator()
|
||||
else:
|
||||
iterator = self._iterator()
|
||||
|
||||
yield from itertools.islice(iterator, self._advance, None)
|
||||
|
||||
def _iterator(self):
|
||||
assert not self._shuffle
|
||||
|
||||
while True:
|
||||
iterable = range(self._sample_count)
|
||||
yield from itertools.islice(iterable, self._start, None, self._step)
|
||||
|
||||
def _shuffled_iterator(self):
|
||||
assert self._shuffle
|
||||
|
||||
# Instantiate a generator here (rather than in the ctor) to keep the class
|
||||
# picklable (requirement of mp.spawn)
|
||||
generator = torch.Generator().manual_seed(self._seed)
|
||||
|
||||
while True:
|
||||
iterable = _generate_randperm_indices(size=self._sample_count, generator=generator)
|
||||
yield from itertools.islice(iterable, self._start, None, self._step)
|
||||
|
||||
|
||||
# The following function is somewhat equivalent to _new_shuffle_tensor_slice below,
|
||||
# but avoids a full in-place random permutation generation.
|
||||
def _shuffle_tensor_slice(
|
||||
*, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
|
||||
) -> np.ndarray:
|
||||
stop = len(tensor)
|
||||
count = stop // step
|
||||
drop_count = stop - step * count
|
||||
if drop_count:
|
||||
warnings.warn(f"# of dropped samples: {drop_count}")
|
||||
|
||||
dtype = _get_numpy_dtype(stop)
|
||||
result = np.empty(count, dtype=dtype)
|
||||
|
||||
for i in range(count):
|
||||
j = torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0
|
||||
|
||||
result[i] = result[j]
|
||||
result[j] = tensor[start + i * step].item()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _new_shuffle_tensor_slice(
|
||||
*, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
|
||||
) -> np.ndarray:
|
||||
stop = len(tensor)
|
||||
count = stop // step
|
||||
dtype = torch.int64 # Needed for using randperm result as indices
|
||||
count = stop // step
|
||||
drop_count = stop - step * count
|
||||
if drop_count:
|
||||
warnings.warn(f"# of dropped samples: {drop_count}")
|
||||
indices = torch.randperm(count, dtype=dtype, generator=generator)
|
||||
return tensor[start::step][indices].numpy()
|
||||
|
||||
|
||||
def _make_seed(seed: int, start: int, iter_count: int) -> int:
|
||||
# NOTE: Tried a few variants (including iter_count << 32), this one worked best.
|
||||
return seed + start + (iter_count << 24)
|
||||
|
||||
|
||||
class ShardedInfiniteSampler(Sampler):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
sample_count: int,
|
||||
shuffle: bool = False,
|
||||
seed: int = 0,
|
||||
start: Optional[int] = None,
|
||||
step: Optional[int] = None,
|
||||
advance: int = 0,
|
||||
use_new_shuffle_tensor_slice: bool = False,
|
||||
):
|
||||
self._sample_count = sample_count
|
||||
self._seed = seed
|
||||
self._shuffle = shuffle
|
||||
self._start = distributed.get_global_rank() if start is None else start
|
||||
self._step = distributed.get_global_size() if step is None else step
|
||||
self._advance = advance
|
||||
self._iter_count = 0
|
||||
self._shuffle_tensor_slice_fn = (
|
||||
_new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
iter_count = self._advance // self._sample_count
|
||||
if iter_count > 0:
|
||||
self._advance -= iter_count * self._sample_count
|
||||
self._iter_count += iter_count
|
||||
|
||||
if self._shuffle:
|
||||
iterator = self._shuffled_iterator()
|
||||
else:
|
||||
iterator = self._iterator()
|
||||
|
||||
yield from itertools.islice(iterator, self._advance, None)
|
||||
|
||||
def _iterator(self):
|
||||
assert not self._shuffle
|
||||
|
||||
while True:
|
||||
iterable = range(self._sample_count)
|
||||
yield from itertools.islice(iterable, self._start, None, self._step)
|
||||
|
||||
def _shuffled_iterator(self):
|
||||
assert self._shuffle
|
||||
|
||||
# Instantiate a generator here (rather than in the ctor) to be keep the class
|
||||
# picklable (requirement of mp.spawn)
|
||||
generator = torch.Generator()
|
||||
|
||||
# Always shuffle everything first
|
||||
generator.manual_seed(self._seed)
|
||||
dtype = _get_torch_dtype(self._sample_count)
|
||||
perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator)
|
||||
|
||||
while True:
|
||||
# Re-seed on each iteration to allow skipping whole permutations
|
||||
seed = _make_seed(self._seed, self._start, self._iter_count)
|
||||
generator.manual_seed(seed)
|
||||
|
||||
iterable = self._shuffle_tensor_slice_fn(
|
||||
tensor=perm, start=self._start, step=self._step, generator=generator
|
||||
)
|
||||
yield from iterable
|
||||
self._iter_count += 1
|
||||
@@ -0,0 +1,92 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
class GaussianBlur(transforms.RandomApply):
|
||||
"""
|
||||
Apply Gaussian Blur to the PIL image.
|
||||
"""
|
||||
|
||||
def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0):
|
||||
# NOTE: torchvision is applying 1 - probability to return the original image
|
||||
keep_p = 1 - p
|
||||
transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max))
|
||||
super().__init__(transforms=[transform], p=keep_p)
|
||||
|
||||
|
||||
class MaybeToTensor(transforms.ToTensor):
|
||||
"""
|
||||
Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor.
|
||||
"""
|
||||
|
||||
def __call__(self, pic):
|
||||
"""
|
||||
Args:
|
||||
pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor.
|
||||
Returns:
|
||||
Tensor: Converted image.
|
||||
"""
|
||||
if isinstance(pic, torch.Tensor):
|
||||
return pic
|
||||
return super().__call__(pic)
|
||||
|
||||
|
||||
# Use timm's names
|
||||
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
||||
|
||||
|
||||
def make_normalize_transform(
|
||||
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
|
||||
std: Sequence[float] = IMAGENET_DEFAULT_STD,
|
||||
) -> transforms.Normalize:
|
||||
return transforms.Normalize(mean=mean, std=std)
|
||||
|
||||
|
||||
# This roughly matches torchvision's preset for classification training:
|
||||
# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44
|
||||
def make_classification_train_transform(
|
||||
*,
|
||||
crop_size: int = 224,
|
||||
interpolation=transforms.InterpolationMode.BICUBIC,
|
||||
hflip_prob: float = 0.5,
|
||||
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
|
||||
std: Sequence[float] = IMAGENET_DEFAULT_STD,
|
||||
):
|
||||
transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
|
||||
if hflip_prob > 0.0:
|
||||
transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob))
|
||||
transforms_list.extend(
|
||||
[
|
||||
MaybeToTensor(),
|
||||
make_normalize_transform(mean=mean, std=std),
|
||||
]
|
||||
)
|
||||
return transforms.Compose(transforms_list)
|
||||
|
||||
|
||||
# This matches (roughly) torchvision's preset for classification evaluation:
|
||||
# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69
|
||||
def make_classification_eval_transform(
|
||||
*,
|
||||
resize_size: int = 256,
|
||||
interpolation=transforms.InterpolationMode.BICUBIC,
|
||||
crop_size: int = 224,
|
||||
mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
|
||||
std: Sequence[float] = IMAGENET_DEFAULT_STD,
|
||||
) -> transforms.Compose:
|
||||
transforms_list = [
|
||||
transforms.Resize(resize_size, interpolation=interpolation),
|
||||
transforms.CenterCrop(crop_size),
|
||||
MaybeToTensor(),
|
||||
make_normalize_transform(mean=mean, std=std),
|
||||
]
|
||||
return transforms.Compose(transforms_list)
|
||||
Reference in New Issue
Block a user