Skip to content

Commit

Permalink
renamed to clara
Browse files Browse the repository at this point in the history
  • Loading branch information
knoriy committed Oct 13, 2023
1 parent 2c015ed commit cdbd087
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
4 changes: 2 additions & 2 deletions clara/clara.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from encoders.text_encoders import SimpleTransformer
from encoders.audio_encoders import *
from encoders.modules import PositionalEncoding, LayerNorm, MLPLayers
from loss import CLAPLoss, CLIPLoss
from loss import CLARALoss, CLIPLoss

from scheduler import CosineAnnealingWarmupRestarts
from utils.accuracy import Accuracy, accuracy
Expand Down Expand Up @@ -202,7 +202,7 @@ def __init__( self,
self.save_hyperparameters()

self.model = CLARA(self.hparams)
self.loss_fn = CLAPLoss(cache_labels=True)
self.loss_fn = CLARALoss(cache_labels=True)
self.acc_fn = Accuracy(cache_labels=True)

def forward(self, texts:Optional[torch.Tensor], mels:Optional[torch.Tensor]):
Expand Down
6 changes: 3 additions & 3 deletions clara/eval/test_linear_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from torchmetrics import MetricCollection, Recall, Accuracy

from clasp import LinearProbeCLASP
from clara import LinearProbeCLARA
from utils import calculate_average
from eval.util import get_dataset

Expand Down Expand Up @@ -50,7 +50,7 @@ def main(args):
##############
# Model
##############
model = LinearProbeCLASP.load_from_checkpoint(args.model_path, clasp_map_location=args.device, clasp_checkpoint_path=args.clasp_path, map_location=args.device)
model = LinearProbeCLARA.load_from_checkpoint(args.model_path, clara_map_location=args.device, clara_checkpoint_path=args.clara_path, map_location=args.device)

##############
# DataModule
Expand Down Expand Up @@ -95,7 +95,7 @@ def main(args):

parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, help='Path to model with linear probe head')
parser.add_argument('--clasp_path', type=str, help='Path to pretrained CLASP model')
parser.add_argument('--clara_path', type=str, help='Path to pretrained CLARA model')
parser.add_argument('--root_cfg_path', type=str, default='./config/', help='root path to config files')
parser.add_argument('--task', type=str, choices=['texts', 'gender', 'emotion', 'age', 'sounds', 'speech'], help='Task to run')
parser.add_argument('--dataset_name', type=str, required=True, help='if task is sounds or emotion, specify dataset name')
Expand Down
4 changes: 2 additions & 2 deletions clara/eval/test_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from torchmetrics import MetricCollection

from clasp import PLCLASP
from clara import PLCLARA
from text.tokeniser import Tokeniser
from eval.util import get_dataset

Expand Down Expand Up @@ -115,7 +115,7 @@ def run(model, zeroshot_weights, dataloader, metric_fn:MetricCollection, limit_b

args = parser.parse_args()

model = PLCLASP.load_from_checkpoint(args.model_path, map_location=device)
model = PLCLARA.load_from_checkpoint(args.model_path, map_location=device)

##############
# DataModule
Expand Down
4 changes: 2 additions & 2 deletions clara/eval/test_zeroshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from torchmetrics import MetricCollection, Recall, Accuracy, Precision, AveragePrecision

from clasp import PLCLASP
from clara import PLCLARA
from text.tokeniser import Tokeniser
from utils import calculate_average
from eval.util import get_dataset
Expand Down Expand Up @@ -97,7 +97,7 @@ def main(args):
##############
# Model
##############
model = PLCLASP.load_from_checkpoint(args.model_path).to(device)
model = PLCLARA.load_from_checkpoint(args.model_path).to(device)

##############
# DataModule
Expand Down

0 comments on commit cdbd087

Please sign in to comment.