Skip to content

Commit

Permalink
updated training loops for logging
Browse files Browse the repository at this point in the history
  • Loading branch information
mat10d committed Sep 2, 2024
1 parent af464cf commit fefe44b
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 26 deletions.
24 changes: 15 additions & 9 deletions scripts/training_loop_basic_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,7 @@ def read_config(yaml_path):
device = torch.device("cpu")

# Basic values for logging
parent_dir = '/mnt/efs/dlmbl/S-md/'
output_path = parent_dir + 'training_logs/'
model_name = "static_vanilla_vae_md_10"
run_name= "initial_params"
model_name = "static_basic_vae_md"
find_port = True

# Function to find an available port
Expand Down Expand Up @@ -160,7 +157,7 @@ def train(

recon_batch, mu, logvar = vae(data)
MSE, KLD = loss_function(recon_batch, data, mu, logvar)
loss = MSE + KLD * 1e-8
loss = MSE + KLD * 1e-5

loss.backward()
train_loss += loss.item()
Expand Down Expand Up @@ -255,10 +252,19 @@ def train(
epoch, train_loss / len(dataloader.dataset)))

# Training loop
output_dir = '/mnt/efs/dlmbl/G-et/'
run_name= "basic_test"

folder_suffix = datetime.now().strftime("%Y%m%d_%H%M_") + run_name
checkpoint_path = output_path + "checkpoints/static/" + folder_suffix + "/"
log_path = output_path + "logs/static/"+ folder_suffix + "/"
for epoch in range(1, 10):
log_path = output_dir + "logs/static/Matteo/"+ folder_suffix + "/"
checkpoint_path = output_dir + "checkpoints/static/Matteo/" + folder_suffix + "/"

if not os.path.exists(log_path):
os.makedirs(log_path)
if not os.path.exists(checkpoint_path):
os.makedirs(checkpoint_path)

for epoch in range(1, 100):
train(epoch, log_interval=100, log_image_interval=20, tb_logger=logger)
filename_suffix = datetime.now().strftime("%Y%m%d_%H%M%S_") + "epoch_"+str(epoch) + "_"
training_logDF = pd.DataFrame(training_log)
Expand All @@ -273,4 +279,4 @@ def train(
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss_per_epoch
}
torch.save(checkpoint, output_path + filename_suffix + str(epoch) + "checkpoint.pth")
torch.save(checkpoint, checkpoint_path + filename_suffix + str(epoch) + "checkpoint.pth")
23 changes: 15 additions & 8 deletions scripts/training_loop_resnet18_linear_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,7 @@ def read_config(yaml_path):
device = torch.device("cpu")

# Basic values for logging
parent_dir = '/mnt/efs/dlmbl/S-md/'
output_path = parent_dir + 'training_logs/'
model_name = "static_resnet_linear_vae_md_10"
run_name= "initial_params"
model_name = "static_resnet_linear_vae_md"
find_port = True

# Function to find an available port
Expand Down Expand Up @@ -74,6 +71,7 @@ def launch_tensorboard(log_dir):
logger = SummaryWriter(f"embed_time_static_runs/{model_name}")

# Define variables for the dataset read in
parent_dir = '/mnt/efs/dlmbl/S-md/'
csv_file = '/mnt/efs/dlmbl/G-et/csv/dataset_split_2.csv'
split = 'train'
channels = [0, 1, 2, 3]
Expand Down Expand Up @@ -149,7 +147,7 @@ def train(

recon_batch, z, mu, logvar = vae(data)
MSE, KLD = loss_function(recon_batch, data, mu, logvar)
loss = MSE + KLD * 1e-8
loss = MSE + KLD * 1e-5

loss.backward()
train_loss += loss.item()
Expand Down Expand Up @@ -244,9 +242,18 @@ def train(
epoch, train_loss / len(dataloader.dataset)))

# Training loop
output_dir = '/mnt/efs/dlmbl/G-et/'
run_name= "resnet_linear_test"

folder_suffix = datetime.now().strftime("%Y%m%d_%H%M_") + run_name
checkpoint_path = output_path + "checkpoints/static/" + folder_suffix + "/"
log_path = output_path + "logs/static/"+ folder_suffix + "/"
log_path = output_dir + "logs/static/Matteo/"+ folder_suffix + "/"
checkpoint_path = output_dir + "checkpoints/static/Matteo/" + folder_suffix + "/"

if not os.path.exists(log_path):
os.makedirs(log_path)
if not os.path.exists(checkpoint_path):
os.makedirs(checkpoint_path)

for epoch in range(1, 100):
train(epoch, log_interval=100, log_image_interval=20, tb_logger=logger)
filename_suffix = datetime.now().strftime("%Y%m%d_%H%M%S_") + "epoch_"+str(epoch) + "_"
Expand All @@ -262,4 +269,4 @@ def train(
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss_per_epoch
}
torch.save(checkpoint, output_path + filename_suffix + str(epoch) + "checkpoint.pth")
torch.save(checkpoint, checkpoint_path + filename_suffix + str(epoch) + "checkpoint.pth")
24 changes: 15 additions & 9 deletions scripts/training_loop_resnet18_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,7 @@ def read_config(yaml_path):
device = torch.device("cpu")

# Basic values for logging
parent_dir = '/mnt/efs/dlmbl/S-md/'
output_path = parent_dir + 'training_logs/'
model_name = "static_resnet_vae_md_10"
run_name= "initial_params"
model_name = "static_resnet_vae_md"
find_port = True

# Function to find an available port
Expand Down Expand Up @@ -149,7 +146,7 @@ def train(

recon_batch, mu, logvar = vae(data)
MSE, KLD = loss_function(recon_batch, data, mu, logvar)
loss = MSE + KLD * 1e-8
loss = MSE + KLD * 1e-5

loss.backward()
train_loss += loss.item()
Expand Down Expand Up @@ -244,10 +241,19 @@ def train(
epoch, train_loss / len(dataloader.dataset)))

# Training loop
output_dir = '/mnt/efs/dlmbl/G-et/'
run_name= "resnet_test"

folder_suffix = datetime.now().strftime("%Y%m%d_%H%M_") + run_name
checkpoint_path = output_path + "checkpoints/static/" + folder_suffix + "/"
log_path = output_path + "logs/static/"+ folder_suffix + "/"
for epoch in range(1, 10):
log_path = output_dir + "logs/static/Matteo/"+ folder_suffix + "/"
checkpoint_path = output_dir + "checkpoints/static/Matteo/" + folder_suffix + "/"

if not os.path.exists(log_path):
os.makedirs(log_path)
if not os.path.exists(checkpoint_path):
os.makedirs(checkpoint_path)

for epoch in range(1, 100):
train(epoch, log_interval=100, log_image_interval=20, tb_logger=logger)
filename_suffix = datetime.now().strftime("%Y%m%d_%H%M%S_") + "epoch_"+str(epoch) + "_"
training_logDF = pd.DataFrame(training_log)
Expand All @@ -262,4 +268,4 @@ def train(
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss_per_epoch
}
torch.save(checkpoint, output_path + filename_suffix + str(epoch) + "checkpoint.pth")
torch.save(checkpoint, checkpoint_path + filename_suffix + str(epoch) + "checkpoint.pth")

0 comments on commit fefe44b

Please sign in to comment.