Skip to content

Commit

Permalink
Support for models with token shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
sdadas committed Oct 4, 2020
1 parent f5d4206 commit 1ed84e2
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 14 deletions.
7 changes: 6 additions & 1 deletion preprocess/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@

class TaskProcessor(object):

def __init__(self, task: BaseTask, data_path: str, output_path: str, model_path: str, resample: str):
def __init__(self, task: BaseTask, data_path: str, output_path: str, model_path: str, resample: str,
token_shapes: bool=False):
self.task: BaseTask = task
self.data_path: str = data_path
self.model_path = model_path
self.output_path = output_path
self.task_output_path = os.path.join(self.output_path, task.spec().output_path())
self.token_shapes = token_shapes
self.resample = self._parse_resample_string(resample)
if not os.path.exists(self.task_output_path):
os.makedirs(self.task_output_path, exist_ok=True)
Expand Down Expand Up @@ -142,5 +144,8 @@ def _run_fairseq_preprocess(self, input_name: str, destdir: str):
dict_path: str = os.path.join(self.model_path, "dict.txt")
cmd.append("--srcdict")
cmd.append(dict_path)
if self.token_shapes:
cmd.append("--task")
cmd.append("masked_lm_with_token_shapes")
logging.info("running %s", cmd.__repr__())
subprocess.run(cmd)
15 changes: 8 additions & 7 deletions run_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,14 @@ def __init__(self, task: BaseTask, task_id: str, input_dir: str, output_dir: str
self.model_name: str = os.path.basename(model_dir)
self.task_output_dir: str = os.path.join(self.output_dir, f"{task.spec().output_path()}-bin")

def prepare_task(self, resample: str):
processor = TaskProcessor(self.task, self.input_dir, self.output_dir, self.model_dir, resample)
def prepare_task(self, resample: str, token_shapes: bool):
processor = TaskProcessor(self.task, self.input_dir, self.output_dir, self.model_dir, resample, token_shapes)
processor.prepare()

def train_task(self, train_epochs: int, fp16: bool, max_sentences: int, update_freq: int):
def train_task(self, train_epochs: int, fp16: bool, max_sentences: int, update_freq: int, token_shapes: bool):
train_size = self._count_train()
trainer = TaskTrainer(self.task, self.output_dir, self.model_dir, train_size, arch=self.arch, fp16=fp16)
trainer = TaskTrainer(self.task, self.output_dir, self.model_dir, train_size,
arch=self.arch, fp16=fp16, token_shapes=token_shapes)
trainer.train(train_epochs=train_epochs, max_sentences=max_sentences, update_freq=update_freq)

def evaluate_task(self):
Expand Down Expand Up @@ -109,7 +110,7 @@ def log_score(self, task_name: str, task_id: str, params: Dict, scores: Dict):

def run_tasks(arch: str, model_dir: str, input_dir: str="data", output_dir: str="data_processed",
tasks: str=None, train_epochs: int=10, fp16: bool=False, max_sentences: int=1, update_freq: int=16,
evaluation_only: bool=False, resample: str=None, seed: int=None):
evaluation_only: bool=False, resample: str=None, token_shapes: bool=False, seed: int=None):
assert arch in ("roberta_base", "roberta_large", "bart_base", "bart_large")
params = locals()
if tasks is None:
Expand All @@ -127,8 +128,8 @@ def run_tasks(arch: str, model_dir: str, input_dir: str="data", output_dir: str=
task = task_class()
runner: TaskRunner = TaskRunner(task, task_id, input_dir, output_dir, model_dir, arch, seed)
if not evaluation_only:
runner.prepare_task(resample)
runner.train_task(train_epochs, fp16, max_sentences, update_freq)
runner.prepare_task(resample, token_shapes)
runner.train_task(train_epochs, fp16, max_sentences, update_freq, token_shapes)
score = runner.evaluate_task()
runner.log_score(task_name, task_id, params, score)

Expand Down
13 changes: 7 additions & 6 deletions train/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import logging
import os
import random
Expand All @@ -12,8 +11,8 @@

class TaskTrainer(object):

def __init__(self, task: BaseTask, data_path: str, model_path: str, train_size: int,
checkpoint: str="model.pt", arch: str="roberta_large", fp16: bool=False):
def __init__(self, task: BaseTask, data_path: str, model_path: str, train_size: int, checkpoint: str="model.pt",
arch: str="roberta_large", fp16: bool=False, token_shapes: bool=False):
self.task: BaseTask = task
self.train_size: int = train_size
self.data_path: str = data_path
Expand All @@ -24,6 +23,7 @@ def __init__(self, task: BaseTask, data_path: str, model_path: str, train_size:
self.arch: str = arch
self.learning_rate = "1e-5"
self.fp16 = fp16
self.token_shapes = token_shapes

def train(self, max_sentences: int=1, update_freq: int=16, train_epochs: int=10, seed: int=None):
self._run_fairseq_train(seed, max_sentences=max_sentences, update_freq=update_freq, max_epoch=train_epochs)
Expand All @@ -42,6 +42,7 @@ def _run_fairseq_train(self, seed: int, max_sentences: int=16, update_freq: int=
restore_file = os.path.join(self.model_path, self.checkpoint)
assert os.path.exists(restore_file)
checkpoint_path = os.path.join("checkpoints", self.model_name, self.task.spec().output_path())
task = "sentence_prediction" if not self.token_shapes else "sentence_prediction_with_token_shapes"
self._remove_previous_checkpoints(checkpoint_path)
cmd = [
self.task_data_path,
Expand All @@ -50,7 +51,7 @@ def _run_fairseq_train(self, seed: int, max_sentences: int=16, update_freq: int=
"--max-sentences", str(max_sentences),
"--update-freq", str(update_freq),
"--max-tokens", "4400",
"--task", "sentence_prediction",
"--task", task,
"--reset-optimizer",
"--reset-dataloader",
"--reset-meters",
Expand Down Expand Up @@ -107,12 +108,12 @@ def _run_fairseq_train(self, seed: int, max_sentences: int=16, update_freq: int=

def _run_training(self, cmd: List[str]):
try:
from fairseq_cli.train import cli_main_helper
from fairseq_cli.train import main
parser = options.get_training_parser()
if self.arch.startswith("bart"):
parser.add_argument("--max-positions", type=int)
args = options.parse_args_and_arch(parser, input_args=cmd)
cli_main_helper(args)
main(args)
except ImportError:
cmd.insert(0, "fairseq-train")
subprocess.run(cmd)
Expand Down

0 comments on commit 1ed84e2

Please sign in to comment.