Skip to content

Commit

Permalink
delete old checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
distributedstatemachine committed Dec 3, 2024
1 parent dd1eb6b commit b4343ab
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 62 deletions.
116 changes: 63 additions & 53 deletions neurons/miner.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,31 +125,31 @@ def __init__(self):
tplr.logger.error(f"Commitment error: {str(e)}")
tplr.commit(self.subtensor, self.wallet, self.config.netuid)

# # Init Wandb.
# # Ensure the wandb directory exists
# wandb_dir = os.path.join(os.getcwd(), 'wandb')
# os.makedirs(wandb_dir, exist_ok=True)

# # Define the run ID file path inside the wandb directory
# run_id_file = os.path.join(wandb_dir, f"wandb_run_id_M{self.uid}_{tplr.__version__}.txt")

# # Attempt to read the existing run ID
# if os.path.exists(run_id_file):
# with open(run_id_file, 'r') as f:
# run_id = f.read().strip()
# tplr.logger.info(f"Resuming WandB run with id {run_id}")
# else:
# run_id = None
# tplr.logger.info("Starting a new WandB run.")

# # Initialize WandB
# self.wandb = tplr.initialize_wandb(
# run_prefix='M',
# uid=self.uid,
# config=self.config,
# group='miner',
# job_type='training'
# )
# Init Wandb.
# Ensure the wandb directory exists
wandb_dir = os.path.join(os.getcwd(), 'wandb')
os.makedirs(wandb_dir, exist_ok=True)

# Define the run ID file path inside the wandb directory
run_id_file = os.path.join(wandb_dir, f"wandb_run_id_M{self.uid}_{tplr.__version__}.txt")

# Attempt to read the existing run ID
if os.path.exists(run_id_file):
with open(run_id_file, 'r') as f:
run_id = f.read().strip()
tplr.logger.info(f"Resuming WandB run with id {run_id}")
else:
run_id = None
tplr.logger.info("Starting a new WandB run.")

# Initialize WandB
self.wandb = tplr.initialize_wandb(
run_prefix='M',
uid=self.uid,
config=self.config,
group='miner',
job_type='training'
)

# Init model.
tplr.logger.info('\n' + '-' * 40 + ' Hparams ' + '-' * 40)
Expand All @@ -167,34 +167,44 @@ def __init__(self):
betas=(self.hparams.optimizer_beta1, self.hparams.optimizer_beta2), # B1 and B2
weight_decay=self.hparams.optimizer_weight_decay, # Weight decay
foreach=True, # more memory usage, but faster
)

# # Load checkpoint if it exists
# self.checkpoint_path = f"checkpoint-V1.pth" if self.config.checkpoint_path is None else self.config.checkpoint_path
# if os.path.exists(self.checkpoint_path):
# tplr.logger.info(f"Loading checkpoint from {self.checkpoint_path}")
# global_step, _ = asyncio.run(tplr.load_checkpoint(
# filename=self.checkpoint_path,
# model=self.model,
# optimizer=self.optimizer,
# scheduler=None,
# device=self.config.device
# ))

# self.global_step = global_step
# if global_step is None:
# tplr.logger.warning(f"Corrupt checkpoint detected at {self.checkpoint_path}. Removing file and starting fresh.")
# try:
# os.remove(self.checkpoint_path)
# tplr.logger.info(f"Removed corrupt checkpoint: {self.checkpoint_path}")
# except OSError as e:
# tplr.logger.error(f"Failed to remove corrupt checkpoint: {e}")
# global_step = 0
# else:
# tplr.logger.info(f"Resumed from global step {self.global_step}")
# else:
# tplr.logger.info("No checkpoint file found. Starting from scratch.")
# self.global_step = 0
)

# Delete old check point
for filename in os.listdir(os.getcwd()):
if filename.startswith("checkpoint") and filename.endswith(".pth"):
file_path = os.path.join(os.getcwd(), filename)
try:
os.remove(file_path)
tplr.logger.info(f"Deleted checkpoint file {file_path}")
except OSError as e:
tplr.logger.error(f"Failed to delete {file_path}: {e}")

