Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better ElasticDeviceMesh #9

Merged
merged 33 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
0ee6612
refactor: move pg concerns into edm
Jackmin801 Sep 25, 2024
9c32225
working but only rank 0 syncs
Jackmin801 Sep 26, 2024
6da6f10
use fake pg instead of None
Jackmin801 Sep 26, 2024
40a3e2a
testing utils
Jackmin801 Sep 26, 2024
22db84a
syncing correctly but ugly
Jackmin801 Sep 26, 2024
84b41f7
make cpu offload use mmaped file
Jackmin801 Sep 27, 2024
9e79c04
fix: allow none diloco to work with fake pg
Jackmin801 Sep 27, 2024
fa56980
simulate multi node diloco script
Jackmin801 Sep 27, 2024
3077262
docs: update docs
Jackmin801 Sep 27, 2024
e8332c4
remove prints
Jackmin801 Sep 27, 2024
4938bb4
ruff lint
Jackmin801 Sep 27, 2024
52f3903
Merge branch 'main' into feat-edm
Jackmin801 Sep 27, 2024
bcc1d7a
move global info to world info and fix unique id
Jackmin801 Sep 27, 2024
70a2a82
fixes from merge
Jackmin801 Sep 27, 2024
970e1f5
move unique id to world info
Jackmin801 Sep 27, 2024
9c22e20
update command in readme
Jackmin801 Sep 27, 2024
84b7297
remove broadcasts at init
Jackmin801 Sep 27, 2024
7345420
move summon full params to diloco class
Jackmin801 Sep 27, 2024
a02c724
fix data split
Jackmin801 Sep 27, 2024
ccddfe4
move testing to utils
Jackmin801 Sep 27, 2024
ed8d6e8
document offloading logic
Jackmin801 Sep 27, 2024
d5b6e2f
add envs to readme
Jackmin801 Sep 27, 2024
a890e0d
repre for worldinfo
Jackmin801 Sep 28, 2024
63bc0f6
revert to global pg
Jackmin801 Sep 28, 2024
fca8165
set unique id in tests
Jackmin801 Sep 28, 2024
ceb96fc
fix: nccl cannot all reduce same device
Jackmin801 Sep 28, 2024
4b45471
use get module signature instead of model hash
Jackmin801 Sep 28, 2024
fa8d3dd
change default global unique id to none
Jackmin801 Sep 28, 2024
73800d9
revert data changes
Jackmin801 Sep 28, 2024
e64eb2d
make /dev/shm/zeroband a constant and some fixes
Jackmin801 Sep 28, 2024
d41de80
revert shm offload
Jackmin801 Sep 28, 2024
9c39401
fix: non zero rank need to reduce too
Jackmin801 Sep 28, 2024
e21a048
remove testing
Jackmin801 Sep 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
ZeroBand is a production ready codebase for decentralized training of LLM


## developlment
## Development

install uv

