From b38d5b769188cb713d0bf3c3e554342a7adc3d4d Mon Sep 17 00:00:00 2001 From: Christoph Pirkl Date: Thu, 27 Jun 2024 11:27:05 +0200 Subject: [PATCH 01/12] Add integration tests --- Makefile | 4 +- itest/integration_test.go | 129 ++++++++++++++---- pkg/connection/prepared_stmt_converter.go | 11 +- .../prepared_stmt_converter_test.go | 48 +++++-- pkg/connection/result_set_test.go | 5 - 5 files changed, 155 insertions(+), 42 deletions(-) diff --git a/Makefile b/Makefile index 881cad8..71628e1 100644 --- a/Makefile +++ b/Makefile @@ -10,10 +10,10 @@ lint: golangci-lint run --print-issued-lines=false ./... test: - go test -v -coverprofile=coverage.out ./... + go test -count 1 -v -p 1 -coverprofile=coverage.out ./... testshort: - go test -v -short -coverprofile=coverage.out ./... + go test -count 1 -v -short -coverprofile=coverage.out ./... coverage: test go tool cover -html=coverage.out -o coverage.html diff --git a/itest/integration_test.go b/itest/integration_test.go index 6a0e6a8..1eb56d5 100644 --- a/itest/integration_test.go +++ b/itest/integration_test.go @@ -6,6 +6,7 @@ import ( "encoding/csv" "fmt" "log" + "math" "os" "os/user" "regexp" @@ -17,6 +18,7 @@ import ( "github.com/exasol/exasol-driver-go" "github.com/exasol/exasol-driver-go/pkg/dsn" "github.com/exasol/exasol-driver-go/pkg/integrationTesting" + "github.com/exasol/exasol-driver-go/pkg/logger" "github.com/stretchr/testify/assert" "go.uber.org/goleak" @@ -193,6 +195,20 @@ func (suite *IntegrationTestSuite) TestFetch() { suite.Equal(10000, len(result)) } +func (suite *IntegrationTestSuite) TestFetchLargeInteger() { + logger.EnableTraceLogger() + database := suite.openConnection(suite.createDefaultConfig()) + number := 100000000 + rows, err := database.Query(fmt.Sprintf("SELECT %d", number)) + suite.NoError(err) + suite.True(rows.Next()) + var result int64 + err = rows.Scan(&result) + suite.NoError(err) + defer rows.Close() + suite.Equal(number, result) +} + func (suite *IntegrationTestSuite) TestExecuteWithError() { database := suite.openConnection(suite.createDefaultConfig()) defer database.Close() @@ -225,7 +241,9 @@ func (suite *IntegrationTestSuite) TestPreparedStatement() { } var dereferenceString = func(v any) any { return *(v.(*string)) } +var dereferenceFloat32 = func(v any) any { return *(v.(*float32)) } var dereferenceFloat64 = func(v any) any { return *(v.(*float64)) } +var dereferenceInt32 = func(v any) any { return *(v.(*int32)) } var dereferenceInt64 = func(v any) any { return *(v.(*int64)) } var dereferenceInt = func(v any) any { return *(v.(*int)) } var dereferenceBool = func(v any) any { return *(v.(*bool)) } @@ -239,14 +257,25 @@ func (suite *IntegrationTestSuite) TestQueryDataTypesCast() { expectedValue any dereference func(any) any }{ + // DECIMAL {"decimal to int64", "1", "DECIMAL(18,0)", new(int64), int64(1), dereferenceInt64}, {"decimal to int", "1", "DECIMAL(18,0)", new(int), 1, dereferenceInt}, {"decimal to float", "1", "DECIMAL(18,0)", new(float64), 1.0, dereferenceFloat64}, {"decimal to string", "1", "DECIMAL(18,0)", new(string), "1", dereferenceString}, + {"max int64", fmt.Sprintf("%d", math.MaxInt64), "DECIMAL(36,0)", new(int64), int64(math.MaxInt64), dereferenceInt64}, + {"min int64", fmt.Sprintf("%d", math.MinInt64), "DECIMAL(36,0)", new(int64), int64(math.MinInt64), dereferenceInt64}, {"decimal to float64", "2.2", "DECIMAL(18,2)", new(float64), 2.2, dereferenceFloat64}, {"decimal to string", "2.2", "DECIMAL(18,2)", new(string), "2.2", dereferenceString}, + {"double to float64", "3.3", "DOUBLE PRECISION", new(float64), 3.3, dereferenceFloat64}, + {"double to float64", "-3.3", "DOUBLE PRECISION", new(float64), -3.3, dereferenceFloat64}, + {"double to float64", "1.7976e+308", "DOUBLE PRECISION", new(float64), 1.7975999999999999e+308, dereferenceFloat64}, + {"double to float64", "-1.7976e+308", "DOUBLE PRECISION", new(float64), -1.7975999999999999e+308, dereferenceFloat64}, + {"double to float64", fmt.Sprintf("%g", math.SmallestNonzeroFloat64), "DOUBLE PRECISION", new(float64), math.SmallestNonzeroFloat64, dereferenceFloat64}, + {"double to float32", fmt.Sprintf("%g", math.MaxFloat32), "DOUBLE PRECISION", new(float32), float32(3.4028235e+38), dereferenceFloat32}, + {"double to float32", fmt.Sprintf("%g", math.SmallestNonzeroFloat32), "DOUBLE PRECISION", new(float32), float32(1e-45), dereferenceFloat32}, {"double to string", "3.3", "DOUBLE PRECISION", new(string), "3.3", dereferenceString}, + {"varchar to string", "'text'", "VARCHAR(10)", new(string), "text", dereferenceString}, {"char to string", "'text'", "CHAR(10)", new(string), "text ", dereferenceString}, {"date to string", "'2024-06-18'", "DATE", new(string), "2024-06-18", dereferenceString}, @@ -274,34 +303,85 @@ func (suite *IntegrationTestSuite) TestQueryDataTypesCast() { } func (suite *IntegrationTestSuite) TestPreparedStatementArgsConverted() { - for i, testCase := range []struct { + type TestCase struct { sqlValue any sqlType string scanDest any expectedValue any dereference func(any) any - }{ - {1, "DECIMAL(18,0)", new(int64), int64(1), dereferenceInt64}, - {1.1, "DECIMAL(18,0)", new(int64), int64(1), dereferenceInt64}, - {1, "DECIMAL(18,0)", new(int), 1, dereferenceInt}, - {1, "DECIMAL(18,0)", new(float64), 1.0, dereferenceFloat64}, - {2.2, "DECIMAL(18,2)", new(float64), 2.2, dereferenceFloat64}, - {2, "DECIMAL(18,2)", new(float64), 2.0, dereferenceFloat64}, - {3.3, "DOUBLE PRECISION", new(float64), 3.3, dereferenceFloat64}, - {3, "DOUBLE PRECISION", new(float64), 3.0, dereferenceFloat64}, - {"text", "VARCHAR(10)", new(string), "text", dereferenceString}, - {"text", "CHAR(10)", new(string), "text ", dereferenceString}, - {"2024-06-18", "DATE", new(string), "2024-06-18", dereferenceString}, - {time.Date(2024, time.June, 18, 0, 0, 0, 0, time.UTC), "DATE", new(string), "2024-06-18", dereferenceString}, - {"2024-06-18 17:22:13.123456", "TIMESTAMP", new(string), "2024-06-18 17:22:13.123000", dereferenceString}, - {time.Date(2024, time.June, 18, 17, 22, 13, 123456789, time.UTC), "TIMESTAMP", new(string), "2024-06-18 17:22:13.123000", dereferenceString}, - {"2024-06-18 17:22:13.123456", "TIMESTAMP WITH LOCAL TIME ZONE", new(string), "2024-06-18 17:22:13.123000", dereferenceString}, - {time.Date(2024, time.June, 18, 17, 22, 13, 123456789, time.UTC), "TIMESTAMP WITH LOCAL TIME ZONE", new(string), "2024-06-18 17:22:13.123000", dereferenceString}, - {"point(1 2)", "GEOMETRY", new(string), "POINT (1 2)", dereferenceString}, - {"5-3", "INTERVAL YEAR TO MONTH", new(string), "+05-03", dereferenceString}, - {"2 12:50:10.123", "INTERVAL DAY TO SECOND", new(string), "+02 12:50:10.123", dereferenceString}, - {"550e8400-e29b-11d4-a716-446655440000", "HASHTYPE", new(string), "550e8400e29b11d4a716446655440000", dereferenceString}, - {true, "BOOLEAN", new(bool), true, dereferenceBool}, + //delta float64 + } + int64TestCase := func(sqlValue any, sqlType string, expectedValue int64) TestCase { + return TestCase{sqlValue: sqlValue, sqlType: sqlType, scanDest: new(int64), expectedValue: expectedValue, dereference: dereferenceInt64} + } + int32TestCase := func(sqlValue any, sqlType string, expectedValue int32) TestCase { + return TestCase{sqlValue: sqlValue, sqlType: sqlType, scanDest: new(int32), expectedValue: expectedValue, dereference: dereferenceInt32} + } + float64TestCase := func(sqlValue any, sqlType string, expectedValue float64) TestCase { + return TestCase{sqlValue: sqlValue, sqlType: sqlType, scanDest: new(float64), expectedValue: expectedValue, dereference: dereferenceFloat64} + } + float32TestCase := func(sqlValue any, sqlType string, expectedValue float32) TestCase { + return TestCase{sqlValue: sqlValue, sqlType: sqlType, scanDest: new(float32), expectedValue: expectedValue, dereference: dereferenceFloat32} + } + stringTestCase := func(sqlValue any, sqlType string, expectedValue string) TestCase { + return TestCase{sqlValue: sqlValue, sqlType: sqlType, scanDest: new(string), expectedValue: expectedValue, dereference: dereferenceString} + } + boolTestCase := func(sqlValue any, sqlType string, expectedValue bool) TestCase { + return TestCase{sqlValue: sqlValue, sqlType: sqlType, scanDest: new(bool), expectedValue: expectedValue, dereference: dereferenceBool} + } + + for i, testCase := range []TestCase{ + // DECIMAL + int64TestCase(1, "DECIMAL(18,0)", 1), + int64TestCase(-1, "DECIMAL(18,0)", -1), + int64TestCase(1.1, "DECIMAL(18,0)", 1), + int64TestCase(-1.1, "DECIMAL(18,0)", -1), + int64TestCase(math.MaxInt64, "DECIMAL(36,0)", math.MaxInt64), + int64TestCase(math.MinInt64, "DECIMAL(36,0)", math.MinInt64), + + int32TestCase(1, "DECIMAL(18,0)", 1), + int32TestCase(-1, "DECIMAL(18,0)", -1), + int32TestCase(1.1, "DECIMAL(18,0)", 1), + int32TestCase(-1.1, "DECIMAL(18,0)", -1), + int32TestCase(math.MaxInt32, "DECIMAL(36,0)", math.MaxInt32), + int32TestCase(math.MinInt32, "DECIMAL(36,0)", math.MinInt32), + + float64TestCase(1, "DECIMAL(18,0)", 1), + float64TestCase(-1, "DECIMAL(18,0)", -1), + float64TestCase(1.123, "DECIMAL(18,3)", 1.123), + float64TestCase(-1.123, "DECIMAL(18,3)", -1.123), + + float32TestCase(1, "DECIMAL(18,0)", 1), + float32TestCase(-1, "DECIMAL(18,0)", -1), + float32TestCase(1.123, "DECIMAL(18,3)", 1.123), + float32TestCase(-1.123, "DECIMAL(18,3)", -1.123), + + // DOUBLE + float64TestCase(3.3, "DOUBLE PRECISION", 3.3), + float64TestCase(-3.3, "DOUBLE PRECISION", -3.3), + float64TestCase(3, "DOUBLE PRECISION", 3.0), + float64TestCase(-3, "DOUBLE PRECISION", -3.0), + + float32TestCase(math.MaxFloat32, "DOUBLE PRECISION", math.MaxFloat32), + float32TestCase(math.SmallestNonzeroFloat32, "DOUBLE PRECISION", math.SmallestNonzeroFloat32), + float64TestCase(1.7976e+308, "DOUBLE PRECISION", 1.7975999999999999e+308), // math.MaxFloat64 causes error "data exception - numeric value out of range" + float64TestCase(math.SmallestNonzeroFloat64, "DOUBLE PRECISION", math.SmallestNonzeroFloat64), + + // VARCHAR + stringTestCase("text", "VARCHAR(10)", "text"), + stringTestCase("text", "CHAR(10)", "text "), + stringTestCase("2024-06-18", "DATE", "2024-06-18"), + stringTestCase(time.Date(2024, time.June, 18, 0, 0, 0, 0, time.UTC), "DATE", "2024-06-18"), + stringTestCase("2024-06-18 17:22:13.123456", "TIMESTAMP", "2024-06-18 17:22:13.123000"), + stringTestCase(time.Date(2024, time.June, 18, 17, 22, 13, 123456789, time.UTC), "TIMESTAMP", "2024-06-18 17:22:13.123000"), + stringTestCase("2024-06-18 17:22:13.123456", "TIMESTAMP WITH LOCAL TIME ZONE", "2024-06-18 17:22:13.123000"), + stringTestCase(time.Date(2024, time.June, 18, 17, 22, 13, 123456789, time.UTC), "TIMESTAMP WITH LOCAL TIME ZONE", "2024-06-18 17:22:13.123000"), + stringTestCase("point(1 2)", "GEOMETRY", "POINT (1 2)"), + stringTestCase("5-3", "INTERVAL YEAR TO MONTH", "+05-03"), + stringTestCase("2 12:50:10.123", "INTERVAL DAY TO SECOND", "+02 12:50:10.123"), + stringTestCase("550e8400-e29b-11d4-a716-446655440000", "HASHTYPE", "550e8400e29b11d4a716446655440000"), + boolTestCase(true, "BOOLEAN", true), + boolTestCase(false, "BOOLEAN", false), } { database := suite.openConnection(suite.createDefaultConfig().Autocommit(false)) schemaName := "DATATYPE_TEST" @@ -320,8 +400,9 @@ func (suite *IntegrationTestSuite) TestPreparedStatementArgsConverted() { rows, err := database.Query(fmt.Sprintf("select * from %s", tableName)) onError(err) defer rows.Close() - suite.True(rows.Next(), "should have one row") + suite.True(rows.Next(), "should have at least one row") onError(rows.Scan(testCase.scanDest)) + suite.False(rows.Next(), "should have at most one row") val := testCase.scanDest suite.Equal(testCase.expectedValue, testCase.dereference(val)) }) diff --git a/pkg/connection/prepared_stmt_converter.go b/pkg/connection/prepared_stmt_converter.go index b905f24..d155346 100644 --- a/pkg/connection/prepared_stmt_converter.go +++ b/pkg/connection/prepared_stmt_converter.go @@ -4,6 +4,7 @@ import ( "database/sql/driver" "encoding/json" "fmt" + "strings" "time" "github.com/exasol/exasol-driver-go/pkg/errors" @@ -63,7 +64,15 @@ type jsonDoubleValueStruct struct { } func (j *jsonDoubleValueStruct) MarshalJSON() ([]byte, error) { - return []byte(fmt.Sprintf("%f", j.value)), nil + r, err := json.Marshal(j.value) + if err != nil { + return nil, err + } + formatted := string(r) + if !strings.Contains(formatted, ".") && !strings.Contains(strings.ToLower(formatted), "e") { + return []byte(formatted + ".0"), nil + } + return r, nil } func jsonTimestampValue(value time.Time) json.Marshaler { diff --git a/pkg/connection/prepared_stmt_converter_test.go b/pkg/connection/prepared_stmt_converter_test.go index 938bd5b..4b50384 100644 --- a/pkg/connection/prepared_stmt_converter_test.go +++ b/pkg/connection/prepared_stmt_converter_test.go @@ -4,6 +4,7 @@ import ( "database/sql/driver" "encoding/json" "fmt" + "math" "testing" "time" @@ -23,11 +24,19 @@ func TestConvertArgs(t *testing.T) { }{ {arg: "text", exasolType: "VARCHAR", expectedJson: `"text"`}, {arg: 123, exasolType: "VARCHAR", expectedJson: `123`}, + {arg: -123, exasolType: "VARCHAR", expectedJson: `-123`}, + {arg: math.MaxInt64, exasolType: "VARCHAR", expectedJson: `9223372036854775807`}, + {arg: math.MinInt64, exasolType: "VARCHAR", expectedJson: `-9223372036854775808`}, + {arg: math.MaxFloat64, exasolType: "VARCHAR", expectedJson: `1.7976931348623157e+308`}, + {arg: math.SmallestNonzeroFloat64, exasolType: "VARCHAR", expectedJson: `5e-324`}, {arg: 123.456, exasolType: "VARCHAR", expectedJson: `123.456`}, + {arg: -123.456, exasolType: "VARCHAR", expectedJson: `-123.456`}, {arg: "text", exasolType: "CHAR", expectedJson: `"text"`}, + // BOOLEAN {arg: true, exasolType: "BOOLEAN", expectedJson: `true`}, {arg: false, exasolType: "BOOLEAN", expectedJson: `false`}, + // DECIMAL {arg: 17, exasolType: "DECIMAL", expectedJson: `17`}, {arg: 123.456, exasolType: "DECIMAL", expectedJson: `123.456`}, @@ -38,17 +47,35 @@ func TestConvertArgs(t *testing.T) { {arg: float64(123), exasolType: "DECIMAL", expectedJson: `123`}, {arg: float32(123.456), exasolType: "DECIMAL", expectedJson: `123.456`}, {arg: float64(123.456), exasolType: "DECIMAL", expectedJson: `123.456`}, + {arg: math.MaxInt64, exasolType: "DECIMAL", expectedJson: `9223372036854775807`}, + {arg: math.MinInt64, exasolType: "DECIMAL", expectedJson: `-9223372036854775808`}, + {arg: math.MaxFloat64, exasolType: "DECIMAL", expectedJson: `1.7976931348623157e+308`}, + {arg: math.SmallestNonzeroFloat64, exasolType: "DECIMAL", expectedJson: `5e-324`}, {arg: "invalid", exasolType: "DECIMAL", expectedJson: `"invalid"`}, // No special handling for invalid values + // DOUBLE - {arg: 123.456, exasolType: "DOUBLE", expectedJson: `123.456000`}, - {arg: 123, exasolType: "DOUBLE", expectedJson: `123.000000`}, - {arg: int(123), exasolType: "DOUBLE", expectedJson: `123.000000`}, - {arg: int32(123), exasolType: "DOUBLE", expectedJson: `123.000000`}, - {arg: int64(123), exasolType: "DOUBLE", expectedJson: `123.000000`}, - {arg: float32(123), exasolType: "DOUBLE", expectedJson: `123.000000`}, - {arg: float64(123), exasolType: "DOUBLE", expectedJson: `123.000000`}, - {arg: float32(123.456), exasolType: "DOUBLE", expectedJson: `123.456001`}, // Float32 rounding error is OK - {arg: float64(123.456), exasolType: "DOUBLE", expectedJson: `123.456000`}, + {arg: 123, exasolType: "DOUBLE", expectedJson: `123.0`}, + {arg: -123, exasolType: "DOUBLE", expectedJson: `-123.0`}, + {arg: math.MinInt64, exasolType: "DOUBLE", expectedJson: `-9223372036854776000.0`}, // rounding error acceptable + {arg: math.MaxInt64, exasolType: "DOUBLE", expectedJson: `9223372036854776000.0`}, // rounding error acceptable + {arg: math.MaxFloat64, exasolType: "DOUBLE", expectedJson: `1.7976931348623157e+308`}, + {arg: math.SmallestNonzeroFloat64, exasolType: "DOUBLE", expectedJson: `5e-324`}, // rounding error acceptable + {arg: 123.456, exasolType: "DOUBLE", expectedJson: `123.456`}, + {arg: -123.456, exasolType: "DOUBLE", expectedJson: `-123.456`}, + {arg: int(123), exasolType: "DOUBLE", expectedJson: `123.0`}, + {arg: int(-123), exasolType: "DOUBLE", expectedJson: `-123.0`}, + {arg: int32(123), exasolType: "DOUBLE", expectedJson: `123.0`}, + {arg: int32(-123), exasolType: "DOUBLE", expectedJson: `-123.0`}, + {arg: int64(123), exasolType: "DOUBLE", expectedJson: `123.0`}, + {arg: int64(-123), exasolType: "DOUBLE", expectedJson: `-123.0`}, + {arg: float32(123), exasolType: "DOUBLE", expectedJson: `123.0`}, + {arg: float32(-123), exasolType: "DOUBLE", expectedJson: `-123.0`}, + {arg: float64(123), exasolType: "DOUBLE", expectedJson: `123.0`}, + {arg: float64(-123), exasolType: "DOUBLE", expectedJson: `-123.0`}, + {arg: float32(123.456), exasolType: "DOUBLE", expectedJson: `123.45600128173828`}, // Float32 rounding error is OK + {arg: float32(-123.456), exasolType: "DOUBLE", expectedJson: `-123.45600128173828`}, // Float32 rounding error is OK + {arg: float64(123.456), exasolType: "DOUBLE", expectedJson: `123.456`}, + {arg: float64(-123.456), exasolType: "DOUBLE", expectedJson: `-123.456`}, {arg: "invalid", exasolType: "DOUBLE", expectedError: "E-EGOD-30: cannot convert argument 'invalid' of type 'string' to 'DOUBLE' type"}, // TIMESTAMP {arg: "some string", exasolType: "TIMESTAMP", expectedJson: `"some string"`}, // We assume strings are already formatted @@ -82,9 +109,10 @@ func TestConvertArgs(t *testing.T) { t.Errorf("Error converting arg: %v", err) return } + actualJson, err := json.Marshal(converted) if err != nil { - t.Errorf("Error marshalling converted arg: %v", err) + t.Errorf("Error marshalling converted arg '%v' of type %T: %v", converted, converted, err) return } if string(actualJson) != testCase.expectedJson { diff --git a/pkg/connection/result_set_test.go b/pkg/connection/result_set_test.go index da3c61a..3b6ea7f 100644 --- a/pkg/connection/result_set_test.go +++ b/pkg/connection/result_set_test.go @@ -6,14 +6,11 @@ import ( "database/sql/driver" "errors" "fmt" - "log" - "os" "reflect" "testing" "github.com/exasol/exasol-driver-go/internal/config" "github.com/exasol/exasol-driver-go/pkg/connection/wsconn" - "github.com/exasol/exasol-driver-go/pkg/logger" "github.com/exasol/exasol-driver-go/pkg/types" "github.com/stretchr/testify/suite" @@ -30,11 +27,9 @@ func TestResultSetSuite(t *testing.T) { func (suite *ResultSetTestSuite) SetupTest() { suite.websocketMock = wsconn.CreateWebsocketConnectionMock() - logger.SetTraceLogger(log.New(os.Stderr, "[TestResultSetSuite] ", log.LstdFlags|log.Lshortfile)) } func (suite *ResultSetTestSuite) TearDownTest() { - logger.SetTraceLogger(nil) } func (suite *ResultSetTestSuite) TestColumnTypeDatabaseTypeName() { From e62c0580f1344aecbc89782e6629c8caab2fe9ab Mon Sep 17 00:00:00 2001 From: Christoph Pirkl Date: Thu, 27 Jun 2024 12:19:45 +0200 Subject: [PATCH 02/12] #113: Fix int conversion error --- itest/integration_test.go | 48 +++++++++++++++++++++--------------- pkg/connection/result_set.go | 19 ++++++++++++-- 2 files changed, 45 insertions(+), 22 deletions(-) diff --git a/itest/integration_test.go b/itest/integration_test.go index 1eb56d5..c0296ac 100644 --- a/itest/integration_test.go +++ b/itest/integration_test.go @@ -18,7 +18,6 @@ import ( "github.com/exasol/exasol-driver-go" "github.com/exasol/exasol-driver-go/pkg/dsn" "github.com/exasol/exasol-driver-go/pkg/integrationTesting" - "github.com/exasol/exasol-driver-go/pkg/logger" "github.com/stretchr/testify/assert" "go.uber.org/goleak" @@ -195,9 +194,10 @@ func (suite *IntegrationTestSuite) TestFetch() { suite.Equal(10000, len(result)) } +// https://github.com/exasol/exasol-driver-go/issues/113 func (suite *IntegrationTestSuite) TestFetchLargeInteger() { - logger.EnableTraceLogger() database := suite.openConnection(suite.createDefaultConfig()) + defer database.Close() number := 100000000 rows, err := database.Query(fmt.Sprintf("SELECT %d", number)) suite.NoError(err) @@ -206,7 +206,7 @@ func (suite *IntegrationTestSuite) TestFetchLargeInteger() { err = rows.Scan(&result) suite.NoError(err) defer rows.Close() - suite.Equal(number, result) + suite.Equal(int64(number), result) } func (suite *IntegrationTestSuite) TestExecuteWithError() { @@ -259,6 +259,8 @@ func (suite *IntegrationTestSuite) TestQueryDataTypesCast() { }{ // DECIMAL {"decimal to int64", "1", "DECIMAL(18,0)", new(int64), int64(1), dereferenceInt64}, + {"large decimal to int64", "100000000", "DECIMAL(18,0)", new(int64), int64(100000000), dereferenceInt64}, + {"large negative decimal to int64", "-100000000", "DECIMAL(18,0)", new(int64), int64(-100000000), dereferenceInt64}, {"decimal to int", "1", "DECIMAL(18,0)", new(int), 1, dereferenceInt}, {"decimal to float", "1", "DECIMAL(18,0)", new(float64), 1.0, dereferenceFloat64}, {"decimal to string", "1", "DECIMAL(18,0)", new(string), "1", dereferenceString}, @@ -336,6 +338,10 @@ func (suite *IntegrationTestSuite) TestPreparedStatementArgsConverted() { int64TestCase(-1, "DECIMAL(18,0)", -1), int64TestCase(1.1, "DECIMAL(18,0)", 1), int64TestCase(-1.1, "DECIMAL(18,0)", -1), + int64TestCase(100000000, "DECIMAL(18,0)", 100000000), + int64TestCase(-100000000, "DECIMAL(18,0)", -100000000), + int64TestCase(100000000, "DECIMAL(18,2)", 100000000), + int64TestCase(-100000000, "DECIMAL(18,2)", -100000000), int64TestCase(math.MaxInt64, "DECIMAL(36,0)", math.MaxInt64), int64TestCase(math.MinInt64, "DECIMAL(36,0)", math.MinInt64), @@ -350,6 +356,8 @@ func (suite *IntegrationTestSuite) TestPreparedStatementArgsConverted() { float64TestCase(-1, "DECIMAL(18,0)", -1), float64TestCase(1.123, "DECIMAL(18,3)", 1.123), float64TestCase(-1.123, "DECIMAL(18,3)", -1.123), + float64TestCase(100000000.12, "DECIMAL(18,2)", 100000000.12), + float64TestCase(-100000000.12, "DECIMAL(18,2)", -100000000.12), float32TestCase(1, "DECIMAL(18,0)", 1), float32TestCase(-1, "DECIMAL(18,0)", -1), @@ -582,9 +590,9 @@ func (suite *IntegrationTestSuite) TestSimpleImportStatement() { suite.assertTableResult(rows, []string{"A", "B"}, [][]interface{}{ - {float64(11), "test1"}, - {float64(12), "test2"}, - {float64(13), "test3"}, + {int64(11), "test1"}, + {int64(12), "test2"}, + {int64(13), "test3"}, }, ) } @@ -636,7 +644,7 @@ func (suite *IntegrationTestSuite) TestSimpleImportStatementBigFile() { suite.NoError(err, "count query should work") suite.assertTableResult(rows, []string{"COUNT(*)"}, [][]interface{}{ - {float64(20000)}, + {int64(20000)}, }, ) @@ -645,9 +653,9 @@ func (suite *IntegrationTestSuite) TestSimpleImportStatementBigFile() { suite.assertTableResult(rows, []string{"A", "B", "C", "D", "E", "F", "G"}, [][]interface{}{ - {float64(0), exampleData, exampleData, exampleData, exampleData, exampleData, exampleData}, - {float64(1), exampleData, exampleData, exampleData, exampleData, exampleData, exampleData}, - {float64(2), exampleData, exampleData, exampleData, exampleData, exampleData, exampleData}, + {int64(0), exampleData, exampleData, exampleData, exampleData, exampleData, exampleData}, + {int64(1), exampleData, exampleData, exampleData, exampleData, exampleData, exampleData}, + {int64(2), exampleData, exampleData, exampleData, exampleData, exampleData, exampleData}, }, ) } @@ -709,12 +717,12 @@ func (suite *IntegrationTestSuite) TestMultiImportStatement() { suite.assertTableResult(rows, []string{"A", "B"}, [][]interface{}{ - {float64(11), "test1"}, - {float64(12), "test2"}, - {float64(13), "test3"}, - {float64(21), "test4"}, - {float64(22), "test5"}, - {float64(23), "test6"}, + {int64(11), "test1"}, + {int64(12), "test2"}, + {int64(13), "test3"}, + {int64(21), "test4"}, + {int64(22), "test5"}, + {int64(23), "test6"}, }, ) } @@ -744,7 +752,7 @@ func (suite *IntegrationTestSuite) TestImportStatementWithCRFile() { tableName := "TEST_TABLE" _, _ = database.ExecContext(ctx, "CREATE SCHEMA "+schemaName) defer suite.cleanup(database, schemaName) - _, _ = database.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s.%s (a int , b VARCHAR(20))", schemaName, tableName)) + _, _ = database.ExecContext(ctx, fmt.Sprintf("CREATE TABLE %s.%s (a int, b VARCHAR(20))", schemaName, tableName)) result, err := database.ExecContext(ctx, fmt.Sprintf(`IMPORT INTO %s.%s FROM LOCAL CSV FILE '../testData/data_cr.csv' COLUMN SEPARATOR = ';' ENCODING = 'UTF-8' ROW SEPARATOR = 'CR'`, schemaName, tableName)) suite.NoError(err, "import should be successful") @@ -755,9 +763,9 @@ func (suite *IntegrationTestSuite) TestImportStatementWithCRFile() { suite.assertTableResult(rows, []string{"A", "B"}, [][]interface{}{ - {float64(11), "test1"}, - {float64(12), "test2"}, - {float64(13), "test3"}, + {int64(11), "test1"}, + {int64(12), "test2"}, + {int64(13), "test3"}, }, ) } diff --git a/pkg/connection/result_set.go b/pkg/connection/result_set.go index e79ecd2..0667ab7 100644 --- a/pkg/connection/result_set.go +++ b/pkg/connection/result_set.go @@ -90,8 +90,8 @@ func (results *QueryResults) Next(dest []driver.Value) error { } } - for i := range dest { - dest[i] = results.data.Data[i][results.rowPointer] + for columnIndex := range dest { + dest[columnIndex] = results.getColumnValue(columnIndex) } results.rowPointer = results.rowPointer + 1 @@ -100,6 +100,21 @@ func (results *QueryResults) Next(dest []driver.Value) error { return nil } +func (results *QueryResults) getColumnValue(columnIndex int) driver.Value { + value := results.data.Data[columnIndex][results.rowPointer] + columnType := results.data.Columns[columnIndex].DataType + return convertValue(value, columnType) +} + +func convertValue(value any, columnType types.SqlQueryColumnType) driver.Value { + if columnType.Type == "DECIMAL" && columnType.Scale != nil && *columnType.Scale == 0 { + if floatValue, ok := value.(float64); ok { + return int64(floatValue) + } + } + return value +} + func (results *QueryResults) fetchNextRowChunk() error { chunk := &types.SqlQueryResponseResultSetData{} err := results.con.Send(context.Background(), &types.FetchCommand{ From d4655f92f5d5814c816033aa0bc5d534b81af2b6 Mon Sep 17 00:00:00 2001 From: Christoph Pirkl Date: Thu, 27 Jun 2024 13:32:37 +0200 Subject: [PATCH 03/12] #113: Add unit tests --- pkg/connection/result_set.go | 9 +++++++- pkg/connection/result_set_test.go | 36 ++++++++++++++++++++++++++++--- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/pkg/connection/result_set.go b/pkg/connection/result_set.go index 0667ab7..a0d4f37 100644 --- a/pkg/connection/result_set.go +++ b/pkg/connection/result_set.go @@ -106,8 +106,11 @@ func (results *QueryResults) getColumnValue(columnIndex int) driver.Value { return convertValue(value, columnType) } +// Result set data contains values as float64 even for whole numbers. This causes an error when calling "Scan()" with an integer value. +// As a workaround we convert the float64 to int for DECIMAL columns with scale 0. +// See https://github.com/exasol/exasol-driver-go/issues/113 for details. func convertValue(value any, columnType types.SqlQueryColumnType) driver.Value { - if columnType.Type == "DECIMAL" && columnType.Scale != nil && *columnType.Scale == 0 { + if isIntegerColumn(columnType) { if floatValue, ok := value.(float64); ok { return int64(floatValue) } @@ -115,6 +118,10 @@ func convertValue(value any, columnType types.SqlQueryColumnType) driver.Value { return value } +func isIntegerColumn(columnType types.SqlQueryColumnType) bool { + return columnType.Type == "DECIMAL" && columnType.Scale != nil && *columnType.Scale == 0 +} + func (results *QueryResults) fetchNextRowChunk() error { chunk := &types.SqlQueryResponseResultSetData{} err := results.con.Send(context.Background(), &types.FetchCommand{ diff --git a/pkg/connection/result_set_test.go b/pkg/connection/result_set_test.go index 3b6ea7f..d8a7dbb 100644 --- a/pkg/connection/result_set_test.go +++ b/pkg/connection/result_set_test.go @@ -29,9 +29,6 @@ func (suite *ResultSetTestSuite) SetupTest() { suite.websocketMock = wsconn.CreateWebsocketConnectionMock() } -func (suite *ResultSetTestSuite) TearDownTest() { -} - func (suite *ResultSetTestSuite) TestColumnTypeDatabaseTypeName() { data := types.SqlQueryResponseResultSetData{Columns: []types.SqlQueryColumn{ {DataType: types.SqlQueryColumnType{Type: "boolean"}}, @@ -280,6 +277,39 @@ func (suite *ResultSetTestSuite) TestCloseSendsCloseResultSetCommand() { suite.NoError(queryResults.Close()) } +func (suite *ResultSetTestSuite) TestConvertValue() { + createType := func(dataType string, scale int64) types.SqlQueryColumnType { + return types.SqlQueryColumnType{Type: dataType, Scale: &scale} + } + createTypeWithoutScale := func(dataType string) types.SqlQueryColumnType { + return types.SqlQueryColumnType{Type: dataType, Scale: nil} + } + decimalTypeZeroScale := createType("DECIMAL", 0) + for i, testCase := range []struct { + value any + columnType types.SqlQueryColumnType + expectedValue driver.Value + }{ + {float64(1.1), decimalTypeZeroScale, int64(1)}, // Only this combination will convert the value + {1.1, decimalTypeZeroScale, int64(1)}, + {float32(1.1), decimalTypeZeroScale, float32(1.1)}, + {"string", decimalTypeZeroScale, "string"}, + {true, decimalTypeZeroScale, true}, + {1, decimalTypeZeroScale, 1}, + {int32(1), decimalTypeZeroScale, int32(1)}, + {int64(1), decimalTypeZeroScale, int64(1)}, + {float64(1.1), createType("DECIMAL", -1), float64(1.1)}, + {float64(1.1), createType("DECIMAL", 1), float64(1.1)}, + {float64(1.1), createType("OTHER", 0), float64(1.1)}, + {float64(1.1), createTypeWithoutScale("DECIMAL"), float64(1.1)}, + } { + suite.Run(fmt.Sprintf("TestConvertValue %d value %v type %v", i, testCase.value, testCase.columnType), func() { + result := convertValue(testCase.value, testCase.columnType) + suite.Equal(testCase.expectedValue, result) + }) + } +} + func (suite *ResultSetTestSuite) createResultSet() QueryResults { return QueryResults{ data: &types.SqlQueryResponseResultSetData{ From 2947f5857a0c2744fd6c09d92933fd1a432b59e8 Mon Sep 17 00:00:00 2001 From: Christoph Pirkl Date: Thu, 27 Jun 2024 13:35:56 +0200 Subject: [PATCH 04/12] Add changelog entry --- doc/changes/changes_1.0.9.md | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/doc/changes/changes_1.0.9.md b/doc/changes/changes_1.0.9.md index d7e5b33..57a9706 100644 --- a/doc/changes/changes_1.0.9.md +++ b/doc/changes/changes_1.0.9.md @@ -1,10 +1,15 @@ -# Exasol Driver go 1.0.9, released 2024-??-?? +# Exasol Driver go 1.0.9, released 2024-06-?? -Code name: +Code name: Fix reading int values ## Summary -## Features +This release fixes an issue when calling `rows.Scan(&result)` with an int value. This failed for large values like 100000000 with the following error: -* ISSUE_NUMBER: description +``` +sql: Scan error on column index 0, name "100000000": converting driver.Value type float64 ("1e+08") to a int64: invalid syntax +``` +## Bugfixes + +* #113: Fixed `Scan()` with large integer numbers From 41d920dd8e55bfbbf2f706f849d6d3bb5ecbbcfb Mon Sep 17 00:00:00 2001 From: Christoph Pirkl Date: Thu, 27 Jun 2024 13:42:40 +0200 Subject: [PATCH 05/12] Add comment --- pkg/connection/prepared_stmt_converter.go | 4 ++++ pkg/connection/prepared_stmt_converter_test.go | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/pkg/connection/prepared_stmt_converter.go b/pkg/connection/prepared_stmt_converter.go index d155346..002c808 100644 --- a/pkg/connection/prepared_stmt_converter.go +++ b/pkg/connection/prepared_stmt_converter.go @@ -63,6 +63,10 @@ type jsonDoubleValueStruct struct { value float64 } +// MarshalJSON ensures that the double value is always formatted with a decimal point +// even if it's an integer. This is necessary because Exasol expects a decimal point +// for double values. +// See https://github.com/exasol/exasol-driver-go/issues/108 for details. func (j *jsonDoubleValueStruct) MarshalJSON() ([]byte, error) { r, err := json.Marshal(j.value) if err != nil { diff --git a/pkg/connection/prepared_stmt_converter_test.go b/pkg/connection/prepared_stmt_converter_test.go index 4b50384..23d42ee 100644 --- a/pkg/connection/prepared_stmt_converter_test.go +++ b/pkg/connection/prepared_stmt_converter_test.go @@ -109,7 +109,6 @@ func TestConvertArgs(t *testing.T) { t.Errorf("Error converting arg: %v", err) return } - actualJson, err := json.Marshal(converted) if err != nil { t.Errorf("Error marshalling converted arg '%v' of type %T: %v", converted, converted, err) From e6fd5932462b6146a9bb01215424b74bd0368bd5 Mon Sep 17 00:00:00 2001 From: Christoph Pirkl Date: Thu, 27 Jun 2024 13:54:14 +0200 Subject: [PATCH 06/12] Add note for already closed issue --- doc/changes/changes_1.0.9.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/changes/changes_1.0.9.md b/doc/changes/changes_1.0.9.md index 57a9706..910bb57 100644 --- a/doc/changes/changes_1.0.9.md +++ b/doc/changes/changes_1.0.9.md @@ -10,6 +10,9 @@ This release fixes an issue when calling `rows.Scan(&result)` with an int value. sql: Scan error on column index 0, name "100000000": converting driver.Value type float64 ("1e+08") to a int64: invalid syntax ``` +The release also now returns the correct error from `rows.Err()`. Before, this only returned `driver.ErrBadConn`. + ## Bugfixes * #113: Fixed `Scan()` with large integer numbers +* #111: Return correct error from `rows.Err()` From 71916de98d0a9b3d6c4fcbc213344b4b8d7ec074 Mon Sep 17 00:00:00 2001 From: Christoph Pirkl Date: Thu, 27 Jun 2024 13:58:18 +0200 Subject: [PATCH 07/12] Update release date --- doc/changes/changes_1.0.9.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/changes/changes_1.0.9.md b/doc/changes/changes_1.0.9.md index 910bb57..1e5a1ce 100644 --- a/doc/changes/changes_1.0.9.md +++ b/doc/changes/changes_1.0.9.md @@ -1,4 +1,4 @@ -# Exasol Driver go 1.0.9, released 2024-06-?? +# Exasol Driver go 1.0.9, released 2024-06-27 Code name: Fix reading int values From 78377e682582cd4ab55117eac476934203fe14dd Mon Sep 17 00:00:00 2001 From: Christoph Pirkl Date: Thu, 27 Jun 2024 14:54:04 +0200 Subject: [PATCH 08/12] Restrict conversion to whole numbers --- pkg/connection/result_set.go | 7 ++++++- pkg/connection/result_set_test.go | 6 ++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pkg/connection/result_set.go b/pkg/connection/result_set.go index a0d4f37..1956f22 100644 --- a/pkg/connection/result_set.go +++ b/pkg/connection/result_set.go @@ -5,6 +5,7 @@ import ( "database/sql" "database/sql/driver" "io" + "math" "reflect" "sync" @@ -111,7 +112,7 @@ func (results *QueryResults) getColumnValue(columnIndex int) driver.Value { // See https://github.com/exasol/exasol-driver-go/issues/113 for details. func convertValue(value any, columnType types.SqlQueryColumnType) driver.Value { if isIntegerColumn(columnType) { - if floatValue, ok := value.(float64); ok { + if floatValue, ok := value.(float64); ok && isIntegerValue(floatValue) { return int64(floatValue) } } @@ -122,6 +123,10 @@ func isIntegerColumn(columnType types.SqlQueryColumnType) bool { return columnType.Type == "DECIMAL" && columnType.Scale != nil && *columnType.Scale == 0 } +func isIntegerValue(value float64) bool { + return value == math.Trunc(value) +} + func (results *QueryResults) fetchNextRowChunk() error { chunk := &types.SqlQueryResponseResultSetData{} err := results.con.Send(context.Background(), &types.FetchCommand{ diff --git a/pkg/connection/result_set_test.go b/pkg/connection/result_set_test.go index d8a7dbb..26212e5 100644 --- a/pkg/connection/result_set_test.go +++ b/pkg/connection/result_set_test.go @@ -290,8 +290,10 @@ func (suite *ResultSetTestSuite) TestConvertValue() { columnType types.SqlQueryColumnType expectedValue driver.Value }{ - {float64(1.1), decimalTypeZeroScale, int64(1)}, // Only this combination will convert the value - {1.1, decimalTypeZeroScale, int64(1)}, + {float64(1), decimalTypeZeroScale, int64(1)}, // Only this combination will convert the value + {float64(-1), decimalTypeZeroScale, int64(-1)}, // Only this combination will convert the value + {float64(10000000000), decimalTypeZeroScale, int64(10000000000)}, // Only this combination will convert the value + {1.1, decimalTypeZeroScale, float64(1.1)}, {float32(1.1), decimalTypeZeroScale, float32(1.1)}, {"string", decimalTypeZeroScale, "string"}, {true, decimalTypeZeroScale, true}, From ddbaaa18a4d32727eb9e8505a2331079f2b9e156 Mon Sep 17 00:00:00 2001 From: Christoph Pirkl Date: Thu, 27 Jun 2024 15:05:48 +0200 Subject: [PATCH 09/12] Add integration tests --- doc/changes/changes_1.0.9.md | 8 +++++++- itest/integration_test.go | 20 +++++++++++--------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/doc/changes/changes_1.0.9.md b/doc/changes/changes_1.0.9.md index 1e5a1ce..98b0ed6 100644 --- a/doc/changes/changes_1.0.9.md +++ b/doc/changes/changes_1.0.9.md @@ -7,7 +7,13 @@ Code name: Fix reading int values This release fixes an issue when calling `rows.Scan(&result)` with an int value. This failed for large values like 100000000 with the following error: ``` -sql: Scan error on column index 0, name "100000000": converting driver.Value type float64 ("1e+08") to a int64: invalid syntax +sql: Scan error on column index 0, name "COL": converting driver.Value type float64 ("1e+08") to a int64: invalid syntax +``` + +Please note that reading non-integer numbers like `1.1` into a `int64` variable will still fail with the following error message: + +``` +sql: Scan error on column index 0, name "COL": converting driver.Value type string ("1.1") to a int64: invalid syntax ``` The release also now returns the correct error from `rows.Err()`. Before, this only returned `driver.ErrBadConn`. diff --git a/itest/integration_test.go b/itest/integration_test.go index c0296ac..f6d9fe7 100644 --- a/itest/integration_test.go +++ b/itest/integration_test.go @@ -338,6 +338,8 @@ func (suite *IntegrationTestSuite) TestPreparedStatementArgsConverted() { int64TestCase(-1, "DECIMAL(18,0)", -1), int64TestCase(1.1, "DECIMAL(18,0)", 1), int64TestCase(-1.1, "DECIMAL(18,0)", -1), + int64TestCase(1.1, "DECIMAL(18,2)", 1), + int64TestCase(-1.1, "DECIMAL(18,2)", -1), int64TestCase(100000000, "DECIMAL(18,0)", 100000000), int64TestCase(-100000000, "DECIMAL(18,0)", -100000000), int64TestCase(100000000, "DECIMAL(18,2)", 100000000), @@ -435,14 +437,14 @@ func (suite *IntegrationTestSuite) TestPreparedStatementArgsConversionFails() { func (suite *IntegrationTestSuite) TestScanTypeUnsupported() { for i, testCase := range []struct { - testDescription string - sqlValue any - sqlType string - scanDest any - expectedError string + sqlValue any + sqlType string + scanDest any + expectedError string }{ - {"timestamp", time.Date(2024, time.June, 18, 17, 22, 13, 123456789, time.UTC), "TIMESTAMP", new(time.Time), `sql: Scan error on column index 0, name "COL": unsupported Scan, storing driver.Value type string into type *time.Time`}, - {"timestamp with local time zone", time.Date(2024, time.June, 18, 17, 22, 13, 123456789, time.UTC), "TIMESTAMP WITH LOCAL TIME ZONE", new(time.Time), `sql: Scan error on column index 0, name "COL": unsupported Scan, storing driver.Value type string into type *time.Time`}, + {1.1, "DECIMAL(4,2)", new(int64), `converting driver.Value type string ("1.1") to a int64: invalid syntax`}, + {time.Date(2024, time.June, 18, 17, 22, 13, 123456789, time.UTC), "TIMESTAMP", new(time.Time), `unsupported Scan, storing driver.Value type string into type *time.Time`}, + {time.Date(2024, time.June, 18, 17, 22, 13, 123456789, time.UTC), "TIMESTAMP WITH LOCAL TIME ZONE", new(time.Time), `unsupported Scan, storing driver.Value type string into type *time.Time`}, } { database := suite.openConnection(suite.createDefaultConfig().Autocommit(false)) schemaName := "DATATYPE_TEST" @@ -450,7 +452,7 @@ func (suite *IntegrationTestSuite) TestScanTypeUnsupported() { onError(err) defer suite.cleanup(database, schemaName) - suite.Run(fmt.Sprintf("Scan fails %02d %s: %s", i, testCase.testDescription, testCase.sqlType), func() { + suite.Run(fmt.Sprintf("Scan fails %02d %s", i, testCase.sqlType), func() { tableName := fmt.Sprintf("%s.TAB_%d", schemaName, i) _, err = database.Exec(fmt.Sprintf("CREATE TABLE %s (col %s)", tableName, testCase.sqlType)) onError(err) @@ -463,7 +465,7 @@ func (suite *IntegrationTestSuite) TestScanTypeUnsupported() { defer rows.Close() suite.True(rows.Next(), "should have one row") err = rows.Scan(testCase.scanDest) - suite.EqualError(err, testCase.expectedError) + suite.EqualError(err, `sql: Scan error on column index 0, name "COL": `+testCase.expectedError) }) } } From 0916b5c01c45aeda390bfa5261983323f4fcf7c3 Mon Sep 17 00:00:00 2001 From: Christoph Pirkl Date: Thu, 27 Jun 2024 19:36:59 +0200 Subject: [PATCH 10/12] Pass context correctly --- itest/integration_test.go | 4 ---- pkg/connection/connection.go | 18 +++++++++--------- pkg/connection/result_set.go | 5 +++-- pkg/connection/result_set_test.go | 1 + pkg/connection/statement.go | 15 ++++++++------- pkg/connection/util.go | 5 +++-- 6 files changed, 24 insertions(+), 24 deletions(-) diff --git a/itest/integration_test.go b/itest/integration_test.go index f6d9fe7..9001423 100644 --- a/itest/integration_test.go +++ b/itest/integration_test.go @@ -43,12 +43,8 @@ func TestIntegrationSuite(t *testing.T) { func (suite *IntegrationTestSuite) SetupSuite() { suite.ctx = context.Background() - var err error suite.exasol = integrationTesting.StartDbSetup(&suite.Suite) connectionInfo := suite.exasol.ConnectionInfo - if err != nil { - suite.FailNowf("setup failed", "failed to get connection info: %v", err) - } suite.port = connectionInfo.Port suite.host = connectionInfo.Host } diff --git a/pkg/connection/connection.go b/pkg/connection/connection.go index a50f6b8..7fa87cf 100644 --- a/pkg/connection/connection.go +++ b/pkg/connection/connection.go @@ -47,11 +47,11 @@ func (c *Connection) ExecContext(ctx context.Context, query string, args []drive } func (c *Connection) Exec(query string, args []driver.Value) (driver.Result, error) { - return c.exec(context.Background(), query, args) + return c.exec(c.Ctx, query, args) } func (c *Connection) Query(query string, args []driver.Value) (driver.Rows, error) { - return c.query(context.Background(), query, args) + return c.query(c.Ctx, query, args) } func (c *Connection) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { @@ -64,7 +64,7 @@ func (c *Connection) PrepareContext(ctx context.Context, query string) (driver.S if err != nil { return nil, err } - return c.createStatement(response), nil + return c.createStatement(ctx, response), nil } func (c *Connection) createPreparedStatement(ctx context.Context, query string) (*types.CreatePreparedStatementResponse, error) { @@ -82,16 +82,16 @@ func (c *Connection) createPreparedStatement(ctx context.Context, query string) return response, nil } -func (c *Connection) createStatement(result *types.CreatePreparedStatementResponse) *Statement { - return NewStatement(c, result) +func (c *Connection) createStatement(ctx context.Context, result *types.CreatePreparedStatementResponse) *Statement { + return NewStatement(ctx, c, result) } func (c *Connection) Prepare(query string) (driver.Stmt, error) { - return c.PrepareContext(context.Background(), query) + return c.PrepareContext(c.Ctx, query) } func (c *Connection) Close() error { - return c.close(context.Background()) + return c.close(c.Ctx) } func (c *Connection) Begin() (driver.Tx, error) { @@ -125,7 +125,7 @@ func (c *Connection) query(ctx context.Context, query string, args []driver.Valu if err != nil { return nil, err } - return ToRow(result, c) + return ToRow(ctx, result, c) } func (c *Connection) executeSimpleWithRows(ctx context.Context, query string) (driver.Rows, error) { @@ -133,7 +133,7 @@ func (c *Connection) executeSimpleWithRows(ctx context.Context, query string) (d if err != nil { return nil, err } - return ToRow(result, c) + return ToRow(ctx, result, c) } func (c *Connection) executePreparedStatement(ctx context.Context, s *types.CreatePreparedStatementResponse, args []driver.Value) (*types.SqlQueriesResponse, error) { diff --git a/pkg/connection/result_set.go b/pkg/connection/result_set.go index 1956f22..309bf04 100644 --- a/pkg/connection/result_set.go +++ b/pkg/connection/result_set.go @@ -15,6 +15,7 @@ import ( type QueryResults struct { sync.Mutex // guards following + ctx context.Context data *types.SqlQueryResponseResultSetData con *Connection fetchedRows int @@ -69,7 +70,7 @@ func (results *QueryResults) Close() error { if results.data.ResultSetHandle == 0 { return nil } - return results.con.Send(context.Background(), &types.CloseResultSetCommand{ + return results.con.Send(results.ctx, &types.CloseResultSetCommand{ Command: types.Command{Command: "closeResultSet"}, ResultSetHandles: []int{results.data.ResultSetHandle}, }, nil) @@ -129,7 +130,7 @@ func isIntegerValue(value float64) bool { func (results *QueryResults) fetchNextRowChunk() error { chunk := &types.SqlQueryResponseResultSetData{} - err := results.con.Send(context.Background(), &types.FetchCommand{ + err := results.con.Send(results.ctx, &types.FetchCommand{ Command: types.Command{Command: "fetch"}, ResultSetHandle: results.data.ResultSetHandle, StartPosition: results.totalRowPointer, diff --git a/pkg/connection/result_set_test.go b/pkg/connection/result_set_test.go index 26212e5..6e6ffd0 100644 --- a/pkg/connection/result_set_test.go +++ b/pkg/connection/result_set_test.go @@ -314,6 +314,7 @@ func (suite *ResultSetTestSuite) TestConvertValue() { func (suite *ResultSetTestSuite) createResultSet() QueryResults { return QueryResults{ + ctx: context.Background(), data: &types.SqlQueryResponseResultSetData{ ResultSetHandle: 1, NumRows: 2, NumRowsInMessage: 2, Columns: []types.SqlQueryColumn{{}, {}}, }, diff --git a/pkg/connection/statement.go b/pkg/connection/statement.go index 8549d8e..5d3bb34 100644 --- a/pkg/connection/statement.go +++ b/pkg/connection/statement.go @@ -11,14 +11,15 @@ import ( ) type Statement struct { + ctx context.Context connection *Connection statementHandle int columns []types.SqlQueryColumn numInput int } -func NewStatement(connection *Connection, response *types.CreatePreparedStatementResponse) *Statement { - return &Statement{connection: connection, statementHandle: response.StatementHandle, columns: response.ParameterData.Columns, numInput: response.ParameterData.NumColumns} +func NewStatement(ctx context.Context, connection *Connection, response *types.CreatePreparedStatementResponse) *Statement { + return &Statement{ctx: ctx, connection: connection, statementHandle: response.StatementHandle, columns: response.ParameterData.Columns, numInput: response.ParameterData.NumColumns} } func (s *Statement) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { @@ -30,15 +31,15 @@ func (s *Statement) QueryContext(ctx context.Context, args []driver.NamedValue) if err != nil { return nil, err } - return ToRow(result, s.connection) + return ToRow(ctx, result, s.connection) } func (s *Statement) Query(args []driver.Value) (driver.Rows, error) { - result, err := s.executePreparedStatement(context.Background(), args) + result, err := s.executePreparedStatement(s.ctx, args) if err != nil { return nil, err } - return ToRow(result, s.connection) + return ToRow(s.ctx, result, s.connection) } func (s *Statement) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { @@ -54,7 +55,7 @@ func (s *Statement) ExecContext(ctx context.Context, args []driver.NamedValue) ( } func (s *Statement) Exec(args []driver.Value) (driver.Result, error) { - result, err := s.executePreparedStatement(context.Background(), args) + result, err := s.executePreparedStatement(s.ctx, args) if err != nil { return nil, err } @@ -65,7 +66,7 @@ func (s *Statement) Close() error { if s.connection.IsClosed { return driver.ErrBadConn } - return s.connection.Send(context.Background(), &types.ClosePreparedStatementCommand{ + return s.connection.Send(s.ctx, &types.ClosePreparedStatementCommand{ Command: types.Command{Command: "closePreparedStatement"}, StatementHandle: s.statementHandle, }, nil) diff --git a/pkg/connection/util.go b/pkg/connection/util.go index 8c1b463..d66568b 100644 --- a/pkg/connection/util.go +++ b/pkg/connection/util.go @@ -1,20 +1,21 @@ package connection import ( + "context" "database/sql/driver" "encoding/json" "github.com/exasol/exasol-driver-go/pkg/types" ) -func ToRow(result *types.SqlQueriesResponse, con *Connection) (driver.Rows, error) { +func ToRow(ctx context.Context, result *types.SqlQueriesResponse, con *Connection) (driver.Rows, error) { resultSet := &types.SqlQueryResponseResultSet{} err := json.Unmarshal(result.Results[0], resultSet) if err != nil { return nil, err } - return &QueryResults{data: &resultSet.ResultSet, con: con}, nil + return &QueryResults{ctx: ctx, data: &resultSet.ResultSet, con: con}, nil } func ToResult(result *types.SqlQueriesResponse) (driver.Result, error) { From ad6bdad2fe28b70c5d8f41d7935c6a643aeb0cb7 Mon Sep 17 00:00:00 2001 From: Christoph Pirkl Date: Thu, 27 Jun 2024 19:37:12 +0200 Subject: [PATCH 11/12] Remove unused dependency --- go.mod | 1 - go.sum | 12 ++++-------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/go.mod b/go.mod index aeddbc4..1dcb5db 100644 --- a/go.mod +++ b/go.mod @@ -18,5 +18,4 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect github.com/stretchr/objx v0.5.2 // indirect - golang.org/x/net v0.26.0 // indirect ) diff --git a/go.sum b/go.sum index c16a1cc..2a7847d 100644 --- a/go.sum +++ b/go.sum @@ -1,21 +1,17 @@ -github.com/antchfx/xmlquery v1.3.18 h1:FSQ3wMuphnPPGJOFhvc+cRQ2CT/rUj4cyQXkJcjOwz0= -github.com/antchfx/xmlquery v1.3.18/go.mod h1:Afkq4JIeXut75taLSuI31ISJ/zeq+3jG7TunF7noreA= -github.com/antchfx/xpath v1.2.5 h1:hqZ+wtQ+KIOV/S3bGZcIhpgYC26um2bZYP2KVGcR7VY= -github.com/antchfx/xpath v1.2.5/go.mod h1:i54GszH55fYfBmoZXapTHN8T8tkcHfRgLyVwwqzXNcs= +github.com/antchfx/xmlquery v1.4.0 h1:xg2HkfcRK2TeTbdb0m1jxCYnvsPaGY/oeZWTGqX/0hA= +github.com/antchfx/xmlquery v1.4.0/go.mod h1:Ax2aeaeDjfIw3CwXKDQ0GkwZ6QlxoChlIBP+mGnDFjI= +github.com/antchfx/xpath v1.3.0 h1:nTMlzGAK3IJ0bPpME2urTuFL76o4A96iYvoKFHRXJgc= +github.com/antchfx/xpath v1.3.0/go.mod h1:i54GszH55fYfBmoZXapTHN8T8tkcHfRgLyVwwqzXNcs= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/exasol/error-reporting-go v0.2.0 h1:nKIe4zYiTHbYrKJRlSNJcmGjTJCZredDh5akVHfIbRs= github.com/exasol/error-reporting-go v0.2.0/go.mod h1:lUzRJqKLiSuYpqRUN2LVyj08WeHzhMEC/8Gmgtuqh1Y= -github.com/exasol/exasol-test-setup-abstraction-server/go-client v0.3.6 h1:zFDtIhX1M52fwGzwSXL4o+JGC86qdsMNY20MaumCPgQ= -github.com/exasol/exasol-test-setup-abstraction-server/go-client v0.3.6/go.mod h1:MpOSQf+M12fO2DoIN6/dcABVodAkPmoPAYMZXd2Oefo= github.com/exasol/exasol-test-setup-abstraction-server/go-client v0.3.9 h1:vkOiwqum2hOK1WHgBop3TZrRGiygDRTvet8lzxP7Gl4= github.com/exasol/exasol-test-setup-abstraction-server/go-client v0.3.9/go.mod h1:g0gO9UJh2LOYlwJIzrw7c1QZJqEBSvYnAaOycu7M5/c= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= -github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= From af423b394a73c2263225b833cd1bd7a0ceeb069d Mon Sep 17 00:00:00 2001 From: Christoph Pirkl Date: Fri, 28 Jun 2024 08:25:48 +0200 Subject: [PATCH 12/12] Pass context --- pkg/connection/connection.go | 8 +++- pkg/connection/connection_test.go | 56 +++++++++++++------------- pkg/connection/result_set_test.go | 5 ++- pkg/connection/transaction.go | 9 +++-- pkg/connection/transaction_test.go | 20 ++++++--- pkg/connection/websocket_test.go | 26 ++++++------ pkg/connection/wsconn/wsconn_i_test.go | 13 ++++-- 7 files changed, 81 insertions(+), 56 deletions(-) diff --git a/pkg/connection/connection.go b/pkg/connection/connection.go index 7fa87cf..56ac3b5 100644 --- a/pkg/connection/connection.go +++ b/pkg/connection/connection.go @@ -86,6 +86,12 @@ func (c *Connection) createStatement(ctx context.Context, result *types.CreatePr return NewStatement(ctx, c, result) } +func (c *Connection) Ping(ctx context.Context) error { + fmt.Printf("Ping\n") + // FIXME + return nil +} + func (c *Connection) Prepare(query string) (driver.Stmt, error) { return c.PrepareContext(c.Ctx, query) } @@ -102,7 +108,7 @@ func (c *Connection) Begin() (driver.Tx, error) { if c.Config.Autocommit { return nil, errors.ErrAutocommitEnabled } - return NewTransaction(c), nil + return NewTransaction(c.Ctx, c), nil } func (c *Connection) query(ctx context.Context, query string, args []driver.Value) (driver.Rows, error) { diff --git a/pkg/connection/connection_test.go b/pkg/connection/connection_test.go index 8b345fb..a62eb30 100644 --- a/pkg/connection/connection_test.go +++ b/pkg/connection/connection_test.go @@ -22,6 +22,7 @@ func mockExceptionError(exception types.Exception) string { type ConnectionTestSuite struct { suite.Suite websocketMock *wsconn.WebsocketConnectionMock + ctx context.Context } func TestConnectionSuite(t *testing.T) { @@ -30,12 +31,13 @@ func TestConnectionSuite(t *testing.T) { func (suite *ConnectionTestSuite) SetupTest() { suite.websocketMock = wsconn.CreateWebsocketConnectionMock() + suite.ctx = context.Background() } func (suite *ConnectionTestSuite) TestConnectFails() { conn := &Connection{ Config: &config.Config{Host: "invalid", Port: 12345}, - Ctx: context.Background(), + Ctx: suite.ctx, IsClosed: true, } err := conn.Connect() @@ -43,7 +45,7 @@ func (suite *ConnectionTestSuite) TestConnectFails() { } func (suite *ConnectionTestSuite) TestQueryContextNamedParametersNotSupported() { - rows, err := suite.createOpenConnection().QueryContext(context.Background(), "query", []driver.NamedValue{{Name: "arg", Ordinal: 1, Value: "value"}}) + rows, err := suite.createOpenConnection().QueryContext(suite.ctx, "query", []driver.NamedValue{{Name: "arg", Ordinal: 1, Value: "value"}}) suite.EqualError(err, "E-EGOD-7: named parameters not supported") suite.Nil(rows) } @@ -66,7 +68,7 @@ func (suite *ConnectionTestSuite) TestQueryContext() { types.SqlQueryResponseResultSet{ResultType: "resultType", ResultSet: types.SqlQueryResponseResultSetData{}}) suite.websocketMock.SimulateOKResponse(types.ClosePreparedStatementCommand{Command: types.Command{Command: "closePreparedStatement"}, StatementHandle: 0, Attributes: types.Attributes{}}, nil) - rows, err := suite.createOpenConnection().QueryContext(context.Background(), "query", []driver.NamedValue{{Ordinal: 1, Value: "value"}}) + rows, err := suite.createOpenConnection().QueryContext(suite.ctx, "query", []driver.NamedValue{{Ordinal: 1, Value: "value"}}) suite.NoError(err) suite.Equal([]string{}, rows.Columns()) } @@ -95,7 +97,7 @@ func (suite *ConnectionTestSuite) TestQuery() { } func (suite *ConnectionTestSuite) TestExecContextNamedParametersNotSupported() { - rows, err := suite.createOpenConnection().ExecContext(context.Background(), "query", []driver.NamedValue{{Name: "arg", Ordinal: 1, Value: "value"}}) + rows, err := suite.createOpenConnection().ExecContext(suite.ctx, "query", []driver.NamedValue{{Name: "arg", Ordinal: 1, Value: "value"}}) suite.EqualError(err, "E-EGOD-7: named parameters not supported") suite.Nil(rows) } @@ -118,7 +120,7 @@ func (suite *ConnectionTestSuite) TestExecContext() { types.SqlQueryResponseRowCount{ResultType: "resultType", RowCount: 42}) suite.websocketMock.SimulateOKResponse(types.ClosePreparedStatementCommand{Command: types.Command{Command: "closePreparedStatement"}, StatementHandle: 0, Attributes: types.Attributes{}}, nil) - rows, err := suite.createOpenConnection().ExecContext(context.Background(), "query", []driver.NamedValue{{Ordinal: 1, Value: "value"}}) + rows, err := suite.createOpenConnection().ExecContext(suite.ctx, "query", []driver.NamedValue{{Ordinal: 1, Value: "value"}}) suite.NoError(err) rowsAffected, err := rows.RowsAffected() suite.NoError(err) @@ -153,7 +155,7 @@ func (suite *ConnectionTestSuite) TestExec() { func (suite *ConnectionTestSuite) TestPrepareContextFailsClosed() { conn := suite.createOpenConnection() conn.IsClosed = true - stmt, err := conn.PrepareContext(context.Background(), "query") + stmt, err := conn.PrepareContext(suite.ctx, "query") suite.EqualError(err, driver.ErrBadConn.Error()) suite.Nil(stmt) } @@ -166,7 +168,7 @@ func (suite *ConnectionTestSuite) TestPrepareContextPreparedStatementFails() { Attributes: types.Attributes{}, }, mockException) - stmt, err := suite.createOpenConnection().PrepareContext(context.Background(), "query") + stmt, err := suite.createOpenConnection().PrepareContext(suite.ctx, "query") suite.EqualError(err, mockExceptionError(mockException)) suite.Nil(stmt) } @@ -180,7 +182,7 @@ func (suite *ConnectionTestSuite) TestPrepareContextSuccess() { }, types.CreatePreparedStatementResponse{ ParameterData: types.ParameterData{Columns: []types.SqlQueryColumn{{Name: "col", DataType: types.SqlQueryColumnType{Type: "type"}}}}}) - stmt, err := suite.createOpenConnection().PrepareContext(context.Background(), "query") + stmt, err := suite.createOpenConnection().PrepareContext(suite.ctx, "query") suite.NoError(err) suite.NotNil(stmt) } @@ -248,7 +250,7 @@ func (suite *ConnectionTestSuite) TestBeginFailsWithAutocommitEnabled() { func (suite *ConnectionTestSuite) TestQueryFailsConnectionClosed() { conn := suite.createOpenConnection() conn.IsClosed = true - rows, err := conn.query(context.Background(), "query", nil) + rows, err := conn.query(suite.ctx, "query", nil) suite.EqualError(err, driver.ErrBadConn.Error()) suite.Nil(rows) } @@ -257,7 +259,7 @@ func (suite *ConnectionTestSuite) TestQueryNoArgs() { suite.websocketMock.SimulateSQLQueriesResponse( types.SqlCommand{Command: types.Command{Command: "execute"}, SQLText: "query", Attributes: types.Attributes{}}, types.SqlQueryResponseResultSet{ResultType: "resultType", ResultSet: types.SqlQueryResponseResultSetData{}}) - rows, err := suite.createOpenConnection().query(context.Background(), "query", []driver.Value{}) + rows, err := suite.createOpenConnection().query(suite.ctx, "query", []driver.Value{}) suite.NoError(err) suite.NotNil(rows) } @@ -266,7 +268,7 @@ func (suite *ConnectionTestSuite) TestQueryNoArgsFails() { suite.websocketMock.SimulateErrorResponse( types.SqlCommand{Command: types.Command{Command: "execute"}, SQLText: "query", Attributes: types.Attributes{}}, mockException) - rows, err := suite.createOpenConnection().query(context.Background(), "query", []driver.Value{}) + rows, err := suite.createOpenConnection().query(suite.ctx, "query", []driver.Value{}) suite.EqualError(err, mockExceptionError(mockException)) suite.Nil(rows) } @@ -289,7 +291,7 @@ func (suite *ConnectionTestSuite) TestQueryWithArgs() { types.SqlQueryResponseResultSet{ResultType: "resultType", ResultSet: types.SqlQueryResponseResultSetData{}}) suite.websocketMock.SimulateOKResponse(types.ClosePreparedStatementCommand{Command: types.Command{Command: "closePreparedStatement"}, StatementHandle: 0, Attributes: types.Attributes{}}, nil) - rows, err := suite.createOpenConnection().query(context.Background(), "query", []driver.Value{"value"}) + rows, err := suite.createOpenConnection().query(suite.ctx, "query", []driver.Value{"value"}) suite.NoError(err) suite.NotNil(rows) } @@ -303,7 +305,7 @@ func (suite *ConnectionTestSuite) TestQueryWithArgsFailsInPrepare() { }, mockException) - rows, err := suite.createOpenConnection().query(context.Background(), "query", []driver.Value{"value"}) + rows, err := suite.createOpenConnection().query(suite.ctx, "query", []driver.Value{"value"}) suite.EqualError(err, mockExceptionError(mockException)) suite.Nil(rows) } @@ -325,7 +327,7 @@ func (suite *ConnectionTestSuite) TestQueryWithArgsFailsInExecute() { }, mockException) - rows, err := suite.createOpenConnection().query(context.Background(), "query", []driver.Value{"value"}) + rows, err := suite.createOpenConnection().query(suite.ctx, "query", []driver.Value{"value"}) suite.EqualError(err, mockExceptionError(mockException)) suite.Nil(rows) } @@ -334,7 +336,7 @@ func (suite *ConnectionTestSuite) TestPasswordLoginFailsInitialRequest() { suite.websocketMock.SimulateErrorResponse(types.LoginCommand{Command: types.Command{Command: "login"}, ProtocolVersion: 42}, mockException) conn := suite.createOpenConnection() - err := conn.Login(context.Background()) + err := conn.Login(suite.ctx) suite.EqualError(err, mockExceptionError(mockException)) } @@ -342,7 +344,7 @@ func (suite *ConnectionTestSuite) TestPasswordLoginFailsEncryptingPasswordReques suite.websocketMock.SimulateOKResponse(types.LoginCommand{Command: types.Command{Command: "login"}, ProtocolVersion: 42}, types.PublicKeyResponse{PublicKeyPem: "", PublicKeyModulus: "", PublicKeyExponent: ""}) conn := suite.createOpenConnection() - err := conn.Login(context.Background()) + err := conn.Login(suite.ctx) suite.EqualError(err, driver.ErrBadConn.Error()) } @@ -352,7 +354,7 @@ func (suite *ConnectionTestSuite) TestPasswordLoginSuccess() { conn.IsClosed = true suite.True(conn.IsClosed) - err := conn.Login(context.Background()) + err := conn.Login(suite.ctx) suite.False(conn.IsClosed) suite.NoError(err) } @@ -364,7 +366,7 @@ func (suite *ConnectionTestSuite) TestAccessTokenLoginSuccess() { conn.Config.AccessToken = "accessToken" suite.True(conn.IsClosed) - err := conn.Login(context.Background()) + err := conn.Login(suite.ctx) suite.False(conn.IsClosed) suite.NoError(err) } @@ -376,7 +378,7 @@ func (suite *ConnectionTestSuite) TestAccessTokenLoginPrepareFails() { conn.Config.AccessToken = "accessToken" suite.True(conn.IsClosed) - err := conn.Login(context.Background()) + err := conn.Login(suite.ctx) suite.True(conn.IsClosed) suite.EqualError(err, "access token login failed: E-EGOD-11: execution failed with SQL error code 'mock sql code' and message 'mock error'") } @@ -388,7 +390,7 @@ func (suite *ConnectionTestSuite) TestRefreshTokenLoginSuccess() { conn.Config.RefreshToken = "refreshToken" suite.True(conn.IsClosed) - err := conn.Login(context.Background()) + err := conn.Login(suite.ctx) suite.False(conn.IsClosed) suite.NoError(err) } @@ -400,7 +402,7 @@ func (suite *ConnectionTestSuite) TestRefreshTokenLoginPrepareFails() { conn.Config.RefreshToken = "refreshToken" suite.True(conn.IsClosed) - err := conn.Login(context.Background()) + err := conn.Login(suite.ctx) suite.True(conn.IsClosed) suite.EqualError(err, "refresh token login failed: E-EGOD-11: execution failed with SQL error code 'mock sql code' and message 'mock error'") } @@ -410,7 +412,7 @@ func (suite *ConnectionTestSuite) TestLoginRestoresCompressionToTrue() { conn := suite.createOpenConnection() conn.Config.Compression = true - err := conn.Login(context.Background()) + err := conn.Login(suite.ctx) suite.True(conn.Config.Compression) suite.NoError(err) } @@ -419,7 +421,7 @@ func (suite *ConnectionTestSuite) TestLoginRestoresCompressionToFalse() { conn := suite.createOpenConnection() conn.Config.Compression = false - err := conn.Login(context.Background()) + err := conn.Login(suite.ctx) suite.False(conn.Config.Compression) suite.NoError(err) } @@ -429,7 +431,7 @@ func (suite *ConnectionTestSuite) TestLoginFails() { conn := suite.createOpenConnection() conn.IsClosed = false - err := conn.Login(context.Background()) + err := conn.Login(suite.ctx) suite.True(conn.IsClosed) suite.EqualError(err, "failed to login: E-EGOD-11: execution failed with SQL error code 'mock sql code' and message 'mock error'") } @@ -439,7 +441,7 @@ func (suite *ConnectionTestSuite) TestLoginFailureRestoresCompressionToTrue() { conn := suite.createOpenConnection() conn.Config.Compression = true - conn.Login(context.Background()) + conn.Login(suite.ctx) suite.True(conn.Config.Compression) } @@ -448,7 +450,7 @@ func (suite *ConnectionTestSuite) TestLoginFailureRestoresCompressionToFalse() { conn := suite.createOpenConnection() conn.Config.Compression = false - conn.Login(context.Background()) + conn.Login(suite.ctx) suite.False(conn.Config.Compression) } @@ -494,7 +496,7 @@ uYIhswioGpmyPXr/wqz1NFkt5wMzm6sU3lFfCjD5SxU6arQ1zVY3AgMBAAE= func (suite *ConnectionTestSuite) createOpenConnection() *Connection { conn := &Connection{ Config: &config.Config{Host: "invalid", Port: 12345, User: "user", Password: "password", ApiVersion: 42}, - Ctx: context.Background(), + Ctx: suite.ctx, IsClosed: false, websocket: suite.websocketMock, } diff --git a/pkg/connection/result_set_test.go b/pkg/connection/result_set_test.go index 6e6ffd0..efce05a 100644 --- a/pkg/connection/result_set_test.go +++ b/pkg/connection/result_set_test.go @@ -313,13 +313,14 @@ func (suite *ResultSetTestSuite) TestConvertValue() { } func (suite *ResultSetTestSuite) createResultSet() QueryResults { + ctx := context.Background() return QueryResults{ - ctx: context.Background(), + ctx: ctx, data: &types.SqlQueryResponseResultSetData{ ResultSetHandle: 1, NumRows: 2, NumRowsInMessage: 2, Columns: []types.SqlQueryColumn{{}, {}}, }, con: &Connection{ - websocket: suite.websocketMock, Config: &config.Config{}, Ctx: context.Background(), IsClosed: false, + websocket: suite.websocketMock, Config: &config.Config{}, Ctx: ctx, IsClosed: false, }, fetchedRows: 0, totalRowPointer: 0, diff --git a/pkg/connection/transaction.go b/pkg/connection/transaction.go index a4ca1bd..2a01c07 100644 --- a/pkg/connection/transaction.go +++ b/pkg/connection/transaction.go @@ -9,11 +9,12 @@ import ( ) type Transaction struct { + ctx context.Context connection *Connection } -func NewTransaction(connection *Connection) *Transaction { - return &Transaction{connection: connection} +func NewTransaction(ctx context.Context, connection *Connection) *Transaction { + return &Transaction{ctx: ctx, connection: connection} } func (t *Transaction) Commit() error { @@ -24,7 +25,7 @@ func (t *Transaction) Commit() error { logger.ErrorLogger.Print(errors.ErrClosed) return driver.ErrBadConn } - _, err := t.connection.SimpleExec(context.Background(), "COMMIT") + _, err := t.connection.SimpleExec(t.ctx, "COMMIT") t.connection = nil return err } @@ -37,7 +38,7 @@ func (t *Transaction) Rollback() error { logger.ErrorLogger.Print(errors.ErrClosed) return driver.ErrBadConn } - _, err := t.connection.SimpleExec(context.Background(), "ROLLBACK") + _, err := t.connection.SimpleExec(t.ctx, "ROLLBACK") t.connection = nil return err } diff --git a/pkg/connection/transaction_test.go b/pkg/connection/transaction_test.go index 357dcc7..ced63d3 100644 --- a/pkg/connection/transaction_test.go +++ b/pkg/connection/transaction_test.go @@ -1,6 +1,7 @@ package connection import ( + "context" "database/sql/driver" "testing" @@ -16,23 +17,30 @@ func TestTransactionSuite(t *testing.T) { } func (suite *TransactionTestSuite) TestCommitWithEmptyConnection() { - transaction := Transaction{nil} + transaction := suite.createTransaction() + transaction.connection = nil suite.EqualError(transaction.Commit(), "E-EGOD-1: invalid connection") } func (suite *TransactionTestSuite) TestRollbackWithEmptyConnection() { - transaction := Transaction{nil} + transaction := suite.createTransaction() + transaction.connection = nil suite.EqualError(transaction.Rollback(), "E-EGOD-1: invalid connection") } func (suite *TransactionTestSuite) TestCommitWithClosedConnection() { - connection := Connection{IsClosed: true} - transaction := Transaction{connection: &connection} + transaction := suite.createTransaction() + transaction.connection.IsClosed = true suite.EqualError(transaction.Commit(), driver.ErrBadConn.Error()) } func (suite *TransactionTestSuite) TestRollbackWithClosedConnection() { - connection := Connection{IsClosed: true} - transaction := Transaction{connection: &connection} + transaction := suite.createTransaction() + transaction.connection.IsClosed = true suite.EqualError(transaction.Rollback(), driver.ErrBadConn.Error()) } + +func (suite *TransactionTestSuite) createTransaction() Transaction { + connection := Connection{IsClosed: true} + return Transaction{ctx: context.Background(), connection: &connection} +} diff --git a/pkg/connection/websocket_test.go b/pkg/connection/websocket_test.go index c7ebe88..f395599 100644 --- a/pkg/connection/websocket_test.go +++ b/pkg/connection/websocket_test.go @@ -16,6 +16,7 @@ import ( type WebsocketTestSuite struct { suite.Suite websocketMock *wsconn.WebsocketConnectionMock + ctx context.Context } func TestWebsocketSuite(t *testing.T) { @@ -24,6 +25,7 @@ func TestWebsocketSuite(t *testing.T) { func (suite *WebsocketTestSuite) SetupTest() { suite.websocketMock = wsconn.CreateWebsocketConnectionMock() + suite.ctx = context.Background() } func (suite *WebsocketTestSuite) TestSendSuccess() { @@ -31,7 +33,7 @@ func (suite *WebsocketTestSuite) TestSendSuccess() { response := &types.PublicKeyResponse{} suite.websocketMock.SimulateOKResponse(request, types.PublicKeyResponse{PublicKeyPem: "pem"}) - err := suite.createOpenConnection().Send(context.Background(), request, response) + err := suite.createOpenConnection().Send(suite.ctx, request, response) suite.NoError(err) suite.Equal("pem", response.PublicKeyPem) } @@ -44,7 +46,7 @@ func (suite *WebsocketTestSuite) TestSendSuccessWithCompression() { conn := suite.createOpenConnection() conn.Config.Compression = true - err := conn.Send(context.Background(), request, response) + err := conn.Send(suite.ctx, request, response) suite.NoError(err) suite.Equal("pem", response.PublicKeyPem) } @@ -57,7 +59,7 @@ func (suite *WebsocketTestSuite) TestSendWithCompressionFailsDuringUncompress() conn := suite.createOpenConnection() conn.Config.Compression = true - err := conn.Send(context.Background(), request, response) + err := conn.Send(suite.ctx, request, response) suite.EqualError(err, "W-EGOD-18: could not decode compressed data: 'zlib: invalid header'") suite.True(errors.Is(err, driver.ErrBadConn)) } @@ -66,7 +68,7 @@ func (suite *WebsocketTestSuite) TestSendSuccessNoResponse() { request := types.LoginCommand{Command: types.Command{Command: "login"}} suite.websocketMock.SimulateOKResponse(request, types.PublicKeyResponse{PublicKeyPem: "pem"}) - err := suite.createOpenConnection().Send(context.Background(), request, nil) + err := suite.createOpenConnection().Send(suite.ctx, request, nil) suite.NoError(err) } @@ -76,7 +78,7 @@ func (suite *WebsocketTestSuite) TestSendFailsNotConnected() { conn := suite.createOpenConnection() conn.websocket = nil - err := conn.Send(context.Background(), request, response) + err := conn.Send(suite.ctx, request, response) suite.EqualError(err, `E-EGOD-29: could not send request '{"command":"login","protocolVersion":0,"attributes":{}}': not connected to server`) } @@ -84,7 +86,7 @@ func (suite *WebsocketTestSuite) TestSendFailsAtWriteMessage() { request := types.LoginCommand{Command: types.Command{Command: "login"}} response := &types.PublicKeyResponse{} suite.websocketMock.OnWriteAnyMessage(fmt.Errorf("mock error")) - err := suite.createOpenConnection().Send(context.Background(), request, response) + err := suite.createOpenConnection().Send(suite.ctx, request, response) suite.EqualError(err, "W-EGOD-16: could not send request: 'mock error'") suite.True(errors.Is(err, driver.ErrBadConn)) } @@ -95,7 +97,7 @@ func (suite *WebsocketTestSuite) TestSendFailsAtReadMessage() { suite.websocketMock.OnWriteAnyMessage(nil) suite.websocketMock.OnReadTextMessage(nil, fmt.Errorf("mock error")) - err := suite.createOpenConnection().Send(context.Background(), request, response) + err := suite.createOpenConnection().Send(suite.ctx, request, response) suite.EqualError(err, "W-EGOD-17: could not receive data: 'mock error'") suite.True(errors.Is(err, driver.ErrBadConn)) } @@ -106,7 +108,7 @@ func (suite *WebsocketTestSuite) TestSendFailsAtDecodingResponse() { suite.websocketMock.OnWriteAnyMessage(nil) suite.websocketMock.OnReadTextMessage([]byte("invalid json"), nil) - err := suite.createOpenConnection().Send(context.Background(), request, response) + err := suite.createOpenConnection().Send(suite.ctx, request, response) suite.EqualError(err, "W-EGOD-19: could not decode json data 'invalid json': 'invalid character 'i' looking for beginning of value'") suite.True(errors.Is(err, driver.ErrBadConn)) } @@ -116,7 +118,7 @@ func (suite *WebsocketTestSuite) TestSendFailsAtNonOKStatusException() { response := &types.PublicKeyResponse{} suite.websocketMock.SimulateErrorResponse(request, mockException) - err := suite.createOpenConnection().Send(context.Background(), request, response) + err := suite.createOpenConnection().Send(suite.ctx, request, response) suite.EqualError(err, "E-EGOD-11: execution failed with SQL error code 'mock sql code' and message 'mock error'") } @@ -126,7 +128,7 @@ func (suite *WebsocketTestSuite) TestSendFailsAtNonOKStatusMissingException() { suite.websocketMock.OnWriteTextMessage(wsconn.JsonMarshall(request), nil) suite.websocketMock.OnReadTextMessage([]byte(`{"status": "notok"}`), nil) - err := suite.createOpenConnection().Send(context.Background(), request, response) + err := suite.createOpenConnection().Send(suite.ctx, request, response) suite.EqualError(err, `result status is not 'ok': "notok", expected exception in response &{notok [] }`) } @@ -136,7 +138,7 @@ func (suite *WebsocketTestSuite) TestSendFailsAtParsingResponseData() { suite.websocketMock.OnWriteTextMessage(wsconn.JsonMarshall(request), nil) suite.websocketMock.OnReadTextMessage([]byte(`{"status": "ok", "responseData": "invalid"}`), nil) - err := suite.createOpenConnection().Send(context.Background(), request, response) + err := suite.createOpenConnection().Send(suite.ctx, request, response) suite.EqualError(err, `failed to parse response data "\"invalid\"": json: cannot unmarshal string into Go value of type types.PublicKeyResponse`) } @@ -170,7 +172,7 @@ func (suite *WebsocketTestSuite) TestCreateURL() { func (suite *WebsocketTestSuite) createOpenConnection() *Connection { conn := &Connection{ Config: &config.Config{Host: "invalid", Port: 12345, User: "user", Password: "password", ApiVersion: 42}, - Ctx: context.Background(), + Ctx: suite.ctx, IsClosed: false, websocket: suite.websocketMock, } diff --git a/pkg/connection/wsconn/wsconn_i_test.go b/pkg/connection/wsconn/wsconn_i_test.go index 68c3c04..0a1309d 100644 --- a/pkg/connection/wsconn/wsconn_i_test.go +++ b/pkg/connection/wsconn/wsconn_i_test.go @@ -15,6 +15,7 @@ import ( type WebsocketITestSuite struct { suite.Suite exasol *integrationTesting.DbTestSetup + ctx context.Context } func TestIntegrationWebsocketSuite(t *testing.T) { @@ -29,21 +30,25 @@ func (suite *WebsocketITestSuite) TearDownSuite() { suite.exasol.StopDb() } +func (suite *WebsocketITestSuite) SetupTest() { + suite.ctx = context.Background() +} + func (suite *WebsocketITestSuite) TestCreateConnectionSuccess() { - conn, err := wsconn.CreateConnection(context.Background(), true, "", suite.exasol.GetUrl()) + conn, err := wsconn.CreateConnection(suite.ctx, true, "", suite.exasol.GetUrl()) suite.NoError(err) suite.NotNil(conn) conn.Close() } func (suite *WebsocketITestSuite) TestCreateConnectionFailed() { - conn, err := wsconn.CreateConnection(context.Background(), true, "", url.URL{Scheme: "wss", Host: "invalid:12345"}) + conn, err := wsconn.CreateConnection(suite.ctx, true, "", url.URL{Scheme: "wss", Host: "invalid:12345"}) suite.ErrorContains(err, `failed to connect to URL "wss://invalid:12345": dial tcp`) suite.Nil(conn) } func (suite *WebsocketITestSuite) TestCreateConnectionInvalidCertificate() { - conn, err := wsconn.CreateConnection(context.Background(), false, "invalid", suite.exasol.GetUrl()) + conn, err := wsconn.CreateConnection(suite.ctx, false, "invalid", suite.exasol.GetUrl()) suite.ErrorContains(err, fmt.Sprintf(`failed to connect to URL "wss://%s:%d": tls: failed to verify certificate`, suite.exasol.ConnectionInfo.Host, suite.exasol.ConnectionInfo.Port)) suite.Nil(conn) } @@ -64,7 +69,7 @@ func (suite *WebsocketITestSuite) TestRead() { } func (suite *WebsocketITestSuite) createConnection() wsconn.WebsocketConnection { - conn, err := wsconn.CreateConnection(context.Background(), true, "", suite.exasol.GetUrl()) + conn, err := wsconn.CreateConnection(suite.ctx, true, "", suite.exasol.GetUrl()) if err != nil { suite.FailNowf("connection failed: %v", err.Error()) }