Skip to content

Commit 3f178ad

Browse files
committed
🚧 fix(wip): cuda check
1 parent 529d9e1 commit 3f178ad

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

magnet/ize/memory.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from magnet.utils.globals import _f
33
from magnet.utils.milvus import *
44
from magnet.utils.data_classes import EmbeddingPayload
5+
from magnet.utils.globals import Utils
56

67
class Embedder:
78
"""
@@ -24,7 +25,7 @@ class Embedder:
2425

2526
def __init__(self, config, create=False, initialize=False):
2627
self.config = config
27-
self.model = SentenceTransformer(self.config['MODEL'])
28+
self.model = SentenceTransformer(self.config['MODEL'], device=Utils.check_cuda())
2829
_f('info', f'loading into {self.model.device}')
2930
self.db = MilvusDB(self.config)
3031
self.db.on()

magnet/utils/globals.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,10 @@ def check_cuda(self):
154154
_f(
155155
"info", f"GPU Name - {torch.cuda.get_device_name(0)}"
156156
) # 0 is the GPU index
157-
return True
157+
return 'cuda'
158158
else:
159159
_f("warn", "CUDA is not available on this machine.")
160+
return 'cpu'
160161

161162
def normalize_text(self, _):
162163
"""

0 commit comments

Comments
 (0)