Skip to content

Commit

Permalink
Merge pull request #32 from RaoFoundation/feat/tegriddy
Browse files Browse the repository at this point in the history
feat: make tegriddy happy
  • Loading branch information
distributedstatemachine authored Nov 30, 2024
2 parents bbecd91 + 56a8526 commit 5ce9b26
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 29 deletions.
28 changes: 21 additions & 7 deletions neurons/miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def config():
parser.add_argument('--no_autoupdate', action='store_true', help='Disable automatic updates')
parser.add_argument("--process_name", type=str, help="The name of the PM2 process")
parser.add_argument('--checkpoint_path', type=str, default=None, help='Path to save/load the checkpoint. If None, the path is set to checkpoint-M<UID>.pth.')
parser.add_argument('--save-location', type=str, default=None, help='Directory to save/load slice files')
bt.wallet.add_args(parser)
bt.subtensor.add_args(parser)
config = bt.config(parser)
Expand Down Expand Up @@ -222,7 +223,13 @@ def __init__(self):
self.new_window_event = asyncio.Event()
self.stop_event = asyncio.Event()
self.last_full_steps = self.hparams.desired_batch_size // self.config.actual_batch_size
bt.logging.off
bt.logging.off
self.save_location = self.config.save_location
if self.save_location is None:
import tempfile
self.save_location = tempfile.gettempdir()
else:
os.makedirs(self.save_location, exist_ok=True)
print ( self.hparams )

async def update(self):
Expand Down Expand Up @@ -282,6 +289,7 @@ async def run(self):
window = window,
seed = window,
compression = self.hparams.compression,
save_location=self.save_location,
key = 'state'
)
if max_global_step is not None:
Expand Down Expand Up @@ -326,7 +334,8 @@ async def run(self):
state_slices = await tplr.download_slices_for_buckets_and_windows(
buckets=valid_buckets,
windows=[window],
key='state'
key='state',
save_location=self.save_location
)

n_state_slices = len(state_slices[window]) if window in state_slices else 0
Expand All @@ -337,7 +346,8 @@ async def run(self):
delta_slices = await tplr.download_slices_for_buckets_and_windows(
buckets = self.buckets,
windows = [ window - 1 ],
key = 'delta'
key = 'delta',
save_location=self.save_location
)
n_slices = len(delta_slices[ window - 1 ]) if window - 1 in delta_slices else 0
tplr.logger.info(f"{tplr.P(window, tplr.T() - st)}: Download {n_slices} window deltas.")
Expand All @@ -349,6 +359,7 @@ async def run(self):
window=window,
seed=window,
compression=self.hparams.compression,
save_location=self.save_location,
key='state'
)
if max_global_step is not None:
Expand Down Expand Up @@ -424,8 +435,9 @@ async def run(self):
seed = window,
wallet = self.wallet,
compression = self.hparams.compression,
save_location = self.save_location,
key = 'delta',
global_step=self.global_step
global_step = self.global_step
)
tplr.logger.info(f"{tplr.P(window, tplr.T() - st)}: Uploaded the delta.")

Expand All @@ -436,6 +448,7 @@ async def run(self):
window=window - 1,
seed=window - 1,
compression=self.hparams.compression,
save_location=self.save_location,
key='delta'
)
if max_global_step is not None:
Expand All @@ -452,15 +465,16 @@ async def run(self):
seed = window + 1,
wallet = self.wallet,
compression = self.hparams.compression,
save_location = self.save_location,
key = 'state',
global_step=self.global_step
global_step = self.global_step
)
tplr.logger.info(f"{tplr.P(window, tplr.T() - st)}: Uploaded the state.")

# Clean file history.
st = tplr.T()
await tplr.delete_files_before_window( window_max = window - self.hparams.max_history, key = 'state')
await tplr.delete_files_before_window( window_max = window - self.hparams.max_history, key = 'delta')
await tplr.delete_files_before_window(window_max=window - self.hparams.max_history, save_location=self.save_location, key='state')
await tplr.delete_files_before_window(window_max=window - self.hparams.max_history, save_location=self.save_location, key='delta')
await tplr.delete_files_from_bucket_before_window( bucket = tplr.config.BUCKET_SECRETS["bucket_name"], window_max = window - self.hparams.max_history, key = 'state' )
await tplr.delete_files_from_bucket_before_window( bucket = tplr.config.BUCKET_SECRETS["bucket_name"], window_max = window - self.hparams.max_history, key = 'delta' )
tplr.logger.info(f"{tplr.P(window, tplr.T() - st)}: Cleaned file history.")
Expand Down
23 changes: 18 additions & 5 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def config():
parser.add_argument('--no_autoupdate', action='store_true', help='Disable automatic updates')
parser.add_argument("--process_name", type=str, help="The name of the PM2 process")
parser.add_argument('--checkpoint_path', type=str, default=None, help='Path to save/load the checkpoint. If None, the path is set to checkpoint-V<UID>.pth.')
parser.add_argument('--save-location', type=str, default=None, help='Directory to save/load slice files')
bt.wallet.add_args(parser)
bt.subtensor.add_args(parser)
config = bt.config(parser)
Expand Down Expand Up @@ -208,6 +209,12 @@ def __init__(self):
self.scores = torch.zeros( 256, dtype = torch.float32 )
self.weights = torch.zeros( 256, dtype = torch.float32 )
self.sample_rate = 1.0
self.save_location = self.config.save_location
if self.save_location is None:
import tempfile
self.save_location = tempfile.gettempdir()
else:
os.makedirs(self.save_location, exist_ok=True)
print ( self.hparams )

