Skip to content

Commit

Permalink
Merge pull request #15 from UCSD-E4E/linting
Browse files Browse the repository at this point in the history
pylint fully passes :3
  • Loading branch information
sprestrelski authored Jul 5, 2023
2 parents 95134b0 + 62546ca commit ad3af58
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 170 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.8"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
10 changes: 5 additions & 5 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ ignore-patterns=^\.#
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis). It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=ps
ignored-modules=pandas, numpy, pydub, torch, torchaudio, timm, tqdm, wandb, torchmetrics

# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=
init-hook="from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))"

# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use, and will cap the count on Windows to
Expand Down Expand Up @@ -319,8 +319,8 @@ min-public-methods=2
[EXCEPTIONS]

# Exceptions that will emit a warning when caught.
overgeneral-exceptions=BaseException,
Exception
overgeneral-exceptions=builtins.BaseException,
builtins.Exception


[FORMAT]
Expand Down Expand Up @@ -357,7 +357,7 @@ single-line-if-stmt=no

# List of modules that can be imported at any level, not just the top level
# one.
allow-any-import-level=
allow-any-import-level=pandas, numpy, pydub, torch, torchaudio, timm, tqdm, wandb

# Allow wildcard imports from modules that define __all__.
allow-wildcard-with-all=no
Expand Down
17 changes: 12 additions & 5 deletions classification/default_parser.py → classification/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
""" Stores default argument information for the argparser
""" Stores default argument information for CONFIG variable
Methods:
create_parser: returns an ArgumentParser with the default arguments
get_config: returns config data that's parsed from the command line
"""
import argparse

def create_parser():
""" Returns an ArgumentParser with the default arguments
def get_config():
""" Returns a config variable with the command line arguments or defaults
"""
parser = argparse.ArgumentParser()
parser.add_argument('-e', '--epochs', default=10, type=int)
Expand Down Expand Up @@ -44,4 +44,11 @@ def create_parser():
parser.add_argument('-et', '--duration_col', default='DURATION', type=str)
parser.add_argument('-fp', '--file_path_col', default='IN FILE', type=str)
parser.add_argument('-mi', '--manual_id_col', default='SCIENTIFIC', type=str)
return parser

CONFIG = parser.parse_args()

# Convert string arguments to boolean
CONFIG.logging = CONFIG.logging == 'True'
CONFIG.verbose = CONFIG.verbose == 'True'

