Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
SoloWayG committed Mar 26, 2024
1 parent 6a99a70 commit 38af6c9
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 10 deletions.
10 changes: 5 additions & 5 deletions gefest/surrogate_models/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
from tqdm import tqdm
from utils.dataloader import create_single_dataloader
#from utils import log_images, dsc
from models import UNet
from models import AttU_Net, UNet
import pandas as pd


device = torch.device("cpu" if not torch.cuda.is_available() else 'cuda')
dataloader = create_single_dataloader(path_to_dir='data_from_comsol//test_gen_data',batch_size=10,shuffle=False)
model = UNet(in_channels=1,out_channels=1).to(device)
model.load_state_dict(torch.load(f'weights//unet_58_mask_ssim_100.pt',map_location=torch.device(device)))
CASE = 'ssim_plus_58'
model = AttU_Net(img_ch=1,output_ch=1).to(device)#UNet(in_channels=1,out_channels=1).to(device)
model.load_state_dict(torch.load(r'weights\unet_11_adam_Accum_2.pt',map_location=torch.device(device)))
CASE = 'att_11'
model.eval()
#predicts = []
truth = []
Expand All @@ -26,7 +26,7 @@
x, y_true = data
x = x.to(device)

y_pred = model(x,x).squeeze()
y_pred = model(x.float()).squeeze()
if i==0:
predicts = np.copy(y_pred.cpu().numpy())
truth = np.copy(y_true)
Expand Down
4 changes: 3 additions & 1 deletion gefest/surrogate_models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, in_channels=3, out_channels=1, init_features=32):
in_channels=features, out_channels=out_channels, kernel_size=1
)

def forward(self, x,mask):
def forward(self, x):
x = torch.unsqueeze(x,dim=1).float()
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
Expand Down Expand Up @@ -353,6 +353,7 @@ def __init__(self,img_ch=3,output_ch=1):

def forward(self,x):
# encoding path

x1 = self.Conv1(x)

x2 = self.Maxpool(x1)
Expand Down Expand Up @@ -495,6 +496,7 @@ def __init__(self,img_ch=3,output_ch=1):

def forward(self,x):
# encoding path
x = torch.unsqueeze(x,dim=1).float()
x1 = self.Conv1(x)

x2 = self.Maxpool(x1)
Expand Down
18 changes: 15 additions & 3 deletions gefest/surrogate_models/read_history.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import torch

step = 100
df = pd.read_csv('result_train_mask_ssim.csv')
Expand All @@ -11,11 +12,22 @@
plt.plot(bins,label='base_train')
#plt.plot([0.013666182630779233,0.00986096077757238,0.006160223454961358,0.00486703837721719,0.005280730882590731,0.0028021071640542045])

df = pd.read_csv('result_train_bi.csv')
df_test = pd.read_csv('result_test_bi.csv')
df = pd.read_csv(r'D:\Projects\GEFEST\GEFEST_surr\GEFEST\result_train_mask_ssim_100.csv')
df_test = pd.read_csv(r'result_test_mask_ssim_100.csv')
bins = [df['train_loss'][i:i+step].mean() for i in range(0,len((df['train_loss'])),step)]
plt.plot(df_test['test_loss'],label='ssim_test')
plt.plot(bins,label='ssim_test')


df = pd.read_csv(r'result_train_adam_Accum_2.csv')
df_test = pd.read_csv(r'result_test_adam_Accum_2.csv')
bins = [df['train_loss'][i:i+step].mean() for i in range(0,len((df['train_loss'])),step)]
plt.plot(df_test['test_loss'],label='att_test')
plt.plot(bins,label='att_train')



plt.grid()
plt.legend()
plt.show()
plt.show()

1 change: 1 addition & 0 deletions gefest/surrogate_models/train_sched.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from timm.scheduler.cosine_lr import CosineLRScheduler
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
from torch.utils.tensorboard import SummaryWriter
from lion_pytorch import Lion

def main(args):
makedirs(args)
Expand Down
140 changes: 140 additions & 0 deletions gefest/surrogate_models/train_sched_att_u.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import json
import os
import sys
sys.path.append(os.getcwd())
import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils.dataloader import create_dataloaders
#from utils import log_images, dsc
from models import AttU_Net, UNet,UNet_bi
import pandas as pd
from timm.scheduler.cosine_lr import CosineLRScheduler
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
from torch.utils.tensorboard import SummaryWriter
from lion_pytorch import Lion

def main(args,accum_gr):
makedirs(args)

writer = SummaryWriter('writer/Adam_acc')
#snapshotargs(args)
device = torch.device("cpu" if not torch.cuda.is_available() else args.device)

