Skip to content

Commit

Permalink
[feat] Major revisions to adjust with urartu v2
Browse files Browse the repository at this point in the history
  • Loading branch information
tamohannes committed Jun 27, 2024
1 parent b2c7c16 commit 628f917
Show file tree
Hide file tree
Showing 35 changed files with 54 additions and 360 deletions.
22 changes: 22 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,25 @@ jobs:
pylint --disable=trailing-whitespace,missing-class-docstring,missing-final-newline,trailing-newlines \
--fail-under=9.0 \
$(git ls-files '*.py') || echo "::warning::Pylint check failed, but the workflow will continue."
python-build-n-publish:
name: Build and publish Python distribution
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install setuptools wheel twine
- name: Build and publish
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
git fetch --all --tags
python setup.py sdist bdist_wheel
twine upload --verbose dist/*
16 changes: 6 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,12 @@ You're all set to dive into goal-oriented, persona-based, diverse, and multi-tur
Please use the following citation:
```
@InProceedings{smith:20xx:CONFERENCE_TITLE,
author = {Hovhannes Tamoyan},
title = {LLM Roleplay: Simulating Human-Chatbot Interaction},
booktitle = {Proceedings of the 20XX Conference on XXXX},
month = mmm,
year = {20xx},
address = {Gotham City, USA},
publisher = {Association for XXX},
pages = {XXXX--XXXX},
url = {http://xxxx.xxx}
% todo
@article{anonymous,
title={LLM Roleplay: Simulating Human-Chatbot Interaction},
author={Hovhannes Tamoyan, Hendrik Schuff, Iryna Gurevych},
journal={axiv},
year={2024}
}
```
Expand Down
Empty file.
14 changes: 0 additions & 14 deletions hydra_plugins/roleplay_plugin/roleplay_plugin.py

This file was deleted.

16 changes: 4 additions & 12 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,4 @@
aim>=3.18.1
datasets>=2.13.1
transformers>=4.36.0
accelerate>=0.21.0
scikit-learn>=1.3.0
sentencepiece==0.2.0
hydra-core==1.3.2
hydra-submitit-launcher==1.2.0
iopath==0.1.10
torch==2.1.2
jsonlines==4.0.0
protobuf==3.20.3
jsonlines
tiktoken
langchain
langchain-openai
2 changes: 1 addition & 1 deletion roleplay/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.0
2.0.0
35 changes: 0 additions & 35 deletions roleplay/__init__.py
Original file line number Diff line number Diff line change
@@ -1,35 +0,0 @@
import hydra
from aim import Run
from omegaconf import OmegaConf

from roleplay.utils.launcher import launch, launch_on_slurm
from roleplay.utils.slurm import is_submitit_available


@hydra.main(version_base=None, config_path="config", config_name="main")
def main(args):
cfg = OmegaConf.create(OmegaConf.to_container(args, resolve=True, enum_to_str=True))

aim_run = Run(
repo=cfg.aim.repo,
experiment=cfg.action_config.experiment_name,
)
aim_run.set("cfg", cfg, strict=False)

if cfg.slurm.use_slurm:
assert is_submitit_available(), "Please 'pip install submitit' to schedule jobs on SLURM"

launch_on_slurm(
action_name=cfg.action_name,
cfg=cfg,
aim_run=aim_run,
)
else:
launch(action_name=cfg.action_name, cfg=cfg, aim_run=aim_run)

if aim_run.active:
aim_run.close()


if __name__ == "__main__":
main()
9 changes: 5 additions & 4 deletions roleplay/actions/roleplay.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from omegaconf import DictConfig
from tqdm import tqdm

from roleplay.common.action import Action
from roleplay.common.dataset import Dataset
from urartu.common.action import Action
from urartu.common.dataset import Dataset

from roleplay.common.persona import Persona


Expand All @@ -24,7 +25,7 @@ def track(self, prompt, name, context=None):
context=context,
)

def run(self):
def main(self):
self.aim_run["num_no_prompts"] = 0
self.aim_run["num_multiple_prompts"] = 0
self.aim_run["num_non_coherent"] = 0
Expand Down Expand Up @@ -221,4 +222,4 @@ def run(self):

def main(cfg: DictConfig, aim_run: Run):
roleplay = Roleplay(cfg, aim_run)
roleplay.run()
roleplay.main()
9 changes: 0 additions & 9 deletions roleplay/common/action.py

This file was deleted.

29 changes: 0 additions & 29 deletions roleplay/common/dataset.py

This file was deleted.

4 changes: 0 additions & 4 deletions roleplay/common/device.py

This file was deleted.

9 changes: 0 additions & 9 deletions roleplay/common/metric.py

This file was deleted.

2 changes: 1 addition & 1 deletion roleplay/common/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import string
from typing import Any, Dict, List

from roleplay.common.device import DEVICE
from urartu.common.device import DEVICE


class Model:
Expand Down
Empty file removed roleplay/config/__init__.py
Empty file.
6 changes: 0 additions & 6 deletions roleplay/config/action_config/defaults.yaml

This file was deleted.

43 changes: 0 additions & 43 deletions roleplay/config/main.yaml

This file was deleted.

File renamed without changes.
File renamed without changes.
File renamed without changes.
17 changes: 0 additions & 17 deletions roleplay/datasets/hf_datasets.py

This file was deleted.

7 changes: 4 additions & 3 deletions roleplay/models/causal_lm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from roleplay.common.device import AUTO_DEVICE
from urartu.common.device import DEVICE
from urartu.utils.dtype import eval_dtype
from roleplay.common.model import Model


Expand All @@ -21,8 +22,8 @@ def _load_model(self) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
self.model = AutoModelForCausalLM.from_pretrained(
self.cfg.name,
cache_dir=self.cfg.cache_dir,
device_map=AUTO_DEVICE,
torch_dtype=eval(self.cfg.dtype),
device_map=DEVICE,
torch_dtype=eval_dtype(self.cfg.dtype),
token=self.cfg.api_token,
)

Expand Down
8 changes: 4 additions & 4 deletions roleplay/models/pipeline_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Tuple

import torch # NOQA
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

from roleplay.common.device import DEVICE
from urartu.common.device import DEVICE
from urartu.utils.dtype import eval_dtype
from roleplay.common.model import Model


Expand All @@ -21,7 +21,7 @@ def _load_model(self) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
self.cfg.name,
cache_dir=self.cfg.cache_dir,
device_map=DEVICE,
torch_dtype=eval(self.cfg.dtype),
torch_dtype=eval_dtype(self.cfg.dtype),
token=self.cfg.api_token,
)
self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.name)
Expand All @@ -30,7 +30,7 @@ def _load_model(self) -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
"text-generation",
model=model,
tokenizer=self.tokenizer,
torch_dtype=eval(self.cfg.dtype),
torch_dtype=eval_dtype(self.cfg.dtype),
device_map=DEVICE,
eos_token_id=self.tokenizer.eos_token_id,
)
Expand Down
17 changes: 0 additions & 17 deletions roleplay/utils/io.py

This file was deleted.

Loading

0 comments on commit 628f917

Please sign in to comment.