diff --git a/conn.go b/conn.go index b131e4c..176ff6a 100644 --- a/conn.go +++ b/conn.go @@ -13,6 +13,7 @@ import ( type conn struct { athena athenaAPI + catalog string db string OutputLocation string @@ -57,11 +58,15 @@ func (c *conn) runQuery(ctx context.Context, query string) (driver.Rows, error) // startQuery starts an Athena query and returns its ID. func (c *conn) startQuery(ctx context.Context, query string) (string, error) { + queryCtx := &types.QueryExecutionContext{ + Database: aws.String(c.db), + } + if c.catalog != "" { + queryCtx.Catalog = aws.String(c.catalog) + } resp, err := c.athena.StartQueryExecution(ctx, &athena.StartQueryExecutionInput{ - QueryString: aws.String(query), - QueryExecutionContext: &types.QueryExecutionContext{ - Database: aws.String(c.db), - }, + QueryString: aws.String(query), + QueryExecutionContext: queryCtx, ResultConfiguration: &types.ResultConfiguration{ OutputLocation: aws.String(c.OutputLocation), }, diff --git a/db_test.go b/db_test.go index a784853..48ade7d 100644 --- a/db_test.go +++ b/db_test.go @@ -119,7 +119,7 @@ func TestOpen(t *testing.T) { awsConfig, err := config.LoadDefaultConfig(context.Background()) require.NoError(t, err, "LoadDefaultConfig") db, err := Open(DriverConfig{ - Config: &awsConfig, + Config: &awsConfig, Database: AthenaDatabase, OutputLocation: fmt.Sprintf("s3://%s/noop", S3Bucket), }) @@ -129,6 +129,28 @@ func TestOpen(t *testing.T) { require.NoError(t, err, "Query") } +func TestDriverWithDBCatalog(t *testing.T) { + ctx := context.Background() + catalogName := os.Getenv("ATHENA_CATALOG") + if catalogName == "" { + t.Skip("ATHENA_CATALOG not set") + } + + tableName := os.Getenv("ATHENA_TABLE") + if tableName == "" { + tableName = "catalog_test_table" + } + connStr := fmt.Sprintf("catalog=%s&db=%s&output_location=s3://%s/output", catalogName, AthenaDatabase, S3Bucket) + db, err := sql.Open("athena", connStr) + require.NoError(t, err, "Open") + defer db.Close() + + harness := &athenaHarness{t: t, db: db, table: tableName} + defer harness.teardown(ctx) + harness.mustExec(ctx, `CREATE TABLE %s ( value string )`, tableName) + harness.mustExec(ctx, `INSERT INTO %s VALUES ('foo')`, tableName) +} + type dummyRow struct { NullValue *struct{} `json:"nullValue"` SmallintType int `json:"smallintType"` diff --git a/driver.go b/driver.go index 03dceca..c311117 100644 --- a/driver.go +++ b/driver.go @@ -80,6 +80,7 @@ func (d *Driver) Open(connStr string) (driver.Conn, error) { return &conn{ athena: athena.NewFromConfig(*cfg.Config), db: cfg.Database, + catalog: cfg.Catalog, OutputLocation: cfg.OutputLocation, pollFrequency: cfg.PollFrequency, }, nil @@ -116,6 +117,7 @@ func Open(cfg DriverConfig) (*sql.DB, error) { type DriverConfig struct { Config *aws.Config Database string + Catalog string OutputLocation string PollFrequency time.Duration @@ -139,6 +141,7 @@ func configFromConnectionString(ctx context.Context, connStr string) (*DriverCon cfg.Config = &awsConfig cfg.Database = args.Get("db") + cfg.Catalog = args.Get("catalog") cfg.OutputLocation = args.Get("output_location") frequencyStr := args.Get("poll_frequency")