Skip to content

Commit

Permalink
Reduce FFI for FindAll and add FindN (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
anuraaga authored Apr 30, 2024
1 parent 13f9283 commit 2592912
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 37 deletions.
35 changes: 21 additions & 14 deletions aho_corasick.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,18 @@ type Finder interface {

// FindAll returns the matches found in the haystack
func (ac AhoCorasick) FindAll(haystack string) []Match {
iter := ac.Iter(haystack)
matches := make([]Match, 0)
return ac.FindN(haystack, -1)
}

for {
next := iter.Next()
if next == nil {
break
}
// FindN returns the matches found in the haystack, up to n matches.
func (ac AhoCorasick) FindN(haystack string, n int) []Match {
ac.abi.startOperation(4)
defer ac.abi.endOperation()

matches = append(matches, *next)
}
cs := ac.abi.newOwnedCString(haystack)
defer ac.abi.freeOwnedCStringPtr(cs.ptr)

return matches
return ac.abi.findN(ac.ptr, haystack, cs, n, ac.matchOnlyWholeWords)
}

// Opts defines a set of options applied before the patterns are built
Expand Down Expand Up @@ -241,10 +240,7 @@ func (f *findIter) Next() *Match {
}

if f.matchOnlyWholeWords {
if result.Start()-1 >= 0 && (unicode.IsLetter(rune(f.haystack[result.Start()-1])) || unicode.IsDigit(rune(f.haystack[result.Start()-1]))) {
return f.Next()
}
if result.end < len(f.haystack) && (unicode.IsLetter(rune(f.haystack[result.end])) || unicode.IsDigit(rune(f.haystack[result.end]))) {
if isNotWholeWord(f.haystack, result.Start(), result.End()) {
return f.Next()
}
}
Expand Down Expand Up @@ -358,3 +354,14 @@ func (m *Match) End() int {
func (m *Match) Start() int {
return m.start
}

func isNotWholeWord(s string, start int, end int) bool {
if start-1 >= 0 && (unicode.IsLetter(rune(s[start-1])) || unicode.IsDigit(rune(s[start-1]))) {
return true
}
if end < len(s) && (unicode.IsLetter(rune(s[end])) || unicode.IsDigit(rune(s[end]))) {
return true
}

return false
}
31 changes: 30 additions & 1 deletion aho_corasick_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,36 @@ func TestAhoCorasick_LeftmostInsensitiveWholeWord(t *testing.T) {
matches := ac.FindAll(t2.haystack)

if len(matches) != len(t2.matches) {
t.Errorf("test %v expected %v matches got %v", i, len(matches), len(t2.matches))
t.Errorf("test %v expected %v matches got %v", i, len(t2.matches), len(matches))
}
for i, m := range matches {
if m != t2.matches[i] {
t.Errorf("test %v expected %v matche got %v", i, m, t2.matches[i])
}
}
}
}
}

func TestAhoCorasick_LeftmostInsensitiveWholeWord_N(t *testing.T) {
for i, t2 := range leftmostInsensitiveWholeWordTestCases {
builders := []*AhoCorasickBuilder{NewAhoCorasickBuilder(Opts{
AsciiCaseInsensitive: true,
MatchOnlyWholeWords: true,
MatchKind: LeftMostLongestMatch,
}), NewAhoCorasickBuilder(Opts{
AsciiCaseInsensitive: true,
MatchOnlyWholeWords: true,
MatchKind: LeftMostLongestMatch,
DFA: true,
})}

for _, builder := range builders {
ac := builder.Build(t2.patterns)
matches := ac.FindN(t2.haystack, 1)

if len(matches) != 1 {
t.Errorf("test %v expected %v matches got %v", i, 1, len(matches))
}
for i, m := range matches {
if m != t2.matches[i] {
Expand Down
28 changes: 28 additions & 0 deletions aho_corasick_tinygowasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ void find_iter_delete(void* iter);
void* overlapping_iter(void* ac, void* value, int value_len);
int overlapping_iter_next(void* iter, size_t* patternOut, size_t* startOut, size_t* endOut);
void overlapping_iter_delete(void* iter);
void* matches(void* ac, void* value, size_t value_len, size_t limit, size_t* numOut);
void matches_delete(void* matches, size_t num);
*/
import "C"

Expand Down Expand Up @@ -103,6 +106,31 @@ func (abi *ahoCorasickABI) overlappingIterDelete(iterPtr uintptr) {
C.overlapping_iter_delete(unsafe.Pointer(iterPtr))
}

func (abi ahoCorasickABI) findN(iter uintptr, valueStr string, value cString, n int, matchWholeWords bool) []Match {
var resLen C.size_t
matchesPtr := C.matches(unsafe.Pointer(iter), unsafe.Pointer(value.ptr), C.size_t(value.length), C.size_t(n), &resLen)
defer C.matches_delete(matchesPtr, resLen)

res := unsafe.Slice((*uintptr)(unsafe.Pointer(matchesPtr)), resLen)

num := int(resLen) / 3
matches := make([]Match, 0, num)
for i := 0; i < num; i++ {
start := int(res[i*3+1])
end := int(res[i*3+2])
if matchWholeWords && isNotWholeWord(valueStr, start, end) {
continue
}
var m Match
m.pattern = int(res[i*3])
m.start = start
m.end = end
matches = append(matches, m)
}

return matches
}

type cString struct {
ptr uintptr
length int
Expand Down
55 changes: 55 additions & 0 deletions aho_corasick_wazero.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package aho_corasick
import (
"context"
_ "embed"
"encoding/binary"
"errors"
"sync"

Expand Down Expand Up @@ -35,6 +36,8 @@ type ahoCorasickABI struct {
overlapping_iter api.Function
overlapping_iter_next api.Function
overlapping_iter_delete api.Function
matches api.Function
matches_delete api.Function

malloc api.Function
free api.Function
Expand Down Expand Up @@ -81,6 +84,8 @@ func newABI() *ahoCorasickABI {
overlapping_iter: mod.ExportedFunction("overlapping_iter"),
overlapping_iter_next: mod.ExportedFunction("overlapping_iter_next"),
overlapping_iter_delete: mod.ExportedFunction("overlapping_iter_delete"),
matches: mod.ExportedFunction("matches"),
matches_delete: mod.ExportedFunction("matches_delete"),

malloc: mod.ExportedFunction("malloc"),
free: mod.ExportedFunction("free"),
Expand Down Expand Up @@ -246,6 +251,56 @@ func (abi *ahoCorasickABI) overlappingIterDelete(iter uintptr) {
}
}

func (abi *ahoCorasickABI) findN(iter uintptr, valueStr string, value cString, n int, matchWholeWords bool) []Match {
lenPtr := abi.memory.allocate(4)

callStack := abi.callStack
callStack[0] = uint64(iter)
callStack[1] = uint64(value.ptr)
callStack[2] = uint64(value.length)
callStack[3] = uint64(n)
callStack[4] = uint64(lenPtr)
if err := abi.matches.CallWithStack(context.Background(), callStack); err != nil {
panic(err)
}

resLen, ok := abi.wasmMemory.ReadUint32Le(uint32(lenPtr))
if !ok {
panic(errFailedRead)
}

resPtr := callStack[0]
defer func() {
callStack[0] = uint64(resPtr)
callStack[1] = uint64(resLen)
if err := abi.matches_delete.CallWithStack(context.Background(), callStack); err != nil {
panic(err)
}
}()

res, ok := abi.wasmMemory.Read(uint32(resPtr), resLen*4)
if !ok {
panic(errFailedRead)
}

num := resLen / 3
matches := make([]Match, 0, num)
for i := 0; i < int(num); i++ {
start := int(binary.LittleEndian.Uint32(res[i*12+4:]))
end := int(binary.LittleEndian.Uint32(res[i*12+8:]))
if matchWholeWords && isNotWholeWord(valueStr, start, end) {
continue
}
var m Match
m.pattern = int(binary.LittleEndian.Uint32(res[i*12:]))
m.start = start
m.end = end
matches = append(matches, m)
}

return matches
}

type sharedMemory struct {
size uint32
bufPtr uint32
Expand Down
26 changes: 17 additions & 9 deletions benchmarks2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,15 +327,23 @@ func BenchmarkBurntSushi(b *testing.B) {
DFA: dfa,
}).Build(tt.patterns)
b.Run(dfaStr, func(b *testing.B) {
for i := 0; i < b.N; i++ {
cnt := 0
iter := ac.Iter(tt.corpus)
for iter.Next() != nil {
cnt++
}
if cnt != tt.count {
b.Errorf("expected %d matches, got %d", tt.count, cnt)
}
for _, iterate := range []bool{false, true} {
b.Run(fmt.Sprintf("iterate=%v", iterate), func(b *testing.B) {
for i := 0; i < b.N; i++ {
var cnt int
if iterate {
iter := ac.Iter(tt.corpus)
for iter.Next() != nil {
cnt++
}
} else {
cnt = len(ac.FindAll(tt.corpus))
}
if cnt != tt.count {
b.Errorf("expected %d matches, got %d", tt.count, cnt)
}
}
})
}
})
}
Expand Down
36 changes: 23 additions & 13 deletions buildtools/aho-corasick/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@

extern crate aho_corasick;

use std::ffi::CStr;
use std::mem::MaybeUninit;
use std::os::raw::c_char;
use std::slice;
use std::str;
use aho_corasick::{AhoCorasick, AhoCorasickBuilder, AhoCorasickKind, FindIter, FindOverlappingIter, MatchKind};
Expand Down Expand Up @@ -84,23 +81,36 @@ pub extern "C" fn overlapping_iter_delete(_iter: Box<FindOverlappingIter>) {
}

#[no_mangle]
pub extern "C" fn matches(ac: &mut AhoCorasick, value_ptr: usize, value_len: usize, n: usize, matches: *mut usize) -> usize {
pub extern "C" fn matches(ac: &mut AhoCorasick, value_ptr: usize, value_len: usize, limit: usize, num: &mut usize) -> *const usize {
let mut matches = Vec::new();
let value = ptr_to_string(value_ptr, value_len);
std::mem::forget(&value);

let mut num = 0;
let mut count = 0;
for value in ac.find_iter(value.as_bytes()) {
if num == n {
if count == limit {
break;
}
unsafe {
*matches.offset(2*num as isize) = value.start();
*matches.offset((2*num+1) as isize) = value.end();
}
num += 1;

matches.push(value.pattern().as_usize());
matches.push(value.start());
matches.push(value.end());

count += 1;
}

return num
let b = matches.into_boxed_slice();
let ptr = b.as_ptr();
let len = b.len(); // Same as count since into_boxed_slice() truncates
std::mem::forget(b);
*num = len;
return ptr;
}

#[no_mangle]
pub extern "C" fn matches_delete(ptr: *const usize, len: usize) {
unsafe {
let _ = Vec::from_raw_parts(ptr as *mut usize, len, len);
}
}

extern "C" {
Expand Down
Binary file modified wasm/aho_corasick.wasm
Binary file not shown.
Binary file modified wasm/libaho_corasick.a
Binary file not shown.

0 comments on commit 2592912

Please sign in to comment.