Skip to content

Commit e40c056

Browse files
committed
Use new lazy_import function
1 parent 4984c89 commit e40c056

File tree

9 files changed

+54
-14
lines changed

9 files changed

+54
-14
lines changed

redisvl/cli/index.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@
22
import sys
33
from argparse import Namespace
44

5-
from tabulate import tabulate
6-
75
from redisvl.cli.utils import add_index_parsing_options, create_redis_url
86
from redisvl.index import SearchIndex
97
from redisvl.redis.connection import RedisConnectionFactory
108
from redisvl.redis.utils import convert_bytes, make_dict
119
from redisvl.schema.schema import IndexSchema
1210
from redisvl.utils.log import get_logger
11+
from redisvl.utils.utils import lazy_import
1312

1413
logger = get_logger("[RedisVL]")
1514

@@ -127,6 +126,8 @@ def _connect_to_index(self, args: Namespace) -> SearchIndex:
127126

128127

129128
def _display_in_table(index_info, output_format="rounded_outline"):
129+
tabulate = lazy_import("tabulate")
130+
130131
print("\n")
131132
attributes = index_info.get("attributes", [])
132133
definition = make_dict(index_info.get("index_definition"))

redisvl/cli/stats.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
import sys
33
from argparse import Namespace
44

5-
from tabulate import tabulate
6-
75
from redisvl.cli.utils import add_index_parsing_options, create_redis_url
86
from redisvl.index import SearchIndex
97
from redisvl.schema.schema import IndexSchema
108
from redisvl.utils.log import get_logger
9+
from redisvl.utils.utils import lazy_import
1110

1211
logger = get_logger("[RedisVL]")
1312

@@ -87,6 +86,8 @@ def _connect_to_index(self, args: Namespace) -> SearchIndex:
8786

8887

8988
def _display_stats(index_info, output_format="rounded_outline"):
89+
tabulate = lazy_import("tabulate")
90+
9091
# Extracting the statistics
9192
stats_data = [(key, str(index_info.get(key))) for key in STATS_KEYS]
9293

redisvl/query/aggregate.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,10 @@ def _set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
164164
elif isinstance(stopwords, str):
165165
# Lazy import because nltk is an optional dependency
166166
try:
167-
import nltk
168-
from nltk.corpus import stopwords as nltk_stopwords
167+
from redisvl.utils.utils import lazy_import
168+
169+
nltk = lazy_import("nltk")
170+
nltk_stopwords = lazy_import("nltk.corpus.stopwords")
169171
except ImportError:
170172
raise ValueError(
171173
f"Loading stopwords for {stopwords} failed: nltk is not installed."

redisvl/query/query.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -895,8 +895,10 @@ def _set_stopwords(self, stopwords: Optional[Union[str, Set[str]]] = "english"):
895895
elif isinstance(stopwords, str):
896896
# Lazy import because nltk is an optional dependency
897897
try:
898-
import nltk
899-
from nltk.corpus import stopwords as nltk_stopwords
898+
from redisvl.utils.utils import lazy_import
899+
900+
nltk = lazy_import("nltk")
901+
nltk_stopwords = lazy_import("nltk.corpus.stopwords")
900902
except ImportError:
901903
raise ValueError(
902904
f"Loading stopwords for {stopwords} failed: nltk is not installed."

redisvl/redis/utils.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import hashlib
22
from typing import Any, Dict, List, Optional
33

4-
import numpy as np
5-
from ml_dtypes import bfloat16
4+
from redisvl.utils.utils import lazy_import
5+
6+
# Lazy import numpy
7+
np = lazy_import("numpy")
8+
# Lazy import ml_dtypes
9+
bfloat16 = lazy_import("ml_dtypes.bfloat16")
610

711
from redisvl.schema.fields import VectorDataType
812

@@ -41,6 +45,14 @@ def array_to_buffer(array: List[float], dtype: str) -> bytes:
4145
raise ValueError(
4246
f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}"
4347
)
48+
49+
# Special handling for bfloat16 which requires explicit import from ml_dtypes
50+
if dtype.lower() == "bfloat16":
51+
# Import ml_dtypes.bfloat16 directly to ensure it's registered with numpy
52+
from ml_dtypes import bfloat16 as bf16
53+
54+
return np.array(array, dtype=bf16).tobytes()
55+
4456
return np.array(array, dtype=dtype.lower()).tobytes()
4557

4658

@@ -52,6 +64,14 @@ def buffer_to_array(buffer: bytes, dtype: str) -> List[Any]:
5264
raise ValueError(
5365
f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}"
5466
)
67+
68+
# Special handling for bfloat16 which requires explicit import from ml_dtypes
69+
if dtype.lower() == "bfloat16":
70+
# Import ml_dtypes.bfloat16 directly to ensure it's registered with numpy
71+
from ml_dtypes import bfloat16 as bf16
72+
73+
return np.frombuffer(buffer, dtype=bf16).tolist() # type: ignore[return-value]
74+
5575
return np.frombuffer(buffer, dtype=dtype.lower()).tolist() # type: ignore[return-value]
5676

5777

redisvl/utils/optimize/cache.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing import Any, Callable, Dict, List
22

3-
import numpy as np
3+
from redisvl.utils.utils import lazy_import
4+
5+
np = lazy_import("numpy")
46
from ranx import Qrels, Run, evaluate
57

68
from redisvl.extensions.cache.llm.semantic import SemanticCache

redisvl/utils/optimize/router.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import random
22
from typing import Any, Callable, Dict, List
33

4-
import numpy as np
4+
from redisvl.utils.utils import lazy_import
5+
6+
np = lazy_import("numpy")
57
from ranx import Qrels, Run, evaluate
68

79
from redisvl.extensions.router.semantic import SemanticRouter

redisvl/utils/optimize/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing import List
22

3-
import numpy as np
3+
from redisvl.utils.utils import lazy_import
4+
5+
np = lazy_import("numpy")
46
from ranx import Qrels
57

68
from redisvl.utils.optimize.schema import LabeledData

tests/unit/test_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,15 @@ def test_empty_list_to_bytes():
165165
def test_conversion_with_various_dtypes(dtype):
166166
"""Test conversion of a list of floats to bytes with various dtypes"""
167167
array = [1.0, -2.0, 3.5]
168-
expected = np.array(array, dtype=dtype).tobytes()
168+
169+
# Special handling for bfloat16 which requires explicit import from ml_dtypes
170+
if dtype == "bfloat16":
171+
from ml_dtypes import bfloat16 as bf16
172+
173+
expected = np.array(array, dtype=bf16).tobytes()
174+
else:
175+
expected = np.array(array, dtype=dtype).tobytes()
176+
169177
assert array_to_buffer(array, dtype=dtype) == expected
170178

171179

0 commit comments

Comments
 (0)