# Load checkpoint if it exists
self.checkpoint_path = f"checkpoint-V1.pth" if self.config.checkpoint_path is None else self.config.checkpoint_path

Check failure on line 183 in neurons/miner.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F541)

neurons/miner.py:183:32: F541 f-string without any placeholders
if os.path.exists(self.checkpoint_path):
tplr.logger.info(f"Loading checkpoint from {self.checkpoint_path}")
global_step, _ = asyncio.run(tplr.load_checkpoint(
filename=self.checkpoint_path,
model=self.model,
optimizer=self.optimizer,
scheduler=None,
device=self.config.device
))

self.global_step = global_step
if global_step is None:
tplr.logger.warning(f"Corrupt checkpoint detected at {self.checkpoint_path}. Removing file and starting fresh.")
try:
os.remove(self.checkpoint_path)
tplr.logger.info(f"Removed corrupt checkpoint: {self.checkpoint_path}")
except OSError as e:
tplr.logger.error(f"Failed to remove corrupt checkpoint: {e}")
global_step = 0
else:
tplr.logger.info(f"Resumed from global step {self.global_step}")
else:
tplr.logger.info("No checkpoint file found. Starting from scratch.")
self.global_step = 0

# Initialize learning rate scheduler
self.scheduler = tplr.get_wsd_scheduler(
Expand Down
33 changes: 24 additions & 9 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,25 @@ def __init__(self):

# Init bucket.
try:
tplr.logger.info(f'bucket_name: {tplr.config.BUCKET_SECRETS["bucket_name"]}')
tplr.logger.debug(f'bucket_name: {tplr.config.BUCKET_SECRETS["bucket_name"]}')
commitment = self.chain_manager.get_commitment(self.uid)

# Convert Bucket object back to concatenated string format for comparison
commitment_str = commitment.name + commitment.access_key_id + commitment.secret_access_key

current_bucket = (
tplr.config.BUCKET_SECRETS["account_id"] +
tplr.config.BUCKET_SECRETS["access_key_id"] +
tplr.config.BUCKET_SECRETS["secret_access_key"]
tplr.config.BUCKET_SECRETS["bucket_name"] +
tplr.config.BUCKET_SECRETS["read"]["access_key_id"] +
tplr.config.BUCKET_SECRETS["read"]["secret_access_key"]
)
if current_bucket != commitment:
# TODO: Handle mismatched commitments
tplr.logger.debug(f'Comparing:\nCommitment: {commitment_str}\nCurrent: {current_bucket}')

if current_bucket != commitment_str:
raise ValueError("Bucket commitment data does not match.")
raise ValueError('')
except Exception:

except Exception as e:
tplr.logger.error(f"Commitment error: {str(e)}")
tplr.commit(self.subtensor, self.wallet, self.config.netuid)
tplr.logger.info('Bucket:' + tplr.config.BUCKET_SECRETS["bucket_name"])

# Init Wandb.
# Ensure the wandb directory exists
Expand Down Expand Up @@ -154,6 +159,16 @@ def __init__(self):
self.model.to(self.config.device)
self.model.eval()

# Delete old check point
for filename in os.listdir(os.getcwd()):
if filename.startswith("checkpoint") and filename.endswith(".pth"):
file_path = os.path.join(os.getcwd(), filename)
try:
os.remove(file_path)
tplr.logger.info(f"Deleted checkpoint file {file_path}")
except OSError as e:
tplr.logger.error(f"Failed to delete {file_path}: {e}")

# Set checkpoint path
self.checkpoint_path = f"checkpoint-V{self.uid}.pth" if self.config.checkpoint_path is None else self.config.checkpoint_path

Expand Down

0 comments on commit b4343ab

Please sign in to comment.