6
6
import json
7
7
import os
8
8
import tempfile
9
+ import time
9
10
from collections import defaultdict
10
11
from typing import Any , Callable , Dict , Generator , List , Optional , Tuple , Union
11
12
14
15
import huggingface_hub .constants
15
16
import numpy as np
16
17
import torch
17
- from huggingface_hub import HfFileSystem , hf_hub_download , snapshot_download
18
+ from huggingface_hub import (HfFileSystem , hf_hub_download , scan_cache_dir ,
19
+ snapshot_download )
18
20
from safetensors .torch import load_file , safe_open , save_file
19
21
from tqdm .auto import tqdm
20
22
@@ -253,6 +255,8 @@ def download_weights_from_hf(
253
255
# Use file lock to prevent multiple processes from
254
256
# downloading the same model weights at the same time.
255
257
with get_lock (model_name_or_path , cache_dir ):
258
+ start_size = scan_cache_dir ().size_on_disk
259
+ start_time = time .perf_counter ()
256
260
hf_folder = snapshot_download (
257
261
model_name_or_path ,
258
262
allow_patterns = allow_patterns ,
@@ -262,6 +266,11 @@ def download_weights_from_hf(
262
266
revision = revision ,
263
267
local_files_only = huggingface_hub .constants .HF_HUB_OFFLINE ,
264
268
)
269
+ end_time = time .perf_counter ()
270
+ end_size = scan_cache_dir ().size_on_disk
271
+ if end_size != start_size :
272
+ logger .info ("Time took to download weights for %s: %.6f seconds" ,
273
+ model_name_or_path , end_time - start_time )
265
274
return hf_folder
266
275
267
276
0 commit comments