diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 383e65cd..622510a9 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -7,7 +7,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8"] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/.pylintrc b/.pylintrc index 0b5eb267..8367a29d 100644 --- a/.pylintrc +++ b/.pylintrc @@ -64,11 +64,11 @@ ignore-patterns=^\.# # (useful for modules/projects where namespaces are manipulated during runtime # and thus existing member attributes cannot be deduced by static analysis). It # supports qualified module names, as well as Unix pattern matching. -ignored-modules=ps +ignored-modules=pandas, numpy, pydub, torch, torchaudio, timm, tqdm, wandb, torchmetrics # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). -#init-hook= +init-hook="from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))" # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the # number of processors available to use, and will cap the count on Windows to @@ -319,8 +319,8 @@ min-public-methods=2 [EXCEPTIONS] # Exceptions that will emit a warning when caught. -overgeneral-exceptions=BaseException, - Exception +overgeneral-exceptions=builtins.BaseException, + builtins.Exception [FORMAT] @@ -357,7 +357,7 @@ single-line-if-stmt=no # List of modules that can be imported at any level, not just the top level # one. -allow-any-import-level= +allow-any-import-level=pandas, numpy, pydub, torch, torchaudio, timm, tqdm, wandb # Allow wildcard imports from modules that define __all__. allow-wildcard-with-all=no diff --git a/classification/default_parser.py b/classification/config.py similarity index 87% rename from classification/default_parser.py rename to classification/config.py index 5cdfe9e1..cc1f5305 100644 --- a/classification/default_parser.py +++ b/classification/config.py @@ -1,11 +1,11 @@ -""" Stores default argument information for the argparser +""" Stores default argument information for CONFIG variable Methods: - create_parser: returns an ArgumentParser with the default arguments + get_config: returns config data that's parsed from the command line """ import argparse -def create_parser(): - """ Returns an ArgumentParser with the default arguments +def get_config(): + """ Returns a config variable with the command line arguments or defaults """ parser = argparse.ArgumentParser() parser.add_argument('-e', '--epochs', default=10, type=int) @@ -44,4 +44,11 @@ def create_parser(): parser.add_argument('-et', '--duration_col', default='DURATION', type=str) parser.add_argument('-fp', '--file_path_col', default='IN FILE', type=str) parser.add_argument('-mi', '--manual_id_col', default='SCIENTIFIC', type=str) - return parser + + CONFIG = parser.parse_args() + + # Convert string arguments to boolean + CONFIG.logging = CONFIG.logging == 'True' + CONFIG.verbose = CONFIG.verbose == 'True' + + return CONFIG diff --git a/classification/dataset.py b/classification/dataset.py index 3b25beeb..9413c6f1 100644 --- a/classification/dataset.py +++ b/classification/dataset.py @@ -8,32 +8,34 @@ """ -import torch -import torchaudio -import torch.nn.functional as F -from torchaudio import transforms as audtr -from torch.utils.data import Dataset - +# Standard library imports from typing import Dict, List, Tuple import os +# Math library imports import pandas as pd import numpy as np -from utils import print_verbose, set_seed -from default_parser import create_parser -parser = create_parser() +# Torch imports +import torch +from torch.utils.data import Dataset +import torch.nn.functional as F +import torchaudio +from torchaudio import transforms as audtr +# Local imports +from utils import print_verbose, set_seed +from config import get_config device = 'cuda' if torch.cuda.is_available() else 'cpu' -#https://www.kaggle.com/code/debarshichanda/pytorch-w-b-birdclef-22-starter - +# pylint: disable=too-many-instance-attributes class PyhaDF_Dataset(Dataset): #datasets.DatasetFolder """ A dataset that loads audio files and converts them to mel spectrograms """ - def __init__(self, csv_file, loader=None, CONFIG=None, max_time=5, train=True, species=None, ignore_bad=True): - super()#.__init__(root, loader, extensions='wav') + # pylint: disable=too-many-arguments + def __init__(self, csv_file, CONFIG=None, max_time=5, train=True, species=None, ignore_bad=True): + super() if isinstance(csv_file,str): self.samples = pd.read_csv(csv_file, index_col=0) elif isinstance(csv_file,pd.DataFrame): @@ -41,15 +43,12 @@ def __init__(self, csv_file, loader=None, CONFIG=None, max_time=5, train=True, s self.csv_file = f"data_train-{train}.csv" else: raise RuntimeError("csv_file must be a str or dataframe!") - - + self.formatted_csv_file = "not yet formatted" - #print(self.samples) self.config = CONFIG self.ignore_bad = ignore_bad - target_sample_rate = CONFIG.sample_rate - self.target_sample_rate = target_sample_rate - num_samples = target_sample_rate * max_time + self.target_sample_rate = CONFIG.sample_rate + num_samples = CONFIG.sample_rate * max_time self.num_samples = num_samples self.mel_spectogram = audtr.MelSpectrogram(sample_rate=self.target_sample_rate, n_mels=self.config.n_mels, @@ -60,13 +59,13 @@ def __init__(self, csv_file, loader=None, CONFIG=None, max_time=5, train=True, s self.time_mask = audtr.TimeMasking(time_mask_param=self.config.time_mask_param) if species is not None: + # pylint: disable=fixme #TODO FIX REPLICATION CODE self.classes, self.class_to_idx = species else: self.classes = self.samples[self.config.manual_id_col].unique() class_idx = np.arange(len(self.classes)) self.class_to_idx = dict(zip(self.classes, class_idx)) - #print(self.class_to_idx) self.num_classes = len(self.classes) self.verify_audio_files() @@ -95,8 +94,7 @@ def verify_audio_files(self) -> bool: #Run the data getting code and check to make sure preprocessing did not break code #poor files may contain null values, or sections of files might contain null files bad_files = [] - for i in range(len(self)): - spectrogram, _ = self[i] + for i, (spectrogram, _) in enumerate(self): if spectrogram.isnan().any(): bad_files.append(i) @@ -219,8 +217,6 @@ def one_hot(x, num_classes, on_value=1., off_value=0.): frame_offset=frame_offset, num_frames=num_frames) - #print(path, "test.wav", annotation[self.config.duration_col], annotation[self.config.duration_col]) - #Assume audio is all mono and at target sample rate assert audio.shape[0] == 1 assert sample_rate == self.target_sample_rate @@ -278,8 +274,6 @@ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: if image.isnan().any(): print("ERROR IN ANNOTATION #", index) raise RuntimeError("NANS IN INPUT FOUND") - #print(image) - #print(target) return image, target def pad_audio(self, audio: torch.Tensor) -> torch.Tensor: @@ -312,39 +306,23 @@ def get_datasets(path="testformatted.csv", CONFIG=None): train = data.sample(frac=1/2) valid = data[~data.index.isin(train.index)] return PyhaDF_Dataset(csv_file=train, CONFIG=CONFIG), PyhaDF_Dataset(csv_file=valid,train=False, CONFIG=CONFIG) - #data = BirdCLEFDataset(root="/share/acoustic_species_id/BirdCLEF2023_train_audio_chunks", CONFIG=CONFIG) - #no_bird_data = BirdCLEFDataset(root="/share/acoustic_species_id/no_bird_10_000_audio_chunks", CONFIG=CONFIG) - #data = torch.utils.data.ConcatDataset([data, no_bird_data]) - #train_data, val_data = torch.utils.data.random_split(data, [0.8, 0.2]) - #return train_data, val_data def main(): + """ Main function + """ torch.multiprocessing.set_start_method('spawn') - CONFIG = parser.parse_args() - CONFIG.logging = CONFIG.logging == 'True' + CONFIG = get_config() set_seed(CONFIG.seed) train_dataset, val_dataset = get_datasets(CONFIG=CONFIG) - print(train_dataset.get_classes()[1]) - print(train_dataset.__getitem__(0)) - input() - #train_dataset = get_datasets(CONFIG=CONFIG) - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - 1, - shuffle=True, - num_workers=CONFIG.jobs, - ) - val_dataloader = torch.utils.data.DataLoader( - val_dataset, - CONFIG.valid_batch_size, - shuffle=False, - num_workers=CONFIG.jobs, - ) - - for i in range(train_dataset.__len__()): - print("entry", i) - train_dataset.__getitem__(i) - input() + + # note: this calls __getitem__ on the dataset and discards the result + for i, _ in enumerate(train_dataset): + print_verbose("train entry", i,verbose=CONFIG.verbose) + print("Loaded all training data") + + for i, _ in enumerate(val_dataset): + print_verbose("validation entry", i,verbose=CONFIG.verbose) + print("Loaded all validation data") if __name__ == '__main__': main() diff --git a/classification/model.py b/classification/model.py index 8a94a05b..88d4ffcf 100644 --- a/classification/model.py +++ b/classification/model.py @@ -65,6 +65,8 @@ def __init__(self, self.pooling = GeM() self.embedding = nn.Linear(in_features, embedding_size) self.fc = nn.Linear(embedding_size, CONFIG.num_classes) + # Must call create_loss_fn after initializing the model to give it the weights + self.loss_fn = None def forward(self, images): """ Forward pass of the model @@ -79,20 +81,10 @@ def create_loss_fn(self,train_dataset): """ Returns the loss function and sets self.loss_fn """ if not self.config.imb: # normal loss - if self.config.pos_weight != 1: - self.loss_fn = nn.CrossEntropyLoss(pos_weight=torch.tensor([self.config.pos_weight] * self.config.num_classes).to(self.device)) - else: - self.loss_fn = nn.CrossEntropyLoss() + self.loss_fn = nn.CrossEntropyLoss() else: # weighted loss - if self.config.pos_weight != 1: - self.loss_fn = nn.CrossEntropyLoss( - pos_weight=torch.tensor([self.config.pos_weight] * self.config.num_classes).to(self.device), - weight=torch.tensor([1 / p for p in train_dataset.class_id_to_num_samples.values()]).to(self.device) - ) - else: - self.loss_fn = nn.CrossEntropyLoss( - weight=torch.tensor( - [1 / p for p in train_dataset.class_id_to_num_samples.values()] - ).to(self.device) - ) + self.loss_fn = nn.CrossEntropyLoss( + weight=torch.tensor( + [1 / p for p in train_dataset.class_id_to_num_samples.values()] + ).to(self.device)) return self.loss_fn diff --git a/classification/train.py b/classification/train.py index 61731a03..45968560 100644 --- a/classification/train.py +++ b/classification/train.py @@ -9,64 +9,41 @@ """ -# other files +# Standard library imports +import datetime +from typing import Dict, Any, Tuple + +# Local imports from dataset import PyhaDF_Dataset, get_datasets -from model import BirdCLEFModel, GeM -from tqdm import tqdm +from model import BirdCLEFModel from utils import set_seed, print_verbose -from default_parser import create_parser +from config import get_config -# pytorch training +# Torch imports import torch -from torch import nn -from torch.optim import Adam import torch.nn.functional as F from torch.optim import Adam - -# general -import numpy as np -from typing import Dict, Any, Tuple - - - -# other files -from model import BirdCLEFModel #, GeM -from tqdm import tqdm from torchmetrics.classification import MultilabelAveragePrecision - -# pytorch training -import torch -import torch.nn as nn - - -#https://www.kaggle.com/code/imvision12/birdclef-2023-efficientnet-training - -# logging +# Other imports +from tqdm import tqdm import wandb -import datetime -time_now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') -from torchmetrics.classification import MultilabelAveragePrecision +wandb_run = None +time_now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') device = 'cuda' if torch.cuda.is_available() else 'cpu' -parser = create_parser() def train(model: BirdCLEFModel, data_loader: PyhaDF_Dataset, optimizer: torch.optim.Optimizer, scheduler, - device: str, - step: int, - best_valid_cmap: float, - epoch: int, CONFIG) -> Tuple[float, int, float]: """ Trains the model Returns: loss: the average loss over the epoch step: the current step - best_valid_cmap: the best validation mAP """ print_verbose('size of data loader:', len(data_loader),verbose=CONFIG.verbose) model.train() @@ -83,26 +60,18 @@ def train(model: BirdCLEFModel, labels = labels.to(device) outputs = model(mels) - # sigmoid multilabel predictions - preds = torch.sigmoid(outputs) > 0.5 loss = model.loss_fn(outputs, labels) loss.backward() optimizer.step() - if scheduler is not None: scheduler.step() running_loss += loss.item() total += labels.size(0) - # index of highest predicted class - #pred_label = torch.argmax(outputs, dim=1) - - # checking highest against true label - correct += torch.all(torch.round(outputs).eq(labels), dim=-1).sum().item() log_loss += loss.item() log_n += 1 @@ -112,33 +81,16 @@ def train(model: BirdCLEFModel, wandb.log({ "train/loss": log_loss / log_n, "train/accuracy": correct / total * 100., - "custom_step": step, }) print("Loss:", log_loss / log_n, "Accuracy:", correct / total * 100.) log_loss = 0 log_n = 0 correct = 0 total = 0 - - #if step % CONFIG.valid_freq == 0 and step != 0: - # del mels, labels, outputs, preds # clear memory - # valid_loss, valid_map = valid(model, val_dataloader, device, step) - # print(f"Validation Loss:\t{valid_loss} \n Validation mAP:\t{valid_map}" ) - # if valid_map > best_valid_cmap: - # print(f"Validation cmAP Improved - {best_valid_cmap} ---> {valid_map}") - # best_valid_cmap = valid_map - # torch.save(model.state_dict(), wandb_run.name + '.pt') - # print(wandb_run.name + '.pt') - # model.train() - - - step += 1 - - return running_loss/len(data_loader), step, best_valid_cmap + return running_loss/len(data_loader) def valid(model: BirdCLEFModel, data_loader: PyhaDF_Dataset, - device: str, step: int, CONFIG) -> Tuple[float, float]: """ @@ -163,7 +115,6 @@ def valid(model: BirdCLEFModel, # argmax outputs = model(mels) - #_, preds = torch.max(outputs, 1) loss = model.loss_fn(outputs, labels) @@ -171,7 +122,7 @@ def valid(model: BirdCLEFModel, pred.append(outputs.cpu().detach()) label.append(labels.cpu().detach()) - # break + pred = torch.cat(pred) label = torch.cat(label) if CONFIG.map_debug and CONFIG.model_checkpoint is not None: @@ -180,25 +131,10 @@ def valid(model: BirdCLEFModel, # softmax predictions pred = F.softmax(pred).to(device) - # pred = pred[:, unq_classes] - - # # pad predictions and labels with `pad_n` true positives metric = MultilabelAveragePrecision(num_labels=CONFIG.num_classes, average="macro") valid_map = metric(pred.detach().cpu(), label.detach().cpu().long()) - #valid_map = metric(padded_preds, padded_labels) - # calculate average precision - # valid_map = average_precision_score( - # label.cpu().long(), - # pred.detach().cpu(), - # average='macro', - # ) - # _, padded_preds = torch.max(padded_preds, 1) - - # acc = (padded_preds == padded_labels).sum().item() / len(padded_preds) - # print("Validation Accuracy:", acc) - print("Validation mAP:", valid_map) @@ -230,7 +166,6 @@ def init_wandb(CONFIG: Dict[str, Any]): f"-{CONFIG.n_fft}-{CONFIG.seed}-" + run.name.split('-')[-1] ) - run.name = f"EFN-{CONFIG.epochs}-{CONFIG.train_batch_size}-{CONFIG.valid_batch_size}-{CONFIG.sample_rate}-{CONFIG.hop_length}-{CONFIG.max_time}-{CONFIG.n_mels}-{CONFIG.n_fft}-{CONFIG.seed}-" + run.name.split('-')[-1] return run def load_datasets(CONFIG: Dict[str, Any]) \ @@ -255,17 +190,19 @@ def load_datasets(CONFIG: Dict[str, Any]) \ return train_dataset, val_dataset, train_dataloader, val_dataloader def main(): + """ Main function + """ torch.multiprocessing.set_start_method('spawn') - CONFIG = parser.parse_args() - print(CONFIG) - CONFIG.logging = CONFIG.logging == 'True' - CONFIG.verbose = CONFIG.verbose == 'True' + CONFIG = get_config() + # Needed to redefine wandb_run as a global variable + # pylint: disable=global-statement global wandb_run wandb_run = init_wandb(CONFIG) set_seed(CONFIG.seed) # Load in dataset print("Loading Dataset") + # pylint: disable=unused-variable train_dataset, val_dataset, train_dataloader, val_dataloader = load_datasets(CONFIG) print("Loading Model...") @@ -284,19 +221,16 @@ def main(): for epoch in range(CONFIG.epochs): print("Epoch " + str(epoch)) - train_loss, step, best_valid_cmap = train( + _ = train( model_for_run, train_dataloader, optimizer, scheduler, - device, - step, - best_valid_cmap, - epoch, CONFIG ) + step += 1 - valid_loss, valid_map = valid(model_for_run, val_dataloader, device, step, CONFIG) + valid_loss, valid_map = valid(model_for_run, val_dataloader, step, CONFIG) print(f"Validation Loss:\t{valid_loss} \n Validation mAP:\t{valid_map}" ) if valid_map > best_valid_cmap: diff --git a/classification/utils.py b/classification/utils.py index 9ed5085c..50041379 100644 --- a/classification/utils.py +++ b/classification/utils.py @@ -5,7 +5,6 @@ """ -from typing import Dict, Any import numpy as np import torch