diff --git a/aho_corasick.go b/aho_corasick.go index 0d0cd2e..367ebc7 100644 --- a/aho_corasick.go +++ b/aho_corasick.go @@ -140,19 +140,18 @@ type Finder interface { // FindAll returns the matches found in the haystack func (ac AhoCorasick) FindAll(haystack string) []Match { - iter := ac.Iter(haystack) - matches := make([]Match, 0) + return ac.FindN(haystack, -1) +} - for { - next := iter.Next() - if next == nil { - break - } +// FindN returns the matches found in the haystack, up to n matches. +func (ac AhoCorasick) FindN(haystack string, n int) []Match { + ac.abi.startOperation(4) + defer ac.abi.endOperation() - matches = append(matches, *next) - } + cs := ac.abi.newOwnedCString(haystack) + defer ac.abi.freeOwnedCStringPtr(cs.ptr) - return matches + return ac.abi.findN(ac.ptr, haystack, cs, n, ac.matchOnlyWholeWords) } // Opts defines a set of options applied before the patterns are built @@ -241,10 +240,7 @@ func (f *findIter) Next() *Match { } if f.matchOnlyWholeWords { - if result.Start()-1 >= 0 && (unicode.IsLetter(rune(f.haystack[result.Start()-1])) || unicode.IsDigit(rune(f.haystack[result.Start()-1]))) { - return f.Next() - } - if result.end < len(f.haystack) && (unicode.IsLetter(rune(f.haystack[result.end])) || unicode.IsDigit(rune(f.haystack[result.end]))) { + if isNotWholeWord(f.haystack, result.Start(), result.End()) { return f.Next() } } @@ -358,3 +354,14 @@ func (m *Match) End() int { func (m *Match) Start() int { return m.start } + +func isNotWholeWord(s string, start int, end int) bool { + if start-1 >= 0 && (unicode.IsLetter(rune(s[start-1])) || unicode.IsDigit(rune(s[start-1]))) { + return true + } + if end < len(s) && (unicode.IsLetter(rune(s[end])) || unicode.IsDigit(rune(s[end]))) { + return true + } + + return false +} diff --git a/aho_corasick_test.go b/aho_corasick_test.go index 33d59b0..c4d905f 100644 --- a/aho_corasick_test.go +++ b/aho_corasick_test.go @@ -253,7 +253,36 @@ func TestAhoCorasick_LeftmostInsensitiveWholeWord(t *testing.T) { matches := ac.FindAll(t2.haystack) if len(matches) != len(t2.matches) { - t.Errorf("test %v expected %v matches got %v", i, len(matches), len(t2.matches)) + t.Errorf("test %v expected %v matches got %v", i, len(t2.matches), len(matches)) + } + for i, m := range matches { + if m != t2.matches[i] { + t.Errorf("test %v expected %v matche got %v", i, m, t2.matches[i]) + } + } + } + } +} + +func TestAhoCorasick_LeftmostInsensitiveWholeWord_N(t *testing.T) { + for i, t2 := range leftmostInsensitiveWholeWordTestCases { + builders := []*AhoCorasickBuilder{NewAhoCorasickBuilder(Opts{ + AsciiCaseInsensitive: true, + MatchOnlyWholeWords: true, + MatchKind: LeftMostLongestMatch, + }), NewAhoCorasickBuilder(Opts{ + AsciiCaseInsensitive: true, + MatchOnlyWholeWords: true, + MatchKind: LeftMostLongestMatch, + DFA: true, + })} + + for _, builder := range builders { + ac := builder.Build(t2.patterns) + matches := ac.FindN(t2.haystack, 1) + + if len(matches) != 1 { + t.Errorf("test %v expected %v matches got %v", i, 1, len(matches)) } for i, m := range matches { if m != t2.matches[i] { diff --git a/aho_corasick_tinygowasm.go b/aho_corasick_tinygowasm.go index 47aa42e..6198ff8 100644 --- a/aho_corasick_tinygowasm.go +++ b/aho_corasick_tinygowasm.go @@ -15,6 +15,9 @@ void find_iter_delete(void* iter); void* overlapping_iter(void* ac, void* value, int value_len); int overlapping_iter_next(void* iter, size_t* patternOut, size_t* startOut, size_t* endOut); void overlapping_iter_delete(void* iter); + +void* matches(void* ac, void* value, size_t value_len, size_t limit, size_t* numOut); +void matches_delete(void* matches, size_t num); */ import "C" @@ -103,6 +106,31 @@ func (abi *ahoCorasickABI) overlappingIterDelete(iterPtr uintptr) { C.overlapping_iter_delete(unsafe.Pointer(iterPtr)) } +func (abi ahoCorasickABI) findN(iter uintptr, valueStr string, value cString, n int, matchWholeWords bool) []Match { + var resLen C.size_t + matchesPtr := C.matches(unsafe.Pointer(iter), unsafe.Pointer(value.ptr), C.size_t(value.length), C.size_t(n), &resLen) + defer C.matches_delete(matchesPtr, resLen) + + res := unsafe.Slice((*uintptr)(unsafe.Pointer(matchesPtr)), resLen) + + num := int(resLen) / 3 + matches := make([]Match, 0, num) + for i := 0; i < num; i++ { + start := int(res[i*3+1]) + end := int(res[i*3+2]) + if matchWholeWords && isNotWholeWord(valueStr, start, end) { + continue + } + var m Match + m.pattern = int(res[i*3]) + m.start = start + m.end = end + matches = append(matches, m) + } + + return matches +} + type cString struct { ptr uintptr length int diff --git a/aho_corasick_wazero.go b/aho_corasick_wazero.go index eb965eb..f5e9f65 100644 --- a/aho_corasick_wazero.go +++ b/aho_corasick_wazero.go @@ -5,6 +5,7 @@ package aho_corasick import ( "context" _ "embed" + "encoding/binary" "errors" "sync" @@ -35,6 +36,8 @@ type ahoCorasickABI struct { overlapping_iter api.Function overlapping_iter_next api.Function overlapping_iter_delete api.Function + matches api.Function + matches_delete api.Function malloc api.Function free api.Function @@ -81,6 +84,8 @@ func newABI() *ahoCorasickABI { overlapping_iter: mod.ExportedFunction("overlapping_iter"), overlapping_iter_next: mod.ExportedFunction("overlapping_iter_next"), overlapping_iter_delete: mod.ExportedFunction("overlapping_iter_delete"), + matches: mod.ExportedFunction("matches"), + matches_delete: mod.ExportedFunction("matches_delete"), malloc: mod.ExportedFunction("malloc"), free: mod.ExportedFunction("free"), @@ -246,6 +251,56 @@ func (abi *ahoCorasickABI) overlappingIterDelete(iter uintptr) { } } +func (abi *ahoCorasickABI) findN(iter uintptr, valueStr string, value cString, n int, matchWholeWords bool) []Match { + lenPtr := abi.memory.allocate(4) + + callStack := abi.callStack + callStack[0] = uint64(iter) + callStack[1] = uint64(value.ptr) + callStack[2] = uint64(value.length) + callStack[3] = uint64(n) + callStack[4] = uint64(lenPtr) + if err := abi.matches.CallWithStack(context.Background(), callStack); err != nil { + panic(err) + } + + resLen, ok := abi.wasmMemory.ReadUint32Le(uint32(lenPtr)) + if !ok { + panic(errFailedRead) + } + + resPtr := callStack[0] + defer func() { + callStack[0] = uint64(resPtr) + callStack[1] = uint64(resLen) + if err := abi.matches_delete.CallWithStack(context.Background(), callStack); err != nil { + panic(err) + } + }() + + res, ok := abi.wasmMemory.Read(uint32(resPtr), resLen*4) + if !ok { + panic(errFailedRead) + } + + num := resLen / 3 + matches := make([]Match, 0, num) + for i := 0; i < int(num); i++ { + start := int(binary.LittleEndian.Uint32(res[i*12+4:])) + end := int(binary.LittleEndian.Uint32(res[i*12+8:])) + if matchWholeWords && isNotWholeWord(valueStr, start, end) { + continue + } + var m Match + m.pattern = int(binary.LittleEndian.Uint32(res[i*12:])) + m.start = start + m.end = end + matches = append(matches, m) + } + + return matches +} + type sharedMemory struct { size uint32 bufPtr uint32 diff --git a/benchmarks2_test.go b/benchmarks2_test.go index 9204560..f15e436 100644 --- a/benchmarks2_test.go +++ b/benchmarks2_test.go @@ -327,15 +327,23 @@ func BenchmarkBurntSushi(b *testing.B) { DFA: dfa, }).Build(tt.patterns) b.Run(dfaStr, func(b *testing.B) { - for i := 0; i < b.N; i++ { - cnt := 0 - iter := ac.Iter(tt.corpus) - for iter.Next() != nil { - cnt++ - } - if cnt != tt.count { - b.Errorf("expected %d matches, got %d", tt.count, cnt) - } + for _, iterate := range []bool{false, true} { + b.Run(fmt.Sprintf("iterate=%v", iterate), func(b *testing.B) { + for i := 0; i < b.N; i++ { + var cnt int + if iterate { + iter := ac.Iter(tt.corpus) + for iter.Next() != nil { + cnt++ + } + } else { + cnt = len(ac.FindAll(tt.corpus)) + } + if cnt != tt.count { + b.Errorf("expected %d matches, got %d", tt.count, cnt) + } + } + }) } }) } diff --git a/buildtools/aho-corasick/src/lib.rs b/buildtools/aho-corasick/src/lib.rs index 771c66d..af4243d 100644 --- a/buildtools/aho-corasick/src/lib.rs +++ b/buildtools/aho-corasick/src/lib.rs @@ -3,9 +3,6 @@ extern crate aho_corasick; -use std::ffi::CStr; -use std::mem::MaybeUninit; -use std::os::raw::c_char; use std::slice; use std::str; use aho_corasick::{AhoCorasick, AhoCorasickBuilder, AhoCorasickKind, FindIter, FindOverlappingIter, MatchKind}; @@ -84,23 +81,36 @@ pub extern "C" fn overlapping_iter_delete(_iter: Box) { } #[no_mangle] -pub extern "C" fn matches(ac: &mut AhoCorasick, value_ptr: usize, value_len: usize, n: usize, matches: *mut usize) -> usize { +pub extern "C" fn matches(ac: &mut AhoCorasick, value_ptr: usize, value_len: usize, limit: usize, num: &mut usize) -> *const usize { + let mut matches = Vec::new(); let value = ptr_to_string(value_ptr, value_len); - std::mem::forget(&value); - let mut num = 0; + let mut count = 0; for value in ac.find_iter(value.as_bytes()) { - if num == n { + if count == limit { break; } - unsafe { - *matches.offset(2*num as isize) = value.start(); - *matches.offset((2*num+1) as isize) = value.end(); - } - num += 1; + + matches.push(value.pattern().as_usize()); + matches.push(value.start()); + matches.push(value.end()); + + count += 1; } - return num + let b = matches.into_boxed_slice(); + let ptr = b.as_ptr(); + let len = b.len(); // Same as count since into_boxed_slice() truncates + std::mem::forget(b); + *num = len; + return ptr; +} + +#[no_mangle] +pub extern "C" fn matches_delete(ptr: *const usize, len: usize) { + unsafe { + let _ = Vec::from_raw_parts(ptr as *mut usize, len, len); + } } extern "C" { diff --git a/wasm/aho_corasick.wasm b/wasm/aho_corasick.wasm index 1054f0c..f2ea148 100755 Binary files a/wasm/aho_corasick.wasm and b/wasm/aho_corasick.wasm differ diff --git a/wasm/libaho_corasick.a b/wasm/libaho_corasick.a index 27445fa..644925d 100644 Binary files a/wasm/libaho_corasick.a and b/wasm/libaho_corasick.a differ