-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdata_types.py
114 lines (97 loc) · 3.31 KB
/
data_types.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from typing import Dict, Tuple, List, Optional, Union, TypedDict, TypeVar
import torch
from constants import TokenType
T = TypeVar("T") # Generic Type
OPT = Tuple[Union[str, TokenType], Union[str, int], Optional[List['OPT']]] # Type, token, children
class Formula(TypedDict):
opt: OPT
tex: str
class Article(TypedDict):
text: str
formulas: Dict[str, Formula]
class GenTaskSample(TypedDict):
prompt: Article
label: Article
class AnswerScoringSample(TypedDict):
answer: Article
problem_id: int
problem_log_id: int
grade: int
class FeedbackTaskSample(TypedDict):
problem_id: str
problem_code: str
answer: Article
feedback: Article
class ProblemSolvingTaskSample(TypedDict):
problem: Article
steps: Article
answer: Article
level: Optional[str]
class CTStep(TypedDict):
step: Article
input: Article
feedback: Article
action: str
outcome: str
class CTTaskSample(TypedDict):
student_id: str
problem_id: str
problem: Article
steps: List[CTStep]
class SequenceMetaData(TypedDict):
# For gen tasks
prompt_length: Optional[int]
# For classification
label: Optional[int]
# For answer scoring
problem_id: Optional[int]
problem_log_id: Optional[int]
# For problem solving
level: Optional[str]
class Sequence:
def __init__(self, name: str):
self.name = name
self.token_ids: List[int] = []
self.token_types: List[TokenType] = []
self.pos_vecs: List[List[int]] = []
self.pos_levels: List[int] = []
self.gpt_tokens: List[List[int]] = []
self.meta: SequenceMetaData = {}
def split_at(self, split_point: int):
pre_split = Sequence(self.name)
pre_split.token_ids = self.token_ids[:split_point]
pre_split.token_types = self.token_types[:split_point]
pre_split.pos_vecs = self.pos_vecs[:split_point]
pre_split.pos_levels = self.pos_levels[:split_point]
pre_split.gpt_tokens = self.gpt_tokens[:split_point]
post_split = Sequence(self.name)
post_split.token_ids = self.token_ids[split_point:]
post_split.token_types = self.token_types[split_point:]
post_split.pos_vecs = self.pos_vecs[split_point:]
post_split.pos_levels = self.pos_levels[split_point:]
post_split.gpt_tokens = self.gpt_tokens[split_point:]
return pre_split, post_split
def __add__(self, seq_2: 'Sequence'):
new_seq = Sequence(self.name)
new_seq.token_ids = self.token_ids + seq_2.token_ids
new_seq.token_types = self.token_types + seq_2.token_types
new_seq.pos_vecs = self.pos_vecs + seq_2.pos_vecs
new_seq.pos_levels = self.pos_levels + seq_2.pos_levels
new_seq.gpt_tokens = self.gpt_tokens + seq_2.gpt_tokens
return new_seq
def __len__(self):
return len(self.token_ids)
class CollatedBatch(TypedDict):
sources: List[str]
token_ids: torch.Tensor
token_types: torch.Tensor
pos_vecs: torch.Tensor
pos_levels: torch.Tensor
pos_encodings: torch.Tensor
gpt_tokens: Optional[torch.Tensor]
use_shared_emb: Optional[torch.Tensor]
attention_mask: torch.Tensor
sequence_lengths: torch.Tensor
prompt_lengths: Optional[torch.Tensor]
gen_labels: Optional[torch.Tensor]
cls_labels: Optional[torch.Tensor]