Initial media depth project backup
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
@@ -0,0 +1,405 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
from functools import partial
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.functional import one_hot, softmax
|
||||
|
||||
import dinov2.distributed as distributed
|
||||
from dinov2.data import SamplerType, make_data_loader, make_dataset
|
||||
from dinov2.data.transforms import make_classification_eval_transform
|
||||
from dinov2.eval.metrics import AccuracyAveraging, build_topk_accuracy_metric
|
||||
from dinov2.eval.setup import get_args_parser as get_setup_args_parser
|
||||
from dinov2.eval.setup import setup_and_build_model
|
||||
from dinov2.eval.utils import ModelWithNormalize, evaluate, extract_features
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
def get_args_parser(
|
||||
description: Optional[str] = None,
|
||||
parents: Optional[List[argparse.ArgumentParser]] = None,
|
||||
add_help: bool = True,
|
||||
):
|
||||
parents = parents or []
|
||||
setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
|
||||
parents = [setup_args_parser]
|
||||
parser = argparse.ArgumentParser(
|
||||
description=description,
|
||||
parents=parents,
|
||||
add_help=add_help,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train-dataset",
|
||||
dest="train_dataset_str",
|
||||
type=str,
|
||||
help="Training dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val-dataset",
|
||||
dest="val_dataset_str",
|
||||
type=str,
|
||||
help="Validation dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nb_knn",
|
||||
nargs="+",
|
||||
type=int,
|
||||
help="Number of NN to use. 20 is usually working the best.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
help="Temperature used in the voting coefficient",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gather-on-cpu",
|
||||
action="store_true",
|
||||
help="Whether to gather the train features on cpu, slower"
|
||||
"but useful to avoid OOM for large datasets (e.g. ImageNet22k).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
help="Batch size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n-per-class-list",
|
||||
nargs="+",
|
||||
type=int,
|
||||
help="Number to take per class",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n-tries",
|
||||
type=int,
|
||||
help="Number of tries",
|
||||
)
|
||||
parser.set_defaults(
|
||||
train_dataset_str="ImageNet:split=TRAIN",
|
||||
val_dataset_str="ImageNet:split=VAL",
|
||||
nb_knn=[10, 20, 100, 200],
|
||||
temperature=0.07,
|
||||
batch_size=256,
|
||||
n_per_class_list=[-1],
|
||||
n_tries=1,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
class KnnModule(torch.nn.Module):
|
||||
"""
|
||||
Gets knn of test features from all processes on a chunk of the train features
|
||||
|
||||
Each rank gets a chunk of the train features as well as a chunk of the test features.
|
||||
In `compute_neighbors`, for each rank one after the other, its chunk of test features
|
||||
is sent to all devices, partial knns are computed with each chunk of train features
|
||||
then collated back on the original device.
|
||||
"""
|
||||
|
||||
def __init__(self, train_features, train_labels, nb_knn, T, device, num_classes=1000):
|
||||
super().__init__()
|
||||
|
||||
self.global_rank = distributed.get_global_rank()
|
||||
self.global_size = distributed.get_global_size()
|
||||
|
||||
self.device = device
|
||||
self.train_features_rank_T = train_features.chunk(self.global_size)[self.global_rank].T.to(self.device)
|
||||
self.candidates = train_labels.chunk(self.global_size)[self.global_rank].view(1, -1).to(self.device)
|
||||
|
||||
self.nb_knn = nb_knn
|
||||
self.max_k = max(self.nb_knn)
|
||||
self.T = T
|
||||
self.num_classes = num_classes
|
||||
|
||||
def _get_knn_sims_and_labels(self, similarity, train_labels):
|
||||
topk_sims, indices = similarity.topk(self.max_k, largest=True, sorted=True)
|
||||
neighbors_labels = torch.gather(train_labels, 1, indices)
|
||||
return topk_sims, neighbors_labels
|
||||
|
||||
def _similarity_for_rank(self, features_rank, source_rank):
|
||||
# Send the features from `source_rank` to all ranks
|
||||
broadcast_shape = torch.tensor(features_rank.shape).to(self.device)
|
||||
torch.distributed.broadcast(broadcast_shape, source_rank)
|
||||
|
||||
broadcasted = features_rank
|
||||
if self.global_rank != source_rank:
|
||||
broadcasted = torch.zeros(*broadcast_shape, dtype=features_rank.dtype, device=self.device)
|
||||
torch.distributed.broadcast(broadcasted, source_rank)
|
||||
|
||||
# Compute the neighbors for `source_rank` among `train_features_rank_T`
|
||||
similarity_rank = torch.mm(broadcasted, self.train_features_rank_T)
|
||||
candidate_labels = self.candidates.expand(len(similarity_rank), -1)
|
||||
return self._get_knn_sims_and_labels(similarity_rank, candidate_labels)
|
||||
|
||||
def _gather_all_knn_for_rank(self, topk_sims, neighbors_labels, target_rank):
|
||||
# Gather all neighbors for `target_rank`
|
||||
topk_sims_rank = retrieved_rank = None
|
||||
if self.global_rank == target_rank:
|
||||
topk_sims_rank = [torch.zeros_like(topk_sims) for _ in range(self.global_size)]
|
||||
retrieved_rank = [torch.zeros_like(neighbors_labels) for _ in range(self.global_size)]
|
||||
|
||||
torch.distributed.gather(topk_sims, topk_sims_rank, dst=target_rank)
|
||||
torch.distributed.gather(neighbors_labels, retrieved_rank, dst=target_rank)
|
||||
|
||||
if self.global_rank == target_rank:
|
||||
# Perform a second top-k on the k * global_size retrieved neighbors
|
||||
topk_sims_rank = torch.cat(topk_sims_rank, dim=1)
|
||||
retrieved_rank = torch.cat(retrieved_rank, dim=1)
|
||||
results = self._get_knn_sims_and_labels(topk_sims_rank, retrieved_rank)
|
||||
return results
|
||||
return None
|
||||
|
||||
def compute_neighbors(self, features_rank):
|
||||
for rank in range(self.global_size):
|
||||
topk_sims, neighbors_labels = self._similarity_for_rank(features_rank, rank)
|
||||
results = self._gather_all_knn_for_rank(topk_sims, neighbors_labels, rank)
|
||||
if results is not None:
|
||||
topk_sims_rank, neighbors_labels_rank = results
|
||||
return topk_sims_rank, neighbors_labels_rank
|
||||
|
||||
def forward(self, features_rank):
|
||||
"""
|
||||
Compute the results on all values of `self.nb_knn` neighbors from the full `self.max_k`
|
||||
"""
|
||||
assert all(k <= self.max_k for k in self.nb_knn)
|
||||
|
||||
topk_sims, neighbors_labels = self.compute_neighbors(features_rank)
|
||||
batch_size = neighbors_labels.shape[0]
|
||||
topk_sims_transform = softmax(topk_sims / self.T, 1)
|
||||
matmul = torch.mul(
|
||||
one_hot(neighbors_labels, num_classes=self.num_classes),
|
||||
topk_sims_transform.view(batch_size, -1, 1),
|
||||
)
|
||||
probas_for_k = {k: torch.sum(matmul[:, :k, :], 1) for k in self.nb_knn}
|
||||
return probas_for_k
|
||||
|
||||
|
||||
class DictKeysModule(torch.nn.Module):
|
||||
def __init__(self, keys):
|
||||
super().__init__()
|
||||
self.keys = keys
|
||||
|
||||
def forward(self, features_dict, targets):
|
||||
for k in self.keys:
|
||||
features_dict = features_dict[k]
|
||||
return {"preds": features_dict, "target": targets}
|
||||
|
||||
|
||||
def create_module_dict(*, module, n_per_class_list, n_tries, nb_knn, train_features, train_labels):
|
||||
modules = {}
|
||||
mapping = create_class_indices_mapping(train_labels)
|
||||
for npc in n_per_class_list:
|
||||
if npc < 0: # Only one try needed when using the full data
|
||||
full_module = module(
|
||||
train_features=train_features,
|
||||
train_labels=train_labels,
|
||||
nb_knn=nb_knn,
|
||||
)
|
||||
modules["full"] = ModuleDictWithForward({"1": full_module})
|
||||
continue
|
||||
all_tries = {}
|
||||
for t in range(n_tries):
|
||||
final_indices = filter_train(mapping, npc, seed=t)
|
||||
k_list = list(set(nb_knn + [npc]))
|
||||
k_list = sorted([el for el in k_list if el <= npc])
|
||||
all_tries[str(t)] = module(
|
||||
train_features=train_features[final_indices],
|
||||
train_labels=train_labels[final_indices],
|
||||
nb_knn=k_list,
|
||||
)
|
||||
modules[f"{npc} per class"] = ModuleDictWithForward(all_tries)
|
||||
|
||||
return ModuleDictWithForward(modules)
|
||||
|
||||
|
||||
def filter_train(mapping, n_per_class, seed):
|
||||
torch.manual_seed(seed)
|
||||
final_indices = []
|
||||
for k in mapping.keys():
|
||||
index = torch.randperm(len(mapping[k]))[:n_per_class]
|
||||
final_indices.append(mapping[k][index])
|
||||
return torch.cat(final_indices).squeeze()
|
||||
|
||||
|
||||
def create_class_indices_mapping(labels):
|
||||
unique_labels, inverse = torch.unique(labels, return_inverse=True)
|
||||
mapping = {unique_labels[i]: (inverse == i).nonzero() for i in range(len(unique_labels))}
|
||||
return mapping
|
||||
|
||||
|
||||
class ModuleDictWithForward(torch.nn.ModuleDict):
|
||||
def forward(self, *args, **kwargs):
|
||||
return {k: module(*args, **kwargs) for k, module in self._modules.items()}
|
||||
|
||||
|
||||
def eval_knn(
|
||||
model,
|
||||
train_dataset,
|
||||
val_dataset,
|
||||
accuracy_averaging,
|
||||
nb_knn,
|
||||
temperature,
|
||||
batch_size,
|
||||
num_workers,
|
||||
gather_on_cpu,
|
||||
n_per_class_list=[-1],
|
||||
n_tries=1,
|
||||
):
|
||||
model = ModelWithNormalize(model)
|
||||
|
||||
logger.info("Extracting features for train set...")
|
||||
train_features, train_labels = extract_features(
|
||||
model, train_dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu
|
||||
)
|
||||
logger.info(f"Train features created, shape {train_features.shape}.")
|
||||
|
||||
val_dataloader = make_data_loader(
|
||||
dataset=val_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
sampler_type=SamplerType.DISTRIBUTED,
|
||||
drop_last=False,
|
||||
shuffle=False,
|
||||
persistent_workers=True,
|
||||
)
|
||||
num_classes = train_labels.max() + 1
|
||||
metric_collection = build_topk_accuracy_metric(accuracy_averaging, num_classes=num_classes)
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
partial_module = partial(KnnModule, T=temperature, device=device, num_classes=num_classes)
|
||||
knn_module_dict = create_module_dict(
|
||||
module=partial_module,
|
||||
n_per_class_list=n_per_class_list,
|
||||
n_tries=n_tries,
|
||||
nb_knn=nb_knn,
|
||||
train_features=train_features,
|
||||
train_labels=train_labels,
|
||||
)
|
||||
postprocessors, metrics = {}, {}
|
||||
for n_per_class, knn_module in knn_module_dict.items():
|
||||
for t, knn_try in knn_module.items():
|
||||
postprocessors = {
|
||||
**postprocessors,
|
||||
**{(n_per_class, t, k): DictKeysModule([n_per_class, t, k]) for k in knn_try.nb_knn},
|
||||
}
|
||||
metrics = {**metrics, **{(n_per_class, t, k): metric_collection.clone() for k in knn_try.nb_knn}}
|
||||
model_with_knn = torch.nn.Sequential(model, knn_module_dict)
|
||||
|
||||
# ============ evaluation ... ============
|
||||
logger.info("Start the k-NN classification.")
|
||||
_, results_dict = evaluate(model_with_knn, val_dataloader, postprocessors, metrics, device)
|
||||
|
||||
# Averaging the results over the n tries for each value of n_per_class
|
||||
for n_per_class, knn_module in knn_module_dict.items():
|
||||
first_try = list(knn_module.keys())[0]
|
||||
k_list = knn_module[first_try].nb_knn
|
||||
for k in k_list:
|
||||
keys = results_dict[(n_per_class, first_try, k)].keys() # keys are e.g. `top-1` and `top-5`
|
||||
results_dict[(n_per_class, k)] = {
|
||||
key: torch.mean(torch.stack([results_dict[(n_per_class, t, k)][key] for t in knn_module.keys()]))
|
||||
for key in keys
|
||||
}
|
||||
for t in knn_module.keys():
|
||||
del results_dict[(n_per_class, t, k)]
|
||||
|
||||
return results_dict
|
||||
|
||||
|
||||
def eval_knn_with_model(
|
||||
model,
|
||||
output_dir,
|
||||
train_dataset_str="ImageNet:split=TRAIN",
|
||||
val_dataset_str="ImageNet:split=VAL",
|
||||
nb_knn=(10, 20, 100, 200),
|
||||
temperature=0.07,
|
||||
autocast_dtype=torch.float,
|
||||
accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY,
|
||||
transform=None,
|
||||
gather_on_cpu=False,
|
||||
batch_size=256,
|
||||
num_workers=5,
|
||||
n_per_class_list=[-1],
|
||||
n_tries=1,
|
||||
):
|
||||
transform = transform or make_classification_eval_transform()
|
||||
|
||||
train_dataset = make_dataset(
|
||||
dataset_str=train_dataset_str,
|
||||
transform=transform,
|
||||
)
|
||||
val_dataset = make_dataset(
|
||||
dataset_str=val_dataset_str,
|
||||
transform=transform,
|
||||
)
|
||||
|
||||
with torch.cuda.amp.autocast(dtype=autocast_dtype):
|
||||
results_dict_knn = eval_knn(
|
||||
model=model,
|
||||
train_dataset=train_dataset,
|
||||
val_dataset=val_dataset,
|
||||
accuracy_averaging=accuracy_averaging,
|
||||
nb_knn=nb_knn,
|
||||
temperature=temperature,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
gather_on_cpu=gather_on_cpu,
|
||||
n_per_class_list=n_per_class_list,
|
||||
n_tries=n_tries,
|
||||
)
|
||||
|
||||
results_dict = {}
|
||||
if distributed.is_main_process():
|
||||
for knn_ in results_dict_knn.keys():
|
||||
top1 = results_dict_knn[knn_]["top-1"].item() * 100.0
|
||||
top5 = results_dict_knn[knn_]["top-5"].item() * 100.0
|
||||
results_dict[f"{knn_} Top 1"] = top1
|
||||
results_dict[f"{knn_} Top 5"] = top5
|
||||
logger.info(f"{knn_} classifier result: Top1: {top1:.2f} Top5: {top5:.2f}")
|
||||
|
||||
metrics_file_path = os.path.join(output_dir, "results_eval_knn.json")
|
||||
with open(metrics_file_path, "a") as f:
|
||||
for k, v in results_dict.items():
|
||||
f.write(json.dumps({k: v}) + "\n")
|
||||
|
||||
if distributed.is_enabled():
|
||||
torch.distributed.barrier()
|
||||
return results_dict
|
||||
|
||||
|
||||
def main(args):
|
||||
model, autocast_dtype = setup_and_build_model(args)
|
||||
eval_knn_with_model(
|
||||
model=model,
|
||||
output_dir=args.output_dir,
|
||||
train_dataset_str=args.train_dataset_str,
|
||||
val_dataset_str=args.val_dataset_str,
|
||||
nb_knn=args.nb_knn,
|
||||
temperature=args.temperature,
|
||||
autocast_dtype=autocast_dtype,
|
||||
accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY,
|
||||
transform=None,
|
||||
gather_on_cpu=args.gather_on_cpu,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=5,
|
||||
n_per_class_list=args.n_per_class_list,
|
||||
n_tries=args.n_tries,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
description = "DINOv2 k-NN evaluation"
|
||||
args_parser = get_args_parser(description=description)
|
||||
args = args_parser.parse_args()
|
||||
sys.exit(main(args))
|
||||
@@ -0,0 +1,626 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
from functools import partial
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
|
||||
|
||||
from dinov2.data import SamplerType, make_data_loader, make_dataset
|
||||
from dinov2.data.transforms import make_classification_eval_transform, make_classification_train_transform
|
||||
import dinov2.distributed as distributed
|
||||
from dinov2.eval.metrics import MetricType, build_metric
|
||||
from dinov2.eval.setup import get_args_parser as get_setup_args_parser
|
||||
from dinov2.eval.setup import setup_and_build_model
|
||||
from dinov2.eval.utils import ModelWithIntermediateLayers, evaluate
|
||||
from dinov2.logging import MetricLogger
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
def get_args_parser(
|
||||
description: Optional[str] = None,
|
||||
parents: Optional[List[argparse.ArgumentParser]] = None,
|
||||
add_help: bool = True,
|
||||
):
|
||||
parents = parents or []
|
||||
setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
|
||||
parents = [setup_args_parser]
|
||||
parser = argparse.ArgumentParser(
|
||||
description=description,
|
||||
parents=parents,
|
||||
add_help=add_help,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train-dataset",
|
||||
dest="train_dataset_str",
|
||||
type=str,
|
||||
help="Training dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val-dataset",
|
||||
dest="val_dataset_str",
|
||||
type=str,
|
||||
help="Validation dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-datasets",
|
||||
dest="test_dataset_strs",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Test datasets, none to reuse the validation dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
type=int,
|
||||
help="Number of training epochs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
help="Batch Size (per GPU)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
help="Number de Workers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epoch-length",
|
||||
type=int,
|
||||
help="Length of an epoch in number of iterations",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-checkpoint-frequency",
|
||||
type=int,
|
||||
help="Number of epochs between two named checkpoint saves.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval-period-iterations",
|
||||
type=int,
|
||||
help="Number of iterations between two evaluations.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning-rates",
|
||||
nargs="+",
|
||||
type=float,
|
||||
help="Learning rates to grid search.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-resume",
|
||||
action="store_true",
|
||||
help="Whether to not resume from existing checkpoints",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val-metric-type",
|
||||
type=MetricType,
|
||||
choices=list(MetricType),
|
||||
help="Validation metric",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-metric-types",
|
||||
type=MetricType,
|
||||
choices=list(MetricType),
|
||||
nargs="+",
|
||||
help="Evaluation metric",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--classifier-fpath",
|
||||
type=str,
|
||||
help="Path to a file containing pretrained linear classifiers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val-class-mapping-fpath",
|
||||
type=str,
|
||||
help="Path to a file containing a mapping to adjust classifier outputs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-class-mapping-fpaths",
|
||||
nargs="+",
|
||||
type=str,
|
||||
help="Path to a file containing a mapping to adjust classifier outputs",
|
||||
)
|
||||
parser.set_defaults(
|
||||
train_dataset_str="ImageNet:split=TRAIN",
|
||||
val_dataset_str="ImageNet:split=VAL",
|
||||
test_dataset_strs=None,
|
||||
epochs=10,
|
||||
batch_size=128,
|
||||
num_workers=8,
|
||||
epoch_length=1250,
|
||||
save_checkpoint_frequency=20,
|
||||
eval_period_iterations=1250,
|
||||
learning_rates=[1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2, 2e-2, 5e-2, 0.1],
|
||||
val_metric_type=MetricType.MEAN_ACCURACY,
|
||||
test_metric_types=None,
|
||||
classifier_fpath=None,
|
||||
val_class_mapping_fpath=None,
|
||||
test_class_mapping_fpaths=[None],
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def has_ddp_wrapper(m: nn.Module) -> bool:
|
||||
return isinstance(m, DistributedDataParallel)
|
||||
|
||||
|
||||
def remove_ddp_wrapper(m: nn.Module) -> nn.Module:
|
||||
return m.module if has_ddp_wrapper(m) else m
|
||||
|
||||
|
||||
def _pad_and_collate(batch):
|
||||
maxlen = max(len(targets) for image, targets in batch)
|
||||
padded_batch = [
|
||||
(image, np.pad(targets, (0, maxlen - len(targets)), constant_values=-1)) for image, targets in batch
|
||||
]
|
||||
return torch.utils.data.default_collate(padded_batch)
|
||||
|
||||
|
||||
def create_linear_input(x_tokens_list, use_n_blocks, use_avgpool):
|
||||
intermediate_output = x_tokens_list[-use_n_blocks:]
|
||||
output = torch.cat([class_token for _, class_token in intermediate_output], dim=-1)
|
||||
if use_avgpool:
|
||||
output = torch.cat(
|
||||
(
|
||||
output,
|
||||
torch.mean(intermediate_output[-1][0], dim=1), # patch tokens
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
output = output.reshape(output.shape[0], -1)
|
||||
return output.float()
|
||||
|
||||
|
||||
class LinearClassifier(nn.Module):
|
||||
"""Linear layer to train on top of frozen features"""
|
||||
|
||||
def __init__(self, out_dim, use_n_blocks, use_avgpool, num_classes=1000):
|
||||
super().__init__()
|
||||
self.out_dim = out_dim
|
||||
self.use_n_blocks = use_n_blocks
|
||||
self.use_avgpool = use_avgpool
|
||||
self.num_classes = num_classes
|
||||
self.linear = nn.Linear(out_dim, num_classes)
|
||||
self.linear.weight.data.normal_(mean=0.0, std=0.01)
|
||||
self.linear.bias.data.zero_()
|
||||
|
||||
def forward(self, x_tokens_list):
|
||||
output = create_linear_input(x_tokens_list, self.use_n_blocks, self.use_avgpool)
|
||||
return self.linear(output)
|
||||
|
||||
|
||||
class AllClassifiers(nn.Module):
|
||||
def __init__(self, classifiers_dict):
|
||||
super().__init__()
|
||||
self.classifiers_dict = nn.ModuleDict()
|
||||
self.classifiers_dict.update(classifiers_dict)
|
||||
|
||||
def forward(self, inputs):
|
||||
return {k: v.forward(inputs) for k, v in self.classifiers_dict.items()}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.classifiers_dict)
|
||||
|
||||
|
||||
class LinearPostprocessor(nn.Module):
|
||||
def __init__(self, linear_classifier, class_mapping=None):
|
||||
super().__init__()
|
||||
self.linear_classifier = linear_classifier
|
||||
self.register_buffer("class_mapping", None if class_mapping is None else torch.LongTensor(class_mapping))
|
||||
|
||||
def forward(self, samples, targets):
|
||||
preds = self.linear_classifier(samples)
|
||||
return {
|
||||
"preds": preds[:, self.class_mapping] if self.class_mapping is not None else preds,
|
||||
"target": targets,
|
||||
}
|
||||
|
||||
|
||||
def scale_lr(learning_rates, batch_size):
|
||||
return learning_rates * (batch_size * distributed.get_global_size()) / 256.0
|
||||
|
||||
|
||||
def setup_linear_classifiers(sample_output, n_last_blocks_list, learning_rates, batch_size, num_classes=1000):
|
||||
linear_classifiers_dict = nn.ModuleDict()
|
||||
optim_param_groups = []
|
||||
for n in n_last_blocks_list:
|
||||
for avgpool in [False, True]:
|
||||
for _lr in learning_rates:
|
||||
lr = scale_lr(_lr, batch_size)
|
||||
out_dim = create_linear_input(sample_output, use_n_blocks=n, use_avgpool=avgpool).shape[1]
|
||||
linear_classifier = LinearClassifier(
|
||||
out_dim, use_n_blocks=n, use_avgpool=avgpool, num_classes=num_classes
|
||||
)
|
||||
linear_classifier = linear_classifier.cuda()
|
||||
linear_classifiers_dict[
|
||||
f"classifier_{n}_blocks_avgpool_{avgpool}_lr_{lr:.5f}".replace(".", "_")
|
||||
] = linear_classifier
|
||||
optim_param_groups.append({"params": linear_classifier.parameters(), "lr": lr})
|
||||
|
||||
linear_classifiers = AllClassifiers(linear_classifiers_dict)
|
||||
if distributed.is_enabled():
|
||||
linear_classifiers = nn.parallel.DistributedDataParallel(linear_classifiers)
|
||||
|
||||
return linear_classifiers, optim_param_groups
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_linear_classifiers(
|
||||
feature_model,
|
||||
linear_classifiers,
|
||||
data_loader,
|
||||
metric_type,
|
||||
metrics_file_path,
|
||||
training_num_classes,
|
||||
iteration,
|
||||
prefixstring="",
|
||||
class_mapping=None,
|
||||
best_classifier_on_val=None,
|
||||
):
|
||||
logger.info("running validation !")
|
||||
|
||||
num_classes = len(class_mapping) if class_mapping is not None else training_num_classes
|
||||
metric = build_metric(metric_type, num_classes=num_classes)
|
||||
postprocessors = {k: LinearPostprocessor(v, class_mapping) for k, v in linear_classifiers.classifiers_dict.items()}
|
||||
metrics = {k: metric.clone() for k in linear_classifiers.classifiers_dict}
|
||||
|
||||
_, results_dict_temp = evaluate(
|
||||
feature_model,
|
||||
data_loader,
|
||||
postprocessors,
|
||||
metrics,
|
||||
torch.cuda.current_device(),
|
||||
)
|
||||
|
||||
logger.info("")
|
||||
results_dict = {}
|
||||
max_accuracy = 0
|
||||
best_classifier = ""
|
||||
for i, (classifier_string, metric) in enumerate(results_dict_temp.items()):
|
||||
logger.info(f"{prefixstring} -- Classifier: {classifier_string} * {metric}")
|
||||
if (
|
||||
best_classifier_on_val is None and metric["top-1"].item() > max_accuracy
|
||||
) or classifier_string == best_classifier_on_val:
|
||||
max_accuracy = metric["top-1"].item()
|
||||
best_classifier = classifier_string
|
||||
|
||||
results_dict["best_classifier"] = {"name": best_classifier, "accuracy": max_accuracy}
|
||||
|
||||
logger.info(f"best classifier: {results_dict['best_classifier']}")
|
||||
|
||||
if distributed.is_main_process():
|
||||
with open(metrics_file_path, "a") as f:
|
||||
f.write(f"iter: {iteration}\n")
|
||||
for k, v in results_dict.items():
|
||||
f.write(json.dumps({k: v}) + "\n")
|
||||
f.write("\n")
|
||||
|
||||
return results_dict
|
||||
|
||||
|
||||
def eval_linear(
|
||||
*,
|
||||
feature_model,
|
||||
linear_classifiers,
|
||||
train_data_loader,
|
||||
val_data_loader,
|
||||
metrics_file_path,
|
||||
optimizer,
|
||||
scheduler,
|
||||
output_dir,
|
||||
max_iter,
|
||||
checkpoint_period, # In number of iter, creates a new file every period
|
||||
running_checkpoint_period, # Period to update main checkpoint file
|
||||
eval_period,
|
||||
metric_type,
|
||||
training_num_classes,
|
||||
resume=True,
|
||||
classifier_fpath=None,
|
||||
val_class_mapping=None,
|
||||
):
|
||||
checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler)
|
||||
start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1
|
||||
|
||||
periodic_checkpointer = PeriodicCheckpointer(checkpointer, checkpoint_period, max_iter=max_iter)
|
||||
iteration = start_iter
|
||||
logger.info("Starting training from iteration {}".format(start_iter))
|
||||
metric_logger = MetricLogger(delimiter=" ")
|
||||
header = "Training"
|
||||
|
||||
for data, labels in metric_logger.log_every(
|
||||
train_data_loader,
|
||||
10,
|
||||
header,
|
||||
max_iter,
|
||||
start_iter,
|
||||
):
|
||||
data = data.cuda(non_blocking=True)
|
||||
labels = labels.cuda(non_blocking=True)
|
||||
|
||||
features = feature_model(data)
|
||||
outputs = linear_classifiers(features)
|
||||
|
||||
losses = {f"loss_{k}": nn.CrossEntropyLoss()(v, labels) for k, v in outputs.items()}
|
||||
loss = sum(losses.values())
|
||||
|
||||
# compute the gradients
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
# step
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
# log
|
||||
if iteration % 10 == 0:
|
||||
torch.cuda.synchronize()
|
||||
metric_logger.update(loss=loss.item())
|
||||
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
||||
print("lr", optimizer.param_groups[0]["lr"])
|
||||
|
||||
if iteration - start_iter > 5:
|
||||
if iteration % running_checkpoint_period == 0:
|
||||
torch.cuda.synchronize()
|
||||
if distributed.is_main_process():
|
||||
logger.info("Checkpointing running_checkpoint")
|
||||
periodic_checkpointer.save("running_checkpoint_linear_eval", iteration=iteration)
|
||||
torch.cuda.synchronize()
|
||||
periodic_checkpointer.step(iteration)
|
||||
|
||||
if eval_period > 0 and (iteration + 1) % eval_period == 0 and iteration != max_iter - 1:
|
||||
_ = evaluate_linear_classifiers(
|
||||
feature_model=feature_model,
|
||||
linear_classifiers=remove_ddp_wrapper(linear_classifiers),
|
||||
data_loader=val_data_loader,
|
||||
metrics_file_path=metrics_file_path,
|
||||
prefixstring=f"ITER: {iteration}",
|
||||
metric_type=metric_type,
|
||||
training_num_classes=training_num_classes,
|
||||
iteration=iteration,
|
||||
class_mapping=val_class_mapping,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
iteration = iteration + 1
|
||||
|
||||
val_results_dict = evaluate_linear_classifiers(
|
||||
feature_model=feature_model,
|
||||
linear_classifiers=remove_ddp_wrapper(linear_classifiers),
|
||||
data_loader=val_data_loader,
|
||||
metrics_file_path=metrics_file_path,
|
||||
metric_type=metric_type,
|
||||
training_num_classes=training_num_classes,
|
||||
iteration=iteration,
|
||||
class_mapping=val_class_mapping,
|
||||
)
|
||||
return val_results_dict, feature_model, linear_classifiers, iteration
|
||||
|
||||
|
||||
def make_eval_data_loader(test_dataset_str, batch_size, num_workers, metric_type):
|
||||
test_dataset = make_dataset(
|
||||
dataset_str=test_dataset_str,
|
||||
transform=make_classification_eval_transform(),
|
||||
)
|
||||
test_data_loader = make_data_loader(
|
||||
dataset=test_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
sampler_type=SamplerType.DISTRIBUTED,
|
||||
drop_last=False,
|
||||
shuffle=False,
|
||||
persistent_workers=False,
|
||||
collate_fn=_pad_and_collate if metric_type == MetricType.IMAGENET_REAL_ACCURACY else None,
|
||||
)
|
||||
return test_data_loader
|
||||
|
||||
|
||||
def test_on_datasets(
|
||||
feature_model,
|
||||
linear_classifiers,
|
||||
test_dataset_strs,
|
||||
batch_size,
|
||||
num_workers,
|
||||
test_metric_types,
|
||||
metrics_file_path,
|
||||
training_num_classes,
|
||||
iteration,
|
||||
best_classifier_on_val,
|
||||
prefixstring="",
|
||||
test_class_mappings=[None],
|
||||
):
|
||||
results_dict = {}
|
||||
for test_dataset_str, class_mapping, metric_type in zip(test_dataset_strs, test_class_mappings, test_metric_types):
|
||||
logger.info(f"Testing on {test_dataset_str}")
|
||||
test_data_loader = make_eval_data_loader(test_dataset_str, batch_size, num_workers, metric_type)
|
||||
dataset_results_dict = evaluate_linear_classifiers(
|
||||
feature_model,
|
||||
remove_ddp_wrapper(linear_classifiers),
|
||||
test_data_loader,
|
||||
metric_type,
|
||||
metrics_file_path,
|
||||
training_num_classes,
|
||||
iteration,
|
||||
prefixstring="",
|
||||
class_mapping=class_mapping,
|
||||
best_classifier_on_val=best_classifier_on_val,
|
||||
)
|
||||
results_dict[f"{test_dataset_str}_accuracy"] = 100.0 * dataset_results_dict["best_classifier"]["accuracy"]
|
||||
return results_dict
|
||||
|
||||
|
||||
def run_eval_linear(
|
||||
model,
|
||||
output_dir,
|
||||
train_dataset_str,
|
||||
val_dataset_str,
|
||||
batch_size,
|
||||
epochs,
|
||||
epoch_length,
|
||||
num_workers,
|
||||
save_checkpoint_frequency,
|
||||
eval_period_iterations,
|
||||
learning_rates,
|
||||
autocast_dtype,
|
||||
test_dataset_strs=None,
|
||||
resume=True,
|
||||
classifier_fpath=None,
|
||||
val_class_mapping_fpath=None,
|
||||
test_class_mapping_fpaths=[None],
|
||||
val_metric_type=MetricType.MEAN_ACCURACY,
|
||||
test_metric_types=None,
|
||||
):
|
||||
seed = 0
|
||||
|
||||
if test_dataset_strs is None:
|
||||
test_dataset_strs = [val_dataset_str]
|
||||
if test_metric_types is None:
|
||||
test_metric_types = [val_metric_type] * len(test_dataset_strs)
|
||||
else:
|
||||
assert len(test_metric_types) == len(test_dataset_strs)
|
||||
assert len(test_dataset_strs) == len(test_class_mapping_fpaths)
|
||||
|
||||
train_transform = make_classification_train_transform()
|
||||
train_dataset = make_dataset(
|
||||
dataset_str=train_dataset_str,
|
||||
transform=train_transform,
|
||||
)
|
||||
training_num_classes = len(torch.unique(torch.Tensor(train_dataset.get_targets().astype(int))))
|
||||
sampler_type = SamplerType.SHARDED_INFINITE
|
||||
# sampler_type = SamplerType.INFINITE
|
||||
|
||||
n_last_blocks_list = [1, 4]
|
||||
n_last_blocks = max(n_last_blocks_list)
|
||||
autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype)
|
||||
feature_model = ModelWithIntermediateLayers(model, n_last_blocks, autocast_ctx)
|
||||
sample_output = feature_model(train_dataset[0][0].unsqueeze(0).cuda())
|
||||
|
||||
linear_classifiers, optim_param_groups = setup_linear_classifiers(
|
||||
sample_output,
|
||||
n_last_blocks_list,
|
||||
learning_rates,
|
||||
batch_size,
|
||||
training_num_classes,
|
||||
)
|
||||
|
||||
optimizer = torch.optim.SGD(optim_param_groups, momentum=0.9, weight_decay=0)
|
||||
max_iter = epochs * epoch_length
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0)
|
||||
checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler)
|
||||
start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1
|
||||
train_data_loader = make_data_loader(
|
||||
dataset=train_dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
shuffle=True,
|
||||
seed=seed,
|
||||
sampler_type=sampler_type,
|
||||
sampler_advance=start_iter,
|
||||
drop_last=True,
|
||||
persistent_workers=True,
|
||||
)
|
||||
val_data_loader = make_eval_data_loader(val_dataset_str, batch_size, num_workers, val_metric_type)
|
||||
|
||||
checkpoint_period = save_checkpoint_frequency * epoch_length
|
||||
|
||||
if val_class_mapping_fpath is not None:
|
||||
logger.info(f"Using class mapping from {val_class_mapping_fpath}")
|
||||
val_class_mapping = np.load(val_class_mapping_fpath)
|
||||
else:
|
||||
val_class_mapping = None
|
||||
|
||||
test_class_mappings = []
|
||||
for class_mapping_fpath in test_class_mapping_fpaths:
|
||||
if class_mapping_fpath is not None and class_mapping_fpath != "None":
|
||||
logger.info(f"Using class mapping from {class_mapping_fpath}")
|
||||
class_mapping = np.load(class_mapping_fpath)
|
||||
else:
|
||||
class_mapping = None
|
||||
test_class_mappings.append(class_mapping)
|
||||
|
||||
metrics_file_path = os.path.join(output_dir, "results_eval_linear.json")
|
||||
val_results_dict, feature_model, linear_classifiers, iteration = eval_linear(
|
||||
feature_model=feature_model,
|
||||
linear_classifiers=linear_classifiers,
|
||||
train_data_loader=train_data_loader,
|
||||
val_data_loader=val_data_loader,
|
||||
metrics_file_path=metrics_file_path,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
output_dir=output_dir,
|
||||
max_iter=max_iter,
|
||||
checkpoint_period=checkpoint_period,
|
||||
running_checkpoint_period=epoch_length,
|
||||
eval_period=eval_period_iterations,
|
||||
metric_type=val_metric_type,
|
||||
training_num_classes=training_num_classes,
|
||||
resume=resume,
|
||||
val_class_mapping=val_class_mapping,
|
||||
classifier_fpath=classifier_fpath,
|
||||
)
|
||||
results_dict = {}
|
||||
if len(test_dataset_strs) > 1 or test_dataset_strs[0] != val_dataset_str:
|
||||
results_dict = test_on_datasets(
|
||||
feature_model,
|
||||
linear_classifiers,
|
||||
test_dataset_strs,
|
||||
batch_size,
|
||||
0, # num_workers,
|
||||
test_metric_types,
|
||||
metrics_file_path,
|
||||
training_num_classes,
|
||||
iteration,
|
||||
val_results_dict["best_classifier"]["name"],
|
||||
prefixstring="",
|
||||
test_class_mappings=test_class_mappings,
|
||||
)
|
||||
results_dict["best_classifier"] = val_results_dict["best_classifier"]["name"]
|
||||
results_dict[f"{val_dataset_str}_accuracy"] = 100.0 * val_results_dict["best_classifier"]["accuracy"]
|
||||
logger.info("Test Results Dict " + str(results_dict))
|
||||
|
||||
return results_dict
|
||||
|
||||
|
||||
def main(args):
|
||||
model, autocast_dtype = setup_and_build_model(args)
|
||||
run_eval_linear(
|
||||
model=model,
|
||||
output_dir=args.output_dir,
|
||||
train_dataset_str=args.train_dataset_str,
|
||||
val_dataset_str=args.val_dataset_str,
|
||||
test_dataset_strs=args.test_dataset_strs,
|
||||
batch_size=args.batch_size,
|
||||
epochs=args.epochs,
|
||||
epoch_length=args.epoch_length,
|
||||
num_workers=args.num_workers,
|
||||
save_checkpoint_frequency=args.save_checkpoint_frequency,
|
||||
eval_period_iterations=args.eval_period_iterations,
|
||||
learning_rates=args.learning_rates,
|
||||
autocast_dtype=autocast_dtype,
|
||||
resume=not args.no_resume,
|
||||
classifier_fpath=args.classifier_fpath,
|
||||
val_metric_type=args.val_metric_type,
|
||||
test_metric_types=args.test_metric_types,
|
||||
val_class_mapping_fpath=args.val_class_mapping_fpath,
|
||||
test_class_mapping_fpaths=args.test_class_mapping_fpaths,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
description = "DINOv2 linear evaluation"
|
||||
args_parser = get_args_parser(description=description)
|
||||
args = args_parser.parse_args()
|
||||
sys.exit(main(args))
|
||||
@@ -0,0 +1,445 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
from cuml.linear_model import LogisticRegression
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.distributed
|
||||
from torch import nn
|
||||
from torch.utils.data import TensorDataset
|
||||
from torchmetrics import MetricTracker
|
||||
|
||||
from dinov2.data import make_dataset
|
||||
from dinov2.data.transforms import make_classification_eval_transform
|
||||
from dinov2.distributed import get_global_rank, get_global_size
|
||||
from dinov2.eval.metrics import MetricType, build_metric
|
||||
from dinov2.eval.setup import get_args_parser as get_setup_args_parser
|
||||
from dinov2.eval.setup import setup_and_build_model
|
||||
from dinov2.eval.utils import evaluate, extract_features
|
||||
from dinov2.utils.dtype import as_torch_dtype
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
DEFAULT_MAX_ITER = 1_000
|
||||
C_POWER_RANGE = torch.linspace(-6, 5, 45)
|
||||
_CPU_DEVICE = torch.device("cpu")
|
||||
|
||||
|
||||
def get_args_parser(
|
||||
description: Optional[str] = None,
|
||||
parents: Optional[List[argparse.ArgumentParser]] = None,
|
||||
add_help: bool = True,
|
||||
):
|
||||
parents = parents or []
|
||||
setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
|
||||
parents = [setup_args_parser]
|
||||
parser = argparse.ArgumentParser(
|
||||
description=description,
|
||||
parents=parents,
|
||||
add_help=add_help,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train-dataset",
|
||||
dest="train_dataset_str",
|
||||
type=str,
|
||||
help="Training dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val-dataset",
|
||||
dest="val_dataset_str",
|
||||
type=str,
|
||||
help="Validation dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--finetune-dataset-str",
|
||||
dest="finetune_dataset_str",
|
||||
type=str,
|
||||
help="Fine-tuning dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--finetune-on-val",
|
||||
action="store_true",
|
||||
help="If there is no finetune dataset, whether to choose the "
|
||||
"hyperparameters on the val set instead of 10%% of the train dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--metric-type",
|
||||
type=MetricType,
|
||||
choices=list(MetricType),
|
||||
help="Metric type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train-features-device",
|
||||
type=str,
|
||||
help="Device to gather train features (cpu, cuda, cuda:0, etc.), default: %(default)s",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train-dtype",
|
||||
type=str,
|
||||
help="Data type to convert the train features to (default: %(default)s)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-train-iters",
|
||||
type=int,
|
||||
help="Maximum number of train iterations (default: %(default)s)",
|
||||
)
|
||||
parser.set_defaults(
|
||||
train_dataset_str="ImageNet:split=TRAIN",
|
||||
val_dataset_str="ImageNet:split=VAL",
|
||||
finetune_dataset_str=None,
|
||||
metric_type=MetricType.MEAN_ACCURACY,
|
||||
train_features_device="cpu",
|
||||
train_dtype="float64",
|
||||
max_train_iters=DEFAULT_MAX_ITER,
|
||||
finetune_on_val=False,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
class LogRegModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
C,
|
||||
max_iter=DEFAULT_MAX_ITER,
|
||||
dtype=torch.float64,
|
||||
device=_CPU_DEVICE,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.estimator = LogisticRegression(
|
||||
penalty="l2",
|
||||
C=C,
|
||||
max_iter=max_iter,
|
||||
output_type="numpy",
|
||||
tol=1e-12,
|
||||
linesearch_max_iter=50,
|
||||
)
|
||||
|
||||
def forward(self, samples, targets):
|
||||
samples_device = samples.device
|
||||
samples = samples.to(dtype=self.dtype, device=self.device)
|
||||
if self.device == _CPU_DEVICE:
|
||||
samples = samples.numpy()
|
||||
probas = self.estimator.predict_proba(samples)
|
||||
return {"preds": torch.from_numpy(probas).to(samples_device), "target": targets}
|
||||
|
||||
def fit(self, train_features, train_labels):
|
||||
train_features = train_features.to(dtype=self.dtype, device=self.device)
|
||||
train_labels = train_labels.to(dtype=self.dtype, device=self.device)
|
||||
if self.device == _CPU_DEVICE:
|
||||
# both cuML and sklearn only work with numpy arrays on CPU
|
||||
train_features = train_features.numpy()
|
||||
train_labels = train_labels.numpy()
|
||||
self.estimator.fit(train_features, train_labels)
|
||||
|
||||
|
||||
def evaluate_model(*, logreg_model, logreg_metric, test_data_loader, device):
|
||||
postprocessors = {"metrics": logreg_model}
|
||||
metrics = {"metrics": logreg_metric}
|
||||
return evaluate(nn.Identity(), test_data_loader, postprocessors, metrics, device)
|
||||
|
||||
|
||||
def train_for_C(*, C, max_iter, train_features, train_labels, dtype=torch.float64, device=_CPU_DEVICE):
|
||||
logreg_model = LogRegModule(C, max_iter=max_iter, dtype=dtype, device=device)
|
||||
logreg_model.fit(train_features, train_labels)
|
||||
return logreg_model
|
||||
|
||||
|
||||
def train_and_evaluate(
|
||||
*,
|
||||
C,
|
||||
max_iter,
|
||||
train_features,
|
||||
train_labels,
|
||||
logreg_metric,
|
||||
test_data_loader,
|
||||
train_dtype=torch.float64,
|
||||
train_features_device,
|
||||
eval_device,
|
||||
):
|
||||
logreg_model = train_for_C(
|
||||
C=C,
|
||||
max_iter=max_iter,
|
||||
train_features=train_features,
|
||||
train_labels=train_labels,
|
||||
dtype=train_dtype,
|
||||
device=train_features_device,
|
||||
)
|
||||
return evaluate_model(
|
||||
logreg_model=logreg_model,
|
||||
logreg_metric=logreg_metric,
|
||||
test_data_loader=test_data_loader,
|
||||
device=eval_device,
|
||||
)
|
||||
|
||||
|
||||
def sweep_C_values(
|
||||
*,
|
||||
train_features,
|
||||
train_labels,
|
||||
test_data_loader,
|
||||
metric_type,
|
||||
num_classes,
|
||||
train_dtype=torch.float64,
|
||||
train_features_device=_CPU_DEVICE,
|
||||
max_train_iters=DEFAULT_MAX_ITER,
|
||||
):
|
||||
if metric_type == MetricType.PER_CLASS_ACCURACY:
|
||||
# If we want to output per-class accuracy, we select the hyperparameters with mean per class
|
||||
metric_type = MetricType.MEAN_PER_CLASS_ACCURACY
|
||||
logreg_metric = build_metric(metric_type, num_classes=num_classes)
|
||||
metric_tracker = MetricTracker(logreg_metric, maximize=True)
|
||||
ALL_C = 10**C_POWER_RANGE
|
||||
logreg_models = {}
|
||||
|
||||
train_features = train_features.to(dtype=train_dtype, device=train_features_device)
|
||||
train_labels = train_labels.to(device=train_features_device)
|
||||
|
||||
for i in range(get_global_rank(), len(ALL_C), get_global_size()):
|
||||
C = ALL_C[i].item()
|
||||
logger.info(
|
||||
f"Training for C = {C:.5f}, dtype={train_dtype}, "
|
||||
f"features: {train_features.shape}, {train_features.dtype}, "
|
||||
f"labels: {train_labels.shape}, {train_labels.dtype}"
|
||||
)
|
||||
logreg_models[C] = train_for_C(
|
||||
C=C,
|
||||
max_iter=max_train_iters,
|
||||
train_features=train_features,
|
||||
train_labels=train_labels,
|
||||
dtype=train_dtype,
|
||||
device=train_features_device,
|
||||
)
|
||||
|
||||
gather_list = [None for _ in range(get_global_size())]
|
||||
torch.distributed.all_gather_object(gather_list, logreg_models)
|
||||
|
||||
logreg_models_gathered = {}
|
||||
for logreg_dict in gather_list:
|
||||
logreg_models_gathered.update(logreg_dict)
|
||||
|
||||
for i in range(len(ALL_C)):
|
||||
metric_tracker.increment()
|
||||
C = ALL_C[i].item()
|
||||
evals = evaluate_model(
|
||||
logreg_model=logreg_models_gathered[C],
|
||||
logreg_metric=metric_tracker,
|
||||
test_data_loader=test_data_loader,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
logger.info(f"Trained for C = {C:.5f}, accuracies = {evals}")
|
||||
|
||||
best_stats, which_epoch = metric_tracker.best_metric(return_step=True)
|
||||
best_stats_100 = {k: 100.0 * v for k, v in best_stats.items()}
|
||||
if which_epoch["top-1"] == i:
|
||||
best_C = C
|
||||
logger.info(f"Sweep best {best_stats_100}, best C = {best_C:.6f}")
|
||||
|
||||
return best_stats, best_C
|
||||
|
||||
|
||||
def eval_log_regression(
|
||||
*,
|
||||
model,
|
||||
train_dataset,
|
||||
val_dataset,
|
||||
finetune_dataset,
|
||||
metric_type,
|
||||
batch_size,
|
||||
num_workers,
|
||||
finetune_on_val=False,
|
||||
train_dtype=torch.float64,
|
||||
train_features_device=_CPU_DEVICE,
|
||||
max_train_iters=DEFAULT_MAX_ITER,
|
||||
):
|
||||
"""
|
||||
Implements the "standard" process for log regression evaluation:
|
||||
The value of C is chosen by training on train_dataset and evaluating on
|
||||
finetune_dataset. Then, the final model is trained on a concatenation of
|
||||
train_dataset and finetune_dataset, and is evaluated on val_dataset.
|
||||
If there is no finetune_dataset, the value of C is the one that yields
|
||||
the best results on a random 10% subset of the train dataset
|
||||
"""
|
||||
|
||||
start = time.time()
|
||||
|
||||
train_features, train_labels = extract_features(
|
||||
model, train_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE)
|
||||
)
|
||||
val_features, val_labels = extract_features(
|
||||
model, val_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE)
|
||||
)
|
||||
val_data_loader = torch.utils.data.DataLoader(
|
||||
TensorDataset(val_features, val_labels),
|
||||
batch_size=batch_size,
|
||||
drop_last=False,
|
||||
num_workers=0,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
if finetune_dataset is None and finetune_on_val:
|
||||
logger.info("Choosing hyperparameters on the val dataset")
|
||||
finetune_features, finetune_labels = val_features, val_labels
|
||||
elif finetune_dataset is None and not finetune_on_val:
|
||||
logger.info("Choosing hyperparameters on 10% of the train dataset")
|
||||
torch.manual_seed(0)
|
||||
indices = torch.randperm(len(train_features), device=train_features.device)
|
||||
finetune_index = indices[: len(train_features) // 10]
|
||||
train_index = indices[len(train_features) // 10 :]
|
||||
finetune_features, finetune_labels = train_features[finetune_index], train_labels[finetune_index]
|
||||
train_features, train_labels = train_features[train_index], train_labels[train_index]
|
||||
else:
|
||||
logger.info("Choosing hyperparameters on the finetune dataset")
|
||||
finetune_features, finetune_labels = extract_features(
|
||||
model, finetune_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE)
|
||||
)
|
||||
# release the model - free GPU memory
|
||||
del model
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
finetune_data_loader = torch.utils.data.DataLoader(
|
||||
TensorDataset(finetune_features, finetune_labels),
|
||||
batch_size=batch_size,
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
if len(train_labels.shape) > 1:
|
||||
num_classes = train_labels.shape[1]
|
||||
else:
|
||||
num_classes = train_labels.max() + 1
|
||||
|
||||
logger.info("Using cuML for logistic regression")
|
||||
|
||||
best_stats, best_C = sweep_C_values(
|
||||
train_features=train_features,
|
||||
train_labels=train_labels,
|
||||
test_data_loader=finetune_data_loader,
|
||||
metric_type=metric_type,
|
||||
num_classes=num_classes,
|
||||
train_dtype=train_dtype,
|
||||
train_features_device=train_features_device,
|
||||
max_train_iters=max_train_iters,
|
||||
)
|
||||
|
||||
if not finetune_on_val:
|
||||
logger.info("Best parameter found, concatenating features")
|
||||
train_features = torch.cat((train_features, finetune_features))
|
||||
train_labels = torch.cat((train_labels, finetune_labels))
|
||||
|
||||
logger.info("Training final model")
|
||||
logreg_metric = build_metric(metric_type, num_classes=num_classes)
|
||||
evals = train_and_evaluate(
|
||||
C=best_C,
|
||||
max_iter=max_train_iters,
|
||||
train_features=train_features,
|
||||
train_labels=train_labels,
|
||||
logreg_metric=logreg_metric.clone(),
|
||||
test_data_loader=val_data_loader,
|
||||
eval_device=torch.cuda.current_device(),
|
||||
train_dtype=train_dtype,
|
||||
train_features_device=train_features_device,
|
||||
)
|
||||
|
||||
best_stats = evals[1]["metrics"]
|
||||
|
||||
best_stats["best_C"] = best_C
|
||||
|
||||
logger.info(f"Log regression evaluation done in {int(time.time() - start)}s")
|
||||
return best_stats
|
||||
|
||||
|
||||
def eval_log_regression_with_model(
|
||||
model,
|
||||
train_dataset_str="ImageNet:split=TRAIN",
|
||||
val_dataset_str="ImageNet:split=VAL",
|
||||
finetune_dataset_str=None,
|
||||
autocast_dtype=torch.float,
|
||||
finetune_on_val=False,
|
||||
metric_type=MetricType.MEAN_ACCURACY,
|
||||
train_dtype=torch.float64,
|
||||
train_features_device=_CPU_DEVICE,
|
||||
max_train_iters=DEFAULT_MAX_ITER,
|
||||
):
|
||||
cudnn.benchmark = True
|
||||
|
||||
transform = make_classification_eval_transform(resize_size=224)
|
||||
target_transform = None
|
||||
|
||||
train_dataset = make_dataset(dataset_str=train_dataset_str, transform=transform, target_transform=target_transform)
|
||||
val_dataset = make_dataset(dataset_str=val_dataset_str, transform=transform, target_transform=target_transform)
|
||||
if finetune_dataset_str is not None:
|
||||
finetune_dataset = make_dataset(
|
||||
dataset_str=finetune_dataset_str, transform=transform, target_transform=target_transform
|
||||
)
|
||||
else:
|
||||
finetune_dataset = None
|
||||
|
||||
with torch.cuda.amp.autocast(dtype=autocast_dtype):
|
||||
results_dict_logreg = eval_log_regression(
|
||||
model=model,
|
||||
train_dataset=train_dataset,
|
||||
val_dataset=val_dataset,
|
||||
finetune_dataset=finetune_dataset,
|
||||
metric_type=metric_type,
|
||||
batch_size=256,
|
||||
num_workers=0, # 5,
|
||||
finetune_on_val=finetune_on_val,
|
||||
train_dtype=train_dtype,
|
||||
train_features_device=train_features_device,
|
||||
max_train_iters=max_train_iters,
|
||||
)
|
||||
|
||||
results_dict = {
|
||||
"top-1": results_dict_logreg["top-1"].cpu().numpy() * 100.0,
|
||||
"top-5": results_dict_logreg.get("top-5", torch.tensor(0.0)).cpu().numpy() * 100.0,
|
||||
"best_C": results_dict_logreg["best_C"],
|
||||
}
|
||||
logger.info(
|
||||
"\n".join(
|
||||
[
|
||||
"Training of the supervised logistic regression on frozen features completed.\n"
|
||||
"Top-1 test accuracy: {acc:.1f}".format(acc=results_dict["top-1"]),
|
||||
"Top-5 test accuracy: {acc:.1f}".format(acc=results_dict["top-5"]),
|
||||
"obtained for C = {c:.6f}".format(c=results_dict["best_C"]),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
torch.distributed.barrier()
|
||||
return results_dict
|
||||
|
||||
|
||||
def main(args):
|
||||
model, autocast_dtype = setup_and_build_model(args)
|
||||
eval_log_regression_with_model(
|
||||
model=model,
|
||||
train_dataset_str=args.train_dataset_str,
|
||||
val_dataset_str=args.val_dataset_str,
|
||||
finetune_dataset_str=args.finetune_dataset_str,
|
||||
autocast_dtype=autocast_dtype,
|
||||
finetune_on_val=args.finetune_on_val,
|
||||
metric_type=args.metric_type,
|
||||
train_dtype=as_torch_dtype(args.train_dtype),
|
||||
train_features_device=torch.device(args.train_features_device),
|
||||
max_train_iters=args.max_train_iters,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
description = "DINOv2 logistic regression evaluation"
|
||||
args_parser = get_args_parser(description=description)
|
||||
args = args_parser.parse_args()
|
||||
sys.exit(main(args))
|
||||
@@ -0,0 +1,114 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torchmetrics import Metric, MetricCollection
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
from torchmetrics.utilities.data import dim_zero_cat, select_topk
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
class MetricType(Enum):
|
||||
MEAN_ACCURACY = "mean_accuracy"
|
||||
MEAN_PER_CLASS_ACCURACY = "mean_per_class_accuracy"
|
||||
PER_CLASS_ACCURACY = "per_class_accuracy"
|
||||
IMAGENET_REAL_ACCURACY = "imagenet_real_accuracy"
|
||||
|
||||
@property
|
||||
def accuracy_averaging(self):
|
||||
return getattr(AccuracyAveraging, self.name, None)
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class AccuracyAveraging(Enum):
|
||||
MEAN_ACCURACY = "micro"
|
||||
MEAN_PER_CLASS_ACCURACY = "macro"
|
||||
PER_CLASS_ACCURACY = "none"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
def build_metric(metric_type: MetricType, *, num_classes: int, ks: Optional[tuple] = None):
|
||||
if metric_type.accuracy_averaging is not None:
|
||||
return build_topk_accuracy_metric(
|
||||
average_type=metric_type.accuracy_averaging,
|
||||
num_classes=num_classes,
|
||||
ks=(1, 5) if ks is None else ks,
|
||||
)
|
||||
elif metric_type == MetricType.IMAGENET_REAL_ACCURACY:
|
||||
return build_topk_imagenet_real_accuracy_metric(
|
||||
num_classes=num_classes,
|
||||
ks=(1, 5) if ks is None else ks,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unknown metric type {metric_type}")
|
||||
|
||||
|
||||
def build_topk_accuracy_metric(average_type: AccuracyAveraging, num_classes: int, ks: tuple = (1, 5)):
|
||||
metrics: Dict[str, Metric] = {
|
||||
f"top-{k}": MulticlassAccuracy(top_k=k, num_classes=int(num_classes), average=average_type.value) for k in ks
|
||||
}
|
||||
return MetricCollection(metrics)
|
||||
|
||||
|
||||
def build_topk_imagenet_real_accuracy_metric(num_classes: int, ks: tuple = (1, 5)):
|
||||
metrics: Dict[str, Metric] = {f"top-{k}": ImageNetReaLAccuracy(top_k=k, num_classes=int(num_classes)) for k in ks}
|
||||
return MetricCollection(metrics)
|
||||
|
||||
|
||||
class ImageNetReaLAccuracy(Metric):
|
||||
is_differentiable: bool = False
|
||||
higher_is_better: Optional[bool] = None
|
||||
full_state_update: bool = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_classes: int,
|
||||
top_k: int = 1,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.num_classes = num_classes
|
||||
self.top_k = top_k
|
||||
self.add_state("tp", [], dist_reduce_fx="cat")
|
||||
|
||||
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
|
||||
# preds [B, D]
|
||||
# target [B, A]
|
||||
# preds_oh [B, D] with 0 and 1
|
||||
# select top K highest probabilities, use one hot representation
|
||||
preds_oh = select_topk(preds, self.top_k)
|
||||
# target_oh [B, D + 1] with 0 and 1
|
||||
target_oh = torch.zeros((preds_oh.shape[0], preds_oh.shape[1] + 1), device=target.device, dtype=torch.int32)
|
||||
target = target.long()
|
||||
# for undefined targets (-1) use a fake value `num_classes`
|
||||
target[target == -1] = self.num_classes
|
||||
# fill targets, use one hot representation
|
||||
target_oh.scatter_(1, target, 1)
|
||||
# target_oh [B, D] (remove the fake target at index `num_classes`)
|
||||
target_oh = target_oh[:, :-1]
|
||||
# tp [B] with 0 and 1
|
||||
tp = (preds_oh * target_oh == 1).sum(dim=1)
|
||||
# at least one match between prediction and target
|
||||
tp.clip_(max=1)
|
||||
# ignore instances where no targets are defined
|
||||
mask = target_oh.sum(dim=1) > 0
|
||||
tp = tp[mask]
|
||||
self.tp.append(tp) # type: ignore
|
||||
|
||||
def compute(self) -> Tensor:
|
||||
tp = dim_zero_cat(self.tp) # type: ignore
|
||||
return tp.float().mean()
|
||||
@@ -0,0 +1,76 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
|
||||
from dinov2.models import build_model_from_cfg
|
||||
from dinov2.utils.config import setup
|
||||
import dinov2.utils.utils as dinov2_utils
|
||||
|
||||
|
||||
def get_args_parser(
|
||||
description: Optional[str] = None,
|
||||
parents: Optional[List[argparse.ArgumentParser]] = None,
|
||||
add_help: bool = True,
|
||||
):
|
||||
parser = argparse.ArgumentParser(
|
||||
description=description,
|
||||
parents=parents or [],
|
||||
add_help=add_help,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config-file",
|
||||
type=str,
|
||||
help="Model configuration file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained-weights",
|
||||
type=str,
|
||||
help="Pretrained model weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
default="",
|
||||
type=str,
|
||||
help="Output directory to write results and logs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--opts",
|
||||
help="Extra configuration options",
|
||||
default=[],
|
||||
nargs="+",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def get_autocast_dtype(config):
|
||||
teacher_dtype_str = config.compute_precision.teacher.backbone.mixed_precision.param_dtype
|
||||
if teacher_dtype_str == "fp16":
|
||||
return torch.half
|
||||
elif teacher_dtype_str == "bf16":
|
||||
return torch.bfloat16
|
||||
else:
|
||||
return torch.float
|
||||
|
||||
|
||||
def build_model_for_eval(config, pretrained_weights):
|
||||
model, _ = build_model_from_cfg(config, only_teacher=True)
|
||||
dinov2_utils.load_pretrained_weights(model, pretrained_weights, "teacher")
|
||||
model.eval()
|
||||
model.cuda()
|
||||
return model
|
||||
|
||||
|
||||
def setup_and_build_model(args) -> Tuple[Any, torch.dtype]:
|
||||
cudnn.benchmark = True
|
||||
config = setup(args)
|
||||
model = build_model_for_eval(config, args.pretrained_weights)
|
||||
autocast_dtype = get_autocast_dtype(config)
|
||||
return model, autocast_dtype
|
||||
@@ -0,0 +1,147 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchmetrics import MetricCollection
|
||||
|
||||
from dinov2.data import DatasetWithEnumeratedTargets, SamplerType, make_data_loader
|
||||
import dinov2.distributed as distributed
|
||||
from dinov2.logging import MetricLogger
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
class ModelWithNormalize(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(self, samples):
|
||||
return nn.functional.normalize(self.model(samples), dim=1, p=2)
|
||||
|
||||
|
||||
class ModelWithIntermediateLayers(nn.Module):
|
||||
def __init__(self, feature_model, n_last_blocks, autocast_ctx):
|
||||
super().__init__()
|
||||
self.feature_model = feature_model
|
||||
self.feature_model.eval()
|
||||
self.n_last_blocks = n_last_blocks
|
||||
self.autocast_ctx = autocast_ctx
|
||||
|
||||
def forward(self, images):
|
||||
with torch.inference_mode():
|
||||
with self.autocast_ctx():
|
||||
features = self.feature_model.get_intermediate_layers(
|
||||
images, self.n_last_blocks, return_class_token=True
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def evaluate(
|
||||
model: nn.Module,
|
||||
data_loader,
|
||||
postprocessors: Dict[str, nn.Module],
|
||||
metrics: Dict[str, MetricCollection],
|
||||
device: torch.device,
|
||||
criterion: Optional[nn.Module] = None,
|
||||
):
|
||||
model.eval()
|
||||
if criterion is not None:
|
||||
criterion.eval()
|
||||
|
||||
for metric in metrics.values():
|
||||
metric = metric.to(device)
|
||||
|
||||
metric_logger = MetricLogger(delimiter=" ")
|
||||
header = "Test:"
|
||||
|
||||
for samples, targets, *_ in metric_logger.log_every(data_loader, 10, header):
|
||||
outputs = model(samples.to(device))
|
||||
targets = targets.to(device)
|
||||
|
||||
if criterion is not None:
|
||||
loss = criterion(outputs, targets)
|
||||
metric_logger.update(loss=loss.item())
|
||||
|
||||
for k, metric in metrics.items():
|
||||
metric_inputs = postprocessors[k](outputs, targets)
|
||||
metric.update(**metric_inputs)
|
||||
|
||||
metric_logger.synchronize_between_processes()
|
||||
logger.info(f"Averaged stats: {metric_logger}")
|
||||
|
||||
stats = {k: metric.compute() for k, metric in metrics.items()}
|
||||
metric_logger_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
||||
return metric_logger_stats, stats
|
||||
|
||||
|
||||
def all_gather_and_flatten(tensor_rank):
|
||||
tensor_all_ranks = torch.empty(
|
||||
distributed.get_global_size(),
|
||||
*tensor_rank.shape,
|
||||
dtype=tensor_rank.dtype,
|
||||
device=tensor_rank.device,
|
||||
)
|
||||
tensor_list = list(tensor_all_ranks.unbind(0))
|
||||
torch.distributed.all_gather(tensor_list, tensor_rank.contiguous())
|
||||
return tensor_all_ranks.flatten(end_dim=1)
|
||||
|
||||
|
||||
def extract_features(model, dataset, batch_size, num_workers, gather_on_cpu=False):
|
||||
dataset_with_enumerated_targets = DatasetWithEnumeratedTargets(dataset)
|
||||
sample_count = len(dataset_with_enumerated_targets)
|
||||
data_loader = make_data_loader(
|
||||
dataset=dataset_with_enumerated_targets,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
sampler_type=SamplerType.DISTRIBUTED,
|
||||
drop_last=False,
|
||||
shuffle=False,
|
||||
)
|
||||
return extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu=False):
|
||||
gather_device = torch.device("cpu") if gather_on_cpu else torch.device("cuda")
|
||||
metric_logger = MetricLogger(delimiter=" ")
|
||||
features, all_labels = None, None
|
||||
for samples, (index, labels_rank) in metric_logger.log_every(data_loader, 10):
|
||||
samples = samples.cuda(non_blocking=True)
|
||||
labels_rank = labels_rank.cuda(non_blocking=True)
|
||||
index = index.cuda(non_blocking=True)
|
||||
features_rank = model(samples).float()
|
||||
|
||||
# init storage feature matrix
|
||||
if features is None:
|
||||
features = torch.zeros(sample_count, features_rank.shape[-1], device=gather_device)
|
||||
labels_shape = list(labels_rank.shape)
|
||||
labels_shape[0] = sample_count
|
||||
all_labels = torch.full(labels_shape, fill_value=-1, device=gather_device)
|
||||
logger.info(f"Storing features into tensor of shape {features.shape}")
|
||||
|
||||
# share indexes, features and labels between processes
|
||||
index_all = all_gather_and_flatten(index).to(gather_device)
|
||||
features_all_ranks = all_gather_and_flatten(features_rank).to(gather_device)
|
||||
labels_all_ranks = all_gather_and_flatten(labels_rank).to(gather_device)
|
||||
|
||||
# update storage feature matrix
|
||||
if len(index_all) > 0:
|
||||
features.index_copy_(0, index_all, features_all_ranks)
|
||||
all_labels.index_copy_(0, index_all, labels_all_ranks)
|
||||
|
||||
logger.info(f"Features shape: {tuple(features.shape)}")
|
||||
logger.info(f"Labels shape: {tuple(all_labels.shape)}")
|
||||
|
||||
assert torch.all(all_labels > -1)
|
||||
|
||||
return features, all_labels
|
||||
Reference in New Issue
Block a user