-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
74 lines (59 loc) · 1.93 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from pathlib import Path
import time
from typing import NamedTuple
from transformers import (
ByT5Tokenizer,
T5ForConditionalGeneration,
PreTrainedTokenizerFast,
)
path = Path("./text-normalization-ru-terrible")
if not (path / "pytorch_model.bin").exists():
try:
path = max(path.glob("checkpoint-*"))
except ValueError:
path = next(path.glob("*"))
try:
tokenizer = PreTrainedTokenizerFast.from_pretrained(path)
except ValueError:
tokenizer = ByT5Tokenizer()
t = time.time()
model = T5ForConditionalGeneration.from_pretrained(path)
print(f"[{time.time() - t}s.] Loaded {path}")
# gpu = torch.device("cuda")
# model.to(gpu, torch.float32)
class Generation(NamedTuple):
out: str
time_tokenizer: float
time_model: float
def generate(text):
time_start = time.time()
inp_ids = tokenizer(
text,
return_tensors="pt",
).input_ids # .to(gpu)
print(inp_ids)
time_tokenizer = time.time() - time_start
# out_ids = np.argmax(
# model.forward(inp_ids, decoder_input_ids=inp_ids).logits.detach().numpy(),
# axis=-1,
# )
out_ids = model.generate(inp_ids, max_new_tokens=128)
# input_ids=inp_ids,
# max_new_tokens=128,
# do_sample=False,
# )
print(out_ids)
time_model = time.time() - (time_start + time_tokenizer)
out = tokenizer.batch_decode(out_ids)[0]
time_tokenizer += time.time() - (time_start + time_model + time_tokenizer)
# out = tokenizer.encode(text)
# print(out)
# time_tokenizer, time_model = time.time() - time_start, 0
# out = tokenizer.tokenize(text)
return Generation(out, time_tokenizer, time_model)
out = generate("РСФСР, 24 июля, Aternos.")
print(f"[tok: {out.time_tokenizer:.3f}s. model: {out.time_model:.3f}s.] {out.out}")
while True:
inp = input("> ")
out = generate(inp)
print(f"[tok: {out.time_tokenizer:.3f}s. model: {out.time_model:.3f}s.] {out.out}")