Skip to content

Commit

Permalink
Added pytest pre-commit and workflow. Some minor refactoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjhawken committed Jan 26, 2025
1 parent 5aca42d commit 0435125
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 125 deletions.
27 changes: 27 additions & 0 deletions .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: PyTest for Node and Distributed ML

on: [pull_request]

jobs:
test:
runs-on: ubuntu-latest

strategy:
matrix:
python-version: ["3.11", "3.12"]

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt pytest
- name: Run Unit Tests
run: |
pytest --maxfail=5 --disable-warnings --cov=. --cov-report=term
10 changes: 9 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@ repos:
rev: 24.10.0
hooks:
- id: black
#

- repo: local
hooks:
- id: pytest
name: Run fast tests
entry: pytest -m "not slow"
language: system
pass_filenames: false

# - repo: https://github.com/pre-commit/mirrors-mypy
# rev: v1.14.1
# hooks:
Expand Down
11 changes: 5 additions & 6 deletions examples/distributed_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,16 @@
validator = ValidatorNode(
upnp=UPNP, off_chain_test=LOCAL, local_test=LOCAL, print_level=logging.DEBUG
)
time.sleep(
2
) # Temporary sleep for preventing two nodes from starting on the same port and conflicting
# Temporary sleep for preventing two nodes from starting on the same port and conflicting
time.sleep(1)
user = UserNode(
upnp=UPNP, off_chain_test=LOCAL, local_test=LOCAL, print_level=logging.DEBUG
)
time.sleep(2)
time.sleep(1)
worker = WorkerNode(
upnp=UPNP, off_chain_test=LOCAL, local_test=LOCAL, print_level=logging.DEBUG
)
time.sleep(2)
time.sleep(1)

# Get validator node information for connecting
val_key, val_host, val_port = validator.send_request("info", None)
Expand Down Expand Up @@ -91,7 +90,7 @@
distributed_model.train()
for _ in range(5):
distributed_optimizer.zero_grad() # Distributed optimizer calls relay to worker nodes
x = torch.zeros((1, 1), dtype=torch.long)
x = torch.zeros((1, 1))
outputs = distributed_model(x)
outputs = outputs.logits
loss = mse_loss(outputs, outputs)
Expand Down
193 changes: 87 additions & 106 deletions tensorlink/p2p/connection.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
import base64
import gc
import hashlib
from datetime import datetime
from typing import Union
import threading
import logging
import os
import hashlib
import socket
import threading
import base64
import time
import zlib
from datetime import datetime
from typing import Union

CHUNK_SIZE = 2048
import os
import gc


def join_writing_threads(writing_threads):
"""Joins all writing threads."""
for t in writing_threads:
t.join()
writing_threads.clear()
CHUNK_SIZE = 2048


class Connection(threading.Thread):
Expand Down Expand Up @@ -64,104 +59,93 @@ def run(self):
writing_threads = []

while not self.terminate_flag.is_set():
chunk = self.receive_chunk()
if chunk is None: # Indicates a disconnection or error
file_name = f"tmp/streamed_data_{self.host}_{self.port}_{self.main_node.host}_{self.main_node.port}"
try:
chunk = self.sock.recv(
self.chunk_size,
)

except socket.timeout:
# self.main_node.debug_print("connection timeout")
continue

except (ConnectionResetError, ConnectionAbortedError) as e:
# Handle disconnections
self.terminate_flag.set()
self.main_node.debug_print(
f"Connection -> Connection lost: {e}",
colour="bright_red",
level=logging.ERROR,
)
self.main_node.disconnect_node(self.node_id)
break

self.update_last_seen()
except Exception as e:
self.terminate_flag.set()
self.main_node.debug_print(
f"Connection -> unexpected error: {e}",
colour="bright_red",
level=logging.ERROR,
)
self.main_node.disconnect_node(self.node_id)
break

if chunk:
buffer, prefix = self.process_chunk(
chunk, buffer, prefix, writing_threads
)
self.last_seen = datetime.now()

if b"MODULE" == chunk[:6]:
prefix = chunk[:70] # MODULE + module_id
buffer += chunk[70:]
elif b"PARAMETERS" == chunk[:10]:
prefix = chunk[:74] # PARAMETERS + module_id
buffer += chunk[74:]
else:
buffer += chunk

eot_pos = buffer.find(self.EOT_CHAR)

if eot_pos >= 0:
packet = buffer[:eot_pos]
try:
with open(file_name, "ab") as f:
f.write(packet)
except Exception as e:
self.main_node.debug_print(
f"Connection -> file writing error: {e}",
colour="bright_red",
level=logging.ERROR,
)

gc.collect()
buffer = buffer[eot_pos + len(self.EOT_CHAR) :]
self.main_node.handle_message(self, b"DONE STREAM" + prefix)
prefix = b""

self.cleanup()
for t in writing_threads:
t.join()

def receive_chunk(self):
"""Handles receiving a chunk of data from the socket."""
try:
return self.sock.recv(self.chunk_size)
except socket.timeout:
return None
except (ConnectionResetError, ConnectionAbortedError) as e:
self.handle_disconnection(f"Connection lost: {e}")
return None
except Exception as e:
self.handle_unexpected_error(e)
return None

