From cdbd087b4cfec14f151c6134724f463a757babb1 Mon Sep 17 00:00:00 2001 From: knoriy Date: Fri, 13 Oct 2023 13:48:19 +0000 Subject: [PATCH] renamed to clara --- clara/clara.py | 4 ++-- clara/eval/test_linear_probe.py | 6 +++--- clara/eval/test_retrieval.py | 4 ++-- clara/eval/test_zeroshot.py | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/clara/clara.py b/clara/clara.py index c5ae59a..58739a3 100644 --- a/clara/clara.py +++ b/clara/clara.py @@ -10,7 +10,7 @@ from encoders.text_encoders import SimpleTransformer from encoders.audio_encoders import * from encoders.modules import PositionalEncoding, LayerNorm, MLPLayers -from loss import CLAPLoss, CLIPLoss +from loss import CLARALoss, CLIPLoss from scheduler import CosineAnnealingWarmupRestarts from utils.accuracy import Accuracy, accuracy @@ -202,7 +202,7 @@ def __init__( self, self.save_hyperparameters() self.model = CLARA(self.hparams) - self.loss_fn = CLAPLoss(cache_labels=True) + self.loss_fn = CLARALoss(cache_labels=True) self.acc_fn = Accuracy(cache_labels=True) def forward(self, texts:Optional[torch.Tensor], mels:Optional[torch.Tensor]): diff --git a/clara/eval/test_linear_probe.py b/clara/eval/test_linear_probe.py index cdd4e3f..52355ec 100644 --- a/clara/eval/test_linear_probe.py +++ b/clara/eval/test_linear_probe.py @@ -7,7 +7,7 @@ from torchmetrics import MetricCollection, Recall, Accuracy -from clasp import LinearProbeCLASP +from clara import LinearProbeCLARA from utils import calculate_average from eval.util import get_dataset @@ -50,7 +50,7 @@ def main(args): ############## # Model ############## - model = LinearProbeCLASP.load_from_checkpoint(args.model_path, clasp_map_location=args.device, clasp_checkpoint_path=args.clasp_path, map_location=args.device) + model = LinearProbeCLARA.load_from_checkpoint(args.model_path, clara_map_location=args.device, clara_checkpoint_path=args.clara_path, map_location=args.device) ############## # DataModule @@ -95,7 +95,7 @@ def main(args): parser = argparse.ArgumentParser() parser.add_argument('--model_path', type=str, help='Path to model with linear probe head') - parser.add_argument('--clasp_path', type=str, help='Path to pretrained CLASP model') + parser.add_argument('--clara_path', type=str, help='Path to pretrained CLARA model') parser.add_argument('--root_cfg_path', type=str, default='./config/', help='root path to config files') parser.add_argument('--task', type=str, choices=['texts', 'gender', 'emotion', 'age', 'sounds', 'speech'], help='Task to run') parser.add_argument('--dataset_name', type=str, required=True, help='if task is sounds or emotion, specify dataset name') diff --git a/clara/eval/test_retrieval.py b/clara/eval/test_retrieval.py index 7ce7d10..9b9f041 100644 --- a/clara/eval/test_retrieval.py +++ b/clara/eval/test_retrieval.py @@ -9,7 +9,7 @@ from torchmetrics import MetricCollection -from clasp import PLCLASP +from clara import PLCLARA from text.tokeniser import Tokeniser from eval.util import get_dataset @@ -115,7 +115,7 @@ def run(model, zeroshot_weights, dataloader, metric_fn:MetricCollection, limit_b args = parser.parse_args() - model = PLCLASP.load_from_checkpoint(args.model_path, map_location=device) + model = PLCLARA.load_from_checkpoint(args.model_path, map_location=device) ############## # DataModule diff --git a/clara/eval/test_zeroshot.py b/clara/eval/test_zeroshot.py index 588a92c..8e96d49 100644 --- a/clara/eval/test_zeroshot.py +++ b/clara/eval/test_zeroshot.py @@ -9,7 +9,7 @@ from torchmetrics import MetricCollection, Recall, Accuracy, Precision, AveragePrecision -from clasp import PLCLASP +from clara import PLCLARA from text.tokeniser import Tokeniser from utils import calculate_average from eval.util import get_dataset @@ -97,7 +97,7 @@ def main(args): ############## # Model ############## - model = PLCLASP.load_from_checkpoint(args.model_path).to(device) + model = PLCLARA.load_from_checkpoint(args.model_path).to(device) ############## # DataModule