diff --git a/go.mod b/go.mod index 92823267f..b79cd31dc 100644 --- a/go.mod +++ b/go.mod @@ -9,40 +9,40 @@ require ( github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/dgryski/go-metro v0.0.0-20180109044635-280f6062b5bc // indirect + github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect - github.com/klauspost/cpuid/v2 v2.0.9 // indirect + github.com/klauspost/cpuid/v2 v2.2.8 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect - github.com/mattn/go-isatty v0.0.19 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - github.com/sagikazarmark/locafero v0.4.0 // indirect + github.com/sagikazarmark/locafero v0.6.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.11.0 // indirect - github.com/spf13/cast v1.6.0 // indirect + github.com/spf13/cast v1.7.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect - go.uber.org/atomic v1.9.0 // indirect - go.uber.org/multierr v1.9.0 // indirect - golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect - golang.org/x/exp v0.0.0-20231006140011-7918f672742d // indirect - golang.org/x/sys v0.22.0 // indirect - golang.org/x/text v0.16.0 // indirect + go.uber.org/atomic v1.11.0 // indirect + go.uber.org/multierr v1.11.0 // indirect + golang.org/x/arch v0.10.0 // indirect + golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 // indirect + golang.org/x/sys v0.25.0 // indirect + golang.org/x/text v0.18.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) require ( github.com/axiomhq/hyperloglog v0.2.0 - github.com/bytedance/sonic v1.12.1 - github.com/cespare/xxhash/v2 v2.2.0 + github.com/bytedance/sonic v1.12.3 + github.com/cespare/xxhash/v2 v2.3.0 github.com/cockroachdb/swiss v0.0.0-20240612210725-f4de07ae6964 - github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 + github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831 github.com/google/btree v1.1.3 github.com/google/go-cmp v0.6.0 @@ -50,10 +50,10 @@ require ( github.com/ohler55/ojg v1.24.0 github.com/pelletier/go-toml/v2 v2.2.3 github.com/rs/xid v1.6.0 - github.com/rs/zerolog v1.30.0 + github.com/rs/zerolog v1.33.0 github.com/spf13/viper v1.19.0 github.com/stretchr/testify v1.9.0 github.com/twmb/murmur3 v1.1.8 github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 - golang.org/x/crypto v0.25.0 + golang.org/x/crypto v0.27.0 ) diff --git a/go.sum b/go.sum index ae7d60a9e..1a1d21646 100644 --- a/go.sum +++ b/go.sum @@ -8,11 +8,15 @@ github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/bytedance/sonic v1.12.1 h1:jWl5Qz1fy7X1ioY74WqO0KjAMtAGQs4sYnjiEBiyX24= github.com/bytedance/sonic v1.12.1/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk= +github.com/bytedance/sonic v1.12.3 h1:W2MGa7RCU1QTeYRTPE3+88mVC0yXmsRQRChiyVocVjU= +github.com/bytedance/sonic v1.12.3/go.mod h1:B8Gt/XvtZ3Fqj+iSKMypzymZxw/FVwgIGKzMzT9r/rk= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/bytedance/sonic/loader v0.2.0 h1:zNprn+lsIP06C/IqCHs3gPQIvnvpKbbxyXQP1iU4kWM= github.com/bytedance/sonic/loader v0.2.0/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= @@ -26,8 +30,12 @@ github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 h1:fAjc9m62+UWV/WAFKLNi6ZS0675eEUC9y3AlwSbQu1Y= github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= +github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da h1:aIftn67I1fkbMa512G+w+Pxci9hJPB8oMnkcP3iZF38= +github.com/dgryski/go-farm v0.0.0-20240924180020-3414d57e47da/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dgryski/go-metro v0.0.0-20180109044635-280f6062b5bc h1:8WFBn63wegobsYAX0YjD+8suexZDga5CctH4CCTx2+8= github.com/dgryski/go-metro v0.0.0-20180109044635-280f6062b5bc/go.mod h1:c9O8+fpSOX1DM8cPNSkX/qsBWdkD4yd2dpciOWQjpBw= +github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140 h1:y7y0Oa6UawqTFPCDw9JG6pdKt4F9pAhHv0B7FMGaGD0= +github.com/dgryski/go-metro v0.0.0-20211217172704-adc40b04c140/go.mod h1:c9O8+fpSOX1DM8cPNSkX/qsBWdkD4yd2dpciOWQjpBw= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dicedb/go-dice v0.0.0-20240820180649-d97f15fca831 h1:Cqyj9WCtoobN6++bFbDSe27q94SPwJD9Z0wmu+SDRuk= @@ -47,6 +55,8 @@ github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM= +github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= @@ -61,10 +71,14 @@ github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27k github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/ohler55/ojg v1.24.0 h1:y2AVez6fPTszK/jPhaAYMCAzAoSleConMqSDD5wJKJg= github.com/ohler55/ojg v1.24.0/go.mod h1:gQhDVpQLqrmnd2eqGAvJtn+NfKoYJbe/A4Sj3/Vro4o= +github.com/ohler55/ojg v1.24.1 h1:PaVLelrNgT5/0ppPaUtey54tOVp245z33fkhL2jljjY= +github.com/ohler55/ojg v1.24.1/go.mod h1:gQhDVpQLqrmnd2eqGAvJtn+NfKoYJbe/A4Sj3/Vro4o= github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -78,8 +92,12 @@ github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU= github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.30.0 h1:SymVODrcRsaRaSInD9yQtKbtWqwsfoPcRff/oRXLj4c= github.com/rs/zerolog v1.30.0/go.mod h1:/tk+P47gFdPXq4QYjvCmT5/Gsug2nagsFWBWhAiSi1w= +github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= +github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= +github.com/sagikazarmark/locafero v0.6.0 h1:ON7AQg37yzcRPU69mt7gwhFEBwxI6P9T4Qu3N51bwOk= +github.com/sagikazarmark/locafero v0.6.0/go.mod h1:77OmuIc6VTraTXKXIs/uvUxKGUXjE1GbemJYHqdNjX0= github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= @@ -88,6 +106,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cast v1.7.0 h1:ntdiHjuueXFgm5nzDRdOS4yfT43P5Fnud6DH50rz/7w= +github.com/spf13/cast v1.7.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI= @@ -112,22 +132,38 @@ github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 h1:zzrxE1FKn5ryB github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2/go.mod h1:hzfGeIUDq/j97IG+FhNqkowIyEcD88LrW6fyU3K3WqY= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.10.0 h1:S3huipmSclq3PJMNe76NGwkBR504WFkQ5dhzWzP8ZW8= +golang.org/x/arch v0.10.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= +golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= +golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 h1:e66Fs6Z+fZTbFBAxKfP3PALWBtpfqks2bwGcexMxgtk= +golang.org/x/exp v0.0.0-20240909161429-701f63a606c0/go.mod h1:2TbTHSBQa924w8M6Xs1QcRcFwyucIwBGpK1p2f1YFFY= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= +golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/integration_tests/commands/async/bit_ops_string_int_test.go b/integration_tests/commands/async/bit_ops_string_int_test.go index 581bb4423..94c1c0e75 100644 --- a/integration_tests/commands/async/bit_ops_string_int_test.go +++ b/integration_tests/commands/async/bit_ops_string_int_test.go @@ -34,8 +34,8 @@ func TestBitOpsString(t *testing.T) { } testCases := []struct { - name string - cmds []string + name string + cmds []string expected []interface{} assertType []string }{ @@ -149,13 +149,13 @@ func TestBitOpsString(t *testing.T) { { name: "BITOP XOR of keys containing strings and get the destkey", cmds: []string{"MSET foo foobar baz abcdef", "BITOP XOR bazz foo baz", "GET bazz"}, - expected: []interface{}{"OK", int64(6), "\a\r\x0c\x06\x04\x14"}, + expected: []interface{}{"OK", int64(6), "\x07\x0d\x0c\x06\x04\x14"}, assertType: []string{"equal", "equal", "equal"}, }, { name: "BITOP XOR of keys containing strings and a bytearray and get the destkey", cmds: []string{"MSET foo foobar baz abcdef", "SETBIT bazz 8 1", "BITOP XOR bazzz foo baz bazz", "GET bazzz", "SETBIT bazz 8 0", "SETBIT bazz 49 1", "BITOP XOR bazzz foo baz bazz", "GET bazzz", "Setbit bazz 49 0", "bitop xor bazzz foo baz bazz", "get bazzz"}, - expected: []interface{}{"OK", int64(0), int64(6), "\a\x8d\x0c\x06\x04\x14", int64(1), int64(0), int64(7), "\a\r\x0c\x06\x04\x14@", int64(1), int64(7), "\a\r\x0c\x06\x04\x14\x00"}, + expected: []interface{}{"OK", int64(0), int64(6), "\x07\x8d\x0c\x06\x04\x14", int64(1), int64(0), int64(7), "\x07\r\x0c\x06\x04\x14@", int64(1), int64(7), "\x07\r\x0c\x06\x04\x14\x00"}, assertType: []string{"equal", "equal", "equal", "equal", "equal", "equal", "equal", "equal", "equal", "equal", "equal"}, }, { @@ -175,6 +175,7 @@ func TestBitOpsString(t *testing.T) { FireCommand(conn, "DEL bazzz") for i := 0; i < len(tc.cmds); i++ { res := FireCommand(conn, tc.cmds[i]) + switch tc.assertType[i] { case "equal": assert.Equal(t, res, tc.expected[i]) diff --git a/integration_tests/commands/http/qwatch_test.go b/integration_tests/commands/http/qwatch_test.go index 7755692ac..35aee9336 100644 --- a/integration_tests/commands/http/qwatch_test.go +++ b/integration_tests/commands/http/qwatch_test.go @@ -33,11 +33,10 @@ func TestQWatch(t *testing.T) { {Command: "QWATCH", Body: map[string]interface{}{"query": qWatchQuery}}, }, expected: []interface{}{ - []interface{}{ - "qwatch", - "SELECT $key, $value WHERE $key like 'match:100:*' and $value > 10 ORDER BY $value desc LIMIT 3", - // Empty array, as the initial result will be empty - []interface{}{}, + map[string]interface{}{ + "cmd": "qwatch", + "query": "SELECT $key, $value WHERE $key like 'match:100:*' and $value > 10 ORDER BY $value desc LIMIT 3", + "data": []interface{}{}, }, }, errorExpected: false, @@ -92,10 +91,10 @@ func TestQwatchWithSSE(t *testing.T) { decoder := json.NewDecoder(resp.Body) expectedResponses := []interface{}{ - []interface{}{ - "qwatch", - "SELECT $key, $value WHERE $key like 'match:100:*' and $value > 10 ORDER BY $value desc LIMIT 3", - []interface{}{}, + map[string]interface{}{ + "cmd": "qwatch", + "query": "SELECT $key, $value WHERE $key like 'match:100:*' and $value > 10 ORDER BY $value desc LIMIT 3", + "data": []interface{}{}, }, map[string]interface{}{ "cmd": "qwatch", diff --git a/integration_tests/commands/http/setup.go b/integration_tests/commands/http/setup.go index 7ada5d5ca..b0c20bf28 100644 --- a/integration_tests/commands/http/setup.go +++ b/integration_tests/commands/http/setup.go @@ -13,6 +13,8 @@ import ( "sync" "time" + "github.com/dicedb/dice/internal/server/utils" + "github.com/dicedb/dice/config" derrors "github.com/dicedb/dice/internal/errors" "github.com/dicedb/dice/internal/querymanager" @@ -73,6 +75,15 @@ func (e *HTTPCommandExecutor) FireCommand(cmd HTTPCommand) (interface{}, error) } defer resp.Body.Close() + if cmd.Command != "QWATCH" { + var result utils.HTTPResponse + err = json.NewDecoder(resp.Body).Decode(&result) + if err != nil { + return nil, err + } + + return result.Data, nil + } var result interface{} err = json.NewDecoder(resp.Body).Decode(&result) if err != nil { diff --git a/integration_tests/commands/resp/abort/server_abort_test.go b/integration_tests/commands/resp/abort/server_abort_test.go new file mode 100644 index 000000000..5d75443a6 --- /dev/null +++ b/integration_tests/commands/resp/abort/server_abort_test.go @@ -0,0 +1,130 @@ +package abort + +import ( + "context" + "fmt" + "github.com/dicedb/dice/integration_tests/commands/resp" + "net" + "sync" + "testing" + "time" + + "github.com/dicedb/dice/config" +) + +var testServerOptions = resp.TestServerOptions{ + Port: 8740, +} + +func TestAbortCommand(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer ctx.Done() + t.Cleanup(cancel) + + var wg sync.WaitGroup + resp.RunTestServer(&wg, testServerOptions) + + time.Sleep(2 * time.Second) + + // Test 1: Ensure the server is running + t.Run("ServerIsRunning", func(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", config.DiceConfig.Server.Port)) + if err != nil { + t.Fatalf("Failed to connect to server: %v", err) + } + conn.Close() + }) + + //Test 2: Send ABORT command and check if the server shuts down + t.Run("AbortCommandShutdown", func(t *testing.T) { + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", config.DiceConfig.Server.Port)) + if err != nil { + t.Fatalf("Failed to connect to server: %v", err) + } + defer conn.Close() + + // Send ABORT command + result := resp.FireCommand(conn, "ABORT") + if result != "OK" { + t.Fatalf("Unexpected response to ABORT command: %v", result) + } + + // Wait for the server to shut down + time.Sleep(1 * time.Second) + + // Try to connect again, it should fail + _, err = net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", config.DiceConfig.Server.Port)) + if err == nil { + t.Fatal("Server did not shut down as expected") + } + }) + + // Test 3: Ensure the server port is released + t.Run("PortIsReleased", func(t *testing.T) { + // Try to bind to the same port + listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", config.DiceConfig.Server.Port)) + if err != nil { + t.Fatalf("Port should be available after server shutdown: %v", err) + } + listener.Close() + }) + + wg.Wait() +} + +func TestServerRestartAfterAbort(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer ctx.Done() + t.Cleanup(cancel) + + // start test server. + var wg sync.WaitGroup + resp.RunTestServer(&wg, testServerOptions) + + time.Sleep(1 * time.Second) + + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", config.DiceConfig.Server.Port)) + if err != nil { + t.Fatalf("Server should be running after restart: %v", err) + } + + // Send ABORT command to shut down server + result := resp.FireCommand(conn, "ABORT") + if result != "OK" { + t.Fatalf("Unexpected response to ABORT command: %v", result) + } + conn.Close() + + // wait for the server to shutdown + time.Sleep(2 * time.Second) + + wg.Wait() + + // restart server + ctx2, cancel2 := context.WithCancel(context.Background()) + defer ctx2.Done() + t.Cleanup(cancel2) + + // start test server. + // use different waitgroups and contexts to avoid race conditions.; + var wg2 sync.WaitGroup + resp.RunTestServer(&wg2, testServerOptions) + + // wait for the server to start up + time.Sleep(2 * time.Second) + + // Check if the server is running + conn2, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", config.DiceConfig.Server.Port)) + if err != nil { + t.Fatalf("Server should be running after restart: %v", err) + } + + // Clean up + result = resp.FireCommand(conn2, "ABORT") + if result != "OK" { + t.Fatalf("Unexpected response to ABORT command: %v", result) + } + conn2.Close() + + wg2.Wait() +} diff --git a/integration_tests/commands/resp/command_getkeys_test.go b/integration_tests/commands/resp/command_getkeys_test.go new file mode 100644 index 000000000..e17251b50 --- /dev/null +++ b/integration_tests/commands/resp/command_getkeys_test.go @@ -0,0 +1,50 @@ +package resp + +import ( + "testing" + + "gotest.tools/v3/assert" +) + +var getKeysTestCases = []struct { + name string + inCmd string + expected interface{} +}{ + {"Set command", "set 1 2 3 4", []interface{}{"1"}}, + {"Get command", "get key", []interface{}{"key"}}, + {"TTL command", "ttl key", []interface{}{"key"}}, + {"Del command", "del 1 2 3 4 5 6", []interface{}{"1", "2", "3", "4", "5", "6"}}, + {"MSET command", "MSET key1 val1 key2 val2", []interface{}{"key1", "key2"}}, + {"Expire command", "expire key time extra", []interface{}{"key"}}, + {"BFINIT command", "BFINIT bloom some parameters", []interface{}{"bloom"}}, + {"Ping command", "ping", "ERR the command has no key arguments"}, + {"Invalid Get command", "get", "ERR invalid number of arguments specified for command"}, + {"Abort command", "abort", "ERR the command has no key arguments"}, + {"Invalid command", "NotValidCommand", "ERR invalid command specified"}, + {"Wrong number of arguments", "", "ERR wrong number of arguments for 'command|getkeys' command"}, +} + +func TestCommandGetKeys(t *testing.T) { + conn := getLocalConnection() + defer conn.Close() + + for _, tc := range getKeysTestCases { + t.Run(tc.name, func(t *testing.T) { + result := FireCommand(conn, "COMMAND GETKEYS "+tc.inCmd) + assert.DeepEqual(t, tc.expected, result) + }) + } +} + +func BenchmarkGetKeysMatch(b *testing.B) { + conn := getLocalConnection() + defer conn.Close() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, tc := range getKeysTestCases { + FireCommand(conn, "COMMAND GETKEYS "+tc.inCmd) + } + } +} diff --git a/integration_tests/commands/resp/command_info_test.go b/integration_tests/commands/resp/command_info_test.go new file mode 100644 index 000000000..56a580388 --- /dev/null +++ b/integration_tests/commands/resp/command_info_test.go @@ -0,0 +1,50 @@ +package resp + +import ( + "testing" + + "gotest.tools/v3/assert" +) + +var getInfoTestCases = []struct { + name string + inCmd string + expected interface{} +}{ + {"Set command", "SET", []interface{}{[]interface{}{"SET", int64(-3), int64(1), int64(0), int64(0)}}}, + {"Get command", "GET", []interface{}{[]interface{}{"GET", int64(2), int64(1), int64(0), int64(0)}}}, + {"Ping command", "PING", []interface{}{[]interface{}{"PING", int64(-1), int64(0), int64(0), int64(0)}}}, + {"Invalid command", "INVALID_CMD", []interface{}{string("(nil)")}}, + {"Combination of valid and Invalid command", "SET INVALID_CMD", []interface{}{ + []interface{}{"SET", int64(-3), int64(1), int64(0), int64(0)}, + string("(nil)"), + }}, + {"Combination of multiple valid commands", "SET GET", []interface{}{ + []interface{}{"SET", int64(-3), int64(1), int64(0), int64(0)}, + []interface{}{"GET", int64(2), int64(1), int64(0), int64(0)}, + }}, +} + +func TestCommandInfo(t *testing.T) { + conn := getLocalConnection() + defer conn.Close() + + for _, tc := range getInfoTestCases { + t.Run(tc.name, func(t *testing.T) { + result := FireCommand(conn, "COMMAND INFO "+tc.inCmd) + assert.DeepEqual(t, tc.expected, result) + }) + } +} + +func BenchmarkCommandInfo(b *testing.B) { + conn := getLocalConnection() + defer conn.Close() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, tc := range getKeysTestCases { + FireCommand(conn, "COMMAND INFO "+tc.inCmd) + } + } +} diff --git a/integration_tests/commands/resp/get_test.go b/integration_tests/commands/resp/get_test.go new file mode 100644 index 000000000..bb826a499 --- /dev/null +++ b/integration_tests/commands/resp/get_test.go @@ -0,0 +1,39 @@ +package resp + +import ( + "testing" + "time" + + "gotest.tools/v3/assert" +) + +func TestGet(t *testing.T) { + conn := getLocalConnection() + defer conn.Close() + + testCases := []struct { + name string + cmds []string + expect []interface{} + delays []time.Duration + }{ + { + name: "Get with expiration", + cmds: []string{"SET k v EX 4", "GET k", "GET k"}, + expect: []interface{}{"OK", "v", "(nil)"}, + delays: []time.Duration{0, 0, 5 * time.Second}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for i, cmd := range tc.cmds { + if tc.delays[i] > 0 { + time.Sleep(tc.delays[i]) + } + result := FireCommand(conn, cmd) + assert.Equal(t, tc.expect[i], result, "Value mismatch for cmd %s", cmd) + } + }) + } +} diff --git a/integration_tests/commands/resp/getset_test.go b/integration_tests/commands/resp/getset_test.go new file mode 100644 index 000000000..af0e5f585 --- /dev/null +++ b/integration_tests/commands/resp/getset_test.go @@ -0,0 +1,57 @@ +package resp + +import ( + "testing" + "time" + + "gotest.tools/v3/assert" +) + +func TestGetSet(t *testing.T) { + conn := getLocalConnection() + defer conn.Close() + + testCases := []struct { + name string + cmds []string + expect []interface{} + delays []time.Duration + }{ + { + name: "GETSET with INCR", + cmds: []string{"INCR mycounter", "GETSET mycounter \"0\"", "GET mycounter"}, + expect: []interface{}{int64(1), int64(1), int64(0)}, + delays: []time.Duration{0, 0, 0}, + }, + { + name: "GETSET with SET", + cmds: []string{"SET mykey \"Hello\"", "GETSET mykey \"world\"", "GET mykey"}, + expect: []interface{}{"OK", "Hello", "world"}, + delays: []time.Duration{0, 0, 0}, + }, + { + name: "GETSET with TTL", + cmds: []string{"SET k v EX 60", "GETSET k v1", "TTL k"}, + expect: []interface{}{"OK", "v", int64(-1)}, + delays: []time.Duration{0, 0, 0}, + }, + { + name: "GETSET error when key exists but does not hold a string value", + cmds: []string{"LPUSH k1 \"somevalue\"", "GETSET k1 \"v1\""}, + expect: []interface{}{"OK", "WRONGTYPE Operation against a key holding the wrong kind of value"}, + delays: []time.Duration{0, 0, 0}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for i, cmd := range tc.cmds { + if tc.delays[i] > 0 { + time.Sleep(tc.delays[i]) + } + result := FireCommand(conn, cmd) + assert.Equal(t, tc.expect[i], result, "Value mismatch for cmd %s", cmd) + } + }) + } +} diff --git a/integration_tests/commands/resp/main_test.go b/integration_tests/commands/resp/main_test.go new file mode 100644 index 000000000..ec3573417 --- /dev/null +++ b/integration_tests/commands/resp/main_test.go @@ -0,0 +1,47 @@ +package resp + +import ( + "log/slog" + "os" + "sync" + "testing" + "time" + + "github.com/dicedb/dice/internal/logger" +) + +func TestMain(m *testing.M) { + logger := logger.New(logger.Opts{WithTimestamp: false}) + slog.SetDefault(logger) + + var wg sync.WaitGroup + // Run the test server + // This is a synchronous method, because internally it + // checks for available port and then forks a goroutine + // to start the server + opts := TestServerOptions{ + Port: 8739, + Logger: logger, + } + RunTestServer(&wg, opts) + + // Wait for the server to start + time.Sleep(2 * time.Second) + + conn := getLocalConnection() + if conn == nil { + panic("Failed to connect to the test server") + } + defer conn.Close() + + // Run the test suite + exitCode := m.Run() + + result := FireCommand(conn, "ABORT") + if result != "OK" { + panic("Failed to abort the server") + } + + wg.Wait() + os.Exit(exitCode) +} diff --git a/integration_tests/commands/resp/set_test.go b/integration_tests/commands/resp/set_test.go new file mode 100644 index 000000000..e6fa927c3 --- /dev/null +++ b/integration_tests/commands/resp/set_test.go @@ -0,0 +1,200 @@ +package resp + +import ( + "strconv" + "testing" + "time" + + "gotest.tools/v3/assert" +) + +type TestCase struct { + name string + commands []string + expected []interface{} +} + +func TestSet(t *testing.T) { + conn := getLocalConnection() + defer conn.Close() + + testCases := []TestCase{ + { + name: "Set and Get Simple Value", + commands: []string{"SET k v", "GET k"}, + expected: []interface{}{"OK", "v"}, + }, + { + name: "Set and Get Integer Value", + commands: []string{"SET k 123456789", "GET k"}, + expected: []interface{}{"OK", int64(123456789)}, + }, + { + name: "Overwrite Existing Key", + commands: []string{"SET k v1", "SET k 5", "GET k"}, + expected: []interface{}{"OK", "OK", int64(5)}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // deleteTestKeys([]string{"k"}, store) + FireCommand(conn, "DEL k") + + for i, cmd := range tc.commands { + result := FireCommand(conn, cmd) + assert.DeepEqual(t, tc.expected[i], result) + } + }) + } +} + +func TestSetWithOptions(t *testing.T) { + conn := getLocalConnection() + expiryTime := strconv.FormatInt(time.Now().Add(1*time.Minute).UnixMilli(), 10) + defer conn.Close() + + testCases := []TestCase{ + { + name: "Set with EX option", + commands: []string{"SET k v EX 2", "GET k", "SLEEP 3", "GET k"}, + expected: []interface{}{"OK", "v", "OK", "(nil)"}, + }, + { + name: "Set with PX option", + commands: []string{"SET k v PX 2000", "GET k", "SLEEP 3", "GET k"}, + expected: []interface{}{"OK", "v", "OK", "(nil)"}, + }, + { + name: "Set with EX and PX option", + commands: []string{"SET k v EX 2 PX 2000"}, + expected: []interface{}{"ERR syntax error"}, + }, + { + name: "XX on non-existing key", + commands: []string{"DEL k", "SET k v XX", "GET k"}, + expected: []interface{}{int64(0), "(nil)", "(nil)"}, + }, + { + name: "NX on non-existing key", + commands: []string{"DEL k", "SET k v NX", "GET k"}, + expected: []interface{}{int64(0), "OK", "v"}, + }, + { + name: "NX on existing key", + commands: []string{"DEL k", "SET k v NX", "GET k", "SET k v NX"}, + expected: []interface{}{int64(0), "OK", "v", "(nil)"}, + }, + { + name: "PXAT option", + commands: []string{"SET k v PXAT " + expiryTime, "GET k"}, + expected: []interface{}{"OK", "v"}, + }, + { + name: "PXAT option with delete", + commands: []string{"SET k1 v1 PXAT " + expiryTime, "GET k1", "SLEEP 2", "DEL k1"}, + expected: []interface{}{"OK", "v1", "OK", int64(1)}, + }, + { + name: "PXAT option with invalid unix time ms", + commands: []string{"SET k2 v2 PXAT 123123", "GET k2"}, + expected: []interface{}{"OK", "(nil)"}, + }, + { + name: "XX on existing key", + commands: []string{"SET k v1", "SET k v2 XX", "GET k"}, + expected: []interface{}{"OK", "OK", "v2"}, + }, + { + name: "Multiple XX operations", + commands: []string{"SET k v1", "SET k v2 XX", "SET k v3 XX", "GET k"}, + expected: []interface{}{"OK", "OK", "OK", "v3"}, + }, + { + name: "EX option", + commands: []string{"SET k v EX 1", "GET k", "SLEEP 2", "GET k"}, + expected: []interface{}{"OK", "v", "OK", "(nil)"}, + }, + { + name: "XX option", + commands: []string{"SET k v XX EX 1", "GET k", "SLEEP 2", "GET k", "SET k v XX EX 1", "GET k"}, + expected: []interface{}{"(nil)", "(nil)", "OK", "(nil)", "(nil)", "(nil)"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // deleteTestKeys([]string{"k", "k1", "k2"}, store) + FireCommand(conn, "DEL k") + FireCommand(conn, "DEL k1") + FireCommand(conn, "DEL k2") + for i, cmd := range tc.commands { + result := FireCommand(conn, cmd) + assert.Equal(t, tc.expected[i], result) + } + }) + } +} + +func TestSetWithExat(t *testing.T) { + conn := getLocalConnection() + defer conn.Close() + Etime := strconv.FormatInt(time.Now().Unix()+5, 10) + BadTime := "123123" + + t.Run("SET with EXAT", + func(t *testing.T) { + // deleteTestKeys([]string{"k"}, store) + FireCommand(conn, "DEL k") + assert.Equal(t, "OK", FireCommand(conn, "SET k v EXAT "+Etime), "Value mismatch for cmd SET k v EXAT "+Etime) + assert.Equal(t, "v", FireCommand(conn, "GET k"), "Value mismatch for cmd GET k") + assert.Assert(t, FireCommand(conn, "TTL k").(int64) <= 5, "Value mismatch for cmd TTL k") + time.Sleep(3 * time.Second) + assert.Assert(t, FireCommand(conn, "TTL k").(int64) <= 3, "Value mismatch for cmd TTL k") + time.Sleep(3 * time.Second) + assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Value mismatch for cmd GET k") + assert.Equal(t, int64(-2), FireCommand(conn, "TTL k"), "Value mismatch for cmd TTL k") + }) + + t.Run("SET with invalid EXAT expires key immediately", + func(t *testing.T) { + // deleteTestKeys([]string{"k"}, store) + FireCommand(conn, "DEL k") + assert.Equal(t, "OK", FireCommand(conn, "SET k v EXAT "+BadTime), "Value mismatch for cmd SET k v EXAT "+BadTime) + assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Value mismatch for cmd GET k") + assert.Equal(t, int64(-2), FireCommand(conn, "TTL k"), "Value mismatch for cmd TTL k") + }) + + t.Run("SET with EXAT and PXAT returns syntax error", + func(t *testing.T) { + // deleteTestKeys([]string{"k"}, store) + FireCommand(conn, "DEL k") + assert.Equal(t, "ERR syntax error", FireCommand(conn, "SET k v PXAT "+Etime+" EXAT "+Etime), "Value mismatch for cmd SET k v PXAT "+Etime+" EXAT "+Etime) + assert.Equal(t, "(nil)", FireCommand(conn, "GET k"), "Value mismatch for cmd GET k") + }) +} + +func TestWithKeepTTLFlag(t *testing.T) { + conn := getLocalConnection() + defer conn.Close() + + for _, tcase := range []TestCase{ + { + commands: []string{"SET k v EX 2", "SET k vv KEEPTTL", "GET k", "SET kk vv", "SET kk vvv KEEPTTL", "GET kk"}, + expected: []interface{}{"OK", "OK", "vv", "OK", "OK", "vvv"}, + }, + } { + for i := 0; i < len(tcase.commands); i++ { + cmd := tcase.commands[i] + out := tcase.expected[i] + assert.Equal(t, out, FireCommand(conn, cmd), "Value mismatch for cmd %s\n.", cmd) + } + } + + time.Sleep(2 * time.Second) + + cmd := "GET k" + out := "(nil)" + + assert.Equal(t, out, FireCommand(conn, cmd), "Value mismatch for cmd %s\n.", cmd) +} diff --git a/integration_tests/commands/resp/setup.go b/integration_tests/commands/resp/setup.go new file mode 100644 index 000000000..8bd0f1cde --- /dev/null +++ b/integration_tests/commands/resp/setup.go @@ -0,0 +1,165 @@ +package resp + +import ( + "context" + "errors" + "fmt" + "io" + "log/slog" + "net" + "os" + "sync" + "time" + + "github.com/dicedb/dice/internal/server/resp" + "github.com/dicedb/dice/internal/worker" + + "github.com/dicedb/dice/config" + "github.com/dicedb/dice/internal/clientio" + derrors "github.com/dicedb/dice/internal/errors" + "github.com/dicedb/dice/internal/logger" + "github.com/dicedb/dice/internal/shard" + dstore "github.com/dicedb/dice/internal/store" + "github.com/dicedb/dice/testutils" + redis "github.com/dicedb/go-dice" +) + +type TestServerOptions struct { + Port int + Logger *slog.Logger +} + +//nolint:unused +func getLocalConnection() net.Conn { + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", config.DiceConfig.Server.Port)) + if err != nil { + panic(err) + } + return conn +} + +// deleteTestKeys is a utility to delete a list of keys before running a test +// +//nolint:unused +func deleteTestKeys(keysToDelete []string, store *dstore.Store) { + for _, key := range keysToDelete { + store.Del(key) + } +} + +//nolint:unused +func getLocalSdk() *redis.Client { + return redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf(":%d", config.DiceConfig.Server.Port), + + DialTimeout: 10 * time.Second, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + ContextTimeoutEnabled: true, + + MaxRetries: -1, + + PoolSize: 10, + PoolTimeout: 30 * time.Second, + ConnMaxIdleTime: time.Minute, + }) +} + +func FireCommand(conn net.Conn, cmd string) interface{} { + var err error + args := testutils.ParseCommand(cmd) + _, err = conn.Write(clientio.Encode(args, false)) + if err != nil { + slog.Error( + "error while firing command", + slog.Any("error", err), + slog.String("command", cmd), + ) + os.Exit(1) + } + + rp := clientio.NewRESPParser(conn) + v, err := rp.DecodeOne() + if err != nil { + if err == io.EOF { + return nil + } + slog.Error( + "error while firing command", + slog.Any("error", err), + slog.String("command", cmd), + ) + os.Exit(1) + } + + return v +} + +//nolint:unused +func fireCommandAndGetRESPParser(conn net.Conn, cmd string) *clientio.RESPParser { + args := testutils.ParseCommand(cmd) + _, err := conn.Write(clientio.Encode(args, false)) + if err != nil { + slog.Error( + "error while firing command", + slog.Any("error", err), + slog.String("command", cmd), + ) + os.Exit(1) + } + + return clientio.NewRESPParser(conn) +} + +func RunTestServer(wg *sync.WaitGroup, opt TestServerOptions) { + logr := logger.New(logger.Opts{WithTimestamp: true}) + slog.SetDefault(logr) + config.DiceConfig.Network.IOBufferLength = 16 + config.DiceConfig.Server.WriteAOFOnCleanup = false + if opt.Port != 0 { + config.DiceConfig.Server.Port = opt.Port + } else { + config.DiceConfig.Server.Port = 8739 + } + + watchChan := make(chan dstore.QueryWatchEvent, config.DiceConfig.Server.KeysLimit) + gec := make(chan error) + shardManager := shard.NewShardManager(1, watchChan, gec, logr) + workerManager := worker.NewWorkerManager(20000, shardManager) + // Initialize the REST Server + testServer := resp.NewServer(shardManager, workerManager, gec, logr) + + ctx, cancel := context.WithCancel(context.Background()) + fmt.Println("Starting the test server on port", config.DiceConfig.Server.Port) + + shardManagerCtx, cancelShardManager := context.WithCancel(ctx) + wg.Add(1) + go func() { + defer wg.Done() + shardManager.Run(shardManagerCtx) + }() + + // Start the server in a goroutine + wg.Add(1) + go func() { + defer wg.Done() + if err := testServer.Run(ctx); err != nil { + if errors.Is(err, derrors.ErrAborted) { + cancelShardManager() + return + } + opt.Logger.Error("Test server encountered an error", slog.Any("error", err)) + os.Exit(1) + } + }() + + go func() { + for err := range gec { + if err != nil && errors.Is(err, derrors.ErrAborted) { + // if either the AsyncServer/RESPServer or the HTTPServer received an abort command, + // cancel the context, helping gracefully exiting all servers + cancel() + } + } + }() +} diff --git a/integration_tests/server/max_conn_test.go b/integration_tests/server/max_conn_test.go index 72467d7ff..18c09ee62 100644 --- a/integration_tests/server/max_conn_test.go +++ b/integration_tests/server/max_conn_test.go @@ -52,10 +52,7 @@ func TestMaxConnection(t *testing.T) { t.Fatalf("unexpected error while getting connection %d: %v", i, err) } } - assert.Equal(t, maxConnLimit, len(connections), "should have reached the max connection limit") - - _, err := getConnection(maxConnTestOptions.Port) - assert.ErrorContains(t, err, "connect: connection refused") + assert.Equal(t, maxConnLimit, int32(len(connections)), "should have reached the max connection limit") result := commands.FireCommand(connections[0], "ABORT") if result != "OK" { diff --git a/internal/clientio/iohandler/iohandler.go b/internal/clientio/iohandler/iohandler.go index 73b6e3adc..3b887a5c8 100644 --- a/internal/clientio/iohandler/iohandler.go +++ b/internal/clientio/iohandler/iohandler.go @@ -6,6 +6,6 @@ import ( type IOHandler interface { Read(ctx context.Context) ([]byte, error) - Write(ctx context.Context, response []byte) error + Write(ctx context.Context, response interface{}) error Close() error } diff --git a/internal/clientio/iohandler/netconn/netconn.go b/internal/clientio/iohandler/netconn/netconn.go index 555af4982..21399cbde 100644 --- a/internal/clientio/iohandler/netconn/netconn.go +++ b/internal/clientio/iohandler/netconn/netconn.go @@ -12,7 +12,9 @@ import ( "syscall" "time" + "github.com/dicedb/dice/internal/clientio" "github.com/dicedb/dice/internal/clientio/iohandler" + "github.com/dicedb/dice/internal/eval" ) const ( @@ -144,11 +146,29 @@ func (h *IOHandler) Read(ctx context.Context) ([]byte, error) { } // WriteResponse writes the response back to the network connection -func (h *IOHandler) Write(ctx context.Context, response []byte) error { +func (h *IOHandler) Write(ctx context.Context, response interface{}) error { errChan := make(chan error, 1) + // Process the incoming response by calling the handleResponse function. + // This function checks the response against known RESP formatted values + // and returns the corresponding byte array representation. The result + // is assigned to the resp variable. + resp := HandlePredefinedResponse(response) + + // Check if the processed response (resp) is not nil. + // If it is not nil, this means incoming response was not + // matched to any predefined RESP responses, + // and we proceed to encode the original response using + // the clientio.Encode function. This function converts the + // response into the desired format based on the specified + // isBlkEnc encoding flag, which indicates whether the + // response should be encoded in a block format. + if resp == nil { + resp = clientio.Encode(response, true) + } + go func(errChan chan error) { - _, err := h.writer.Write(response) + _, err := h.writer.Write(resp) if err == nil { err = h.writer.Flush() } @@ -183,3 +203,43 @@ func (h *IOHandler) Close() error { h.logger.Info("Closing connection") return errors.Join(h.conn.Close(), h.file.Close()) } + +// handleResponse processes the incoming response from a client and returns the corresponding +// RESP (REdis Serialization Protocol) formatted byte array based on the response content. +// +// The function takes an interface{} as input, attempts to assert it as a byte slice. If successful, +// it checks the content of the byte slice against predefined RESP responses using the `bytes.Contains` +// function. If a match is found, it returns the associated byte array response. If no match is found +// or if the input cannot be converted to a byte slice, the function returns nil. +// +// This function is designed to handle various response scenarios, such as: +// - $-1: Represents a nil response. +// - +OK: Indicates a successful command execution. +// - +QUEUED: Signifies that a command has been queued. +// - :0, :1, :-1, :-2: Represents integer values in RESP format. +// - *0: Represents an empty array in RESP format. +// +// Note: The use of `bytes.Contains` is to check if the provided response matches any of the +// predefined RESP responses, making it flexible in handling responses that might include +// additional content beyond the expected response format. +func HandlePredefinedResponse(response interface{}) []byte { + // WARN: Do not change the ordering of the array elements + // It is strictly mapped to internal/eval/results.go enum. + respArr := [][]byte{ + clientio.RespNIL, // Represents a RESP Nil Bulk String, which indicates a null value. + clientio.RespOK, // Represents a RESP Simple String with value "OK". + clientio.RespQueued, // Represents a Simple String indicating that a command has been queued. + clientio.RespZero, // Represents a RESP Integer with value 0. + clientio.RespOne, // Represents a RESP Integer with value 1. + clientio.RespMinusOne, // Represents a RESP Integer with value -1. + clientio.RespMinusTwo, // Represents a RESP Integer with value -2. + clientio.RespEmptyArray, // Represents an empty RESP Array. + } + + switch val := response.(type) { + case eval.RespType: + return respArr[val] + default: + return nil + } +} diff --git a/internal/clientio/iohandler/netconn/netconn_resp_test.go b/internal/clientio/iohandler/netconn/netconn_resp_test.go index ba605810a..920dea768 100644 --- a/internal/clientio/iohandler/netconn/netconn_resp_test.go +++ b/internal/clientio/iohandler/netconn/netconn_resp_test.go @@ -4,12 +4,13 @@ import ( "bufio" "context" "errors" - "github.com/dicedb/dice/mocks" "log/slog" "strings" "testing" "time" + "github.com/dicedb/dice/mocks" + "github.com/stretchr/testify/assert" ) diff --git a/internal/clientio/iohandler/netconn/netconn_test.go b/internal/clientio/iohandler/netconn/netconn_test.go index 207d0afc0..a9a5b5b72 100644 --- a/internal/clientio/iohandler/netconn/netconn_test.go +++ b/internal/clientio/iohandler/netconn/netconn_test.go @@ -5,7 +5,6 @@ import ( "bytes" "context" "errors" - "github.com/dicedb/dice/mocks" "io" "log/slog" "net" @@ -14,6 +13,8 @@ import ( "testing" "time" + "github.com/dicedb/dice/mocks" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/internal/clientio/resp.go b/internal/clientio/resp.go index 2a6e53ff7..5823a96cc 100644 --- a/internal/clientio/resp.go +++ b/internal/clientio/resp.go @@ -158,7 +158,16 @@ func encodeBool(v bool) []byte { } func Encode(value interface{}, isSimple bool) []byte { + // Use a type switch to determine the type of the provided value and encode accordingly. switch v := value.(type) { + // Temporary case to maintain backwards compatibility. + // This case handles byte slices ([]byte) directly, allowing existing functionality + // that relies on byte slice inputs to continue working without modifications. + // It serves as a transitional measure and should be revisited for removal + // once all commands are migrated. + case []byte: + return v // Return the byte slice as-is. + case string: // encode as simple strings if isSimple || v == "[" || v == "{" { @@ -167,7 +176,9 @@ func Encode(value interface{}, isSimple bool) []byte { // encode as bulk strings return encodeString(v) case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - return []byte(fmt.Sprintf(":%d\r\n", v)) + return []byte(fmt.Sprintf(":%d\r\n", v)) // Prefix with ':' for RESP integers. + + // Handle floating-point types similarly to integers. case float32, float64: // In case the element being encoded was obtained after parsing a JSON value, // it is possible for integers to have been encoded as floats @@ -185,32 +196,40 @@ func Encode(value interface{}, isSimple bool) []byte { return encodeBool(v) case []string: var b []byte - buf := bytes.NewBuffer(b) + buf := bytes.NewBuffer(b) // Create a buffer to accumulate encoded strings. for _, b := range value.([]string) { - buf.Write(encodeString(b)) + buf.Write(encodeString(b)) // Encode each string and write to the buffer. } - return []byte(fmt.Sprintf("*%d\r\n%s", len(v), buf.Bytes())) + return []byte(fmt.Sprintf("*%d\r\n%s", len(v), buf.Bytes())) // Return the encoded response. + + // Handle slices of custom objects (Obj). case []*object.Obj: var b []byte - buf := bytes.NewBuffer(b) + buf := bytes.NewBuffer(b) // Create a buffer to accumulate encoded objects. for _, b := range value.([]*object.Obj) { - buf.Write(Encode(b.Value, false)) + buf.Write(Encode(b.Value, false)) // Encode each object’s value and write to the buffer. } - return []byte(fmt.Sprintf("*%d\r\n%s", len(v), buf.Bytes())) + return []byte(fmt.Sprintf("*%d\r\n%s", len(v), buf.Bytes())) // Return the encoded response. + + // Handle slices of interfaces. case []interface{}: var b []byte - buf := bytes.NewBuffer(b) + buf := bytes.NewBuffer(b) // Create a buffer for accumulating encoded values. for _, elem := range v { - buf.Write(Encode(elem, false)) + buf.Write(Encode(elem, false)) // Encode each element and write to the buffer. } - return []byte(fmt.Sprintf("*%d\r\n%s", len(v), buf.Bytes())) + return []byte(fmt.Sprintf("*%d\r\n%s", len(v), buf.Bytes())) // Return the encoded response. + + // Handle slices of int64. case []int64: var b []byte - buf := bytes.NewBuffer(b) + buf := bytes.NewBuffer(b) // Create a buffer for accumulating encoded values. for _, b := range value.([]int64) { - buf.Write(Encode(b, false)) + buf.Write(Encode(b, false)) // Encode each int64 and write to the buffer. } - return []byte(fmt.Sprintf("*%d\r\n%s", len(v), buf.Bytes())) + return []byte(fmt.Sprintf("*%d\r\n%s", len(v), buf.Bytes())) // Return the encoded response. + + // Handle error type by formatting it as a RESP error. case error: return []byte(fmt.Sprintf("-%s\r\n", v)) case dstore.QueryWatchEvent: @@ -222,16 +241,20 @@ func Encode(value interface{}, isSimple bool) []byte { return []byte(fmt.Sprintf("*2\r\n%s", buf.Bytes())) case []sql.QueryResultRow: var b []byte - buf := bytes.NewBuffer(b) + buf := bytes.NewBuffer(b) // Create a buffer for accumulating encoded rows. for _, row := range value.([]sql.QueryResultRow) { - buf.WriteString("*2\r\n") - buf.Write(Encode(row.Key, false)) - buf.Write(Encode(row.Value.Value, false)) + buf.WriteString("*2\r\n") // Start a new array for each row. + buf.Write(Encode(row.Key, false)) // Encode the row key. + buf.Write(Encode(row.Value.Value, false)) // Encode the row value. } - return []byte(fmt.Sprintf("*%d\r\n%s", len(v), buf.Bytes())) + return []byte(fmt.Sprintf("*%d\r\n%s", len(v), buf.Bytes())) // Return the encoded response. + + // Handle map[string]bool and return a nil response indicating unsupported types. case map[string]bool: - return RespNIL + return RespNIL // Return nil response for unsupported type. + + // For all other unsupported types, return a nil response. default: - return RespNIL + return RespNIL // Return nil response for unsupported types. } } diff --git a/internal/errors/errors.go b/internal/errors/errors.go index 05d7fb27c..17f0b7681 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -27,11 +27,6 @@ const ( InvalidIntErr = "-ERR value is not a valid integer" ) -var ( - ErrAborted = errors.New("server received ABORT command") - ErrEmptyCommand = errors.New("empty command") -) - type DiceError struct { message error } diff --git a/internal/errors/migrated_errors.go b/internal/errors/migrated_errors.go new file mode 100644 index 000000000..3a68f7871 --- /dev/null +++ b/internal/errors/migrated_errors.go @@ -0,0 +1,64 @@ +package errors + +import ( + "errors" + "fmt" +) + +// Package errors provides error definitions and utility functions for handling +// common Redis error scenarios within the application. This package centralizes +// error messages to ensure consistency and clarity when interacting with Redis +// commands and responses. + +// Standard error variables for various Redis-related error conditions. +var ( + ErrAuthFailed = errors.New("AUTH failed") // Indicates authentication failure. + ErrIntegerOutOfRange = errors.New("ERR value is not an integer or out of range") // Represents a value that is either not an integer or is out of allowed range. + ErrInvalidNumberFormat = errors.New("ERR value is not an integer or a float") // Signals that a value provided is not in a valid integer or float format. + ErrValueOutOfRange = errors.New("ERR value is out of range") // Indicates that a value is beyond the permissible range. + ErrOverflow = errors.New("ERR increment or decrement would overflow") // Signifies that an increment or decrement operation would exceed the limits. + ErrSyntax = errors.New("ERR syntax error") // Represents a syntax error in a Redis command. + ErrKeyNotFound = errors.New("ERR no such key") // Indicates that the specified key does not exist. + ErrWrongTypeOperation = errors.New("WRONGTYPE Operation against a key holding the wrong kind of value") // Signals an operation attempted on a key with an incompatible type. + ErrInvalidHyperLogLogKey = errors.New("WRONGTYPE Key is not a valid HyperLogLog string value") // Indicates that a key is not a valid HyperLogLog value. + ErrCorruptedHyperLogLogObject = errors.New("INVALIDOBJ Corrupted HLL object detected") // Signals detection of a corrupted HyperLogLog object. + ErrInvalidJSONPathType = errors.New("WRONGTYPE wrong type of path value - expected string but found integer") // Represents an invalid type for a JSON path. + ErrInvalidExpireTimeValue = errors.New("ERR invalid expire time") // Indicates that the provided expiration time is invalid. + ErrHashValueNotInteger = errors.New("ERR hash value is not an integer") // Signifies that a hash value is expected to be an integer. + ErrInternalServer = errors.New("ERR Internal server error, unable to process command") // Represents a generic internal server error. + ErrAuth = errors.New("AUTH called without any password configured for the default user. Are you sure your configuration is correct?") + ErrAborted = errors.New("server received ABORT command") + ErrEmptyCommand = errors.New("empty command") + + // Error generation functions for specific error messages with dynamic parameters. + ErrWrongArgumentCount = func(command string) error { + return fmt.Errorf("ERR wrong number of arguments for '%s' command", command) // Indicates an incorrect number of arguments for a given command. + } + ErrInvalidExpireTime = func(command string) error { + return fmt.Errorf("ERR invalid expire time in '%s' command", command) // Represents an invalid expiration time for a specific command. + } + + ErrInvalidElementPeekCount = func(max int) error { + return fmt.Errorf("ERR number of elements to peek should be a positive number less than %d", max) // Signals an invalid count for elements to peek. + } + + ErrGeneral = func(err string) error { + return fmt.Errorf("ERR %s", err) // General error format for various commands. + } + + ErrWorkerNotFound = func(workerID string) error { + return fmt.Errorf("ERR worker with ID %s not found", workerID) // Indicates that a worker with the specified ID does not exist. + } + + ErrJSONPathNotFound = func(path string) error { + return fmt.Errorf("ERR Path '%s' does not exist", path) // Represents an error where the specified JSON path cannot be found. + } + + ErrUnsupportedEncoding = func(encoding int) error { + return fmt.Errorf("ERR unsupported encoding: %d", encoding) // Indicates that an unsupported encoding type was provided. + } + + ErrUnexpectedType = func(expectedType string, actualType interface{}) error { + return fmt.Errorf("ERR expected %s but got another type: %s", expectedType, actualType) // Signals an unexpected type received when an integer was expected. + } +) diff --git a/internal/eval/commands.go b/internal/eval/commands.go index c3de1d6fd..6f7c3b734 100644 --- a/internal/eval/commands.go +++ b/internal/eval/commands.go @@ -24,7 +24,7 @@ type DiceCmdMeta struct { // instead of just raw bytes. Commands that have been migrated to this new model // will utilize this function for evaluation, allowing for better handling of // complex command execution scenarios and improved response consistency. - NewEval func([]string, *dstore.Store) EvalResponse + NewEval func([]string, *dstore.Store) *EvalResponse } type KeySpecs struct { diff --git a/internal/eval/eval_test.go b/internal/eval/eval_test.go index 54a6bfe85..ffbf22493 100644 --- a/internal/eval/eval_test.go +++ b/internal/eval/eval_test.go @@ -30,6 +30,7 @@ type evalTestCase struct { input []string output []byte validator func(output []byte) + newValidator func(output interface{}) migratedOutput EvalResponse } @@ -152,87 +153,87 @@ func testEvalSET(t *testing.T, store *dstore.Store) { { name: "nil value", input: nil, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR wrong number of arguments for 'set' command\r\n")}, + migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR wrong number of arguments for 'SET' command")}, }, { name: "empty array", input: []string{}, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR wrong number of arguments for 'set' command\r\n")}, + migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR wrong number of arguments for 'SET' command")}, }, { name: "one value", input: []string{"KEY"}, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR wrong number of arguments for 'set' command\r\n")}, + migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR wrong number of arguments for 'SET' command")}, }, { name: "key val pair", input: []string{"KEY", "VAL"}, - migratedOutput: EvalResponse{Result: clientio.RespOK, Error: nil}, + migratedOutput: EvalResponse{Result: RespOK, Error: nil}, }, { name: "key val pair with int val", input: []string{"KEY", "123456"}, - migratedOutput: EvalResponse{Result: clientio.RespOK, Error: nil}, + migratedOutput: EvalResponse{Result: RespOK, Error: nil}, }, { name: "key val pair and expiry key", input: []string{"KEY", "VAL", Px}, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR syntax error\r\n")}, + migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR syntax error")}, }, { name: "key val pair and EX no val", input: []string{"KEY", "VAL", Ex}, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR syntax error\r\n")}, + migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR syntax error")}, }, { name: "key val pair and valid EX", input: []string{"KEY", "VAL", Ex, "2"}, - migratedOutput: EvalResponse{Result: clientio.RespOK, Error: nil}, + migratedOutput: EvalResponse{Result: RespOK, Error: nil}, }, { name: "key val pair and invalid EX", input: []string{"KEY", "VAL", Ex, "invalid_expiry_val"}, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR value is not an integer or out of range\r\n")}, + migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR value is not an integer or out of range")}, }, { name: "key val pair and valid PX", input: []string{"KEY", "VAL", Px, "2000"}, - migratedOutput: EvalResponse{Result: clientio.RespOK, Error: nil}, + migratedOutput: EvalResponse{Result: RespOK, Error: nil}, }, { name: "key val pair and invalid PX", input: []string{"KEY", "VAL", Px, "invalid_expiry_val"}, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR value is not an integer or out of range\r\n")}, + migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR value is not an integer or out of range")}, }, { name: "key val pair and both EX and PX", input: []string{"KEY", "VAL", Ex, "2", Px, "2000"}, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR syntax error\r\n")}, + migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR syntax error")}, }, { name: "key val pair and PXAT no val", input: []string{"KEY", "VAL", Pxat}, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR syntax error\r\n")}, + migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR syntax error")}, }, { name: "key val pair and invalid PXAT", input: []string{"KEY", "VAL", Pxat, "invalid_expiry_val"}, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR value is not an integer or out of range\r\n")}, + migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR value is not an integer or out of range")}, }, { name: "key val pair and expired PXAT", input: []string{"KEY", "VAL", Pxat, "2"}, - migratedOutput: EvalResponse{Result: clientio.RespOK, Error: nil}, + migratedOutput: EvalResponse{Result: RespOK, Error: nil}, }, { name: "key val pair and negative PXAT", input: []string{"KEY", "VAL", Pxat, "-123456"}, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR invalid expire time in 'set' command\r\n")}, + migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR invalid expire time in 'SET' command")}, }, { name: "key val pair and valid PXAT", input: []string{"KEY", "VAL", Pxat, strconv.FormatInt(time.Now().Add(2*time.Minute).UnixMilli(), 10)}, - migratedOutput: EvalResponse{Result: clientio.RespOK, Error: nil}, + migratedOutput: EvalResponse{Result: RespOK, Error: nil}, }, } @@ -323,36 +324,36 @@ func testEvalGET(t *testing.T, store *dstore.Store) { { name: "nil value", input: nil, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR wrong number of arguments for 'get' command\r\n")}, + migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR wrong number of arguments for 'GET' command")}, }, { name: "empty array", input: []string{}, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR wrong number of arguments for 'get' command\r\n")}, + migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR wrong number of arguments for 'GET' command")}, }, { name: "key does not exist", input: []string{"NONEXISTENT_KEY"}, - migratedOutput: EvalResponse{Result: clientio.RespNIL, Error: nil}, + migratedOutput: EvalResponse{Result: RespNIL, Error: nil}, }, { name: "multiple arguments", input: []string{"KEY1", "KEY2"}, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR wrong number of arguments for 'get' command\r\n")}, + migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR wrong number of arguments for 'GET' command")}, }, { name: "key exists", setup: func() { - key := "EXISTING_KEY" - value := "mock_value" + key := "diceKey" + value := "diceVal" obj := &object.Obj{ Value: value, LastAccessedAt: uint32(time.Now().Unix()), } store.Put(key, obj) }, - input: []string{"EXISTING_KEY"}, - migratedOutput: EvalResponse{Result: fmt.Sprintf("$%d\r\n%s\r\n", len("mock_value"), "mock_value"), Error: nil}, + input: []string{"diceKey"}, + migratedOutput: EvalResponse{Result: "diceVal", Error: nil}, }, { name: "key exists but expired", @@ -367,14 +368,21 @@ func testEvalGET(t *testing.T, store *dstore.Store) { store.SetExpiry(obj, int64(-2*time.Millisecond)) }, input: []string{"EXISTING_KEY"}, - migratedOutput: EvalResponse{Result: clientio.RespNIL, Error: nil}, + migratedOutput: EvalResponse{Result: RespNIL, Error: nil}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + // Setup the test store + if tt.setup != nil { + tt.setup() + } + response := evalGET(tt.input, store) + fmt.Printf("Response: %v | Expected: %v\n", *response, tt.migratedOutput.Result) + // Handle comparison for byte slices if b, ok := response.Result.([]byte); ok && tt.migratedOutput.Result != nil { if expectedBytes, ok := tt.migratedOutput.Result.([]byte); ok { @@ -398,17 +406,17 @@ func testEvalGETSET(t *testing.T, store *dstore.Store) { { name: "GETSET with 1 arg", input: []string{"HELLO"}, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR wrong number of arguments for 'getset' command\r\n")}, + migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR wrong number of arguments for 'GETSET' command")}, }, { name: "GETSET with 3 args", input: []string{"HELLO", "WORLD", "WORLD1"}, - migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR wrong number of arguments for 'getset' command\r\n")}, + migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR wrong number of arguments for 'GETSET' command")}, }, { name: "GETSET key not exists", input: []string{"HELLO", "WORLD"}, - migratedOutput: EvalResponse{Result: clientio.RespNIL, Error: nil}, + migratedOutput: EvalResponse{Result: RespNIL, Error: nil}, }, { name: "GETSET key exists", @@ -422,7 +430,7 @@ func testEvalGETSET(t *testing.T, store *dstore.Store) { store.Put(key, obj) }, input: []string{"EXISTING_KEY", "WORLD"}, - migratedOutput: EvalResponse{Result: fmt.Sprintf("$%d\r\n%s\r\n", len("mock_value"), "mock_value"), Error: nil}, + migratedOutput: EvalResponse{Result: "mock_value", Error: nil}, }, { name: "GETSET key exists TTL should be reset", @@ -436,12 +444,17 @@ func testEvalGETSET(t *testing.T, store *dstore.Store) { store.Put(key, obj) }, input: []string{"EXISTING_KEY", "WORLD"}, - migratedOutput: EvalResponse{Result: fmt.Sprintf("$%d\r\n%s\r\n", len("mock_value"), "mock_value"), Error: nil}, + migratedOutput: EvalResponse{Result: "mock_value", Error: nil}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + // Setup the test store + if tt.setup != nil { + tt.setup() + } + response := evalGETSET(tt.input, store) // Handle comparison for byte slices @@ -4007,28 +4020,28 @@ func testEvalSETEX(t *testing.T, store *dstore.Store) { utils.CurrentTime = mockTime tests := map[string]evalTestCase{ - "nil value": {input: nil, migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR wrong number of arguments for 'setex' command\r\n")}}, - "empty array": {input: []string{}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR wrong number of arguments for 'setex' command\r\n")}}, - "one value": {input: []string{"KEY"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR wrong number of arguments for 'setex' command\r\n")}}, - "key val pair": {input: []string{"KEY", "VAL"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR wrong number of arguments for 'setex' command\r\n")}}, - "key exp pair": {input: []string{"KEY", "123456"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR wrong number of arguments for 'setex' command\r\n")}}, - "key exp value pair": {input: []string{"KEY", "123", "VAL"}, migratedOutput: EvalResponse{Result: clientio.RespOK, Error: nil}}, - "key exp value pair with extra args": {input: []string{"KEY", "123", "VAL", " "}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR wrong number of arguments for 'setex' command\r\n")}}, - "key exp value pair with invalid exp": {input: []string{"KEY", "0", "VAL"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR invalid expire time in 'setex' command\r\n")}}, - "key exp value pair with exp > maxexp": {input: []string{"KEY", "9223372036854776", "VAL"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR invalid expire time in 'setex' command\r\n")}}, - "key exp value pair with exp > maxint64": {input: []string{"KEY", "92233720368547760000000", "VAL"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR value is not an integer or out of range\r\n")}}, - "key exp value pair with negative exp": {input: []string{"KEY", "-23", "VAL"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR invalid expire time in 'setex' command\r\n")}}, - "key exp value pair with not-int exp": {input: []string{"KEY", "12a", "VAL"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("-ERR value is not an integer or out of range\r\n")}}, + "nil value": {input: nil, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR wrong number of arguments for 'SETEX' command")}}, + "empty array": {input: []string{}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR wrong number of arguments for 'SETEX' command")}}, + "one value": {input: []string{"KEY"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR wrong number of arguments for 'SETEX' command")}}, + "key val pair": {input: []string{"KEY", "VAL"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR wrong number of arguments for 'SETEX' command")}}, + "key exp pair": {input: []string{"KEY", "123456"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR wrong number of arguments for 'SETEX' command")}}, + "key exp value pair": {input: []string{"KEY", "123", "VAL"}, migratedOutput: EvalResponse{Result: RespOK, Error: nil}}, + "key exp value pair with extra args": {input: []string{"KEY", "123", "VAL", " "}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR wrong number of arguments for 'SETEX' command")}}, + "key exp value pair with invalid exp": {input: []string{"KEY", "0", "VAL"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR invalid expire time in 'SETEX' command")}}, + "key exp value pair with exp > maxexp": {input: []string{"KEY", "9223372036854776", "VAL"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR invalid expire time in 'SETEX' command")}}, + "key exp value pair with exp > maxint64": {input: []string{"KEY", "92233720368547760000000", "VAL"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR value is not an integer or out of range")}}, + "key exp value pair with negative exp": {input: []string{"KEY", "-23", "VAL"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR invalid expire time in 'SETEX' command")}}, + "key exp value pair with not-int exp": {input: []string{"KEY", "12a", "VAL"}, migratedOutput: EvalResponse{Result: nil, Error: errors.New("ERR value is not an integer or out of range")}}, "set and get": { setup: func() {}, input: []string{"TEST_KEY", "5", "TEST_VALUE"}, - validator: func(output []byte) { - assert.Equal(t, string(clientio.RespOK), string(output)) + newValidator: func(output interface{}) { + assert.Equal(t, RespOK, output) // Check if the key was set correctly getValue := evalGET([]string{"TEST_KEY"}, store) - assert.Equal(t, string(clientio.Encode("TEST_VALUE", false)), string(getValue.Result.([]byte))) + assert.Equal(t, "TEST_VALUE", getValue.Result) // Check if the TTL is set correctly (should be 5 seconds or less) ttlValue := evalTTL([]string{"TEST_KEY"}, store) @@ -4041,7 +4054,7 @@ func testEvalSETEX(t *testing.T, store *dstore.Store) { // Check if the key has been deleted after expiry expiredValue := evalGET([]string{"TEST_KEY"}, store) - assert.Equal(t, string(clientio.RespNIL), string(expiredValue.Result.([]byte))) + assert.Equal(t, RespNIL, expiredValue.Result) }, }, "update existing key": { @@ -4049,12 +4062,12 @@ func testEvalSETEX(t *testing.T, store *dstore.Store) { evalSET([]string{"EXISTING_KEY", "OLD_VALUE"}, store) }, input: []string{"EXISTING_KEY", "10", "NEW_VALUE"}, - validator: func(output []byte) { - assert.Equal(t, string(clientio.RespOK), string(output)) + newValidator: func(output interface{}) { + assert.Equal(t, RespOK, output) // Check if the key was updated correctly getValue := evalGET([]string{"EXISTING_KEY"}, store) - assert.Equal(t, string(clientio.Encode("NEW_VALUE", false)), string(getValue.Result.([]byte))) + assert.Equal(t, "NEW_VALUE", getValue.Result) // Check if the TTL is set correctly ttlValue := evalTTL([]string{"EXISTING_KEY"}, store) @@ -4069,11 +4082,11 @@ func testEvalSETEX(t *testing.T, store *dstore.Store) { t.Run(tt.name, func(t *testing.T) { response := evalSETEX(tt.input, store) - if tt.validator != nil { + if tt.newValidator != nil { if tt.migratedOutput.Error != nil { - tt.validator([]byte(tt.migratedOutput.Error.Error())) + tt.newValidator(tt.migratedOutput.Error) } else { - tt.validator(response.Result.([]byte)) + tt.newValidator(response.Result) } } else { // Handle comparison for byte slices diff --git a/internal/eval/execute.go b/internal/eval/execute.go index 24b582e4b..14d9642af 100644 --- a/internal/eval/execute.go +++ b/internal/eval/execute.go @@ -11,10 +11,10 @@ import ( dstore "github.com/dicedb/dice/internal/store" ) -func ExecuteCommand(c *cmd.RedisCmd, client *comm.Client, store *dstore.Store, httpOp, websocketOp bool) EvalResponse { +func ExecuteCommand(c *cmd.RedisCmd, client *comm.Client, store *dstore.Store, httpOp, websocketOp bool) *EvalResponse { diceCmd, ok := DiceCmds[c.Cmd] if !ok { - return EvalResponse{Result: diceerrors.NewErrWithFormattedMessage("unknown command '%s', with args beginning with: %s", c.Cmd, strings.Join(c.Args, " ")), Error: nil} + return &EvalResponse{Result: diceerrors.NewErrWithFormattedMessage("unknown command '%s', with args beginning with: %s", c.Cmd, strings.Join(c.Args, " ")), Error: nil} } // Till the time we refactor to handle QWATCH differently for websocket @@ -23,7 +23,7 @@ func ExecuteCommand(c *cmd.RedisCmd, client *comm.Client, store *dstore.Store, h return diceCmd.NewEval(c.Args, store) } - return EvalResponse{Result: diceCmd.Eval(c.Args, store), Error: nil} + return &EvalResponse{Result: diceCmd.Eval(c.Args, store), Error: nil} } // Temporary logic till we move all commands to new eval logic. @@ -40,14 +40,14 @@ func ExecuteCommand(c *cmd.RedisCmd, client *comm.Client, store *dstore.Store, h // Old implementation kept as it is, but we will be moving // to the new implmentation soon for all commands case "SUBSCRIBE", "QWATCH": - return EvalResponse{Result: EvalQWATCH(c.Args, httpOp, client, store), Error: nil} + return &EvalResponse{Result: EvalQWATCH(c.Args, httpOp, client, store), Error: nil} case "UNSUBSCRIBE", "QUNWATCH": - return EvalResponse{Result: EvalQUNWATCH(c.Args, httpOp, client), Error: nil} + return &EvalResponse{Result: EvalQUNWATCH(c.Args, httpOp, client), Error: nil} case auth.Cmd: - return EvalResponse{Result: EvalAUTH(c.Args, client), Error: nil} + return &EvalResponse{Result: EvalAUTH(c.Args, client), Error: nil} case "ABORT": - return EvalResponse{Result: clientio.RespOK, Error: nil} + return &EvalResponse{Result: clientio.RespOK, Error: nil} default: - return EvalResponse{Result: diceCmd.Eval(c.Args, store), Error: nil} + return &EvalResponse{Result: diceCmd.Eval(c.Args, store), Error: nil} } } diff --git a/internal/eval/results.go b/internal/eval/results.go new file mode 100644 index 000000000..e3cb89a38 --- /dev/null +++ b/internal/eval/results.go @@ -0,0 +1,17 @@ +package eval + +type RespType int + +// WARN: Do not change the ordering of the enum elements +// It is strictly mapped to HandlePredefinedResponse func internal/clientio/iohandler/netconn/netconn.go + +const ( + RespNIL RespType = iota + RespOK // OK + RespQueued // []byte("+QUEUED\r\n") // Signifies that a command has been queued for execution. //nolint:unused + RespZero // []byte(":0\r\n") // Represents the integer zero in RESP format. //nolint:unused + RespOne // []byte(":1\r\n") // Represents the integer one in RESP format. //nolint:unused + RespMinusOne // []byte(":-1\r\n") // Represents the integer negative one in RESP format. //nolint:unused + RespMinusTwo // []byte(":-2\r\n") // Represents the integer negative two in RESP format. //nolint:unused + RespEmptyArray // []byte("*0\r\n") // Represents an empty array in RESP format. //nolint:unused +) diff --git a/internal/eval/store_eval.go b/internal/eval/store_eval.go index c9c02d4d8..8ec1af48b 100644 --- a/internal/eval/store_eval.go +++ b/internal/eval/store_eval.go @@ -1,12 +1,9 @@ package eval import ( - "errors" - "fmt" "strconv" "strings" - "github.com/dicedb/dice/internal/clientio" diceerrors "github.com/dicedb/dice/internal/errors" "github.com/dicedb/dice/internal/object" "github.com/dicedb/dice/internal/server/utils" @@ -14,7 +11,8 @@ import ( ) // evalSET puts a new pair in db as in the args -// args must contain key and value, can also contain multiple options - +// args must contain key and value. +// args can also contain multiple options - // // EX or ex which will set the expiry time(in secs) for the key // PX or px which will set the expiry time(in milliseconds) for the key @@ -27,15 +25,18 @@ import ( // Returns encoded error response if both PX and EX flags are present // Returns encoded OK RESP once new entry is added // If the key already exists then the value will be overwritten and expiry will be discarded -func evalSET(args []string, store *dstore.Store) EvalResponse { +func evalSET(args []string, store *dstore.Store) *EvalResponse { if len(args) <= 1 { - return EvalResponse{Result: nil, Error: errors.New(string(diceerrors.NewErrArity("SET")))} + return &EvalResponse{ + Result: nil, + Error: diceerrors.ErrWrongArgumentCount("SET"), + } } var key, value string var exDurationMs int64 = -1 - var state = Uninitialized - var keepttl = false + var state exDurationState = Uninitialized + var keepttl bool = false key, value = args[0], args[1] oType, oEnc := deduceTypeEncoding(value) @@ -45,20 +46,32 @@ func evalSET(args []string, store *dstore.Store) EvalResponse { switch arg { case Ex, Px: if state != Uninitialized { - return EvalResponse{Result: nil, Error: errors.New(string(diceerrors.NewErrWithMessage(diceerrors.SyntaxErr)))} + return &EvalResponse{ + Result: nil, + Error: diceerrors.ErrSyntax, + } } i++ if i == len(args) { - return EvalResponse{Result: nil, Error: errors.New(string(diceerrors.NewErrWithMessage(diceerrors.SyntaxErr)))} + return &EvalResponse{ + Result: nil, + Error: diceerrors.ErrSyntax, + } } exDuration, err := strconv.ParseInt(args[i], 10, 64) if err != nil { - return EvalResponse{Result: nil, Error: errors.New(string(diceerrors.NewErrWithMessage(diceerrors.IntOrOutOfRangeErr)))} + return &EvalResponse{ + Result: nil, + Error: diceerrors.ErrIntegerOutOfRange, + } } if exDuration <= 0 || exDuration >= maxExDuration { - return EvalResponse{Result: nil, Error: errors.New(string(diceerrors.NewErrExpireTime("SET")))} + return &EvalResponse{ + Result: nil, + Error: diceerrors.ErrInvalidExpireTime("SET"), + } } // converting seconds to milliseconds @@ -70,19 +83,31 @@ func evalSET(args []string, store *dstore.Store) EvalResponse { case Pxat, Exat: if state != Uninitialized { - return EvalResponse{Result: nil, Error: errors.New(string(diceerrors.NewErrWithMessage(diceerrors.SyntaxErr)))} + return &EvalResponse{ + Result: nil, + Error: diceerrors.ErrSyntax, + } } i++ if i == len(args) { - return EvalResponse{Result: nil, Error: errors.New(string(diceerrors.NewErrWithMessage(diceerrors.SyntaxErr)))} + return &EvalResponse{ + Result: nil, + Error: diceerrors.ErrSyntax, + } } exDuration, err := strconv.ParseInt(args[i], 10, 64) if err != nil { - return EvalResponse{Result: nil, Error: errors.New(string(diceerrors.NewErrWithMessage(diceerrors.IntOrOutOfRangeErr)))} + return &EvalResponse{ + Result: nil, + Error: diceerrors.ErrIntegerOutOfRange, + } } if exDuration < 0 { - return EvalResponse{Result: nil, Error: errors.New(string(diceerrors.NewErrExpireTime("SET")))} + return &EvalResponse{ + Result: nil, + Error: diceerrors.ErrInvalidExpireTime("SET"), + } } if arg == Exat { @@ -102,17 +127,26 @@ func evalSET(args []string, store *dstore.Store) EvalResponse { // if key does not exist, return RESP encoded nil if obj == nil { - return EvalResponse{Result: clientio.RespNIL, Error: nil} + return &EvalResponse{ + Result: RespNIL, + Error: nil, + } } case NX: obj := store.Get(key) if obj != nil { - return EvalResponse{Result: clientio.RespNIL, Error: nil} + return &EvalResponse{ + Result: RespNIL, + Error: nil, + } } case KeepTTL: keepttl = true default: - return EvalResponse{Result: nil, Error: errors.New(string(diceerrors.NewErrWithMessage(diceerrors.SyntaxErr)))} + return &EvalResponse{ + Result: nil, + Error: diceerrors.ErrSyntax, + } } } @@ -124,22 +158,31 @@ func evalSET(args []string, store *dstore.Store) EvalResponse { case object.ObjEncodingEmbStr, object.ObjEncodingRaw: storedValue = value default: - return EvalResponse{Result: nil, Error: fmt.Errorf("ERR unsupported encoding: %d", oEnc)} + return &EvalResponse{ + Result: nil, + Error: diceerrors.ErrUnsupportedEncoding(int(oEnc)), + } } // putting the k and value in a Hash Table store.Put(key, store.NewObj(storedValue, exDurationMs, oType, oEnc), dstore.WithKeepTTL(keepttl)) - return EvalResponse{Result: clientio.RespOK, Error: nil} + return &EvalResponse{ + Result: RespOK, + Error: nil, + } } // evalGET returns the value for the queried key in args // The key should be the only param in args // The RESP value of the key is encoded and then returned // evalGET returns response.RespNIL if key is expired or it does not exist -func evalGET(args []string, store *dstore.Store) EvalResponse { +func evalGET(args []string, store *dstore.Store) *EvalResponse { if len(args) != 1 { - return EvalResponse{Result: nil, Error: errors.New(string(diceerrors.NewErrArity("GET")))} + return &EvalResponse{ + Result: nil, + Error: diceerrors.ErrWrongArgumentCount("GET"), + } } key := args[0] @@ -148,7 +191,10 @@ func evalGET(args []string, store *dstore.Store) EvalResponse { // if key does not exist, return RESP encoded nil if obj == nil { - return EvalResponse{Result: clientio.RespNIL, Error: nil} + return &EvalResponse{ + Result: RespNIL, + Error: nil, + } } // Decode and return the value based on its encoding @@ -156,31 +202,49 @@ func evalGET(args []string, store *dstore.Store) EvalResponse { case object.ObjEncodingInt: // Value is stored as an int64, so use type assertion if val, ok := obj.Value.(int64); ok { - return EvalResponse{Result: clientio.Encode(val, false), Error: nil} + return &EvalResponse{ + Result: val, + Error: nil, + } + } + + return &EvalResponse{ + Result: nil, + Error: diceerrors.ErrUnexpectedType("int64", obj.Value), } - return EvalResponse{Result: nil, - Error: errors.New(string(diceerrors.NewErrWithFormattedMessage("expected int64 but got another type: %s", obj.Value)))} case object.ObjEncodingEmbStr, object.ObjEncodingRaw: // Value is stored as a string, use type assertion if val, ok := obj.Value.(string); ok { - return EvalResponse{Result: clientio.Encode(val, false), Error: nil} + return &EvalResponse{ + Result: val, + Error: nil, + } + } + return &EvalResponse{ + Result: nil, + Error: diceerrors.ErrUnexpectedType("string", obj.Value), } - return EvalResponse{Result: nil, - Error: errors.New(string(diceerrors.NewErrWithMessage("expected string but got another type")))} case object.ObjEncodingByteArray: // Value is stored as a bytearray, use type assertion if val, ok := obj.Value.(*ByteArray); ok { - return EvalResponse{Result: clientio.Encode(string(val.data), false), Error: nil} + return &EvalResponse{ + Result: string(val.data), + Error: nil, + } } - return EvalResponse{Result: nil, - Error: errors.New(string(diceerrors.NewErrWithMessage(diceerrors.WrongTypeErr)))} + return &EvalResponse{ + Result: nil, + Error: diceerrors.ErrWrongTypeOperation, + } default: - return EvalResponse{Result: nil, - Error: errors.New(string(diceerrors.NewErrWithMessage(diceerrors.WrongTypeErr)))} + return &EvalResponse{ + Result: nil, + Error: diceerrors.ErrWrongTypeOperation, + } } } @@ -192,9 +256,12 @@ func evalGET(args []string, store *dstore.Store) EvalResponse { // Returns: // Bulk string reply: the old value stored at the key. // Nil reply: if the key does not exist. -func evalGETSET(args []string, store *dstore.Store) EvalResponse { +func evalGETSET(args []string, store *dstore.Store) *EvalResponse { if len(args) != 2 { - return EvalResponse{Result: nil, Error: errors.New(string(diceerrors.NewErrArity("GETSET")))} + return &EvalResponse{ + Result: nil, + Error: diceerrors.ErrWrongArgumentCount("GETSET"), + } } key, value := args[0], args[1] @@ -215,16 +282,16 @@ func evalGETSET(args []string, store *dstore.Store) EvalResponse { } // evalSETEX puts a new pair in db as in the args -// args must contain only key, expiry and value +// args must contain only key , expiry and value // Returns encoded error response if is not part of args // Returns encoded error response if expiry time value in not integer // Returns encoded OK RESP once new entry is added // If the key already exists then the value and expiry will be overwritten -func evalSETEX(args []string, store *dstore.Store) EvalResponse { +func evalSETEX(args []string, store *dstore.Store) *EvalResponse { if len(args) != 3 { - return EvalResponse{ + return &EvalResponse{ Result: nil, - Error: errors.New(string(diceerrors.NewErrArity("SETEX"))), + Error: diceerrors.ErrWrongArgumentCount("SETEX"), } } @@ -233,15 +300,15 @@ func evalSETEX(args []string, store *dstore.Store) EvalResponse { exDuration, err := strconv.ParseInt(args[1], 10, 64) if err != nil { - return EvalResponse{ + return &EvalResponse{ Result: nil, - Error: errors.New(string(diceerrors.NewErrWithMessage(diceerrors.IntOrOutOfRangeErr))), + Error: diceerrors.ErrIntegerOutOfRange, } } if exDuration <= 0 || exDuration >= maxExDuration { - return EvalResponse{ + return &EvalResponse{ Result: nil, - Error: errors.New(string(diceerrors.NewErrExpireTime("SETEX"))), + Error: diceerrors.ErrInvalidExpireTime("SETEX"), } } newArgs := []string{key, value, Ex, args[1]} diff --git a/internal/ops/store_op.go b/internal/ops/store_op.go index d33dca45c..b2d7fb8ec 100644 --- a/internal/ops/store_op.go +++ b/internal/ops/store_op.go @@ -19,6 +19,6 @@ type StoreOp struct { // StoreResponse represents the response of a Store operation. type StoreResponse struct { - RequestID uint32 // RequestID that this StoreResponse belongs to - EvalResponse eval.EvalResponse // Result of the Store operation, for now the type is set to []byte, but this can change in the future. + RequestID uint32 // RequestID that this StoreResponse belongs to + EvalResponse *eval.EvalResponse // Result of the Store operation, for now the type is set to []byte, but this can change in the future. } diff --git a/internal/server/httpServer.go b/internal/server/httpServer.go index 1b95a707c..90ca59e47 100644 --- a/internal/server/httpServer.go +++ b/internal/server/httpServer.go @@ -17,6 +17,7 @@ import ( "github.com/dicedb/dice/internal/cmd" "github.com/dicedb/dice/internal/comm" derrors "github.com/dicedb/dice/internal/errors" + "github.com/dicedb/dice/internal/eval" "github.com/dicedb/dice/internal/ops" "github.com/dicedb/dice/internal/server/utils" "github.com/dicedb/dice/internal/shard" @@ -160,7 +161,7 @@ func (s *HTTPServer) DiceHTTPHandler(writer http.ResponseWriter, request *http.R // Wait for response resp := <-s.ioChan - s.writeResponse(writer, resp) + s.writeResponse(writer, resp, redisCmd) } func (s *HTTPServer) DiceHTTPQwatchHandler(writer http.ResponseWriter, request *http.Request) { @@ -215,7 +216,7 @@ func (s *HTTPServer) DiceHTTPQwatchHandler(writer http.ResponseWriter, request * // Wait for 1st sync response from server for QWATCH and flush it to client resp := <-s.ioChan - s.writeResponse(writer, resp) + s.writeQWatchResponse(writer, resp) flusher.Flush() // Keep listening for context cancellation (client disconnect) and continuous responses doneChan := request.Context().Done() @@ -238,23 +239,41 @@ func (s *HTTPServer) DiceHTTPQwatchHandler(writer http.ResponseWriter, request * storeOp.Cmd = unWatchCmd s.shardManager.GetShard(0).ReqChan <- storeOp resp := <-s.ioChan - s.writeResponse(writer, resp) + s.writeResponse(writer, resp, redisCmd) return } } } -func (s *HTTPServer) writeQWatchResponse(writer http.ResponseWriter, response comm.QwatchResponse) { +func (s *HTTPServer) writeQWatchResponse(writer http.ResponseWriter, response interface{}) { + var result interface{} + var err error + + // Use type assertion to handle both types of responses + switch resp := response.(type) { + case comm.QwatchResponse: + result = resp.Result + err = resp.Error + case *ops.StoreResponse: + result = resp.EvalResponse.Result + err = resp.EvalResponse.Error + default: + s.logger.Error("Unsupported response type") + http.Error(writer, "Internal Server Error", http.StatusInternalServerError) + return + } + var rp *clientio.RESPParser - if response.Error != nil { - rp = clientio.NewRESPParser(bytes.NewBuffer([]byte(response.Error.Error()))) + if err != nil { + rp = clientio.NewRESPParser(bytes.NewBuffer([]byte(err.Error()))) } else { - rp = clientio.NewRESPParser(bytes.NewBuffer(response.Result.([]byte))) + rp = clientio.NewRESPParser(bytes.NewBuffer(result.([]byte))) } val, err := rp.DecodeOne() if err != nil { s.logger.Error("Error decoding response: %v", slog.Any("error", err)) + http.Error(writer, "Internal Server Error", http.StatusInternalServerError) return } @@ -277,6 +296,7 @@ func (s *HTTPServer) writeQWatchResponse(writer http.ResponseWriter, response co if err != nil { s.logger.Error("Error marshaling QueryData to JSON: %v", slog.Any("error", err)) + http.Error(writer, "Internal Server Error", http.StatusInternalServerError) return } @@ -284,8 +304,10 @@ func (s *HTTPServer) writeQWatchResponse(writer http.ResponseWriter, response co _, err = writer.Write(responseJSON) if err != nil { s.logger.Error("Error writing SSE data: %v", slog.Any("error", err)) + http.Error(writer, "Internal Server Error", http.StatusInternalServerError) return } + flusher, ok := writer.(http.Flusher) if !ok { http.Error(writer, "Streaming unsupported", http.StatusInternalServerError) @@ -295,22 +317,56 @@ func (s *HTTPServer) writeQWatchResponse(writer http.ResponseWriter, response co flusher.Flush() // Flush the response to send it to the client } -func (s *HTTPServer) writeResponse(writer http.ResponseWriter, result *ops.StoreResponse) { +func (s *HTTPServer) writeResponse(writer http.ResponseWriter, result *ops.StoreResponse, redisCmd *cmd.RedisCmd) { + _, ok := WorkerCmdsMeta[redisCmd.Cmd] var rp *clientio.RESPParser - if result.EvalResponse.Error != nil { - rp = clientio.NewRESPParser(bytes.NewBuffer([]byte(result.EvalResponse.Error.Error()))) + + var responseValue interface{} + // TODO: Remove this conditional check and if (true) condition when all commands are migrated + if !ok { + var err error + if result.EvalResponse.Error != nil { + rp = clientio.NewRESPParser(bytes.NewBuffer([]byte(result.EvalResponse.Error.Error()))) + } else { + rp = clientio.NewRESPParser(bytes.NewBuffer(result.EvalResponse.Result.([]byte))) + } + + responseValue, err = rp.DecodeOne() + if err != nil { + s.logger.Error("Error decoding response", "error", err) + http.Error(writer, "Internal Server Error", http.StatusInternalServerError) + return + } } else { - rp = clientio.NewRESPParser(bytes.NewBuffer(result.EvalResponse.Result.([]byte))) + if result.EvalResponse.Error != nil { + responseValue = result.EvalResponse.Error.Error() + } else { + responseValue = result.EvalResponse.Result + } } - val, err := rp.DecodeOne() - if err != nil { - s.logger.Error("Error decoding response", "error", err) - http.Error(writer, "Internal Server Error", http.StatusInternalServerError) - return + // func HandlePredefinedResponse(response interface{}) []byte { + respArr := []string{ + "(nil)", // Represents a RESP Nil Bulk String, which indicates a null value. + "OK", // Represents a RESP Simple String with value "OK". + "QUEUED", // Represents a Simple String indicating that a command has been queued. + "0", // Represents a RESP Integer with value 0. + "1", // Represents a RESP Integer with value 1. + "-1", // Represents a RESP Integer with value -1. + "-2", // Represents a RESP Integer with value -2. + "*0", // Represents an empty RESP Array. + } + + if val, ok := responseValue.(eval.RespType); ok { + responseValue = respArr[val] + } + + if bt, ok := responseValue.([]byte); ok { + responseValue = string(bt) } + httpResponse := utils.HTTPResponse{Data: responseValue} - responseJSON, err := json.Marshal(val) + responseJSON, err := json.Marshal(httpResponse) if err != nil { s.logger.Error("Error marshaling response", "error", err) http.Error(writer, "Internal Server Error", http.StatusInternalServerError) diff --git a/internal/server/server.go b/internal/server/server.go index 764958499..48f837c02 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -16,6 +16,7 @@ import ( "github.com/dicedb/dice/config" "github.com/dicedb/dice/internal/auth" "github.com/dicedb/dice/internal/clientio" + "github.com/dicedb/dice/internal/clientio/iohandler/netconn" "github.com/dicedb/dice/internal/cmd" "github.com/dicedb/dice/internal/comm" diceerrors "github.com/dicedb/dice/internal/errors" @@ -157,7 +158,7 @@ func (s *AsyncServer) Run(ctx context.Context) error { s.shardManager.RegisterWorker("server", s.ioChan) - if err := syscall.Listen(s.serverFD, s.maxClients); err != nil { + if err := syscall.Listen(s.serverFD, int(s.maxClients)); err != nil { return err } @@ -218,10 +219,6 @@ func (s *AsyncServer) eventLoop(ctx context.Context) error { if event.Fd == s.serverFD { if err := s.acceptConnection(); err != nil { s.logger.Warn(err.Error()) - // Close the event FD on error - if closeErr := syscall.Close(event.Fd); closeErr != nil { - s.logger.Error("Failed to close event FD:", slog.Any("error", closeErr)) - } } } else { if err := s.handleClientEvent(event); err != nil { @@ -240,10 +237,6 @@ func (s *AsyncServer) eventLoop(ctx context.Context) error { // acceptConnection accepts a new client connection and subscribes to read events on the connection. func (s *AsyncServer) acceptConnection() error { - if len(s.connectedClients) > s.maxClients { - return errors.New("connection refused. Reached the max-connection limit") - } - fd, _, err := syscall.Accept(s.serverFD) if err != nil { return err @@ -284,6 +277,28 @@ func (s *AsyncServer) handleClientEvent(event iomultiplexer.Event) error { return nil } +func handleMigratedResp(resp interface{}, buf *bytes.Buffer) { + // Process the incoming response by calling the handleResponse function. + // This function checks the response against known RESP formatted values + // and returns the corresponding byte array representation. The result + // is assigned to the resp variable. + r := netconn.HandlePredefinedResponse(resp) + + // Check if the processed response (resp) is not nil. + // If it is not nil, this means incoming response was not + // matched to any predefined RESP responses, + // and we proceed to encode the original response using + // the clientio.Encode function. This function converts the + // response into the desired format based on the specified + // isBlkEnc encoding flag, which indicates whether the + // response should be encoded in a block format. + if r == nil { + r = clientio.Encode(resp, false) + } + + buf.Write(r) +} + func (s *AsyncServer) executeCommandToBuffer(redisCmd *cmd.RedisCmd, buf *bytes.Buffer, c *comm.Client) { s.shardManager.GetShard(0).ReqChan <- &ops.StoreOp{ Cmd: redisCmd, @@ -293,12 +308,24 @@ func (s *AsyncServer) executeCommandToBuffer(redisCmd *cmd.RedisCmd, buf *bytes. } resp := <-s.ioChan - if resp.EvalResponse.Error != nil { - buf.WriteString(resp.EvalResponse.Error.Error()) + + val, ok := WorkerCmdsMeta[redisCmd.Cmd] + // TODO: Remove this conditional check and if (true) condition when all commands are migrated + if !ok { + buf.Write(resp.EvalResponse.Result.([]byte)) + } else { + // If command type is Global then return the worker eval + if val.CmdType == Global { + buf.Write(val.RespNoShards(redisCmd.Args)) + return + } + // Handle error case independently + if resp.EvalResponse.Error != nil { + handleMigratedResp(resp.EvalResponse.Error, buf) + } + handleMigratedResp(resp.EvalResponse.Result, buf) return } - - buf.Write(resp.EvalResponse.Result.([]byte)) } func readCommands(c io.ReadWriter) (*cmd.RedisCmds, bool, error) { @@ -422,8 +449,8 @@ func (s *AsyncServer) executeTransaction(c *comm.Client, buf *bytes.Buffer) { return } - for _, redisCmd := range cmds { - s.executeCommandToBuffer(redisCmd, buf, c) + for _, cmd := range cmds { + s.executeCommandToBuffer(cmd, buf, c) } c.Cqueue.Cmds = make([]*cmd.RedisCmd, 0) diff --git a/internal/server/utils/httpResp.go b/internal/server/utils/httpResp.go new file mode 100644 index 000000000..bf909f0cf --- /dev/null +++ b/internal/server/utils/httpResp.go @@ -0,0 +1,5 @@ +package utils + +type HTTPResponse struct { + Data interface{} `json:"data"` +} diff --git a/internal/server/websocketServer.go b/internal/server/websocketServer.go index c355ad380..7eac5f091 100644 --- a/internal/server/websocketServer.go +++ b/internal/server/websocketServer.go @@ -11,6 +11,8 @@ import ( "sync" "time" + "github.com/dicedb/dice/internal/eval" + "github.com/dicedb/dice/config" "github.com/dicedb/dice/internal/clientio" diceerrors "github.com/dicedb/dice/internal/errors" @@ -163,20 +165,53 @@ func (s *WebsocketServer) WebsocketHandler(w http.ResponseWriter, r *http.Reques // Wait for response resp := <-s.ioChan + + _, ok := WorkerCmdsMeta[redisCmd.Cmd] + respArr := []string{ + "(nil)", // Represents a RESP Nil Bulk String, which indicates a null value. + "OK", // Represents a RESP Simple String with value "OK". + "QUEUED", // Represents a Simple String indicating that a command has been queued. + "0", // Represents a RESP Integer with value 0. + "1", // Represents a RESP Integer with value 1. + "-1", // Represents a RESP Integer with value -1. + "-2", // Represents a RESP Integer with value -2. + "*0", // Represents an empty RESP Array. + } var rp *clientio.RESPParser - if resp.EvalResponse.Error != nil { - rp = clientio.NewRESPParser(bytes.NewBuffer([]byte(resp.EvalResponse.Error.Error()))) + + var responseValue interface{} + // TODO: Remove this conditional check and if (true) condition when all commands are migrated + if !ok { + var err error + if resp.EvalResponse.Error != nil { + rp = clientio.NewRESPParser(bytes.NewBuffer([]byte(resp.EvalResponse.Error.Error()))) + } else { + rp = clientio.NewRESPParser(bytes.NewBuffer(resp.EvalResponse.Result.([]byte))) + } + + responseValue, err = rp.DecodeOne() + if err != nil { + s.logger.Error("Error decoding response", "error", err) + writeResponse(conn, []byte("error: Internal Server Error")) + return + } } else { - rp = clientio.NewRESPParser(bytes.NewBuffer(resp.EvalResponse.Result.([]byte))) + if resp.EvalResponse.Error != nil { + responseValue = resp.EvalResponse.Error.Error() + } else { + responseValue = resp.EvalResponse.Result + } } - val, err := rp.DecodeOne() - if err != nil { - writeResponse(conn, []byte("error: decoding response")) - continue + if val, ok := responseValue.(eval.RespType); ok { + responseValue = respArr[val] + } + + if bt, ok := responseValue.([]byte); ok { + responseValue = string(bt) } - respBytes, err := json.Marshal(val) + respBytes, err := json.Marshal(responseValue) if err != nil { writeResponse(conn, []byte("error: marshaling json response")) continue diff --git a/internal/worker/cmd_meta.go b/internal/worker/cmd_meta.go index acfb7064b..76ab8fe85 100644 --- a/internal/worker/cmd_meta.go +++ b/internal/worker/cmd_meta.go @@ -35,8 +35,15 @@ type CmdMeta struct { CmdType Cmd string WorkerCommandHandler func([]string) []byte - decomposeCommand func(redisCmd *cmd.RedisCmd) []*cmd.RedisCmd - composeResponse func(responses ...eval.EvalResponse) []byte + + // decomposeCommand is a function that takes a Redis command and breaks it down into smaller, + // manageable Redis commands for each shard processing. It returns a slice of Redis commands. + decomposeCommand func(redisCmd *cmd.RedisCmd) []*cmd.RedisCmd + + // composeResponse is a function that combines multiple responses from the execution of commands + // into a single response object. It accepts a variadic parameter of EvalResponse objects + // and returns a unified response interface. + composeResponse func(responses ...eval.EvalResponse) interface{} } var CommandsMeta = map[string]CmdMeta{ diff --git a/internal/worker/worker.go b/internal/worker/worker.go index 79e3149dc..15756061b 100644 --- a/internal/worker/worker.go +++ b/internal/worker/worker.go @@ -11,7 +11,6 @@ import ( "github.com/dicedb/dice/config" "github.com/dicedb/dice/internal/auth" - "github.com/dicedb/dice/internal/clientio" "github.com/dicedb/dice/internal/clientio/iohandler" "github.com/dicedb/dice/internal/clientio/requestparser" "github.com/dicedb/dice/internal/cmd" @@ -85,14 +84,14 @@ func (w *BaseWorker) Start(ctx context.Context) error { } cmds, err := w.parser.Parse(data) if err != nil { - err = w.ioHandler.Write(ctx, clientio.Encode(err, true)) + err = w.ioHandler.Write(ctx, err) if err != nil { w.logger.Debug("Write error, connection closed possibly", slog.String("workerID", w.id), slog.Any("error", err)) return err } } if len(cmds) == 0 { - err = w.ioHandler.Write(ctx, clientio.Encode("ERR: Invalid request", true)) + err = w.ioHandler.Write(ctx, fmt.Errorf("ERR: Invalid request")) if err != nil { w.logger.Debug("Write error, connection closed possibly", slog.String("workerID", w.id), slog.Any("error", err)) return err @@ -103,7 +102,7 @@ func (w *BaseWorker) Start(ctx context.Context) error { // DiceDB supports clients to send only one request at a time // We also need to ensure that the client is blocked until the response is received if len(cmds) > 1 { - err = w.ioHandler.Write(ctx, clientio.Encode("ERR: Multiple commands not supported", true)) + err = w.ioHandler.Write(ctx, fmt.Errorf("ERR: Multiple commands not supported")) if err != nil { w.logger.Debug("Write error, connection closed possibly", slog.String("workerID", w.id), slog.Any("error", err)) return err @@ -112,7 +111,7 @@ func (w *BaseWorker) Start(ctx context.Context) error { err = w.isAuthenticated(cmds[0]) if err != nil { - werr := w.ioHandler.Write(ctx, clientio.Encode(err, false)) + werr := w.ioHandler.Write(ctx, err) if werr != nil { w.logger.Debug("Write error, connection closed possibly", slog.Any("error", errors.Join(err, werr))) return errors.Join(err, werr) @@ -120,7 +119,7 @@ func (w *BaseWorker) Start(ctx context.Context) error { } // executeCommand executes the command and return the response back to the client func(errChan chan error) { - execctx, cancel := context.WithTimeout(ctx, 1*time.Second) // Timeout if + execctx, cancel := context.WithTimeout(ctx, 6*time.Second) // Timeout set to 6 seconds for integration tests defer cancel() err = w.executeCommand(execctx, cmds[0]) if err != nil { @@ -143,7 +142,7 @@ func (w *BaseWorker) executeCommand(ctx context.Context, redisCmd *cmd.RedisCmd) // Retrieve metadata for the command to determine if multisharding is supported. meta, ok := CommandsMeta[redisCmd.Cmd] if !ok { - // If no metadata exists, treat it as a single command. + // If no metadata exists, treat it as a single command and not migrated cmdList = append(cmdList, redisCmd) } else { // Depending on the command type, decide how to handle it. @@ -165,12 +164,18 @@ func (w *BaseWorker) executeCommand(ctx context.Context, redisCmd *cmd.RedisCmd) switch redisCmd.Cmd { case CmdAuth: err := w.ioHandler.Write(ctx, w.RespAuth(redisCmd.Args)) - w.logger.Error("Error sending auth response to worker", slog.String("workerID", w.id), slog.Any("error", err)) + if err != nil { + w.logger.Error("Error sending auth response to worker", slog.String("workerID", w.id), slog.Any("error", err)) + } return err case CmdAbort: + err := w.ioHandler.Write(ctx, eval.RespOK) + if err != nil { + w.logger.Error("Error sending abort response to worker", slog.String("workerID", w.id), slog.Any("error", err)) + } w.logger.Info("Received ABORT command, initiating server shutdown", slog.String("workerID", w.id)) w.globalErrorChan <- diceerrors.ErrAborted - return nil + return err default: cmdList = append(cmdList, redisCmd) } @@ -238,7 +243,7 @@ func (w *BaseWorker) gather(ctx context.Context, c string, numCmds int, ct CmdTy w.logger.Error("Timed out waiting for response from shards", slog.String("workerID", w.id), slog.Any("error", ctx.Err())) case resp, ok := <-w.respChan: if ok { - evalResp = append(evalResp, resp.EvalResponse) + evalResp = append(evalResp, *resp.EvalResponse) } numCmds-- continue @@ -276,7 +281,7 @@ func (w *BaseWorker) gather(ctx context.Context, c string, numCmds int, ct CmdTy switch ct { case SingleShard, Custom: if evalResp[0].Error != nil { - err := w.ioHandler.Write(ctx, []byte(evalResp[0].Error.Error())) + err := w.ioHandler.Write(ctx, evalResp[0].Error) if err != nil { w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) } @@ -284,7 +289,7 @@ func (w *BaseWorker) gather(ctx context.Context, c string, numCmds int, ct CmdTy return err } - err := w.ioHandler.Write(ctx, evalResp[0].Result.([]byte)) + err := w.ioHandler.Write(ctx, evalResp[0].Result) if err != nil { w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) return err @@ -299,7 +304,7 @@ func (w *BaseWorker) gather(ctx context.Context, c string, numCmds int, ct CmdTy default: w.logger.Error("Unknown command type", slog.String("workerID", w.id)) - err := w.ioHandler.Write(ctx, []byte(diceerrors.InternalServerError)) + err := w.ioHandler.Write(ctx, diceerrors.ErrInternalServer) if err != nil { w.logger.Debug("Error sending response to client", slog.String("workerID", w.id), slog.Any("error", err)) return err @@ -319,14 +324,14 @@ func (w *BaseWorker) isAuthenticated(redisCmd *cmd.RedisCmd) error { // RespAuth returns with an encoded "OK" if the user is authenticated // If the user is not authenticated, it returns with an encoded error message -func (w *BaseWorker) RespAuth(args []string) []byte { +func (w *BaseWorker) RespAuth(args []string) interface{} { // Check for incorrect number of arguments (arity error). if len(args) < 1 || len(args) > 2 { - return diceerrors.NewErrArity("AUTH") // Return an error if the number of arguments is not equal to 1. + return diceerrors.ErrWrongArgumentCount("AUTH") } if config.DiceConfig.Auth.Password == "" { - return diceerrors.NewErrWithMessage("AUTH called without any password configured for the default user. Are you sure your configuration is correct?") + return diceerrors.ErrAuth } username := config.DiceConfig.Auth.UserName @@ -339,10 +344,10 @@ func (w *BaseWorker) RespAuth(args []string) []byte { } if err := w.Session.Validate(username, password); err != nil { - return clientio.Encode(err, false) + return err } - return clientio.RespOK + return eval.RespOK } func (w *BaseWorker) Stop() error { diff --git a/internal/worker/workermanager.go b/internal/worker/workermanager.go index 96befcbcb..597f51603 100644 --- a/internal/worker/workermanager.go +++ b/internal/worker/workermanager.go @@ -3,14 +3,15 @@ package worker import ( "errors" "sync" + "sync/atomic" "github.com/dicedb/dice/internal/shard" ) type WorkerManager struct { connectedClients sync.Map - numWorkers int - maxClients int + numWorkers atomic.Int64 + maxClients int64 shardManager *shard.ShardManager mu sync.Mutex } @@ -22,7 +23,7 @@ var ( func NewWorkerManager(maxClients int, sm *shard.ShardManager) *WorkerManager { return &WorkerManager{ - maxClients: maxClients, + maxClients: int64(maxClients), shardManager: sm, } } @@ -41,12 +42,12 @@ func (wm *WorkerManager) RegisterWorker(worker Worker) error { wm.shardManager.RegisterWorker(worker.ID(), respChan) // TODO: Change respChan type to ShardResponse } - wm.numWorkers++ + wm.numWorkers.Add(1) return nil } -func (wm *WorkerManager) GetWorkerCount() int { - return wm.numWorkers +func (wm *WorkerManager) GetWorkerCount() int64 { + return wm.numWorkers.Load() } func (wm *WorkerManager) GetWorker(workerID string) (Worker, bool) { @@ -68,7 +69,7 @@ func (wm *WorkerManager) UnregisterWorker(workerID string) error { } wm.shardManager.UnregisterWorker(workerID) - wm.numWorkers++ + wm.numWorkers.Add(-1) return nil }