def process_chunk(self, chunk, buffer, prefix, writing_threads):
"""Processes a received chunk and updates the buffer and prefix."""
if chunk.startswith(b"MODULE"):
prefix = chunk[:70] # MODULE + module_id
buffer += chunk[70:]
elif chunk.startswith(b"PARAMETERS"):
prefix = chunk[:74] # PARAMETERS + module_id
buffer += chunk[74:]
else:
buffer += chunk

buffer = self.handle_packet(buffer, prefix, writing_threads)
return buffer, prefix

def handle_packet(self, buffer, prefix, writing_threads):
"""Handles a complete packet if the EOT_CHAR is found."""
eot_pos = buffer.find(self.EOT_CHAR)
if eot_pos >= 0:
packet = buffer[:eot_pos]
self.write_packet_to_file(packet)
buffer = buffer[eot_pos + len(self.EOT_CHAR) :]
self.main_node.handle_message(self, b"DONE STREAM" + prefix)
join_writing_threads(writing_threads)
elif len(buffer) > 20_000_000:
self.start_writing_thread(buffer, writing_threads)
buffer = b""
return buffer

def write_packet_to_file(self, packet):
"""Writes a complete packet to the file."""
file_name = self.get_file_name()
try:
with open(file_name, "ab") as f:
f.write(packet)
except Exception as e:
self.main_node.debug_print(
f"Connection -> file writing error: {e}",
colour="bright_red",
level=logging.ERROR,
)
writing_threads = []

def start_writing_thread(self, buffer, writing_threads):
"""Starts a new thread to write buffer data to a file."""
file_name = self.get_file_name()
t = threading.Thread(target=self.write_to_file, args=(file_name, buffer))
writing_threads.append(t)
t.start()
elif len(buffer) > 20_000_000:
t = threading.Thread(
target=self.write_to_file, args=(file_name, buffer)
)
writing_threads.append(t)
t.start()

def handle_disconnection(self, message):
"""Handles disconnection scenarios."""
self.terminate_flag.set()
self.main_node.debug_print(message, colour="bright_red", level=logging.ERROR)
self.main_node.disconnect_node(self.node_id)
buffer = b""

elif chunk == b"":
self.terminate_flag.set()
self.main_node.debug_print(
f"Connection -> Connection lost.",
colour="bright_red",
level=logging.ERROR,
)
self.main_node.disconnect_node(self.node_id)
break

gc.collect()

def handle_unexpected_error(self, error):
"""Handles unexpected exceptions."""
self.terminate_flag.set()
self.main_node.debug_print(
f"Connection -> unexpected error: {error}",
colour="bright_red",
level=logging.ERROR,
)
self.main_node.disconnect_node(self.node_id)

def update_last_seen(self):
"""Updates the last seen timestamp."""
self.last_seen = datetime.now()

def cleanup(self):
"""Closes the socket and performs cleanup."""
try:
self.sock.shutdown(socket.SHUT_RDWR)
except OSError:
Expand All @@ -170,10 +154,6 @@ def cleanup(self):
self.sock.settimeout(None)
self.sock.close()

def get_file_name(self):
"""Generates the file name for storing streamed data."""
return f"tmp/streamed_data_{self.host}_{self.port}_{self.main_node.host}_{self.main_node.port}"

def send(self, data: bytes, compression: bool = False):
try:
if compression:
Expand Down Expand Up @@ -220,6 +200,7 @@ def send_from_file(self, file_name: str, tag: bytes):
with open(file_name, "rb") as file:
self.sock.sendall(tag)

start_time = time.time()
# Get the total file size
total_size = os.fstat(file.fileno()).st_size
chunk_size = self.chunk_size
Expand Down
18 changes: 6 additions & 12 deletions tests/ml/test_local_dist_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,32 +91,26 @@ def test_node_connectivity(nodes):
assert val_port is not None, "Validator port is None."


def test_distributed_training(nodes):
def test_distributed_setup(nodes):
"""Test distributed training with a simple model."""
_, user, _ = nodes

# Create model and tokenizer
import torch
from torch.nn.functional import mse_loss
from transformers import BertForSequenceClassification, BertTokenizer
import torch.nn as nn

model = nn.ModuleList([nn.Linear(10, 10)])

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForSequenceClassification.from_pretrained("bert-base-uncased")
distributed_model, distributed_optimizer = user.create_distributed_model(
model=model, training=True, optimizer_type=torch.optim.Adam
)
del model # Free up memory
del model
distributed_optimizer = distributed_optimizer(lr=0.001, weight_decay=0.01)

# Training loop
distributed_model.train()
for _ in range(5):
distributed_optimizer.zero_grad()
x = torch.zeros((1, 1), dtype=torch.long)
outputs = distributed_model(x)
loss = mse_loss(outputs.logits, outputs.logits)
loss.backward()
distributed_optimizer.step()
distributed_optimizer.zero_grad()

assert distributed_model is not None, "Distributed model is None."
assert distributed_optimizer is not None, "Distributed optimizer is None."

0 comments on commit 0435125

Please sign in to comment.