diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c9d6a95..fcd7677 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -457,6 +457,26 @@ cargo package --list --allow-dirty If you want to run benchmarks against third-party implementations, check out the [`ashvardanian/memchr_vs_stringzilla`](https://github.com/ashvardanian/memchr_vs_stringzilla/) repository. +## Contributing in Go + +```bash +export GO111MODULE="off" +go run scripts/test.go +go run scripts/bench.go +``` + +To run locally import with a relative path + +```bash + sz "../StringZilla/go/stringzilla" +``` + +And turn off GO111MODULE + +```bash +export GO111MODULE="off" +``` + ## General Performance Observations ### Unaligned Loads diff --git a/go/stringzilla/main.go b/go/stringzilla/main.go new file mode 100644 index 0000000..fdaca00 --- /dev/null +++ b/go/stringzilla/main.go @@ -0,0 +1,130 @@ +package sz + +// #cgo CFLAGS: -g -mavx2 +// #include +// #include <../../include/stringzilla/stringzilla.h> +import "C" + +// -Wall -O3 + +import ( + "unsafe" +) + +/* +// Passing a C function pointer around in go isn't working +//type searchFunc func(*C.char, C.ulong, *C.char, C.ulong)C.sz_cptr_t +//func _search( str string, pat string, searchFunc func(*C.char, C.ulong, *C.char, C.ulong)C.sz_cptr_t) uintptr { +func _search( str string, pat string, searchFunc C.sz_find_t ) uintptr { + cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) + strlen := len(str) + patlen := len(pat) + ret := unsafe.Pointer( searchFunc(cstr, C.ulong(strlen), cpat, C.ulong(patlen)) ) + return ret +} +*/ + +func Contains( str string, pat string ) bool { + cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) + strlen := len(str) + patlen := len(pat) + ret := unsafe.Pointer(C.sz_find( cstr, C.ulong(strlen), cpat, C.ulong(patlen) )) + //ret := _search( str, pat, C.sz_find_t(C.sz_find) ) + return ret != nil +} + +func Index( str string, pat string ) int64 { + cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) + strlen := len(str) + patlen := len(pat) + ret := unsafe.Pointer(C.sz_find( cstr, C.ulong(strlen), cpat, C.ulong(patlen) )) + if ret == nil { + return 0 + } + return int64(uintptr(ret)-uintptr(unsafe.Pointer(cstr))) +} + +func Find( str string, pat string ) int64 { + cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) + strlen := len(str) + patlen := len(pat) + ret := unsafe.Pointer(C.sz_find( cstr, C.ulong(strlen), cpat, C.ulong(patlen) )) + if ret == nil { + return -1 + } + return int64(uintptr(ret)-uintptr(unsafe.Pointer(cstr))) +} + +func LastIndex( str string, pat string ) int64 { + cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) + strlen := len(str) + patlen := len(pat) + ret := unsafe.Pointer(C.sz_rfind( cstr, C.ulong(strlen), cpat, C.ulong(patlen) )) + if ret == nil { + return -1 + } + return int64(uintptr(ret)-uintptr(unsafe.Pointer(cstr))) +} +func RFind( str string, pat string ) int64 { + return LastIndex(str,pat) +} + +func IndexAny( str string, charset string ) int64 { + cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(charset))) + strlen := len(str) + patlen := len(charset) + ret := unsafe.Pointer(C.sz_find_char_from( cstr, C.ulong(strlen), cpat, C.ulong(patlen) )) + if ret == nil { + return -1 + } + return int64(uintptr(ret)-uintptr(unsafe.Pointer(cstr))) +} +func FindCharFrom( str string, charset string ) int64 { + return IndexAny( str, charset ) +} + +func Count( str string, pat string, overlap bool ) int64 { + cstr := (*C.char)(unsafe.Pointer(unsafe.StringData(str))) + cpat := (*C.char)(unsafe.Pointer(unsafe.StringData(pat))) + strlen := int64(len(str)) + patlen := int64(len(pat)) + + if strlen == 0 || patlen == 0 || strlen < patlen { + return 0 + } + + count := int64(0); + if overlap == true { + for strlen > 0 { + ret := unsafe.Pointer(C.sz_find( cstr, C.ulong(strlen), cpat, C.ulong(patlen) )) + if ret == nil { + break + } + count += 1 + strlen -= ( 1 + int64(uintptr(ret)-uintptr(unsafe.Pointer(cstr))) ) + cstr = (*C.char)(unsafe.Add(ret,1)) + } + } else { + for strlen > 0 { + ret := unsafe.Pointer(C.sz_find( cstr, C.ulong(strlen), cpat, C.ulong(patlen) )) + if ret == nil { + break + } + count += 1 + strlen -= (patlen+int64(uintptr(ret)-uintptr(unsafe.Pointer(cstr)))) + cstr = (*C.char)(unsafe.Add(ret,patlen)) + } + } + + return count + +} + + + diff --git a/scripts/bench.go b/scripts/bench.go new file mode 100644 index 0000000..680736b --- /dev/null +++ b/scripts/bench.go @@ -0,0 +1,71 @@ +package main + +import ( + "fmt" + "time" + "strings" + sz "../go/stringzilla" +) + +func main() { + + str := strings.Repeat("0123456789", 10000) + "something" + pat := "some" + + fmt.Println("Contains") + t := time.Now() + for i := 0; i < 1; i++ { + strings.Contains( str, pat ) + } + fmt.Println( " ", time.Since(t) , "\tstrings.Contains" ) + + t = time.Now() + for i := 0; i < 1; i++ { + sz.Contains( str,pat ) + } + fmt.Println( " ", time.Since(t) , "\tsz.Contains" ) + + fmt.Println("Index") + t = time.Now() + for i := 0; i < 1; i++ { + strings.Index( str, pat ) + } + fmt.Println( " ", time.Since(t) , "\tstrings.Index" ) + + t = time.Now() + for i := 0; i < 1; i++ { + sz.Index( str,pat ) + } + fmt.Println( " ", time.Since(t) , "\tsz.Index" ) + + fmt.Println("IndexAny") + t = time.Now() + for i := 0; i < 1; i++ { + strings.IndexAny( str, pat ) + } + fmt.Println( " ", time.Since(t) , "\tstrings.IndexAny" ) + + t = time.Now() + for i := 0; i < 1; i++ { + sz.IndexAny( str,pat ) + } + fmt.Println( " ", time.Since(t) , "\tsz.IndexAny" ) + + str = strings.Repeat("0123456789", 100000) + "something" + pat = "123456789" + fmt.Println("Count") + t = time.Now() + for i := 0; i < 1; i++ { + strings.Count( str, pat ) + } + fmt.Println( " ", time.Since(t) , "\tstrings.Count" ) + + t = time.Now() + for i := 0; i < 1; i++ { + sz.Count( str,pat, false ) + } + fmt.Println( " ", time.Since(t) , "\tsz.Count" ) + + + +} diff --git a/scripts/test.go b/scripts/test.go new file mode 100644 index 0000000..8c31de5 --- /dev/null +++ b/scripts/test.go @@ -0,0 +1,59 @@ +package main + +import ( + "fmt" + "strings" + "runtime" + sz "../go/stringzilla" +) + +func assertEqual[T comparable](act T, exp T) int { + if exp == act { + return 0 + } + _, _, line, _ := runtime.Caller(1) + fmt.Println("") + fmt.Println(" ERROR line ",line," expected (",exp,") is not equal to actual (",act,")") + return 1 +} + +func main() { + + str := strings.Repeat("0123456789", 100000) + "something" + pat := "some" + ret := 0 + + fmt.Print("Contains ... ") + ret |= assertEqual( sz.Contains( "", "" ), true ) + ret |= assertEqual( sz.Contains( "test", "" ), true ) + ret |= assertEqual( sz.Contains( "test", "s" ), true ) + ret |= assertEqual( sz.Contains( "test", "test" ), true ) + ret |= assertEqual( sz.Contains( "test", "zest" ), false ) + ret |= assertEqual( sz.Contains( "test", "z" ), false ) + if ( ret == 0 ) { + fmt.Println("successful") + } + + fmt.Print("Index ... ") + assertEqual( strings.Index( str, pat ), int(sz.Index( str,pat )) ) + assertEqual( sz.Index( "","" ), 0 ) + assertEqual( sz.Index( "test","" ), 0 ) + assertEqual( sz.Index( "test","t" ), 0 ) + assertEqual( sz.Index( "test","s" ), 2 ) + fmt.Println("successful") + + fmt.Print("IndexAny ... ") + assertEqual( strings.IndexAny( str, pat ), int(sz.IndexAny( str,pat )) ) + assertEqual( sz.IndexAny( "test", "st" ), 0 ) + assertEqual( sz.IndexAny( "west east", "ta" ), 3) + fmt.Println("successful") + + fmt.Print("Count ... ") + //assertEqual( strings.Count( str, pat ), int(sz.Count( str,pat,false )) ) + assertEqual( sz.Count( "aaaaa", "a", false ), 5 ) + assertEqual( sz.Count( "aaaaa", "aa", false ), 2 ) + assertEqual( sz.Count( "aaaaa", "aa", true ), 4 ) + fmt.Println("successful") + + +}