Skip to content

Configurable support for rewriting consistency levels on writes #142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Apr 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ type UseStatement struct {

func (u UseStatement) isStatement() {}

var defaultSelectStatement = &SelectStatement{}

// IsQueryHandled parses the query string and determines if the query is handled by the proxy
func IsQueryHandled(keyspace Identifier, query string) (handled bool, stmt Statement, err error) {
var l lexer
Expand All @@ -164,11 +166,15 @@ func IsQueryHandled(keyspace Identifier, query string) (handled bool, stmt State
t := l.next()
switch t {
case tkSelect:
return isHandledSelectStmt(&l, keyspace)
handled, stmt, err = isHandledSelectStmt(&l, keyspace)
if !handled {
stmt = defaultSelectStatement
}
return
case tkUse:
return isHandledUseStmt(&l)
}
return false, nil, nil
return
}

// IsQueryIdempotent parses the query string and determines if the query is idempotent
Expand Down
24 changes: 13 additions & 11 deletions parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,12 @@ func TestParser(t *testing.T) {
Keyspace: "system",
}, false},
// Reads from tables named similarly to system tables (not handled)
{"", "SELECT count(*) FROM local", false, true, nil, false},
{"", "SELECT count(*) FROM peers", false, true, nil, false},
{"", "SELECT count(*) FROM peers_v2", false, true, nil, false},
{"", "SELECT count(*) FROM local", false, true, defaultSelectStatement, false},
{"", "SELECT count(*) FROM peers", false, true, defaultSelectStatement, false},
{"", "SELECT count(*) FROM peers_v2", false, true, defaultSelectStatement, false},

// Semicolon at the end
{"", "SELECT count(*) FROM table;", false, true, nil, false},
{"", "SELECT count(*) FROM table;", false, true, defaultSelectStatement, false},

// Mutations to system tables (not handled)
{"", "INSERT INTO system.local (key, rpc_address) VALUES ('local1', '127.0.0.1')", false, true, nil, false},
Expand Down Expand Up @@ -169,15 +169,17 @@ func TestParser(t *testing.T) {
}

for _, tt := range tests {
handled, stmt, err := IsQueryHandled(IdentifierFromString(tt.keyspace), tt.query)
assert.True(t, (err != nil) == tt.hasError, tt.query)
t.Run(tt.query, func(t *testing.T) {
handled, stmt, err := IsQueryHandled(IdentifierFromString(tt.keyspace), tt.query)
assert.True(t, (err != nil) == tt.hasError, tt.query)

idempotent, err := IsQueryIdempotent(tt.query)
assert.Nil(t, err, tt.query)
idempotent, err := IsQueryIdempotent(tt.query)
assert.Nil(t, err, tt.query)

assert.Equal(t, tt.handled, handled, "invalid handled", tt.query)
assert.Equal(t, tt.idempotent, idempotent, "invalid idempotency", tt.query)
assert.Equal(t, tt.stmt, stmt, "invalid parsed statement", tt.query)
assert.Equal(t, tt.handled, handled, "invalid handled", tt.query)
assert.Equal(t, tt.idempotent, idempotent, "invalid idempotency", tt.query)
assert.Equal(t, tt.stmt, stmt, "invalid parsed statement", tt.query)
})
}
}

