Skip to content

Commit

Permalink
[REFACTOR] Terminology download=>download_cache (#2425)
Browse files Browse the repository at this point in the history
This PR renames download to download_cache for better clarity.
  • Loading branch information
tqchen authored May 26, 2024
1 parent 8b38a4b commit ff91749
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 20 deletions.
2 changes: 1 addition & 1 deletion python/mlc_llm/chat_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def _get_model_path(model: str) -> Tuple[str, str]:
FileNotFoundError: if we cannot find a valid `model_path`.
"""
if model.startswith("HF://"):
from mlc_llm.support.download import ( # pylint: disable=import-outside-toplevel
from mlc_llm.support.download_cache import ( # pylint: disable=import-outside-toplevel
download_and_cache_mlc_weights,
)

Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/cli/delivery.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from mlc_llm.support import logging
from mlc_llm.support.argparse import ArgumentParser
from mlc_llm.support.constants import MLC_TEMP_DIR
from mlc_llm.support.download import git_clone
from mlc_llm.support.download_cache import git_clone
from mlc_llm.support.style import bold, green, red

logging.enable_logging()
Expand Down
4 changes: 2 additions & 2 deletions python/mlc_llm/interface/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Any, Dict, List, Literal

from mlc_llm.interface import jit
from mlc_llm.support import download, logging, style
from mlc_llm.support import download_cache, logging, style

logging.enable_logging()
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -70,7 +70,7 @@ def build_model_library( # pylint: disable=too-many-branches,too-many-locals,to
raise ValueError('The value of "model_lib" in "model_list" is expected to be string.')

# - Load model config. Download happens when needed.
model_path = download.get_or_download_model(model)
model_path = download_cache.get_or_download_model(model)

# - Jit compile if the model lib path is not specified.
model_lib_path = (
Expand Down
4 changes: 2 additions & 2 deletions python/mlc_llm/serve/engine_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from mlc_llm.serve.config import EngineConfig, GenerationConfig
from mlc_llm.serve.event_trace_recorder import EventTraceRecorder
from mlc_llm.streamer import TextStreamer
from mlc_llm.support import download, logging
from mlc_llm.support import download_cache, logging
from mlc_llm.support.auto_device import detect_device
from mlc_llm.support.style import green
from mlc_llm.tokenizer import Tokenizer
Expand Down Expand Up @@ -120,7 +120,7 @@ def _process_model_args(
def _convert_model_info(model: ModelInfo) -> Tuple[str, str]:
nonlocal conversation

model_path = download.get_or_download_model(model.model)
model_path = download_cache.get_or_download_model(model.model)
mlc_config_path = model_path / "mlc-chat-config.json"
config_file_paths.append(str(mlc_config_path))

Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/support/auto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def detect_mlc_chat_config(mlc_chat_config: str) -> Path:
# pylint: disable=import-outside-toplevel
from mlc_llm.model import MODEL_PRESETS

from .download import download_and_cache_mlc_weights
from .download_cache import download_and_cache_mlc_weights

# pylint: enable=import-outside-toplevel

Expand Down
6 changes: 3 additions & 3 deletions python/mlc_llm/support/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ def _check():
f"but got {MLC_JIT_POLICY}."
)

if MLC_DOWNLOAD_POLICY not in ["ON", "OFF", "REDO", "READONLY"]:
if MLC_DOWNLOAD_CACHE_POLICY not in ["ON", "OFF", "REDO", "READONLY"]:
raise ValueError(
"Invalid MLC_AUTO_DOWNLOAD_POLICY. "
'It has to be one of "ON", "OFF", "REDO", "READONLY"'
f"but got {MLC_DOWNLOAD_POLICY}."
f"but got {MLC_DOWNLOAD_CACHE_POLICY}."
)


Expand Down Expand Up @@ -80,7 +80,7 @@ def _get_read_only_weight_caches() -> List[Path]:
MLC_DSO_SUFFIX = _get_dso_suffix()
MLC_TEST_MODEL_PATH: List[Path] = _get_test_model_path()

MLC_DOWNLOAD_POLICY = os.environ.get("MLC_DOWNLOAD_POLICY", "ON")
MLC_DOWNLOAD_CACHE_POLICY = os.environ.get("MLC_DOWNLOAD_CACHE_POLICY", "ON")
MLC_LLM_HOME: Path = _get_cache_dir()
MLC_LLM_READONLY_WEIGHT_CACHE = _get_read_only_weight_caches()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from . import logging, tqdm
from .constants import (
MLC_DOWNLOAD_POLICY,
MLC_DOWNLOAD_CACHE_POLICY,
MLC_LLM_HOME,
MLC_LLM_READONLY_WEIGHT_CACHE,
MLC_TEMP_DIR,
Expand All @@ -24,12 +24,12 @@
logger = logging.getLogger(__name__)


def log_download_policy():
def log_download_cache_policy():
"""log current download policy"""
logger.info(
"%s = %s. Can be one of: ON, OFF, REDO, READONLY",
bold("MLC_DOWNLOAD_POLICY"),
MLC_DOWNLOAD_POLICY,
bold("MLC_DOWNLOAD_CACHE_POLICY"),
MLC_DOWNLOAD_CACHE_POLICY,
)


Expand Down Expand Up @@ -130,9 +130,9 @@ def download_and_cache_mlc_weights( # pylint: disable=too-many-locals
force_redo: Optional[bool] = None,
) -> Path:
"""Download weights for a model from the HuggingFace Git LFS repo."""
log_download_policy()
if MLC_DOWNLOAD_POLICY == "OFF":
raise RuntimeError(f"Cannot download {model_url} as MLC_DOWNLOAD_POLICY=OFF")
log_download_cache_policy()
if MLC_DOWNLOAD_CACHE_POLICY == "OFF":
raise RuntimeError(f"Cannot download {model_url} as MLC_DOWNLOAD_CACHE_POLICY=OFF")

prefixes, mlc_prefix = ["HF://", "https://huggingface.co/"], ""
mlc_prefix = next(p for p in prefixes if model_url.startswith(p))
Expand All @@ -155,7 +155,7 @@ def download_and_cache_mlc_weights( # pylint: disable=too-many-locals
return cache_dir

if force_redo is None:
force_redo = MLC_DOWNLOAD_POLICY == "REDO"
force_redo = MLC_DOWNLOAD_CACHE_POLICY == "REDO"

git_dir = MLC_LLM_HOME / "model_weights" / domain / user / repo
readonly_cache_dirs.append(str(git_dir))
Expand All @@ -166,10 +166,10 @@ def download_and_cache_mlc_weights( # pylint: disable=too-many-locals
logger.info("Weights already downloaded: %s", bold(str(git_dir)))
return git_dir

if MLC_DOWNLOAD_POLICY == "READONLY":
if MLC_DOWNLOAD_CACHE_POLICY == "READONLY":
raise RuntimeError(
f"Cannot find cache for {model_url}, "
"cannot proceed to download as MLC_DOWNLOAD_POLICY=READONLY, "
"cannot proceed to download as MLC_DOWNLOAD_CACHE_POLICY=READONLY, "
"please check settings MLC_LLM_READONLY_WEIGHT_CACHE, "
f"local path candidates: {readonly_cache_dirs}"
)
Expand Down

0 comments on commit ff91749

Please sign in to comment.