Skip to content

Commit daa4dd2

Browse files
authored
Making the regex in ByteLevel optional. (#939)
* Making the regex in ByteLevel optional. * Changed the stub. * Beter stub. * Typo fix. * Remove bad comments.
1 parent cdabef1 commit daa4dd2

File tree

4 files changed

+84
-7
lines changed

4 files changed

+84
-7
lines changed

bindings/node/native/src/pre_tokenizers.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,12 +125,15 @@ declare_types! {
125125
}
126126
}
127127

128-
/// byte_level(addPrefixSpace: bool = true)
128+
/// byte_level(addPrefixSpace: bool = true, useRegex: bool = true)
129129
fn byte_level(mut cx: FunctionContext) -> JsResult<JsPreTokenizer> {
130130
let mut byte_level = tk::pre_tokenizers::byte_level::ByteLevel::default();
131131
if let Some(add_prefix_space) = cx.extract_opt::<bool>(0)? {
132132
byte_level = byte_level.add_prefix_space(add_prefix_space);
133133
}
134+
if let Some(use_regex) = cx.extract_opt::<bool>(1)? {
135+
byte_level = byte_level.use_regex(use_regex);
136+
}
134137

135138
let mut pretok = JsPreTokenizer::new::<_, JsPreTokenizer, _>(&mut cx, vec![])?;
136139
let guard = cx.lock();

bindings/python/py_src/tokenizers/pre_tokenizers/__init__.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ class ByteLevel(PreTokenizer):
102102
lets us treat `hello` exactly like `say hello`.
103103
"""
104104

105-
def __init__(self, add_prefix_space=True):
105+
def __init__(self, add_prefix_space=True, use_regex=True):
106106
pass
107107
@staticmethod
108108
def alphabet():

bindings/python/src/pre_tokenizers.rs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ macro_rules! setter {
229229
/// Whether to add a space to the first word if there isn't already one. This
230230
/// lets us treat `hello` exactly like `say hello`.
231231
#[pyclass(extends=PyPreTokenizer, module = "tokenizers.pre_tokenizers", name=ByteLevel)]
232-
#[text_signature = "(self, add_prefix_space=True)"]
232+
#[text_signature = "(self, add_prefix_space=True, use_regex=True)"]
233233
pub struct PyByteLevel {}
234234
#[pymethods]
235235
impl PyByteLevel {
@@ -243,13 +243,28 @@ impl PyByteLevel {
243243
setter!(self_, ByteLevel, add_prefix_space, add_prefix_space);
244244
}
245245

246+
#[getter]
247+
fn get_use_regex(self_: PyRef<Self>) -> bool {
248+
getter!(self_, ByteLevel, use_regex)
249+
}
250+
251+
#[setter]
252+
fn set_use_regex(self_: PyRef<Self>, use_regex: bool) {
253+
setter!(self_, ByteLevel, use_regex, use_regex);
254+
}
255+
246256
#[new]
247-
#[args(add_prefix_space = "true", _kwargs = "**")]
248-
fn new(add_prefix_space: bool, _kwargs: Option<&PyDict>) -> (Self, PyPreTokenizer) {
257+
#[args(add_prefix_space = "true", use_regex = "true", _kwargs = "**")]
258+
fn new(
259+
add_prefix_space: bool,
260+
use_regex: bool,
261+
_kwargs: Option<&PyDict>,
262+
) -> (Self, PyPreTokenizer) {
249263
(
250264
PyByteLevel {},
251265
ByteLevel::default()
252266
.add_prefix_space(add_prefix_space)
267+
.use_regex(use_regex)
253268
.into(),
254269
)
255270
}

tokenizers/src/pre_tokenizers/byte_level.rs

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,22 +53,33 @@ pub struct ByteLevel {
5353
pub add_prefix_space: bool,
5454
/// Whether the post processing step should trim offsets to avoid including whitespaces.
5555
pub trim_offsets: bool,
56+
57+
/// Whether to use the standard GPT2 regex for whitespace splitting
58+
/// Set it to False if you want to use your own splitting.
59+
#[serde(default = "default_true")]
60+
pub use_regex: bool,
61+
}
62+
63+
fn default_true() -> bool {
64+
true
5665
}
5766

5867
impl Default for ByteLevel {
5968
fn default() -> Self {
6069
Self {
6170
add_prefix_space: true,
6271
trim_offsets: true,
72+
use_regex: true,
6373
}
6474
}
6575
}
6676

6777
impl ByteLevel {
68-
pub fn new(add_prefix_space: bool, trim_offsets: bool) -> Self {
78+
pub fn new(add_prefix_space: bool, trim_offsets: bool, use_regex: bool) -> Self {
6979
Self {
7080
add_prefix_space,
7181
trim_offsets,
82+
use_regex,
7283
}
7384
}
7485

@@ -87,6 +98,12 @@ impl ByteLevel {
8798
self.trim_offsets = v;
8899
self
89100
}
101+
102+
#[must_use]
103+
pub fn use_regex(mut self, v: bool) -> Self {
104+
self.use_regex = v;
105+
self
106+
}
90107
}
91108

92109
/// As a `PreTokenizer`, `ByteLevel` is in charge of transforming all the unicode characters into
@@ -99,7 +116,11 @@ impl PreTokenizer for ByteLevel {
99116
if self.add_prefix_space && !normalized.get().starts_with(' ') {
100117
normalized.prepend(" ");
101118
}
102-
normalized.split(re_ref, SplitDelimiterBehavior::Isolated)
119+
if self.use_regex {
120+
normalized.split(re_ref, SplitDelimiterBehavior::Isolated)
121+
} else {
122+
Ok(vec![normalized])
123+
}
103124
})?;
104125
pretokenized.normalize(|normalized| {
105126
let s = normalized.get();
@@ -247,6 +268,21 @@ mod tests {
247268
);
248269
}
249270

271+
#[test]
272+
fn pre_tokenization_no_regex() {
273+
let bytelevel = ByteLevel::default().use_regex(false);
274+
let mut pretokenized: PreTokenizedString = "Hello my friend, how is your day going?".into();
275+
bytelevel.pre_tokenize(&mut pretokenized).unwrap();
276+
assert_eq!(
277+
pretokenized
278+
.get_splits(OffsetReferential::Original, OffsetType::Byte)
279+
.into_iter()
280+
.map(|(s, o, _)| (s, o))
281+
.collect::<Vec<_>>(),
282+
vec![("ĠHelloĠmyĠfriend,ĠhowĠisĠyourĠdayĠgoing?", (0, 39))]
283+
);
284+
}
285+
250286
#[test]
251287
fn decoding() {
252288
let bytelevel = ByteLevel::default().add_prefix_space(false);
@@ -513,4 +549,27 @@ mod tests {
513549
vec!["Hello there dear friend! [PA D]"]
514550
);
515551
}
552+
553+
#[test]
554+
fn deserialization() {
555+
// Before use_regex
556+
let byte_level: ByteLevel = serde_json::from_str(
557+
r#"{"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": false}"#,
558+
)
559+
.unwrap();
560+
assert!(byte_level.use_regex);
561+
562+
// Loading works, new future BC test.
563+
let byte_level: ByteLevel = serde_json::from_str(
564+
r#"{"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": false, "use_regex": true}"#,
565+
)
566+
.unwrap();
567+
assert!(byte_level.use_regex);
568+
569+
let byte_level: ByteLevel = serde_json::from_str(
570+
r#"{"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": false, "use_regex": false}"#,
571+
)
572+
.unwrap();
573+
assert!(!byte_level.use_regex);
574+
}
516575
}

0 commit comments

Comments
 (0)