async def update(self):
Expand Down Expand Up @@ -256,14 +263,16 @@ async def run(self):
state_slices = await tplr.download_slices_for_buckets_and_windows(
buckets=[b for b in self.buckets if b is not None],
windows = history_windows,
key = 'state'
key = 'state',
save_location=self.save_location
)
for window in tqdm(history_windows, desc="Syncing state"):
max_global_step = await tplr.apply_slices_to_model(
model=self.model,
window=window,
seed=window,
compression=self.hparams.compression,
save_location=self.save_location,
key='state',
)
if max_global_step is not None:
Expand Down Expand Up @@ -309,7 +318,8 @@ async def run(self):
state_slices = await tplr.download_slices_for_buckets_and_windows(
buckets=valid_buckets,
windows=[window],
key='state'
key='state',
save_location=self.save_location
)
n_state_slices = len(state_slices[window]) if window in state_slices else 0
tplr.logger.info(f"{tplr.P(window, tplr.T() - st)}: Downloaded {n_state_slices} window states.")
Expand All @@ -319,7 +329,8 @@ async def run(self):
eval_slices = await tplr.download_slices_for_buckets_and_windows(
buckets = self.buckets,
windows = [ window ],
key = 'delta'
key = 'delta',
save_location=self.save_location
)
n_eval_slices = len(eval_slices[ window ]) if window in eval_slices else 0
tplr.logger.info(f"{tplr.P(window, tplr.T() - st)}: Downloaded {n_eval_slices} window deltas.")
Expand All @@ -346,6 +357,7 @@ async def run(self):
window=window,
seed=window,
compression=self.hparams.compression,
save_location=self.save_location,
key='state',
)
if max_global_step is not None:
Expand Down Expand Up @@ -498,6 +510,7 @@ async def run(self):
window=window,
seed=window,
compression=self.hparams.compression,
save_location=self.save_location,
key='delta',
)
if max_global_step is not None:
Expand All @@ -506,8 +519,8 @@ async def run(self):

# Clean local and remote space from old slices.
st = tplr.T()
await tplr.delete_files_before_window( window_max = window - self.hparams.max_history, key = 'state')
await tplr.delete_files_before_window( window_max = window - self.hparams.max_history, key = 'delta')
await tplr.delete_files_before_window(window_max=window - self.hparams.max_history, save_location=self.save_location, key='state')
await tplr.delete_files_before_window(window_max=window - self.hparams.max_history, save_location=self.save_location, key='delta')
await tplr.delete_files_from_bucket_before_window( bucket = tplr.config.BUCKET_SECRETS["bucket_name"], window_max = window - self.hparams.max_history, key = 'state' )
await tplr.delete_files_from_bucket_before_window( bucket = tplr.config.BUCKET_SECRETS["bucket_name"], window_max = window - self.hparams.max_history, key = 'delta' )
tplr.logger.info(f"{tplr.P(window, tplr.T() - st)}: Cleaned file history.")
Expand Down
56 changes: 39 additions & 17 deletions src/templar/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,12 @@ async def get_slices(filename: str, device: str) -> Dict[str, torch.Tensor]:


async def apply_slices_to_model(
model: torch.nn.Module, window: int, seed: str, compression: int, key: str = "slice"
model: torch.nn.Module,
window: int,
seed: str,
compression: int,
save_location: str,
key: str = "slice",
) -> int:
"""
Applies downloaded model parameter slices to a model for a specific window.
Expand Down Expand Up @@ -169,7 +174,9 @@ async def apply_slices_to_model(
"""
max_global_step = 0
indices_dict = await get_indices_for_window(model, seed, compression)
slice_files = await load_files_for_window(window=window, key=key)
slice_files = await load_files_for_window(
window=window, save_location=save_location, key=key
)

