-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdump_RL.py
40 lines (27 loc) · 1.28 KB
/
dump_RL.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Original code from voidful/TextRL/textrl/dump.py
import argparse
import sys
from transformers import AutoTokenizer, AutoModelWithLMHead
from textrl import TextRLEnv, TextRLActor
def parse_dump_args(args):
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True, type=str, help="model before rl training")
parser.add_argument("--tokenizer", required=True, type=str, help="tokenizer before rl training")
parser.add_argument("--rl", required=True, type=str, help="rl model dir")
parser.add_argument("--dumpdir", required=True, type=str, help="output path")
return vars(parser.parse_args(args))
def main(arg=None):
arg = parse_dump_args(sys.argv[1:]) if arg is None else parse_dump_args(arg)
model = AutoModelWithLMHead.from_pretrained(arg.get('model'))
tokenizer = AutoTokenizer.from_pretrained(arg.get('tokenizer'))
env = TextRLEnv(model, tokenizer, observation_input=[{'input':'dummy'}])
actor = TextRLActor(env, model, tokenizer)
agent = actor.agent_ppo()
agent.load(arg.get('rl'))
model.lm_head = actor.converter
model.save_pretrained(arg.get('dumpdir'))
tokenizer.save_pretrained(arg.get('dumpdir'))
print('==================')
print("Finish model dump.")
if __name__ == "__main__":
main()