Skip to content

Reworked validation data selection #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
186 changes: 145 additions & 41 deletions snrv/snrv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
# Do something if this file is invoked on its own
print("0")

from warnings import warn
import numpy as np
import os

import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data.dataset import random_split
from torch.utils.data.dataset import random_split, Subset
from torch.utils.data import DataLoader

from tqdm import tqdm
Expand Down Expand Up @@ -152,7 +153,7 @@ def __init__(

super().__init__()

assert 0.0 < val_frac < 1.0
assert 0.0 <= val_frac < 1.0
assert isinstance(VAMPdegree, int) and VAMPdegree > 0

self.input_size = input_size
Expand Down Expand Up @@ -513,53 +514,143 @@ def _apply_Koopman_reweighting(self, C00, C01, z_t0, z_tt, pathweight):

def _create_dataset(self, data, ln_dynamical_weight, thermo_weight):
"""
create training and validation data loader
create training and validation data loader. Since the input data is a time series, the split cannot
be random. If `data` is a torch.tensor, the final val_frac of the data is used as the validation set.
If `data` is a list of torch.tensors, the ceil(val_frac*len(data)) of the trajectories are used. If the
trajectories are not of the same length, the shortest trajectory is used for validation.

Parameters
----------
data : torch.tensor, n x dim, n = observations, dim = dimensionality of trajectory featurization
trajectory

ln_pathweight : torch.tensor, n, n = observations
ln_dynamical_weight : torch.tensor, n, n = observations, default = None
accumulated sum of the log Girsanov path weights between frames in the trajectory;
Girsanov theorem measure of the probability of the observed sample path under a target potential
relative to that which was actually observed under the simulation potential;
identically unity (no reweighting rqd) for target potential == simulation potential and code as None;
Ref.: Kieninger and Keller J. Chem. Phys 154 094102 (2021) https://doi.org/10.1063/5.0038408

thermo_weight : torch.tensor, n, n = observations, default = None
thermodynamic weights for each trajectory frame

Return
------
self._train_loader : torch DataLoader
training data loader

self._val_loader : torch DataLoader
validation data loader
None
"""
dataset = DatasetSnrv(
data,
self.lag,
ln_dynamical_weight,
thermo_weight,
)
def _create_dataloaders(train_data, val_data, ln_dynamical_weight, thermo_weight):
train_dataset = DatasetSnrv(
train_data,
self.lag,
ln_dynamical_weight,
thermo_weight,
)
self._train_loader = DataLoader(
dataset=train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
)
val_dataset = DatasetSnrv(
val_data,
self.lag,
ln_dynamical_weight,
thermo_weight,
)
self._val_loader = DataLoader(
dataset=val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
return None

if self.val_frac == 0.0: # if no validation data is to be used
# all data is used for training
train_dataset = DatasetSnrv(
data,
self.lag,
ln_dynamical_weight,
thermo_weight,
)
self._train_loader = DataLoader(
dataset=train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
)
self._val_loader = None
return None

if isinstance(data, torch.Tensor): # if data is a single trajectory file
n = data.size(0)
# split data into training and validation sets based on val_frac
train_size = int((1.0 - self.val_frac) * n)
train_dataset = data[0:train_size]
val_dataset = data[train_size:]
_create_dataloaders(train_dataset, val_dataset, ln_dynamical_weight, thermo_weight)


elif isinstance(data, list): # if data is a list of trajectories
if len(data) == 1: # same as tensor case
data = data[0]
n = data.size(0)
# split data into training and validation sets based on val_frac
train_size = int((1.0 - self.val_frac) * n)
train_dataset = data[0:train_size]
val_dataset = data[train_size:]
_create_dataloaders(train_dataset, val_dataset, ln_dynamical_weight, thermo_weight)
else:
len_list = [x.size(0) for x in data] # length of each trajectory
is_homogeneous = all([x == len_list[0] for x in len_list]) # check if all trajectories are of same length
if is_homogeneous:
# avoid breaking the trajectories into smaller chunks by using a single trajectory for validation
val_size = self.val_frac * len(data)
if val_size != np.ceil(val_size):
val_size = int(np.ceil(val_size)) # adjust val_frac to equal a whole number of trajectories
warn(f"""Selected validation size is larger/smaller than a single trajectory. Validation percentage will be changed to %{val_size/len(data)*100}.
You can change this behavior by reintializing the model with `val_frac=0` and manually feed the validation data using `val_data` argument of `fit`.""", stacklevel=2)
else:
val_size = int(val_size) # adjust val_frac to equal a whole number of trajectories
_create_dataloaders(data[:-val_size], data[-val_size:], ln_dynamical_weight, thermo_weight)

else: # if trajectories are not of same length
len_list_sorted = np.argsort(len_list)[::-1] # descending order of trajectory lengths
data = [data[i] for i in len_list_sorted] # sort data based on trajectory length
warn(f"""Trajectories in the provided list are not the same length. The shortest trajectory will be selected for validation (validation percentage = {data[-1].size(0)/sum(len_list)*100}).
You can change this behavior by reintializing the model with `val_frac=0` and manually setting the validation data using `val_data` argument of `fit`.""", stacklevel=2)
_create_dataloaders(data[:-1], data[-1], ln_dynamical_weight, thermo_weight)
return None

def _set_val_loader(self, val_data, ln_dynamical_weight=None, thermo_weight=None):
"""
manually set the validation data loader.

Parameters
----------
val_data : torch.tensor, n x dim, n = observations, dim = dimensionality of trajectory featurization
trajectory

ln_dynamical_weight : torch.tensor, n, n = observations, default = None
accumulated sum of the log Girsanov path weights between frames in the trajectory;

n = len(dataset)
train_size = int((1.0 - self.val_frac) * n)
val_size = n - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
thermo_weight : torch.tensor, n, n = observations, default = None
thermodynamic weights for each trajectory frame

Return
------
None
"""
assert isinstance(val_data, torch.Tensor)

self._train_loader = DataLoader(
dataset=train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
val_dataset = DatasetSnrv(
val_data,
self.lag,
ln_dynamical_weight,
thermo_weight,
)
self._val_loader = DataLoader(
dataset=val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)

dataset=val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
return None

def _train_step(self, x_t0, x_tt, pathweight):
Expand Down Expand Up @@ -599,6 +690,7 @@ def fit(
standardize=False,
scheduler=False,
noise_scheduler=None,
val_data=None,
):
"""
fit SNRV model to data
Expand Down Expand Up @@ -642,6 +734,10 @@ def fit(
noise magnitude will be sampled from `(0.1 / 2) * N(0, 1)`, and in the third epoch will be
`(0.1 / 3) * N(0, 1)`, etc...

val_data : torch.tensor, n x dim, n = observations, dim = dimensionality of trajectory featurization
validation data to use during training. If None (default), then the validation data is taken from the
data provided in the `data` argument according to the `val_frac` parameter set during model initialization.

Return
------
self.lag : int
Expand Down Expand Up @@ -676,6 +772,10 @@ def fit(
self.lag = lag
self.standardize = standardize
self._create_dataset(data, ln_dynamical_weight, thermo_weight)
if val_data is not None: # if validation data is provided manually
if self.val_frac != 0.0:
raise ValueError("Validation data is provided but val_frac is not set to 0. Please set val_frac to zero and feed the training data to `data` argument or remove val_data argument.")
self._set_val_loader(val_data, ln_dynamical_weight, thermo_weight)
if self.standardize:
# calculate data mean and std from x_0
for i, (x_batch, _, _) in tqdm(
Expand Down Expand Up @@ -731,9 +831,13 @@ def fit(
if hasattr(self, "scheduler"):
self.scheduler.step()

self.eval()
with torch.no_grad():
val_losses = []
self.eval()
with torch.no_grad():
val_losses = []
if self._val_loader is None:
validation_loss = np.NaN
validation_losses.append(validation_loss)
else:
for x_t0_batch, x_tt_batch, pathweight_batch in self._val_loader:
x_t0_batch = x_t0_batch.to(self.device)
x_tt_batch = x_tt_batch.to(self.device)
Expand All @@ -747,10 +851,10 @@ def fit(
validation_loss = float(np.mean(val_losses))
validation_losses.append(validation_loss)

print(
"[Epoch %d]\t training loss = %.3f\t validation loss = %.3f"
% (epoch, training_loss, validation_loss)
)
print(
"[Epoch %d]\t training loss = %.3f\t validation loss = %.3f"
% (epoch, training_loss, validation_loss)
)

self.training_losses = training_losses
self.validation_losses = validation_losses
Expand Down