diff --git a/pfns/bar_distribution.py b/pfns/bar_distribution.py index 1fea353..d92aa53 100644 --- a/pfns/bar_distribution.py +++ b/pfns/bar_distribution.py @@ -386,8 +386,7 @@ def get_bucket_limits(num_outputs:int, full_range:tuple=None, ys:torch.Tensor=No full_range = (ys.min(), ys.max()) else: assert full_range[0] <= ys.min() and full_range[1] >= ys.max(), f'full_range {full_range} not in range of ys {ys.min(), ys.max()}' - # FIXME: this needs to be on same device - full_range = torch.tensor(full_range) + full_range = torch.tensor(full_range, device=ys.device) ys_sorted, ys_order = ys.sort(0) bucket_limits = (ys_sorted[ys_per_bucket-1::ys_per_bucket][:-1]+ys_sorted[ys_per_bucket::ys_per_bucket])/2 if verbose: diff --git a/pfns/train.py b/pfns/train.py index d55b9d6..5998a3a 100644 --- a/pfns/train.py +++ b/pfns/train.py @@ -10,6 +10,7 @@ import torch from torch import nn from torch.cuda.amp import autocast, GradScaler +import numpy as np from . import utils from .priors import prior @@ -281,6 +282,7 @@ def apply_batch_wise_criterion(i): try: total_loss, total_positional_losses, time_to_get_batch, forward_time, step_time, nan_share, ignore_share =\ train_epoch() + step_callback({"mean_loss": total_loss}) except Exception as e: print("Invalid epoch encountered, skipping...") print(e) @@ -293,10 +295,16 @@ def apply_batch_wise_criterion(i): val_score = None if verbose: + pos_losses_str = f"pos losses {','.join([f'{l:5.2f}' for l in total_positional_losses])}" + if len(total_positional_losses) > 20: + actual_losses = np.array(total_positional_losses)[~np.isnan(total_positional_losses)] + pos_losses_str = f"pos losses std {np.std(actual_losses):5.2f} | " + \ + f"pos losses quantiles {np.quantile(actual_losses, [0.1, 0.25, 0.5, 0.75, 0.9])} | " + \ + f"losses/total {len(actual_losses)} / {len(total_positional_losses)}" print('-' * 89) print( - f'| end of epoch {epoch:3d} | time: {(time.time() - epoch_start_time):5.2f}s | mean loss {total_loss:5.2f} | ' - f"pos losses {','.join([f'{l:5.2f}' for l in total_positional_losses])}, lr {scheduler.get_last_lr()[0]}" + f'| end of epoch {epoch:3d} | time: {(time.time() - epoch_start_time):5.2f}s | mean loss {total_loss:5.4f} | ' + f' {pos_losses_str}, lr {scheduler.get_last_lr()[0]}' f' data time {time_to_get_batch:5.2f} step time {step_time:5.2f}' f' forward time {forward_time:5.2f}' f' nan share {nan_share:5.2f} ignore share (for classification tasks) {ignore_share:5.4f}'