Skip to content

Commit

Permalink
add initial licenses
Browse files Browse the repository at this point in the history
  • Loading branch information
awfderry committed Jun 4, 2021
1 parent f8b2a35 commit 50b9b6b
Show file tree
Hide file tree
Showing 7 changed files with 1,650 additions and 21 deletions.
23 changes: 12 additions & 11 deletions examples/lep/gnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
Expand Down Expand Up @@ -83,7 +84,7 @@ def test(gcn_model, ff_model, loader, criterion, device):
auroc = roc_auc_score(y_true, y_pred)
auprc = average_precision_score(y_true, y_pred)

return np.mean(losses), auroc, auprc
return np.mean(losses), auroc, auprc, y_true, y_pred

def plot_corr(y_true, y_pred, plot_dir):
plt.clf()
Expand All @@ -95,7 +96,7 @@ def plot_corr(y_true, y_pred, plot_dir):
def save_weights(model, weight_dir):
torch.save(model.state_dict(), weight_dir)

def train(args, device, log_dir, seed=None, test_mode=False):
def train(args, device, log_dir, rep=None, test_mode=False):
# logger = logging.getLogger('lba')
# logger.basicConfig(filename=os.path.join(log_dir, f'train_{split}_cv{fold}.log'),level=logging.INFO)
transform = PairedGraphTransform('atoms_active', 'atoms_inactive', label_key='label')
Expand Down Expand Up @@ -129,7 +130,7 @@ def train(args, device, log_dir, seed=None, test_mode=False):
start = time.time()
train_loss = train_loop(epoch, gcn_model, ff_model, train_loader, criterion, optimizer, device)
print('validating...')
val_loss, auroc, auprc = test(gcn_model, ff_model, val_loader, criterion, device)
val_loss, auroc, auprc, _, _ = test(gcn_model, ff_model, val_loader, criterion, device)
if auroc > best_val_auroc:
torch.save({
'epoch': epoch,
Expand All @@ -144,15 +145,15 @@ def train(args, device, log_dir, seed=None, test_mode=False):
print(f'\tTrain loss {train_loss}, Val loss {val_loss}, Val AUROC {auroc}, Val auprc {auprc}')

if test_mode:
test_file = os.path.join(log_dir, f'test_results.txt')
test_file = os.path.join(log_dir, f'lep_rep{rep}.csv')
cpt = torch.load(os.path.join(log_dir, f'best_weights.pt'))
gcn_model.load_state_dict(cpt['gcn_state_dict'])
ff_model.load_state_dict(cpt['ff_state_dict'])
test_loss, auroc, auprc = test(gcn_model, ff_model, test_loader, criterion, device)
test_loss, auroc, auprc, y_true, y_pred = test(gcn_model, ff_model, test_loader, criterion, device)
print(f'\tTest loss {test_loss}, Test AUROC {auroc}, Test auprc {auprc}')
with open(test_file, 'w') as f:
f.write(f'test_loss\tAUROC\n')
f.write(f'{test_loss}\t{auroc}\n')
res_df = pd.DataFrame(y_true, y_pred, columns=['true', 'pred'])
res_df.to_csv(test_file, index=False)

return test_loss, auroc, auprc


Expand All @@ -166,7 +167,7 @@ def train(args, device, log_dir, seed=None, test_mode=False):
parser.add_argument('--mode', type=str, default='train')
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--hidden_dim', type=int, default=64)
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--num_epochs', type=int, default=50)
parser.add_argument('--learning_rate', type=float, default=1e-4)
parser.add_argument('--log_dir', type=str, default=None)
args = parser.parse_args()
Expand All @@ -186,9 +187,9 @@ def train(args, device, log_dir, seed=None, test_mode=False):
train(args, device, log_dir)

elif args.mode == 'test':
for seed in np.random.randint(0, 1000, size=3):
for rep, seed in enumerate(np.random.randint(0, 1000, size=3)):
print('seed:', seed)
log_dir = os.path.join('logs', f'test_{seed}')
log_dir = os.path.join('logs', f'test_rep{rep}')
if not os.path.exists(log_dir):
os.makedirs(log_dir)
np.random.seed(seed)
Expand Down
20 changes: 14 additions & 6 deletions examples/psr/gnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import time
import datetime
import wandb

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -70,11 +71,12 @@ def train_loop(model, loader, optimizer, device):
loss_all += loss.item() * data.num_graphs
total += data.num_graphs
optimizer.step()
wandb.log({'train_loss': loss})
return np.sqrt(loss_all / total)


@torch.no_grad()
def test(model, loader, device):
def test(model, loader, device, log=True):
model.eval()

losses = []
Expand Down Expand Up @@ -107,6 +109,8 @@ def test(model, loader, device):
)

res = compute_correlations(test_df)
if log:
wandb.log({'val_loss': np.mean(losses), 'pearson': res['all_pearson'], 'kendall': res['all_kendall'], 'spearman': res['all_spearman']})

return np.mean(losses), res, test_df

Expand Down Expand Up @@ -169,12 +173,13 @@ def train(args, device, log_dir, seed=None, test_mode=False):
train_loss, val_loss, corrs['per_target_spearman'], corrs['all_spearman']))

