From 48449fc049255520a45d7ac300a4cba40643a09c Mon Sep 17 00:00:00 2001 From: MoustHolmes Date: Fri, 28 Apr 2023 12:14:33 +0200 Subject: [PATCH] first working transformer training!! --- configs/data/upgrade_energy.yaml | 6 +- lightning_no_config_test.py | 146 ++++++++++++++++++++++++ src/data/datapipe_icecube_datamodule.py | 64 ++++------- 3 files changed, 174 insertions(+), 42 deletions(-) create mode 100644 lightning_no_config_test.py diff --git a/configs/data/upgrade_energy.yaml b/configs/data/upgrade_energy.yaml index bbb15bb..bf106ec 100644 --- a/configs/data/upgrade_energy.yaml +++ b/configs/data/upgrade_energy.yaml @@ -13,10 +13,12 @@ input_cols: - pmt_dir_x - pmt_dir_y - pmt_dir_z -target_cols: energy +target_cols: + - energy truth_table: truth -max_token_count: 1024 +max_token_count: 16384 num_workers: 16 +multi_processing_reading_service_num_workers: 4 # data_dir: ${paths.data_dir} # batch_size: 256 # # train_val_test_split: [879064,109884,109884] #[0.8, 0.1, 0.1] diff --git a/lightning_no_config_test.py b/lightning_no_config_test.py new file mode 100644 index 0000000..aba5446 --- /dev/null +++ b/lightning_no_config_test.py @@ -0,0 +1,146 @@ +import torch +import torch.nn.functional as F +import torchvision.datasets as datasets +import torchvision.transforms as transforms +from torch import nn, optim +from torch.utils.data import DataLoader +from tqdm import tqdm +from torch.utils.data import random_split +import pytorch_lightning as pl +import torchmetrics +from torchmetrics import Metric + + +class MyAccuracy(Metric): + def __init__(self): + super().__init__() + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds, target): + preds = torch.argmax(preds, dim=1) + assert preds.shape == target.shape + self.correct += torch.sum(preds == target) + self.total += target.numel() + + def compute(self): + return self.correct.float() / self.total.float() + + +class NN(pl.LightningModule): + def __init__(self, input_size, num_classes): + super().__init__() + self.fc1 = nn.Linear(input_size, 50) + self.fc2 = nn.Linear(50, num_classes) + self.loss_fn = nn.CrossEntropyLoss() + self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes) + self.my_accuracy = MyAccuracy() + self.f1_score = torchmetrics.F1Score(task="multiclass", num_classes=num_classes) + + def forward(self, x): + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return x + + def training_step(self, batch, batch_idx): + loss, scores, y = self._common_step(batch, batch_idx) + accuracy = self.my_accuracy(scores, y) + f1_score = self.f1_score(scores, y) + self.log_dict({'train_loss': loss, 'train_accuracy': accuracy, 'train_f1_score': f1_score}, + on_step=False, on_epoch=True, prog_bar=True) + return {'loss': loss, "scores": scores, "y": y} + + def validation_step(self, batch, batch_idx): + loss, scores, y = self._common_step(batch, batch_idx) + self.log('val_loss', loss) + return loss + + def test_step(self, batch, batch_idx): + loss, scores, y = self._common_step(batch, batch_idx) + self.log('test_loss', loss) + return loss + + def _common_step(self, batch, batch_idx): + x, y = batch + x = x.reshape(x.size(0), -1) + scores = self.forward(x) + loss = self.loss_fn(scores, y) + return loss, scores, y + + def predict_step(self, batch, batch_idx): + x, y = batch + x = x.reshape(x.size(0), -1) + scores = self.forward(x) + preds = torch.argmax(scores, dim=1) + return preds + + def configure_optimizers(self): + return optim.Adam(self.parameters(), lr=0.001) + + +class MnistDataModule(pl.LightningDataModule): + def __init__(self, data_dir, batch_size, num_workers): + super().__init__() + self.data_dir = data_dir + self.batch_size = batch_size + self.num_workers = num_workers + + def prepare_data(self): + datasets.MNIST(self.data_dir, train=True, download=True) + datasets.MNIST(self.data_dir, train=False, download=True) + + def setup(self, stage): + entire_dataset = datasets.MNIST( + root=self.data_dir, + train=True, + transform=transforms.ToTensor(), + download=False, + ) + self.train_ds, self.val_ds = random_split(entire_dataset, [50000, 10000]) + self.test_ds = datasets.MNIST( + root=self.data_dir, + train=False, + transform=transforms.ToTensor(), + download=False, + ) + + def train_dataloader(self): + return DataLoader( + self.train_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=True, + ) + + def val_dataloader(self): + return DataLoader( + self.val_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + + def test_dataloader(self): + return DataLoader( + self.test_ds, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) + +# Set device cuda for GPU if it's available otherwise run on the CPU +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# Hyperparameters +input_size = 784 +num_classes = 10 +learning_rate = 0.001 +batch_size = 64 +num_epochs = 3 + +model = NN(input_size=input_size, num_classes=num_classes) +dm = MnistDataModule(data_dir="dataset/", batch_size=batch_size, num_workers=4) +trainer = pl.Trainer(accelerator="gpu", devices=1, min_epochs=1, max_epochs=3, precision=16) +trainer.fit(model, dm) +trainer.validate(model, dm) +trainer.test(model, dm) \ No newline at end of file diff --git a/src/data/datapipe_icecube_datamodule.py b/src/data/datapipe_icecube_datamodule.py index eb287d4..760c2a9 100644 --- a/src/data/datapipe_icecube_datamodule.py +++ b/src/data/datapipe_icecube_datamodule.py @@ -4,7 +4,7 @@ import pandas as pd import torch import torch.nn as nn -from pytorch_lightning import LightningDataModule +from lightning import LightningDataModule from torch.utils.data import ConcatDataset, DataLoader, Dataset, SequentialSampler #random_split from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence from torch.utils.data import Dataset, DataLoader, random_split @@ -269,7 +269,8 @@ def __init__( target_cols: List[str], truth_table: str = "truth", max_token_count: int = 64, - num_workers: int = 0, + num_workers: int = 16, + multi_processing_reading_service_num_workers: int = 4, pin_memory: bool = False, ): super().__init__() @@ -280,18 +281,13 @@ def __init__( # data transformations here if any - - # self.event_no_list = np.genfromtxt(self.hparams.event_no_list_path,dtype=int) - self.datapipe_train: Optional[IterDataPipe] = None self.datapipe_val: Optional[IterDataPipe] = None self.datapipe_test: Optional[IterDataPipe] = None - self.rs = MultiProcessingReadingService(num_workers=2) - - # @property - # def num_classes(self): - # return 10 + self.rs = MultiProcessingReadingService( + num_workers = self.hparams.multi_processing_reading_service_num_workers + ) def prepare_data(self): """Download data if needed. @@ -306,56 +302,44 @@ def setup(self, stage: Optional[str] = None): This method is called by lightning with both `trainer.fit()` and `trainer.test()`, so be careful not to execute things like random split twice! """ + # sampler = SequentialSampler() if not self.datapipe_train and not self.datapipe_val and not self.datapipe_test: self.datapipe_train, self.datapipe_val, self.datapipe_test = make_train_test_val_datapipe( - self.hparams.train_csv_file, - self.hparams.test_csv_file, - self.hparams.val_csv_file, - self.hparams.db_path, - self.hparams.input_cols, - self.hparams.pulsemap, - self.hparams.target_cols, - self.hparams.truth_table, - self.hparams.max_token_count, - upgrade_feature_transform, - truth_transform = None + train_csv_file = self.hparams.train_csv_file, + test_csv_file = self.hparams.test_csv_file, + val_csv_file = self.hparams.val_csv_file, + db_path = self.hparams.db_path, + input_cols = self.hparams.input_cols, + pulsemap = self.hparams.pulsemap, + target_cols = self.hparams.target_cols, + truth_table = self.hparams.truth_table, + max_token_count = self.hparams.max_token_count, + feature_transform = upgrade_feature_transform, + truth_transform = None, ) def train_dataloader(self): self.icecube_train_dataloader = DataLoader2( - self.datapipe_train, - # self.rs + datapipe = self.datapipe_train, + reading_service = self.rs, ) - print() - for i, batch in enumerate(self.datapipe_train): - - xx, y, pad, = batch - print(sum([len(x) for x in xx])) - print(len(y)) - print() - if i == 2: - break - - print() - print() - print(iter(self.icecube_train_dataloader)) - print() return self.icecube_train_dataloader def val_dataloader(self): self.icecube_val_dataloader = DataLoader2( - self.datapipe_val + datapipe = self.datapipe_val, + reading_service = self.rs, ) return self.icecube_val_dataloader def test_dataloader(self): self.icecube_test_dataloader = DataLoader2( - self.datapipe_val + datapipe = self.datapipe_val, + reading_service = self.rs, ) return self.icecube_test_dataloader - def teardown(self, stage: Optional[str] = None): """Clean up after fit or test."""