Skip to content

Commit d0830b9

Browse files
jhazenaaiploeberSwimburgerRobMcHRobert McHardy
authored
feat(mistral): add mistral support (#43)
Co-authored-by: Patrick Loeber <[email protected]> Co-authored-by: Niels Swimberghe <[email protected]> Co-authored-by: Robert McHardy <[email protected]> Co-authored-by: Robert McHardy <[email protected]> Co-authored-by: Martin Schweiger <[email protected]>
1 parent e4d3379 commit d0830b9

File tree

6 files changed

+63
-12
lines changed

6 files changed

+63
-12
lines changed

assemblyai/lemur.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from typing import Any, Dict, List, Optional, Union
22

3-
from . import api
3+
from . import api, types
44
from . import client as _client
5-
from . import types
65

76

87
class _LemurImpl:
@@ -173,10 +172,11 @@ def question(
173172
Args:
174173
questions: One or a list of questions to ask.
175174
context: The context which is shared among all questions. This can be a string or a dictionary.
176-
final_model: The model that is used for the final prompt after compression is performed (options: "basic" and "default").
175+
final_model: The model that is used for the final prompt after compression is performed (options: "basic", "default", and "assemblyai/mistral-7b").
177176
max_output_size: Max output size in tokens
178177
timeout: The timeout in seconds to wait for the answer(s).
179178
temperature: Change how deterministic the response is, with 0 being the most deterministic and 1 being the least deterministic.
179+
input_text: Custom formatted transcript data. Use instead of transcript_ids.
180180
181181
Returns: One or a list of answer objects.
182182
"""
@@ -214,10 +214,11 @@ def summarize(
214214
Args:
215215
context: An optional context on the transcript.
216216
answer_format: The format on how the summary shall be summarized.
217-
final_model: The model that is used for the final prompt after compression is performed (options: "basic" and "default").
217+
final_model: The model that is used for the final prompt after compression is performed (options: "basic", "default", and "assemblyai/mistral-7b").
218218
max_output_size: Max output size in tokens
219219
timeout: The timeout in seconds to wait for the summary.
220220
temperature: Change how deterministic the response is, with 0 being the most deterministic and 1 being the least deterministic.
221+
input_text: Custom formatted transcript data. Use instead of transcript_ids.
221222
222223
Returns: The summary as a string.
223224
"""
@@ -253,10 +254,11 @@ def action_items(
253254
Args:
254255
context: An optional context on the transcript.
255256
answer_format: The preferred format for the result action items.
256-
final_model: The model that is used for the final prompt after compression is performed (options: "basic" and "default").
257+
final_model: The model that is used for the final prompt after compression is performed (options: "basic", "default", and "assemblyai/mistral-7b").
257258
max_output_size: Max output size in tokens
258259
timeout: The timeout in seconds to wait for the action items response.
259260
temperature: Change how deterministic the response is, with 0 being the most deterministic and 1 being the least deterministic.
261+
input_text: Custom formatted transcript data. Use instead of transcript_ids.
260262
261263
Returns: The action items as a string.
262264
"""
@@ -287,10 +289,11 @@ def task(
287289
288290
Args:
289291
prompt: The prompt to use for this task.
290-
final_model: The model that is used for the final prompt after compression is performed (options: "basic" and "default").
292+
final_model: The model that is used for the final prompt after compression is performed (options: "basic", "default", and "assemblyai/mistral-7b").
291293
max_output_size: Max output size in tokens
292294
timeout: The timeout in seconds to wait for the task.
293295
temperature: Change how deterministic the response is, with 0 being the most deterministic and 1 being the least deterministic.
296+
input_text: Custom formatted transcript data. Use instead of transcript_ids.
294297
295298
Returns: A response to a question or task submitted via custom prompt (with source transcripts or other sources taken into the context)
296299
"""

assemblyai/transcriber.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@
2727
from typing_extensions import Self
2828
from websockets.sync.client import connect as websocket_connect
2929

30-
from . import api
30+
from . import api, lemur, types
3131
from . import client as _client
32-
from . import lemur, types
3332

3433

3534
class _TranscriptImpl:

assemblyai/types.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,6 +1247,7 @@ class Word(BaseModel):
12471247
start: int
12481248
end: int
12491249
confidence: float
1250+
speaker: Optional[str]
12501251

12511252

12521253
class UtteranceWord(Word):
@@ -1382,6 +1383,10 @@ class RedactedAudioResponse(BaseModel):
13821383

13831384
class Sentence(Word):
13841385
words: List[Word]
1386+
start: int
1387+
end: int
1388+
confidence: int
1389+
speaker: Optional[str]
13851390

13861391

13871392
class SentencesResponse(BaseModel):
@@ -1392,6 +1397,10 @@ class SentencesResponse(BaseModel):
13921397

13931398
class Paragraph(Word):
13941399
words: List[Word]
1400+
start: int
1401+
end: int
1402+
confidence: int
1403+
text: str
13951404

13961405

13971406
class ParagraphsResponse(BaseModel):
@@ -1695,7 +1704,7 @@ def from_lemur_source(cls, source: LemurSource) -> Self:
16951704

16961705
class LemurModel(str, Enum):
16971706
"""
1698-
LeMUR features two model modes, Basic and Default, that allow you to configure your request
1707+
LeMUR features three model modes, Basic, Default and Mistral 7B, that allow you to configure your request
16991708
to suit your needs. These options tell LeMUR whether to use the more advanced Default model or
17001709
the cheaper, faster, but simplified Basic model. The implicit setting is Default when no option
17011710
is explicitly passed in.
@@ -1720,6 +1729,11 @@ class LemurModel(str, Enum):
17201729
for complex/subjective tasks where answers require more nuance to be effective.
17211730
"""
17221731

1732+
mistral7b = "assemblyai/mistral-7b"
1733+
"""
1734+
Mistral 7B is an open source model that works well for summarization and answering questions.
1735+
"""
1736+
17231737

17241738
class LemurQuestionAnswer(BaseModel):
17251739
"""
@@ -1921,7 +1935,7 @@ class RealtimeTranscript(BaseModel):
19211935
text: str
19221936
"The transcript for your audio"
19231937

1924-
words: List[Word]
1938+
words: List[RealtimeWord]
19251939
"""
19261940
An array of objects, with the information for each word in the transcription text.
19271941
Will include the `start`/`end` time (in milliseconds) of the word, the `confidence` score of the word,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
setup(
99
name="assemblyai",
10-
version="0.20.0",
10+
version="0.20.1",
1111
description="AssemblyAI Python SDK",
1212
author="AssemblyAI",
1313
author_email="[email protected]",

tests/unit/test_lemur.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,40 @@ def test_lemur_task_succeeds_input_text(httpx_mock: HTTPXMock):
514514
assert len(httpx_mock.get_requests()) == 1
515515

516516

517+
def test_lemur_task_succeeds_mistral(httpx_mock: HTTPXMock):
518+
"""
519+
Tests whether creating a task request succeeds with mistral.
520+
"""
521+
522+
# create a mock response of a LemurSummaryResponse
523+
mock_lemur_task_response = factories.generate_dict_factory(
524+
factories.LemurTaskResponse
525+
)()
526+
527+
# mock the specific endpoints
528+
httpx_mock.add_response(
529+
url=f"{aai.settings.base_url}{ENDPOINT_LEMUR}/task",
530+
status_code=httpx.codes.OK,
531+
method="POST",
532+
json=mock_lemur_task_response,
533+
)
534+
# test input_text input
535+
lemur = aai.Lemur()
536+
result = lemur.task(
537+
final_model=aai.LemurModel.mistral7b,
538+
prompt="Create action items of the meeting",
539+
input_text="Test test",
540+
)
541+
542+
# check the response
543+
assert isinstance(result, aai.LemurTaskResponse)
544+
545+
assert result.response == mock_lemur_task_response["response"]
546+
547+
# check whether we mocked everything
548+
assert len(httpx_mock.get_requests()) == 1
549+
550+
517551
def test_lemur_ask_coach_fails(httpx_mock: HTTPXMock):
518552
"""
519553
Tests whether creating a task request fails.

tests/unit/test_summarization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ def test_summarization_fails_without_required_field(
2626
httpx_mock,
2727
{},
2828
config=aai.TranscriptionConfig(
29-
summarization=True, **{required_field: False} # type: ignore
29+
summarization=True,
30+
**{required_field: False}, # type: ignore
3031
),
3132
)
3233

0 commit comments

Comments
 (0)