Skip to content

Commit

Permalink
test(sourcerer database functions): Add tests for db calls
Browse files Browse the repository at this point in the history
Figured out how to mock the db and added some initial test coverage for sourcerer.
  • Loading branch information
gwenwindflower committed Apr 18, 2024
1 parent 29b8fe8 commit 8b399ec
Show file tree
Hide file tree
Showing 15 changed files with 370 additions and 111 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,9 @@ jobs:
uses: actions/[email protected]
with:
go-version: ">=1.22.1"
- name: Run tests
- name: Run main tests
run: go test
- name: Run `sourcerer` tests
run: |
cd sourcerer
go test
66 changes: 33 additions & 33 deletions forms.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ type FormResponse struct {

func Forms() (formResponse FormResponse) {
formResponse = FormResponse{}
intro_form := huh.NewForm(
introForm := huh.NewForm(
huh.NewGroup(
huh.NewNote().
Title("🏁 Welcome to tbd! 🏎️✨").
Expand Down Expand Up @@ -88,19 +88,19 @@ You'll need:
Placeholder("stg"),
),
)
project_name_form := huh.NewForm(
projectNameForm := huh.NewForm(
huh.NewGroup(huh.NewInput().
Title("What is the name of your dbt project?").
Value(&formResponse.ProjectName).
Placeholder("gondor_patrol_analytics"),
))
profile_create_form := huh.NewForm(
profileCreateForm := huh.NewForm(
huh.NewGroup(
huh.NewConfirm().Affirmative("Yes, pls").Negative("No, thx").
Title("Would you like to generate a profiles.yml file from the info you provide next?").
Value(&formResponse.CreateProfile),
))
dbt_form := huh.NewForm(
dbtForm := huh.NewForm(
huh.NewGroup(
huh.NewInput().
Title("What is the dbt profile name you'd like to use?").
Expand All @@ -115,7 +115,7 @@ You'll need:
Value(&formResponse.Schema),
),
)
warehouse_form := huh.NewForm(
warehouseForm := huh.NewForm(
huh.NewGroup(
huh.NewSelect[string]().
Title("Choose your warehouse.").
Expand All @@ -127,7 +127,7 @@ You'll need:
Value(&formResponse.Warehouse),
),
)
snowflake_form := huh.NewForm(
snowflakeForm := huh.NewForm(
huh.NewGroup(
huh.NewInput().
Title("What is your username?").
Expand All @@ -146,15 +146,15 @@ You'll need:
Value(&formResponse.Database).Placeholder("gondor"),
),
)
bigquery_form := huh.NewForm(
bigqueryForm := huh.NewForm(
huh.NewGroup(
huh.NewInput().Title("What is your GCP project's id?").
Value(&formResponse.Project).Placeholder("legolas_inc"),
huh.NewInput().Title("What is the dataset you want to generate?").
Value(&formResponse.Dataset).Placeholder("mirkwood"),
),
)
duckdb_form := huh.NewForm(
duckdbForm := huh.NewForm(
huh.NewGroup(
huh.NewInput().Title(`What is the path to your DuckDB database?
Relative to pwd e.g. if db is in this dir -> cool_ducks.db`).
Expand All @@ -165,15 +165,15 @@ Relative to pwd e.g. if db is in this dir -> cool_ducks.db`).
Value(&formResponse.Schema).Placeholder("raw"),
),
)
llm_form := huh.NewForm(
llmForm := huh.NewForm(
huh.NewGroup(
huh.NewInput().
Title("What env var holds your Groq key?").
Placeholder("GROQ_API_KEY").
Value(&formResponse.GroqKeyEnvVar),
),
)
dir_form := huh.NewForm(
dirForm := huh.NewForm(
huh.NewGroup(
huh.NewNote().
Title("🚧🚨 Choose your build directory carefully! 🚨🚧").
Expand All @@ -188,81 +188,81 @@ tbd will _intentionally error out_.`),
Placeholder("build"),
),
)
confirm_form := huh.NewForm(
confirmForm := huh.NewForm(
huh.NewGroup(
huh.NewConfirm().Affirmative("Let's go!").Negative("Nevermind").
Title("🚦Are you ready to do this thing?🚦").
Value(&formResponse.Confirm),
),
)
intro_form.WithTheme(huh.ThemeCatppuccin())
profile_create_form.WithTheme(huh.ThemeCatppuccin())
project_name_form.WithTheme(huh.ThemeCatppuccin())
dbt_form.WithTheme(huh.ThemeCatppuccin())
warehouse_form.WithTheme(huh.ThemeCatppuccin())
snowflake_form.WithTheme(huh.ThemeCatppuccin())
bigquery_form.WithTheme(huh.ThemeCatppuccin())
duckdb_form.WithTheme(huh.ThemeCatppuccin())
llm_form.WithTheme(huh.ThemeCatppuccin())
dir_form.WithTheme(huh.ThemeCatppuccin())
confirm_form.WithTheme(huh.ThemeCatppuccin())
err := intro_form.Run()
introForm.WithTheme(huh.ThemeCatppuccin())
profileCreateForm.WithTheme(huh.ThemeCatppuccin())
projectNameForm.WithTheme(huh.ThemeCatppuccin())
dbtForm.WithTheme(huh.ThemeCatppuccin())
warehouseForm.WithTheme(huh.ThemeCatppuccin())
snowflakeForm.WithTheme(huh.ThemeCatppuccin())
bigqueryForm.WithTheme(huh.ThemeCatppuccin())
duckdbForm.WithTheme(huh.ThemeCatppuccin())
llmForm.WithTheme(huh.ThemeCatppuccin())
dirForm.WithTheme(huh.ThemeCatppuccin())
confirmForm.WithTheme(huh.ThemeCatppuccin())
err := introForm.Run()
if err != nil {
log.Fatalf("Error running intro form %v\n", err)
}
if formResponse.UseDbtProfile {
err = dbt_form.Run()
err = dbtForm.Run()
if err != nil {
log.Fatalf("Error running dbt form %v\n", err)
}
} else {
err = profile_create_form.Run()
err = profileCreateForm.Run()
if err != nil {
log.Fatalf("Error running profile create form %v\n", err)
}
if formResponse.ScaffoldProject {
err = project_name_form.Run()
err = projectNameForm.Run()
if err != nil {
log.Fatalf("Error running project name form %v\n", err)
}
}
err = warehouse_form.Run()
err = warehouseForm.Run()
if err != nil {
log.Fatalf("Error running warehouse form %v\n", err)
}
switch formResponse.Warehouse {
case "snowflake":
err = snowflake_form.Run()
err = snowflakeForm.Run()
if err != nil {
log.Fatalf("Error running snowflake form %v\n", err)
}
case "bigquery":
{
err = bigquery_form.Run()
err = bigqueryForm.Run()
if err != nil {
log.Fatalf("Error running bigquery form %v\n", err)
}
}
case "duckdb":
{
err = duckdb_form.Run()
err = duckdbForm.Run()
if err != nil {
log.Fatalf("Error running duckdb form %v\n", err)
}
}
}
}
if formResponse.GenerateDescriptions {
err = llm_form.Run()
err = llmForm.Run()
if err != nil {
log.Fatalf("Error running LLM features form %v\n", err)
}
}
err = dir_form.Run()
err = dirForm.Run()
if err != nil {
log.Fatalf("Error running build directory form %v\n", err)
}
err = confirm_form.Run()
err = confirmForm.Run()
if err != nil {
log.Fatalf("Error running confirmation form %v\n", err)
}
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 // indirect
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0 // indirect
github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect
github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c // indirect
github.com/apache/arrow/go/v14 v14.0.2 // indirect
github.com/apache/arrow/go/v15 v15.0.0 // indirect
Expand Down
3 changes: 3 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0/go.mod h1:2e8rMJtl2+
github.com/AzureAD/microsoft-authentication-library-for-go v0.5.1 h1:BWe8a+f/t+7KY7zH2mqygeUD0t8hNFXe08p1Pb3/jKE=
github.com/AzureAD/microsoft-authentication-library-for-go v0.5.1/go.mod h1:Vt9sXTKwMyGcOxSmLDMnGPgqsUg7m8pe215qMLrDXw4=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c h1:RGWPOewvKIROun94nF7v2cua9qP+thov/7M50KEoeSU=
github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c/go.mod h1:X0CRv0ky0k6m906ixxpzmDRLvX58TFUKS2eePweuyxk=
github.com/apache/arrow/go/v14 v14.0.2 h1:N8OkaJEOfI3mEZt07BIkvo4sC6XDbL+48MBPWO5IONw=
Expand Down Expand Up @@ -178,6 +180,7 @@ github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9Y
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
github.com/klauspost/compress v1.17.7 h1:ehO88t2UGzQK66LMdE8tibEd1ErmzZjNEqWkjLAKQQg=
github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
Expand Down
12 changes: 10 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,20 @@ func main() {

dbc, err := sourcerer.GetConn(cd)
if err != nil {
log.Fatalf("Error getting connection: %v\n", err)
log.Fatalf("Error getting database connection: %v\n", err)
}
ts, err := dbc.GetSources(ctx)
err = dbc.ConnectToDb(ctx)
if err != nil {
log.Fatalf("Error connecting to database: %v\n", err)
}
ts, err := dbc.GetSourceTables(ctx)
if err != nil {
log.Fatalf("Error getting sources: %v\n", err)
}
err = sourcerer.PutColumnsOnTables(ctx, ts, dbc)
if err != nil {
log.Fatalf("Error putting columns on tables: %v\n", err)
}

e.DbElapsed = time.Since(e.DbStart).Seconds()
// End of database interaction, start of processing
Expand Down
6 changes: 3 additions & 3 deletions sourcerer/connect_to_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
_ "github.com/snowflakedb/gosnowflake"
)

func (sfc *SfConn) ConnectToDB(ctx context.Context) (err error) {
func (sfc *SfConn) ConnectToDb(ctx context.Context) (err error) {
connStr := fmt.Sprintf(
"%s@%s/%s/%s?authenticator=externalbrowser",
sfc.Username,
Expand All @@ -31,7 +31,7 @@ func (sfc *SfConn) ConnectToDB(ctx context.Context) (err error) {
return err
}

func (bqc *BqConn) ConnectToDB(ctx context.Context) (err error) {
func (bqc *BqConn) ConnectToDb(ctx context.Context) (err error) {
_, bqc.Cancel = context.WithTimeout(ctx, 1*time.Minute)
defer bqc.Cancel()
bqc.Bq, err = bigquery.NewClient(ctx, bqc.Project)
Expand All @@ -41,7 +41,7 @@ func (bqc *BqConn) ConnectToDB(ctx context.Context) (err error) {
return err
}

func (dc *DuckConn) ConnectToDB(ctx context.Context) (err error) {
func (dc *DuckConn) ConnectToDb(ctx context.Context) (err error) {
_, dc.Cancel = context.WithTimeout(ctx, 1*time.Minute)
defer dc.Cancel()
if _, err := os.Stat(dc.Path); os.IsNotExist(err) {
Expand Down
39 changes: 39 additions & 0 deletions sourcerer/connect_to_db_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package sourcerer

import (
"testing"

"github.com/DATA-DOG/go-sqlmock"
"github.com/gwenwindflower/tbd/shared"
)

func TestConnectToDb(t *testing.T) {
cd := shared.ConnectionDetails{
ConnType: "snowflake",
Account: "dunedain.snowflakecomputing.com",
Username: "aragorn",
Database: "gondor",
Schema: "minas-tirith",
}
conn, err := GetConn(cd)
if err != nil {
t.Errorf("GetConn failed: %v", err)
}
SfConn, ok := conn.(*SfConn)
if !ok {
t.Errorf("conn not of type SfConn: %v", err)
}
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
SfConn.Db = db
defer SfConn.Db.Close()
mock.ExpectBegin()
if _, err := SfConn.Db.Begin(); err != nil {
t.Errorf("error '%s' was not expected, while pinging db", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}
57 changes: 57 additions & 0 deletions sourcerer/get_columns_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package sourcerer

import (
"context"
"fmt"
"testing"

"github.com/DATA-DOG/go-sqlmock"
"github.com/gwenwindflower/tbd/shared"
)

func TestGetColumnsSnowflake(t *testing.T) {
t.SkipNow()
ctx := context.Background()
st := shared.SourceTable{
Name: "table1",
}
cd := shared.ConnectionDetails{
ConnType: "snowflake",
Account: "dunedain.snowflakecomputing.com",
Username: "aragorn",
Database: "gondor",
Schema: "minas-tirith",
}
conn, err := GetConn(cd)
if err != nil {
t.Errorf("GetConn failed: %v", err)
}
if conn == nil {
t.Errorf("GetConn failed: conn is nil")
}
SfConn, ok := conn.(*SfConn)
if !ok {
t.Errorf("GetConn failed: conn is not of type SfConn")
}
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
SfConn.Db = db
defer SfConn.Db.Close()
q := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = '%s' AND table_name = '%s'", SfConn.Schema, st.Name)
mock.ExpectQuery(q).WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type"}).AddRow("column1", "varchar").AddRow("column2", "varchar").AddRow("column3", "int"))
cols, err := SfConn.GetColumns(ctx, st)
if err != nil {
t.Errorf("GetColumns failed: %v", err)
}
if len(cols) != 1 {
t.Errorf("GetColumns failed: expected 1 column, got %d", len(cols))
}
if cols[0].Name != "column1" {
t.Errorf("GetColumns failed: expected column name %s, got %s", "column1", cols[0].Name)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
}
5 changes: 2 additions & 3 deletions sourcerer/get_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@ import (
)

type DbConn interface {
ConnectToDB(ctx context.Context) error
GetSources(ctx context.Context) (shared.SourceTables, error)
ConnectToDb(ctx context.Context) error
GetSourceTables(ctx context.Context) (shared.SourceTables, error)
GetColumns(ctx context.Context, t shared.SourceTable) ([]shared.Column, error)
PutColumnsOnTables(ctx context.Context, tables shared.SourceTables)
}

type SfConn struct {
Expand Down
Loading

0 comments on commit 8b399ec

Please sign in to comment.