Skip to content

Commit

Permalink
Cleanup download of huggingface
Browse files Browse the repository at this point in the history
Also change get_symlink_path to symlink_path

Signed-off-by: Daniel J Walsh <[email protected]>
  • Loading branch information
rhatdan committed Oct 8, 2024
1 parent 8f1bf21 commit e4d0910
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 53 deletions.
114 changes: 66 additions & 48 deletions ramalama/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,17 @@
"""


def download(store, model, directory, filename):
return run_cmd(
[
"huggingface-cli",
"download",
directory,
filename,
"--cache-dir",
store + "/repos/huggingface/.cache",
"--local-dir",
store + "/repos/huggingface/" + directory,
]
)


def try_download(store, model, directory, filename):
try:
proc = download(store, model, directory, filename)
return proc.stdout.decode("utf-8")
except FileNotFoundError as e:
raise NotImplementedError(
"""\
%s
%s"""
% (str(e).strip("'"), missing_huggingface)
)


class Huggingface(Model):
def __init__(self, model):
super().__init__(model.removeprefix("huggingface://"))
self.type = "HuggingFace"
split = self.model.rsplit("/", 1)
self.directory = ""
if len(split) > 1:
self.directory = split[0]
self.filename = split[1]
else:
self.filename = split[0]

def login(self, args):
conman_args = ["huggingface-cli", "login"]
Expand All @@ -64,20 +43,19 @@ def logout(self, args):
conman_args.extend(args)
self.exec(conman_args)


def path(self, args):
return self.symlink_path(args)

def pull(self, args):
split = self.model.rsplit("/", 1)
directory = ""
if len(split) > 1:
directory = split[0]
filename = split[1]
else:
filename = split[0]
relative_target_path=""
symlink_path = self.symlink_path(args)

gguf_path = try_download(args.store, self.model, directory, filename)
directory = f"{args.store}/models/huggingface/{directory}"
os.makedirs(directory, exist_ok=True)
symlink_path = f"{directory}/{filename}"
gguf_path = self.download(args.store)
relative_target_path = os.path.relpath(gguf_path.rstrip(), start=os.path.dirname(symlink_path))
directory = f"{args.store}/models/huggingface/{self.directory}"
os.makedirs(directory, exist_ok=True)

if os.path.exists(symlink_path) and os.readlink(symlink_path) == relative_target_path:
# Symlink is already correct, no need to update it
return symlink_path
Expand All @@ -86,16 +64,33 @@ def pull(self, args):

return symlink_path

def get_symlink_path(self, args):
split = self.model.rsplit("/", 1)
directory = ""
if len(split) > 1:
directory = split[0]
filename = split[1]
else:
filename = split[0]
def push(self, source, args):
try:
proc = run_cmd(
[
"huggingface-cli",
"upload",
"--repo-type",
"model",
self.directory,
self.filename,
"--cache-dir",
store + "/repos/huggingface/.cache",
"--local-dir",
store + "/repos/huggingface/" + self.directory,
]
)
return proc.stdout.decode("utf-8")
except FileNotFoundError as e:
raise NotImplementedError(
"""\
%s
%s"""
% (str(e).strip("'"), missing_huggingface)
)

return f"{args.store}/models/huggingface/{directory}/{filename}"
def symlink_path(self, args):
return f"{args.store}/models/huggingface/{self.directory}/{self.filename}"

def exec(self, args):
try:
Expand All @@ -110,3 +105,26 @@ def exec(self, args):
% str(e).strip("'"),
missing_huggingface,
)

def download(self, store):
try:
proc = run_cmd(
[
"huggingface-cli",
"download",
self.directory,
self.filename,
"--cache-dir",
store + "/repos/huggingface/.cache",
"--local-dir",
store + "/repos/huggingface/" + self.directory,
]
)
return proc.stdout.decode("utf-8")
except FileNotFoundError as e:
raise NotImplementedError(
"""\
%s
%s"""
% (str(e).strip("'"), missing_huggingface)
)
6 changes: 3 additions & 3 deletions ramalama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def garbage_collection(self, args):
print(f"Deleted: {file_path}")

def remove(self, args):
symlink_path = self.get_symlink_path(args)
symlink_path = self.symlink_path(args)
if os.path.exists(symlink_path):
try:
os.remove(symlink_path)
Expand All @@ -73,8 +73,8 @@ def remove(self, args):

self.garbage_collection(args)

def get_symlink_path(self, args):
raise NotImplementedError(f"get_symlink_path for {self.type} not implemented")
def symlink_path(self, args):
raise NotImplementedError(f"symlink_path for {self.type} not implemented")

def run(self, args):
prompt = "You are a helpful assistant"
Expand Down
2 changes: 1 addition & 1 deletion ramalama/oci.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def pull(self, args):

return symlink_path

def get_symlink_path(self, args):
def symlink_path(self, args):
registry, reference = self.model.split("/", 1)
reference_dir = reference.replace(":", "/")
path = f"{args.store}/models/oci/{registry}/{reference_dir}"
Expand Down
2 changes: 1 addition & 1 deletion ramalama/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def pull(self, args):
repos, manifests, accept, registry_head, model_name, model_tag, models, symlink_path, self.model
)

def get_symlink_path(self, args):
def symlink_path(self, args):
models = args.store + "/models/ollama"
if "/" in self.model:
model_full = self.model
Expand Down

0 comments on commit e4d0910

Please sign in to comment.