Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 8, 2023
1 parent 62c38ce commit 9a68343
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions emo_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def __init__(self, config):
self.init_weights()

def forward(
self,
input_values,
self,
input_values,
):
outputs = self.wav2vec2(input_values)
hidden_states = outputs[0]
Expand All @@ -55,24 +55,24 @@ def forward(


# load model from hub
device = 'cuda'
model_name = './emotion/wav2vec2-large-robust-12-ft-emotion-msp-dim'
device = "cuda"
model_name = "./emotion/wav2vec2-large-robust-12-ft-emotion-msp-dim"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = EmotionModel.from_pretrained(model_name).to(device)


def process_func(
x: np.ndarray,
sampling_rate: int,
embeddings: bool = False,
x: np.ndarray,
sampling_rate: int,
embeddings: bool = False,
) -> np.ndarray:
r"""Predict emotions or extract embeddings from raw audio signal."""

# run through processor to normalize signal
# always returns a batch, so we just get the first entry
# then we put it on the device
y = processor(x, sampling_rate=sampling_rate)
y = y['input_values'][0]
y = y["input_values"][0]
y = torch.from_numpy(y).to(device)

# run through model
Expand All @@ -83,18 +83,22 @@ def process_func(
y = y.detach().cpu().numpy()

return y



rootpath = "/home/ubuntu/CVAEJETS/dataset/nene"
embs = []
wavnames = []


def extract_dir(path):
rootpath = path
for idx, wavname in enumerate(os.listdir(rootpath)):
wav, sr =librosa.load(f"{rootpath}/{wavname}", 16000)
wav, sr = librosa.load(f"{rootpath}/{wavname}", 16000)
emb = process_func(np.expand_dims(wav, 0), sr, embeddings=True)
embs.append(emb)
wavnames.append(wavname)
np.save(f"{rootpath}/{wavname}.emo.npy", emb.squeeze(0))
print(idx, wavname)


extract_dir(rootpath)

0 comments on commit 9a68343

Please sign in to comment.