Skip to content

Commit c559eb2

Browse files
committed
More tests, consistently import bfloat16
1 parent e40c056 commit c559eb2

File tree

5 files changed

+124
-27
lines changed

5 files changed

+124
-27
lines changed

redisvl/query/aggregate.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
from redisvl.query.filter import FilterExpression
66
from redisvl.redis.utils import array_to_buffer
77
from redisvl.utils.token_escaper import TokenEscaper
8+
from redisvl.utils.utils import lazy_import
9+
10+
nltk = lazy_import("nltk")
11+
nltk_stopwords = lazy_import("nltk.corpus.stopwords")
812

913

1014
class AggregationQuery(AggregateRequest):
@@ -162,19 +166,13 @@ def _set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
162166
if not stopwords:
163167
self._stopwords = set()
164168
elif isinstance(stopwords, str):
165-
# Lazy import because nltk is an optional dependency
166169
try:
167-
from redisvl.utils.utils import lazy_import
168-
169-
nltk = lazy_import("nltk")
170-
nltk_stopwords = lazy_import("nltk.corpus.stopwords")
170+
nltk.download("stopwords", quiet=True)
171+
self._stopwords = set(nltk_stopwords.words(stopwords))
171172
except ImportError:
172173
raise ValueError(
173174
f"Loading stopwords for {stopwords} failed: nltk is not installed."
174175
)
175-
try:
176-
nltk.download("stopwords", quiet=True)
177-
self._stopwords = set(nltk_stopwords.words(stopwords))
178176
except Exception as e:
179177
raise ValueError(f"Error trying to load {stopwords} from nltk. {e}")
180178
elif isinstance(stopwords, (Set, List, Tuple)) and all( # type: ignore

redisvl/query/query.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
from redisvl.query.filter import FilterExpression
77
from redisvl.redis.utils import array_to_buffer
88
from redisvl.utils.token_escaper import TokenEscaper
9-
from redisvl.utils.utils import denorm_cosine_distance
9+
from redisvl.utils.utils import denorm_cosine_distance, lazy_import
10+
11+
nltk = lazy_import("nltk")
12+
nltk_stopwords = lazy_import("nltk.corpus.stopwords")
1013

1114

1215
class BaseQuery(RedisQuery):
@@ -893,19 +896,13 @@ def _set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
893896
if not stopwords:
894897
self._stopwords = set()
895898
elif isinstance(stopwords, str):
896-
# Lazy import because nltk is an optional dependency
897899
try:
898-
from redisvl.utils.utils import lazy_import
899-
900-
nltk = lazy_import("nltk")
901-
nltk_stopwords = lazy_import("nltk.corpus.stopwords")
900+
nltk.download("stopwords", quiet=True)
901+
self._stopwords = set(nltk_stopwords.words(stopwords))
902902
except ImportError:
903903
raise ValueError(
904904
f"Loading stopwords for {stopwords} failed: nltk is not installed."
905905
)
906-
try:
907-
nltk.download("stopwords", quiet=True)
908-
self._stopwords = set(nltk_stopwords.words(stopwords))
909906
except Exception as e:
910907
raise ValueError(f"Error trying to load {stopwords} from nltk. {e}")
911908
elif isinstance(stopwords, (Set, List, Tuple)) and all( # type: ignore

redisvl/redis/utils.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
# Lazy import numpy
77
np = lazy_import("numpy")
8-
# Lazy import ml_dtypes
9-
bfloat16 = lazy_import("ml_dtypes.bfloat16")
108

119
from redisvl.schema.fields import VectorDataType
1210

@@ -48,10 +46,9 @@ def array_to_buffer(array: List[float], dtype: str) -> bytes:
4846

4947
# Special handling for bfloat16 which requires explicit import from ml_dtypes
5048
if dtype.lower() == "bfloat16":
51-
# Import ml_dtypes.bfloat16 directly to ensure it's registered with numpy
52-
from ml_dtypes import bfloat16 as bf16
49+
from ml_dtypes import bfloat16
5350

54-
return np.array(array, dtype=bf16).tobytes()
51+
return np.array(array, dtype=bfloat16).tobytes()
5552

5653
return np.array(array, dtype=dtype.lower()).tobytes()
5754

@@ -66,11 +63,11 @@ def buffer_to_array(buffer: bytes, dtype: str) -> List[Any]:
6663
)
6764

6865
# Special handling for bfloat16 which requires explicit import from ml_dtypes
66+
# because otherwise the (lazily imported) numpy is unaware of the type
6967
if dtype.lower() == "bfloat16":
70-
# Import ml_dtypes.bfloat16 directly to ensure it's registered with numpy
71-
from ml_dtypes import bfloat16 as bf16
68+
from ml_dtypes import bfloat16
7269

73-
return np.frombuffer(buffer, dtype=bf16).tolist() # type: ignore[return-value]
70+
return np.frombuffer(buffer, dtype=bfloat16).tolist() # type: ignore[return-value]
7471

7572
return np.frombuffer(buffer, dtype=dtype.lower()).tolist() # type: ignore[return-value]
7673

redisvl/utils/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,15 +321,17 @@ def __getattr__(self, name: str) -> Any:
321321
else:
322322
# This means we couldn't find the attribute in the module path
323323
raise AttributeError(
324-
f"{self._parts[0]} has no attribute '{self._parts[1]}'"
324+
f"module '{self._parts[0]}' has no attribute '{self._parts[1]}'"
325325
)
326326

327327
# If we have a module, get the requested attribute
328328
if hasattr(self._module, name):
329329
return getattr(self._module, name)
330330

331331
# If the attribute doesn't exist, raise AttributeError
332-
raise AttributeError(f"{self._module_path} has no attribute '{name}'")
332+
raise AttributeError(
333+
f"module '{self._module_path}' has no attribute '{name}'"
334+
)
333335

334336
def __call__(self, *args: Any, **kwargs: Any) -> Any:
335337
# Import the module if it hasn't been imported yet

tests/unit/test_utils.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,26 @@ def test_import_standard_library(self):
549549
assert "json" in sys.modules
550550
assert result == '{"key": "value"}'
551551

552+
def test_cached_module_import(self):
553+
"""Test that _import_module returns the cached module if it exists"""
554+
# Remove the module from sys.modules if it's already imported
555+
if "json" in sys.modules:
556+
del sys.modules["json"]
557+
558+
# Lazy import the module
559+
json = lazy_import("json")
560+
561+
# Access an attribute to trigger the import
562+
json.dumps
563+
564+
# The module should now be cached
565+
# We need to access the private _import_module method directly
566+
# to test the cached path
567+
module = json._import_module()
568+
569+
# Verify that the cached module was returned
570+
assert module is json._module
571+
552572
def test_import_already_imported_module(self):
553573
"""Test lazy importing of an already imported module"""
554574
# Make sure the module is imported
@@ -618,6 +638,17 @@ def test_import_nonexistent_module(self):
618638

619639
assert "Failed to lazily import nonexistent_module_xyz" in str(excinfo.value)
620640

641+
def test_call_nonexistent_module(self):
642+
"""Test calling a nonexistent module"""
643+
# Lazy import a nonexistent module
644+
nonexistent = lazy_import("nonexistent_module_xyz")
645+
646+
# Calling the nonexistent module should raise ImportError
647+
with pytest.raises(ImportError) as excinfo:
648+
nonexistent()
649+
650+
assert "Failed to lazily import nonexistent_module_xyz" in str(excinfo.value)
651+
621652
def test_import_nonexistent_attribute(self):
622653
"""Test lazy importing of a nonexistent attribute"""
623654
# Lazy import a nonexistent attribute
@@ -631,6 +662,19 @@ def test_import_nonexistent_attribute(self):
631662
excinfo.value
632663
)
633664