Expand Down
1 change: 1 addition & 0 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (

func main() {
ctx, cancel := signalContext(context.Background(), os.Interrupt, os.Kill)

defer cancel()

os.Exit(proxy.Run(ctx, os.Args[1:]))
Expand Down
65 changes: 50 additions & 15 deletions proxy/codecs.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,30 @@ func (c *partialQueryCodec) EncodedLength(_ message.Message, _ primitive.Protoco
panic("not implemented")
}

func (c *partialQueryCodec) Decode(source io.Reader, _ primitive.ProtocolVersion) (message.Message, error) {
if query, err := primitive.ReadLongString(source); err != nil {
func (c *partialQueryCodec) Decode(source io.Reader, _ primitive.ProtocolVersion) (msg message.Message, err error) {
var (
query string
consistency uint16
)

if query, err = primitive.ReadLongString(source); err != nil {
return nil, err
} else {
return &partialQuery{query}, nil
}

if consistency, err = primitive.ReadShort(source); err != nil {
return nil, fmt.Errorf("cannot read QUERY consistency level: %w", err)
}

return &partialQuery{query, primitive.ConsistencyLevel(consistency)}, nil
}

func (c *partialQueryCodec) GetOpCode() primitive.OpCode {
return primitive.OpCodeQuery
}

type partialQuery struct {
query string
query string
consistency primitive.ConsistencyLevel
}

func (p *partialQuery) IsResponse() bool {
Expand All @@ -63,11 +73,16 @@ func (p *partialQuery) GetOpCode() primitive.OpCode {
}

func (p *partialQuery) DeepCopyMessage() message.Message {
return &partialQuery{p.query}
return &partialQuery{p.query, p.consistency}
}

func (m *partialQuery) String() string {
return "QUERY " + m.query
}

type partialExecute struct {
queryId []byte
queryId []byte
consistency primitive.ConsistencyLevel
}

func (m *partialExecute) IsResponse() bool {
Expand All @@ -81,7 +96,7 @@ func (m *partialExecute) GetOpCode() primitive.OpCode {
func (m *partialExecute) DeepCopyMessage() message.Message {
queryId := make([]byte, len(m.queryId))
copy(queryId, m.queryId)
return &partialExecute{queryId}
return &partialExecute{queryId, m.consistency}
}

func (m *partialExecute) String() string {
Expand All @@ -99,21 +114,31 @@ func (c *partialExecuteCodec) EncodedLength(_ message.Message, _ primitive.Proto
}

func (c *partialExecuteCodec) Decode(source io.Reader, _ primitive.ProtocolVersion) (msg message.Message, err error) {
execute := &partialExecute{}
if execute.queryId, err = primitive.ReadShortBytes(source); err != nil {
var (
queryID []byte
consistency uint16
)

if queryID, err = primitive.ReadShortBytes(source); err != nil {
return nil, fmt.Errorf("cannot read EXECUTE query id: %w", err)
} else if len(execute.queryId) == 0 {
} else if len(queryID) == 0 {
return nil, errors.New("EXECUTE missing query id")
}
return execute, nil

if consistency, err = primitive.ReadShort(source); err != nil {
return nil, fmt.Errorf("cannot read EXECUTE consistency level: %w", err)
}

return &partialExecute{queryID, primitive.ConsistencyLevel(consistency)}, nil
}

func (c *partialExecuteCodec) GetOpCode() primitive.OpCode {
return primitive.OpCodeExecute
}

type partialBatch struct {
queryOrIds []interface{}
queryOrIds []interface{}
consistency primitive.ConsistencyLevel
}

func (p partialBatch) IsResponse() bool {
Expand All @@ -127,7 +152,11 @@ func (p partialBatch) GetOpCode() primitive.OpCode {
func (p partialBatch) DeepCopyMessage() message.Message {
queryOrIds := make([]interface{}, len(p.queryOrIds))
copy(queryOrIds, p.queryOrIds)
return &partialBatch{queryOrIds}
return &partialBatch{queryOrIds, p.consistency}
}

func (p partialBatch) String() string {
return fmt.Sprintf("BATCH (%d statements)", len(p.queryOrIds))
}

type partialBatchCodec struct{}
Expand All @@ -142,6 +171,7 @@ func (p partialBatchCodec) EncodedLength(msg message.Message, version primitive.

func (p partialBatchCodec) Decode(source io.Reader, version primitive.ProtocolVersion) (msg message.Message, err error) {
var queryOrIds []interface{}
var consistency uint16
var typ uint8
if typ, err = primitive.ReadByte(source); err != nil {
return nil, fmt.Errorf("cannot read BATCH type: %w", err)
Expand Down Expand Up @@ -177,7 +207,12 @@ func (p partialBatchCodec) Decode(source io.Reader, version primitive.ProtocolVe
}
queryOrIds[i] = queryOrId
}
return &partialBatch{queryOrIds}, nil

if consistency, err = primitive.ReadShort(source); err != nil {
return nil, fmt.Errorf("cannot read BATCH consistency level: %w", err)
}

return &partialBatch{queryOrIds, primitive.ConsistencyLevel(consistency)}, nil
}

func (p partialBatchCodec) GetOpCode() primitive.OpCode {
Expand Down
156 changes: 156 additions & 0 deletions proxy/frame_patch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package proxy

import (
"encoding/binary"
"fmt"

"github.com/datastax/go-cassandra-native-protocol/primitive"
)

// patchQueryConsistency modifies the consistency level of a QUERY message in-place
// by locating the consistency field directly in the frame body.
//
// Layout based on the CQL native protocol v4 spec:
// /* <query: long string><consistency: short><flags: byte>... */
func patchQueryConsistency(body []byte, newConsistency primitive.ConsistencyLevel) error {
if len(body) < 6 {
return fmt.Errorf("body too short for QUERY")
}

queryLen := binary.BigEndian.Uint32(body[0:4])
offset := 4 + int(queryLen)

if len(body) < offset+2 {
return fmt.Errorf("not enough bytes to patch QUERY consistency")
}

// Modify the batch consistency field
binary.BigEndian.PutUint16(body[offset:offset+2], uint16(newConsistency))

return nil
}

// patchExecuteConsistency modifies the consistency level of an EXECUTE message in-place
// by locating the consistency field directly after the prepared statement ID.
//
// Layout based on the CQL native protocol v4 spec:
// /* <id: short bytes><consistency: short><flags: byte>... */
func patchExecuteConsistency(body []byte, newConsistency primitive.ConsistencyLevel) error {
if len(body) < 2 {
return fmt.Errorf("body too short for EXECUTE")
}

idLen := int(binary.BigEndian.Uint16(body[0:2]))
offset := 2 + idLen

if len(body) < offset+2 {
return fmt.Errorf("not enough bytes to patch EXECUTE consistency")
}

// Modify the batch consistency field
binary.BigEndian.PutUint16(body[offset:offset+2], uint16(newConsistency))

return nil
}

// patchBatchConsistency modifies the consistency level of a BATCH message in-place
// by locating and modifying the consistency field of the batch, which applies to all queries in the batch.
//
// Layout based on the CQL native protocol v4 spec:
// /* <type: byte>
//
// <n: short> (number of queries)
// <queries...> (queries themselves)
// <batch consistency: short>
// <flags: byte>
// [<serial_consistency: short>]
// [<timestamp: long>] */
func patchBatchConsistency(body []byte, newConsistency primitive.ConsistencyLevel) error {
if len(body) < 7 {
// Not enough bytes for even the basic batch layout (at least 2 bytes for n + 2 bytes for consistency + 1 byte for flags)
return fmt.Errorf("invalid batch body: too short")
}

offset := 3 // <type> [byte] <n> [short] (3 bytes)
numQueries := binary.BigEndian.Uint16(body[1:3])

// Process the queries
for i := uint16(0); i < numQueries; i++ {
if len(body) <= offset {
return fmt.Errorf("query #%d exceeds body length", i)
}

queryType := body[offset]
offset++ // Move past the query <kind> [byte]

switch primitive.BatchChildType(queryType) {
case primitive.BatchChildTypeQueryString:
queryLength := binary.BigEndian.Uint32(body[offset : offset+4])
offset += 4 // Move past the length
_ = body[offset : offset+int(queryLength)]
offset += int(queryLength)
case primitive.BatchChildTypePreparedId:
stmtIDLength := binary.BigEndian.Uint16(body[offset : offset+2])
offset += 2 // Move past the length
_ = body[offset : offset+int(stmtIDLength)]
offset += int(stmtIDLength)
default:
return fmt.Errorf("unsupported BATCH child type for query #%d: %v", i, queryType)
}

// Skip positional values
if err := skipPositionalValuesByteSlice(body, &offset); err != nil {
return fmt.Errorf("cannot skip positional values for query #%d: %w", i, err)
}
}

// Patch consistency at the right spot
if len(body) < offset+2 {
return fmt.Errorf("not enough bytes to patch consistency")
}
binary.BigEndian.PutUint16(body[offset:], uint16(newConsistency))

return nil
}

// skipPositionalValuesByteSlice skips the positional values in the byte slice
// It reads the length of positional values and skips them based on the byte slice offset.
func skipPositionalValuesByteSlice(body []byte, offset *int) error {
if len(body) <= *offset+2 {
return fmt.Errorf("insufficient bytes to read positional values length")
}
length := binary.BigEndian.Uint16(body[*offset : *offset+2])
*offset += 2 // Move the offset past the length

for i := uint16(0); i < length; i++ {
if err := skipValueByteSlice(body, offset); err != nil {
return fmt.Errorf("cannot skip positional value %d: %w", i, err)
}
}
return nil
}

// skipValueByteSlice skips a single positional value based on its length in the byte slice.
func skipValueByteSlice(body []byte, offset *int) error {
if len(body) <= *offset+4 {
return fmt.Errorf("insufficient bytes to read value length")
}
length := int32(binary.BigEndian.Uint32(body[*offset : *offset+4]))
*offset += 4 // Move the offset past the length

if length == -1 || length == -2 {
// It's a null or unset, nothing to skip
return nil
}
if length < 0 {
return fmt.Errorf("invalid negative length: %d", length)
}

if length > 0 {
if len(body) < *offset+int(length) {
return fmt.Errorf("insufficient bytes to skip value content")
}
*offset += int(length) // Move the offset past the value content
}
return nil
}
Loading