Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop generation with Continuation when a specific string was generated #187

Merged
merged 1 commit into from
Jul 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 64 additions & 11 deletions outlines/text/generate/continuation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, Union

import torch

Expand All @@ -17,36 +17,89 @@ class Continuation(Sequence):

"""

def __init__(self, model, max_tokens: Optional[int]):
def __init__(
self, model, max_tokens: Optional[int] = None, stop: Union[str, List[str]] = []
):
super().__init__(model, max_tokens)
self.eos_token_id = torch.tensor(
[self.model.tokenizer.eos_token_id], device=self.device
)

if isinstance(stop, str):
stop = [stop]

self.stop_sequences = stop

def is_finished(self, token_ids: torch.LongTensor) -> torch.BoolTensor:
"""Determine whether the sequences reached maximum length of end with
and EOS token.

In practice, `Sequence`'s `__call__` methods only passed the `token_ids`
of the sequences that haven't been marked as finished already, which is
why we only need to look for the EOS token in the last element rather
than in the whole sequence.
We only need to look for the EOS token in the last element rather than
in the whole sequence. Indeed, (1) EOS is a single token (2)
`Sequence`'s `__call__` methods only passed the `token_ids` of the
sequences that haven't been marked as finished already.

Parameters
----------
token_ids
The input sequences.

"""
return token_ids[:, -1] == self.model.tokenizer.eos_token_id

sequences = self.model.tokenizer.decode(token_ids)
contains_stop_sequence = []
for sequence in sequences:
found = False
for stop_str in self.stop_sequences:
if stop_str in sequence:
found = True

contains_stop_sequence.append(found)

contains_stop_sequence = torch.tensor(contains_stop_sequence, dtype=torch.bool)
contains_eos = token_ids[:, -1] == self.model.tokenizer.eos_token_id

return torch.logical_or(contains_eos, contains_stop_sequence)

def postprocess_completions(self, completions: List[str]) -> List[str]:
"""Remove the EOS token from the completion."""
return [
"""Remove the EOS token from the completion.

Sequences in `stop` take precedence over EOS. For instance, if
`stop=["\n"]` and the generated sequence is 'One\nTwo<EOS>`
`Continuation.postprocess_completions` will return `One`.

"""
completions_without_eos = [
completion.replace(self.model.tokenizer.eos_token, "")
for completion in completions
]

completions_without_stop = []
for completion in completions_without_eos:
for stop_str in self.stop_sequences:
idx = completion.rfind(stop_str) # ignore the prompt
if idx > 0:
completion = completion[:idx]

completions_without_stop.append(completion)

return completions_without_stop


def continuation(model, max_tokens: Optional[int] = None):
return Continuation(model, max_tokens)
def continuation(
model, max_tokens: Optional[int] = None, *, stop: Union[str, List[str]] = []
):
"""Generate text sequences.

Parameters
----------
model
The model to use to computes the next-token logits.
max_tokens
The maximum number of tokens to generate.
stop
A string or list of strings which, when generated, stops
the generation for this sequence.

"""
return Continuation(model, max_tokens, stop)
30 changes: 30 additions & 0 deletions outlines/text/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,18 @@ def create_proposal(


def regex(model, regex_string: str, max_tokens: Optional[int] = None):
"""Generate text sequences that match the input regex.

Parameters
----------
model
The model to use to computes the next-token logits.
regex
The regular expression generated expressions must match.
max_tokens
The maximum number of tokens to generate.

"""
return Regex(model, regex_string, max_tokens)


Expand All @@ -145,6 +157,15 @@ def integer(model, max_tokens: Optional[int] = None):
signs and forbids leading zeros (even if the `int` function in Python allows
them).

Parameters
----------
model
The model to use to computes the next-token logits.
regex
The regular expression generated expressions must match.
max_tokens
The maximum number of tokens to generate.

"""
return Regex(model, r"[-+]?\d+", max_tokens)

Expand All @@ -156,5 +177,14 @@ def float(model, max_tokens: Optional[int] = None):
signs, and forbids leading zeros (even if the `float` function in Python
allows them).

Parameters
----------
model
The model to use to computes the next-token logits.
regex
The regular expression generated expressions must match.
max_tokens
The maximum number of tokens to generate.

