Skip to content

Commit

Permalink
fixing custom model and changing it to base BirdAVES model
Browse files Browse the repository at this point in the history
  • Loading branch information
Ludwigvsch committed Aug 31, 2024
1 parent 4401ffe commit 4f96a75
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 30 deletions.
53 changes: 53 additions & 0 deletions birdaves-biox-base.torchaudio.model_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
{
"extractor_mode": "group_norm",
"extractor_conv_layer_config": [
[
512,
10,
5
],
[
512,
3,
2
],
[
512,
3,
2
],
[
512,
3,
2
],
[
512,
3,
2
],
[
512,
2,
2
],
[
512,
2,
2
]
],
"extractor_conv_bias": false,
"encoder_embed_dim": 768,
"encoder_projection_dropout": 0.1,
"encoder_pos_conv_kernel": 128,
"encoder_pos_conv_groups": 16,
"encoder_num_layers": 12,
"encoder_num_heads": 12,
"encoder_attention_dropout": 0.1,
"encoder_ff_interm_features": 3072,
"encoder_ff_interm_dropout": 0.0,
"encoder_dropout": 0.1,
"encoder_layer_norm_first": false,
"encoder_layer_drop": 0.05
}
75 changes: 53 additions & 22 deletions pyha_analyzer/models/CustomModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,44 +14,75 @@


class CustomModel(nn.Module):
def __init__(self, config_path, model_path, num_classes, trainable, embedding_dim=768):
super(CustomModel, self).__init__()
""" Uses AVES Hubert to embed sounds and classify """
def __init__(self, cfg, num_classes, model_path, trainable, config_path, embedding_dim=768):
super().__init__()
# reference: https://pytorch.org/audio/stable/_modules/torchaudio/models/wav2vec2/utils/import_fairseq.html
self.cfg = cfg
self.trainable = trainable
self.config = self.load_config(config_path)
self.model = wav2vec2_model(**self.config, aux_num_out=None)
self.model.load_state_dict(torch.load(model_path))
self.trainable = trainable
# Freeze the AVES network
self.freeze_embedding_weights(self.model, trainable)
self.classifier_head = nn.Linear(in_features=embedding_dim, out_features=num_classes)
self.loss_fn = None
# Add a linear layer to match the embedding dimensions
self.embedding_transform = nn.Linear(768, num_classes) #TODO: change this when switching models
# We will only train the classifier head
#self.classifier_head = nn.Linear(in_features=embedding_dim, out_features=num_classes)
self.audio_sr = cfg.sample_rate

def load_config(self, config_path):
with open(config_path, 'r') as f:
return json.load(f)
with open(config_path, 'r') as ff:
obj = json.load(ff)
return obj

def forward(self, sig):
"""
Input
sig (Tensor): (batch, time)
Returns
mean_embedding (Tensor): (batch, output_dim)
logits (Tensor): (batch, n_classes)
"""
# extract_feature in the torchaudio version will output all 12 layers' output, -1 to select the final one
out = self.model.extract_features(sig)[0][-1]
mean_embedding = out.mean(dim=1) #over time
logits = self.embedding_transform(mean_embedding) # Transform embedding dimensions
#logits = self.classifier_head(mean_embedding)
return mean_embedding, logits

def freeze_embedding_weights(self, model, trainable):
""" Freeze weights in AVES embeddings for classification """
# The convolutional layers should never be trainable
model.feature_extractor.requires_grad_(False)
model.feature_extractor.eval()
# The transformers are optionally trainable
for param in model.encoder.parameters():
param.requires_grad = trainable
if not trainable:
# We also set layers without params (like dropout) to eval mode, so they do not change
model.encoder.eval()

def set_eval_aves(model):
""" Set AVES-based classifier to eval mode. Takes into account whether we are training transformers """
model.classifier_head.eval()
model.model.encoder.eval()



def forward(self, sig):
out = self.model.extract_features(sig)[0][-1]
mean_embedding = out.mean(dim=1)
logits = self.classifier_head(mean_embedding)
return mean_embedding, logits

