diff --git a/gefest/surrogate_models/inference.py b/gefest/surrogate_models/inference.py index c9a6b56dd..378978f93 100644 --- a/gefest/surrogate_models/inference.py +++ b/gefest/surrogate_models/inference.py @@ -8,15 +8,15 @@ from tqdm import tqdm from utils.dataloader import create_single_dataloader #from utils import log_images, dsc -from models import UNet +from models import AttU_Net, UNet import pandas as pd device = torch.device("cpu" if not torch.cuda.is_available() else 'cuda') dataloader = create_single_dataloader(path_to_dir='data_from_comsol//test_gen_data',batch_size=10,shuffle=False) -model = UNet(in_channels=1,out_channels=1).to(device) -model.load_state_dict(torch.load(f'weights//unet_58_mask_ssim_100.pt',map_location=torch.device(device))) -CASE = 'ssim_plus_58' +model = AttU_Net(img_ch=1,output_ch=1).to(device)#UNet(in_channels=1,out_channels=1).to(device) +model.load_state_dict(torch.load(r'weights\unet_11_adam_Accum_2.pt',map_location=torch.device(device))) +CASE = 'att_11' model.eval() #predicts = [] truth = [] @@ -26,7 +26,7 @@ x, y_true = data x = x.to(device) - y_pred = model(x,x).squeeze() + y_pred = model(x.float()).squeeze() if i==0: predicts = np.copy(y_pred.cpu().numpy()) truth = np.copy(y_true) diff --git a/gefest/surrogate_models/models.py b/gefest/surrogate_models/models.py index 4756f8c69..34c9d54b6 100644 --- a/gefest/surrogate_models/models.py +++ b/gefest/surrogate_models/models.py @@ -46,7 +46,7 @@ def __init__(self, in_channels=3, out_channels=1, init_features=32): in_channels=features, out_channels=out_channels, kernel_size=1 ) - def forward(self, x,mask): + def forward(self, x): x = torch.unsqueeze(x,dim=1).float() enc1 = self.encoder1(x) enc2 = self.encoder2(self.pool1(enc1)) @@ -353,6 +353,7 @@ def __init__(self,img_ch=3,output_ch=1): def forward(self,x): # encoding path + x1 = self.Conv1(x) x2 = self.Maxpool(x1) @@ -495,6 +496,7 @@ def __init__(self,img_ch=3,output_ch=1): def forward(self,x): # encoding path + x = torch.unsqueeze(x,dim=1).float() x1 = self.Conv1(x) x2 = self.Maxpool(x1) diff --git a/gefest/surrogate_models/read_history.py b/gefest/surrogate_models/read_history.py index f22d3b006..f084b3a9e 100644 --- a/gefest/surrogate_models/read_history.py +++ b/gefest/surrogate_models/read_history.py @@ -1,6 +1,7 @@ import pandas as pd import numpy as np from matplotlib import pyplot as plt +import torch step = 100 df = pd.read_csv('result_train_mask_ssim.csv') @@ -11,11 +12,22 @@ plt.plot(bins,label='base_train') #plt.plot([0.013666182630779233,0.00986096077757238,0.006160223454961358,0.00486703837721719,0.005280730882590731,0.0028021071640542045]) -df = pd.read_csv('result_train_bi.csv') -df_test = pd.read_csv('result_test_bi.csv') +df = pd.read_csv(r'D:\Projects\GEFEST\GEFEST_surr\GEFEST\result_train_mask_ssim_100.csv') +df_test = pd.read_csv(r'result_test_mask_ssim_100.csv') bins = [df['train_loss'][i:i+step].mean() for i in range(0,len((df['train_loss'])),step)] plt.plot(df_test['test_loss'],label='ssim_test') plt.plot(bins,label='ssim_test') + + +df = pd.read_csv(r'result_train_adam_Accum_2.csv') +df_test = pd.read_csv(r'result_test_adam_Accum_2.csv') +bins = [df['train_loss'][i:i+step].mean() for i in range(0,len((df['train_loss'])),step)] +plt.plot(df_test['test_loss'],label='att_test') +plt.plot(bins,label='att_train') + + + plt.grid() plt.legend() -plt.show() \ No newline at end of file +plt.show() + diff --git a/gefest/surrogate_models/train_sched.py b/gefest/surrogate_models/train_sched.py index b456ca9a8..21e990cb4 100644 --- a/gefest/surrogate_models/train_sched.py +++ b/gefest/surrogate_models/train_sched.py @@ -14,6 +14,7 @@ from timm.scheduler.cosine_lr import CosineLRScheduler from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM from torch.utils.tensorboard import SummaryWriter +from lion_pytorch import Lion def main(args): makedirs(args) diff --git a/gefest/surrogate_models/train_sched_att_u.py b/gefest/surrogate_models/train_sched_att_u.py new file mode 100644 index 000000000..d39e40457 --- /dev/null +++ b/gefest/surrogate_models/train_sched_att_u.py @@ -0,0 +1,140 @@ +import json +import os +import sys +sys.path.append(os.getcwd()) +import numpy as np +import torch +import torch.optim as optim +from torch.utils.data import DataLoader +from tqdm import tqdm +from utils.dataloader import create_dataloaders +#from utils import log_images, dsc +from models import AttU_Net, UNet,UNet_bi +import pandas as pd +from timm.scheduler.cosine_lr import CosineLRScheduler +from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM +from torch.utils.tensorboard import SummaryWriter +from lion_pytorch import Lion + +def main(args,accum_gr): + makedirs(args) + + writer = SummaryWriter('writer/Adam_acc') + #snapshotargs(args) + device = torch.device("cpu" if not torch.cuda.is_available() else args.device) + + loader_train, loader_valid = create_dataloaders(args.path_to_data, + batch_size = args.batch_size, + validation_split = 0.2, + shuffle_dataset = True, + random_seed = 42, + end=None) + loaders = {"train": loader_train, "valid": loader_valid} + + unet = AttU_Net(img_ch=1,output_ch=1)#UNet_bi(in_channels=1, out_channels=1) + unet.to(device) + + ssim_loss = SSIM(data_range=2, size_average=True, channel=1) + mae_loss = torch.nn.L1Loss()#torch.nn.MSELoss() + best_validation_dsc = 0.0 + + optimizer = optim.AdamW(unet.parameters(), lr=args.lr)##Lion(unet.parameters(), lr=args.lr, weight_decay=1e-2) + sched =CosineLRScheduler(optimizer, t_initial=8, lr_min=0.00001, + cycle_mul=1.0, cycle_decay=1.0, cycle_limit=1, + warmup_t=0, warmup_lr_init=0.00001, warmup_prefix=False, t_in_epochs=True, + noise_range_t=None, noise_pct=0.67, noise_std=1.0, + noise_seed=42, k_decay=1.0, initialize=True) + #logger = Logger(args.logs) + loss_train = [] + loss_valid = [] + mean_loss_train = [] + mean_loss_test = [] + accum_loss = [] + + step = 0 + ep = 0 + for epoch in tqdm(range(args.epochs), total=args.epochs): + + sched.step(ep) + writer.add_scalar('Lr',optimizer.param_groups[0]['lr'],ep) + ep+=1 + #print(optimizer.param_groups[0]['lr']) + # for param_group in optimizer.param_groups: + # current_lr = param_group['lr'] + + for phase in ["train", "valid"]: + if phase == "train": + unet.train() + else: + unet.eval() + + validation_pred = [] + validation_true = [] + + for i, data in enumerate(tqdm(loaders[phase])): + if phase == "train": + step += 1 + + x, y_true = data + x, y_true = x.to(device), y_true.to(device) + + + + with torch.set_grad_enabled(phase == "train"): + y_pred = unet(x.float()) + + loss = 0.5*(1-ssim_loss( y_pred, y_true.unsqueeze(1).float())) + mae_loss(y_pred.squeeze(), y_true)#loss = dsc_loss(y_pred.squeeze(), y_true) + + if phase == "valid": + loss_valid.append(loss.item()) + + + if phase == "train": + if accum_gr is not None: + loss_train.append(loss.item()) + accum_loss.append(loss.item()) + loss = loss/accum_gr + loss.backward() + if (i + 1) % accum_gr == 0: + writer.add_scalar('Loss/train',loss_train[-1],step) + + writer.add_scalar('Loss/accum_train',sum(accum_loss)/len(accum_loss),step) + optimizer.step() + optimizer.zero_grad() + accum_loss=[] + else: + loss_train.append(loss.item()) + writer.add_scalar('Loss/train',loss_train[-1],step) + loss.backward() + optimizer.step() + optimizer.zero_grad() + + if phase == "train": + mean_loss_train.append(np.mean(loss_train)) + #mean_loss_test.append('') + print('train_loss',mean_loss_train[-1]) + loss_train = [] + + if phase == "valid": + mean_loss_test.append(np.mean(loss_valid)) + writer.add_scalar('Loss/test',mean_loss_test[-1],step) + print('test_loss',mean_loss_test[-1]) + torch.save(unet.state_dict(), os.path.join(args.weights, f"unet_{epoch}_adam_Accum_2.pt")) + loss_valid = [] + result_train = pd.DataFrame(data = {'train_loss':mean_loss_train}) + result_test = pd.DataFrame(data = {'test_loss':mean_loss_test}) + result_test.to_csv('result_test_adam_Accum_2.csv') + result_train.to_csv('result_train_adam_Accum_2.csv') + #print("Best validation mean DSC: {:4f}".format(best_validation_dsc)) + +def makedirs(args): + os.makedirs(args.weights, exist_ok=True) + os.makedirs(args.logs, exist_ok=True) + + + +if __name__ == "__main__": + from config import crete_parser + args = crete_parser(batch_size=7,epochs=70,lr=0.01) + accumulatin_gradients = 10 + main(args=args,accum_gr=accumulatin_gradients) diff --git a/gefest/surrogate_models/utils/animation.py b/gefest/surrogate_models/utils/animation.py index 9af30c814..de351f3ef 100644 --- a/gefest/surrogate_models/utils/animation.py +++ b/gefest/surrogate_models/utils/animation.py @@ -91,4 +91,4 @@ def animation_npz(path_to_dir): #animation_npz(path_to_dir='gefest\surrogate_models\gendata/ssim_23') #animation_npz(path_to_dir='gefest\surrogate_models\gendata/ssim_plus_57') #animation_data_npz(path_to_dir='data_from_comsol/gen_data_extend') -animation_npz(path_to_dir='gefest\surrogate_models\gendata/ssim_plus_58') \ No newline at end of file +animation_npz(path_to_dir='gefest\surrogate_models\gendata/att_11') \ No newline at end of file