"""
return Regex(model, r"([+-]?((0|[1-9]+)([.][0-9]*)?)|([.][0-9]+))", max_tokens)
4 changes: 3 additions & 1 deletion outlines/text/generate/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,9 @@ def __call__(
)
token_ids = self.update_token_ids(is_finished, token_ids, updated_token_ids)
attention_mask = self.expand_attention_mask(attention_mask)
is_finished[~is_finished] = self.is_finished(updated_token_ids).flatten()
is_finished[~is_finished] = self.is_finished(
updated_token_ids[:, num_prompt_tokens:]
).flatten()

result = self.model.tokenizer.decode(token_ids)
result = self.postprocess_completions(result)
Expand Down
75 changes: 63 additions & 12 deletions tests/text/generate/test_continuation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
from numpy.testing import assert_array_equal
import torch

from outlines.text.generate.continuation import Continuation, continuation

Expand All @@ -9,35 +8,87 @@ class Tokenizer:
eos_token_id = 0
pad_token_id = -1

def decode(self, token_ids):
return ["Test"] * token_ids.shape[0]


class Model:
tokenizer = Tokenizer()
device = "cpu"


def test_continuation_is_finished():
model = continuation(Model(), 10)
def test_continuation_eos_is_finished():
model = continuation(Model())
assert isinstance(model, Continuation)

token_ids = np.array([[3, 2]])
token_ids = torch.tensor([[3, 2]])
result = model.is_finished(token_ids)
assert_array_equal(result, [False])
assert torch.equal(result, torch.tensor([False]))

token_ids = np.array([[3, 2, 0]])
token_ids = torch.tensor([[3, 2, 0]])
result = model.is_finished(token_ids)
assert_array_equal(result, [True])
assert torch.equal(result, torch.tensor([True]))

token_ids = np.array([[3, 2, 1], [3, 2, 0]])
token_ids = torch.tensor([[3, 2, 1], [3, 2, 0]])
result = model.is_finished(token_ids)
assert_array_equal(result, [False, True])
assert torch.equal(result, torch.tensor([False, True]))

token_ids = np.array([[3, 2, 1, 0], [3, 2, 0, -1]])
token_ids = torch.tensor([[3, 2, 1, 0], [3, 2, 0, -1]])
result = model.is_finished(token_ids)
assert_array_equal(result, [True, False])
assert torch.equal(result, torch.tensor([True, False]))


def test_continuation_postprocess():
model = continuation(Model())
result = model.postprocess_completions(["Here<EOS>"])
assert len(result) == 1
assert result[0] == "Here"


def test_continuation_stop_is_finished():
tokenizer = Tokenizer()
tokenizer.decode = lambda x: ["finished \n", "not_finished"]
model = Model()
model.tokenizer = tokenizer

model = continuation(model, stop=["\n"])

token_ids = torch.tensor([[2, 3]])
result = model.is_finished(token_ids)
assert torch.equal(result, torch.tensor([True, False]))


def test_continuation_stop_postprocess():
model = Continuation(Model(), stop="\n")
result = model.postprocess_completions(["Stop\n"])
assert len(result) == 1
assert result[0] == "Stop"

model = Continuation(Model(), stop=["\n", ","])
result = model.postprocess_completions(["Stop"])
assert len(result) == 1
assert result[0] == "Stop"

result = model.postprocess_completions(["Stop\n"])
assert len(result) == 1
assert result[0] == "Stop"

result = model.postprocess_completions(["Stop\naaa"])
assert len(result) == 1
assert result[0] == "Stop"

result = model.postprocess_completions(["Stop,aa\naaa"])
assert len(result) == 1
assert result[0] == "Stop"

result = model.postprocess_completions(["Stop\naa,a"])
assert len(result) == 1
assert result[0] == "Stop"

result = model.postprocess_completions(["Stop\n", "Nonstop"])
assert len(result) == 2
assert result == ["Stop", "Nonstop"]

result = model.postprocess_completions(["StopHere\nNoHere<EOS>"])
assert len(result) == 1
assert result[0] == "StopHere"
10 changes: 7 additions & 3 deletions tests/text/generate/test_integration_transfomers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,25 @@ def test_transformers_integration_continuation():

model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM"
model = models.transformers(model_name, device="cpu")
sequence = generate.continuation(model)("Write a short sentence", rng=rng)
sequence = generate.continuation(model)("Write a short sentence ", rng=rng)
assert isinstance(sequence, str)
assert model.tokenizer.eos_token not in sequence

sequence = generate.continuation(model, max_tokens=10)(
"Write a short sentence", rng=rng
"Write a short sentence ", rng=rng
)
assert isinstance(sequence, str)

prompts = ["Write a short sentence", "And another one"]
prompts = ["Write a short sentence ", "And another one "]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am just curious why you added whitespace padding? According to guidance, its preferable to terminate prompts without any new space or line, because frequent tokens already come with a space before the word.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question, that came intuitively with GPT2 and numbers. I'll look at the vocabulary directly to see if that's actually the right thing to do.

Copy link
Member Author

@rlouf rlouf Jul 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I came back to this and looked at the vocabulary of the GPT2 tokenizer. It is true that most of the tokens begin with a space.

Your point highlights something that we need to be very careful about, and which might be incorrectly implemented in outlines.

It should not affect this PR since we're partially matching on text, so "/n" will match " /n". However the regex [a-z]{3} will allow "art" to be generated, but not " art". This could make it impossible to generate what would otherwise be the most probable completion.

I need to dig more into this. I opened #193 to keep track of my thinking on this.

Token healing (tracked by #161) should ensure that this kind of quirk doesn't affect generation. Users shouldn't have to worry about the effects of tokenization.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, in that token healing should be able to correct all these nuances.

sequence = generate.continuation(model, max_tokens=10)(prompts, rng=rng)
assert isinstance(sequence, list)
assert len(sequence) == 2
assert isinstance(sequence[0], str)

prompt = "Write a short sentence "
sequence = generate.continuation(model, stop="a")(prompt, rng=rng)
assert sequence[len(prompt) :].find("a") == -1


@pytest.mark.xfail
def test_transformers_integration_continuation_array_samples():
Expand Down