Skip to content

Commit

Permalink
Merge branch 'main' into seralize_waveforms
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean1572 authored Jul 5, 2023
2 parents e6a406d + ad3af58 commit a2a6581
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 97 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10"]
python-version: ["3.8"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
3 changes: 2 additions & 1 deletion classification/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@

""" Stores default argument information for the argparser
Methods:
get_config: returns an ArgumentParser with the default arguments
"""
import argparse

def get_config():
""" Returns an ArgumentParser with the default arguments
""" Returns a config variable with the command line arguments or defaults
"""
parser = argparse.ArgumentParser()
parser.add_argument('-e', '--epochs', default=10, type=int)
Expand Down
13 changes: 3 additions & 10 deletions classification/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,15 @@
import torchaudio
from torchaudio import transforms as audtr





# Math library imports
import pandas as pd
import numpy as np


from utils import set_seed #print_verbose
from config import get_config
from tqdm import tqdm


tqdm.pandas()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

Expand All @@ -45,7 +42,7 @@ class PyhaDF_Dataset(Dataset):
"""

# df, csv_file, train, and species decided outside of config, so those cannot be added in there
# pylint: disable-next=R0913
# pylint: disable-next=too-many-instance-attributes
def __init__(self, df, csv_file="test.csv", CONFIG=None, train=True, species=None):
self.config = CONFIG
self.samples = df[~(df[self.config.file_path_col].isnull())]
Expand Down Expand Up @@ -203,8 +200,6 @@ def one_hot(x, num_classes, on_value=1., off_value=0.):
raise RuntimeError("Bad Audio") from e


#print(path, "test.wav", annotation[self.config.duration_col], annotation[self.config.duration_col])

#Assume audio is all mono and at target sample rate
#assert audio.shape[0] == 1
#assert sample_rate == self.target_sample_rate
Expand Down Expand Up @@ -268,8 +263,6 @@ def __getitem__(self, index): #-> Any:
#try again with a diff annotation to avoid training breaking
image, target = self[self.samples.sample(1).index[0]]

#print(image)
#print(target)
return image, target

def pad_audio(self, audio: torch.Tensor) -> torch.Tensor:
Expand Down
22 changes: 5 additions & 17 deletions classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,10 @@ def create_loss_fn(self,train_dataset):
""" Returns the loss function and sets self.loss_fn
"""
if not self.config.imb: # normal loss
if self.config.pos_weight != 1:
self.loss_fn = nn.CrossEntropyLoss(
pos_weight=torch.tensor([self.config.pos_weight] * self.config.num_classes).to(self.device)
)
else:
self.loss_fn = nn.CrossEntropyLoss()
self.loss_fn = nn.CrossEntropyLoss()
else: # weighted loss
if self.config.pos_weight != 1:
self.loss_fn = nn.CrossEntropyLoss(
pos_weight=torch.tensor([self.config.pos_weight] * self.config.num_classes).to(self.device),
weight=torch.tensor([1 / p for p in train_dataset.class_id_to_num_samples.values()]).to(self.device)
)
else:
self.loss_fn = nn.CrossEntropyLoss(
weight=torch.tensor(
[1 / p for p in train_dataset.class_id_to_num_samples.values()]
).to(self.device)
)
self.loss_fn = nn.CrossEntropyLoss(
weight=torch.tensor(
[1 / p for p in train_dataset.class_id_to_num_samples.values()]
).to(self.device))
return self.loss_fn
77 changes: 10 additions & 67 deletions classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@
tqdm.pandas()
time_now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')



device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
wandb_run = None
Expand All @@ -81,7 +79,6 @@ def train(model: BirdCLEFModel,
Returns:
loss: the average loss over the epoch
step: the current step
best_valid_cmap: the best validation mAP
"""
print_verbose('size of data loader:', len(data_loader),verbose=CONFIG.verbose)
model.train()
Expand All @@ -98,26 +95,17 @@ def train(model: BirdCLEFModel,
labels = labels.to(device)

outputs = model(mels)
# sigmoid multilabel predictions
# preds = torch.sigmoid(outputs) > 0.5

loss = model.loss_fn(outputs, labels)

loss.backward()
optimizer.step()


if scheduler is not None:
scheduler.step()

running_loss += loss.item()
total += labels.size(0)

# index of highest predicted class
#pred_label = torch.argmax(outputs, dim=1)

# checking highest against true label

correct += torch.all(torch.round(outputs).eq(labels), dim=-1).sum().item()
log_loss += loss.item()
log_n += 1
Expand All @@ -127,33 +115,18 @@ def train(model: BirdCLEFModel,
wandb.log({
"train/loss": log_loss / log_n,
"train/accuracy": correct / total * 100.,
"custom_step": step,
})
print("Loss:", log_loss / log_n, "Accuracy:", correct / total * 100.)
log_loss = 0
log_n = 0
correct = 0
total = 0

# if step % CONFIG.valid_freq == 0 and step != 0:
# del mels, labels, outputs # clear memory
# valid_loss, valid_map = valid(model, val_dataloader, device, step, CONFIG)
# print(f"Validation Loss:\t{valid_loss} \n Validation mAP:\t{valid_map}" )
# if valid_map > best_valid_cmap:
# print(f"Validation cmAP Improved - {best_valid_cmap} ---> {valid_map}")
# best_valid_cmap = valid_map
# torch.save(model.state_dict(), wandb_run.name + '.pt')
# print(wandb_run.name + '.pt')
# model.train()


step += 1
return running_loss/len(data_loader)

return running_loss/len(data_loader), step, best_valid_cmap

def valid(model: BirdCLEFModel,
data_loader: PyhaDF_Dataset,
device: str,
step: int,
CONFIG) -> Tuple[float, float]:
"""
Expand All @@ -178,15 +151,14 @@ def valid(model: BirdCLEFModel,

# argmax
outputs = model(mels)
#_, preds = torch.max(outputs, 1)

loss = model.loss_fn(outputs, labels)

running_loss += loss.item()

pred.append(outputs.cpu().detach())
label.append(labels.cpu().detach())
# break

pred = torch.cat(pred)
label = torch.cat(label)
if CONFIG.map_debug and CONFIG.model_checkpoint is not None:
Expand All @@ -195,25 +167,10 @@ def valid(model: BirdCLEFModel,

# softmax predictions
pred = F.softmax(pred).to(device)
# pred = pred[:, unq_classes]

# # pad predictions and labels with `pad_n` true positives

metric = MultilabelAveragePrecision(num_labels=CONFIG.num_classes, average="macro")
valid_map = metric(pred.detach().cpu(), label.detach().cpu().long())

#valid_map = metric(padded_preds, padded_labels)
# calculate average precision
# valid_map = average_precision_score(
# label.cpu().long(),
# pred.detach().cpu(),
# average='macro',
# )
# _, padded_preds = torch.max(padded_preds, 1)

# acc = (padded_preds == padded_labels).sum().item() / len(padded_preds)
# print("Validation Accuracy:", acc)


print("Validation mAP:", valid_map)

Expand Down Expand Up @@ -264,11 +221,7 @@ def init_wandb(CONFIG: Dict[str, Any]):
f"-{CONFIG.n_fft}-{CONFIG.seed}-" +
run.name.split('-')[-1]
)
run.name = f"EFN-{CONFIG.epochs}-{CONFIG.train_batch_size}-"
run.name += f"{CONFIG.valid_batch_size}-{CONFIG.sample_rate}-"
run.name += f"{CONFIG.hop_length}-{CONFIG.max_time}-"
run.name += f"{CONFIG.n_mels}-{CONFIG.n_fft}-{CONFIG.seed}-"
run.name += run.name.split('-')[-1]

return run

def load_datasets(CONFIG: Dict[str, Any]) \
Expand All @@ -293,25 +246,19 @@ def load_datasets(CONFIG: Dict[str, Any]) \
return train_dataset, val_dataset, train_dataloader, val_dataloader

def main():
"""
Run training
""" Main function
"""
torch.multiprocessing.set_start_method('spawn')
CONFIG = get_config()
print(CONFIG)
CONFIG.logging = CONFIG.logging == 'True'
CONFIG.verbose = CONFIG.verbose == 'True'

# Yes this could be better, out of scope of MVP
# pylint: disable=W0603
# Needed to redefine wandb_run as a global variable
# pylint: disable=global-statement
global wandb_run
wandb_run = init_wandb(CONFIG)
set_seed(CONFIG.seed)

# Load in dataset
print("Loading Dataset")
# we might need it in future
# pylint: disable-next=W0612
# pylint: disable=unused-variable
train_dataset, val_dataset, train_dataloader, val_dataloader = load_datasets(CONFIG)

print("Loading Model...")
Expand All @@ -332,20 +279,16 @@ def main():
for epoch in range(CONFIG.epochs):
print("Epoch " + str(epoch))

train_loss, step, best_valid_cmap = train(
_ = train(
model_for_run,
train_dataloader,
optimizer,
scheduler,
device,
step,
best_valid_cmap,
CONFIG
)

print(f"Train Loss:\t{train_loss} ")
step += 1

valid_loss, valid_map = valid(model_for_run, val_dataloader, device, step, CONFIG)
valid_loss, valid_map = valid(model_for_run, val_dataloader, step, CONFIG)
print(f"Validation Loss:\t{valid_loss} \n Validation mAP:\t{valid_map}" )

if valid_map > best_valid_cmap:
Expand Down
1 change: 0 additions & 1 deletion classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"""

from typing import Dict, Any
import numpy as np
import torch

Expand Down

0 comments on commit a2a6581

Please sign in to comment.