Skip to content

Commit 39b0552

Browse files
committed
fix(python/sdk): Fix custom spelling property and add unit tests (#6763)
GitOrigin-RevId: c5d69d94605abad2bef7510b8ba61bfa1d079db3
1 parent c1f2ce4 commit 39b0552

File tree

2 files changed

+107
-10
lines changed

2 files changed

+107
-10
lines changed

assemblyai/types.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ class RawTranscriptionConfig(BaseModel):
524524
"Enable Topic Detection."
525525

526526
custom_spelling: Optional[List[Dict[str, Union[str, List[str]]]]] = None
527-
"Customize how words are spelled and formatted using to and from values"
527+
"Customize how words are spelled and formatted using to and from values."
528528

529529
disfluencies: Optional[bool] = None
530530
"Transcribe Filler Words, like 'umm', in your media file."
@@ -916,18 +916,21 @@ def iab_categories(self, enable: Optional[bool]) -> None:
916916

917917
@property
918918
def custom_spelling(self) -> Optional[Dict[str, Union[str, List[str]]]]:
919-
"Returns the current set custom spellings."
919+
"""
920+
Returns the current set of custom spellings. For each key-value pair in the dictionary,
921+
the key is the 'to' field, and the value is the 'from' field.
922+
"""
920923

921924
if self._raw_transcription_config.custom_spelling is None:
922925
return None
923926

924927
custom_spellings = {}
925928
for custom_spelling in self._raw_transcription_config.custom_spelling:
926-
_from = custom_spelling["from"]
927-
if isinstance(_from, str):
928-
custom_spellings[_from] = custom_spelling["to"]
929-
else:
930-
raise ValueError("`from` argument must be a string!")
929+
_to = custom_spelling["to"]
930+
if not isinstance(_to, str):
931+
raise ValueError("`to` argument must be a string!")
932+
933+
custom_spellings[_to] = custom_spelling["from"]
931934

932935
return custom_spellings if custom_spelling else None
933936

@@ -1231,13 +1234,14 @@ def set_custom_spelling(
12311234
Customize how given words are being spelled or formatted in the transcription's text.
12321235
12331236
Args:
1234-
replacement: A dictionary that contains the replacement object (see below example)
1237+
replacement: A dictionary that contains the replacement object (see below example).
1238+
For each key-value pair, the key is the 'to' field, and the value is the 'from' field.
12351239
override: If `True` `replacement` gets overriden with the given `replacement` argument, otherwise merged.
12361240
12371241
Example:
12381242
```
12391243
config.custom_spelling({
1240-
"AssemblyAI": "AssemblyAI",
1244+
"AssemblyAI": "assemblyAI",
12411245
"Kubernetes": ["k8s", "kubernetes"]
12421246
})
12431247
```
@@ -1619,7 +1623,7 @@ class BaseTranscript(BaseModel):
16191623
"Enable Topic Detection."
16201624

16211625
custom_spelling: Optional[List[Dict[str, Union[str, List[str]]]]] = None
1622-
"Customize how words are spelled and formatted using to and from values"
1626+
"Customize how words are spelled and formatted using to and from values."
16231627

16241628
disfluencies: Optional[bool] = None
16251629
"Transcribe Filler Words, like 'umm', in your media file."

tests/unit/test_custom_spelling.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import factory
2+
from pytest_httpx import HTTPXMock
3+
4+
import tests.unit.unit_test_utils as unit_test_utils
5+
import assemblyai as aai
6+
from tests.unit import factories
7+
8+
aai.settings.api_key = "test"
9+
10+
11+
class CustomSpellingFactory(factory.Factory):
12+
class Meta:
13+
model = dict # The model is a dictionary
14+
rename = {"_from": "from"}
15+
16+
_from = factory.List([factory.Faker("word")]) # List of words in 'from'
17+
to = factory.Faker("word") # one word in 'to'
18+
19+
20+
class CustomSpellingResponseFactory(factories.TranscriptCompletedResponseFactory):
21+
@factory.lazy_attribute
22+
def custom_spelling(self):
23+
return [CustomSpellingFactory()]
24+
25+
26+
def test_custom_spelling_disabled_by_default(httpx_mock: HTTPXMock):
27+
"""
28+
Tests that not calling `set_custom_spelling()` on the `TranscriptionConfig`
29+
will result in the default behavior of it being excluded from the request body.
30+
"""
31+
request_body, transcript = unit_test_utils.submit_mock_transcription_request(
32+
httpx_mock,
33+
mock_response=factories.generate_dict_factory(
34+
factories.TranscriptCompletedResponseFactory
35+
)(),
36+
config=aai.TranscriptionConfig(),
37+
)
38+
assert request_body.get("custom_spelling") is None
39+
assert transcript.json_response.get("custom_spelling") is None
40+
41+
42+
def test_custom_spelling_set_config_succeeds():
43+
"""
44+
Tests that calling `set_custom_spelling()` on the `TranscriptionConfig`
45+
will set the values correctly, and that the config values can be accessed again
46+
through the custom_spelling property.
47+
"""
48+
config = aai.TranscriptionConfig()
49+
50+
# Setting a string will be put in a list
51+
config.set_custom_spelling({"AssemblyAI": "assemblyAI"})
52+
assert config.custom_spelling == {"AssemblyAI": ["assemblyAI"]}
53+
54+
# Setting multiple pairs works
55+
config.set_custom_spelling(
56+
{"AssemblyAI": "assemblyAI", "Kubernetes": ["k8s", "kubernetes"]}, override=True
57+
)
58+
assert config.custom_spelling == {
59+
"AssemblyAI": ["assemblyAI"],
60+
"Kubernetes": ["k8s", "kubernetes"],
61+
}
62+
63+
64+
def test_custom_spelling_enabled(httpx_mock: HTTPXMock):
65+
"""
66+
Tests that calling `set_custom_spelling()` on the `TranscriptionConfig`
67+
will result in correct `custom_spelling` in the request body, and that the
68+
response is properly parsed into the `custom_spelling` field.
69+
"""
70+
71+
mock_response = factories.generate_dict_factory(CustomSpellingResponseFactory)()
72+
73+
# Set up the custom spelling config based on the mocked values
74+
from_ = mock_response["custom_spelling"][0]["from"]
75+
to = mock_response["custom_spelling"][0]["to"]
76+
77+
config = aai.TranscriptionConfig().set_custom_spelling({to: from_})
78+
79+
request_body, transcript = unit_test_utils.submit_mock_transcription_request(
80+
httpx_mock,
81+
mock_response=mock_response,
82+
config=config,
83+
)
84+
85+
# Check that request body was properly defined
86+
custom_spelling_response = request_body["custom_spelling"]
87+
assert custom_spelling_response is not None and len(custom_spelling_response) > 0
88+
assert "from" in custom_spelling_response[0]
89+
assert "to" in custom_spelling_response[0]
90+
91+
# Check that transcript has no errors and custom spelling response corresponds to request
92+
assert transcript.error is None
93+
assert transcript.json_response["custom_spelling"] == custom_spelling_response

0 commit comments

Comments
 (0)