Skip to content

Commit e1b1b63

Browse files
committedDec 5, 2021
Merge branch 'encoder' into main
2 parents 4abd53f + 97fe774 commit e1b1b63

File tree

7 files changed

+288
-56
lines changed

7 files changed

+288
-56
lines changed
 

‎.gitignore

+6
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,9 @@ venv.bak/
127127
# seed project
128128
lightning_logs/
129129
.DS_Store
130+
131+
132+
# WandB stuff
133+
ParaPhrasegen/
134+
wandb/
135+
*.ckpt

‎eval.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from typing import List
2+
3+
import torch
4+
5+
from transformers import AutoTokenizer
6+
7+
8+
from paraphrasegen.constants import PATH_BASE_MODELS
9+
from paraphrasegen.model import Encoder
10+
from paraphrasegen.loss import Similarity
11+
12+
13+
device = (
14+
"cpu" # torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
15+
)
16+
17+
18+
def tokenize_text(model_name, sentences: List[str]):
19+
tokenizer = AutoTokenizer.from_pretrained(
20+
model_name, use_fast=True, cache_dir=PATH_BASE_MODELS
21+
)
22+
23+
tokenized = tokenizer(
24+
sentences,
25+
max_length=32,
26+
padding="max_length",
27+
truncation=True,
28+
return_tensors="pt",
29+
)
30+
return tokenized
31+
32+
33+
def eval(encoder):
34+
anchor = "A Washington County man may have the countys first human case of West Nile virus , the health department said Friday ."
35+
target = "A Hyderabadi man may have the citys first human case of West Nile virus , the health ministry said Friday ."
36+
# target = "The countys first and only human case of West Nile this year was confirmed by health officials on Sept . 8 ."
37+
negative = "What the fuck is the County Virus"
38+
39+
print("Tokenizing Text... ", sep="")
40+
tokenized = tokenize_text(
41+
encoder.hparams.model_name_or_path, [anchor, target, negative]
42+
)
43+
44+
print("Tokenized!")
45+
46+
print("Generating Embeddings... ", sep="")
47+
embeddings = encoder(
48+
tokenized["input_ids"],
49+
tokenized["attention_mask"],
50+
do_mlm=False,
51+
)
52+
53+
anchor_embedddings = embeddings[0, ...]
54+
target_embedddings = embeddings[1, ...]
55+
negative_embeddings = embeddings[2, ...]
56+
57+
print("Generated!")
58+
59+
# print(f"|Anchor|: {torch.norm(anchor_embedddings)}")
60+
diff = target_embedddings - anchor_embedddings
61+
print(
62+
f"|target_embedddings - anchor_embedddings|: {torch.norm(diff)}, %age: {100 * torch.mean(diff / anchor_embedddings)}"
63+
)
64+
65+
diff = negative_embeddings - anchor_embedddings
66+
print(
67+
f"|negative_embeddings - anchor_embedddings|: {torch.norm(diff)}, %age: {100 * torch.mean(diff / anchor_embedddings)}"
68+
)
69+
70+
sim = Similarity(temp=1)
71+
print(
72+
f"Similarity between anchor and target: {sim(anchor_embedddings, target_embedddings)}"
73+
)
74+
75+
print(
76+
f"Similarity between anchor and negative: {sim(anchor_embedddings, negative_embeddings)}"
77+
)
78+
79+
80+
if __name__ == "__main__":
81+
path_to_checkpoint = "runs/default/version_7/checkpoints/last.ckpt" # input(">>> Enter Model Checkpoint Path: ")
82+
print("Loading Model... ", sep="")
83+
encoder = Encoder.load_from_checkpoint(path_to_checkpoint)
84+
print("Finished")
85+
86+
eval(encoder)

‎paraphrasegen/constants.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import torch
33

4-
EPOCHS = 3
4+
MAX_EPOCHS = 5
55
PATH_DATASETS = os.environ.get("PATH_DATASETS", "./datasets")
66
PATH_BASE_MODELS = os.environ.get("PATH_BASE_MODELS", "./base_models")
77
AVAIL_GPUS = min(1, torch.cuda.device_count())

‎paraphrasegen/dataset.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -65,20 +65,22 @@ def prepare_data(self) -> None:
6565
def setup(self, stage: Optional[str] = None) -> None:
6666
self.dataset = load_dataset(*self.dataset_args, cache_dir=PATH_DATASETS)
6767

