Initial media depth project backup
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
@@ -0,0 +1,96 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
class ClusterType(Enum):
|
||||
AWS = "aws"
|
||||
FAIR = "fair"
|
||||
RSC = "rsc"
|
||||
|
||||
|
||||
def _guess_cluster_type() -> ClusterType:
|
||||
uname = os.uname()
|
||||
if uname.sysname == "Linux":
|
||||
if uname.release.endswith("-aws"):
|
||||
# Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
|
||||
return ClusterType.AWS
|
||||
elif uname.nodename.startswith("rsc"):
|
||||
# Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
|
||||
return ClusterType.RSC
|
||||
|
||||
return ClusterType.FAIR
|
||||
|
||||
|
||||
def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
|
||||
if cluster_type is None:
|
||||
return _guess_cluster_type()
|
||||
|
||||
return cluster_type
|
||||
|
||||
|
||||
def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
|
||||
cluster_type = get_cluster_type(cluster_type)
|
||||
if cluster_type is None:
|
||||
return None
|
||||
|
||||
CHECKPOINT_DIRNAMES = {
|
||||
ClusterType.AWS: "checkpoints",
|
||||
ClusterType.FAIR: "checkpoint",
|
||||
ClusterType.RSC: "checkpoint/dino",
|
||||
}
|
||||
return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
|
||||
|
||||
|
||||
def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
|
||||
checkpoint_path = get_checkpoint_path(cluster_type)
|
||||
if checkpoint_path is None:
|
||||
return None
|
||||
|
||||
username = os.environ.get("USER")
|
||||
assert username is not None
|
||||
return checkpoint_path / username
|
||||
|
||||
|
||||
def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
|
||||
cluster_type = get_cluster_type(cluster_type)
|
||||
if cluster_type is None:
|
||||
return None
|
||||
|
||||
SLURM_PARTITIONS = {
|
||||
ClusterType.AWS: "learnlab",
|
||||
ClusterType.FAIR: "learnlab",
|
||||
ClusterType.RSC: "learn",
|
||||
}
|
||||
return SLURM_PARTITIONS[cluster_type]
|
||||
|
||||
|
||||
def get_slurm_executor_parameters(
|
||||
nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
# create default parameters
|
||||
params = {
|
||||
"mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
|
||||
"gpus_per_node": num_gpus_per_node,
|
||||
"tasks_per_node": num_gpus_per_node, # one task per GPU
|
||||
"cpus_per_task": 10,
|
||||
"nodes": nodes,
|
||||
"slurm_partition": get_slurm_partition(cluster_type),
|
||||
}
|
||||
# apply cluster-specific adjustments
|
||||
cluster_type = get_cluster_type(cluster_type)
|
||||
if cluster_type == ClusterType.AWS:
|
||||
params["cpus_per_task"] = 12
|
||||
del params["mem_gb"]
|
||||
elif cluster_type == ClusterType.RSC:
|
||||
params["cpus_per_task"] = 12
|
||||
# set additional parameters / apply overrides
|
||||
params.update(kwargs)
|
||||
return params
|
||||
@@ -0,0 +1,73 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import logging
|
||||
import os
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import dinov2.distributed as distributed
|
||||
from dinov2.logging import setup_logging
|
||||
from dinov2.utils import utils
|
||||
from dinov2.configs import dinov2_default_config
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
def apply_scaling_rules_to_cfg(cfg): # to fix
|
||||
if cfg.optim.scaling_rule == "sqrt_wrt_1024":
|
||||
base_lr = cfg.optim.base_lr
|
||||
cfg.optim.lr = base_lr
|
||||
cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
|
||||
logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return cfg
|
||||
|
||||
|
||||
def write_config(cfg, output_dir, name="config.yaml"):
|
||||
logger.info(OmegaConf.to_yaml(cfg))
|
||||
saved_cfg_path = os.path.join(output_dir, name)
|
||||
with open(saved_cfg_path, "w") as f:
|
||||
OmegaConf.save(config=cfg, f=f)
|
||||
return saved_cfg_path
|
||||
|
||||
|
||||
def get_cfg_from_args(args):
|
||||
args.output_dir = os.path.abspath(args.output_dir)
|
||||
args.opts += [f"train.output_dir={args.output_dir}"]
|
||||
default_cfg = OmegaConf.create(dinov2_default_config)
|
||||
cfg = OmegaConf.load(args.config_file)
|
||||
cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
|
||||
return cfg
|
||||
|
||||
|
||||
def default_setup(args):
|
||||
distributed.enable(overwrite=True)
|
||||
seed = getattr(args, "seed", 0)
|
||||
rank = distributed.get_global_rank()
|
||||
|
||||
global logger
|
||||
setup_logging(output=args.output_dir, level=logging.INFO)
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
utils.fix_random_seeds(seed + rank)
|
||||
logger.info("git:\n {}\n".format(utils.get_sha()))
|
||||
logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
|
||||
|
||||
|
||||
def setup(args):
|
||||
"""
|
||||
Create configs and perform basic setups.
|
||||
"""
|
||||
cfg = get_cfg_from_args(args)
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
default_setup(args)
|
||||
apply_scaling_rules_to_cfg(cfg)
|
||||
write_config(cfg, args.output_dir)
|
||||
return cfg
|
||||
@@ -0,0 +1,38 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Dict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
TypeSpec = Union[str, np.dtype, torch.dtype]
|
||||
|
||||
|
||||
_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
|
||||
np.dtype("bool"): torch.bool,
|
||||
np.dtype("uint8"): torch.uint8,
|
||||
np.dtype("int8"): torch.int8,
|
||||
np.dtype("int16"): torch.int16,
|
||||
np.dtype("int32"): torch.int32,
|
||||
np.dtype("int64"): torch.int64,
|
||||
np.dtype("float16"): torch.float16,
|
||||
np.dtype("float32"): torch.float32,
|
||||
np.dtype("float64"): torch.float64,
|
||||
np.dtype("complex64"): torch.complex64,
|
||||
np.dtype("complex128"): torch.complex128,
|
||||
}
|
||||
|
||||
|
||||
def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
|
||||
if isinstance(dtype, torch.dtype):
|
||||
return dtype
|
||||
if isinstance(dtype, str):
|
||||
dtype = np.dtype(dtype)
|
||||
assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}"
|
||||
return _NUMPY_TO_TORCH_DTYPE[dtype]
|
||||
@@ -0,0 +1,94 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
|
||||
"""
|
||||
Calculate lr decay rate for different ViT blocks.
|
||||
Args:
|
||||
name (string): parameter name.
|
||||
lr_decay_rate (float): base lr decay rate.
|
||||
num_layers (int): number of ViT blocks.
|
||||
Returns:
|
||||
lr decay rate for the given parameter.
|
||||
"""
|
||||
layer_id = num_layers + 1
|
||||
if name.startswith("backbone") or force_is_backbone:
|
||||
if ".pos_embed" in name or ".patch_embed" in name or ".mask_token" in name or ".cls_token" in name:
|
||||
layer_id = 0
|
||||
elif force_is_backbone and (
|
||||
"pos_embed" in name or "patch_embed" in name or "mask_token" in name or "cls_token" in name
|
||||
):
|
||||
layer_id = 0
|
||||
elif ".blocks." in name and ".residual." not in name:
|
||||
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
|
||||
elif chunked_blocks and "blocks." in name and "residual." not in name:
|
||||
layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
|
||||
elif "blocks." in name and "residual." not in name:
|
||||
layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
|
||||
|
||||
return lr_decay_rate ** (num_layers + 1 - layer_id)
|
||||
|
||||
|
||||
def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
|
||||
chunked_blocks = False
|
||||
if hasattr(model, "n_blocks"):
|
||||
logger.info("chunked fsdp")
|
||||
n_blocks = model.n_blocks
|
||||
chunked_blocks = model.chunked_blocks
|
||||
elif hasattr(model, "blocks"):
|
||||
logger.info("first code branch")
|
||||
n_blocks = len(model.blocks)
|
||||
elif hasattr(model, "backbone"):
|
||||
logger.info("second code branch")
|
||||
n_blocks = len(model.backbone.blocks)
|
||||
else:
|
||||
logger.info("else code branch")
|
||||
n_blocks = 0
|
||||
all_param_groups = []
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
name = name.replace("_fsdp_wrapped_module.", "")
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
decay_rate = get_vit_lr_decay_rate(
|
||||
name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
|
||||
)
|
||||
d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
|
||||
|
||||
if "last_layer" in name:
|
||||
d.update({"is_last_layer": True})
|
||||
|
||||
if name.endswith(".bias") or "norm" in name or "gamma" in name:
|
||||
d.update({"wd_multiplier": 0.0})
|
||||
|
||||
if "patch_embed" in name:
|
||||
d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
|
||||
|
||||
all_param_groups.append(d)
|
||||
logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
|
||||
|
||||
return all_param_groups
|
||||
|
||||
|
||||
def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
|
||||
fused_params_groups = defaultdict(lambda: {"params": []})
|
||||
for d in all_params_groups:
|
||||
identifier = ""
|
||||
for k in keys:
|
||||
identifier += k + str(d[k]) + "_"
|
||||
|
||||
for k in keys:
|
||||
fused_params_groups[identifier][k] = d[k]
|
||||
fused_params_groups[identifier]["params"].append(d["params"])
|
||||
|
||||
return fused_params_groups.values()
|
||||
@@ -0,0 +1,96 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
|
||||
if urlparse(pretrained_weights).scheme: # If it looks like an URL
|
||||
state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
|
||||
else:
|
||||
state_dict = torch.load(pretrained_weights, map_location="cpu")
|
||||
if checkpoint_key is not None and checkpoint_key in state_dict:
|
||||
logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
|
||||
state_dict = state_dict[checkpoint_key]
|
||||
# remove `module.` prefix
|
||||
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
||||
# remove `backbone.` prefix induced by multicrop wrapper
|
||||
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
||||
msg = model.load_state_dict(state_dict, strict=False)
|
||||
logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
|
||||
|
||||
|
||||
def fix_random_seeds(seed=31):
|
||||
"""
|
||||
Fix random seeds.
|
||||
"""
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
def get_sha():
|
||||
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
def _run(command):
|
||||
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
|
||||
|
||||
sha = "N/A"
|
||||
diff = "clean"
|
||||
branch = "N/A"
|
||||
try:
|
||||
sha = _run(["git", "rev-parse", "HEAD"])
|
||||
subprocess.check_output(["git", "diff"], cwd=cwd)
|
||||
diff = _run(["git", "diff-index", "HEAD"])
|
||||
diff = "has uncommitted changes" if diff else "clean"
|
||||
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
||||
except Exception:
|
||||
pass
|
||||
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
||||
return message
|
||||
|
||||
|
||||
class CosineScheduler(object):
|
||||
def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
|
||||
super().__init__()
|
||||
self.final_value = final_value
|
||||
self.total_iters = total_iters
|
||||
|
||||
freeze_schedule = np.zeros((freeze_iters))
|
||||
|
||||
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
||||
|
||||
iters = np.arange(total_iters - warmup_iters - freeze_iters)
|
||||
schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
|
||||
self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
|
||||
|
||||
assert len(self.schedule) == self.total_iters
|
||||
|
||||
def __getitem__(self, it):
|
||||
if it >= self.total_iters:
|
||||
return self.final_value
|
||||
else:
|
||||
return self.schedule[it]
|
||||
|
||||
|
||||
def has_batchnorms(model):
|
||||
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, bn_types):
|
||||
return True
|
||||
return False
|
||||
Reference in New Issue
Block a user