diff --git a/configs/150M/3090.toml b/configs/150M/3090.toml index 1364218e..df111082 100644 --- a/configs/150M/3090.toml +++ b/configs/150M/3090.toml @@ -4,7 +4,6 @@ run_id = "2c774d7c830b49e7855f4f9be6ea4d09" [metric_logger] type = "dummy" -base_url = "https://protocol-api.primeintellect.ai" [train] micro_bs = 16 # change this base on the gpu diff --git a/configs/debug/diloco_http_logger.toml b/configs/debug/diloco_http_logger.toml deleted file mode 100644 index bbba825d..00000000 --- a/configs/debug/diloco_http_logger.toml +++ /dev/null @@ -1,22 +0,0 @@ -name_model = "debugmodel" -project = "/tmp/debug" - -[metric_logger] -type = "http" -base_url = "https://protocol-api.primeintellect.ai" - -[train] -micro_bs = 8 -sharding_strategy = "FULL_SHARD" - -[optim] -batch_size = 16 -warmup_steps = 10 -total_steps = 4 - -[data] -fake = true - -[diloco] -inner_steps = 2 - diff --git a/pyproject.toml b/pyproject.toml index f5b1a711..b691534c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,14 +11,17 @@ dependencies = [ "transformers>=4.44.2", "datasets>=3.0.0", "pydantic_config @ git+https://github.com/samsja/pydantic_config.git@e529c9c", - "einops" ] [project.optional-dependencies] all = [ "wandb", + "einops", + "asyncio>=3.4.3", + "aiohttp>=3.10.5", + "requests>=2.32.3", ] - +g [build-system] requires = ["hatchling"] build-backend = "hatchling.build" diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 3f2c20cf..85ba15e3 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -56,6 +56,9 @@ class TrainConfig(BaseConfig): class MetricLogger(BaseConfig): type: Literal["wandb", "dummy", "http"] = "http" + log_async: bool = True + batch_size: int = 10 + # for http monitor base_url: str | None = None auth_token: str | None = None diff --git a/src/zeroband/utils/monitor.py b/src/zeroband/utils/monitor.py index 23d48222..050fd7c6 100644 --- a/src/zeroband/utils/monitor.py +++ b/src/zeroband/utils/monitor.py @@ -21,7 +21,7 @@ class HttpMonitor: def __init__(self, config, *args, **kwargs): self.data = [] - self.batch_size = getattr(config.progress_logger, 'batch_size', 10) + self.batch_size = getattr(config.metric_logger, 'batch_size', 10) self.run_id = config.get('run_id', 'default_run') self.base_url = config['metric_logger']['base_url'] self.auth_token = config['metric_logger']['auth_token'] @@ -37,19 +37,19 @@ def _remove_duplicates(self): self.data = unique_logs def log(self, data: dict[str, Any]): + import asyncio + # Lowercase the keys in the data dictionary lowercased_data = {k.lower(): v for k, v in data.items()} self.data.append(lowercased_data) if len(self.data) >= self.batch_size: - self._remove_duplicates() # Remove duplicates before sending - self._send_batch() + # do this in a separate thread to not affect training loop + asyncio.create_task(self._send_batch()) + + async def _send_batch(self): + import aiohttp - def _send_batch(self): - import requests - # Remove duplicates before sending self._remove_duplicates() - - # Send batch of logs to API endpoint batch = self.data[:self.batch_size] headers = { "Content-Type": "application/json", @@ -59,13 +59,15 @@ def _send_batch(self): "logs": batch } api = f"{self.base_url}/training_runs/{self.run_id}/logs" - try: - response = requests.post(api, json=payload, headers=headers) - response.raise_for_status() - except requests.RequestException as e: - logger.debug(f"Failed to send batch of logs to http monitor: {e}") - return False - + + async with aiohttp.ClientSession() as session: + try: + async with session.post(api, json=payload, headers=headers) as response: + await response.raise_for_status() + except aiohttp.ClientError as e: + logger.debug(f"Failed to send batch of logs to http monitor: {e}") + return False + self.data = self.data[self.batch_size:] return True @@ -84,9 +86,6 @@ def _finish(self): return False def finish(self): - # Remove duplicates before sending any remaining logs - self._remove_duplicates() - # Send any remaining logs while self.data: self._send_batch() diff --git a/uv.lock b/uv.lock index f0ec766e..33907213 100644 --- a/uv.lock +++ b/uv.lock @@ -131,6 +131,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a7/fa/e01228c2938de91d47b307831c62ab9e4001e747789d0b05baf779a6488c/async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028", size = 5721 }, ] +[[package]] +name = "asyncio" +version = "3.4.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/da/54/054bafaf2c0fb8473d423743e191fcdf49b2c1fd5e9af3524efbe097bafd/asyncio-3.4.3.tar.gz", hash = "sha256:83360ff8bc97980e4ff25c964c7bd3923d333d177aa4f7fb736b019f26c7cb41", size = 204411 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/22/74/07679c5b9f98a7cb0fc147b1ef1cc1853bc07a4eb9cb5731e24732c5f773/asyncio-3.4.3-py3-none-any.whl", hash = "sha256:c4d18b22701821de07bd6aea8b53d21449ec0ec5680645e5317062ea21817d2d", size = 101767 }, +] + [[package]] name = "attrs" version = "24.2.0" @@ -704,7 +713,6 @@ version = "12.1.3.1" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/37/6d/121efd7382d5b0284239f4ab1fc1590d86d34ed4a4a2fdb13b30ca8e5740/nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728", size = 410594774 }, - { url = "https://files.pythonhosted.org/packages/c5/ef/32a375b74bea706c93deea5613552f7c9104f961b21df423f5887eca713b/nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906", size = 439918445 }, ] [[package]] @@ -713,7 +721,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/7e/00/6b218edd739ecfc60524e585ba8e6b00554dd908de2c9c66c1af3e44e18d/nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e", size = 14109015 }, - { url = "https://files.pythonhosted.org/packages/d0/56/0021e32ea2848c24242f6b56790bd0ccc8bf99f973ca790569c6ca028107/nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4", size = 10154340 }, ] [[package]] @@ -722,7 +729,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/b6/9f/c64c03f49d6fbc56196664d05dba14e3a561038a81a638eeb47f4d4cfd48/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2", size = 23671734 }, - { url = "https://files.pythonhosted.org/packages/ad/1d/f76987c4f454eb86e0b9a0e4f57c3bf1ac1d13ad13cd1a4da4eb0e0c0ce9/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed", size = 19331863 }, ] [[package]] @@ -731,7 +737,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/eb/d5/c68b1d2cdfcc59e72e8a5949a37ddb22ae6cade80cd4a57a84d4c8b55472/nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40", size = 823596 }, - { url = "https://files.pythonhosted.org/packages/9f/e2/7a2b4b5064af56ea8ea2d8b2776c0f2960d95c88716138806121ae52a9c9/nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344", size = 821226 }, ] [[package]] @@ -743,7 +748,6 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, - { url = "https://files.pythonhosted.org/packages/3f/d0/f90ee6956a628f9f04bf467932c0a25e5a7e706a684b896593c06c82f460/nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a", size = 679925892 }, ] [[package]] @@ -752,7 +756,6 @@ version = "11.0.2.54" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/86/94/eb540db023ce1d162e7bea9f8f5aa781d57c65aed513c33ee9a5123ead4d/nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56", size = 121635161 }, - { url = "https://files.pythonhosted.org/packages/f7/57/7927a3aa0e19927dfed30256d1c854caf991655d847a4e7c01fe87e3d4ac/nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253", size = 121344196 }, ] [[package]] @@ -761,7 +764,6 @@ version = "10.3.2.106" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/44/31/4890b1c9abc496303412947fc7dcea3d14861720642b49e8ceed89636705/nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0", size = 56467784 }, - { url = "https://files.pythonhosted.org/packages/5c/97/4c9c7c79efcdf5b70374241d48cf03b94ef6707fd18ea0c0f53684931d0b/nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a", size = 55995813 }, ] [[package]] @@ -775,7 +777,6 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 }, - { url = "https://files.pythonhosted.org/packages/b8/80/8fca0bf819122a631c3976b6fc517c1b10741b643b94046bd8dd451522c5/nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5", size = 121643081 }, ] [[package]] @@ -787,7 +788,6 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 }, - { url = "https://files.pythonhosted.org/packages/0f/95/48fdbba24c93614d1ecd35bc6bdc6087bd17cbacc3abc4b05a9c2a1ca232/nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a", size = 195414588 }, ] [[package]] @@ -806,7 +806,6 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/58/8c/69c9e39cd6bfa813852a94e9bd3c075045e2707d163e9dc2326c82d2c330/nvidia_nvjitlink_cu12-12.6.68-py3-none-manylinux2014_aarch64.whl", hash = "sha256:b3fd0779845f68b92063ab1393abab1ed0a23412fc520df79a8190d098b5cd6b", size = 19253287 }, { url = "https://files.pythonhosted.org/packages/a8/48/a9775d377cb95585fb188b469387f58ba6738e268de22eae2ad4cedb2c41/nvidia_nvjitlink_cu12-12.6.68-py3-none-manylinux2014_x86_64.whl", hash = "sha256:125a6c2a44e96386dda634e13d944e60b07a0402d391a070e8fb4104b34ea1ab", size = 19725597 }, - { url = "https://files.pythonhosted.org/packages/00/d5/02af3b39427ed71e8c40b6912271499ec186a72405bcb7e4ca26ff70678c/nvidia_nvjitlink_cu12-12.6.68-py3-none-win_amd64.whl", hash = "sha256:a55744c98d70317c5e23db14866a8cc2b733f7324509e941fc96276f9f37801d", size = 161730369 }, ] [[package]] @@ -815,7 +814,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/da/d3/8057f0587683ed2fcd4dbfbdfdfa807b9160b809976099d36b8f60d08f03/nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5", size = 99138 }, - { url = "https://files.pythonhosted.org/packages/b8/d7/bd7cb2d95ac6ac6e8d05bfa96cdce69619f1ef2808e072919044c2d47a8c/nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82", size = 66307 }, ] [[package]] @@ -1791,10 +1789,13 @@ name = "zeroband" version = "0.1.0" source = { editable = "." } dependencies = [ + { name = "aiohttp" }, + { name = "asyncio" }, { name = "datasets" }, { name = "einops" }, { name = "numpy" }, { name = "pydantic-config" }, + { name = "requests" }, { name = "setuptools" }, { name = "torch" }, { name = "transformers" }, @@ -1814,10 +1815,13 @@ dev = [ [package.metadata] requires-dist = [ + { name = "aiohttp", specifier = ">=3.10.5" }, + { name = "asyncio", specifier = ">=3.4.3" }, { name = "datasets", specifier = ">=3.0.0" }, { name = "einops" }, { name = "numpy" }, { name = "pydantic-config", git = "https://github.com/samsja/pydantic_config.git?rev=e529c9c" }, + { name = "requests", specifier = ">=2.32.3" }, { name = "setuptools" }, { name = "torch", specifier = "==2.4.1" }, { name = "transformers", specifier = ">=4.44.2" },