Skip to content

Commit

Permalink
Merge branch 'seralize_waveforms' of github.com:UCSD-E4E/acoustic-mul…
Browse files Browse the repository at this point in the history
…ticlass-training into seralize_waveforms
  • Loading branch information
sprestrelski committed Jul 5, 2023
2 parents f576ab1 + ef19d35 commit 2da5ba9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
4 changes: 2 additions & 2 deletions classification/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,14 @@ def serialize_data(self):
right_on="IN FILE").dropna()

print("fixed size:", self.samples.shape)

self.samples["original_file_path"] = self.samples[self.config.file_path_col]

if "files" in self.samples.columns:
self.samples[self.config.file_path_col] = self.samples["files"].copy()
if "files_y" in self.samples.columns:
self.samples[self.config.file_path_col] = self.samples["files_y"].copy()

self.samples["original_file_path"] = self.samples[self.config.file_path_col]

self.formatted_csv_file = ".".join(self.csv_file.split(".")[:-1]) + "formatted.csv"
self.samples.to_csv(self.formatted_csv_file)

Expand Down
22 changes: 21 additions & 1 deletion classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def train(model: BirdCLEFModel,
outputs = model(mels)
# sigmoid multilabel predictions
# preds = torch.sigmoid(outputs) > 0.5


loss = model.loss_fn(outputs, labels)

Expand Down Expand Up @@ -216,6 +215,25 @@ def valid(model: BirdCLEFModel,
return running_loss/len(data_loader), valid_map


def test_loop(model: BirdCLEFModel,
data_loaders: PyhaDF_Dataset,
device: str,
step: int,
CONFIG):

model.eval()
for dl in data_loaders:
(mels, labels) = next(iter(dl))

print(mels.shape)
out = model(mels)

if (out.shape != labels.shape):
print(out.shape)
print(labels.shape)
raise RuntimeError("Shape diff between output of models and labels, see above and debug")


def init_wandb(CONFIG: Dict[str, Any]):
"""
Initialize the weights and biases logging
Expand Down Expand Up @@ -282,6 +300,8 @@ def main():
step = 0
best_valid_cmap = 0

test_loop(model_for_run, [train_dataloader, val_dataloader], device, step, CONFIG)

for epoch in range(CONFIG.epochs):
print("Epoch " + str(epoch))

Expand Down

0 comments on commit 2da5ba9

Please sign in to comment.