def create_loss_fn(self, cfg, train_dataset):
loss_desc = cfg.loss_fnc
def create_loss_fn(self, train_dataset):
loss_desc = self.cfg.loss_fnc
if loss_desc == "CE":
self.loss_fn = nn.CrossEntropyLoss()
elif loss_desc == "BCE":
self.loss_fn = nn.BCEWithLogitsLoss()
else:
raise RuntimeError(f"Unsupported loss function: {loss_desc}")
return cross_entropy_loss_fn(self, train_dataset)
if loss_desc == "BCE":
return bce_loss_fn(self, without_logits=True)
if loss_desc == "BCEWL":
return bce_loss_fn(self, without_logits=False)
if loss_desc == "FL":
return focal_loss_fn(self, self.without_logits)
raise RuntimeError("Unsupported loss function")

def download_model_files():
import os
os.system("wget https://storage.googleapis.com/esp-public-files/ported_aves/aves-base-bio.torchaudio.pt")
os.system("wget https://storage.googleapis.com/esp-public-files/ported_aves/aves-base-bio.torchaudio.model_config.json")
os.system("wget https://storage.googleapis.com/esp-public-files/ported_aves/birdaves-biox-base.torchaudio.pt")
os.system("wget https://storage.googleapis.com/esp-public-files/ported_aves/birdaves-biox-base.torchaudio.model_config.json")
25 changes: 25 additions & 0 deletions pyha_analyzer/models/inference_M_dataset.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 1 addition & 1 deletion pyha_analyzer/models/loss_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def bce_loss_fn(self, without_logits=False):
BCEwithLogitsLoss
"""
if not without_logits:
self.loss_fn = nn.BCEWithLogitsLoss(reduction='sum')
self.loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
else:
self.loss_fn = nn.BCELoss(reduction='mean')
return self.loss_fn
Expand Down
17 changes: 10 additions & 7 deletions pyha_analyzer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,10 @@ def download_model_files():
import urllib.request

urls = [
#"https://storage.googleapis.com/esp-public-files/birdaves/birdaves-biox-large.torchaudio.pt",
#"https://storage.googleapis.com/esp-public-files/birdaves/birdaves-biox-large.torchaudio.model_config.json"
"https://storage.googleapis.com/esp-public-files/ported_aves/aves-base-bio.torchaudio.pt",
"https://storage.googleapis.com/esp-public-files/ported_aves/aves-base-bio.torchaudio.model_config.json"
"https://storage.googleapis.com/esp-public-files/birdaves/birdaves-biox-base.torchaudio.pt",
"https://storage.googleapis.com/esp-public-files/birdaves/birdaves-biox-base.torchaudio.model_config.json"
#"https://storage.googleapis.com/esp-public-files/ported_aves/aves-base-bio.torchaudio.pt",
#"https://storage.googleapis.com/esp-public-files/ported_aves/aves-base-bio.torchaudio.model_config.json"
]
for url in urls:
filename = url.split("/")[-1]
Expand Down Expand Up @@ -368,12 +368,15 @@ def main(in_sweep=True) -> None:
logger.info("Loading Model...")
download_model_files()
model_for_run = CustomModel(
config_path= "aves-base-bio.torchaudio.model_config.json", #aves-base-bio.torchaudio.model_config.json",
model_path= "aves-base-bio.torchaudio.pt", #"aves-base-bio.torchaudio.pt",
config_path="birdaves-biox-base.torchaudio.model_config.json",
model_path="birdaves-biox-base.torchaudio.pt",
cfg=cfg,
#config_path= "aves-base-bio.torchaudio.model_config.json", #aves-base-bio.torchaudio.model_config.json",
#model_path= "aves-base-bio.torchaudio.pt", #"aves-base-bio.torchaudio.pt",
num_classes=train_dataset.num_classes,
trainable=cfg.trainable,
).to(cfg.device)
model_for_run.create_loss_fn(cfg, train_dataset)
model_for_run.create_loss_fn(train_dataset)

optimizer = Adam(model_for_run.parameters(), lr=cfg.learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, eta_min=1e-5, T_max=10)
Expand Down

0 comments on commit 4f96a75

Please sign in to comment.