Expand Down Expand Up @@ -40,28 +40,28 @@ run your code using
uv run ...
```

## quick check
## Quick check

To check that everything is working you can do

```bash
ZERO_BAND_LOG_LEVEL=DEBUG torchrun --nproc_per_node=2 src/zeroband/train.py @configs/debug/normal.toml
ZERO_BAND_LOG_LEVEL=DEBUG torchrun --nproc_per_node=2 src/zeroband/train.py @configs/debug/normal.toml
```

## run diloco
## Run diloco

To run diloco locally you can use the helper script `scripts/simulatsimulate_multi_nodee_mutl.sh`

:note: you need 4 gpus to run the following command

```bash
ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node.sh 2 2 src/zeroband/train.py @configs/debug/diloco.toml
ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node_diloco.sh 2 2 src/zeroband/train.py @configs/debug/diloco.toml
```

if you have only two gpus

```bash
ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node.sh 2 1 src/zeroband/train.py @configs/debug/diloco.toml
ZERO_BAND_LOG_LEVEL=DEBUG ./scripts/simulate_multi_node_diloco.sh 2 1 src/zeroband/train.py @configs/debug/diloco.toml
```

One gpu is not supported at the moment because of a fsdp bug in our implementation.
Expand All @@ -71,8 +71,15 @@ One gpu is not supported at the moment because of a fsdp bug in our implementati
You need a machine with a least two gpus to run the full test suite.

Some test must be run from the root directory.

```bash
uv run pytest
```

## Environment variables
| Environment Variable | Description | Default Value |
|-----------------------|--------------------------------------------------|---------------|
| `GLOBAL_UNIQUE_ID` | Unique identifier worker in global store. | `None` |
| `GLOBAL_ADDR` | IP Address of the global store | `None` |
| `GLOBAL_PORT` | Port number of the global store. | `None` |
| `GLOBAL_WORLD_SIZE` | The size of the global process group. | `1` |
| `GLOBAL_RANK` | Rank of the process in the global process group. | `0` |
69 changes: 69 additions & 0 deletions scripts/simulate_multi_node_diloco.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/bin/bash

#
# simulate multi nodes on one gpu. start N torchrun on X gpu locally.
# example how to run ./scripts/simulate_multi_node.sh 2 1 src/zeroband/train.py @configs/debug/normal.toml

# Function to get CUDA devices based on the number of GPUs and index
function get_cuda_devices() {
local num_gpu=$1
local index=$2
local start_gpu=$((num_gpu * index))
local end_gpu=$((start_gpu + num_gpu - 1))

if [ "$num_gpu" -eq 1 ]; then
echo $start_gpu
else
echo $(seq -s ',' $start_gpu $end_gpu)
fi
}

# Array to store PIDs of child processes
child_pids=()

# Function to kill all child processes
cleanup() {
echo "Cleaning up child processes..."
local killed=0
for pid in "${child_pids[@]}"; do
if kill -TERM "$pid" 2>/dev/null; then
((killed++))
fi
done
wait
echo "All child processes terminated. Killed $killed processes."
exit
}

# Check if at least three arguments were passed
if [ "$#" -lt 3 ]; then
echo "Usage: $0 <N> <initial_peer> <num_gpu> [additional_python_args]"
exit 1
fi


N=$1 # Set N from the first argument
NUM_GPU=$2
shift 2 # Remove the first three arguments so $@ contains only additional Python arguments

# Register the cleanup function to be called on SIGINT (Ctrl+C)
trap cleanup SIGINT


mkdir -p logs

export GLOBAL_ADDR=localhost
export GLOBAL_PORT=10000
export GLOBAL_WORLD_SIZE=$N

for i in $(seq 0 $(($N - 1 )))
do
> logs/log$i
GLOBAL_UNIQUE_ID=$i GLOBAL_RANK=$i CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank 0 --rdzv-endpoint localhost:$((10001 + $i)) --nnodes=1 $@ > logs/log$i 2>&1 &
child_pids+=($!)
done

tail -f logs/log0 &
child_pids+=($!)

wait
230 changes: 230 additions & 0 deletions src/zeroband/comms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
from torch.distributed.device_mesh import init_device_mesh
from zeroband.utils.world_info import get_world_info
from zeroband.utils.logging import get_logger
import torch.distributed as dist
from datetime import timedelta
import time
from typing import List, Tuple, Optional
from torch.testing._internal.distributed.fake_pg import FakeProcessGroup


TCPSTORE_TIMEOUT = timedelta(seconds=10)
MAX_JOINERS = 100 # Maximum number of nodes that can join in a single reinit
MAX_LEAVERS = 100 # Maximum number of nodes that can leave in a single reinit


def _wait_for_status(store: dist.Store, status: Optional[str] = None) -> str:
while True:
try:
ret = store.get("status").decode("utf-8")
if status is None or ret == status:
return ret
time.sleep(0.1)
except dist.DistStoreError as e:
if status is not None:
raise e
time.sleep(0.1)


def _queue_join(store: dist.Store, unique_id: str):
for i in range(MAX_JOINERS):
joiner_id = store.get(f"joiner_{i}").decode("utf-8")
if joiner_id == "null":
store.set(f"joiner_{i}", unique_id)
store.set(f"joiner_{i + 1}", "null")
break
else:
raise RuntimeError("Too many joiners")


def _queue_leave(store: dist.Store, unique_id: str):
for i in range(MAX_LEAVERS):
leaver_id = store.get(f"leaver_{i}").decode("utf-8")
if leaver_id == "null":
store.set(f"leaver_{i}", unique_id)
store.set(f"leaver_{i + 1}", "null")
break
else:
raise RuntimeError("Too many leavers")


def _get_joiners_and_leavers(store: dist.Store) -> Tuple[List[str], List[str]]:
joiners = []
leavers = []
for i in range(MAX_JOINERS):
joiner_id = store.get(f"joiner_{i}").decode("utf-8")
if joiner_id == "null":
break
joiners.append(joiner_id)
for i in range(MAX_LEAVERS):
leaver_id = store.get(f"leaver_{i}").decode("utf-8")
if leaver_id == "null":
break
leavers.append(leaver_id)
print(f"Joiners: {joiners}, Leavers: {leavers}")
return joiners, leavers


def _clear_joiners_and_leavers(store: dist.Store):
store.set("joiner_0", "null")
store.set("leaver_0", "null")


class ElasticDeviceMesh:
"""A class to manage the process groups for elastic training without restarts.

The way it works is rank 0 coordinates the joining and leaving of nodes.
Rank 0 manages the status to coordinate the creation and recreation of the process groups.
When a node wants to join, rank 0 will setup the store so that all nodes know the new world size and their respective ranks.

Store keys used:
- status: "init", "running", "reinit"
- world_size: The current world size
- mesh_count: The version of the mesh
- rank_{uuid}: The rank of the node with the given uuid
- rank_map_{rank}: The new rank of the node with the given rank. Used to remap ranks when nodes leave.
- joiner_{i}: The uuid of the ith joiner. Its a KV implmentation of a queue.
- leaver_{i}: The uuid of the ith leaver. Its a KV implmentation of a queue.
"""

local_pg: dist.ProcessGroup
global_pg: dist.ProcessGroup

def __init__(self):
self._logger = get_logger()
self.world_info = get_world_info()

# Initialize global process group
self.global_pg = FakeProcessGroup(self.world_info.rank, 1)
if self.world_info.global_world_size > 1:
self.global_pg = self._init_global_pg()

# Initialize local process group
dist.init_process_group(backend="cpu:gloo,cuda:nccl")
self._device_mesh = init_device_mesh(
"cuda",
(self.world_info.nnodes, self.world_info.local_world_size),
mesh_dim_names=("internode", "intranode"),
)
self.local_pg = self._device_mesh.get_group("intranode")

if self.world_info.rank == 0:
self._logger.debug(f"global pg world : {self.global_pg.size()}, local pg: {self.local_pg.size()}")
else:
self._logger.debug(f"local pg world : {self.local_pg.size()}")

def __del__(self):
dist.destroy_process_group()

def _init_global_pg(self) -> dist.Store:
store = dist.TCPStore(
host_name=self.world_info.global_addr,
port=self.world_info.global_port + self.world_info.rank,
timeout=TCPSTORE_TIMEOUT,
is_master=(self.world_info.global_rank == 0),
)

# Initialize store
if self.world_info.global_rank == 0:
store.set("mesh_count", "0")
store.set("joiner_0", "null")
store.set("leaver_0", "null")
store.set("status", "init")
status = "init"
else:
status = _wait_for_status(store)

if status == "init":
# First time initialization
self.mesh_count = 0
self.prefix_store = dist.PrefixStore("mesh_0", store)
pg = dist.ProcessGroupGloo(
self.prefix_store, self.world_info.global_rank, self.world_info.global_world_size, TCPSTORE_TIMEOUT
)
if self.world_info.global_rank == 0:
store.set("status", "running")
store.set(f"rank_{self.world_info.global_unique_id}", str(self.world_info.global_rank))
elif status == "running":
# Node wants to join
_queue_join(store, self.world_info.global_unique_id)
_wait_for_status(store, "reinit")
# Get assigned rank
self.world_info.global_rank = int(store.get(f"rank_{self.world_info.global_unique_id}").decode("utf-8"))
# Get updated world_size
self.world_info.global_world_size = int(store.get("world_size").decode("utf-8"))
self.mesh_count = int(store.get("mesh_count").decode("utf-8"))
self.prefix_store = dist.PrefixStore(f"mesh_{self.mesh_count}", store)
pg = dist.ProcessGroupGloo(
self.prefix_store, self.world_info.global_rank, self.world_info.global_world_size, TCPSTORE_TIMEOUT
)
else:
# TODO: Could be in "reinit" status
raise RuntimeError(f"Unknown status {status}")

# Setting instance variables
self.global_store = store
self.leaving = False
return pg

def _resolve_world(self):
"""Set the new world size and ranks for all nodes."""
# Find joiners and leavers
joiners, leavers = _get_joiners_and_leavers(self.global_store)
# If no joiners or leavers, no resolution needed
if len(joiners) == 0 and len(leavers) == 0:
return

# Remap live ranks to smaller world_size caused by leavers
leaving_ranks = {int(self.global_store.get(f"rank_{leaver_id}").decode("utf-8")) for leaver_id in leavers}
live_ranks = [i for i in range(0, self.world_size, self.local_world_size) if i not in leaving_ranks]
for i, rank in enumerate(live_ranks):
self.global_store.set(f"rank_map_{rank}", str(i * self.local_world_size))
new_world_size = len(live_ranks) * self.local_world_size

# Give joiners new ranks
for joiner_id in joiners:
self.global_store.set(f"rank_{joiner_id}", str(new_world_size))
new_world_size += self.local_world_size

# Update world_size
self.global_store.set("world_size", str(new_world_size))
self.global_store.set("mesh_count", str(self.mesh_count + 1))
# Set status to "reinit"
self.global_store.set("status", "reinit")

def maybe_reinit_device_mesh(self):
"""Reinitialize the device mesh if there are joiners or leavers."""
if self.rank == 0:
self._resolve_world()
dist.barrier()
status = self.global_store.get("status").decode("utf-8")
if status == "running":
return

print("Reinitializing device mesh")
dist.destroy_process_group()
print("Destroyed process group")
if self.leaving:
print("Leaving")
return

# Check if we got remapped
prev_uuid_rank = int(self.global_store.get(f"rank_{self.world_info.global_unique_id}").decode("utf-8"))
new_uuid_rank = int(self.global_store.get(f"rank_map_{prev_uuid_rank}").decode("utf-8"))
self.rank = new_uuid_rank + self.local_rank

self.world_size = int(self.global_store.get("world_size").decode("utf-8"))
self.mesh_count = int(self.global_store.get("mesh_count").decode("utf-8"))
self.prefix_store = dist.PrefixStore(f"mesh_{self.mesh_count}", self.global_store)
dist.init_process_group(
backend="cpu:gloo,cuda:nccl", store=self.prefix_store, rank=self.rank, world_size=self.world_size
)

if self.rank == 0:
_clear_joiners_and_leavers(self.global_store)
self.global_store.set("status", "running")
# Update rank if needed (otherwise, the next remap will do the lookup incorrectly)
if self.local_rank == 0 and new_uuid_rank != prev_uuid_rank:
self.global_store.set(f"rank_{self.world_info.global_unique_id}", str(new_uuid_rank))
# Reinitialize sub process groups
self.world_rank = self.rank // self.local_world_size
Loading