665+
def test_getattr_on_nonexistent_attribute_path(self):
666+
"""Test accessing an attribute on a nonexistent attribute path"""
667+
# Lazy import a nonexistent attribute path
668+
nonexistent_attr = lazy_import("math.nonexistent_attribute")
669+
670+
# Accessing an attribute on the nonexistent attribute should raise AttributeError
671+
with pytest.raises(AttributeError) as excinfo:
672+
nonexistent_attr.some_attribute
673+
674+
assert "module 'math' has no attribute 'nonexistent_attribute'" in str(
675+
excinfo.value
676+
)
677+
634678
def test_import_noncallable(self):
635679
"""Test calling a non-callable lazy imported object"""
636680
# Lazy import a non-callable attribute
@@ -654,3 +698,62 @@ def test_attribute_error(self):
654698
assert "module 'math' has no attribute 'nonexistent_attribute'" in str(
655699
excinfo.value
656700
)
701+
702+
def test_attribute_error_after_import(self):
703+
"""Test accessing a nonexistent attribute on a module after it's been imported"""
704+
# Create a simple module with a known attribute
705+
import types
706+
707+
test_module = types.ModuleType("test_module")
708+
test_module.existing_attr = "exists"
709+
710+
# Add it to sys.modules so lazy_import can find it
711+
sys.modules["test_module"] = test_module
712+
713+
try:
714+
# Lazy import the module
715+
lazy_mod = lazy_import("test_module")
716+
717+
# Access the existing attribute to trigger the import
718+
assert lazy_mod.existing_attr == "exists"
719+
720+
# Now access a nonexistent attribute
721+
with pytest.raises(AttributeError) as excinfo:
722+
lazy_mod.nonexistent_attribute
723+
724+
assert (
725+
"module 'test_module' has no attribute 'nonexistent_attribute'"
726+
in str(excinfo.value)
727+
)
728+
finally:
729+
# Clean up
730+
if "test_module" in sys.modules:
731+
del sys.modules["test_module"]
732+
733+
def test_attribute_error_with_direct_module_access(self):
734+
"""Test accessing a nonexistent attribute by directly setting the _module attribute"""
735+
# Get the lazy_import function
736+
from redisvl.utils.utils import lazy_import
737+
738+
# Create a lazy import for a module that doesn't exist yet
739+
lazy_mod = lazy_import("test_direct_module")
740+
741+
# Create a simple object with no __getattr__ method
742+
class SimpleObject:
743+
pass
744+
745+
obj = SimpleObject()
746+
747+
# Directly set the _module attribute to our simple object
748+
# This bypasses the normal import mechanism
749+
lazy_mod._module = obj
750+
751+
# Now access a nonexistent attribute
752+
# This should go through our LazyModule.__getattr__ and hit line 332
753+
with pytest.raises(AttributeError) as excinfo:
754+
lazy_mod.nonexistent_attribute
755+
756+
assert (
757+
"module 'test_direct_module' has no attribute 'nonexistent_attribute'"
758+
in str(excinfo.value)
759+
)

0 commit comments

Comments
 (0)