diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 43a45d4..50a640b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,5 +15,9 @@ jobs: uses: actions/setup-go@v5.0.0 with: go-version: ">=1.22.1" - - name: Run tests + - name: Run main tests run: go test + - name: Run `sourcerer` tests + run: | + cd sourcerer + go test diff --git a/forms.go b/forms.go index 6aad49b..cef8b00 100644 --- a/forms.go +++ b/forms.go @@ -31,7 +31,7 @@ type FormResponse struct { func Forms() (formResponse FormResponse) { formResponse = FormResponse{} - intro_form := huh.NewForm( + introForm := huh.NewForm( huh.NewGroup( huh.NewNote(). Title("🏁 Welcome to tbd! 🏎️✨"). @@ -88,19 +88,19 @@ You'll need: Placeholder("stg"), ), ) - project_name_form := huh.NewForm( + projectNameForm := huh.NewForm( huh.NewGroup(huh.NewInput(). Title("What is the name of your dbt project?"). Value(&formResponse.ProjectName). Placeholder("gondor_patrol_analytics"), )) - profile_create_form := huh.NewForm( + profileCreateForm := huh.NewForm( huh.NewGroup( huh.NewConfirm().Affirmative("Yes, pls").Negative("No, thx"). Title("Would you like to generate a profiles.yml file from the info you provide next?"). Value(&formResponse.CreateProfile), )) - dbt_form := huh.NewForm( + dbtForm := huh.NewForm( huh.NewGroup( huh.NewInput(). Title("What is the dbt profile name you'd like to use?"). @@ -115,7 +115,7 @@ You'll need: Value(&formResponse.Schema), ), ) - warehouse_form := huh.NewForm( + warehouseForm := huh.NewForm( huh.NewGroup( huh.NewSelect[string](). Title("Choose your warehouse."). @@ -127,7 +127,7 @@ You'll need: Value(&formResponse.Warehouse), ), ) - snowflake_form := huh.NewForm( + snowflakeForm := huh.NewForm( huh.NewGroup( huh.NewInput(). Title("What is your username?"). @@ -146,7 +146,7 @@ You'll need: Value(&formResponse.Database).Placeholder("gondor"), ), ) - bigquery_form := huh.NewForm( + bigqueryForm := huh.NewForm( huh.NewGroup( huh.NewInput().Title("What is your GCP project's id?"). Value(&formResponse.Project).Placeholder("legolas_inc"), @@ -154,7 +154,7 @@ You'll need: Value(&formResponse.Dataset).Placeholder("mirkwood"), ), ) - duckdb_form := huh.NewForm( + duckdbForm := huh.NewForm( huh.NewGroup( huh.NewInput().Title(`What is the path to your DuckDB database? Relative to pwd e.g. if db is in this dir -> cool_ducks.db`). @@ -165,7 +165,7 @@ Relative to pwd e.g. if db is in this dir -> cool_ducks.db`). Value(&formResponse.Schema).Placeholder("raw"), ), ) - llm_form := huh.NewForm( + llmForm := huh.NewForm( huh.NewGroup( huh.NewInput(). Title("What env var holds your Groq key?"). @@ -173,7 +173,7 @@ Relative to pwd e.g. if db is in this dir -> cool_ducks.db`). Value(&formResponse.GroqKeyEnvVar), ), ) - dir_form := huh.NewForm( + dirForm := huh.NewForm( huh.NewGroup( huh.NewNote(). Title("🚧🚨 Choose your build directory carefully! 🚨🚧"). @@ -188,64 +188,64 @@ tbd will _intentionally error out_.`), Placeholder("build"), ), ) - confirm_form := huh.NewForm( + confirmForm := huh.NewForm( huh.NewGroup( huh.NewConfirm().Affirmative("Let's go!").Negative("Nevermind"). Title("🚦Are you ready to do this thing?🚦"). Value(&formResponse.Confirm), ), ) - intro_form.WithTheme(huh.ThemeCatppuccin()) - profile_create_form.WithTheme(huh.ThemeCatppuccin()) - project_name_form.WithTheme(huh.ThemeCatppuccin()) - dbt_form.WithTheme(huh.ThemeCatppuccin()) - warehouse_form.WithTheme(huh.ThemeCatppuccin()) - snowflake_form.WithTheme(huh.ThemeCatppuccin()) - bigquery_form.WithTheme(huh.ThemeCatppuccin()) - duckdb_form.WithTheme(huh.ThemeCatppuccin()) - llm_form.WithTheme(huh.ThemeCatppuccin()) - dir_form.WithTheme(huh.ThemeCatppuccin()) - confirm_form.WithTheme(huh.ThemeCatppuccin()) - err := intro_form.Run() + introForm.WithTheme(huh.ThemeCatppuccin()) + profileCreateForm.WithTheme(huh.ThemeCatppuccin()) + projectNameForm.WithTheme(huh.ThemeCatppuccin()) + dbtForm.WithTheme(huh.ThemeCatppuccin()) + warehouseForm.WithTheme(huh.ThemeCatppuccin()) + snowflakeForm.WithTheme(huh.ThemeCatppuccin()) + bigqueryForm.WithTheme(huh.ThemeCatppuccin()) + duckdbForm.WithTheme(huh.ThemeCatppuccin()) + llmForm.WithTheme(huh.ThemeCatppuccin()) + dirForm.WithTheme(huh.ThemeCatppuccin()) + confirmForm.WithTheme(huh.ThemeCatppuccin()) + err := introForm.Run() if err != nil { log.Fatalf("Error running intro form %v\n", err) } if formResponse.UseDbtProfile { - err = dbt_form.Run() + err = dbtForm.Run() if err != nil { log.Fatalf("Error running dbt form %v\n", err) } } else { - err = profile_create_form.Run() + err = profileCreateForm.Run() if err != nil { log.Fatalf("Error running profile create form %v\n", err) } if formResponse.ScaffoldProject { - err = project_name_form.Run() + err = projectNameForm.Run() if err != nil { log.Fatalf("Error running project name form %v\n", err) } } - err = warehouse_form.Run() + err = warehouseForm.Run() if err != nil { log.Fatalf("Error running warehouse form %v\n", err) } switch formResponse.Warehouse { case "snowflake": - err = snowflake_form.Run() + err = snowflakeForm.Run() if err != nil { log.Fatalf("Error running snowflake form %v\n", err) } case "bigquery": { - err = bigquery_form.Run() + err = bigqueryForm.Run() if err != nil { log.Fatalf("Error running bigquery form %v\n", err) } } case "duckdb": { - err = duckdb_form.Run() + err = duckdbForm.Run() if err != nil { log.Fatalf("Error running duckdb form %v\n", err) } @@ -253,16 +253,16 @@ tbd will _intentionally error out_.`), } } if formResponse.GenerateDescriptions { - err = llm_form.Run() + err = llmForm.Run() if err != nil { log.Fatalf("Error running LLM features form %v\n", err) } } - err = dir_form.Run() + err = dirForm.Run() if err != nil { log.Fatalf("Error running build directory form %v\n", err) } - err = confirm_form.Run() + err = confirmForm.Run() if err != nil { log.Fatalf("Error running confirmation form %v\n", err) } diff --git a/go.mod b/go.mod index 604f365..254f313 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/azcore v1.4.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.2 // indirect github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0 // indirect + github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c // indirect github.com/apache/arrow/go/v14 v14.0.2 // indirect github.com/apache/arrow/go/v15 v15.0.0 // indirect diff --git a/go.sum b/go.sum index 5af07b1..0a542af 100644 --- a/go.sum +++ b/go.sum @@ -30,6 +30,8 @@ github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.0.0/go.mod h1:2e8rMJtl2+ github.com/AzureAD/microsoft-authentication-library-for-go v0.5.1 h1:BWe8a+f/t+7KY7zH2mqygeUD0t8hNFXe08p1Pb3/jKE= github.com/AzureAD/microsoft-authentication-library-for-go v0.5.1/go.mod h1:Vt9sXTKwMyGcOxSmLDMnGPgqsUg7m8pe215qMLrDXw4= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c h1:RGWPOewvKIROun94nF7v2cua9qP+thov/7M50KEoeSU= github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c/go.mod h1:X0CRv0ky0k6m906ixxpzmDRLvX58TFUKS2eePweuyxk= github.com/apache/arrow/go/v14 v14.0.2 h1:N8OkaJEOfI3mEZt07BIkvo4sC6XDbL+48MBPWO5IONw= @@ -178,6 +180,7 @@ github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9Y github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/klauspost/compress v1.17.7 h1:ehO88t2UGzQK66LMdE8tibEd1ErmzZjNEqWkjLAKQQg= github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= diff --git a/main.go b/main.go index 97aa806..83b46e3 100644 --- a/main.go +++ b/main.go @@ -36,12 +36,20 @@ func main() { dbc, err := sourcerer.GetConn(cd) if err != nil { - log.Fatalf("Error getting connection: %v\n", err) + log.Fatalf("Error getting database connection: %v\n", err) } - ts, err := dbc.GetSources(ctx) + err = dbc.ConnectToDb(ctx) + if err != nil { + log.Fatalf("Error connecting to database: %v\n", err) + } + ts, err := dbc.GetSourceTables(ctx) if err != nil { log.Fatalf("Error getting sources: %v\n", err) } + err = sourcerer.PutColumnsOnTables(ctx, ts, dbc) + if err != nil { + log.Fatalf("Error putting columns on tables: %v\n", err) + } e.DbElapsed = time.Since(e.DbStart).Seconds() // End of database interaction, start of processing diff --git a/sourcerer/connect_to_db.go b/sourcerer/connect_to_db.go index cf08eb9..710e187 100644 --- a/sourcerer/connect_to_db.go +++ b/sourcerer/connect_to_db.go @@ -13,7 +13,7 @@ import ( _ "github.com/snowflakedb/gosnowflake" ) -func (sfc *SfConn) ConnectToDB(ctx context.Context) (err error) { +func (sfc *SfConn) ConnectToDb(ctx context.Context) (err error) { connStr := fmt.Sprintf( "%s@%s/%s/%s?authenticator=externalbrowser", sfc.Username, @@ -31,7 +31,7 @@ func (sfc *SfConn) ConnectToDB(ctx context.Context) (err error) { return err } -func (bqc *BqConn) ConnectToDB(ctx context.Context) (err error) { +func (bqc *BqConn) ConnectToDb(ctx context.Context) (err error) { _, bqc.Cancel = context.WithTimeout(ctx, 1*time.Minute) defer bqc.Cancel() bqc.Bq, err = bigquery.NewClient(ctx, bqc.Project) @@ -41,7 +41,7 @@ func (bqc *BqConn) ConnectToDB(ctx context.Context) (err error) { return err } -func (dc *DuckConn) ConnectToDB(ctx context.Context) (err error) { +func (dc *DuckConn) ConnectToDb(ctx context.Context) (err error) { _, dc.Cancel = context.WithTimeout(ctx, 1*time.Minute) defer dc.Cancel() if _, err := os.Stat(dc.Path); os.IsNotExist(err) { diff --git a/sourcerer/connect_to_db_test.go b/sourcerer/connect_to_db_test.go new file mode 100644 index 0000000..32ff39e --- /dev/null +++ b/sourcerer/connect_to_db_test.go @@ -0,0 +1,39 @@ +package sourcerer + +import ( + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/gwenwindflower/tbd/shared" +) + +func TestConnectToDb(t *testing.T) { + cd := shared.ConnectionDetails{ + ConnType: "snowflake", + Account: "dunedain.snowflakecomputing.com", + Username: "aragorn", + Database: "gondor", + Schema: "minas-tirith", + } + conn, err := GetConn(cd) + if err != nil { + t.Errorf("GetConn failed: %v", err) + } + SfConn, ok := conn.(*SfConn) + if !ok { + t.Errorf("conn not of type SfConn: %v", err) + } + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + SfConn.Db = db + defer SfConn.Db.Close() + mock.ExpectBegin() + if _, err := SfConn.Db.Begin(); err != nil { + t.Errorf("error '%s' was not expected, while pinging db", err) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} diff --git a/sourcerer/get_columns_test.go b/sourcerer/get_columns_test.go new file mode 100644 index 0000000..ae9e446 --- /dev/null +++ b/sourcerer/get_columns_test.go @@ -0,0 +1,57 @@ +package sourcerer + +import ( + "context" + "fmt" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/gwenwindflower/tbd/shared" +) + +func TestGetColumnsSnowflake(t *testing.T) { + t.SkipNow() + ctx := context.Background() + st := shared.SourceTable{ + Name: "table1", + } + cd := shared.ConnectionDetails{ + ConnType: "snowflake", + Account: "dunedain.snowflakecomputing.com", + Username: "aragorn", + Database: "gondor", + Schema: "minas-tirith", + } + conn, err := GetConn(cd) + if err != nil { + t.Errorf("GetConn failed: %v", err) + } + if conn == nil { + t.Errorf("GetConn failed: conn is nil") + } + SfConn, ok := conn.(*SfConn) + if !ok { + t.Errorf("GetConn failed: conn is not of type SfConn") + } + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + SfConn.Db = db + defer SfConn.Db.Close() + q := fmt.Sprintf("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = '%s' AND table_name = '%s'", SfConn.Schema, st.Name) + mock.ExpectQuery(q).WillReturnRows(sqlmock.NewRows([]string{"column_name", "data_type"}).AddRow("column1", "varchar").AddRow("column2", "varchar").AddRow("column3", "int")) + cols, err := SfConn.GetColumns(ctx, st) + if err != nil { + t.Errorf("GetColumns failed: %v", err) + } + if len(cols) != 1 { + t.Errorf("GetColumns failed: expected 1 column, got %d", len(cols)) + } + if cols[0].Name != "column1" { + t.Errorf("GetColumns failed: expected column name %s, got %s", "column1", cols[0].Name) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} diff --git a/sourcerer/get_conn.go b/sourcerer/get_conn.go index 162532f..3d22d01 100644 --- a/sourcerer/get_conn.go +++ b/sourcerer/get_conn.go @@ -12,10 +12,9 @@ import ( ) type DbConn interface { - ConnectToDB(ctx context.Context) error - GetSources(ctx context.Context) (shared.SourceTables, error) + ConnectToDb(ctx context.Context) error + GetSourceTables(ctx context.Context) (shared.SourceTables, error) GetColumns(ctx context.Context, t shared.SourceTable) ([]shared.Column, error) - PutColumnsOnTables(ctx context.Context, tables shared.SourceTables) } type SfConn struct { diff --git a/sourcerer/get_conn_test.go b/sourcerer/get_conn_test.go new file mode 100644 index 0000000..d2e9aec --- /dev/null +++ b/sourcerer/get_conn_test.go @@ -0,0 +1,77 @@ +package sourcerer + +import ( + "testing" + + "github.com/gwenwindflower/tbd/shared" +) + +func TestGetConnSnowflake(t *testing.T) { + t.SkipNow() + cd := shared.ConnectionDetails{ + ConnType: "snowflake", + Account: "dunedain.snowflakecomputing.com", + Username: "aragorn", + Database: "gondor", + Schema: "minas-tirith", + } + conn, err := GetConn(cd) + if err != nil { + t.Errorf("GetConn failed: %v", err) + } + if conn == nil { + t.Errorf("GetConn failed: conn is nil") + } + SfConn, ok := conn.(*SfConn) + if !ok { + t.Errorf("GetConn failed: conn is not of type SfConn") + } + if SfConn.Account != "DUNEDAIN.SNOWFLAKECOMPUTING.COM" { + t.Errorf("GetConn failed: Account is not correct") + } +} + +func TestGetConnBigQuery(t *testing.T) { + cd := shared.ConnectionDetails{ + ConnType: "bigquery", + Project: "mirkwood", + Dataset: "hall_of_thranduil", + } + conn, err := GetConn(cd) + if err != nil { + t.Errorf("GetConn failed: %v", err) + } + if conn == nil { + t.Errorf("GetConn failed: conn is nil") + } + BqConn, ok := conn.(*BqConn) + if !ok { + t.Errorf("GetConn failed: conn is not of type BqConn") + } + if BqConn.Dataset != "hall_of_thranduil" { + t.Errorf("GetConn failed: Account is not correct") + } +} + +func TestGetConnDuckDB(t *testing.T) { + cd := shared.ConnectionDetails{ + ConnType: "duckdb", + Path: "/path/to/duckdb.db", + Database: "lothlorien", + Schema: "mallorn_trees", + } + conn, err := GetConn(cd) + if err != nil { + t.Errorf("GetConn failed: %v", err) + } + if conn == nil { + t.Errorf("GetConn failed: conn is nil") + } + DuckConn, ok := conn.(*DuckConn) + if !ok { + t.Errorf("GetConn failed: conn is not of type DuckConn") + } + if DuckConn.Path != "/path/to/duckdb.db" { + t.Errorf("GetConn failed: Account is not correct") + } +} diff --git a/sourcerer/get_sources.go b/sourcerer/get_sources_tables.go similarity index 66% rename from sourcerer/get_sources.go rename to sourcerer/get_sources_tables.go index cc5caa8..c9d7e8b 100644 --- a/sourcerer/get_sources.go +++ b/sourcerer/get_sources_tables.go @@ -10,15 +10,10 @@ import ( "google.golang.org/api/iterator" ) -func (sfc *SfConn) GetSources(ctx context.Context) (shared.SourceTables, error) { +func (sfc *SfConn) GetSourceTables(ctx context.Context) (shared.SourceTables, error) { ts := shared.SourceTables{} - - err := sfc.ConnectToDB(ctx) defer sfc.Cancel() - if err != nil { - log.Fatalf("Couldn't connect to database: %v\n", err) - } - rows, err := sfc.Db.QueryContext(ctx, fmt.Sprintf("SELECT table_name FROM information_schema.tables where table_schema = '%s'", sfc.Schema)) + rows, err := sfc.Db.QueryContext(ctx, fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema = '%s'", sfc.Schema)) if err != nil { log.Fatalf("Error fetching tables: %v\n", err) } @@ -30,18 +25,12 @@ func (sfc *SfConn) GetSources(ctx context.Context) (shared.SourceTables, error) } ts.SourceTables = append(ts.SourceTables, table) } - sfc.PutColumnsOnTables(ctx, ts) - return ts, nil } -func (bqc *BqConn) GetSources(ctx context.Context) (shared.SourceTables, error) { +func (bqc *BqConn) GetSourceTables(ctx context.Context) (shared.SourceTables, error) { ts := shared.SourceTables{} - err := bqc.ConnectToDB(ctx) defer bqc.Cancel() - if err != nil { - log.Fatalf("Couldn't connect to database: %v\n", err) - } bqDataset := bqc.Bq.Dataset(bqc.Dataset) tableIter := bqDataset.Tables(ctx) for { @@ -54,17 +43,12 @@ func (bqc *BqConn) GetSources(ctx context.Context) (shared.SourceTables, error) } ts.SourceTables = append(ts.SourceTables, shared.SourceTable{Name: table.TableID}) } - bqc.PutColumnsOnTables(ctx, ts) return ts, nil } -func (dc *DuckConn) GetSources(ctx context.Context) (shared.SourceTables, error) { +func (dc *DuckConn) GetSourceTables(ctx context.Context) (shared.SourceTables, error) { ts := shared.SourceTables{} - err := dc.ConnectToDB(ctx) defer dc.Cancel() - if err != nil { - log.Fatalf("Couldn't connect to database: %v\n", err) - } q := fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema = '%s'", dc.Schema) rows, err := dc.Db.QueryContext(ctx, q) if err != nil { @@ -78,6 +62,5 @@ func (dc *DuckConn) GetSources(ctx context.Context) (shared.SourceTables, error) } ts.SourceTables = append(ts.SourceTables, table) } - dc.PutColumnsOnTables(ctx, ts) return ts, nil } diff --git a/sourcerer/get_sources_tables_test.go b/sourcerer/get_sources_tables_test.go new file mode 100644 index 0000000..1a83fc9 --- /dev/null +++ b/sourcerer/get_sources_tables_test.go @@ -0,0 +1,54 @@ +package sourcerer + +import ( + "context" + "fmt" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/gwenwindflower/tbd/shared" +) + +func TestGetSourceTablesSnowflake(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cd := shared.ConnectionDetails{ + ConnType: "snowflake", + Account: "dunedain.snowflakecomputing.com", + Username: "aragorn", + Database: "gondor", + Schema: "minas-tirith", + } + conn, err := GetConn(cd) + if err != nil { + t.Errorf("GetConn failed: %v", err) + } + if conn == nil { + t.Errorf("GetConn failed: conn is nil") + } + SfConn, ok := conn.(*SfConn) + if !ok { + t.Errorf("GetConn failed: conn is not of type SfConn") + } + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + SfConn.Db = db + SfConn.Cancel = cancel + defer SfConn.Db.Close() + q := fmt.Sprintf("SELECT table_name FROM information_schema.tables WHERE table_schema = '%s'", SfConn.Schema) + mock.ExpectQuery(q).WillReturnRows(sqlmock.NewRows([]string{"table_name"}).AddRow("table1").AddRow("table2")) + ts, err := SfConn.GetSourceTables(ctx) + if err != nil { + t.Errorf("GetSources failed: %v", err) + } + if len(ts.SourceTables) != 2 { + t.Errorf("GetSources failed: expected 2 sources, got %d", len(ts.SourceTables)) + } + if ts.SourceTables[0].Name != "table1" { + t.Errorf("GetSources failed: expected source name %s, got %s", "table1", ts.SourceTables[0].Name) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } +} diff --git a/sourcerer/put_columns_on_tables.go b/sourcerer/put_columns_on_tables.go index 68d9e55..53739e2 100644 --- a/sourcerer/put_columns_on_tables.go +++ b/sourcerer/put_columns_on_tables.go @@ -3,79 +3,51 @@ package sourcerer import ( "context" "fmt" - "log" "regexp" "sync" "github.com/gwenwindflower/tbd/shared" ) -func (sfc *SfConn) PutColumnsOnTables(ctx context.Context, tables shared.SourceTables) { +func PutColumnsOnTables(ctx context.Context, ts shared.SourceTables, dbc DbConn) error { dataTypeGroupMap := map[string]string{ - "(text|char)": "text", - "(float|int|num)": "numbers", - "(bool|bit)": "booleans", - "json": "json", - "date": "datetimes", - "timestamp": "timestamps", + "(text|char|varchar)": "text", + "(float|int|num|number|bigint|float32|float64|int8)": "numbers", + "(bool|boolean|bit)": "booleans", + "(json|struct)": "json", + "(date|datetime)": "datetimes", + "(timestamp|timestamptz|timestampntz|timestampltz)": "timestamps", } - columnPutter(ctx, tables, sfc, dataTypeGroupMap) -} - -func (bqc *BqConn) PutColumnsOnTables(ctx context.Context, tables shared.SourceTables) { - dataTypeGroupMap := map[string]string{ - "(string)": "text", - "(float|int)": "numbers", - "(bool)": "booleans", - "(json)": "json", - "(date)": "datetimes", - "(timestamp)": "timestamps", - } - columnPutter(ctx, tables, bqc, dataTypeGroupMap) -} - -func (dc *DuckConn) PutColumnsOnTables(ctx context.Context, tables shared.SourceTables) { - dataTypeGroupMap := map[string]string{ - "(string|varchar)": "text", - "(float|int)": "numbers", - "(bool)": "booleans", - "(json)": "json", - "(date)": "datetimes", - "(timestamp)": "timestamps", - } - columnPutter(ctx, tables, dc, dataTypeGroupMap) -} - -func columnPutter(ctx context.Context, tables shared.SourceTables, conn DbConn, dataTypeGroupMap map[string]string) { mutex := sync.Mutex{} var wg sync.WaitGroup - wg.Add(len(tables.SourceTables)) - for i := range tables.SourceTables { - go func(i int) { + wg.Add(len(ts.SourceTables)) + for i := range ts.SourceTables { + go func(i int) error { defer wg.Done() - columns, err := conn.GetColumns(ctx, tables.SourceTables[i]) + columns, err := dbc.GetColumns(ctx, ts.SourceTables[i]) if err != nil { - log.Fatalf("Error fetching columns for table %s: %v\n", tables.SourceTables[i].Name, err) - return + return err } mutex.Lock() - tables.SourceTables[i].Columns = columns - tables.SourceTables[i].DataTypeGroups = make(map[string][]shared.Column) + ts.SourceTables[i].Columns = columns + ts.SourceTables[i].DataTypeGroups = make(map[string][]shared.Column) // Create a map of data types groups to hold column slices by data type // This lets us group columns by their data type e.g. in templates - for j := range tables.SourceTables[i].Columns { + for j := range ts.SourceTables[i].Columns { for k, v := range dataTypeGroupMap { r, _ := regexp.Compile(fmt.Sprintf(`(?i).*%s.*`, k)) - if r.MatchString(tables.SourceTables[i].Columns[j].DataType) { - tables.SourceTables[i].DataTypeGroups[v] = append(tables.SourceTables[i].DataTypeGroups[v], tables.SourceTables[i].Columns[j]) + if r.MatchString(ts.SourceTables[i].Columns[j].DataType) { + ts.SourceTables[i].DataTypeGroups[v] = append(ts.SourceTables[i].DataTypeGroups[v], ts.SourceTables[i].Columns[j]) } } } mutex.Unlock() + return nil }(i) } wg.Wait() + return nil } diff --git a/sourcerer/put_columns_on_tables_test.go b/sourcerer/put_columns_on_tables_test.go new file mode 100644 index 0000000..1b7e4a3 --- /dev/null +++ b/sourcerer/put_columns_on_tables_test.go @@ -0,0 +1,62 @@ +package sourcerer + +import ( + "context" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/gwenwindflower/tbd/shared" +) + +func TestPutColumnsOnTables(t *testing.T) { + ctx := context.Background() + ts := shared.SourceTables{ + SourceTables: []shared.SourceTable{ + { + Name: "table1", + }, + }, + } + cd := shared.ConnectionDetails{ + ConnType: "snowflake", + Account: "dunedain.snowflakecomputing.com", + Username: "aragorn", + Database: "gondor", + Schema: "minas-tirith", + } + conn, err := GetConn(cd) + if err != nil { + t.Errorf("GetConn failed: %v", err) + } + if conn == nil { + t.Errorf("GetConn failed: conn is nil") + } + SfConn, ok := conn.(*SfConn) + if !ok { + t.Errorf("GetConn failed: conn is not of type SfConn") + } + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + defer db.Close() + SfConn.Db = db + rows := sqlmock.NewRows([]string{"column_name", "data_type"}).AddRow("column1", "text").AddRow("column2", "char").AddRow("COLUMN3", "int") + mock.ExpectQuery("SELECT column_name, data_type FROM information_schema.columns WHERE table_schema = 'MINAS-TIRITH' AND table_name = 'table1'").WillReturnRows(rows) + err = PutColumnsOnTables(ctx, ts, SfConn) + if err != nil { + t.Errorf("PutColumnsOnTables failed: %v", err) + } + if len(ts.SourceTables[0].Columns) != 3 { + t.Errorf("PutColumnsOnTables failed: expected 3 columns, got %d", len(ts.SourceTables[0].Columns)) + } + if ts.SourceTables[0].Columns[0].Name != "column1" { + t.Errorf("PutColumnsOnTables failed: expected column name column1, got %s", ts.SourceTables[0].Columns[0].Name) + } + if ts.SourceTables[0].Columns[0].DataType != "text" { + t.Errorf("PutColumnsOnTables failed: expected column data type text, got %s", ts.SourceTables[0].Columns[0].DataType) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("Expectations were not met: %v", err) + } +} diff --git a/version.go b/version.go index 2e93f5e..f19cdae 100644 --- a/version.go +++ b/version.go @@ -1,3 +1,3 @@ package main -const Version = "0.0.11" +const Version = "0.0.16"