From 699a8d55d98c7cf20f5097312c4fd712ccd5b04a Mon Sep 17 00:00:00 2001 From: sjoerdk Date: Tue, 17 Sep 2024 16:09:19 +0200 Subject: [PATCH] Releases pydantic pin, upgrades to pydantic v2 Refs #54 --- dicomtrolley/core.py | 12 ++++++------ dicomtrolley/mint.py | 24 +++++++++++++---------- dicomtrolley/qido_rs.py | 40 +++++++++++++++++++-------------------- pyproject.toml | 2 +- tests/test_integration.py | 2 +- tests/test_rad69.py | 10 +++++----- 6 files changed, 46 insertions(+), 44 deletions(-) diff --git a/dicomtrolley/core.py b/dicomtrolley/core.py index 0e97355..b45d793 100644 --- a/dicomtrolley/core.py +++ b/dicomtrolley/core.py @@ -13,8 +13,7 @@ TypeVar, ) -from pydantic import Field, ValidationError -from pydantic.class_validators import validator +from pydantic import Field, ValidationError, field_validator from pydantic.main import BaseModel from pydicom.datadict import tag_for_keyword from pydicom.dataset import Dataset @@ -608,8 +607,8 @@ class Query(BaseModel): query_level: QueryLevels = ( QueryLevels.STUDY ) # to which depth to return results - max_study_date: Optional[datetime] - min_study_date: Optional[datetime] + max_study_date: Optional[datetime] = None + min_study_date: Optional[datetime] = None include_fields: List[str] = Field([]) # class Config: @@ -657,7 +656,8 @@ def validate_keyword(keyword): if not tag_for_keyword(keyword): raise ValueError(f"{keyword} is not a valid DICOM keyword") - @validator("include_fields") + @field_validator("include_fields") + @classmethod def include_fields_check(cls, include_fields, values): # noqa: B902, N805 """Include fields should be valid dicom tag names""" for field in include_fields: @@ -684,7 +684,7 @@ class ExtendedQuery(Query): StudyDescription: str = "" SeriesDescription: str = "" InstitutionalDepartmentName: str = "" - PatientBirthDate: Optional[date] + PatientBirthDate: Optional[date] = None class Searcher: diff --git a/dicomtrolley/mint.py b/dicomtrolley/mint.py index cfa511f..a89f014 100644 --- a/dicomtrolley/mint.py +++ b/dicomtrolley/mint.py @@ -6,7 +6,7 @@ from xml.etree import ElementTree from xml.etree.ElementTree import ParseError -from pydantic.class_validators import root_validator +from pydantic import model_validator from pydicom.dataelem import DataElement from pydicom.dataset import Dataset @@ -159,11 +159,11 @@ class MintQuery(ExtendedQuery): limit: int = 0 # how many results to return. 0 = all - @root_validator() + @model_validator(mode="after") def min_max_study_date_xor(cls, values): # noqa: B902, N805 """Min and max should both be given or both be empty""" - min_date = values.get("min_study_date") - max_date = values.get("max_study_date") + min_date = values.min_study_date + max_date = values.max_study_date if min_date and not max_date: raise ValueError( f"min_study_date parameter was passed" @@ -177,14 +177,18 @@ def min_max_study_date_xor(cls, values): # noqa: B902, N805 ) return values - @root_validator() - def include_fields_check(cls, values): # noqa: B902, N805 + @model_validator(mode="after") + def include_fields_check(self): """Include fields should match query level""" - include_fields = values.get("include_fields") + if isinstance(self, list): + # Interplay with base Query field_validator for include fields + return self # don't check + else: + include_fields = self.include_fields if not include_fields: - return values # May not exist if include_fields is invalid type + return self # May not exist if include_fields is invalid type - query_level = values.get("query_level") + query_level = self.query_level if query_level: # May be None for child classes valid_fields = get_valid_fields(query_level=query_level) for field in include_fields: @@ -193,7 +197,7 @@ def include_fields_check(cls, values): # noqa: B902, N805 f'"{field}" is not a valid include field for query ' f"level {query_level}. Valid fields: {valid_fields}" ) - return values + return self def __str__(self): return str(self.as_parameters()) diff --git a/dicomtrolley/qido_rs.py b/dicomtrolley/qido_rs.py index 2d1ee64..ef8ad28 100644 --- a/dicomtrolley/qido_rs.py +++ b/dicomtrolley/qido_rs.py @@ -11,7 +11,7 @@ from datetime import datetime from typing import Dict, List, Optional, Sequence, Union -from pydantic import root_validator +from pydantic import model_validator from pydicom import Dataset from requests import Response @@ -39,11 +39,11 @@ class QidoRSQueryBase(Query): limit: int = 0 # How many results to return. 0 = all offset: int = 0 # Number of skipped results - @root_validator() # type: ignore - def min_max_study_date_xor(cls, values): # noqa: B902, N805 + @model_validator(mode="after") + def min_max_study_date_xor(self): # noqa: B902, N805 """Min and max should both be given or both be empty""" - min_date = values.get("min_study_date") - max_date = values.get("max_study_date") + min_date = self.min_study_date + max_date = self.max_study_date if min_date and not max_date: raise ValueError( f"min_study_date parameter was passed" @@ -55,7 +55,7 @@ def min_max_study_date_xor(cls, values): # noqa: B902, N805 f"max_study_date parameter was passed ({max_date}), " f"but min_study_date was not. Both need to be given" ) - return values + return self @staticmethod def date_to_str(date_in: Optional[datetime]) -> str: @@ -158,8 +158,8 @@ class HierarchicalQuery(QidoRSQueryBase): Faster than relationalQuery, but requires more information """ - @root_validator() # type: ignore - def uids_should_be_hierarchical(cls, values): # noqa: B902, N805 + @model_validator(mode="after") + def uids_should_be_hierarchical(self): """Any object uids passed should conform to study->series->instance""" order = ["StudyInstanceUID", "SeriesInstanceUID", "SOPInstanceUID"] @@ -182,14 +182,13 @@ def assert_parents_filled(a_hierarchy, value_dict): else: return assert_parents_filled(a_hierarchy, value_dict) - assert_parents_filled(order, values) + assert_parents_filled(order, self.dict()) + return self - return values - - @root_validator() # type: ignore - def uids_should_match_query_level(cls, values): # noqa: B902, N805 + @model_validator(mode="after") + def uids_should_match_query_level(self): """If a query is for instance level, there should be study and series UIDs""" - query_level = values["query_level"] + query_level = self.query_level def assert_key_exists(values_in, query_level_in, missing_key_in): if not values_in.get(missing_key_in): @@ -199,6 +198,7 @@ def assert_key_exists(values_in, query_level_in, missing_key_in): f"a QIDO-RS relational query" ) + values = self.dict() if query_level == QueryLevels.STUDY: pass # Fine. you can always look for some studies elif query_level == QueryLevels.SERIES: @@ -207,7 +207,7 @@ def assert_key_exists(values_in, query_level_in, missing_key_in): assert_key_exists(values, query_level, "SeriesInstanceUID") assert_key_exists(values, query_level, "StudyInstanceUID") - return values + return self def uri_base(self) -> str: """WADO-RS url to call when performing this query. Full URI also needs @@ -294,17 +294,15 @@ class RelationalQuery(QidoRSQueryBase): Allows broader searches than HierarchicalQuery, but can be slower """ - @root_validator() # type: ignore - def query_level_should_be_series_or_instance( - cls, values # noqa: B902, N805 - ): + @model_validator(mode="after") + def query_level_should_be_series_or_instance(self): """A relational query only makes sense for the instance and series levels. If you want to look for studies, us a hierarchical query """ - if values.get("query_level") == QueryLevels.STUDY: + if self.query_level == QueryLevels.STUDY: raise ValueError(STUDY_VALUE_ERROR_TEXT) - return values + return self def uri_base(self) -> str: """WADO-RS url to call when performing this query. Full URI also needs diff --git a/pyproject.toml b/pyproject.toml index fa77630..89fce06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ requests-futures = "^1.0.0" pynetdicom = "^1.5.6" Jinja2 = "^3.0.3" requests-toolbelt = "^1.0.0" -pydantic = "1.8.2" +pydantic = "^2.9.1" [tool.poetry.dev-dependencies] pytest = "^7.4.0" diff --git a/tests/test_integration.py b/tests/test_integration.py index b1df17e..fcb9dfc 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -39,7 +39,7 @@ def test_dicom_query_mint_cast(requests_mock, a_mint): set_mock_response(requests_mock, MINT_SEARCH_INSTANCE_LEVEL_ANY) with pytest.raises(DICOMTrolleyError): # should fail, casting to mint would lose unsupported StudyID parameter - a_mint.find_studies(DICOMQuery(StudyID=123)) + a_mint.find_studies(DICOMQuery(StudyID="123")) def test_from_query(): diff --git a/tests/test_rad69.py b/tests/test_rad69.py index ed90ddd..26fc155 100644 --- a/tests/test_rad69.py +++ b/tests/test_rad69.py @@ -85,9 +85,9 @@ def test_rad69_error_from_server( with pytest.raises(DICOMTrolleyError) as e: a_rad69.get_dataset( InstanceReference( - study_uid=1, - series_uid=2, - instance_uid=3, + study_uid="1", + series_uid="2", + instance_uid="3", ) ) assert re.match(error_contains, str(e)) @@ -244,8 +244,8 @@ def test_wado_datasets_async(a_rad69, requests_mock): ) instances = [ - InstanceReference(study_uid=1, series_uid=2, instance_uid=3), - InstanceReference(study_uid=4, series_uid=5, instance_uid=6), + InstanceReference(study_uid="1", series_uid="2", instance_uid="3"), + InstanceReference(study_uid="4", series_uid="5", instance_uid="6"), ] a_rad69.use_async = True datasets = [x for x in a_rad69.datasets(instances)]