Skip to content

Commit 89943ec

Browse files
authored
Use new lazy_import function to lazily import third-party libraries (#331)
Uses the new lazy_import utility from: #330
1 parent 8435e05 commit 89943ec

File tree

10 files changed

+162
-21
lines changed

10 files changed

+162
-21
lines changed

redisvl/cli/index.py

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from redisvl.redis.utils import convert_bytes, make_dict
99
from redisvl.schema.schema import IndexSchema
1010
from redisvl.utils.log import get_logger
11+
from redisvl.utils.utils import lazy_import
1112

1213
logger = get_logger("[RedisVL]")
1314

@@ -125,6 +126,8 @@ def _connect_to_index(self, args: Namespace) -> SearchIndex:
125126

126127

127128
def _display_in_table(index_info, output_format="rounded_outline"):
129+
tabulate = lazy_import("tabulate")
130+
128131
print("\n")
129132
attributes = index_info.get("attributes", [])
130133
definition = make_dict(index_info.get("index_definition"))

redisvl/cli/stats.py

+3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from redisvl.index import SearchIndex
77
from redisvl.schema.schema import IndexSchema
88
from redisvl.utils.log import get_logger
9+
from redisvl.utils.utils import lazy_import
910

1011
logger = get_logger("[RedisVL]")
1112

@@ -85,6 +86,8 @@ def _connect_to_index(self, args: Namespace) -> SearchIndex:
8586

8687

8788
def _display_stats(index_info, output_format="rounded_outline"):
89+
tabulate = lazy_import("tabulate")
90+
8891
# Extracting the statistics
8992
stats_data = [(key, str(index_info.get(key))) for key in STATS_KEYS]
9093

redisvl/query/aggregate.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
from redisvl.query.filter import FilterExpression
66
from redisvl.redis.utils import array_to_buffer
77
from redisvl.utils.token_escaper import TokenEscaper
8+
from redisvl.utils.utils import lazy_import
9+
10+
nltk = lazy_import("nltk")
11+
nltk_stopwords = lazy_import("nltk.corpus.stopwords")
812

913

1014
class AggregationQuery(AggregateRequest):
@@ -162,17 +166,13 @@ def _set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
162166
if not stopwords:
163167
self._stopwords = set()
164168
elif isinstance(stopwords, str):
165-
# Lazy import because nltk is an optional dependency
166169
try:
167-
import nltk
168-
from nltk.corpus import stopwords as nltk_stopwords
170+
nltk.download("stopwords", quiet=True)
171+
self._stopwords = set(nltk_stopwords.words(stopwords))
169172
except ImportError:
170173
raise ValueError(
171174
f"Loading stopwords for {stopwords} failed: nltk is not installed."
172175
)
173-
try:
174-
nltk.download("stopwords", quiet=True)
175-
self._stopwords = set(nltk_stopwords.words(stopwords))
176176
except Exception as e:
177177
raise ValueError(f"Error trying to load {stopwords} from nltk. {e}")
178178
elif isinstance(stopwords, (Set, List, Tuple)) and all( # type: ignore

redisvl/query/query.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
from redisvl.query.filter import FilterExpression
77
from redisvl.redis.utils import array_to_buffer
88
from redisvl.utils.token_escaper import TokenEscaper
9-
from redisvl.utils.utils import denorm_cosine_distance
9+
from redisvl.utils.utils import denorm_cosine_distance, lazy_import
10+
11+
nltk = lazy_import("nltk")
12+
nltk_stopwords = lazy_import("nltk.corpus.stopwords")
1013

1114

1215
class BaseQuery(RedisQuery):
@@ -893,17 +896,13 @@ def _set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
893896
if not stopwords:
894897
self._stopwords = set()
895898
elif isinstance(stopwords, str):
896-
# Lazy import because nltk is an optional dependency
897899
try:
898-
import nltk
899-
from nltk.corpus import stopwords as nltk_stopwords
900+
nltk.download("stopwords", quiet=True)
901+
self._stopwords = set(nltk_stopwords.words(stopwords))
900902
except ImportError:
901903
raise ValueError(
902904
f"Loading stopwords for {stopwords} failed: nltk is not installed."
903905
)
904-
try:
905-
nltk.download("stopwords", quiet=True)
906-
self._stopwords = set(nltk_stopwords.words(stopwords))
907906
except Exception as e:
908907
raise ValueError(f"Error trying to load {stopwords} from nltk. {e}")
909908
elif isinstance(stopwords, (Set, List, Tuple)) and all( # type: ignore

redisvl/redis/utils.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import hashlib
22
from typing import Any, Dict, List, Optional
33

4-
import numpy as np
5-
from ml_dtypes import bfloat16
4+
from redisvl.utils.utils import lazy_import
5+
6+
# Lazy import numpy
7+
np = lazy_import("numpy")
68

79
from redisvl.schema.fields import VectorDataType
810

@@ -41,6 +43,13 @@ def array_to_buffer(array: List[float], dtype: str) -> bytes:
4143
raise ValueError(
4244
f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}"
4345
)
46+
47+
# Special handling for bfloat16 which requires explicit import from ml_dtypes
48+
if dtype.lower() == "bfloat16":
49+
from ml_dtypes import bfloat16
50+
51+
return np.array(array, dtype=bfloat16).tobytes()
52+
4453
return np.array(array, dtype=dtype.lower()).tobytes()
4554

4655

@@ -52,6 +61,14 @@ def buffer_to_array(buffer: bytes, dtype: str) -> List[Any]:
5261
raise ValueError(
5362
f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}"
5463
)
64+
65+
# Special handling for bfloat16 which requires explicit import from ml_dtypes
66+
# because otherwise the (lazily imported) numpy is unaware of the type
67+
if dtype.lower() == "bfloat16":
68+
from ml_dtypes import bfloat16
69+
70+
return np.frombuffer(buffer, dtype=bfloat16).tolist() # type: ignore[return-value]
71+
5572
return np.frombuffer(buffer, dtype=dtype.lower()).tolist() # type: ignore[return-value]
5673

5774

redisvl/utils/optimize/cache.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing import Any, Callable, Dict, List
22

3-
import numpy as np
3+
from redisvl.utils.utils import lazy_import
4+
5+
np = lazy_import("numpy")
46
from ranx import Qrels, Run, evaluate
57

68
from redisvl.extensions.cache.llm.semantic import SemanticCache

redisvl/utils/optimize/router.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import random
22
from typing import Any, Callable, Dict, List
33

4-
import numpy as np
4+
from redisvl.utils.utils import lazy_import
5+
6+
np = lazy_import("numpy")
57
from ranx import Qrels, Run, evaluate
68

79
from redisvl.extensions.router.semantic import SemanticRouter

redisvl/utils/optimize/utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing import List
22

3-
import numpy as np
3+
from redisvl.utils.utils import lazy_import
4+
5+
np = lazy_import("numpy")
46
from ranx import Qrels
57

68
from redisvl.utils.optimize.schema import LabeledData

redisvl/utils/utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -321,15 +321,17 @@ def __getattr__(self, name: str) -> Any:
321321
else:
322322
# This means we couldn't find the attribute in the module path
323323
raise AttributeError(
324-
f"{self._parts[0]} has no attribute '{self._parts[1]}'"
324+
f"module '{self._parts[0]}' has no attribute '{self._parts[1]}'"
325325
)
326326

327327
# If we have a module, get the requested attribute
328328
if hasattr(self._module, name):
329329
return getattr(self._module, name)
330330

331331
# If the attribute doesn't exist, raise AttributeError
332-
raise AttributeError(f"{self._module_path} has no attribute '{name}'")
332+
raise AttributeError(
333+
f"module '{self._module_path}' has no attribute '{name}'"
334+
)
333335

334336
def __call__(self, *args: Any, **kwargs: Any) -> Any:
335337
# Import the module if it hasn't been imported yet

tests/unit/test_utils.py

+112-1
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,15 @@ def test_empty_list_to_bytes():
165165
def test_conversion_with_various_dtypes(dtype):
166166
"""Test conversion of a list of floats to bytes with various dtypes"""
167167
array = [1.0, -2.0, 3.5]
168-
expected = np.array(array, dtype=dtype).tobytes()
168+
169+
# Special handling for bfloat16 which requires explicit import from ml_dtypes
170+
if dtype == "bfloat16":
171+
from ml_dtypes import bfloat16 as bf16
172+
173+
expected = np.array(array, dtype=bf16).tobytes()
174+
else:
175+
expected = np.array(array, dtype=dtype).tobytes()
176+
169177
assert array_to_buffer(array, dtype=dtype) == expected
170178

171179

@@ -541,6 +549,26 @@ def test_import_standard_library(self):
541549
assert "json" in sys.modules
542550
assert result == '{"key": "value"}'
543551

552+
def test_cached_module_import(self):
553+
"""Test that _import_module returns the cached module if it exists"""
554+
# Remove the module from sys.modules if it's already imported
555+
if "json" in sys.modules:
556+
del sys.modules["json"]
557+
558+
# Lazy import the module
559+
json = lazy_import("json")
560+
561+
# Access an attribute to trigger the import
562+
json.dumps
563+
564+
# The module should now be cached
565+
# We need to access the private _import_module method directly
566+
# to test the cached path
567+
module = json._import_module()
568+
569+
# Verify that the cached module was returned
570+
assert module is json._module
571+
544572
def test_import_already_imported_module(self):
545573
"""Test lazy importing of an already imported module"""
546574
# Make sure the module is imported
@@ -610,6 +638,17 @@ def test_import_nonexistent_module(self):
610638

611639
assert "Failed to lazily import nonexistent_module_xyz" in str(excinfo.value)
612640

641+
def test_call_nonexistent_module(self):
642+
"""Test calling a nonexistent module"""
643+
# Lazy import a nonexistent module
644+
nonexistent = lazy_import("nonexistent_module_xyz")
645+
646+
# Calling the nonexistent module should raise ImportError
647+
with pytest.raises(ImportError) as excinfo:
648+
nonexistent()
649+
650+
assert "Failed to lazily import nonexistent_module_xyz" in str(excinfo.value)
651+
613652
def test_import_nonexistent_attribute(self):
614653
"""Test lazy importing of a nonexistent attribute"""
615654
# Lazy import a nonexistent attribute
@@ -623,6 +662,19 @@ def test_import_nonexistent_attribute(self):
623662
excinfo.value
624663
)
625664

