Skip to content

Commit

Permalink
Implement HTTPMonitor to send node status and training progress to ge…
Browse files Browse the repository at this point in the history
…neric HTTP Server
  • Loading branch information
manveerxyz committed Sep 28, 2024
1 parent 55e0b71 commit 0d36e74
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 6 deletions.
7 changes: 6 additions & 1 deletion configs/150M/3090.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
name_model = "150M"
project = "debug_150m_zero_band"
run_id = "2c774d7c830b49e7855f4f9be6ea4d09"

[metric_logger]
type = "dummy"
base_url = "https://protocol-api.primeintellect.ai"

[train]
micro_bs = 16 # change this base on the gpu
Expand All @@ -9,4 +14,4 @@ sharding_strategy = "SHARD_GRAD_OP"
batch_size = 512
warmup_steps = 1000
total_steps = 88_000
lr = 4e-4
lr = 4e-4
4 changes: 3 additions & 1 deletion configs/debug/diloco.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
name_model = "debugmodel"
project = "/tmp/debug"
metric_logger_type = "dummy"

[metric_logger]
type = "dummy"

[train]
micro_bs = 8
Expand Down
22 changes: 22 additions & 0 deletions configs/debug/diloco_http_logger.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name_model = "debugmodel"
project = "/tmp/debug"

[metric_logger]
type = "http"
base_url = "https://protocol-api.primeintellect.ai"

[train]
micro_bs = 8
sharding_strategy = "FULL_SHARD"

[optim]
batch_size = 16
warmup_steps = 10
total_steps = 4

[data]
fake = true

[diloco]
inner_steps = 2

4 changes: 3 additions & 1 deletion configs/debug/normal.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
name_model = "debugmodel"
project = "/tmp/debug"
metric_logger_type = "dummy"

[metric_logger]
type = "dummy"

[train]
micro_bs = 8
Expand Down
20 changes: 17 additions & 3 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from contextlib import nullcontext
from typing import Literal
import time

import torch
from pydantic_config import parse_argv, BaseConfig
Expand All @@ -21,7 +22,7 @@
from zeroband.diloco import Diloco, DilocoConfig, ElasticDeviceMesh

from zeroband.utils import PerfCounter, get_model_hash, get_sharding_strategy
from zeroband.utils.monitor import WandbMonitor, DummyMonitor
from zeroband.utils.monitor import WandbMonitor, DummyMonitor, HttpMonitor
from zeroband.data import TEST_VOCAB_SIZE, get_dataloader
from zeroband.models.llama import get_model
from zeroband.utils.world_info import get_world_info
Expand Down Expand Up @@ -53,19 +54,26 @@ class TrainConfig(BaseConfig):
log_model_hash: bool = False


class MetricLogger(BaseConfig):
type: Literal["wandb", "dummy", "http"] = "http"
base_url: str | None = None
auth_token: str | None = None


class Config(BaseConfig):
# main config
name_model: Literal["debugmodel", "150M", "271M", "1B", "7B", "13B", "26B", "70B"] = "150M"
type_model: Literal["llama2", "llama3"] = "llama2"

project: str = "zeroband"
metric_logger_type: Literal["wandb", "dummy"] = "wandb"
run_id: str | None = None

# sub config
diloco: DilocoConfig | None = None
data: DataConfig = DataConfig()
optim: OptimConfig = OptimConfig()
train: TrainConfig
metric_logger: MetricLogger


def train(config: Config):
Expand Down Expand Up @@ -153,7 +161,12 @@ def train(config: Config):
model.train()

if world_info.rank == 0:
logger_cls = WandbMonitor if config.metric_logger_type == "wandb" else DummyMonitor
if config.metric_logger.type == "wandb":
logger_cls = WandbMonitor
elif config.metric_logger.type == "http":
logger_cls = HttpMonitor
else:
logger_cls = DummyMonitor
metric_logger = logger_cls(project=config.project, config=config.model_dump(), resume=False)

train_dataloader_iterator = iter(train_dataloader)
Expand Down Expand Up @@ -209,6 +222,7 @@ def train(config: Config):
"inner_lr": inner_lr,
"Perplexity": torch.exp(loss_batch).item(),
"total_tokens": real_step * config.optim.batch_size * config.data.seq_length,
"time": time.time(),
}
log = f"step: {real_step}, loss: {loss_batch.item():.4f}"

Expand Down
83 changes: 83 additions & 0 deletions src/zeroband/utils/monitor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import pickle
from typing import Any, Protocol
import importlib
from zeroband.utils.logging import get_logger

logger = get_logger()


class Monitor(Protocol):
Expand All @@ -11,6 +14,86 @@ def log(self, metrics: dict[str, Any]): ...
def finish(self): ...


class HttpMonitor:
"""
Logs the status of nodes, and training progress to an API
"""

def __init__(self, config, *args, **kwargs):
self.data = []
self.batch_size = getattr(config.progress_logger, 'batch_size', 10)
self.run_id = config.get('run_id', 'default_run')
self.base_url = config['metric_logger']['base_url']
self.auth_token = config['metric_logger']['auth_token']

def _remove_duplicates(self):
seen = set()
unique_logs = []
for log in self.data:
log_tuple = tuple(sorted(log.items()))
if log_tuple not in seen:
unique_logs.append(log)
seen.add(log_tuple)
self.data = unique_logs

def log(self, data: dict[str, Any]):
# Lowercase the keys in the data dictionary
lowercased_data = {k.lower(): v for k, v in data.items()}
self.data.append(lowercased_data)
if len(self.data) >= self.batch_size:
self._remove_duplicates() # Remove duplicates before sending
self._send_batch()

def _send_batch(self):
import requests
# Remove duplicates before sending
self._remove_duplicates()

# Send batch of logs to API endpoint
batch = self.data[:self.batch_size]
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.auth_token}"
}
payload = {
"logs": batch
}
api = f"{self.base_url}/training_runs/{self.run_id}/logs"
try:
response = requests.post(api, json=payload, headers=headers)
response.raise_for_status()
except requests.RequestException as e:
logger.debug(f"Failed to send batch of logs to http monitor: {e}")
return False

self.data = self.data[self.batch_size:]
return True

def _finish(self):
import requests
headers = {
"Content-Type": "application/json"
}
api = f"{self.base_url}/training_runs/{self.run_id}/finish"
try:
response = requests.post(api, headers=headers)
response.raise_for_status()
return True
except requests.RequestException as e:
logger.debug(f"Failed to send finish signal to http monitor: {e}")
return False

def finish(self):
# Remove duplicates before sending any remaining logs
self._remove_duplicates()

# Send any remaining logs
while self.data:
self._send_batch()

self._finish()


class WandbMonitor:
def __init__(self, project, config, resume: bool):
if importlib.util.find_spec("wandb") is None:
Expand Down

0 comments on commit 0d36e74

Please sign in to comment.