Skip to content

Commit 6b069d2

Browse files
committed
update
1 parent 6046265 commit 6b069d2

File tree

4 files changed

+27
-11
lines changed

4 files changed

+27
-11
lines changed

DATASETS.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ Each CSV file includes the following columns:
6363
<hr><hr>
6464
<br><br>
6565

66-
We have uploaded all datasets on [Huggingface Datasets](https://huggingface.co/docs/datasets/en/index). Following are the python commands to download datasets. Make sure to provide valid destination dataset path ending with 'Audio-Datasets' folder and install `huggingface_hub` package. We have also provided a [Jupyter Notebook](/utils/DownloadAudioDatasets.ipynb) to download all datasets in one go. It might take some time to download all datasets, so we recommend running the notebook on a cloud instance or a machine with good internet speed.<br><br>
66+
We have uploaded all datasets on [Huggingface Datasets](https://huggingface.co/docs/datasets/en/index). Following are the python commands to download datasets. Make sure to provide valid destination dataset path ending with 'Audio-Datasets' folder and install `huggingface_hub` package. We have also provided a [Jupyter Notebook](/media/DownloadAudioDatasets.ipynb) to download all datasets in one go. It might take some time to download all datasets, so we recommend running the notebook on a cloud instance or a machine with good internet speed.<br><br>
6767
`pip install huggingface-hub==0.25.1`
6868

6969
<br>

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ wget https://zenodo.org/records/8387083/files/base.pth
9494

9595
## Datasets :page_with_curl:
9696

97-
We have performed experiments on 11 audio classification datasets. Instructions for downloading/processing datasets used by our method have been provided in the [DATASETS.md](DATASETS.md). All of the datasets have been uploaded on HuggingFace Datasets Hub :hugs: for easy access. We have also provided a [Jupyter Notebook](/utils/DownloadAudioDatasets.ipynb) to download all datasets in one go. It might take some time to download all datasets, so we recommend running the notebook on a cloud instance or a machine with good internet speed.
97+
We have performed experiments on 11 audio classification datasets. Instructions for downloading/processing datasets used by our method have been provided in the [DATASETS.md](DATASETS.md). All of the datasets have been uploaded on HuggingFace Datasets Hub :hugs: for easy access. We have also provided a [Jupyter Notebook](/media/DownloadAudioDatasets.ipynb) to download all datasets in one go. It might take some time to download all datasets, so we recommend running the notebook on a cloud instance or a machine with good internet speed.
9898

9999
| Dataset | Type | Classes | Size | Link |
100100
|:-- |:-- |:--: |--: |:-- |
File renamed without changes.

utils/utils.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,28 @@ def get_dataloaders(args):
4141

4242

4343
def save_model(args, model, save_model_path):
44-
print(f"Saving Context Weights for Method: '{args.model_name.upper()}'\n")
45-
if args.model_name in ['coop', 'cocoop', 'palm']:
46-
torch.save(model.prompt_learner.state_dict(), save_model_path)
47-
else:
48-
raise ValueError(f"Model '{args.model_name}' is not supported. Choose from: [{', '.join(METHODS)}]")
49-
44+
print(f"Saving Context Weights for Method: '{args.model_name.upper()}'\n")
45+
if args.model_name in ['coop', 'cocoop', 'palm']:
46+
checkpoint = {'prompt_learner': model.prompt_learner.state_dict()}
47+
checkpoint['pengi_bn0_buffer'] = {'running_mean': model.audio_encoder.base.htsat.bn0.running_mean.clone(),
48+
'running_var': model.audio_encoder.base.htsat.bn0.running_var.clone(),
49+
'num_batches_tracked': model.audio_encoder.base.htsat.bn0.num_batches_tracked.clone()}
50+
51+
torch.save(checkpoint, save_model_path)
52+
else:
53+
raise ValueError(f"Model '{args.model_name}' is not supported. Choose from: [{', '.join(METHODS)}]")
54+
55+
56+
5057
def load_model(args, model):
51-
raise NotImplementedError("\n\nLoading model is not implemented yet.\n\n")
58+
load_model_path = get_load_model_path(args)
59+
checkpoint = torch.load(load_model_path)
60+
model.prompt_learner.load_state_dict(checkpoint['prompt_learner'])
61+
model.audio_encoder.base.htsat.bn0.running_mean.copy_(checkpoint['pengi_bn0_buffer']['running_mean'])
62+
model.audio_encoder.base.htsat.bn0.running_var.copy_(checkpoint['pengi_bn0_buffer']['running_var'])
63+
model.audio_encoder.base.htsat.bn0.num_batches_tracked.copy_(checkpoint['pengi_bn0_buffer']['num_batches_tracked'])
64+
# raise NotImplementedError("\n\nLoading model is not implemented yet.\n\n")
65+
5266

5367

5468
def get_save_model_path(args):
@@ -105,8 +119,10 @@ def get_args():
105119
raise ValueError(f"\n\nDirectory '{args.dataset_root}' does not exist. Specify the correct path to the dataset.\n\n")
106120
if args.save_model and not os.path.exists(args.save_model_path):
107121
raise ValueError(f"\n\nDirectory '{args.save_model_path}' does not exist. Create or specify the correct the directory to save the trained model.\n\n")
108-
if args.eval_only and not os.path.exists(args.load_model_path):
109-
raise ValueError(f"\n\nEvaluation Mode: Model file '{args.load_model_path}' does not exist. Specify the correct path to the model file.\n\n")
122+
if args.eval_only:
123+
load_model_path = get_load_model_path(args)
124+
if not os.path.exists(load_model_path): raise ValueError(f"\n\nEvaluation Mode: Model file '{load_model_path}' does not exist. Specify the correct path to the model file.\n\n")
125+
110126
if args.model_name == 'zeroshot': args.eval_only = True
111127

112128
return args

0 commit comments

Comments
 (0)