diff --git a/parser/parser.go b/parser/parser.go index b815259..d50c559 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -156,6 +156,8 @@ type UseStatement struct { func (u UseStatement) isStatement() {} +var defaultSelectStatement = &SelectStatement{} + // IsQueryHandled parses the query string and determines if the query is handled by the proxy func IsQueryHandled(keyspace Identifier, query string) (handled bool, stmt Statement, err error) { var l lexer @@ -164,11 +166,15 @@ func IsQueryHandled(keyspace Identifier, query string) (handled bool, stmt State t := l.next() switch t { case tkSelect: - return isHandledSelectStmt(&l, keyspace) + handled, stmt, err = isHandledSelectStmt(&l, keyspace) + if !handled { + stmt = defaultSelectStatement + } + return case tkUse: return isHandledUseStmt(&l) } - return false, nil, nil + return } // IsQueryIdempotent parses the query string and determines if the query is idempotent diff --git a/parser/parser_test.go b/parser/parser_test.go index d0a9f9a..b0d5342 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -98,12 +98,12 @@ func TestParser(t *testing.T) { Keyspace: "system", }, false}, // Reads from tables named similarly to system tables (not handled) - {"", "SELECT count(*) FROM local", false, true, nil, false}, - {"", "SELECT count(*) FROM peers", false, true, nil, false}, - {"", "SELECT count(*) FROM peers_v2", false, true, nil, false}, + {"", "SELECT count(*) FROM local", false, true, defaultSelectStatement, false}, + {"", "SELECT count(*) FROM peers", false, true, defaultSelectStatement, false}, + {"", "SELECT count(*) FROM peers_v2", false, true, defaultSelectStatement, false}, // Semicolon at the end - {"", "SELECT count(*) FROM table;", false, true, nil, false}, + {"", "SELECT count(*) FROM table;", false, true, defaultSelectStatement, false}, // Mutations to system tables (not handled) {"", "INSERT INTO system.local (key, rpc_address) VALUES ('local1', '127.0.0.1')", false, true, nil, false}, @@ -169,15 +169,17 @@ func TestParser(t *testing.T) { } for _, tt := range tests { - handled, stmt, err := IsQueryHandled(IdentifierFromString(tt.keyspace), tt.query) - assert.True(t, (err != nil) == tt.hasError, tt.query) + t.Run(tt.query, func(t *testing.T) { + handled, stmt, err := IsQueryHandled(IdentifierFromString(tt.keyspace), tt.query) + assert.True(t, (err != nil) == tt.hasError, tt.query) - idempotent, err := IsQueryIdempotent(tt.query) - assert.Nil(t, err, tt.query) + idempotent, err := IsQueryIdempotent(tt.query) + assert.Nil(t, err, tt.query) - assert.Equal(t, tt.handled, handled, "invalid handled", tt.query) - assert.Equal(t, tt.idempotent, idempotent, "invalid idempotency", tt.query) - assert.Equal(t, tt.stmt, stmt, "invalid parsed statement", tt.query) + assert.Equal(t, tt.handled, handled, "invalid handled", tt.query) + assert.Equal(t, tt.idempotent, idempotent, "invalid idempotency", tt.query) + assert.Equal(t, tt.stmt, stmt, "invalid parsed statement", tt.query) + }) } } diff --git a/proxy.go b/proxy.go index 04975b9..8a64d45 100644 --- a/proxy.go +++ b/proxy.go @@ -24,6 +24,7 @@ import ( func main() { ctx, cancel := signalContext(context.Background(), os.Interrupt, os.Kill) + defer cancel() os.Exit(proxy.Run(ctx, os.Args[1:])) diff --git a/proxy/codecs.go b/proxy/codecs.go index bdab69f..89386ac 100644 --- a/proxy/codecs.go +++ b/proxy/codecs.go @@ -38,12 +38,21 @@ func (c *partialQueryCodec) EncodedLength(_ message.Message, _ primitive.Protoco panic("not implemented") } -func (c *partialQueryCodec) Decode(source io.Reader, _ primitive.ProtocolVersion) (message.Message, error) { - if query, err := primitive.ReadLongString(source); err != nil { +func (c *partialQueryCodec) Decode(source io.Reader, _ primitive.ProtocolVersion) (msg message.Message, err error) { + var ( + query string + consistency uint16 + ) + + if query, err = primitive.ReadLongString(source); err != nil { return nil, err - } else { - return &partialQuery{query}, nil } + + if consistency, err = primitive.ReadShort(source); err != nil { + return nil, fmt.Errorf("cannot read QUERY consistency level: %w", err) + } + + return &partialQuery{query, primitive.ConsistencyLevel(consistency)}, nil } func (c *partialQueryCodec) GetOpCode() primitive.OpCode { @@ -51,7 +60,8 @@ func (c *partialQueryCodec) GetOpCode() primitive.OpCode { } type partialQuery struct { - query string + query string + consistency primitive.ConsistencyLevel } func (p *partialQuery) IsResponse() bool { @@ -63,11 +73,16 @@ func (p *partialQuery) GetOpCode() primitive.OpCode { } func (p *partialQuery) DeepCopyMessage() message.Message { - return &partialQuery{p.query} + return &partialQuery{p.query, p.consistency} +} + +func (m *partialQuery) String() string { + return "QUERY " + m.query } type partialExecute struct { - queryId []byte + queryId []byte + consistency primitive.ConsistencyLevel } func (m *partialExecute) IsResponse() bool { @@ -81,7 +96,7 @@ func (m *partialExecute) GetOpCode() primitive.OpCode { func (m *partialExecute) DeepCopyMessage() message.Message { queryId := make([]byte, len(m.queryId)) copy(queryId, m.queryId) - return &partialExecute{queryId} + return &partialExecute{queryId, m.consistency} } func (m *partialExecute) String() string { @@ -99,13 +114,22 @@ func (c *partialExecuteCodec) EncodedLength(_ message.Message, _ primitive.Proto } func (c *partialExecuteCodec) Decode(source io.Reader, _ primitive.ProtocolVersion) (msg message.Message, err error) { - execute := &partialExecute{} - if execute.queryId, err = primitive.ReadShortBytes(source); err != nil { + var ( + queryID []byte + consistency uint16 + ) + + if queryID, err = primitive.ReadShortBytes(source); err != nil { return nil, fmt.Errorf("cannot read EXECUTE query id: %w", err) - } else if len(execute.queryId) == 0 { + } else if len(queryID) == 0 { return nil, errors.New("EXECUTE missing query id") } - return execute, nil + + if consistency, err = primitive.ReadShort(source); err != nil { + return nil, fmt.Errorf("cannot read EXECUTE consistency level: %w", err) + } + + return &partialExecute{queryID, primitive.ConsistencyLevel(consistency)}, nil } func (c *partialExecuteCodec) GetOpCode() primitive.OpCode { @@ -113,7 +137,8 @@ func (c *partialExecuteCodec) GetOpCode() primitive.OpCode { } type partialBatch struct { - queryOrIds []interface{} + queryOrIds []interface{} + consistency primitive.ConsistencyLevel } func (p partialBatch) IsResponse() bool { @@ -127,7 +152,11 @@ func (p partialBatch) GetOpCode() primitive.OpCode { func (p partialBatch) DeepCopyMessage() message.Message { queryOrIds := make([]interface{}, len(p.queryOrIds)) copy(queryOrIds, p.queryOrIds) - return &partialBatch{queryOrIds} + return &partialBatch{queryOrIds, p.consistency} +} + +func (p partialBatch) String() string { + return fmt.Sprintf("BATCH (%d statements)", len(p.queryOrIds)) } type partialBatchCodec struct{} @@ -142,6 +171,7 @@ func (p partialBatchCodec) EncodedLength(msg message.Message, version primitive. func (p partialBatchCodec) Decode(source io.Reader, version primitive.ProtocolVersion) (msg message.Message, err error) { var queryOrIds []interface{} + var consistency uint16 var typ uint8 if typ, err = primitive.ReadByte(source); err != nil { return nil, fmt.Errorf("cannot read BATCH type: %w", err) @@ -177,7 +207,12 @@ func (p partialBatchCodec) Decode(source io.Reader, version primitive.ProtocolVe } queryOrIds[i] = queryOrId } - return &partialBatch{queryOrIds}, nil + + if consistency, err = primitive.ReadShort(source); err != nil { + return nil, fmt.Errorf("cannot read BATCH consistency level: %w", err) + } + + return &partialBatch{queryOrIds, primitive.ConsistencyLevel(consistency)}, nil } func (p partialBatchCodec) GetOpCode() primitive.OpCode { diff --git a/proxy/frame_patch.go b/proxy/frame_patch.go new file mode 100644 index 0000000..4a178c2 --- /dev/null +++ b/proxy/frame_patch.go @@ -0,0 +1,156 @@ +package proxy + +import ( + "encoding/binary" + "fmt" + + "github.com/datastax/go-cassandra-native-protocol/primitive" +) + +// patchQueryConsistency modifies the consistency level of a QUERY message in-place +// by locating the consistency field directly in the frame body. +// +// Layout based on the CQL native protocol v4 spec: +// /* ... */ +func patchQueryConsistency(body []byte, newConsistency primitive.ConsistencyLevel) error { + if len(body) < 6 { + return fmt.Errorf("body too short for QUERY") + } + + queryLen := binary.BigEndian.Uint32(body[0:4]) + offset := 4 + int(queryLen) + + if len(body) < offset+2 { + return fmt.Errorf("not enough bytes to patch QUERY consistency") + } + + // Modify the batch consistency field + binary.BigEndian.PutUint16(body[offset:offset+2], uint16(newConsistency)) + + return nil +} + +// patchExecuteConsistency modifies the consistency level of an EXECUTE message in-place +// by locating the consistency field directly after the prepared statement ID. +// +// Layout based on the CQL native protocol v4 spec: +// /* ... */ +func patchExecuteConsistency(body []byte, newConsistency primitive.ConsistencyLevel) error { + if len(body) < 2 { + return fmt.Errorf("body too short for EXECUTE") + } + + idLen := int(binary.BigEndian.Uint16(body[0:2])) + offset := 2 + idLen + + if len(body) < offset+2 { + return fmt.Errorf("not enough bytes to patch EXECUTE consistency") + } + + // Modify the batch consistency field + binary.BigEndian.PutUint16(body[offset:offset+2], uint16(newConsistency)) + + return nil +} + +// patchBatchConsistency modifies the consistency level of a BATCH message in-place +// by locating and modifying the consistency field of the batch, which applies to all queries in the batch. +// +// Layout based on the CQL native protocol v4 spec: +// /* +// +// (number of queries) +// (queries themselves) +// +// +// [] +// [] */ +func patchBatchConsistency(body []byte, newConsistency primitive.ConsistencyLevel) error { + if len(body) < 7 { + // Not enough bytes for even the basic batch layout (at least 2 bytes for n + 2 bytes for consistency + 1 byte for flags) + return fmt.Errorf("invalid batch body: too short") + } + + offset := 3 // [byte] [short] (3 bytes) + numQueries := binary.BigEndian.Uint16(body[1:3]) + + // Process the queries + for i := uint16(0); i < numQueries; i++ { + if len(body) <= offset { + return fmt.Errorf("query #%d exceeds body length", i) + } + + queryType := body[offset] + offset++ // Move past the query [byte] + + switch primitive.BatchChildType(queryType) { + case primitive.BatchChildTypeQueryString: + queryLength := binary.BigEndian.Uint32(body[offset : offset+4]) + offset += 4 // Move past the length + _ = body[offset : offset+int(queryLength)] + offset += int(queryLength) + case primitive.BatchChildTypePreparedId: + stmtIDLength := binary.BigEndian.Uint16(body[offset : offset+2]) + offset += 2 // Move past the length + _ = body[offset : offset+int(stmtIDLength)] + offset += int(stmtIDLength) + default: + return fmt.Errorf("unsupported BATCH child type for query #%d: %v", i, queryType) + } + + // Skip positional values + if err := skipPositionalValuesByteSlice(body, &offset); err != nil { + return fmt.Errorf("cannot skip positional values for query #%d: %w", i, err) + } + } + + // Patch consistency at the right spot + if len(body) < offset+2 { + return fmt.Errorf("not enough bytes to patch consistency") + } + binary.BigEndian.PutUint16(body[offset:], uint16(newConsistency)) + + return nil +} + +// skipPositionalValuesByteSlice skips the positional values in the byte slice +// It reads the length of positional values and skips them based on the byte slice offset. +func skipPositionalValuesByteSlice(body []byte, offset *int) error { + if len(body) <= *offset+2 { + return fmt.Errorf("insufficient bytes to read positional values length") + } + length := binary.BigEndian.Uint16(body[*offset : *offset+2]) + *offset += 2 // Move the offset past the length + + for i := uint16(0); i < length; i++ { + if err := skipValueByteSlice(body, offset); err != nil { + return fmt.Errorf("cannot skip positional value %d: %w", i, err) + } + } + return nil +} + +// skipValueByteSlice skips a single positional value based on its length in the byte slice. +func skipValueByteSlice(body []byte, offset *int) error { + if len(body) <= *offset+4 { + return fmt.Errorf("insufficient bytes to read value length") + } + length := int32(binary.BigEndian.Uint32(body[*offset : *offset+4])) + *offset += 4 // Move the offset past the length + + if length == -1 || length == -2 { + // It's a null or unset, nothing to skip + return nil + } + if length < 0 { + return fmt.Errorf("invalid negative length: %d", length) + } + + if length > 0 { + if len(body) < *offset+int(length) { + return fmt.Errorf("insufficient bytes to skip value content") + } + *offset += int(length) // Move the offset past the value content + } + return nil +} diff --git a/proxy/frame_patch_test.go b/proxy/frame_patch_test.go new file mode 100644 index 0000000..14e5783 --- /dev/null +++ b/proxy/frame_patch_test.go @@ -0,0 +1,182 @@ +package proxy + +import ( + "bytes" + "testing" + + "github.com/datastax/go-cassandra-native-protocol/message" + "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const version = primitive.ProtocolVersion4 + +func TestPatchQueryConsistency(t *testing.T) { + var queryCodec message.Codec + + for _, c := range message.DefaultMessageCodecs { + if c.GetOpCode() == primitive.OpCodeQuery { + queryCodec = c + } + } + assert.NotNil(t, queryCodec) + + t.Run("patch query consistency for valid query frame", func(t *testing.T) { + var buf bytes.Buffer + err := queryCodec.Encode(&message.Query{ + Query: "SELECT * FROM test", + Options: &message.QueryOptions{ + Consistency: primitive.ConsistencyLevelOne, + }, + }, &buf, version) + assert.NoError(t, err) + + body := buf.Bytes() + err = patchQueryConsistency(body, primitive.ConsistencyLevelQuorum) + assert.NoError(t, err) + + msg, err := queryCodec.Decode(bytes.NewBuffer(body), version) + require.NoError(t, err) + + assert.Equal(t, primitive.ConsistencyLevelQuorum, msg.(*message.Query).Options.Consistency) + }) +} + +func TestPatchExecuteConsistency(t *testing.T) { + + localSerialConsistency := primitive.ConsistencyLevelLocalSerial + + var queryCodec message.Codec + + for _, c := range message.DefaultMessageCodecs { + if c.GetOpCode() == primitive.OpCodeExecute { + queryCodec = c + } + } + assert.NotNil(t, queryCodec) + + t.Run("patch execute consistency for valid execute frame", func(t *testing.T) { + var buf bytes.Buffer + + msg := &message.Execute{ + QueryId: []byte{0x0a, 0x0b, 0x0c}, + ResultMetadataId: []byte{0x0d, 0x0e, 0x0f}, + Options: &message.QueryOptions{ + Consistency: primitive.ConsistencyLevelOne, + SerialConsistency: &localSerialConsistency, + }, + } + + err := queryCodec.Encode(msg, &buf, version) + assert.NoError(t, err) + + body := buf.Bytes() + err = patchExecuteConsistency(body, primitive.ConsistencyLevelQuorum) + assert.NoError(t, err) + + mesg, err := queryCodec.Decode(bytes.NewBuffer(body), version) + require.NoError(t, err) + + assert.Equal(t, primitive.ConsistencyLevelQuorum, mesg.(*message.Execute).Options.Consistency) + }) + +} + +func TestPatchBatchConsistency(t *testing.T) { + localSerialConsistency := primitive.ConsistencyLevelLocalSerial + timestamp := int64(1234567890) + + var queryCodec message.Codec + + for _, c := range message.DefaultMessageCodecs { + if c.GetOpCode() == primitive.OpCodeBatch { + queryCodec = c + } + } + assert.NotNil(t, queryCodec) + + t.Run("patch batch consistency for valid batch frame with values", func(t *testing.T) { + var buf bytes.Buffer + + msgWithFlags := &message.Batch{ + Type: primitive.BatchTypeLogged, + Children: []*message.BatchChild{ + { + Id: []byte{0x01, 0x02, 0x03}, + Values: []*primitive.Value{ + {Type: primitive.ValueTypeNull}, + }, + }, + { + Query: "SELECT * FROM table WHERE id = ?", + Values: []*primitive.Value{ + {Type: primitive.ValueTypeRegular, Contents: []byte{0x01, 0x02, 0x03}}, + }, + }, + }, + Consistency: primitive.ConsistencyLevelOne, + SerialConsistency: &localSerialConsistency, + DefaultTimestamp: ×tamp, + } + + err := queryCodec.Encode(msgWithFlags, &buf, version) + assert.NoError(t, err) + + body := buf.Bytes() + err = patchBatchConsistency(body, primitive.ConsistencyLevelQuorum) + assert.NoError(t, err) + + mesg, err := queryCodec.Decode(bytes.NewBuffer(body), version) + require.NoError(t, err) + + assert.Equal(t, primitive.ConsistencyLevelQuorum, mesg.(*message.Batch).Consistency) + }) + + t.Run("patch batch consistency for valid batch frame without values", func(t *testing.T) { + var buf bytes.Buffer + + msgWithFlags := &message.Batch{ + Type: primitive.BatchTypeLogged, + Children: []*message.BatchChild{ + { + Query: "SELECT * FROM table", + }, + { + Query: "SELECT * FROM table WHERE id = ?", + Values: []*primitive.Value{ + { + Type: primitive.ValueTypeRegular, + Contents: []byte{0x1, 0x2, 0x3}, + }, + }, + }, + { + Id: []byte{0x01, 0x02, 0x03}, + Values: []*primitive.Value{ + { + Type: primitive.ValueTypeRegular, + Contents: []byte{0x4, 0x5, 0x6}, + }, + }, + }, + }, + Consistency: primitive.ConsistencyLevelOne, + SerialConsistency: &localSerialConsistency, + DefaultTimestamp: ×tamp, + } + + err := queryCodec.Encode(msgWithFlags, &buf, version) + assert.NoError(t, err) + + body := buf.Bytes() + err = patchBatchConsistency(body, primitive.ConsistencyLevelQuorum) + assert.NoError(t, err) + + mesg, err := queryCodec.Decode(bytes.NewBuffer(body), version) + require.NoError(t, err) + + assert.Equal(t, primitive.ConsistencyLevelQuorum, mesg.(*message.Batch).Consistency) + }) + +} diff --git a/proxy/proxy.go b/proxy/proxy.go index fc585f4..fe34ed8 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -61,47 +61,54 @@ type PeerConfig struct { } type Config struct { - Version primitive.ProtocolVersion - MaxVersion primitive.ProtocolVersion - Auth proxycore.Authenticator - Resolver proxycore.EndpointResolver - ReconnectPolicy proxycore.ReconnectPolicy - RetryPolicy RetryPolicy - IdempotentGraph bool - NumConns int - Logger *zap.Logger - HeartBeatInterval time.Duration - ConnectTimeout time.Duration - IdleTimeout time.Duration - RPCAddr string - DC string - Tokens []string - Peers []PeerConfig + Version primitive.ProtocolVersion + MaxVersion primitive.ProtocolVersion + Auth proxycore.Authenticator + Resolver proxycore.EndpointResolver + ReconnectPolicy proxycore.ReconnectPolicy + RetryPolicy RetryPolicy + IdempotentGraph bool + NumConns int + Logger *zap.Logger + HeartBeatInterval time.Duration + ConnectTimeout time.Duration + IdleTimeout time.Duration + RPCAddr string + DC string + Tokens []string + Peers []PeerConfig + UnsupportedWriteConsistencies []clWrapper + UnsupportedWriteConsistencyOverride clWrapper // PreparedCache a cache that stores prepared queries. If not set it uses the default implementation with a max // capacity of ~100MB. PreparedCache proxycore.PreparedCache } type Proxy struct { - ctx context.Context - config Config - logger *zap.Logger - cluster *proxycore.Cluster - sessions [primitive.ProtocolVersionDse2 + 1]sync.Map // Cache sessions per protocol version - mu sync.Mutex - isConnected bool - isClosing bool - clients map[*client]struct{} - listeners map[*net.Listener]struct{} - eventClients sync.Map - preparedCache proxycore.PreparedCache - preparedIdempotence sync.Map - lb proxycore.LoadBalancer - systemLocalValues map[string]message.Column - closed chan struct{} - localNode *node - nodes []*node - onceUsingGraphLog sync.Once + ctx context.Context + config Config + logger *zap.Logger + cluster *proxycore.Cluster + sessions [primitive.ProtocolVersionDse2 + 1]sync.Map // Cache sessions per protocol version + mu sync.Mutex + isConnected bool + isClosing bool + clients map[*client]struct{} + listeners map[*net.Listener]struct{} + eventClients sync.Map + preparedCache proxycore.PreparedCache + preparedMetadata sync.Map + lb proxycore.LoadBalancer + systemLocalValues map[string]message.Column + closed chan struct{} + localNode *node + nodes []*node + onceUsingGraphLog sync.Once +} + +type preparedMetadata struct { + idempotent bool + isSelect bool } type node struct { @@ -469,31 +476,48 @@ func (p *Proxy) encodeTypeFatal(dt datatype.DataType, val interface{}) []byte { // isIdempotent checks whether a prepared ID is idempotent. // If the proxy receives a query that it's never prepared then this will also return false. func (p *Proxy) isIdempotent(id []byte) bool { - if val, ok := p.preparedIdempotence.Load(preparedIdKey(id)); !ok { + if val, ok := p.preparedMetadata.Load(preparedIdKey(id)); !ok { // This should only happen if the proxy has never had a "PREPARE" request for this query ID. p.logger.Error("unable to determine if prepared statement is idempotent", zap.String("preparedID", hex.EncodeToString(id))) return false } else { - return val.(bool) + return val.(preparedMetadata).idempotent + } +} + +func (p *Proxy) isSelect(id [16]byte) bool { + if val, ok := p.preparedMetadata.Load(id); !ok { + // This should only happen if the proxy has never had a "PREPARE" request for this query ID. + p.logger.Error("unable to determine if prepared statement is idempotent", + zap.String("preparedID", hex.EncodeToString(id[:]))) + return false + } else { + return val.(preparedMetadata).isSelect } } -// maybeStorePreparedIdempotence stores the idempotence of a "PREPARE" request's query. +// maybeStorePreparedMetadata stores the idempotence of a "PREPARE" request's query. // This information is used by future "EXECUTE" requests when they need to be retried. -func (p *Proxy) maybeStorePreparedIdempotence(raw *frame.RawFrame, msg message.Message) { +func (p *Proxy) maybeStorePreparedMetadata(raw *frame.RawFrame, isSelect bool, msg message.Message) { if prepareMsg, ok := msg.(*message.Prepare); ok && raw.Header.OpCode == primitive.OpCodeResult { // Prepared result frm, err := codec.ConvertFromRawFrame(raw) if err != nil { p.logger.Error("error attempting to decode prepared result message") - } else if _, ok = frm.Body.Message.(*message.PreparedResult); !ok { // TODO: Use prepared type data to disambiguate idempotency + } else if preparedResultMsg, ok := frm.Body.Message.(*message.PreparedResult); !ok { // TODO: Use prepared type data to disambiguate idempotency p.logger.Error("expected prepared result message, but got something else") } else { + p.logger.Debug("prepared request", + zap.Stringer("request", prepareMsg), + zap.Stringer("response", preparedResultMsg)) idempotent, err := parser.IsQueryIdempotent(prepareMsg.Query) if err != nil { p.logger.Error("error parsing query for idempotence", zap.Error(err)) } else if result, ok := frm.Body.Message.(*message.PreparedResult); ok { - p.preparedIdempotence.Store(preparedIdKey(result.PreparedQueryId), idempotent) + p.preparedMetadata.Store(preparedIdKey(result.PreparedQueryId), preparedMetadata{ + idempotent: idempotent, + isSelect: isSelect, + }) } else { p.logger.Error("expected prepared result, but got some other type of message", zap.Stringer("type", reflect.TypeOf(frm.Body.Message))) @@ -533,6 +557,7 @@ type client struct { conn *proxycore.Conn keyspace string preparedSystemQuery map[[16]byte]interface{} + preparedSelectQuery map[[16]byte]interface{} } func (c *client) Receive(reader io.Reader) error { @@ -579,7 +604,8 @@ func (c *client) Receive(reader io.Reader) error { case *partialQuery: c.handleQuery(raw, msg, body.CustomPayload) case *partialBatch: - c.execute(raw, notDetermined, c.keyspace, msg) + c.maybeOverrideUnsupportedWriteConsistency(false, raw, msg) + c.execute(raw, notDetermined, false, c.keyspace, msg) default: c.send(raw.Header, &message.ProtocolError{ErrorMessage: "Unsupported operation"}) } @@ -587,7 +613,7 @@ func (c *client) Receive(reader io.Reader) error { return nil } -func (c *client) execute(raw *frame.RawFrame, state idempotentState, keyspace string, msg message.Message) { +func (c *client) execute(raw *frame.RawFrame, state idempotentState, isSelect bool, keyspace string, msg message.Message) { if sess, err := c.proxy.findSession(raw.Header.Version, c.keyspace); err == nil { req := &request{ client: c, @@ -599,6 +625,7 @@ func (c *client) execute(raw *frame.RawFrame, state idempotentState, keyspace st stream: raw.Header.StreamId, qp: c.proxy.newQueryPlan(), raw: raw, + isSelect: isSelect, } req.Execute(true) } else { @@ -653,7 +680,8 @@ func (c *client) handlePrepare(raw *frame.RawFrame, msg *message.Prepare) { } } else { - c.execute(raw, isIdempotent, keyspace, msg) // Prepared statements can be retried themselves + _, isSelect := stmt.(*parser.SelectStatement) + c.execute(raw, isIdempotent, isSelect, keyspace, msg) // Prepared statements can be retried themselves } } @@ -662,7 +690,9 @@ func (c *client) handleExecute(raw *frame.RawFrame, msg *partialExecute, customP if stmt, ok := c.preparedSystemQuery[id]; ok { c.interceptSystemQuery(raw.Header, stmt) } else { - c.execute(raw, c.getDefaultIdempotency(customPayload), "", msg) + isSelect := c.proxy.isSelect(id) + c.maybeOverrideUnsupportedWriteConsistency(isSelect, raw, msg) + c.execute(raw, c.getDefaultIdempotency(customPayload), isSelect, "", msg) } } @@ -678,7 +708,9 @@ func (c *client) handleQuery(raw *frame.RawFrame, msg *partialQuery, customPaylo } } else { c.proxy.logger.Debug("Query not handled by proxy, forwarding", zap.String("query", msg.query), zap.Int16("stream", raw.Header.StreamId)) - c.execute(raw, c.getDefaultIdempotency(customPayload), c.keyspace, msg) + _, isSelect := stmt.(*parser.SelectStatement) + c.maybeOverrideUnsupportedWriteConsistency(isSelect, raw, msg) + c.execute(raw, c.getDefaultIdempotency(customPayload), isSelect, c.keyspace, msg) } } @@ -844,6 +876,75 @@ func (c *client) Closing(_ error) { c.proxy.removeClient(c) } +func (c *client) maybeOverrideUnsupportedWriteConsistency(isSelect bool, raw *frame.RawFrame, msg message.Message) { + if !isSelect { + overrideConsistency := c.proxy.config.UnsupportedWriteConsistencyOverride.ConsistencyLevel + + switch m := msg.(type) { + case *partialExecute: + if c.isUnsupportedWriteConsistency(m.consistency) { + c.proxy.logger.Debug("overriding unsupported write consistency for execute", + zap.Stringer("request", m), + zap.Stringer("unsupported", m.consistency), + zap.Stringer("override", overrideConsistency)) + err := patchExecuteConsistency(raw.Body, overrideConsistency) + if err != nil { + c.proxy.logger.Error("unable to override write consistency for execute", + zap.Stringer("request", m), + zap.Error(err)) + } + } else { + c.proxy.logger.Debug("no override required for execute write consistency", + zap.Stringer("request", m), + zap.Stringer("consistency", m.consistency)) + } + case *partialQuery: + if c.isUnsupportedWriteConsistency(m.consistency) { + c.proxy.logger.Debug("overriding unsupported write consistency for query", + zap.Stringer("request", m), + zap.Stringer("unsupported", m.consistency), + zap.Stringer("override", overrideConsistency)) + err := patchQueryConsistency(raw.Body, overrideConsistency) + if err != nil { + c.proxy.logger.Error("unable to override write consistency for query", + zap.Stringer("request", m), + zap.Error(err)) + } + } else { + c.proxy.logger.Debug("no override required for query write consistency", + zap.Stringer("request", m), + zap.Stringer("consistency", m.consistency)) + } + case *partialBatch: + if c.isUnsupportedWriteConsistency(m.consistency) { + c.proxy.logger.Debug("overriding unsupported write consistency for batch", + zap.Stringer("request", m), + zap.Stringer("unsupported", m.consistency), + zap.Stringer("override", overrideConsistency)) + err := patchBatchConsistency(raw.Body, overrideConsistency) + if err != nil { + c.proxy.logger.Error("unable to override write consistency for batch", + zap.Stringer("request", m), + zap.Error(err)) + } + } else { + c.proxy.logger.Debug("no override required for batch write consistency", + zap.Stringer("request", m), + zap.Stringer("consistency", m.consistency)) + } + } + } +} + +func (c *client) isUnsupportedWriteConsistency(consistency primitive.ConsistencyLevel) bool { + for _, unsupported := range c.proxy.config.UnsupportedWriteConsistencies { + if unsupported.ConsistencyLevel == consistency { + return true + } + } + return false +} + func getOrCreateDefaultPreparedCache(cache proxycore.PreparedCache) (proxycore.PreparedCache, error) { if cache == nil { return NewDefaultPreparedCache(1e8 / 256) // ~100MB with an average query size of 256 bytes diff --git a/proxy/request.go b/proxy/request.go index 9114b5f..e48b67b 100644 --- a/proxy/request.go +++ b/proxy/request.go @@ -49,6 +49,7 @@ type request struct { stream int16 qp proxycore.QueryPlan raw *frame.RawFrame + isSelect bool // Only used for prepared statements currently mu sync.Mutex } @@ -146,7 +147,7 @@ func (r *request) OnResult(raw *frame.RawFrame) { if !r.done { if raw.Header.OpCode != primitive.OpCodeError || !r.handleErrorResult(raw) { // If the error result is retried then we don't send back this response - r.client.proxy.maybeStorePreparedIdempotence(raw, r.msg) + r.client.proxy.maybeStorePreparedMetadata(raw, r.isSelect, r.msg) r.done = true r.sendRaw(raw) } diff --git a/proxy/run.go b/proxy/run.go index 9ad5111..2678cd6 100644 --- a/proxy/run.go +++ b/proxy/run.go @@ -40,34 +40,71 @@ const livenessPath = "/liveness" const readinessPath = "/readiness" type runConfig struct { - AstraBundle string `yaml:"astra-bundle" help:"Path to secure connect bundle for an Astra database. Requires '--username' and '--password'. Ignored if using the token or contact points option." short:"b" env:"ASTRA_BUNDLE"` - AstraToken string `yaml:"astra-token" help:"Token used to authenticate to an Astra database. Requires '--astra-database-id'. Ignored if using the bundle path or contact points option." short:"t" env:"ASTRA_TOKEN"` - AstraDatabaseID string `yaml:"astra-database-id" help:"Database ID of the Astra database. Requires '--astra-token'" short:"i" env:"ASTRA_DATABASE_ID"` - AstraApiURL string `yaml:"astra-api-url" help:"URL for the Astra API" default:"https://api.astra.datastax.com" env:"ASTRA_API_URL"` - AstraTimeout time.Duration `yaml:"astra-timeout" help:"Timeout for contacting Astra when retrieving the bundle and metadata" default:"10s" env:"ASTRA_TIMEOUT"` - ContactPoints []string `yaml:"contact-points" help:"Contact points for cluster. Ignored if using the bundle path or token option." short:"c" env:"CONTACT_POINTS"` - Username string `yaml:"username" help:"Username to use for authentication" short:"u" env:"USERNAME"` - Password string `yaml:"password" help:"Password to use for authentication" short:"p" env:"PASSWORD"` - Port int `yaml:"port" help:"Default port to use when connecting to cluster" default:"9042" short:"r" env:"PORT"` - ProtocolVersion string `yaml:"protocol-version" help:"Initial protocol version to use when connecting to the backend cluster (default: v4, options: v3, v4, v5, DSEv1, DSEv2)" default:"v4" short:"n" env:"PROTOCOL_VERSION"` - MaxProtocolVersion string `yaml:"max-protocol-version" help:"Max protocol version supported by the backend cluster (default: v4, options: v3, v4, v5, DSEv1, DSEv2)" default:"v4" short:"m" env:"MAX_PROTOCOL_VERSION"` - Bind string `yaml:"bind" help:"Address to use to bind server" short:"a" default:":9042" env:"BIND"` - Config *os.File `yaml:"-" help:"YAML configuration file" short:"f" env:"CONFIG_FILE"` // Not available in the configuration file - Debug bool `yaml:"debug" help:"Show debug logging" default:"false" env:"DEBUG"` - HealthCheck bool `yaml:"health-check" help:"Enable liveness and readiness checks" default:"false" env:"HEALTH_CHECK"` - HttpBind string `yaml:"http-bind" help:"Address to use to bind HTTP server used for health checks" default:":8000" env:"HTTP_BIND"` - HeartbeatInterval time.Duration `yaml:"heartbeat-interval" help:"Interval between performing heartbeats to the cluster" default:"30s" env:"HEARTBEAT_INTERVAL"` - ConnectTimeout time.Duration `yaml:"connect-timeout" help:"Duration before an attempt to connect to a cluster is considered timed out" default:"10s" env:"CONNECT_TIMEOUT"` - IdleTimeout time.Duration `yaml:"idle-timeout" help:"Duration between successful heartbeats before a connection to the cluster is considered unresponsive and closed" default:"60s" env:"IDLE_TIMEOUT"` - ReadinessTimeout time.Duration `yaml:"readiness-timeout" help:"Duration the proxy is unable to connect to the backend cluster before it is considered not ready" default:"30s" env:"READINESS_TIMEOUT"` - IdempotentGraph bool `yaml:"idempotent-graph" help:"If true it will treat all graph queries as idempotent by default and retry them automatically. It may be dangerous to retry some graph queries -- use with caution." default:"false" env:"IDEMPOTENT_GRAPH"` - NumConns int `yaml:"num-conns" help:"Number of connection to create to each node of the backend cluster" default:"1" env:"NUM_CONNS"` - ProxyCertFile string `yaml:"proxy-cert-file" help:"Path to a PEM encoded certificate file with its intermediate certificate chain. This is used to encrypt traffic for proxy clients" env:"PROXY_CERT_FILE"` - ProxyKeyFile string `yaml:"proxy-key-file" help:"Path to a PEM encoded private key file. This is used to encrypt traffic for proxy clients" env:"PROXY_KEY_FILE"` - RpcAddress string `yaml:"rpc-address" help:"Address to advertise in the 'system.local' table for 'rpc_address'. It must be set if configuring peer proxies" env:"RPC_ADDRESS"` - DataCenter string `yaml:"data-center" help:"Data center to use in system tables" env:"DATA_CENTER"` - Tokens []string `yaml:"tokens" help:"Tokens to use in the system tables. It's not recommended" env:"TOKENS"` - Peers []PeerConfig `yaml:"peers" kong:"-"` // Not available as a CLI flag + AstraBundle string `yaml:"astra-bundle" help:"Path to secure connect bundle for an Astra database. Requires '--username' and '--password'. Ignored if using the token or contact points option." short:"b" env:"ASTRA_BUNDLE"` + AstraToken string `yaml:"astra-token" help:"Token used to authenticate to an Astra database. Requires '--astra-database-id'. Ignored if using the bundle path or contact points option." short:"t" env:"ASTRA_TOKEN"` + AstraDatabaseID string `yaml:"astra-database-id" help:"Database ID of the Astra database. Requires '--astra-token'" short:"i" env:"ASTRA_DATABASE_ID"` + AstraApiURL string `yaml:"astra-api-url" help:"URL for the Astra API" default:"https://api.astra.datastax.com" env:"ASTRA_API_URL"` + AstraTimeout time.Duration `yaml:"astra-timeout" help:"Timeout for contacting Astra when retrieving the bundle and metadata" default:"10s" env:"ASTRA_TIMEOUT"` + ContactPoints []string `yaml:"contact-points" help:"Contact points for cluster. Ignored if using the bundle path or token option." short:"c" env:"CONTACT_POINTS"` + Username string `yaml:"username" help:"Username to use for authentication" short:"u" env:"USERNAME"` + Password string `yaml:"password" help:"Password to use for authentication" short:"p" env:"PASSWORD"` + Port int `yaml:"port" help:"Default port to use when connecting to cluster" default:"9042" short:"r" env:"PORT"` + ProtocolVersion string `yaml:"protocol-version" help:"Initial protocol version to use when connecting to the backend cluster (default: v4, options: v3, v4, v5, DSEv1, DSEv2)" default:"v4" short:"n" env:"PROTOCOL_VERSION"` + MaxProtocolVersion string `yaml:"max-protocol-version" help:"Max protocol version supported by the backend cluster (default: v4, options: v3, v4, v5, DSEv1, DSEv2)" default:"v4" short:"m" env:"MAX_PROTOCOL_VERSION"` + Bind string `yaml:"bind" help:"Address to use to bind server" short:"a" default:":9042" env:"BIND"` + Config *os.File `yaml:"-" help:"YAML configuration file" short:"f" env:"CONFIG_FILE"` // Not available in the configuration file + Debug bool `yaml:"debug" help:"Show debug logging" default:"false" env:"DEBUG"` + HealthCheck bool `yaml:"health-check" help:"Enable liveness and readiness checks" default:"false" env:"HEALTH_CHECK"` + HttpBind string `yaml:"http-bind" help:"Address to use to bind HTTP server used for health checks" default:":8000" env:"HTTP_BIND"` + HeartbeatInterval time.Duration `yaml:"heartbeat-interval" help:"Interval between performing heartbeats to the cluster" default:"30s" env:"HEARTBEAT_INTERVAL"` + ConnectTimeout time.Duration `yaml:"connect-timeout" help:"Duration before an attempt to connect to a cluster is considered timed out" default:"10s" env:"CONNECT_TIMEOUT"` + IdleTimeout time.Duration `yaml:"idle-timeout" help:"Duration between successful heartbeats before a connection to the cluster is considered unresponsive and closed" default:"60s" env:"IDLE_TIMEOUT"` + ReadinessTimeout time.Duration `yaml:"readiness-timeout" help:"Duration the proxy is unable to connect to the backend cluster before it is considered not ready" default:"30s" env:"READINESS_TIMEOUT"` + IdempotentGraph bool `yaml:"idempotent-graph" help:"If true it will treat all graph queries as idempotent by default and retry them automatically. It may be dangerous to retry some graph queries -- use with caution." default:"false" env:"IDEMPOTENT_GRAPH"` + NumConns int `yaml:"num-conns" help:"Number of connection to create to each node of the backend cluster" default:"1" env:"NUM_CONNS"` + ProxyCertFile string `yaml:"proxy-cert-file" help:"Path to a PEM encoded certificate file with its intermediate certificate chain. This is used to encrypt traffic for proxy clients" env:"PROXY_CERT_FILE"` + ProxyKeyFile string `yaml:"proxy-key-file" help:"Path to a PEM encoded private key file. This is used to encrypt traffic for proxy clients" env:"PROXY_KEY_FILE"` + RpcAddress string `yaml:"rpc-address" help:"Address to advertise in the 'system.local' table for 'rpc_address'. It must be set if configuring peer proxies" env:"RPC_ADDRESS"` + DataCenter string `yaml:"data-center" help:"Data center to use in system tables" env:"DATA_CENTER"` + Tokens []string `yaml:"tokens" help:"Tokens to use in the system tables. It's not recommended" env:"TOKENS"` + Peers []PeerConfig `yaml:"peers" kong:"-"` // Not available as a CLI flag + UnsupportedWriteConsistencies []clWrapper `yaml:"unsupported-write-consistencies" help:"A list of unsupported write consistency levels. The unsupported write consistency override setting will be used inplace of the unsupported level" env:"UNSUPPORTED_WRITE_CONSISTENCIES"` + UnsupportedWriteConsistencyOverride clWrapper `yaml:"unsupported-write-consistency-override" help:"A consistency level use to override unsupported write consistency levels" env:"" default:"LOCAL_QUORUM"` +} + +type clWrapper struct { + primitive.ConsistencyLevel +} + +func (c *clWrapper) UnmarshalText(text []byte) error { + switch strings.ToLower(string(text)) { + case "any": + c.ConsistencyLevel = primitive.ConsistencyLevelAny + case "one": + c.ConsistencyLevel = primitive.ConsistencyLevelOne + case "two": + c.ConsistencyLevel = primitive.ConsistencyLevelTwo + case "three": + c.ConsistencyLevel = primitive.ConsistencyLevelThree + case "quorum": + c.ConsistencyLevel = primitive.ConsistencyLevelQuorum + case "all": + c.ConsistencyLevel = primitive.ConsistencyLevelAll + case "local_quorum": + c.ConsistencyLevel = primitive.ConsistencyLevelLocalQuorum + case "each_quorum": + c.ConsistencyLevel = primitive.ConsistencyLevelEachQuorum + case "serial": + c.ConsistencyLevel = primitive.ConsistencyLevelSerial + case "local_serial": + c.ConsistencyLevel = primitive.ConsistencyLevelLocalSerial + case "local_one": + c.ConsistencyLevel = primitive.ConsistencyLevelLocalOne + default: + return fmt.Errorf("invalid consistency level: %s", string(text)) + } + + return nil } // Run starts the proxy command. 'args' shouldn't include the executable (i.e. os.Args[1:]). It returns the exit code @@ -171,21 +208,23 @@ func Run(ctx context.Context, args []string) int { } p := NewProxy(ctx, Config{ - Version: version, - MaxVersion: maxVersion, - Resolver: resolver, - ReconnectPolicy: proxycore.NewReconnectPolicy(), - NumConns: cfg.NumConns, - Auth: auth, - Logger: logger, - HeartBeatInterval: cfg.HeartbeatInterval, - ConnectTimeout: cfg.ConnectTimeout, - IdleTimeout: cfg.IdleTimeout, - RPCAddr: cfg.RpcAddress, - DC: cfg.DataCenter, - Tokens: cfg.Tokens, - Peers: cfg.Peers, - IdempotentGraph: cfg.IdempotentGraph, + Version: version, + MaxVersion: maxVersion, + Resolver: resolver, + ReconnectPolicy: proxycore.NewReconnectPolicy(), + NumConns: cfg.NumConns, + Auth: auth, + Logger: logger, + HeartBeatInterval: cfg.HeartbeatInterval, + ConnectTimeout: cfg.ConnectTimeout, + IdleTimeout: cfg.IdleTimeout, + RPCAddr: cfg.RpcAddress, + DC: cfg.DataCenter, + Tokens: cfg.Tokens, + Peers: cfg.Peers, + IdempotentGraph: cfg.IdempotentGraph, + UnsupportedWriteConsistencies: cfg.UnsupportedWriteConsistencies, + UnsupportedWriteConsistencyOverride: cfg.UnsupportedWriteConsistencyOverride, }) cfg.Bind = maybeAddPort(cfg.Bind, "9042") diff --git a/proxy/run_test.go b/proxy/run_test.go index 2056b6e..2e4f698 100644 --- a/proxy/run_test.go +++ b/proxy/run_test.go @@ -16,6 +16,7 @@ package proxy import ( "context" + "crypto/md5" "crypto/tls" "crypto/x509" "encoding/json" @@ -27,11 +28,13 @@ import ( "path" "runtime" "strconv" + "strings" "sync" "testing" "time" "github.com/datastax/cql-proxy/proxycore" + "github.com/datastax/go-cassandra-native-protocol/frame" "github.com/datastax/go-cassandra-native-protocol/message" "github.com/datastax/go-cassandra-native-protocol/primitive" "github.com/stretchr/testify/assert" @@ -479,6 +482,194 @@ func TestRun_ProxyTLS(t *testing.T) { require.Equal(t, rs.RowCount(), 1) } +func TestRun_UnsupportedWriteConsistency(t *testing.T) { + unsupportedConsistencies := []primitive.ConsistencyLevel{ + primitive.ConsistencyLevelAny, + primitive.ConsistencyLevelOne, + primitive.ConsistencyLevelLocalOne, + } + + selectConsistency := unsupportedConsistencies[0] + + checkMutationConsistency := func(consistency primitive.ConsistencyLevel) { + for _, unsupported := range unsupportedConsistencies { + assert.NotEqual(t, unsupported, consistency, "received unsupported consistency") + } + assert.Equal(t, primitive.ConsistencyLevelLocalQuorum, consistency) + } + + checkConsistency := func(query string, consistency primitive.ConsistencyLevel) { + if strings.Contains(query, "INSERT") { + checkMutationConsistency(consistency) + } else if strings.Contains(query, "SELECT") { + assert.Equal(t, selectConsistency, consistency) + } else { + assert.Fail(t, "received invalid query") + } + } + + ctx, cancel := context.WithCancel(context.Background()) + + clusterPort, clusterAddr, proxyBindAddr, httpBindAddr := generateTestAddrs(testAddr) + + cluster := proxycore.NewMockCluster(net.ParseIP(testStartAddr), clusterPort) + + var prepared sync.Map + + cluster.Handlers = proxycore.NewMockRequestHandlers(proxycore.MockRequestHandlers{ + primitive.OpCodeQuery: func(cl *proxycore.MockClient, frm *frame.Frame) message.Message { + if msg := cl.InterceptQuery(frm.Header, frm.Body.Message.(*message.Query)); msg != nil { + return msg + } else { + query := frm.Body.Message.(*message.Query) + checkConsistency(query.Query, query.Options.Consistency) + return &message.RowsResult{ + Metadata: &message.RowsMetadata{ + ColumnCount: 0, + }, + Data: message.RowSet{}, + } + } + }, + primitive.OpCodePrepare: func(client *proxycore.MockClient, frm *frame.Frame) message.Message { + prepare := frm.Body.Message.(*message.Prepare) + preparedId := md5.Sum([]byte(prepare.Query)) + prepared.Store(preparedId, prepare.Query) + return &message.PreparedResult{ + PreparedQueryId: preparedId[:], + } + }, + primitive.OpCodeExecute: func(cl *proxycore.MockClient, frm *frame.Frame) message.Message { + execute := frm.Body.Message.(*message.Execute) + preparedId := preparedIdKey(execute.QueryId) + query, ok := prepared.Load(preparedId) + assert.True(t, ok, "unable to find prepared ID") + checkConsistency(query.(string), execute.Options.Consistency) + return &message.RowsResult{ + Metadata: &message.RowsMetadata{ + ColumnCount: 0, + }, + Data: message.RowSet{}, + } + }, + primitive.OpCodeBatch: func(cl *proxycore.MockClient, frm *frame.Frame) message.Message { + batch := frm.Body.Message.(*message.Batch) + checkMutationConsistency(batch.Consistency) + return &message.VoidResult{} + }, + }) + + defer cluster.Shutdown() + err := cluster.Add(ctx, 1) + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + rc := Run(ctx, []string{ + "--bind", proxyBindAddr, + "--contact-points", clusterAddr, + "--port", strconv.Itoa(clusterPort), + "--health-check", + "--http-bind", httpBindAddr, + "--readiness-timeout", "200ms", // Use short timeout for the test + "--unsupported-write-consistencies", "any,one,local_one", + }) + assert.Equal(t, 0, rc) + wg.Done() + }() + + defer func() { + cancel() + wg.Wait() + }() + + require.True(t, waitUntil(10*time.Second, func() bool { + return checkLiveness(httpBindAddr) + })) + + cl, err := proxycore.ConnectClient(ctx, proxycore.NewEndpoint(proxyBindAddr), proxycore.ClientConnConfig{}) + defer cl.Close() + require.NoError(t, err) + + version, err := cl.Handshake(ctx, primitive.ProtocolVersion4, nil) + require.NoError(t, err) + assert.Equal(t, primitive.ProtocolVersion4, version) + + insertPrepareResp, err := cl.SendAndReceive( + ctx, + frame.NewFrame(version, 0, &message.Prepare{Query: "INSERT INTO test (k, v) VALUES ('k1', 'v1')"}), + ) + require.NoError(t, err) + require.Equal(t, primitive.OpCodeResult, insertPrepareResp.Header.OpCode) + insertPrepareResult, ok := insertPrepareResp.Body.Message.(*message.PreparedResult) + assert.True(t, ok, "expected prepared result") + + selectPrepareResp, err := cl.SendAndReceive( + ctx, + frame.NewFrame(version, 0, &message.Prepare{Query: "SELECT * FROM test.test"}), + ) + require.NoError(t, err) + require.Equal(t, primitive.OpCodeResult, selectPrepareResp.Header.OpCode) + selectPrepareResult, ok := selectPrepareResp.Body.Message.(*message.PreparedResult) + assert.True(t, ok, "expected prepared result") + + t.Run("simple queries", func(t *testing.T) { + for _, unsupported := range unsupportedConsistencies { + _, err = cl.Query(ctx, primitive.ProtocolVersion4, &message.Query{ + Query: "INSERT INTO test (k, v) VALUES ('k1', 'v1')", + Options: &message.QueryOptions{ + Consistency: unsupported, + }, + }) + assert.NoError(t, err) + } + + _, err = cl.Query(ctx, primitive.ProtocolVersion4, &message.Query{ + Query: "SELECT * FROM test", + Options: &message.QueryOptions{ + Consistency: selectConsistency, + }, + }) + assert.NoError(t, err) + }) + + t.Run("prepared queries", func(t *testing.T) { + for _, unsupported := range unsupportedConsistencies { + _, err = cl.Query(ctx, primitive.ProtocolVersion4, &message.Execute{ + QueryId: insertPrepareResult.PreparedQueryId, + Options: &message.QueryOptions{ + Consistency: unsupported, + }, + }) + assert.NoError(t, err) + } + + _, err = cl.Query(ctx, primitive.ProtocolVersion4, &message.Execute{ + QueryId: selectPrepareResult.PreparedQueryId, + Options: &message.QueryOptions{ + Consistency: selectConsistency, + }, + }) + assert.NoError(t, err) + }) + + t.Run("batch", func(t *testing.T) { + for _, unsupported := range unsupportedConsistencies { + _, err = cl.Query(ctx, primitive.ProtocolVersion4, &message.Batch{ + Children: []*message.BatchChild{ + {Query: "INSERT INTO test (k, v) VALUES ('k1', 'v1')"}, + {Id: insertPrepareResp.Body.Message.(*message.PreparedResult).PreparedQueryId}, + }, + Consistency: unsupported, + }) + assert.NoError(t, err) + } + }) + +} + func writeTempYaml(o interface{}) (name string, err error) { bytes, err := yaml.Marshal(o) if err != nil { diff --git a/proxycore/endpoint.go b/proxycore/endpoint.go index aff0e1d..d606ffa 100644 --- a/proxycore/endpoint.go +++ b/proxycore/endpoint.go @@ -141,7 +141,7 @@ func LookupEndpoint(endpoint Endpoint) (string, error) { } else { host, port, err := net.SplitHostPort(endpoint.Addr()) if err != nil { - return "'", err + return "", err } addrs, err := net.LookupHost(host) if err != nil { diff --git a/proxycore/endpoint_test.go b/proxycore/endpoint_test.go index 5f04523..2ea5325 100644 --- a/proxycore/endpoint_test.go +++ b/proxycore/endpoint_test.go @@ -42,7 +42,7 @@ func TestLookupEndpoint_Invalid(t *testing.T) { err string }{ {"localhost", "missing port in address"}, - {"dne:1234", ""}, // Errors for DNS can vary per system + {"test:1234", ""}, // Errors for DNS can vary per system } for _, tt := range tests {