-
-
Notifications
You must be signed in to change notification settings - Fork 124
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
60 changed files
with
825 additions
and
204 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,13 +13,19 @@ | |
# Created by Wenjie Du <[email protected]> | ||
# License: BSD-3-Clause | ||
|
||
import os | ||
from typing import Union, Optional | ||
|
||
import h5py | ||
import numpy as np | ||
import torch | ||
from torch.utils.data import DataLoader | ||
|
||
try: | ||
import nni | ||
except ImportError: | ||
pass | ||
|
||
from .data import DatasetForCSDI | ||
from .modules import _CSDI | ||
from ..base import BaseNNImputer | ||
|
@@ -205,6 +211,113 @@ def _assemble_input_for_validating(self, data) -> dict: | |
def _assemble_input_for_testing(self, data) -> dict: | ||
return self._assemble_input_for_validating(data) | ||
|
||
def _train_model( | ||
self, | ||
training_loader: DataLoader, | ||
val_loader: DataLoader = None, | ||
) -> None: | ||
# each training starts from the very beginning, so reset the loss and model dict here | ||
self.best_loss = float("inf") | ||
self.best_model_dict = None | ||
|
||
try: | ||
training_step = 0 | ||
for epoch in range(self.epochs): | ||
self.model.train() | ||
epoch_train_loss_collector = [] | ||
for idx, data in enumerate(training_loader): | ||
training_step += 1 | ||
inputs = self._assemble_input_for_training(data) | ||
self.optimizer.zero_grad() | ||
results = self.model.forward(inputs) | ||
# use sum() before backward() in case of multi-gpu training | ||
results["loss"].sum().backward() | ||
self.optimizer.step() | ||
epoch_train_loss_collector.append(results["loss"].sum().item()) | ||
|
||
# save training loss logs into the tensorboard file for every step if in need | ||
if self.summary_writer is not None: | ||
self._save_log_into_tb_file(training_step, "training", results) | ||
|
||
# mean training loss of the current epoch | ||
mean_train_loss = np.mean(epoch_train_loss_collector) | ||
|
||
if val_loader is not None: | ||
self.model.eval() | ||
val_loss_collector = [] | ||
with torch.no_grad(): | ||
for idx, data in enumerate(val_loader): | ||
inputs = self._assemble_input_for_validating(data) | ||
results = self.model.forward( | ||
inputs, training=False, n_sampling_times=0 | ||
) | ||
val_loss_collector.append(results["loss"].item()) | ||
|
||
mean_val_loss = np.asarray(val_loss_collector).mean() | ||
|
||
# save validating loss logs into the tensorboard file for every epoch if in need | ||
if self.summary_writer is not None: | ||
val_loss_dict = { | ||
"imputation_loss": mean_val_loss, | ||
} | ||
self._save_log_into_tb_file(epoch, "validating", val_loss_dict) | ||
|
||
logger.info( | ||
f"Epoch {epoch} - " | ||
f"training loss: {mean_train_loss:.4f}, " | ||
f"validating loss: {mean_val_loss:.4f}" | ||
) | ||
mean_loss = mean_val_loss | ||
else: | ||
logger.info(f"Epoch {epoch} - training loss: {mean_train_loss:.4f}") | ||
mean_loss = mean_train_loss | ||
|
||
if np.isnan(mean_loss): | ||
logger.warning( | ||
f"‼️ Attention: got NaN loss in Epoch {epoch}. This may lead to unexpected errors." | ||
) | ||
|
||
if mean_loss < self.best_loss: | ||
self.best_loss = mean_loss | ||
self.best_model_dict = self.model.state_dict() | ||
self.patience = self.original_patience | ||
# save the model if necessary | ||
self._auto_save_model_if_necessary( | ||
training_finished=False, | ||
saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}", | ||
) | ||
else: | ||
self.patience -= 1 | ||
|
||
if os.getenv("enable_tuning", False): | ||
nni.report_intermediate_result(mean_loss) | ||
if epoch == self.epochs - 1 or self.patience == 0: | ||
nni.report_final_result(self.best_loss) | ||
|
||
if self.patience == 0: | ||
logger.info( | ||
"Exceeded the training patience. Terminating the training procedure..." | ||
) | ||
break | ||
|
||
except Exception as e: | ||
logger.error(f"Exception: {e}") | ||
if self.best_model_dict is None: | ||
raise RuntimeError( | ||
"Training got interrupted. Model was not trained. Please investigate the error printed above." | ||
) | ||
else: | ||
RuntimeWarning( | ||
"Training got interrupted. Please investigate the error printed above.\n" | ||
"Model got trained and will load the best checkpoint so far for testing.\n" | ||
"If you don't want it, please try fit() again." | ||
) | ||
|
||
if np.isnan(self.best_loss): | ||
raise ValueError("Something is wrong. best_loss is Nan after training.") | ||
|
||
logger.info("Finished training.") | ||
|
||
def fit( | ||
self, | ||
train_set: Union[dict, str], | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.