Skip to content

Commit

Permalink
feat: client side few-shot classification
Browse files Browse the repository at this point in the history
- train LR head and classify at the sdk level with co.embed()
  • Loading branch information
hemant-co committed Nov 23, 2024
1 parent 756515a commit cfa1fe6
Showing 1 changed file with 45 additions and 138 deletions.
183 changes: 45 additions & 138 deletions src/cohere/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@
from .models.client import AsyncModelsClient
from .finetuning.client import AsyncFinetuningClient

import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.exceptions import ConvergenceWarning
import warnings

# this is used as the default value for optional parameters
OMIT = typing.cast(typing.Any, ...)

Expand Down Expand Up @@ -2245,144 +2250,46 @@ def classify(
inputs=["inputs"],
)
"""
_response = self._client_wrapper.httpx_client.request(
"v1/classify",
method="POST",
json={
"inputs": inputs,
"examples": convert_and_respect_annotation_metadata(
object_=examples, annotation=typing.Sequence[ClassifyExample], direction="write"
),
"model": model,
"preset": preset,
"truncate": truncate,
},
request_options=request_options,
omit=OMIT,
)
try:
if 200 <= _response.status_code < 300:
return typing.cast(
ClassifyResponse,
construct_type(
type_=ClassifyResponse, # type: ignore
object_=_response.json(),
),
)
if _response.status_code == 400:
raise BadRequestError(
typing.cast(
typing.Optional[typing.Any],
construct_type(
type_=typing.Optional[typing.Any], # type: ignore
object_=_response.json(),
),
)
)
if _response.status_code == 401:
raise UnauthorizedError(
typing.cast(
typing.Optional[typing.Any],
construct_type(
type_=typing.Optional[typing.Any], # type: ignore
object_=_response.json(),
),
)
)
if _response.status_code == 403:
raise ForbiddenError(
typing.cast(
typing.Optional[typing.Any],
construct_type(
type_=typing.Optional[typing.Any], # type: ignore
object_=_response.json(),
),
)
)
if _response.status_code == 404:
raise NotFoundError(
typing.cast(
typing.Optional[typing.Any],
construct_type(
type_=typing.Optional[typing.Any], # type: ignore
object_=_response.json(),
),
)
)
if _response.status_code == 422:
raise UnprocessableEntityError(
typing.cast(
UnprocessableEntityErrorBody,
construct_type(
type_=UnprocessableEntityErrorBody, # type: ignore
object_=_response.json(),
),
)
)
if _response.status_code == 429:
raise TooManyRequestsError(
typing.cast(
TooManyRequestsErrorBody,
construct_type(
type_=TooManyRequestsErrorBody, # type: ignore
object_=_response.json(),
),
)
)
if _response.status_code == 499:
raise ClientClosedRequestError(
typing.cast(
ClientClosedRequestErrorBody,
construct_type(
type_=ClientClosedRequestErrorBody, # type: ignore
object_=_response.json(),
),
)
)
if _response.status_code == 500:
raise InternalServerError(
typing.cast(
typing.Optional[typing.Any],
construct_type(
type_=typing.Optional[typing.Any], # type: ignore
object_=_response.json(),
),
)
)
if _response.status_code == 501:
raise NotImplementedError(
typing.cast(
NotImplementedErrorBody,
construct_type(
type_=NotImplementedErrorBody, # type: ignore
object_=_response.json(),
),
)
)
if _response.status_code == 503:
raise ServiceUnavailableError(
typing.cast(
typing.Optional[typing.Any],
construct_type(
type_=typing.Optional[typing.Any], # type: ignore
object_=_response.json(),
),
)
)
if _response.status_code == 504:
raise GatewayTimeoutError(
typing.cast(
GatewayTimeoutErrorBody,
construct_type(
type_=GatewayTimeoutErrorBody, # type: ignore
object_=_response.json(),
),
)
)
_response_json = _response.json()
except JSONDecodeError:
raise ApiError(status_code=_response.status_code, body=_response.text)
raise ApiError(status_code=_response.status_code, body=_response_json)
# TODO @coderham: Cache the LR model to speedup the classification process and save on API calls
example_embeds = self.embed(model=model, texts=[example.text for example in examples], input_type="classification")
example_embeddings = np.zeros((len(examples), len(example_embeds.embeddings[0])))
for i, embed in enumerate(example_embeds.embeddings):
example_embeddings[i] = np.array(embed)
example_labels = np.array([example.label for example in examples])

convergence_warning = ""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always", ConvergenceWarning)
clf = LogisticRegression(random_state=0).fit(example_embeddings, example_labels)
if any(issubclass(w_.category, ConvergenceWarning) for w_ in w):
convergence_warning = str(w[-1].message)
if convergence_warning:
print(f"ConvergenceWarning: {convergence_warning}")

# Make predictions on the inputs
input_embeds = self.embed(model=model, texts=inputs, input_type="classification")
input_embeddings = np.zeros((len(inputs), len(input_embeds.embeddings[0])))
for i, embed in enumerate(input_embeds.embeddings):
input_embeddings[i] = np.array(embed)

class_probs = clf.predict_proba(input_embeddings)
class_labels = clf.classes_
classifications = []
for i, class_prob in enumerate(class_probs):
class_idx = np.argmax(class_prob)
predicted_label = class_labels[class_idx]
classifications.append({
"id": input_embeds.id,
"classification_type":"single-label",
"input": inputs[i],
"prediction": predicted_label,
"predictions": [predicted_label],
"confidence": class_prob[class_idx],
"confidences": [class_prob[class_idx]],
"labels": {class_labels[j]: {"confidence": class_prob[j]} for j in range(len(class_labels))},
})

return ClassifyResponse(id=input_embeds.id, classifications=classifications)

def summarize(
self,
Expand Down

0 comments on commit cfa1fe6

Please sign in to comment.