Skip to content

Commit

Permalink
fix: Client's handling of Enum Values, to support Python >=3.11 enum …
Browse files Browse the repository at this point in the history
…handling. (#4672)

# Description

The argilla client for v0 Search and Metrics appears to be broken in
Python 3.11 and greater. This comes from a simple mis-resolution of enum
values, due to some of the poorly documented python 3.11. In short,
given this definition of the TaskType enum:

```
class TaskType(str, Enum):
    text_classification = "TextClassification"
    token_classification = "TokenClassification"
    text2text = "Text2Text"
```

in Python 3.10 and before, format strings and the format method will
return the value of the target enum:

```
>>> f'{TaskType.text_classification}'
'TextClassification'
```

However, in python 3.11 and later, format strings will produce a more
verbose description:

```
>>> f'{TaskType.text_classification}'
'TaskType.text_classification'
```

The client's Search and Metrics code both use the same pattern,
producing an incorrect _API_URL and thereby running into 404's:

```
class Search(AbstractApi):
    _API_URL_PATTERN = "/api/datasets/{name}/{task}:search"

    def search_records(
        self,
        name: str,
        task: TaskType,
        size: Optional[int] = None,
        **query,
    ):
    ...
    url = Search._API_URL_PATTERN.format(name=name, task=task)
    ...
    response = self.http_client.post(
            path=url,
            json={"query": query} if query else None,
        )
    ...
```

This PR just updates the way task is handled by the url format function
in both APIs, a method that produces the correct value in 3.8, 3.10,
3.11, and 3.12.0a7:

```
>>> f'{TaskType.text_classification.value}'                                                                                            
'TextClassification'
```

```
    url = Search._API_URL_PATTERN.format(name=name, task=task.value)
```

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [x] Bug fix (non-breaking change which fixes an issue)

**How Has This Been Tested**

This has been tested locally using pyenv to swap between python 3.8,
3.10, 3.11, and 3.12.0a7 to confirm 3.11 is the change point. Also,
tested this in-situ by trying to run AutoPrompt (which only supports
3.10 and below) with this change, and it appears this change is the last
thing stopping autoprompt from supporting 3.11!

**Checklist**

- [x] I followed the style guidelines of this project
- [x] I did a self-review of my code
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [ ] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)

---------

Co-authored-by: admin <[email protected]>
  • Loading branch information
tim-win and admin authored Mar 20, 2024
1 parent 54f2c1e commit 7dd6d57
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ These are the section headers that we use:
- Fixed parsing ranking values in suggestions from HF datasets. ([#4629](https://github.com/argilla-io/argilla/pull/4629))
- Fixed reading description from API response payload. ([#4632](https://github.com/argilla-io/argilla/pull/4632))
- Fixed pulling (n*chunk_size)+1 records when using `ds.pull` or iterating over the dataset. ([#4662](https://github.com/argilla-io/argilla/pull/4662))
- Fixed client's resolution of enum values when calling the Search and Metrics api, to support Python >=3.11 enum handling. ([#4672](https://github.com/argilla-io/argilla/pull/4672))

## [1.25.0](https://github.com/argilla-io/argilla/compare/v1.24.0...v1.25.0)

Expand Down
2 changes: 1 addition & 1 deletion src/argilla/client/apis/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def metric_summary(
query: Optional[str] = None,
**metric_params,
):
url = self._API_URL_PATTERN.format(task=task, name=name, metric=metric)
url = self._API_URL_PATTERN.format(task=task.value, name=name, metric=metric)
metric_params = metric_params or {}
query_params = {k: v for k, v in metric_params.items() if v is not None}
if query_params:
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/client/apis/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def search_records(
else:
raise ValueError(f"Task {task} not supported")

url = self._API_URL_PATTERN.format(name=name, task=task)
url = Search._API_URL_PATTERN.format(name=name, task=task.value)
if size:
url += f"?limit={size}"

Expand Down

0 comments on commit 7dd6d57

Please sign in to comment.