Initial media depth project backup
This commit is contained in:
@@ -0,0 +1,158 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import dinov2.distributed as distributed
|
||||
from functools import partial
|
||||
from fvcore.common.checkpoint import Checkpointer
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp import ShardingStrategy
|
||||
from torch.distributed.fsdp import MixedPrecision
|
||||
from torch.distributed.fsdp import StateDictType
|
||||
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
|
||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
|
||||
from torch.distributed.fsdp._runtime_utils import _reshard
|
||||
|
||||
|
||||
def get_fsdp_wrapper(model_cfg, modules_to_wrap=set()):
|
||||
sharding_strategy_dict = {
|
||||
"NO_SHARD": ShardingStrategy.NO_SHARD,
|
||||
"SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP,
|
||||
"FULL_SHARD": ShardingStrategy.FULL_SHARD,
|
||||
}
|
||||
|
||||
dtype_dict = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
mixed_precision_config = MixedPrecision(
|
||||
param_dtype=dtype_dict[model_cfg.mixed_precision.param_dtype],
|
||||
reduce_dtype=dtype_dict[model_cfg.mixed_precision.reduce_dtype],
|
||||
buffer_dtype=dtype_dict[model_cfg.mixed_precision.buffer_dtype],
|
||||
)
|
||||
|
||||
sharding_strategy_config = sharding_strategy_dict[model_cfg.sharding_strategy]
|
||||
|
||||
local_rank = distributed.get_local_rank()
|
||||
|
||||
fsdp_wrapper = partial(
|
||||
FSDP,
|
||||
sharding_strategy=sharding_strategy_config,
|
||||
mixed_precision=mixed_precision_config,
|
||||
device_id=local_rank,
|
||||
sync_module_states=True,
|
||||
use_orig_params=True,
|
||||
auto_wrap_policy=ModuleWrapPolicy(modules_to_wrap),
|
||||
)
|
||||
return fsdp_wrapper
|
||||
|
||||
|
||||
def is_fsdp(x):
|
||||
return isinstance(x, FSDP)
|
||||
|
||||
|
||||
def is_sharded_fsdp(x):
|
||||
return is_fsdp(x) and x.sharding_strategy is not ShardingStrategy.NO_SHARD
|
||||
|
||||
|
||||
def free_if_fsdp(x):
|
||||
if is_sharded_fsdp(x):
|
||||
handles = x._handles
|
||||
true_list = [True for h in handles]
|
||||
_reshard(x, handles, true_list)
|
||||
|
||||
|
||||
def get_fsdp_modules(x):
|
||||
return FSDP.fsdp_modules(x)
|
||||
|
||||
|
||||
def reshard_fsdp_model(x):
|
||||
for m in get_fsdp_modules(x):
|
||||
free_if_fsdp(m)
|
||||
|
||||
|
||||
def rankstr():
|
||||
return f"rank_{distributed.get_global_rank()}"
|
||||
|
||||
|
||||
class FSDPCheckpointer(Checkpointer):
|
||||
def save(self, name: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Dump model and checkpointables to a file.
|
||||
|
||||
Args:
|
||||
name (str): name of the file.
|
||||
kwargs (dict): extra arbitrary data to save.
|
||||
"""
|
||||
if not self.save_dir or not self.save_to_disk:
|
||||
return
|
||||
|
||||
data = {}
|
||||
with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT):
|
||||
data["model"] = self.model.state_dict()
|
||||
|
||||
# data["model"] = self.model.state_dict()
|
||||
for key, obj in self.checkpointables.items():
|
||||
data[key] = obj.state_dict()
|
||||
data.update(kwargs)
|
||||
|
||||
basename = f"{name}.{rankstr()}.pth"
|
||||
save_file = os.path.join(self.save_dir, basename)
|
||||
assert os.path.basename(save_file) == basename, basename
|
||||
self.logger.info("Saving checkpoint to {}".format(save_file))
|
||||
with self.path_manager.open(save_file, "wb") as f:
|
||||
torch.save(data, f)
|
||||
self.tag_last_checkpoint(basename)
|
||||
|
||||
def load(self, *args, **kwargs):
|
||||
with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT):
|
||||
return super().load(*args, **kwargs)
|
||||
|
||||
def has_checkpoint(self) -> bool:
|
||||
"""
|
||||
Returns:
|
||||
bool: whether a checkpoint exists in the target directory.
|
||||
"""
|
||||
save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}")
|
||||
return self.path_manager.exists(save_file)
|
||||
|
||||
def get_checkpoint_file(self) -> str:
|
||||
"""
|
||||
Returns:
|
||||
str: The latest checkpoint file in target directory.
|
||||
"""
|
||||
save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}")
|
||||
try:
|
||||
with self.path_manager.open(save_file, "r") as f:
|
||||
last_saved = f.read().strip()
|
||||
except IOError:
|
||||
# if file doesn't exist, maybe because it has just been
|
||||
# deleted by a separate process
|
||||
return ""
|
||||
# pyre-fixme[6]: For 2nd param expected `Union[PathLike[str], str]` but got
|
||||
# `Union[bytes, str]`.
|
||||
return os.path.join(self.save_dir, last_saved)
|
||||
|
||||
def tag_last_checkpoint(self, last_filename_basename: str) -> None:
|
||||
"""
|
||||
Tag the last checkpoint.
|
||||
|
||||
Args:
|
||||
last_filename_basename (str): the basename of the last filename.
|
||||
"""
|
||||
if distributed.is_enabled():
|
||||
torch.distributed.barrier()
|
||||
save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}")
|
||||
with self.path_manager.open(save_file, "w") as f:
|
||||
f.write(last_filename_basename) # pyre-ignore
|
||||
|
||||
|
||||
ShardedGradScaler = ShardedGradScaler
|
||||
Reference in New Issue
Block a user