loader_train, loader_valid = create_dataloaders(args.path_to_data,
batch_size = args.batch_size,
validation_split = 0.2,
shuffle_dataset = True,
random_seed = 42,
end=None)
loaders = {"train": loader_train, "valid": loader_valid}

unet = AttU_Net(img_ch=1,output_ch=1)#UNet_bi(in_channels=1, out_channels=1)
unet.to(device)

ssim_loss = SSIM(data_range=2, size_average=True, channel=1)
mae_loss = torch.nn.L1Loss()#torch.nn.MSELoss()
best_validation_dsc = 0.0

optimizer = optim.AdamW(unet.parameters(), lr=args.lr)##Lion(unet.parameters(), lr=args.lr, weight_decay=1e-2)
sched =CosineLRScheduler(optimizer, t_initial=8, lr_min=0.00001,
cycle_mul=1.0, cycle_decay=1.0, cycle_limit=1,
warmup_t=0, warmup_lr_init=0.00001, warmup_prefix=False, t_in_epochs=True,
noise_range_t=None, noise_pct=0.67, noise_std=1.0,
noise_seed=42, k_decay=1.0, initialize=True)
#logger = Logger(args.logs)
loss_train = []
loss_valid = []
mean_loss_train = []
mean_loss_test = []
accum_loss = []

step = 0
ep = 0
for epoch in tqdm(range(args.epochs), total=args.epochs):

sched.step(ep)
writer.add_scalar('Lr',optimizer.param_groups[0]['lr'],ep)
ep+=1
#print(optimizer.param_groups[0]['lr'])
# for param_group in optimizer.param_groups:
# current_lr = param_group['lr']

for phase in ["train", "valid"]:
if phase == "train":
unet.train()
else:
unet.eval()

validation_pred = []
validation_true = []

for i, data in enumerate(tqdm(loaders[phase])):
if phase == "train":
step += 1

x, y_true = data
x, y_true = x.to(device), y_true.to(device)



with torch.set_grad_enabled(phase == "train"):
y_pred = unet(x.float())

loss = 0.5*(1-ssim_loss( y_pred, y_true.unsqueeze(1).float())) + mae_loss(y_pred.squeeze(), y_true)#loss = dsc_loss(y_pred.squeeze(), y_true)

if phase == "valid":
loss_valid.append(loss.item())


if phase == "train":
if accum_gr is not None:
loss_train.append(loss.item())
accum_loss.append(loss.item())
loss = loss/accum_gr
loss.backward()
if (i + 1) % accum_gr == 0:
writer.add_scalar('Loss/train',loss_train[-1],step)

writer.add_scalar('Loss/accum_train',sum(accum_loss)/len(accum_loss),step)
optimizer.step()
optimizer.zero_grad()
accum_loss=[]
else:
loss_train.append(loss.item())
writer.add_scalar('Loss/train',loss_train[-1],step)
loss.backward()
optimizer.step()
optimizer.zero_grad()

if phase == "train":
mean_loss_train.append(np.mean(loss_train))
#mean_loss_test.append('')
print('train_loss',mean_loss_train[-1])
loss_train = []

if phase == "valid":
mean_loss_test.append(np.mean(loss_valid))
writer.add_scalar('Loss/test',mean_loss_test[-1],step)
print('test_loss',mean_loss_test[-1])
torch.save(unet.state_dict(), os.path.join(args.weights, f"unet_{epoch}_adam_Accum_2.pt"))
loss_valid = []
result_train = pd.DataFrame(data = {'train_loss':mean_loss_train})
result_test = pd.DataFrame(data = {'test_loss':mean_loss_test})
result_test.to_csv('result_test_adam_Accum_2.csv')
result_train.to_csv('result_train_adam_Accum_2.csv')
#print("Best validation mean DSC: {:4f}".format(best_validation_dsc))

def makedirs(args):
os.makedirs(args.weights, exist_ok=True)
os.makedirs(args.logs, exist_ok=True)



if __name__ == "__main__":
from config import crete_parser
args = crete_parser(batch_size=7,epochs=70,lr=0.01)
accumulatin_gradients = 10
main(args=args,accum_gr=accumulatin_gradients)
2 changes: 1 addition & 1 deletion gefest/surrogate_models/utils/animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,4 +91,4 @@ def animation_npz(path_to_dir):
#animation_npz(path_to_dir='gefest\surrogate_models\gendata/ssim_23')
#animation_npz(path_to_dir='gefest\surrogate_models\gendata/ssim_plus_57')
#animation_data_npz(path_to_dir='data_from_comsol/gen_data_extend')
animation_npz(path_to_dir='gefest\surrogate_models\gendata/ssim_plus_58')
animation_npz(path_to_dir='gefest\surrogate_models\gendata/att_11')

0 comments on commit 38af6c9

Please sign in to comment.