-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
140 lines (110 loc) · 4.68 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
main module for llm service
This script sets up a Flask-based API that loads a large language
model (LLM) using Hugging Face's transformers library. It includes
functions to load a pre-trained model and tokenizer from disk,
handle requests, and generate responses from the model.
"""
import os
import sys
import logging
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from flask import Flask, jsonify
from transformers import logging as hf_logging
# Configure Hugging Face transformers logging
hf_logging.set_verbosity(hf_logging.CRITICAL)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Suppress tokenizer logs from tokenization_utils_base
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
# Get Hugging Face token
token = os.getenv('LLAMA_TOKEN')
if not token:
logger.error("Hugging Face token (LLAMA_TOKEN) is not set in the environment.")
sys.exit(1)
# Define model name and paths
model_name = "google/gemma-2-2b"
model_dir = "/app/model"
model_path = "/app/model/models--google--gemma-2-2b"
offload_folder = "/app/offload"
# Ensure model and offload directories exist
os.makedirs(model_dir, exist_ok=True)
os.makedirs(offload_folder, exist_ok=True)
# Initialize Flask app
app = Flask(__name__)
def get_model(model_path):
"""
Load the model and tokenizer from the local directory if available.
If not available, download it from Hugging Face.
:param model_path: Path to the directory containing model files
:type model_path: str
:return: Loaded model and tokenizer or None if failed
:rtype: tuple
"""
try:
logger.info(f"Looking for model in: {model_path}")
# Check if model and tokenizer files exist
if os.path.exists(os.path.join(model_path, 'model.safetensors.index.json')) and \
os.path.exists(os.path.join(model_path, 'tokenizer_config.json')):
logger.info("Loading model and tokenizer from cached safetensors files.")
# Load tokenizer from local path
# logging.disable(logging.CRITICAL)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# logger.info(f"Tokenizer: {tokenizer}")
# Load model with device map for inference
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto", # Automatically map model to devices
offload_folder=offload_folder # Offload to disk if needed
)
logger.info("Model loaded with device map and offloading.")
else:
# If model not found, download from Hugging Face
logger.info("Model not found locally, downloading from Hugging Face...")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)
model = AutoModelForCausalLM.from_pretrained(
model_name,
use_auth_token=token,
device_map="auto",
offload_folder=offload_folder
)
# Save the model and tokenizer locally for future use
tokenizer.save_pretrained(model_path)
model.save_pretrained(model_path)
logger.info(f"Model and tokenizer downloaded and saved to {model_path}")
return model, tokenizer
except Exception as e:
logger.exception("Error in get_model() function")
return None, None
# Load model and tokenizer
model, tokenizer = get_model(model_path)
# Ensure the model is loaded before starting the server
if model is None or tokenizer is None:
logger.error("Failed to load model or tokenizer. Exiting application.")
sys.exit(1)
@app.route('/test/', methods=['GET'])
def test_llm():
"""
Test endpoint to send a test query to the LLM and return the response
:return: JSON response from the model
:rtype: flask.Response
"""
test_message = "Hi, are you awake?"
try:
# Prepare inputs
inputs = tokenizer(test_message, return_tensors='pt')
# Get device of model's parameters
input_device = next(model.parameters()).device
# Move inputs to the correct device
inputs = {key: value.to(input_device) for key, value in inputs.items()}
logger.info("Generating response from LLM...")
# Generate output
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=10)
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return jsonify({'response': response_text})
except Exception as e:
logger.exception("Error during generation")
return jsonify({'error': 'An internal error occurred'}), 500