slices_per_param = {name: 0 for name, _ in model.named_parameters()}
param_sums = {
Expand Down Expand Up @@ -238,6 +245,7 @@ async def upload_slice_for_window(
seed: str,
wallet: "bt.wallet",
compression: int,
save_location: str,
key: str = "slice",
global_step: int = 0,
):
Expand Down Expand Up @@ -280,10 +288,9 @@ async def upload_slice_for_window(
for name, param in model.named_parameters():
slice_data[name] = param.data.view(-1)[indices[name].to(model.device)].cpu()

# Create a temporary file and write the sliced model state dictionary to it
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
torch.save(slice_data, temp_file)
temp_file_name = temp_file.name # Store the temporary file name
# Use save_location for temporary file
temp_file_name = os.path.join(save_location, filename)
torch.save(slice_data, temp_file_name)

# Upload the file to S3
session = get_session()
Expand Down Expand Up @@ -375,7 +382,9 @@ async def get_indices_for_window(
return result


async def download_file(s3_client, bucket: str, filename: str) -> str:
async def download_file(
s3_client, bucket: str, filename: str, save_location: str
) -> str:
"""
Downloads a file from S3, using parallel downloads for large files.
Expand All @@ -388,7 +397,7 @@ async def download_file(s3_client, bucket: str, filename: str) -> str:
str: The path to the downloaded file in the temporary directory.
"""
async with semaphore:
temp_file = os.path.join(tempfile.gettempdir(), filename)
temp_file = os.path.join(save_location, filename)
# Check if the file exists.
if os.path.exists(temp_file):
logger.debug(f"File {temp_file} already exists, skipping download.")
Expand Down Expand Up @@ -429,7 +438,13 @@ async def download_file(s3_client, bucket: str, filename: str) -> str:


async def handle_file(
s3_client, bucket: str, filename: str, hotkey: str, window: int, version: str
s3_client,
bucket: str,
filename: str,
hotkey: str,
window: int,
version: str,
save_location: str,
):
"""
Handles downloading a single file from S3.
Expand All @@ -449,7 +464,7 @@ async def handle_file(
logger.debug(
f"Handling file '{filename}' for window {window} and hotkey '{hotkey}'"
)
temp_file = await download_file(s3_client, bucket, filename)
temp_file = await download_file(s3_client, bucket, filename, save_location)
if temp_file:
return SimpleNamespace(
bucket=bucket,
Expand All @@ -463,7 +478,7 @@ async def handle_file(


async def process_bucket(
s3_client, bucket: str, windows: List[int], key: str = "slice"
s3_client, bucket: str, windows: List[int], key: str, save_location: str
):
"""
Processes a single S3 bucket to download files for specified windows.
Expand Down Expand Up @@ -550,6 +565,7 @@ async def process_bucket(
slice_hotkey,
window,
slice_version,
save_location,
)
)
except ValueError:
Expand Down Expand Up @@ -586,7 +602,7 @@ async def process_bucket(


async def download_slices_for_buckets_and_windows(
buckets: List[Bucket], windows: List[int], key: str = "slice"
buckets: List[Bucket], windows: List[int], key: str, save_location: str
) -> Dict[int, List[SimpleNamespace]]:
"""Downloads model slices from multiple S3 buckets for specified windows.
Expand Down Expand Up @@ -642,7 +658,9 @@ async def download_slices_for_buckets_and_windows(
aws_secret_access_key=bucket.secret_access_key,
) as s3_client:
logger.debug(f"Processing bucket: {bucket.name}")
tasks.append(process_bucket(s3_client, bucket.name, windows, key))
tasks.append(
process_bucket(s3_client, bucket.name, windows, key, save_location)
)

results = await asyncio.gather(*tasks)
# Combine results into a dictionary mapping window IDs to lists of slices
Expand All @@ -653,7 +671,9 @@ async def download_slices_for_buckets_and_windows(
return slices


async def load_files_for_window(window: int, key: str = "slice") -> List[str]:
async def load_files_for_window(
window: int, save_location: str, key: str = "slice"
) -> List[str]:
"""
Loads files for a specific window from the temporary directory.
Expand All @@ -676,7 +696,7 @@ async def load_files_for_window(window: int, key: str = "slice") -> List[str]:
"""

logger.debug(f"Retrieving files for window {window} from temporary directory")
temp_dir = tempfile.gettempdir()
temp_dir = save_location
window_files = []
pattern = re.compile(rf"^{key}-{window}-.+-v{__version__}\.pt$")
for filename in os.listdir(temp_dir):
Expand All @@ -686,7 +706,9 @@ async def load_files_for_window(window: int, key: str = "slice") -> List[str]:
return window_files


async def delete_files_before_window(window_max: int, key: str = "slice"):
async def delete_files_before_window(
window_max: int, save_location: str, key: str = "slice"
):
"""
Deletes temporary files with window IDs less than the specified maximum.
Expand All @@ -706,7 +728,7 @@ async def delete_files_before_window(window_max: int, key: str = "slice"):
"""

logger.debug(f"Deleting files with window id before {window_max}")
temp_dir = tempfile.gettempdir()
temp_dir = save_location
pattern = re.compile(rf"^{re.escape(key)}-(\d+)-.+-v{__version__}\.(pt|pt\.lock)$")
for filename in os.listdir(temp_dir):
match = pattern.match(filename)
Expand Down

0 comments on commit 5ce9b26

Please sign in to comment.