Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functions to read you token to me standard file #9

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
# go-YouTokenToMe

go-YouTokenToMe is a Go port of [YoutTokenToMe](https://github.com/VKCOM/YouTokenToMe) - a computationally efficient implementation of Byte Pair Encoding [[Sennrich et al.](https://www.aclweb.org/anthology/P16-1162/)]. Only inference is supported, no training.
go-YouTokenToMe is a Go port of [YoutTokenToMe](https://github.com/VKCOM/YouTokenToMe) - a computationally efficient implementation of Byte Pair Encoding [[Sennrich et al.](https://www.aclweb.org/anthology/P16-1162/)]. Only inference is supported, no training.

## Usage example
```go
file, err := os.Open("data/yttm.model")
if err != nil {
fmt.Println(err)
return
}
defer file.Close()

r := io.Reader(file)

m, err := bpe.ReadModel(r)
if err != nil {
panic(err)
}
config := bpe.NewConfig(false, false, false)
fmt.Println(m.EncodeSentence("мама мыла раму", *config))
```
236 changes: 149 additions & 87 deletions bpe.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"container/heap"
"encoding/binary"
"errors"
"fmt"
"io"
"strconv"
"strings"
Expand Down Expand Up @@ -102,16 +103,38 @@ func (s specialTokens) toBinary() []byte {
return bytesArray
}

func binaryToSpecialTokens(bytesArray []byte) (specialTokens, error) {
func rowToSpecialTokens(row string) (specialTokens, error) {
var s specialTokens
if len(bytesArray) < 16 {
logrus.Error("Bytes array length is too small")
return s, errors.New("bytes array is too small")
rowSplitted := strings.Fields(row)
if len(rowSplitted) != 4 {
logrus.Errorf("String slice with len %d != 4", len(rowSplitted))
return s, errors.New("string slice is wrong")
}
s.unk = int32(binary.BigEndian.Uint32(bytesArray))
s.pad = int32(binary.BigEndian.Uint32(bytesArray[4:]))
s.bos = int32(binary.BigEndian.Uint32(bytesArray[8:]))
s.eos = int32(binary.BigEndian.Uint32(bytesArray[12:]))
unk, err := strconv.Atoi(rowSplitted[0])
if err != nil {
logrus.Error("Broken input:", err)
return s, err
}
pad, err := strconv.Atoi(rowSplitted[1])
if err != nil {
logrus.Error("Broken input:", err)
return s, err
}
bos, err := strconv.Atoi(rowSplitted[2])
if err != nil {
logrus.Error("Broken input:", err)
return s, err
}
eos, err := strconv.Atoi(rowSplitted[2])
if err != nil {
logrus.Error("Broken input:", err)
return s, err
}

s.unk = int32(unk)
s.pad = int32(pad)
s.bos = int32(bos)
s.eos = int32(eos)
return s, nil
}

Expand All @@ -123,100 +146,139 @@ func (r rule) toBinary() []byte {
return bytesArray
}

func binaryToRule(bytesArray []byte) (rule, error) {
func rowToRule(row string) (rule, error) {
rowSplitted := strings.Fields(row)
var r rule
if len(bytesArray) < 12 {
logrus.Error("Bytes array length is too small")
return r, errors.New("bytes array is too small")
if len(rowSplitted) != 3 {
logrus.Errorf("String slice with len %d != 3", len(rowSplitted))
return r, errors.New("string slice is wrong")
}
rLeft, err := strconv.Atoi(rowSplitted[0])
if err != nil {
logrus.Error("Broken input:", err)
return r, err
}
rRight, err := strconv.Atoi(rowSplitted[1])
if err != nil {
logrus.Error("Broken input:", err)
return r, err
}
rRes, err := strconv.Atoi(rowSplitted[2])
if err != nil {
logrus.Error("Broken input:", err)
return r, err
}
r.left = TokenID(binary.BigEndian.Uint32(bytesArray))
r.right = TokenID(binary.BigEndian.Uint32(bytesArray[4:]))
r.result = TokenID(binary.BigEndian.Uint32(bytesArray[8:]))

r.left = TokenID(rLeft)
r.right = TokenID(rRight)
r.result = TokenID(rRes)
return r, nil
}

// ReadModel loads the BPE model from the binary dump
func ReadModel(reader io.Reader) (*Model, error) {
buf := make([]byte, 4)

scanner := bufio.NewScanner(reader)
var nChars, nRules int
if _, err := io.ReadFull(reader, buf); err != nil {
logrus.Error("Broken input: ", err)
return &Model{}, err
}
nChars = int(binary.BigEndian.Uint32(buf))
if _, err := io.ReadFull(reader, buf); err != nil {
logrus.Error("Broken input: ", err)
return &Model{}, err
}
nRules = int(binary.BigEndian.Uint32(buf))
var char rune
var charID TokenID
var row string
var err error

model := newModel(nRules)
model := &Model{}
minCharID := TokenID(0)
for i := 0; i < nChars; i++ {
var char rune
var charID TokenID
if _, err := io.ReadFull(reader, buf); err != nil {
logrus.Error("Broken input: ", err)
return &Model{}, err
}
char = rune(binary.BigEndian.Uint32(buf))
if _, err := io.ReadFull(reader, buf); err != nil {
logrus.Error("Broken input: ", err)
return &Model{}, err
}
charID = TokenID(binary.BigEndian.Uint32(buf))
model.char2id[char] = charID
model.id2char[charID] = char
model.recipe[charID] = EncodedString{charID}
model.revRecipe[string(char)] = charID
if charID < minCharID || minCharID == 0 {
minCharID = charID
model.spaceID = charID
}
}
ruleBuf := make([]byte, 12)
for i := 0; i < nRules; i++ {
if _, err := io.ReadFull(reader, ruleBuf); err != nil {
logrus.Error("Broken input: ", err)
return &Model{}, err
}
rule, err := binaryToRule(ruleBuf)
if err != nil {
return model, err

i := 0
j := 0

for scanner.Scan() {
row = scanner.Text()
if i == 0 {
nChars, err = strconv.Atoi(strings.Fields(row)[0])
if err != nil {
logrus.Error("Broken input:", err)
return &Model{}, err
}

nRules, err = strconv.Atoi(strings.Fields(row)[1])
if err != nil {
logrus.Error("Broken input:", err)
return &Model{}, err
}
logrus.Println("Reading bpe model file with number of")
logrus.Println("Characters:", nChars)
logrus.Println("Rules of merge:", nRules)

model = newModel(nRules)
}
if _, ok := model.recipe[rule.left]; !ok {
logrus.Errorf("%d: token id not described before", rule.left)
return model, errors.New("token id is impossible")
if i < nChars+1 && i != 0 {
row = scanner.Text()
unicodeChar, err := strconv.Atoi(strings.Fields(row)[0])
if err != nil {
logrus.Error("Broken input:", err)
return &Model{}, err
}
tokenId, err := strconv.Atoi(strings.Fields(row)[1])
if err != nil {
logrus.Error("Broken input:", err)
return &Model{}, err
}

char = rune(unicodeChar)
charID = TokenID(tokenId)
model.char2id[char] = charID
model.id2char[charID] = char
model.recipe[charID] = EncodedString{charID}
model.revRecipe[string(char)] = charID
if charID < minCharID || minCharID == 0 {
minCharID = charID
model.spaceID = charID
}
}
if _, ok := model.recipe[rule.right]; !ok {
logrus.Errorf("%d: token id not described before", rule.right)
return model, errors.New("token id is impossible")
if i < nChars+nRules+1 && i >= nChars+1 {
fmt.Println(j)
row = scanner.Text()

rule, err := rowToRule(row)
if err != nil {
return model, err
}
if _, ok := model.recipe[rule.left]; !ok {
logrus.Errorf("%d: token id not described before", rule.left)
return model, errors.New("token id is impossible")
}
if _, ok := model.recipe[rule.right]; !ok {
logrus.Errorf("%d: token id not described before", rule.right)
return model, errors.New("token id is impossible")
}
model.rules[j] = rule
model.rule2id[newTokenIDPair(rule.left, rule.right)] = j
model.recipe[rule.result] = append(model.recipe[rule.left], model.recipe[rule.right]...)
resultString, err := DecodeToken(model.recipe[rule.result], model.id2char)
if err != nil {
logrus.Error("Unexpected token id inside the rules: ", err)
return model, err
}
model.revRecipe[resultString] = rule.result
j++
}
model.rules[i] = rule
model.rule2id[newTokenIDPair(rule.left, rule.right)] = i
model.recipe[rule.result] = append(model.recipe[rule.left], model.recipe[rule.right]...)
resultString, err := DecodeToken(model.recipe[rule.result], model.id2char)
if err != nil {
logrus.Error("Unexpected token id inside the rules: ", err)
return model, err

if i == nChars+nRules+1 {
row = scanner.Text()
specials, err := rowToSpecialTokens(row)
if err != nil {
return model, err
}
model.specialTokens = specials
model.revRecipe[bosToken] = TokenID(specials.bos)
model.revRecipe[eosToken] = TokenID(specials.eos)
model.revRecipe[unkToken] = TokenID(specials.unk)
model.revRecipe[padToken] = TokenID(specials.pad)
}
model.revRecipe[resultString] = rule.result
}
specialTokensBuf := make([]byte, 16)
if _, err := io.ReadFull(reader, specialTokensBuf); err != nil {
logrus.Error("Broken input: ", err)
return &Model{}, err
}
specials, err := binaryToSpecialTokens(specialTokensBuf)
if err != nil {
return model, err

i++
}
model.specialTokens = specials
model.revRecipe[bosToken] = TokenID(specials.bos)
model.revRecipe[eosToken] = TokenID(specials.eos)
model.revRecipe[unkToken] = TokenID(specials.unk)
model.revRecipe[padToken] = TokenID(specials.pad)
return model, err
return model, nil
}

// IDToToken returns string token corresponding to the given token id.
Expand Down
30 changes: 0 additions & 30 deletions bpe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,42 +45,12 @@ func TestSpecialTokens_ToBinary(t *testing.T) {
require.Equal(t, bytesArray, specials.toBinary())
}

func TestBinaryToSpecialTokens(t *testing.T) {
req := require.New(t)
bytesArray := []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0, 0}
expected := specialTokens{1, 259, 2*256*256 + 37*256 + 2, -256 * 256 * 256 * 127}
specials, err := binaryToSpecialTokens(bytesArray)
req.NoError(err)
req.Equal(expected, specials)
bytesArray = []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0}
specials, err = binaryToSpecialTokens(bytesArray)
req.Error(err)
bytesArray = []byte{}
specials, err = binaryToSpecialTokens(bytesArray)
req.Error(err)
}

func TestRule_ToBinary(t *testing.T) {
rule := rule{1, 2, 257}
bytesArray := []byte{0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 1, 1}
require.Equal(t, bytesArray, rule.toBinary())
}

func TestBinaryToRule(t *testing.T) {
req := require.New(t)
expected := rule{1, 2, 257}
bytesArray := []byte{0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 1, 1}
rule, err := binaryToRule(bytesArray)
req.NoError(err)
req.Equal(expected, rule)
bytesArray = []byte{0, 0, 0, 0, 0, 0, 2, 0, 0, 1, 1}
rule, err = binaryToRule(bytesArray)
req.Error(err)
bytesArray = []byte{}
rule, err = binaryToRule(bytesArray)
req.Error(err)
}

func TestReadModel(t *testing.T) {
req := require.New(t)
reader := bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 6,
Expand Down
12 changes: 12 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=