665+
def test_getattr_on_nonexistent_attribute_path(self):
666+
"""Test accessing an attribute on a nonexistent attribute path"""
667+
# Lazy import a nonexistent attribute path
668+
nonexistent_attr = lazy_import("math.nonexistent_attribute")
669+
670+
# Accessing an attribute on the nonexistent attribute should raise AttributeError
671+
with pytest.raises(AttributeError) as excinfo:
672+
nonexistent_attr.some_attribute
673+
674+
assert "module 'math' has no attribute 'nonexistent_attribute'" in str(
675+
excinfo.value
676+
)
677+
626678
def test_import_noncallable(self):
627679
"""Test calling a non-callable lazy imported object"""
628680
# Lazy import a non-callable attribute
@@ -646,3 +698,62 @@ def test_attribute_error(self):
646698
assert "module 'math' has no attribute 'nonexistent_attribute'" in str(
647699
excinfo.value
648700
)
701+
702+
def test_attribute_error_after_import(self):
703+
"""Test accessing a nonexistent attribute on a module after it's been imported"""
704+
# Create a simple module with a known attribute
705+
import types
706+
707+
test_module = types.ModuleType("test_module")
708+
test_module.existing_attr = "exists"
709+
710+
# Add it to sys.modules so lazy_import can find it
711+
sys.modules["test_module"] = test_module
712+
713+
try:
714+
# Lazy import the module
715+
lazy_mod = lazy_import("test_module")
716+
717+
# Access the existing attribute to trigger the import
718+
assert lazy_mod.existing_attr == "exists"
719+
720+
# Now access a nonexistent attribute
721+
with pytest.raises(AttributeError) as excinfo:
722+
lazy_mod.nonexistent_attribute
723+
724+
assert (
725+
"module 'test_module' has no attribute 'nonexistent_attribute'"
726+
in str(excinfo.value)
727+
)
728+
finally:
729+
# Clean up
730+
if "test_module" in sys.modules:
731+
del sys.modules["test_module"]
732+
733+
def test_attribute_error_with_direct_module_access(self):
734+
"""Test accessing a nonexistent attribute by directly setting the _module attribute"""
735+
# Get the lazy_import function
736+
from redisvl.utils.utils import lazy_import
737+
738+
# Create a lazy import for a module that doesn't exist yet
739+
lazy_mod = lazy_import("test_direct_module")
740+
741+
# Create a simple object with no __getattr__ method
742+
class SimpleObject:
743+
pass
744+
745+
obj = SimpleObject()
746+
747+
# Directly set the _module attribute to our simple object
748+
# This bypasses the normal import mechanism
749+
lazy_mod._module = obj
750+
751+
# Now access a nonexistent attribute
752+
# This should go through our LazyModule.__getattr__ and hit line 332
753+
with pytest.raises(AttributeError) as excinfo:
754+
lazy_mod.nonexistent_attribute
755+
756+
assert (
757+
"module 'test_direct_module' has no attribute 'nonexistent_attribute'"
758+
in str(excinfo.value)
759+
)

0 commit comments

Comments
 (0)