Skip to content

Commit

Permalink
Merge branch 'seralize_waveforms' of github.com:UCSD-E4E/acoustic-mul…
Browse files Browse the repository at this point in the history
…ticlass-training into seralize_waveforms
  • Loading branch information
benjamin-cates committed Jul 5, 2023
2 parents 43df354 + d1c4b8c commit ef19d35
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def train(model: BirdCLEFModel,
outputs = model(mels)
# sigmoid multilabel predictions
# preds = torch.sigmoid(outputs) > 0.5


loss = model.loss_fn(outputs, labels)

Expand Down Expand Up @@ -217,6 +216,25 @@ def valid(model: BirdCLEFModel,
return running_loss/len(data_loader), valid_map


def test_loop(model: BirdCLEFModel,
data_loaders: PyhaDF_Dataset,
device: str,
step: int,
CONFIG):

model.eval()
for dl in data_loaders:
(mels, labels) = next(iter(dl))

print(mels.shape)
out = model(mels)

if (out.shape != labels.shape):
print(out.shape)
print(labels.shape)
raise RuntimeError("Shape diff between output of models and labels, see above and debug")


def init_wandb(CONFIG: Dict[str, Any]):
"""
Initialize the weights and biases logging
Expand Down Expand Up @@ -285,6 +303,8 @@ def main():
step = 0
best_valid_cmap = 0

test_loop(model_for_run, [train_dataloader, val_dataloader], device, step, CONFIG)

for epoch in range(CONFIG.epochs):
print("Epoch " + str(epoch))

Expand Down

0 comments on commit ef19d35

Please sign in to comment.