Initial media depth project backup
This commit is contained in:
@@ -0,0 +1,271 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import socket
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
_LOCAL_RANK = -1
|
||||
_LOCAL_WORLD_SIZE = -1
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
"""
|
||||
Returns:
|
||||
True if distributed training is enabled
|
||||
"""
|
||||
return dist.is_available() and dist.is_initialized()
|
||||
|
||||
|
||||
def get_global_size() -> int:
|
||||
"""
|
||||
Returns:
|
||||
The number of processes in the process group
|
||||
"""
|
||||
return dist.get_world_size() if is_enabled() else 1
|
||||
|
||||
|
||||
def get_global_rank() -> int:
|
||||
"""
|
||||
Returns:
|
||||
The rank of the current process within the global process group.
|
||||
"""
|
||||
return dist.get_rank() if is_enabled() else 0
|
||||
|
||||
|
||||
def get_local_rank() -> int:
|
||||
"""
|
||||
Returns:
|
||||
The rank of the current process within the local (per-machine) process group.
|
||||
"""
|
||||
if not is_enabled():
|
||||
return 0
|
||||
assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
|
||||
return _LOCAL_RANK
|
||||
|
||||
|
||||
def get_local_size() -> int:
|
||||
"""
|
||||
Returns:
|
||||
The size of the per-machine process group,
|
||||
i.e. the number of processes per machine.
|
||||
"""
|
||||
if not is_enabled():
|
||||
return 1
|
||||
assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
|
||||
return _LOCAL_WORLD_SIZE
|
||||
|
||||
|
||||
def is_main_process() -> bool:
|
||||
"""
|
||||
Returns:
|
||||
True if the current process is the main one.
|
||||
"""
|
||||
return get_global_rank() == 0
|
||||
|
||||
|
||||
def _restrict_print_to_main_process() -> None:
|
||||
"""
|
||||
This function disables printing when not in the main process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop("force", False)
|
||||
if is_main_process() or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def _get_master_port(seed: int = 0) -> int:
|
||||
MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000)
|
||||
|
||||
master_port_str = os.environ.get("MASTER_PORT")
|
||||
if master_port_str is None:
|
||||
rng = random.Random(seed)
|
||||
return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT)
|
||||
|
||||
return int(master_port_str)
|
||||
|
||||
|
||||
def _get_available_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
# A "" host address means INADDR_ANY i.e. binding to all interfaces.
|
||||
# Note this is not compatible with IPv6.
|
||||
s.bind(("", 0))
|
||||
port = s.getsockname()[1]
|
||||
return port
|
||||
|
||||
|
||||
_TORCH_DISTRIBUTED_ENV_VARS = (
|
||||
"MASTER_ADDR",
|
||||
"MASTER_PORT",
|
||||
"RANK",
|
||||
"WORLD_SIZE",
|
||||
"LOCAL_RANK",
|
||||
"LOCAL_WORLD_SIZE",
|
||||
)
|
||||
|
||||
|
||||
def _collect_env_vars() -> Dict[str, str]:
|
||||
return {env_var: os.environ[env_var] for env_var in _TORCH_DISTRIBUTED_ENV_VARS if env_var in os.environ}
|
||||
|
||||
|
||||
def _is_slurm_job_process() -> bool:
|
||||
return "SLURM_JOB_ID" in os.environ
|
||||
|
||||
|
||||
def _parse_slurm_node_list(s: str) -> List[str]:
|
||||
nodes = []
|
||||
# Extract "hostname", "hostname[1-2,3,4-5]," substrings
|
||||
p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?")
|
||||
for m in p.finditer(s):
|
||||
prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)]
|
||||
for suffix in suffixes.split(","):
|
||||
span = suffix.split("-")
|
||||
if len(span) == 1:
|
||||
nodes.append(prefix + suffix)
|
||||
else:
|
||||
width = len(span[0])
|
||||
start, end = int(span[0]), int(span[1]) + 1
|
||||
nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)])
|
||||
return nodes
|
||||
|
||||
|
||||
def _check_env_variable(key: str, new_value: str):
|
||||
# Only check for difference with preset environment variables
|
||||
if key in os.environ and os.environ[key] != new_value:
|
||||
raise RuntimeError(f"Cannot export environment variables as {key} is already set")
|
||||
|
||||
|
||||
class _TorchDistributedEnvironment:
|
||||
def __init__(self):
|
||||
self.master_addr = "127.0.0.1"
|
||||
self.master_port = 0
|
||||
self.rank = -1
|
||||
self.world_size = -1
|
||||
self.local_rank = -1
|
||||
self.local_world_size = -1
|
||||
|
||||
if _is_slurm_job_process():
|
||||
return self._set_from_slurm_env()
|
||||
|
||||
env_vars = _collect_env_vars()
|
||||
if not env_vars:
|
||||
# Environment is not set
|
||||
pass
|
||||
elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS):
|
||||
# Environment is fully set
|
||||
return self._set_from_preset_env()
|
||||
else:
|
||||
# Environment is partially set
|
||||
collected_env_vars = ", ".join(env_vars.keys())
|
||||
raise RuntimeError(f"Partially set environment: {collected_env_vars}")
|
||||
|
||||
if torch.cuda.device_count() > 0:
|
||||
return self._set_from_local()
|
||||
|
||||
raise RuntimeError("Can't initialize PyTorch distributed environment")
|
||||
|
||||
# Slurm job created with sbatch, submitit, etc...
|
||||
def _set_from_slurm_env(self):
|
||||
# logger.info("Initialization from Slurm environment")
|
||||
job_id = int(os.environ["SLURM_JOB_ID"])
|
||||
node_count = int(os.environ["SLURM_JOB_NUM_NODES"])
|
||||
nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"])
|
||||
assert len(nodes) == node_count
|
||||
|
||||
self.master_addr = nodes[0]
|
||||
self.master_port = _get_master_port(seed=job_id)
|
||||
self.rank = int(os.environ["SLURM_PROCID"])
|
||||
self.world_size = int(os.environ["SLURM_NTASKS"])
|
||||
assert self.rank < self.world_size
|
||||
self.local_rank = int(os.environ["SLURM_LOCALID"])
|
||||
self.local_world_size = self.world_size // node_count
|
||||
assert self.local_rank < self.local_world_size
|
||||
|
||||
# Single node job with preset environment (i.e. torchrun)
|
||||
def _set_from_preset_env(self):
|
||||
# logger.info("Initialization from preset environment")
|
||||
self.master_addr = os.environ["MASTER_ADDR"]
|
||||
self.master_port = os.environ["MASTER_PORT"]
|
||||
self.rank = int(os.environ["RANK"])
|
||||
self.world_size = int(os.environ["WORLD_SIZE"])
|
||||
assert self.rank < self.world_size
|
||||
self.local_rank = int(os.environ["LOCAL_RANK"])
|
||||
self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
|
||||
assert self.local_rank < self.local_world_size
|
||||
|
||||
# Single node and GPU job (i.e. local script run)
|
||||
def _set_from_local(self):
|
||||
# logger.info("Initialization from local")
|
||||
self.master_addr = "127.0.0.1"
|
||||
self.master_port = _get_available_port()
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
self.local_rank = 0
|
||||
self.local_world_size = 1
|
||||
|
||||
def export(self, *, overwrite: bool) -> "_TorchDistributedEnvironment":
|
||||
# See the "Environment variable initialization" section from
|
||||
# https://pytorch.org/docs/stable/distributed.html for the complete list of
|
||||
# environment variables required for the env:// initialization method.
|
||||
env_vars = {
|
||||
"MASTER_ADDR": self.master_addr,
|
||||
"MASTER_PORT": str(self.master_port),
|
||||
"RANK": str(self.rank),
|
||||
"WORLD_SIZE": str(self.world_size),
|
||||
"LOCAL_RANK": str(self.local_rank),
|
||||
"LOCAL_WORLD_SIZE": str(self.local_world_size),
|
||||
}
|
||||
if not overwrite:
|
||||
for k, v in env_vars.items():
|
||||
_check_env_variable(k, v)
|
||||
|
||||
os.environ.update(env_vars)
|
||||
return self
|
||||
|
||||
|
||||
def enable(*, set_cuda_current_device: bool = True, overwrite: bool = False, allow_nccl_timeout: bool = False):
|
||||
"""Enable distributed mode
|
||||
|
||||
Args:
|
||||
set_cuda_current_device: If True, call torch.cuda.set_device() to set the
|
||||
current PyTorch CUDA device to the one matching the local rank.
|
||||
overwrite: If True, overwrites already set variables. Else fails.
|
||||
"""
|
||||
|
||||
global _LOCAL_RANK, _LOCAL_WORLD_SIZE
|
||||
if _LOCAL_RANK >= 0 or _LOCAL_WORLD_SIZE >= 0:
|
||||
raise RuntimeError("Distributed mode has already been enabled")
|
||||
torch_env = _TorchDistributedEnvironment()
|
||||
torch_env.export(overwrite=overwrite)
|
||||
|
||||
if set_cuda_current_device:
|
||||
torch.cuda.set_device(torch_env.local_rank)
|
||||
|
||||
if allow_nccl_timeout:
|
||||
# This allows to use torch distributed timeout in a NCCL backend
|
||||
key, value = "NCCL_ASYNC_ERROR_HANDLING", "1"
|
||||
if not overwrite:
|
||||
_check_env_variable(key, value)
|
||||
os.environ[key] = value
|
||||
|
||||
dist.init_process_group(backend="nccl")
|
||||
dist.barrier()
|
||||
|
||||
# Finalize setup
|
||||
_LOCAL_RANK = torch_env.local_rank
|
||||
_LOCAL_WORLD_SIZE = torch_env.local_world_size
|
||||
_restrict_print_to_main_process()
|
||||
Reference in New Issue
Block a user