From 1ab6fdd5cd6b4d0f646265f6a4e6eba0ed98ce46 Mon Sep 17 00:00:00 2001 From: distributedstatemachine! Date: Fri, 29 Nov 2024 23:08:09 +0000 Subject: [PATCH 1/3] feat: make tegriddy happy --- neurons/miner.py | 22 +++++++++++++++++----- neurons/validator.py | 22 +++++++++++++++++----- src/templar/comms.py | 35 ++++++++++++++++++----------------- 3 files changed, 52 insertions(+), 27 deletions(-) diff --git a/neurons/miner.py b/neurons/miner.py index 38e573b..0fa7496 100644 --- a/neurons/miner.py +++ b/neurons/miner.py @@ -60,6 +60,7 @@ def config(): parser.add_argument('--no_autoupdate', action='store_true', help='Disable automatic updates') parser.add_argument("--process_name", type=str, help="The name of the PM2 process") parser.add_argument('--checkpoint_path', type=str, default=None, help='Path to save/load the checkpoint. If None, the path is set to checkpoint-M.pth.') + parser.add_argument('--save-location', type=str, default=None, help='Directory to save/load slice files') bt.wallet.add_args(parser) bt.subtensor.add_args(parser) config = bt.config(parser) @@ -222,7 +223,13 @@ def __init__(self): self.new_window_event = asyncio.Event() self.stop_event = asyncio.Event() self.last_full_steps = self.hparams.desired_batch_size // self.config.actual_batch_size - bt.logging.off + bt.logging.off + self.save_location = self.config.save_location + if self.save_location is None: + import tempfile + self.save_location = tempfile.gettempdir() + else: + os.makedirs(self.save_location, exist_ok=True) print ( self.hparams ) async def update(self): @@ -282,6 +289,7 @@ async def run(self): window = window, seed = window, compression = self.hparams.compression, + save_location=self.save_location, key = 'state' ) if max_global_step is not None: @@ -326,7 +334,8 @@ async def run(self): state_slices = await tplr.download_slices_for_buckets_and_windows( buckets=valid_buckets, windows=[window], - key='state' + key='state', + save_location=self.save_location ) n_state_slices = len(state_slices[window]) if window in state_slices else 0 @@ -337,7 +346,8 @@ async def run(self): delta_slices = await tplr.download_slices_for_buckets_and_windows( buckets = self.buckets, windows = [ window - 1 ], - key = 'delta' + key = 'delta', + save_location=self.save_location ) n_slices = len(delta_slices[ window - 1 ]) if window - 1 in delta_slices else 0 tplr.logger.info(f"{tplr.P(window, tplr.T() - st)}: Download {n_slices} window deltas.") @@ -349,6 +359,7 @@ async def run(self): window=window, seed=window, compression=self.hparams.compression, + save_location=self.save_location, key='state' ) if max_global_step is not None: @@ -436,6 +447,7 @@ async def run(self): window=window - 1, seed=window - 1, compression=self.hparams.compression, + save_location=self.save_location, key='delta' ) if max_global_step is not None: @@ -459,8 +471,8 @@ async def run(self): # Clean file history. st = tplr.T() - await tplr.delete_files_before_window( window_max = window - self.hparams.max_history, key = 'state') - await tplr.delete_files_before_window( window_max = window - self.hparams.max_history, key = 'delta') + await tplr.delete_files_before_window(window_max=window - self.hparams.max_history, save_location=self.save_location, key='state') + await tplr.delete_files_before_window(window_max=window - self.hparams.max_history, save_location=self.save_location, key='delta') await tplr.delete_files_from_bucket_before_window( bucket = tplr.config.BUCKET_SECRETS["bucket_name"], window_max = window - self.hparams.max_history, key = 'state' ) await tplr.delete_files_from_bucket_before_window( bucket = tplr.config.BUCKET_SECRETS["bucket_name"], window_max = window - self.hparams.max_history, key = 'delta' ) tplr.logger.info(f"{tplr.P(window, tplr.T() - st)}: Cleaned file history.") diff --git a/neurons/validator.py b/neurons/validator.py index 63c5a3b..be36020 100644 --- a/neurons/validator.py +++ b/neurons/validator.py @@ -208,6 +208,12 @@ def __init__(self): self.scores = torch.zeros( 256, dtype = torch.float32 ) self.weights = torch.zeros( 256, dtype = torch.float32 ) self.sample_rate = 1.0 + self.save_location = self.config.save_location + if self.save_location is None: + import tempfile + self.save_location = tempfile.gettempdir() + else: + os.makedirs(self.save_location, exist_ok=True) print ( self.hparams ) async def update(self): @@ -256,7 +262,8 @@ async def run(self): state_slices = await tplr.download_slices_for_buckets_and_windows( buckets=[b for b in self.buckets if b is not None], windows = history_windows, - key = 'state' + key = 'state', + save_location=self.save_location ) for window in tqdm(history_windows, desc="Syncing state"): max_global_step = await tplr.apply_slices_to_model( @@ -264,6 +271,7 @@ async def run(self): window=window, seed=window, compression=self.hparams.compression, + save_location=self.save_location, key='state', ) if max_global_step is not None: @@ -309,7 +317,8 @@ async def run(self): state_slices = await tplr.download_slices_for_buckets_and_windows( buckets=valid_buckets, windows=[window], - key='state' + key='state', + save_location=self.save_location ) n_state_slices = len(state_slices[window]) if window in state_slices else 0 tplr.logger.info(f"{tplr.P(window, tplr.T() - st)}: Downloaded {n_state_slices} window states.") @@ -319,7 +328,8 @@ async def run(self): eval_slices = await tplr.download_slices_for_buckets_and_windows( buckets = self.buckets, windows = [ window ], - key = 'delta' + key = 'delta', + save_location=self.save_location ) n_eval_slices = len(eval_slices[ window ]) if window in eval_slices else 0 tplr.logger.info(f"{tplr.P(window, tplr.T() - st)}: Downloaded {n_eval_slices} window deltas.") @@ -346,6 +356,7 @@ async def run(self): window=window, seed=window, compression=self.hparams.compression, + save_location=self.save_location, key='state', ) if max_global_step is not None: @@ -498,6 +509,7 @@ async def run(self): window=window, seed=window, compression=self.hparams.compression, + save_location=self.save_location, key='delta', ) if max_global_step is not None: @@ -506,8 +518,8 @@ async def run(self): # Clean local and remote space from old slices. st = tplr.T() - await tplr.delete_files_before_window( window_max = window - self.hparams.max_history, key = 'state') - await tplr.delete_files_before_window( window_max = window - self.hparams.max_history, key = 'delta') + await tplr.delete_files_before_window(window_max=window - self.hparams.max_history, save_location=self.save_location, key='state') + await tplr.delete_files_before_window(window_max=window - self.hparams.max_history, save_location=self.save_location, key='delta') await tplr.delete_files_from_bucket_before_window( bucket = tplr.config.BUCKET_SECRETS["bucket_name"], window_max = window - self.hparams.max_history, key = 'state' ) await tplr.delete_files_from_bucket_before_window( bucket = tplr.config.BUCKET_SECRETS["bucket_name"], window_max = window - self.hparams.max_history, key = 'delta' ) tplr.logger.info(f"{tplr.P(window, tplr.T() - st)}: Cleaned file history.") diff --git a/src/templar/comms.py b/src/templar/comms.py index 0f4766e..138df03 100644 --- a/src/templar/comms.py +++ b/src/templar/comms.py @@ -127,7 +127,7 @@ async def get_slices(filename: str, device: str) -> Dict[str, torch.Tensor]: async def apply_slices_to_model( - model: torch.nn.Module, window: int, seed: str, compression: int, key: str = "slice" + model: torch.nn.Module, window: int, seed: str, compression: int, save_location: str, key: str = "slice" ) -> int: """ Applies downloaded model parameter slices to a model for a specific window. @@ -169,7 +169,7 @@ async def apply_slices_to_model( """ max_global_step = 0 indices_dict = await get_indices_for_window(model, seed, compression) - slice_files = await load_files_for_window(window=window, key=key) + slice_files = await load_files_for_window(window=window, save_location=save_location, key=key) slices_per_param = {name: 0 for name, _ in model.named_parameters()} param_sums = { @@ -238,6 +238,7 @@ async def upload_slice_for_window( seed: str, wallet: "bt.wallet", compression: int, + save_location: str, key: str = "slice", global_step: int = 0, ): @@ -280,10 +281,9 @@ async def upload_slice_for_window( for name, param in model.named_parameters(): slice_data[name] = param.data.view(-1)[indices[name].to(model.device)].cpu() - # Create a temporary file and write the sliced model state dictionary to it - with tempfile.NamedTemporaryFile(delete=False) as temp_file: - torch.save(slice_data, temp_file) - temp_file_name = temp_file.name # Store the temporary file name + # Use save_location for temporary file + temp_file_name = os.path.join(save_location, filename) + torch.save(slice_data, temp_file_name) # Upload the file to S3 session = get_session() @@ -375,7 +375,7 @@ async def get_indices_for_window( return result -async def download_file(s3_client, bucket: str, filename: str) -> str: +async def download_file(s3_client, bucket: str, filename: str, save_location: str) -> str: """ Downloads a file from S3, using parallel downloads for large files. @@ -388,7 +388,7 @@ async def download_file(s3_client, bucket: str, filename: str) -> str: str: The path to the downloaded file in the temporary directory. """ async with semaphore: - temp_file = os.path.join(tempfile.gettempdir(), filename) + temp_file = os.path.join(save_location, filename) # Check if the file exists. if os.path.exists(temp_file): logger.debug(f"File {temp_file} already exists, skipping download.") @@ -429,7 +429,7 @@ async def download_file(s3_client, bucket: str, filename: str) -> str: async def handle_file( - s3_client, bucket: str, filename: str, hotkey: str, window: int, version: str + s3_client, bucket: str, filename: str, hotkey: str, window: int, version: str, save_location: str ): """ Handles downloading a single file from S3. @@ -449,7 +449,7 @@ async def handle_file( logger.debug( f"Handling file '{filename}' for window {window} and hotkey '{hotkey}'" ) - temp_file = await download_file(s3_client, bucket, filename) + temp_file = await download_file(s3_client, bucket, filename, save_location) if temp_file: return SimpleNamespace( bucket=bucket, @@ -463,7 +463,7 @@ async def handle_file( async def process_bucket( - s3_client, bucket: str, windows: List[int], key: str = "slice" + s3_client, bucket: str, windows: List[int], key: str, save_location: str ): """ Processes a single S3 bucket to download files for specified windows. @@ -550,6 +550,7 @@ async def process_bucket( slice_hotkey, window, slice_version, + save_location ) ) except ValueError: @@ -586,7 +587,7 @@ async def process_bucket( async def download_slices_for_buckets_and_windows( - buckets: List[Bucket], windows: List[int], key: str = "slice" + buckets: List[Bucket], windows: List[int], key: str, save_location: str ) -> Dict[int, List[SimpleNamespace]]: """Downloads model slices from multiple S3 buckets for specified windows. @@ -642,7 +643,7 @@ async def download_slices_for_buckets_and_windows( aws_secret_access_key=bucket.secret_access_key, ) as s3_client: logger.debug(f"Processing bucket: {bucket.name}") - tasks.append(process_bucket(s3_client, bucket.name, windows, key)) + tasks.append(process_bucket(s3_client, bucket.name, windows, key, save_location)) results = await asyncio.gather(*tasks) # Combine results into a dictionary mapping window IDs to lists of slices @@ -653,7 +654,7 @@ async def download_slices_for_buckets_and_windows( return slices -async def load_files_for_window(window: int, key: str = "slice") -> List[str]: +async def load_files_for_window(window: int, save_location: str, key: str = "slice") -> List[str]: """ Loads files for a specific window from the temporary directory. @@ -676,7 +677,7 @@ async def load_files_for_window(window: int, key: str = "slice") -> List[str]: """ logger.debug(f"Retrieving files for window {window} from temporary directory") - temp_dir = tempfile.gettempdir() + temp_dir = save_location window_files = [] pattern = re.compile(rf"^{key}-{window}-.+-v{__version__}\.pt$") for filename in os.listdir(temp_dir): @@ -686,7 +687,7 @@ async def load_files_for_window(window: int, key: str = "slice") -> List[str]: return window_files -async def delete_files_before_window(window_max: int, key: str = "slice"): +async def delete_files_before_window(window_max: int, save_location: str, key: str = "slice"): """ Deletes temporary files with window IDs less than the specified maximum. @@ -706,7 +707,7 @@ async def delete_files_before_window(window_max: int, key: str = "slice"): """ logger.debug(f"Deleting files with window id before {window_max}") - temp_dir = tempfile.gettempdir() + temp_dir = save_location pattern = re.compile(rf"^{re.escape(key)}-(\d+)-.+-v{__version__}\.(pt|pt\.lock)$") for filename in os.listdir(temp_dir): match = pattern.match(filename) From 0d920871eda42d0773e25fee3e6f994bf8c01116 Mon Sep 17 00:00:00 2001 From: distributedstatemachine! Date: Fri, 29 Nov 2024 23:15:29 +0000 Subject: [PATCH 2/3] chore: add validator arg , ruff --- neurons/validator.py | 1 + src/templar/comms.py | 37 +++++++++++++++++++++++++++++-------- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/neurons/validator.py b/neurons/validator.py index be36020..37ab7f6 100644 --- a/neurons/validator.py +++ b/neurons/validator.py @@ -59,6 +59,7 @@ def config(): parser.add_argument('--no_autoupdate', action='store_true', help='Disable automatic updates') parser.add_argument("--process_name", type=str, help="The name of the PM2 process") parser.add_argument('--checkpoint_path', type=str, default=None, help='Path to save/load the checkpoint. If None, the path is set to checkpoint-V.pth.') + parser.add_argument('--save-location', type=str, default=None, help='Directory to save/load slice files') bt.wallet.add_args(parser) bt.subtensor.add_args(parser) config = bt.config(parser) diff --git a/src/templar/comms.py b/src/templar/comms.py index 138df03..e70ab65 100644 --- a/src/templar/comms.py +++ b/src/templar/comms.py @@ -127,7 +127,12 @@ async def get_slices(filename: str, device: str) -> Dict[str, torch.Tensor]: async def apply_slices_to_model( - model: torch.nn.Module, window: int, seed: str, compression: int, save_location: str, key: str = "slice" + model: torch.nn.Module, + window: int, + seed: str, + compression: int, + save_location: str, + key: str = "slice", ) -> int: """ Applies downloaded model parameter slices to a model for a specific window. @@ -169,7 +174,9 @@ async def apply_slices_to_model( """ max_global_step = 0 indices_dict = await get_indices_for_window(model, seed, compression) - slice_files = await load_files_for_window(window=window, save_location=save_location, key=key) + slice_files = await load_files_for_window( + window=window, save_location=save_location, key=key + ) slices_per_param = {name: 0 for name, _ in model.named_parameters()} param_sums = { @@ -375,7 +382,9 @@ async def get_indices_for_window( return result -async def download_file(s3_client, bucket: str, filename: str, save_location: str) -> str: +async def download_file( + s3_client, bucket: str, filename: str, save_location: str +) -> str: """ Downloads a file from S3, using parallel downloads for large files. @@ -429,7 +438,13 @@ async def download_file(s3_client, bucket: str, filename: str, save_location: st async def handle_file( - s3_client, bucket: str, filename: str, hotkey: str, window: int, version: str, save_location: str + s3_client, + bucket: str, + filename: str, + hotkey: str, + window: int, + version: str, + save_location: str, ): """ Handles downloading a single file from S3. @@ -550,7 +565,7 @@ async def process_bucket( slice_hotkey, window, slice_version, - save_location + save_location, ) ) except ValueError: @@ -643,7 +658,9 @@ async def download_slices_for_buckets_and_windows( aws_secret_access_key=bucket.secret_access_key, ) as s3_client: logger.debug(f"Processing bucket: {bucket.name}") - tasks.append(process_bucket(s3_client, bucket.name, windows, key, save_location)) + tasks.append( + process_bucket(s3_client, bucket.name, windows, key, save_location) + ) results = await asyncio.gather(*tasks) # Combine results into a dictionary mapping window IDs to lists of slices @@ -654,7 +671,9 @@ async def download_slices_for_buckets_and_windows( return slices -async def load_files_for_window(window: int, save_location: str, key: str = "slice") -> List[str]: +async def load_files_for_window( + window: int, save_location: str, key: str = "slice" +) -> List[str]: """ Loads files for a specific window from the temporary directory. @@ -687,7 +706,9 @@ async def load_files_for_window(window: int, save_location: str, key: str = "sli return window_files -async def delete_files_before_window(window_max: int, save_location: str, key: str = "slice"): +async def delete_files_before_window( + window_max: int, save_location: str, key: str = "slice" +): """ Deletes temporary files with window IDs less than the specified maximum. From 56a8526dd83e5733c023c7d4935ffbd937675efc Mon Sep 17 00:00:00 2001 From: distributedstatemachine! Date: Fri, 29 Nov 2024 23:42:04 +0000 Subject: [PATCH 3/3] fix: upload file --- neurons/miner.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/neurons/miner.py b/neurons/miner.py index 0fa7496..8ece9b6 100644 --- a/neurons/miner.py +++ b/neurons/miner.py @@ -435,8 +435,9 @@ async def run(self): seed = window, wallet = self.wallet, compression = self.hparams.compression, + save_location = self.save_location, key = 'delta', - global_step=self.global_step + global_step = self.global_step ) tplr.logger.info(f"{tplr.P(window, tplr.T() - st)}: Uploaded the delta.") @@ -464,8 +465,9 @@ async def run(self): seed = window + 1, wallet = self.wallet, compression = self.hparams.compression, + save_location = self.save_location, key = 'state', - global_step=self.global_step + global_step = self.global_step ) tplr.logger.info(f"{tplr.P(window, tplr.T() - st)}: Uploaded the state.")