68-
if self.task_name == "qqp":
69-
self.dataset["train"] = self.dataset["train"].filter(
70-
lambda el: el["label"] == 1
71-
)
72-
self.dataset["validation"] = self.dataset["validation"].filter(
73-
lambda el: el["label"] == 1
74-
)
75-
76-
else:
77-
self.dataset = self.dataset.filter(lambda el: el["label"] == 1)
68+
self.dataset["train"] = self.dataset["train"].filter(
69+
lambda el: el["label"] == 1
70+
)
71+
# self.dataset["validation"] = self.dataset["validation"].filter(
72+
# lambda el: el["label"] == 1
73+
# )
74+
7875
self.dataset = self.dataset.map(
7976
self.convert_to_features,
8077
batched=True,
81-
remove_columns=(["label",] + self.text_fields),
78+
remove_columns=(
79+
[
80+
"label",
81+
]
82+
+ self.text_fields
83+
),
8284
num_proc=NUM_WORKERS,
8385
)
8486

‎paraphrasegen/model.py

+72-30
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from os import stat_result
2-
from typing import Optional
2+
from typing import List, Optional
33
import time
44

55
import torch
@@ -12,7 +12,7 @@
1212

1313
from transformers import AutoConfig, AutoModel, AdamW
1414

15-
from paraphrasegen.loss import ContrastiveLoss
15+
from paraphrasegen.loss import ContrastiveLoss, Similarity
1616
from paraphrasegen.constants import (
1717
AVAIL_GPUS,
1818
BATCH_SIZE,
@@ -72,25 +72,48 @@ def forward(self, attention_mask, outputs):
7272

7373

7474
class MLPLayer(nn.Module):
75-
def __init__(self, in_dims: int = 768, hidden_dims: int = 768):
75+
def __init__(
76+
self, in_dims: int = 768, hidden_dims: List[int] = 768, activation: str = "GELU"
77+
):
7678
super(MLPLayer, self).__init__()
77-
self.fc1 = nn.Linear(in_dims, hidden_dims)
78-
self.layer_norm = nn.LayerNorm(hidden_dims)
79-
self.activation = nn.Tanh()
79+
80+
if activation == "GELU":
81+
activation_fn = nn.GELU()
82+
elif activation == "ReLU":
83+
activation_fn = nn.ReLU()
84+
elif activation == "mish":
85+
activation_fn = nn.Mish()
86+
elif activation == "leaky_relu":
87+
activation_fn = nn.LeakyReLU()
88+
89+
layers = [
90+
nn.Linear(in_dims, hidden_dims[0]),
91+
nn.LayerNorm(hidden_dims[0]),
92+
activation_fn,
93+
]
94+
95+
for i in range(1, len(hidden_dims)):
96+
layers += [
97+
nn.Linear(hidden_dims[i - 1], hidden_dims[i]),
98+
nn.LayerNorm(hidden_dims[i]),
99+
activation_fn,
100+
]
101+
102+
self.net = nn.Sequential(*layers)
80103

81104
def forward(self, x: torch.Tensor):
82-
out = self.fc1(x)
83-
out = self.layer_norm(out)
84-
return self.activation(out)
105+
return self.net(x)
85106

86107

87108
class Encoder(pl.LightningModule):
88109
def __init__(
89110
self,
90111
model_name_or_path: str,
91112
input_mask_rate: float = 0.1,
92-
embedding_from: str = "single",
93113
pooler_type: str = "cls",
114+
mlp_layers: List[int] = [768],
115+
temp: float = 0.05,
116+
hard_negative_weight: float = 0,
94117
learning_rate: float = 3e-5,
95118
weight_decay: float = 0,
96119
) -> None:
@@ -99,17 +122,19 @@ def __init__(
99122
self.save_hyperparameters()
100123
self.config = AutoConfig.from_pretrained(model_name_or_path)
101124
self.input_mask_rate = input_mask_rate
102-
self.embedding_from = embedding_from
103125
self.bert_model = AutoModel.from_pretrained(
104126
model_name_or_path, config=self.config, cache_dir=PATH_BASE_MODELS
105127
)
106128

107129
self.pooler_type = pooler_type
108130
self.pooler = Pooler(pooler_type)
109131

110-
self.net = MLPLayer()
132+
self.net = MLPLayer(in_dims=768, hidden_dims=mlp_layers)
111133

112-
self.loss_fn = ContrastiveLoss()
134+
self.loss_fn = ContrastiveLoss(
135+
temp=self.hparams.temp,
136+
hard_negative_weight=self.hparams.hard_negative_weight,
137+
)
113138

114139
def forward(
115140
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, do_mlm: bool = True
@@ -167,42 +192,59 @@ def training_step(self, batch, batch_idx):
167192
attention_mask=batch["target_attention_mask"],
168193
)
169194

170-
loss = self.loss_fn(anchor_outputs, target_outputs)
195+
negative_index = torch.randperm(batch["anchor_input_ids"].size(0))
196+
197+
negative_outputs = self(
198+
input_ids=batch["anchor_input_ids"][negative_index],
199+
attention_mask=batch["anchor_attention_mask"][negative_index],
200+
)
201+
202+
loss = self.loss_fn(anchor_outputs, target_outputs, negative_outputs)
171203
self.log("loss/train", loss)
172204

173205
return loss
174206

175-
def validation_step(self, batch, batch_idx, dataloader_idx=0):
207+
def _evaluate(self, batch):
176208
anchor_outputs = self(
177209
input_ids=batch["anchor_input_ids"],
178210
attention_mask=batch["anchor_attention_mask"],
211+
do_mlm=False,
179212
)
180213

181214
target_outputs = self(
182215
input_ids=batch["target_input_ids"],
183216
attention_mask=batch["target_attention_mask"],
217+
do_mlm=False,
184218
)
185219

186-
loss = self.loss_fn(anchor_outputs, target_outputs)
220+
pos_anchor_emb = anchor_outputs[batch["labels"] == 1]
221+
pos_target_emb = target_outputs[batch["labels"] == 1]
187222

188-
self.log("loss/val", loss, prog_bar=True)
189-
self.log("hp_metric", loss)
223+
neg_anchor_emb = anchor_outputs[batch["labels"] == 0]
224+
neg_target_emb = target_outputs[batch["labels"] == 0]
190225

191-
def test_step(self, batch, batch_idx):
192-
anchor_outputs = self(
193-
input_ids=batch["anchor_input_ids"],
194-
attention_mask=batch["anchor_attention_mask"],
195-
)
226+
pos_diff = torch.norm(pos_anchor_emb - pos_target_emb).mean()
227+
neg_diff = torch.norm(neg_anchor_emb - neg_target_emb).mean()
196228

197-
target_outputs = self(
198-
input_ids=batch["target_input_ids"],
199-
attention_mask=batch["target_attention_mask"],
229+
sim = Similarity(temp=self.hparams.temp)
230+
pos_sim = sim(pos_anchor_emb, pos_target_emb).mean()
231+
neg_sim = sim(neg_anchor_emb, neg_target_emb).mean()
232+
233+
self.log_dict(
234+
{
235+
"diff/pos": pos_diff,
236+
"diff/neg": neg_diff,
237+
"sim/pos": pos_sim,
238+
"sim/neg": neg_sim,
239+
}
200240
)
241+
self.log("hp_metric", pos_sim - neg_sim)
201242

202-
loss = self.loss_fn(anchor_outputs, target_outputs)
243+
def validation_step(self, batch, batch_idx, dataloader_idx=0):
244+
self._evaluate(batch)
203245

204-
self.log("loss/test", loss, prog_bar=True)
205-
self.log("hp_metric", loss)
246+
def test_step(self, batch, batch_idx):
247+
self._evaluate(batch)
206248

207249
def configure_optimizers(self):
208250
"""Prepare optimizer and schedule (linear warmup and decay)"""
@@ -256,7 +298,7 @@ def configure_optimizers(self):
256298
trainer = Trainer(
257299
max_epochs=1,
258300
gpus=AVAIL_GPUS,
259-
log_every_n_steps=10,
301+
log_every_n_steps=2,
260302
precision=16,
261303
stochastic_weight_avg=True,
262304
logger=TensorBoardLogger("runs/"),

‎run_exp.sh

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#! /bin/sh
2+
3+
4+
for p in "cls" "avg"; do
5+
for imr in "0.1" "0.15" "0.2" "0.25"; do
6+
python train.py \
7+
--pooler $p \
8+
-i $imr
9+
done
10+
done

0 commit comments

Comments
 (0)
Please sign in to comment.