diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 524d6c49..e6ed8331 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -471,6 +471,26 @@ cargo package --list --allow-dirty If you want to run benchmarks against third-party implementations, check out the [`ashvardanian/memchr_vs_stringzilla`](https://github.com/ashvardanian/memchr_vs_stringzilla/) repository. +## Contributing in Go + +```bash +export GO111MODULE="off" +go run scripts/test.go +go run scripts/bench.go +``` + +To run locally import with a relative path + +```bash + sz "../StringZilla/go/stringzilla" +``` + +And turn off GO111MODULE + +```bash +export GO111MODULE="off" +``` + ## General Recommendations ### Operations Not Worth Optimizing diff --git a/go/stringzilla/main.go b/go/stringzilla/main.go new file mode 100644 index 00000000..b9b566c3 --- /dev/null +++ b/go/stringzilla/main.go @@ -0,0 +1,127 @@ +package sz + +// #cgo CFLAGS: -g -mavx2 +// #include +// #include <../../include/stringzilla/stringzilla.h> +import "C" + +// -Wall -O3 + +import ( + "unsafe" +) + +/* +// Passing a C function pointer around in go isn't working +//type searchFunc func(*C.char, C.ulong, *C.char, C.ulong)C.sz_cptr_t +//func _search( str string, pat string, searchFunc func(*C.char, C.ulong, *C.char, C.ulong)C.sz_cptr_t) uintptr { +func _search( str string, pat string, searchFunc C.sz_find_t ) uintptr { + cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) + strlen := len(str) + patlen := len(pat) + ret := unsafe.Pointer( searchFunc(cstr, C.ulong(strlen), cpat, C.ulong(patlen)) ) + return ret +} +*/ + +func Contains(str string, pat string) bool { + cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) + strlen := len(str) + patlen := len(pat) + ret := unsafe.Pointer(C.sz_find(cstr, C.ulong(strlen), cpat, C.ulong(patlen))) + //ret := _search( str, pat, C.sz_find_t(C.sz_find) ) + return ret != nil +} + +func Index(str string, pat string) int64 { + cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) + strlen := len(str) + patlen := len(pat) + ret := unsafe.Pointer(C.sz_find(cstr, C.ulong(strlen), cpat, C.ulong(patlen))) + if ret == nil { + return 0 + } + return int64(uintptr(ret) - uintptr(unsafe.Pointer(cstr))) +} + +func Find(str string, pat string) int64 { + cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) + strlen := len(str) + patlen := len(pat) + ret := unsafe.Pointer(C.sz_find(cstr, C.ulong(strlen), cpat, C.ulong(patlen))) + if ret == nil { + return -1 + } + return int64(uintptr(ret) - uintptr(unsafe.Pointer(cstr))) +} + +func LastIndex(str string, pat string) int64 { + cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) + strlen := len(str) + patlen := len(pat) + ret := unsafe.Pointer(C.sz_rfind(cstr, C.ulong(strlen), cpat, C.ulong(patlen))) + if ret == nil { + return -1 + } + return int64(uintptr(ret) - uintptr(unsafe.Pointer(cstr))) +} +func RFind(str string, pat string) int64 { + return LastIndex(str, pat) +} + +func IndexAny(str string, charset string) int64 { + cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(charset))) + strlen := len(str) + patlen := len(charset) + ret := unsafe.Pointer(C.sz_find_char_from(cstr, C.ulong(strlen), cpat, C.ulong(patlen))) + if ret == nil { + return -1 + } + return int64(uintptr(ret) - uintptr(unsafe.Pointer(cstr))) +} +func FindCharFrom(str string, charset string) int64 { + return IndexAny(str, charset) +} + +func Count(str string, pat string, overlap bool) int64 { + cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) + strlen := int64(len(str)) + patlen := int64(len(pat)) + + if strlen == 0 || patlen == 0 || strlen < patlen { + return 0 + } + + count := int64(0) + if overlap == true { + for strlen > 0 { + ret := unsafe.Pointer(C.sz_find(cstr, C.ulong(strlen), cpat, C.ulong(patlen))) + if ret == nil { + break + } + count += 1 + strlen -= (1 + int64(uintptr(ret)-uintptr(unsafe.Pointer(cstr)))) + cstr = (*C.char)(unsafe.Add(ret, 1)) + } + } else { + for strlen > 0 { + ret := unsafe.Pointer(C.sz_find(cstr, C.ulong(strlen), cpat, C.ulong(patlen))) + if ret == nil { + break + } + count += 1 + strlen -= (patlen + int64(uintptr(ret)-uintptr(unsafe.Pointer(cstr)))) + cstr = (*C.char)(unsafe.Add(ret, patlen)) + } + } + + return count + +} diff --git a/scripts/bench.go b/scripts/bench.go new file mode 100644 index 00000000..a63380ea --- /dev/null +++ b/scripts/bench.go @@ -0,0 +1,70 @@ +package main + +import ( + "fmt" + "strings" + "time" + + sz "../go/stringzilla" +) + +func main() { + + str := strings.Repeat("0123456789", 10000) + "something" + pat := "some" + + fmt.Println("Contains") + t := time.Now() + for i := 0; i < 1; i++ { + strings.Contains(str, pat) + } + fmt.Println(" ", time.Since(t), "\tstrings.Contains") + + t = time.Now() + for i := 0; i < 1; i++ { + sz.Contains(str, pat) + } + fmt.Println(" ", time.Since(t), "\tsz.Contains") + + fmt.Println("Index") + t = time.Now() + for i := 0; i < 1; i++ { + strings.Index(str, pat) + } + fmt.Println(" ", time.Since(t), "\tstrings.Index") + + t = time.Now() + for i := 0; i < 1; i++ { + sz.Index(str, pat) + } + fmt.Println(" ", time.Since(t), "\tsz.Index") + + fmt.Println("IndexAny") + t = time.Now() + for i := 0; i < 1; i++ { + strings.IndexAny(str, pat) + } + fmt.Println(" ", time.Since(t), "\tstrings.IndexAny") + + t = time.Now() + for i := 0; i < 1; i++ { + sz.IndexAny(str, pat) + } + fmt.Println(" ", time.Since(t), "\tsz.IndexAny") + + str = strings.Repeat("0123456789", 100000) + "something" + pat = "123456789" + fmt.Println("Count") + t = time.Now() + for i := 0; i < 1; i++ { + strings.Count(str, pat) + } + fmt.Println(" ", time.Since(t), "\tstrings.Count") + + t = time.Now() + for i := 0; i < 1; i++ { + sz.Count(str, pat, false) + } + fmt.Println(" ", time.Since(t), "\tsz.Count") + +} diff --git a/scripts/test.go b/scripts/test.go new file mode 100644 index 00000000..07faa6f8 --- /dev/null +++ b/scripts/test.go @@ -0,0 +1,59 @@ +package main + +import ( + "fmt" + "runtime" + "strings" + + sz "../go/stringzilla" +) + +func assertEqual[T comparable](act T, exp T) int { + if exp == act { + return 0 + } + _, _, line, _ := runtime.Caller(1) + fmt.Println("") + fmt.Println(" ERROR line ", line, " expected (", exp, ") is not equal to actual (", act, ")") + return 1 +} + +func main() { + + str := strings.Repeat("0123456789", 100000) + "something" + pat := "some" + ret := 0 + + fmt.Print("Contains ... ") + ret |= assertEqual(sz.Contains("", ""), true) + ret |= assertEqual(sz.Contains("test", ""), true) + ret |= assertEqual(sz.Contains("test", "s"), true) + ret |= assertEqual(sz.Contains("test", "test"), true) + ret |= assertEqual(sz.Contains("test", "zest"), false) + ret |= assertEqual(sz.Contains("test", "z"), false) + if ret == 0 { + fmt.Println("successful") + } + + fmt.Print("Index ... ") + assertEqual(strings.Index(str, pat), int(sz.Index(str, pat))) + assertEqual(sz.Index("", ""), 0) + assertEqual(sz.Index("test", ""), 0) + assertEqual(sz.Index("test", "t"), 0) + assertEqual(sz.Index("test", "s"), 2) + fmt.Println("successful") + + fmt.Print("IndexAny ... ") + assertEqual(strings.IndexAny(str, pat), int(sz.IndexAny(str, pat))) + assertEqual(sz.IndexAny("test", "st"), 0) + assertEqual(sz.IndexAny("west east", "ta"), 3) + fmt.Println("successful") + + fmt.Print("Count ... ") + //assertEqual( strings.Count( str, pat ), int(sz.Count( str,pat,false )) ) + assertEqual(sz.Count("aaaaa", "a", false), 5) + assertEqual(sz.Count("aaaaa", "aa", false), 2) + assertEqual(sz.Count("aaaaa", "aa", true), 4) + fmt.Println("successful") + +}