forked from anarchy-ai/LLM-VM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathweb_playground
35 lines (26 loc) · 1.13 KB
/
web_playground
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Import the require library
from flask import Flask, render_template, request
from transformers import GPT2LMHeadModel, GPT2Tokenizer
app = Flask(__name__)
# We are load the pre-trained GPT-2 model and tokenizer
model = GPT2LMHeadModel("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# We are define the route
@app.route("/")
def index():
return render_template("index.html")
@app.route("/generate", methods=["POST"])
def generate():
if request.method == "POST":
# We are get user input the form
prompt = request.form["prompt"]
max_length = int(request.form["max_length"])
# We are Tokenize the prompt
input_ids = tokenizer.encode(prompt, return_tensors="pt")
# We are Generate text based on the prompt
output = model.generate(input_ids, max_length=max_length, num_return_sequences=1)
# We are decode the generated output
generated_text = tokenizer(output[0], skip_special_tokens=True)
return render_template("index.html", prompt=prompt, generated_text=generated_text)
if __name__ == "__main__":
app.run(debug=True)