if test_mode:
test_file = os.path.join(log_dir, f'test_results.txt')
test_file = os.path.join(log_dir, f'psr_rep{rep}.csv')
model.load_state_dict(torch.load(os.path.join(log_dir, f'best_weights.pt')))
test_loss, corrs, test_df = test(model, val_loader, device)
print('Test RMSE: {:.7f}, Per-target Spearman R: {:.7f}, Global Spearman R: {:.7f}'.format(
test_loss, corrs['per_target_spearman'], corrs['all_spearman']))
test_df.to_csv(test_file)
val_loss, corrs, results_df = test(model, test_loader, device, log=False)
# plot_corr(y_true, y_pred, os.path.join(log_dir, f'corr_{split}_test.png'))
print('\tTest RMSE: {:.7f}, Per-target Spearman R: {:.7f}, Global Spearman R: {:.7f}'.format(
train_loss, val_loss, corrs['per_target_spearman'], corrs['all_spearman']))
pd.to_csv(results_df, test_file, index=False)



Expand All @@ -191,6 +196,9 @@ def train(args, device, log_dir, seed=None, test_mode=False):

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
log_dir = args.log_dir

wandb.init(project="atom3d", name='PSR', config=vars(args)
)


if args.mode == 'train':
Expand Down
15 changes: 11 additions & 4 deletions examples/rsr/gnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import time
import datetime
import wandb

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -72,11 +73,12 @@ def train_loop(model, loader, optimizer, device):
loss_all += loss.item() * data.num_graphs
total += data.num_graphs
optimizer.step()
wandb.log({'train_loss': loss})
return np.sqrt(loss_all / total)


@torch.no_grad()
def test(model, loader, device):
def test(model, loader, device, log=True):
model.eval()

losses = []
Expand All @@ -101,8 +103,11 @@ def test(model, loader, device):
columns=['target', 'decoy', 'true', 'pred'],
)

corrs = compute_correlations(results_df)
return np.sqrt(np.mean(losses)), corrs, results_df
res = compute_correlations(results_df)
if log:
wandb.log({'val_loss': np.mean(losses), 'pearson': res['all_pearson'], 'kendall': res['all_kendall'], 'spearman': res['all_spearman']})

return np.sqrt(np.mean(losses)), res, results_df

def plot_corr(y_true, y_pred, plot_dir):
plt.clf()
Expand Down Expand Up @@ -159,7 +164,7 @@ def train(args, device, log_dir, rep=None, test_mode=False):
if test_mode:
test_file = os.path.join(log_dir, f'rsr_rep{rep}.csv')
model.load_state_dict(torch.load(os.path.join(log_dir, f'best_weights.pt')))
val_loss, corrs, results_df = test(model, test_loader, device)
val_loss, corrs, results_df = test(model, test_loader, device, log=False)
# plot_corr(y_true, y_pred, os.path.join(log_dir, f'corr_{split}_test.png'))
print('\tTest RMSE: {:.7f}, Per-target Spearman R: {:.7f}, Global Spearman R: {:.7f}'.format(
train_loss, val_loss, corrs['per_target_spearman'], corrs['all_spearman']))
Expand All @@ -181,6 +186,8 @@ def train(args, device, log_dir, rep=None, test_mode=False):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
log_dir = args.log_dir

wandb.init(project="atom3d", name='RSR', config=vars(args)
)

if args.mode == 'train':
if log_dir is None:
Expand Down
Loading

0 comments on commit 50b9b6b

Please sign in to comment.