This repository was archived by the owner on Sep 30, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 126
/
Copy pathtest_codegen2.py
65 lines (42 loc) · 1.66 KB
/
test_codegen2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#%%
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen2-1B")
model = AutoModelForCausalLM.from_pretrained("Salesforce/codegen2-1B", trust_remote_code=True, revision="main")
#%%
model = model.to(device="cuda")
#%%
text = """
import os
def post_to_pastebin"""
input_ids = tokenizer(text, return_tensors="pt").to("cuda").input_ids
generated_ids = model.generate(input_ids, max_length=512)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
# %%
def format_model_input(prefix, suffix):
return prefix + "<mask_1>" + suffix + "<|endoftext|>" + "<sep>" + "<mask_1>"
prefix = """
import os
def post_to_pastebin"""
suffix = "result = post_to_pastebin(content)"
text = format_model_input(prefix, suffix)
input_ids = tokenizer(text, return_tensors="pt").to("cuda").input_ids
generated_ids = model.generate(input_ids, max_length=128)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=False))
# %%
def main():
text = """
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
if __name__ == '__main__':
main()
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
# %%
import os
def post_to_pastebin"""
input_ids = tokenizer(text, return_tensors="pt").to("cuda").input_ids
generated_ids = model.generate(input_ids, max_length=512)
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
# %%
def post_to_pastebin(content):
input_ids = tokenizer(content, return_tensors="pt").to("cuda").input_ids
generated_ids = model.generate(input_ids, max_length=512)
return tokenizer.decode(generated_ids[0], skip_special_tokens=True)