diff --git a/snrv/snrv.py b/snrv/snrv.py index 195eee1..2bba54c 100644 --- a/snrv/snrv.py +++ b/snrv/snrv.py @@ -2,13 +2,14 @@ # Do something if this file is invoked on its own print("0") +from warnings import warn import numpy as np import os import torch import torch.optim as optim import torch.nn as nn -from torch.utils.data.dataset import random_split +from torch.utils.data.dataset import random_split, Subset from torch.utils.data import DataLoader from tqdm import tqdm @@ -152,7 +153,7 @@ def __init__( super().__init__() - assert 0.0 < val_frac < 1.0 + assert 0.0 <= val_frac < 1.0 assert isinstance(VAMPdegree, int) and VAMPdegree > 0 self.input_size = input_size @@ -513,53 +514,143 @@ def _apply_Koopman_reweighting(self, C00, C01, z_t0, z_tt, pathweight): def _create_dataset(self, data, ln_dynamical_weight, thermo_weight): """ - create training and validation data loader + create training and validation data loader. Since the input data is a time series, the split cannot + be random. If `data` is a torch.tensor, the final val_frac of the data is used as the validation set. + If `data` is a list of torch.tensors, the ceil(val_frac*len(data)) of the trajectories are used. If the + trajectories are not of the same length, the shortest trajectory is used for validation. Parameters ---------- data : torch.tensor, n x dim, n = observations, dim = dimensionality of trajectory featurization trajectory - ln_pathweight : torch.tensor, n, n = observations + ln_dynamical_weight : torch.tensor, n, n = observations, default = None accumulated sum of the log Girsanov path weights between frames in the trajectory; - Girsanov theorem measure of the probability of the observed sample path under a target potential - relative to that which was actually observed under the simulation potential; - identically unity (no reweighting rqd) for target potential == simulation potential and code as None; - Ref.: Kieninger and Keller J. Chem. Phys 154 094102 (2021) https://doi.org/10.1063/5.0038408 + + thermo_weight : torch.tensor, n, n = observations, default = None + thermodynamic weights for each trajectory frame Return ------ - self._train_loader : torch DataLoader - training data loader - - self._val_loader : torch DataLoader - validation data loader + None """ - dataset = DatasetSnrv( - data, - self.lag, - ln_dynamical_weight, - thermo_weight, - ) + def _create_dataloaders(train_data, val_data, ln_dynamical_weight, thermo_weight): + train_dataset = DatasetSnrv( + train_data, + self.lag, + ln_dynamical_weight, + thermo_weight, + ) + self._train_loader = DataLoader( + dataset=train_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + val_dataset = DatasetSnrv( + val_data, + self.lag, + ln_dynamical_weight, + thermo_weight, + ) + self._val_loader = DataLoader( + dataset=val_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + return None + + if self.val_frac == 0.0: # if no validation data is to be used + # all data is used for training + train_dataset = DatasetSnrv( + data, + self.lag, + ln_dynamical_weight, + thermo_weight, + ) + self._train_loader = DataLoader( + dataset=train_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + self._val_loader = None + return None + + if isinstance(data, torch.Tensor): # if data is a single trajectory file + n = data.size(0) + # split data into training and validation sets based on val_frac + train_size = int((1.0 - self.val_frac) * n) + train_dataset = data[0:train_size] + val_dataset = data[train_size:] + _create_dataloaders(train_dataset, val_dataset, ln_dynamical_weight, thermo_weight) + + + elif isinstance(data, list): # if data is a list of trajectories + if len(data) == 1: # same as tensor case + data = data[0] + n = data.size(0) + # split data into training and validation sets based on val_frac + train_size = int((1.0 - self.val_frac) * n) + train_dataset = data[0:train_size] + val_dataset = data[train_size:] + _create_dataloaders(train_dataset, val_dataset, ln_dynamical_weight, thermo_weight) + else: + len_list = [x.size(0) for x in data] # length of each trajectory + is_homogeneous = all([x == len_list[0] for x in len_list]) # check if all trajectories are of same length + if is_homogeneous: + # avoid breaking the trajectories into smaller chunks by using a single trajectory for validation + val_size = self.val_frac * len(data) + if val_size != np.ceil(val_size): + val_size = int(np.ceil(val_size)) # adjust val_frac to equal a whole number of trajectories + warn(f"""Selected validation size is larger/smaller than a single trajectory. Validation percentage will be changed to %{val_size/len(data)*100}. + You can change this behavior by reintializing the model with `val_frac=0` and manually feed the validation data using `val_data` argument of `fit`.""", stacklevel=2) + else: + val_size = int(val_size) # adjust val_frac to equal a whole number of trajectories + _create_dataloaders(data[:-val_size], data[-val_size:], ln_dynamical_weight, thermo_weight) + + else: # if trajectories are not of same length + len_list_sorted = np.argsort(len_list)[::-1] # descending order of trajectory lengths + data = [data[i] for i in len_list_sorted] # sort data based on trajectory length + warn(f"""Trajectories in the provided list are not the same length. The shortest trajectory will be selected for validation (validation percentage = {data[-1].size(0)/sum(len_list)*100}). + You can change this behavior by reintializing the model with `val_frac=0` and manually setting the validation data using `val_data` argument of `fit`.""", stacklevel=2) + _create_dataloaders(data[:-1], data[-1], ln_dynamical_weight, thermo_weight) + return None + + def _set_val_loader(self, val_data, ln_dynamical_weight=None, thermo_weight=None): + """ + manually set the validation data loader. + + Parameters + ---------- + val_data : torch.tensor, n x dim, n = observations, dim = dimensionality of trajectory featurization + trajectory + + ln_dynamical_weight : torch.tensor, n, n = observations, default = None + accumulated sum of the log Girsanov path weights between frames in the trajectory; - n = len(dataset) - train_size = int((1.0 - self.val_frac) * n) - val_size = n - train_size - train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) + thermo_weight : torch.tensor, n, n = observations, default = None + thermodynamic weights for each trajectory frame + + Return + ------ + None + """ + assert isinstance(val_data, torch.Tensor) - self._train_loader = DataLoader( - dataset=train_dataset, - batch_size=self.batch_size, - shuffle=True, - num_workers=self.num_workers, + val_dataset = DatasetSnrv( + val_data, + self.lag, + ln_dynamical_weight, + thermo_weight, ) self._val_loader = DataLoader( - dataset=val_dataset, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers, - ) - + dataset=val_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) return None def _train_step(self, x_t0, x_tt, pathweight): @@ -599,6 +690,7 @@ def fit( standardize=False, scheduler=False, noise_scheduler=None, + val_data=None, ): """ fit SNRV model to data @@ -642,6 +734,10 @@ def fit( noise magnitude will be sampled from `(0.1 / 2) * N(0, 1)`, and in the third epoch will be `(0.1 / 3) * N(0, 1)`, etc... + val_data : torch.tensor, n x dim, n = observations, dim = dimensionality of trajectory featurization + validation data to use during training. If None (default), then the validation data is taken from the + data provided in the `data` argument according to the `val_frac` parameter set during model initialization. + Return ------ self.lag : int @@ -676,6 +772,10 @@ def fit( self.lag = lag self.standardize = standardize self._create_dataset(data, ln_dynamical_weight, thermo_weight) + if val_data is not None: # if validation data is provided manually + if self.val_frac != 0.0: + raise ValueError("Validation data is provided but val_frac is not set to 0. Please set val_frac to zero and feed the training data to `data` argument or remove val_data argument.") + self._set_val_loader(val_data, ln_dynamical_weight, thermo_weight) if self.standardize: # calculate data mean and std from x_0 for i, (x_batch, _, _) in tqdm( @@ -731,9 +831,13 @@ def fit( if hasattr(self, "scheduler"): self.scheduler.step() - self.eval() - with torch.no_grad(): - val_losses = [] + self.eval() + with torch.no_grad(): + val_losses = [] + if self._val_loader is None: + validation_loss = np.NaN + validation_losses.append(validation_loss) + else: for x_t0_batch, x_tt_batch, pathweight_batch in self._val_loader: x_t0_batch = x_t0_batch.to(self.device) x_tt_batch = x_tt_batch.to(self.device) @@ -747,10 +851,10 @@ def fit( validation_loss = float(np.mean(val_losses)) validation_losses.append(validation_loss) - print( - "[Epoch %d]\t training loss = %.3f\t validation loss = %.3f" - % (epoch, training_loss, validation_loss) - ) + print( + "[Epoch %d]\t training loss = %.3f\t validation loss = %.3f" + % (epoch, training_loss, validation_loss) + ) self.training_losses = training_losses self.validation_losses = validation_losses