Skip to content

Commit

Permalink
fix: v2 API need special handling
Browse files Browse the repository at this point in the history
- in V2 client API the embeddings are not returned as a list but as EmbedByTypeResponseEmbeddings
- v2 client api also mandates embedding_types arg for embed()
  • Loading branch information
hemant-co committed Nov 23, 2024
1 parent cfa1fe6 commit d5d6665
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions src/cohere/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2251,9 +2251,17 @@ def classify(
)
"""
# TODO @coderham: Cache the LR model to speedup the classification process and save on API calls
example_embeds = self.embed(model=model, texts=[example.text for example in examples], input_type="classification")
example_embeddings = np.zeros((len(examples), len(example_embeds.embeddings[0])))
for i, embed in enumerate(example_embeds.embeddings):
example_embeds = self.embed(model=model, embedding_types=["float"], input_type="classification",
texts=[example.text for example in examples])

# in V2 client API the embeddings are not returned as a list but as EmbedByTypeResponseEmbeddings
if type(example_embeds.embeddings) is not list:
example_embedding_float = example_embeds.embeddings.float_
else:
example_embedding_float = example_embeds.embeddings

example_embeddings = np.zeros((len(examples), len(example_embedding_float[0])))
for i, embed in enumerate(example_embedding_float):
example_embeddings[i] = np.array(embed)
example_labels = np.array([example.label for example in examples])

Expand All @@ -2267,9 +2275,17 @@ def classify(
print(f"ConvergenceWarning: {convergence_warning}")

# Make predictions on the inputs
input_embeds = self.embed(model=model, texts=inputs, input_type="classification")
input_embeddings = np.zeros((len(inputs), len(input_embeds.embeddings[0])))
for i, embed in enumerate(input_embeds.embeddings):
input_embeds = self.embed(model=model, embedding_types=["float"], input_type="classification",
texts=inputs)

# in V2 client API the embeddings are not returned as a list but as EmbedByTypeResponseEmbeddings
if type(input_embeds.embeddings) is not list:
input_embedding_float = input_embeds.embeddings.float_
else:
input_embedding_float = input_embeds.embeddings

input_embeddings = np.zeros((len(inputs), len(input_embedding_float[0])))
for i, embed in enumerate(input_embedding_float):
input_embeddings[i] = np.array(embed)

class_probs = clf.predict_proba(input_embeddings)
Expand Down

0 comments on commit d5d6665

Please sign in to comment.