Skip to content

Commit

Permalink
feat: add option to retrieve offsets from tokenizer (#21)
Browse files Browse the repository at this point in the history
* feat: add option to retrieve offsets from tokenizer

* fix: avoid tuple struct in offsets

* fix: free offsets array

* fix: revert freeing of memory
  • Loading branch information
riccardopinosio authored Aug 9, 2024
1 parent d9aff87 commit 7bb47dd
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 2 deletions.
24 changes: 22 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub struct Buffer {
special_tokens_mask: *mut u32,
attention_mask: *mut u32,
tokens: *mut *mut libc::c_char,
offsets: *mut usize,
len: usize,
}

Expand Down Expand Up @@ -68,6 +69,7 @@ pub struct EncodeOptions {
return_tokens: bool,
return_special_tokens_mask: bool,
return_attention_mask: bool,
return_offsets: bool,
}

#[no_mangle]
Expand All @@ -79,7 +81,7 @@ pub extern "C" fn encode(ptr: *mut libc::c_void, message: *const libc::c_char, o
let message_cstr = unsafe { CStr::from_ptr(message) };
let message = message_cstr.to_str();
if message.is_err() {
return Buffer { ids: ptr::null_mut(), tokens: ptr::null_mut(), len: 0, type_ids: ptr::null_mut(), special_tokens_mask: ptr::null_mut(), attention_mask: ptr::null_mut() };
return Buffer { ids: ptr::null_mut(), tokens: ptr::null_mut(), len: 0, type_ids: ptr::null_mut(), special_tokens_mask: ptr::null_mut(), attention_mask: ptr::null_mut() , offsets: ptr::null_mut()};
}

let encoding = tokenizer.encode(message.unwrap(), options.add_special_tokens).expect("failed to encode input");
Expand Down Expand Up @@ -124,7 +126,20 @@ pub extern "C" fn encode(ptr: *mut libc::c_void, message: *const libc::c_char, o
std::mem::forget(vec_attention_mask);
}

Buffer { ids, type_ids, special_tokens_mask, attention_mask, tokens, len }
let mut offsets: *mut usize = ptr::null_mut();
if options.return_offsets {
let vec_offsets_tuples = encoding.get_offsets().to_vec();
let mut vec_offsets = Vec::with_capacity(vec_offsets_tuples.len() * 2);
for i in vec_offsets_tuples {
vec_offsets.push(i.0);
vec_offsets.push(i.1);
}
vec_offsets.shrink_to_fit();
offsets = vec_offsets.as_mut_ptr();
std::mem::forget(vec_offsets);
}

Buffer { ids, type_ids, special_tokens_mask, attention_mask, tokens, offsets, len }
}

#[no_mangle]
Expand Down Expand Up @@ -183,6 +198,11 @@ pub extern "C" fn free_buffer(buf: Buffer) {
Vec::from_raw_parts(buf.attention_mask, buf.len, buf.len);
}
}
if !buf.offsets.is_null() {
unsafe {
Vec::from_raw_parts(buf.offsets, buf.len*2, buf.len*2);
}
}
if !buf.tokens.is_null() {
unsafe {
let strings = Vec::from_raw_parts(buf.tokens, buf.len, buf.len);
Expand Down
27 changes: 27 additions & 0 deletions tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,15 @@ func (t *Tokenizer) Close() error {
return nil
}

type Offset [2]uint

type Encoding struct {
IDs []uint32
TypeIDs []uint32
SpecialTokensMask []uint32
AttentionMask []uint32
Tokens []string
Offsets []Offset
}

type encodeOpts struct {
Expand All @@ -88,6 +91,7 @@ type encodeOpts struct {
ReturnTokens C.bool
ReturnSpecialTokensMask C.bool
ReturnAttentionMask C.bool
ReturnOffsets C.bool
}

type EncodeOption func(eo *encodeOpts)
Expand All @@ -101,6 +105,18 @@ func uintVecToSlice(arrPtr *C.uint, len int) []uint32 {
return slice
}

func offsetVecToSlice(arrPtr *C.size_t, tokenLength int) []Offset {
arr := unsafe.Slice(arrPtr, tokenLength*2)
slice := make([]Offset, tokenLength)
counter := 0
for i := 0; i < tokenLength; i++ {
offset := Offset{uint(arr[counter]), uint(arr[counter+1])}
slice[i] = offset
counter = counter + 2
}
return slice
}

func (t *Tokenizer) Encode(str string, addSpecialTokens bool) ([]uint32, []string) {
cStr := C.CString(str)
defer C.free(unsafe.Pointer(cStr))
Expand Down Expand Up @@ -133,6 +149,7 @@ func WithReturnAllAttributes() EncodeOption {
eo.ReturnSpecialTokensMask = C.bool(true)
eo.ReturnAttentionMask = C.bool(true)
eo.ReturnTokens = C.bool(true)
eo.ReturnOffsets = C.bool(true)
}
}

Expand Down Expand Up @@ -160,6 +177,12 @@ func WithReturnAttentionMask() EncodeOption {
}
}

func WithReturnOffsets() EncodeOption {
return func(eo *encodeOpts) {
eo.ReturnOffsets = C.bool(true)
}
}

func (t *Tokenizer) EncodeWithOptions(str string, addSpecialTokens bool, opts ...EncodeOption) Encoding {
cStr := C.CString(str)
defer C.free(unsafe.Pointer(cStr))
Expand Down Expand Up @@ -201,6 +224,10 @@ func (t *Tokenizer) EncodeWithOptions(str string, addSpecialTokens bool, opts ..
encoding.AttentionMask = uintVecToSlice(res.attention_mask, len)
}

if encOptions.ReturnOffsets && res.offsets != nil {
encoding.Offsets = offsetVecToSlice(res.offsets, len)
}

return encoding
}

Expand Down
27 changes: 27 additions & 0 deletions tokenizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func TestEmbeddingConfig(t *testing.T) {
wantTokens []string
wantSpecialTokensMask []uint32
wantAttentionMask []uint32
wantOffsets []tokenizers.Offset
}{
{
name: "without special tokens",
Expand All @@ -45,6 +46,7 @@ func TestEmbeddingConfig(t *testing.T) {
wantTokens: []string{"brown", "fox", "jumps", "over", "the", "lazy", "dog"},
wantSpecialTokensMask: []uint32{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
wantAttentionMask: []uint32{0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1},
wantOffsets: []tokenizers.Offset{{0x0, 0x5}, {0x6, 0x9}, {0xa, 0xf}, {0x10, 0x14}, {0x15, 0x18}, {0x19, 0x1d}, {0x1e, 0x21}},
},
{
name: "with special tokens",
Expand All @@ -55,6 +57,7 @@ func TestEmbeddingConfig(t *testing.T) {
wantTokens: []string{"[CLS]", "brown", "fox", "jumps", "over", "the", "lazy", "dog", "[SEP]"},
wantSpecialTokensMask: []uint32{0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1},
wantAttentionMask: []uint32{0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1},
wantOffsets: []tokenizers.Offset{{0x0, 0x0}, {0x0, 0x5}, {0x6, 0x9}, {0xa, 0xf}, {0x10, 0x14}, {0x15, 0x18}, {0x19, 0x1d}, {0x1e, 0x21}, {0x0, 0x0}},
},
}
for _, tt := range tests {
Expand All @@ -65,6 +68,7 @@ func TestEmbeddingConfig(t *testing.T) {
assert.Equal(t, tt.wantTokens, encoding.Tokens, "wrong tokens")
assert.Equal(t, tt.wantSpecialTokensMask, encoding.SpecialTokensMask, "wrong special tokens mask")
assert.Equal(t, tt.wantAttentionMask, encoding.AttentionMask, "wrong attention mask")
assert.Equal(t, tt.wantOffsets, encoding.Offsets, "wrong offsets")

ids, tokens := tk.Encode(tt.str, tt.addSpecial)
assert.Equal(t, tt.wantIDs, ids, "wrong ids")
Expand All @@ -86,6 +90,7 @@ func TestEncodeWithAndWithoutOptions(t *testing.T) {
wantTokens []string
wantSpecialTokensMask []uint32
wantAttentionMask []uint32
wantOffsets []tokenizers.Offset
}{
{
name: "without special tokens",
Expand All @@ -96,6 +101,7 @@ func TestEncodeWithAndWithoutOptions(t *testing.T) {
wantTokens: []string{"brown", "fox", "jumps", "over", "the", "lazy", "dog"},
wantSpecialTokensMask: []uint32{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
wantAttentionMask: []uint32{0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1},
wantOffsets: []tokenizers.Offset{{0x0, 0x5}, {0x6, 0x9}, {0xa, 0xf}, {0x10, 0x14}, {0x15, 0x18}, {0x19, 0x1d}, {0x1e, 0x21}},
},
{
name: "with special tokens",
Expand All @@ -106,6 +112,7 @@ func TestEncodeWithAndWithoutOptions(t *testing.T) {
wantTokens: []string{"[CLS]", "brown", "fox", "jumps", "over", "the", "lazy", "dog", "[SEP]"},
wantSpecialTokensMask: []uint32{0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1},
wantAttentionMask: []uint32{0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1},
wantOffsets: []tokenizers.Offset{{0x0, 0x0}, {0x0, 0x5}, {0x6, 0x9}, {0xa, 0xf}, {0x10, 0x14}, {0x15, 0x18}, {0x19, 0x1d}, {0x1e, 0x21}, {0x0, 0x0}},
},
{
name: "empty string",
Expand All @@ -121,6 +128,7 @@ func TestEncodeWithAndWithoutOptions(t *testing.T) {
wantAttentionMask: []uint32{0x1, 0x1},
wantIDs: []uint32{101, 102},
wantTokens: []string{"[CLS]", "[SEP]"},
wantOffsets: []tokenizers.Offset{{0x0, 0x0}, {0x0, 0x0}},
},
{
name: "invalid utf8 string",
Expand All @@ -136,6 +144,7 @@ func TestEncodeWithAndWithoutOptions(t *testing.T) {
assert.Equal(t, tt.wantTokens, encoding.Tokens, "wrong tokens")
assert.Equal(t, tt.wantSpecialTokensMask, encoding.SpecialTokensMask, "wrong special tokens mask")
assert.Equal(t, tt.wantAttentionMask, encoding.AttentionMask, "wrong attention mask")
assert.Equal(t, tt.wantOffsets, encoding.Offsets, "wrong offsets mask")

ids, tokens := tk.Encode(tt.str, tt.addSpecial)
assert.Equal(t, tt.wantIDs, ids, "wrong ids")
Expand Down Expand Up @@ -174,6 +183,7 @@ func TestEncodeOptions(t *testing.T) {
wantTokens []string
wantSpecialTokensMask []uint32
wantAttentionMask []uint32
wantOffsets []tokenizers.Offset
}{
{
name: "without special tokens",
Expand All @@ -184,6 +194,7 @@ func TestEncodeOptions(t *testing.T) {
wantTokens: []string{"brown", "fox", "jumps", "over", "the", "lazy", "dog"},
wantSpecialTokensMask: []uint32{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
wantAttentionMask: []uint32{0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1},
wantOffsets: []tokenizers.Offset{{0x0, 0x5}, {0x6, 0x9}, {0xa, 0xf}, {0x10, 0x14}, {0x15, 0x18}, {0x19, 0x1d}, {0x1e, 0x21}},
},
}
for _, tt := range tests {
Expand All @@ -194,34 +205,47 @@ func TestEncodeOptions(t *testing.T) {
assert.Equal(t, []string(nil), encoding.Tokens, "wrong tokens")
assert.Equal(t, []uint32(nil), encoding.SpecialTokensMask, "wrong special tokens mask")
assert.Equal(t, []uint32(nil), encoding.AttentionMask, "wrong attention mask")
assert.Equal(t, []tokenizers.Offset(nil), encoding.Offsets, "wrong offsets")

encoding = tk.EncodeWithOptions(tt.str, tt.addSpecial, tokenizers.WithReturnTokens())
assert.Equal(t, tt.wantIDs, encoding.IDs, "wrong ids")
assert.Equal(t, []uint32(nil), encoding.TypeIDs, "wrong type ids")
assert.Equal(t, tt.wantTokens, encoding.Tokens, "wrong tokens")
assert.Equal(t, []uint32(nil), encoding.SpecialTokensMask, "wrong special tokens mask")
assert.Equal(t, []uint32(nil), encoding.AttentionMask, "wrong attention mask")
assert.Equal(t, []tokenizers.Offset(nil), encoding.Offsets, "wrong offsets")

encoding = tk.EncodeWithOptions(tt.str, tt.addSpecial, tokenizers.WithReturnTypeIDs())
assert.Equal(t, tt.wantIDs, encoding.IDs, "wrong ids")
assert.Equal(t, tt.wantTypeIDs, encoding.TypeIDs, "wrong type ids")
assert.Equal(t, []string(nil), encoding.Tokens, "wrong tokens")
assert.Equal(t, []uint32(nil), encoding.SpecialTokensMask, "wrong special tokens mask")
assert.Equal(t, []uint32(nil), encoding.AttentionMask, "wrong attention mask")
assert.Equal(t, []tokenizers.Offset(nil), encoding.Offsets, "wrong offsets")

encoding = tk.EncodeWithOptions(tt.str, tt.addSpecial, tokenizers.WithReturnSpecialTokensMask())
assert.Equal(t, tt.wantIDs, encoding.IDs, "wrong ids")
assert.Equal(t, []uint32(nil), encoding.TypeIDs, "wrong type ids")
assert.Equal(t, []string(nil), encoding.Tokens, "wrong tokens")
assert.Equal(t, tt.wantSpecialTokensMask, encoding.SpecialTokensMask, "wrong special tokens mask")
assert.Equal(t, []uint32(nil), encoding.AttentionMask, "wrong attention mask")
assert.Equal(t, []tokenizers.Offset(nil), encoding.Offsets, "wrong offsets")

encoding = tk.EncodeWithOptions(tt.str, tt.addSpecial, tokenizers.WithReturnAttentionMask())
assert.Equal(t, tt.wantIDs, encoding.IDs, "wrong ids")
assert.Equal(t, []uint32(nil), encoding.TypeIDs, "wrong type ids")
assert.Equal(t, []string(nil), encoding.Tokens, "wrong tokens")
assert.Equal(t, []uint32(nil), encoding.SpecialTokensMask, "wrong special tokens mask")
assert.Equal(t, tt.wantAttentionMask, encoding.AttentionMask, "wrong attention mask")
assert.Equal(t, []tokenizers.Offset(nil), encoding.Offsets, "wrong offsets")

encoding = tk.EncodeWithOptions(tt.str, tt.addSpecial, tokenizers.WithReturnOffsets())
assert.Equal(t, tt.wantIDs, encoding.IDs, "wrong ids")
assert.Equal(t, []uint32(nil), encoding.TypeIDs, "wrong type ids")
assert.Equal(t, []string(nil), encoding.Tokens, "wrong tokens")
assert.Equal(t, []uint32(nil), encoding.SpecialTokensMask, "wrong special tokens mask")
assert.Equal(t, []uint32(nil), encoding.AttentionMask, "wrong attention mask")
assert.Equal(t, tt.wantOffsets, encoding.Offsets, "wrong offsets")
})
}
}
Expand Down Expand Up @@ -300,6 +324,7 @@ func TestEncodeWithPadding(t *testing.T) {
wantTokens []string
wantSpecialTokensMask []uint32
wantAttentionMask []uint32
wantOffsets []tokenizers.Offset
}{
{
name: "sentence with padding",
Expand All @@ -310,6 +335,7 @@ func TestEncodeWithPadding(t *testing.T) {
wantTokens: []string{"this", "short", "sentence", "[PAD]", "[PAD]", "[PAD]", "[PAD]", "[PAD]"},
wantSpecialTokensMask: []uint32{0x0, 0x0, 0x0, 0x1, 0x1, 0x1, 0x1, 0x1},
wantAttentionMask: []uint32{0x1, 0x1, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0},
wantOffsets: []tokenizers.Offset{{0x0, 0x4}, {0x5, 0xa}, {0xb, 0x13}, {0x0, 0x0}, {0x0, 0x0}, {0x0, 0x0}, {0x0, 0x0}, {0x0, 0x0}},
},
}
for _, tt := range tests {
Expand All @@ -320,6 +346,7 @@ func TestEncodeWithPadding(t *testing.T) {
assert.Equal(t, tt.wantTokens, encoding.Tokens, "wrong tokens")
assert.Equal(t, tt.wantSpecialTokensMask, encoding.SpecialTokensMask, "wrong special tokens mask")
assert.Equal(t, tt.wantAttentionMask, encoding.AttentionMask, "wrong attention mask")
assert.Equal(t, tt.wantOffsets, encoding.Offsets, "wrong offsets")

ids, tokens := tk.Encode(tt.str, tt.addSpecial)
assert.Equal(t, tt.wantIDs, ids, "wrong ids")
Expand Down
2 changes: 2 additions & 0 deletions tokenizers.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ struct EncodeOptions {
bool return_tokens;
bool return_special_tokens_mask;
bool return_attention_mask;
bool return_offsets;
};

struct TokenizerOptions {
Expand All @@ -19,6 +20,7 @@ struct Buffer {
uint32_t *special_tokens_mask;
uint32_t *attention_mask;
char *tokens;
size_t *offsets;
uint32_t len;
};

Expand Down

0 comments on commit 7bb47dd

Please sign in to comment.