Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
manveerxyz committed Sep 29, 2024
1 parent 0d36e74 commit 7bdb1c7
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 54 deletions.
1 change: 0 additions & 1 deletion configs/150M/3090.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ run_id = "2c774d7c830b49e7855f4f9be6ea4d09"

[metric_logger]
type = "dummy"
base_url = "https://protocol-api.primeintellect.ai"

[train]
micro_bs = 16 # change this base on the gpu
Expand Down
22 changes: 0 additions & 22 deletions configs/debug/diloco_http_logger.toml

This file was deleted.

7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@ dependencies = [
"transformers>=4.44.2",
"datasets>=3.0.0",
"pydantic_config @ git+https://github.com/samsja/pydantic_config.git@e529c9c",
"einops"
]

[project.optional-dependencies]
all = [
"wandb",
"einops",
"asyncio>=3.4.3",
"aiohttp>=3.10.5",
"requests>=2.32.3",
]

g
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
Expand Down
3 changes: 3 additions & 0 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class TrainConfig(BaseConfig):

class MetricLogger(BaseConfig):
type: Literal["wandb", "dummy", "http"] = "http"
log_async: bool = True
batch_size: int = 10
# for http monitor
base_url: str | None = None
auth_token: str | None = None

Expand Down
35 changes: 17 additions & 18 deletions src/zeroband/utils/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class HttpMonitor:

def __init__(self, config, *args, **kwargs):
self.data = []
self.batch_size = getattr(config.progress_logger, 'batch_size', 10)
self.batch_size = getattr(config.metric_logger, 'batch_size', 10)
self.run_id = config.get('run_id', 'default_run')
self.base_url = config['metric_logger']['base_url']
self.auth_token = config['metric_logger']['auth_token']
Expand All @@ -37,19 +37,19 @@ def _remove_duplicates(self):
self.data = unique_logs

def log(self, data: dict[str, Any]):
import asyncio

# Lowercase the keys in the data dictionary
lowercased_data = {k.lower(): v for k, v in data.items()}
self.data.append(lowercased_data)
if len(self.data) >= self.batch_size:
self._remove_duplicates() # Remove duplicates before sending
self._send_batch()
# do this in a separate thread to not affect training loop
asyncio.create_task(self._send_batch())

async def _send_batch(self):
import aiohttp

def _send_batch(self):
import requests
# Remove duplicates before sending
self._remove_duplicates()

# Send batch of logs to API endpoint
batch = self.data[:self.batch_size]
headers = {
"Content-Type": "application/json",
Expand All @@ -59,13 +59,15 @@ def _send_batch(self):
"logs": batch
}
api = f"{self.base_url}/training_runs/{self.run_id}/logs"
try:
response = requests.post(api, json=payload, headers=headers)
response.raise_for_status()
except requests.RequestException as e:
logger.debug(f"Failed to send batch of logs to http monitor: {e}")
return False


async with aiohttp.ClientSession() as session:
try:
async with session.post(api, json=payload, headers=headers) as response:
await response.raise_for_status()
except aiohttp.ClientError as e:
logger.debug(f"Failed to send batch of logs to http monitor: {e}")
return False

self.data = self.data[self.batch_size:]
return True

Expand All @@ -84,9 +86,6 @@ def _finish(self):
return False

def finish(self):
# Remove duplicates before sending any remaining logs
self._remove_duplicates()

# Send any remaining logs
while self.data:
self._send_batch()
Expand Down
26 changes: 15 additions & 11 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 7bdb1c7

Please sign in to comment.