return CONFIG
86 changes: 32 additions & 54 deletions classification/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,48 +8,47 @@
"""

import torch
import torchaudio
import torch.nn.functional as F
from torchaudio import transforms as audtr
from torch.utils.data import Dataset

# Standard library imports
from typing import Dict, List, Tuple
import os

# Math library imports
import pandas as pd
import numpy as np

from utils import print_verbose, set_seed
from default_parser import create_parser
parser = create_parser()
# Torch imports
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
import torchaudio
from torchaudio import transforms as audtr

# Local imports
from utils import print_verbose, set_seed
from config import get_config

device = 'cuda' if torch.cuda.is_available() else 'cpu'

#https://www.kaggle.com/code/debarshichanda/pytorch-w-b-birdclef-22-starter

# pylint: disable=too-many-instance-attributes
class PyhaDF_Dataset(Dataset): #datasets.DatasetFolder
""" A dataset that loads audio files and converts them to mel spectrograms
"""
def __init__(self, csv_file, loader=None, CONFIG=None, max_time=5, train=True, species=None, ignore_bad=True):
super()#.__init__(root, loader, extensions='wav')
# pylint: disable=too-many-arguments
def __init__(self, csv_file, CONFIG=None, max_time=5, train=True, species=None, ignore_bad=True):
super()
if isinstance(csv_file,str):
self.samples = pd.read_csv(csv_file, index_col=0)
elif isinstance(csv_file,pd.DataFrame):
self.samples = csv_file
self.csv_file = f"data_train-{train}.csv"
else:
raise RuntimeError("csv_file must be a str or dataframe!")



self.formatted_csv_file = "not yet formatted"
#print(self.samples)
self.config = CONFIG
self.ignore_bad = ignore_bad
target_sample_rate = CONFIG.sample_rate
self.target_sample_rate = target_sample_rate
num_samples = target_sample_rate * max_time
self.target_sample_rate = CONFIG.sample_rate
num_samples = CONFIG.sample_rate * max_time
self.num_samples = num_samples
self.mel_spectogram = audtr.MelSpectrogram(sample_rate=self.target_sample_rate,
n_mels=self.config.n_mels,
Expand All @@ -60,13 +59,13 @@ def __init__(self, csv_file, loader=None, CONFIG=None, max_time=5, train=True, s
self.time_mask = audtr.TimeMasking(time_mask_param=self.config.time_mask_param)

if species is not None:
# pylint: disable=fixme
#TODO FIX REPLICATION CODE
self.classes, self.class_to_idx = species
else:
self.classes = self.samples[self.config.manual_id_col].unique()
class_idx = np.arange(len(self.classes))
self.class_to_idx = dict(zip(self.classes, class_idx))
#print(self.class_to_idx)
self.num_classes = len(self.classes)

self.verify_audio_files()
Expand Down Expand Up @@ -95,8 +94,7 @@ def verify_audio_files(self) -> bool:
#Run the data getting code and check to make sure preprocessing did not break code
#poor files may contain null values, or sections of files might contain null files
bad_files = []
for i in range(len(self)):
spectrogram, _ = self[i]
for i, (spectrogram, _) in enumerate(self):
if spectrogram.isnan().any():
bad_files.append(i)

Expand Down Expand Up @@ -219,8 +217,6 @@ def one_hot(x, num_classes, on_value=1., off_value=0.):
frame_offset=frame_offset,
num_frames=num_frames)

#print(path, "test.wav", annotation[self.config.duration_col], annotation[self.config.duration_col])

#Assume audio is all mono and at target sample rate
assert audio.shape[0] == 1
assert sample_rate == self.target_sample_rate
Expand Down Expand Up @@ -278,8 +274,6 @@ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
if image.isnan().any():
print("ERROR IN ANNOTATION #", index)
raise RuntimeError("NANS IN INPUT FOUND")
#print(image)
#print(target)
return image, target

def pad_audio(self, audio: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -312,39 +306,23 @@ def get_datasets(path="testformatted.csv", CONFIG=None):
train = data.sample(frac=1/2)
valid = data[~data.index.isin(train.index)]
return PyhaDF_Dataset(csv_file=train, CONFIG=CONFIG), PyhaDF_Dataset(csv_file=valid,train=False, CONFIG=CONFIG)
#data = BirdCLEFDataset(root="/share/acoustic_species_id/BirdCLEF2023_train_audio_chunks", CONFIG=CONFIG)
#no_bird_data = BirdCLEFDataset(root="/share/acoustic_species_id/no_bird_10_000_audio_chunks", CONFIG=CONFIG)
#data = torch.utils.data.ConcatDataset([data, no_bird_data])
#train_data, val_data = torch.utils.data.random_split(data, [0.8, 0.2])
#return train_data, val_data

def main():
""" Main function
"""
torch.multiprocessing.set_start_method('spawn')
CONFIG = parser.parse_args()
CONFIG.logging = CONFIG.logging == 'True'
CONFIG = get_config()
set_seed(CONFIG.seed)
train_dataset, val_dataset = get_datasets(CONFIG=CONFIG)
print(train_dataset.get_classes()[1])
print(train_dataset.__getitem__(0))
input()
#train_dataset = get_datasets(CONFIG=CONFIG)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
1,
shuffle=True,
num_workers=CONFIG.jobs,
)
val_dataloader = torch.utils.data.DataLoader(
val_dataset,
CONFIG.valid_batch_size,
shuffle=False,
num_workers=CONFIG.jobs,
)

for i in range(train_dataset.__len__()):
print("entry", i)
train_dataset.__getitem__(i)
input()

# note: this calls __getitem__ on the dataset and discards the result
for i, _ in enumerate(train_dataset):
print_verbose("train entry", i,verbose=CONFIG.verbose)
print("Loaded all training data")

for i, _ in enumerate(val_dataset):
print_verbose("validation entry", i,verbose=CONFIG.verbose)
print("Loaded all validation data")

if __name__ == '__main__':
main()
22 changes: 7 additions & 15 deletions classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def __init__(self,
self.pooling = GeM()
self.embedding = nn.Linear(in_features, embedding_size)
self.fc = nn.Linear(embedding_size, CONFIG.num_classes)
# Must call create_loss_fn after initializing the model to give it the weights
self.loss_fn = None

def forward(self, images):
""" Forward pass of the model
Expand All @@ -79,20 +81,10 @@ def create_loss_fn(self,train_dataset):
""" Returns the loss function and sets self.loss_fn
"""
if not self.config.imb: # normal loss
if self.config.pos_weight != 1:
self.loss_fn = nn.CrossEntropyLoss(pos_weight=torch.tensor([self.config.pos_weight] * self.config.num_classes).to(self.device))
else:
self.loss_fn = nn.CrossEntropyLoss()
self.loss_fn = nn.CrossEntropyLoss()
else: # weighted loss
if self.config.pos_weight != 1:
self.loss_fn = nn.CrossEntropyLoss(
pos_weight=torch.tensor([self.config.pos_weight] * self.config.num_classes).to(self.device),
weight=torch.tensor([1 / p for p in train_dataset.class_id_to_num_samples.values()]).to(self.device)
)
else:
self.loss_fn = nn.CrossEntropyLoss(
weight=torch.tensor(
[1 / p for p in train_dataset.class_id_to_num_samples.values()]
).to(self.device)
)
self.loss_fn = nn.CrossEntropyLoss(
weight=torch.tensor(
[1 / p for p in train_dataset.class_id_to_num_samples.values()]
).to(self.device))
return self.loss_fn
Loading

0 comments on commit ad3af58

Please sign in to comment.