From 16aa26e02540d062f132867882b05b48be6e3d8e Mon Sep 17 00:00:00 2001
From: shekhar-rudder <shekhar@rudderstack.com>
Date: Mon, 6 Jan 2025 13:19:34 +0530
Subject: [PATCH] fix: implementation

---
 .../integrations/bigquery/bigquery_test.go    |   1 +
 warehouse/router/router.go                    |   2 +-
 warehouse/router/sync.go                      |  16 +--
 warehouse/router/sync_test.go                 |  12 +-
 warehouse/router/upload.go                    |   7 +-
 warehouse/router/upload_test.go               |   9 ++
 warehouse/schema/schema.go                    |  31 +++--
 warehouse/schema/schema_v2.go                 |  60 ++++++----
 warehouse/schema/schema_v2_test.go            | 109 ++++++++++++++++++
 9 files changed, 190 insertions(+), 57 deletions(-)
 create mode 100644 warehouse/schema/schema_v2_test.go

diff --git a/warehouse/integrations/bigquery/bigquery_test.go b/warehouse/integrations/bigquery/bigquery_test.go
index adb6364828..4aecc69738 100644
--- a/warehouse/integrations/bigquery/bigquery_test.go
+++ b/warehouse/integrations/bigquery/bigquery_test.go
@@ -607,6 +607,7 @@ func TestIntegration(t *testing.T) {
 				t.Setenv("RSERVER_WAREHOUSE_BIGQUERY_ENABLE_DELETE_BY_JOBS", "true")
 				t.Setenv("RSERVER_WAREHOUSE_BIGQUERY_MAX_PARALLEL_LOADS", "8")
 				t.Setenv("RSERVER_WAREHOUSE_BIGQUERY_SLOW_QUERY_THRESHOLD", "0s")
+				t.Setenv("RSERVER_WAREHOUSE_SYNC_SCHEMA_FREQUENCY", "5s")
 
 				whth.BootstrapSvc(t, workspaceConfig, httpPort, jobsDBPort)
 
diff --git a/warehouse/router/router.go b/warehouse/router/router.go
index ea224cef99..e66fd20189 100644
--- a/warehouse/router/router.go
+++ b/warehouse/router/router.go
@@ -721,7 +721,7 @@ func (r *Router) loadReloadableConfig(whName string) {
 	r.config.cronTrackerRetries = r.conf.GetReloadableInt64Var(5, 1, "Warehouse.cronTrackerRetries")
 	r.config.uploadBufferTimeInMin = r.conf.GetReloadableDurationVar(180, time.Minute, "Warehouse.uploadBufferTimeInMin")
 	r.config.syncSchemaFrequency = r.conf.GetDurationVar(12, time.Hour, "Warehouse.syncSchemaFrequency")
-    r.config.enableSyncSchema = r.conf.GetBoolVar(true, "Warehouse.enableSyncSchema")
+	r.config.enableSyncSchema = r.conf.GetBoolVar(true, "Warehouse.enableSyncSchema")
 }
 
 func (r *Router) loadStats() {
diff --git a/warehouse/router/sync.go b/warehouse/router/sync.go
index ef21e3b55b..7f3078517a 100644
--- a/warehouse/router/sync.go
+++ b/warehouse/router/sync.go
@@ -2,12 +2,12 @@ package router
 
 import (
 	"context"
-	"fmt"
 	"time"
 
 	obskit "github.com/rudderlabs/rudder-observability-kit/go/labels"
 	"github.com/rudderlabs/rudder-server/warehouse/integrations/manager"
 	"github.com/rudderlabs/rudder-server/warehouse/internal/model"
+	"github.com/rudderlabs/rudder-server/warehouse/internal/repo"
 	"github.com/rudderlabs/rudder-server/warehouse/schema"
 	warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils"
 )
@@ -18,20 +18,22 @@ func (r *Router) sync(ctx context.Context) error {
 		warehouses := append([]model.Warehouse{}, r.warehouses...)
 		r.configSubscriberLock.RUnlock()
 		execTime := time.Now()
-		whManager, err := manager.New(r.destType, r.conf, r.logger, r.statsFactory)
-		if err != nil {
-			return fmt.Errorf("failed to create warehouse manager: %w", err)
-		}
 		for _, warehouse := range warehouses {
-			err := whManager.Setup(ctx, warehouse, warehouseutils.NewNoOpUploader())
+			whManager, err := manager.New(r.destType, r.conf, r.logger, r.statsFactory)
+			if err != nil {
+				r.logger.Errorn("create warehouse manager: %w", obskit.Error(err))
+				continue
+			}
+			err = whManager.Setup(ctx, warehouse, warehouseutils.NewNoOpUploader())
 			if err != nil {
 				r.logger.Errorn("failed to setup WH Manager", obskit.Error(err))
 				continue
 			}
-			if err := schema.SyncSchema(ctx, whManager, warehouse, r.db, r.logger.Child("syncer")); err != nil {
+			if err := schema.FetchAndSaveSchema(ctx, whManager, warehouse, repo.NewWHSchemas(r.db), r.logger.Child("syncer")); err != nil {
 				r.logger.Errorn("failed to sync schema", obskit.Error(err))
 				continue
 			}
+			whManager.Cleanup(ctx)
 		}
 		nextExecTime := execTime.Add(r.config.syncSchemaFrequency)
 		select {
diff --git a/warehouse/router/sync_test.go b/warehouse/router/sync_test.go
index cb5666d9cc..f1b2d12a86 100644
--- a/warehouse/router/sync_test.go
+++ b/warehouse/router/sync_test.go
@@ -25,6 +25,12 @@ import (
 	warehouseutils "github.com/rudderlabs/rudder-server/warehouse/utils"
 )
 
+type mockFetchSchemaRepo struct{}
+
+func (m mockFetchSchemaRepo) FetchSchema(ctx context.Context) (model.Schema, error) {
+	return model.Schema{}, nil
+}
+
 func TestSync_SyncRemoteSchemaIntegration(t *testing.T) {
 	destinationType := warehouseutils.POSTGRES
 	bucket := "some-bucket"
@@ -124,16 +130,16 @@ func TestSync_SyncRemoteSchemaIntegration(t *testing.T) {
 
 		<-setupCh
 		r.conf.Set("Warehouse.enableSyncSchema", true)
-		sh, err := schema.New(
-			context.Background(),
+		sh := schema.New(
 			r.db,
 			warehouse,
 			r.conf,
 			r.logger.Child("syncer"),
 			r.statsFactory,
 		)
-		require.NoError(t, err)
 		require.Eventually(t, func() bool {
+			_, err := sh.SyncRemoteSchema(ctx, &mockFetchSchemaRepo{}, 0)
+			require.NoError(t, err)
 			schema, err := sh.GetLocalSchema(ctx)
 			require.NoError(t, err)
 			return reflect.DeepEqual(schema, model.Schema{
diff --git a/warehouse/router/upload.go b/warehouse/router/upload.go
index c611499929..401157e38d 100644
--- a/warehouse/router/upload.go
+++ b/warehouse/router/upload.go
@@ -148,18 +148,13 @@ func (f *UploadJobFactory) NewUploadJob(ctx context.Context, dto *model.UploadJo
 		logfield.UseRudderStorage, dto.Upload.UseRudderStorage,
 	)
 
-	schemaHandle, err := schema.New(
-		ctx,
+	schemaHandle := schema.New(
 		f.db,
 		dto.Warehouse,
 		f.conf,
 		f.logger.Child("warehouse"),
 		f.statsFactory,
 	)
-	if err != nil {
-		log.Errorw("failed to create schema handler", logfield.Error, err)
-		return nil
-	}
 
 	uj := &UploadJob{
 		ctx:                  ujCtx,
diff --git a/warehouse/router/upload_test.go b/warehouse/router/upload_test.go
index 03b2539162..a82c2cec0e 100644
--- a/warehouse/router/upload_test.go
+++ b/warehouse/router/upload_test.go
@@ -126,12 +126,21 @@ func TestColumnCountStat(t *testing.T) {
 		tc := tc
 
 		t.Run(tc.name, func(t *testing.T) {
+			t.Parallel()
 			conf := config.New()
 			conf.Set(fmt.Sprintf("Warehouse.%s.columnCountLimit", strings.ToLower(warehouseutils.WHDestNameMap[tc.destinationType])), tc.columnCountLimit)
+
+			pool, err := dockertest.NewPool("")
+			require.NoError(t, err)
+
+			pgResource, err := postgres.Setup(pool, t)
+			require.NoError(t, err)
+
 			uploadJobFactory := &UploadJobFactory{
 				logger:       logger.NOP,
 				statsFactory: statsStore,
 				conf:         conf,
+				db:           sqlmiddleware.New(pgResource.DB),
 			}
 			rs := redshift.New(config.New(), logger.NOP, stats.NOP)
 			j := uploadJobFactory.NewUploadJob(context.Background(), &model.UploadJob{
diff --git a/warehouse/schema/schema.go b/warehouse/schema/schema.go
index ecc789fc7a..057748ef27 100644
--- a/warehouse/schema/schema.go
+++ b/warehouse/schema/schema.go
@@ -78,13 +78,12 @@ type schema struct {
 }
 
 func New(
-	ctx context.Context,
 	db *sqlquerywrapper.DB,
 	warehouse model.Warehouse,
 	conf *config.Config,
 	logger logger.Logger,
 	statsFactory stats.Stats,
-) (SchemaHandler, error) {
+) SchemaHandler {
 	schemaSize := statsFactory.NewTaggedStat("warehouse_schema_size", stats.HistogramType, stats.Tags{
 		"module":        "warehouse",
 		"workspaceId":   warehouse.WorkspaceID,
@@ -102,15 +101,10 @@ func New(
 		enableIDResolution:               conf.GetBool("Warehouse.enableIDResolution", false),
 	}
 	if conf.GetBoolVar(true, "Warehouse.enableSyncSchema") {
-		schemaHandler, err := newSchemaV2(ctx, schemaV1, warehouse, log)
-		if err != nil {
-			return nil, fmt.Errorf("creating schema handler: %w", err)
-		}
-		schemaHandler.stats.schemaSize = schemaSize
-		return schemaHandler, nil
+		return newSchemaV2(schemaV1, warehouse, log, schemaSize)
 	}
 	schemaV1.stats.schemaSize = schemaSize
-	return schemaV1, nil
+	return schemaV1
 }
 
 // ConsolidateStagingFilesUsingLocalSchema
@@ -293,13 +287,7 @@ func (sh *schema) updateLocalSchema(ctx context.Context, updatedSchema model.Sch
 	}
 	sh.stats.schemaSize.Observe(float64(len(updatedSchemaInBytes)))
 
-	_, err = sh.schemaRepo.Insert(ctx, &model.WHSchema{
-		SourceID:        sh.warehouse.Source.ID,
-		Namespace:       sh.warehouse.Namespace,
-		DestinationID:   sh.warehouse.Destination.ID,
-		DestinationType: sh.warehouse.Type,
-		Schema:          updatedSchema,
-	})
+	err = writeSchema(ctx, sh.schemaRepo, sh.warehouse, updatedSchema)
 	if err != nil {
 		return fmt.Errorf("updating local schema: %w", err)
 	}
@@ -481,3 +469,14 @@ func removeDeprecatedColumns(schema model.Schema, warehouse model.Warehouse, log
 		}
 	}
 }
+
+func writeSchema(ctx context.Context, schemaRepo schemaRepo, warehouse model.Warehouse, updatedSchema model.Schema) error {
+	_, err := schemaRepo.Insert(ctx, &model.WHSchema{
+		SourceID:        warehouse.Source.ID,
+		Namespace:       warehouse.Namespace,
+		DestinationID:   warehouse.Destination.ID,
+		DestinationType: warehouse.Type,
+		Schema:          updatedSchema,
+	})
+	return err
+}
diff --git a/warehouse/schema/schema_v2.go b/warehouse/schema/schema_v2.go
index 109d4a724c..e47f34ff75 100644
--- a/warehouse/schema/schema_v2.go
+++ b/warehouse/schema/schema_v2.go
@@ -9,7 +9,6 @@ import (
 
 	"github.com/rudderlabs/rudder-go-kit/logger"
 	"github.com/rudderlabs/rudder-go-kit/stats"
-	"github.com/rudderlabs/rudder-server/warehouse/integrations/middleware/sqlquerywrapper"
 	"github.com/rudderlabs/rudder-server/warehouse/internal/model"
 	"github.com/rudderlabs/rudder-server/warehouse/internal/repo"
 	whutils "github.com/rudderlabs/rudder-server/warehouse/utils"
@@ -19,6 +18,7 @@ type schemaV2 struct {
 	stats struct {
 		schemaSize stats.Histogram
 	}
+	// caches the schema present in the repository
 	cachedSchema model.Schema
 	warehouse    model.Warehouse
 	v1           *schema
@@ -26,29 +26,45 @@ type schemaV2 struct {
 	schemaMu     sync.RWMutex
 }
 
-func SyncSchema(ctx context.Context, fetchSchemaRepo fetchSchemaRepo, warehouse model.Warehouse, db *sqlquerywrapper.DB, log logger.Logger) error {
+func FetchAndSaveSchema(ctx context.Context, fetchSchemaRepo fetchSchemaRepo, warehouse model.Warehouse, schemaRepo schemaRepo, log logger.Logger) error {
 	warehouseSchema, err := fetchSchemaRepo.FetchSchema(ctx)
 	if err != nil {
 		return fmt.Errorf("fetching schema: %w", err)
 	}
 	removeDeprecatedColumns(warehouseSchema, warehouse, log)
-	schemaRepo := repo.NewWHSchemas(db)
 	return writeSchema(ctx, schemaRepo, warehouse, warehouseSchema)
 }
 
-func newSchemaV2(ctx context.Context, v1 *schema, warehouse model.Warehouse, log logger.Logger) (*schemaV2, error) {
-	v2 := &schemaV2{
-		v1:        v1,
-		warehouse: warehouse,
-		log:       log,
+func newSchemaV2(v1 *schema, warehouse model.Warehouse, log logger.Logger, schemaSize stats.Histogram) *schemaV2 {
+	return &schemaV2{
+		v1:           v1,
+		warehouse:    warehouse,
+		log:          log,
+		cachedSchema: model.Schema{},
+		stats: struct {
+			schemaSize stats.Histogram
+		}{
+			schemaSize: schemaSize,
+		},
 	}
-	var err error
-	v2.cachedSchema, err = v1.GetLocalSchema(ctx)
-	return v2, err
 }
 
 func (sh *schemaV2) SyncRemoteSchema(ctx context.Context, fetchSchemaRepo fetchSchemaRepo, uploadID int64) (bool, error) {
-	// no-op since syncing of local schema with warehouse schema is being done in the background
+	whSchema, err := sh.v1.schemaRepo.GetForNamespace(
+		ctx,
+		sh.warehouse.Source.ID,
+		sh.warehouse.Destination.ID,
+		sh.warehouse.Namespace,
+	)
+	if err != nil {
+		return false, fmt.Errorf("getting schema for namespace: %w", err)
+	}
+	if whSchema.Schema == nil {
+		return false, nil
+	}
+	sh.schemaMu.Lock()
+	defer sh.schemaMu.Unlock()
+	sh.cachedSchema = whSchema.Schema
 	return false, nil
 }
 
@@ -85,7 +101,14 @@ func (sh *schemaV2) UpdateLocalSchema(ctx context.Context, updatedSchema model.S
 }
 
 func (sh *schemaV2) UpdateWarehouseTableSchema(tableName string, tableSchema model.TableSchema) {
-	// no-op since there is no warehouse schema to update
+	sh.schemaMu.Lock()
+	defer sh.schemaMu.Unlock()
+	sh.cachedSchema[tableName] = tableSchema
+	err := writeSchema(context.TODO(), sh.v1.schemaRepo, sh.warehouse, sh.cachedSchema)
+	if err != nil {
+		// TODO - Return error to the caller
+		sh.log.Errorf("error updating warehouse schema: %v", err)
+	}
 }
 
 func (sh *schemaV2) GetColumnsCountInWarehouseSchema(tableName string) int {
@@ -130,14 +153,3 @@ func (sh *schemaV2) FetchSchemaFromWarehouse(ctx context.Context, repo fetchSche
 	// no-op since local schema and warehouse schema are supposed to be in sync
 	return nil
 }
-
-func writeSchema(ctx context.Context, schemaRepo schemaRepo, warehouse model.Warehouse, updatedSchema model.Schema) error {
-	_, err := schemaRepo.Insert(ctx, &model.WHSchema{
-		SourceID:        warehouse.Source.ID,
-		Namespace:       warehouse.Namespace,
-		DestinationID:   warehouse.Destination.ID,
-		DestinationType: warehouse.Type,
-		Schema:          updatedSchema,
-	})
-	return err
-}
diff --git a/warehouse/schema/schema_v2_test.go b/warehouse/schema/schema_v2_test.go
new file mode 100644
index 0000000000..2f2664787d
--- /dev/null
+++ b/warehouse/schema/schema_v2_test.go
@@ -0,0 +1,109 @@
+package schema
+
+import (
+	"context"
+	"testing"
+
+	"github.com/stretchr/testify/require"
+
+	"github.com/rudderlabs/rudder-go-kit/stats"
+	backendconfig "github.com/rudderlabs/rudder-server/backend-config"
+	"github.com/rudderlabs/rudder-server/warehouse/internal/model"
+)
+
+type mFetchSchemaRepo struct{}
+
+var schema1 = model.Schema{
+	"identifies": model.TableSchema{
+		"id":      "string",
+		"user_id": "int",
+	},
+}
+
+var warehouse = model.Warehouse{
+	WorkspaceID: "workspaceID",
+	Source: backendconfig.SourceT{
+		ID: "sourceID",
+	},
+	Destination: backendconfig.DestinationT{
+		ID: "destinationID",
+	},
+	Namespace: "namespace",
+	Type:      "warehouseType",
+}
+
+func (m *mFetchSchemaRepo) FetchSchema(ctx context.Context) (model.Schema, error) {
+	return schema1, nil
+}
+
+type mSchemaRepo struct {
+	schemaMap map[string]model.WHSchema
+}
+
+func (m *mSchemaRepo) Insert(ctx context.Context, whSchema *model.WHSchema) (int64, error) {
+	if m.schemaMap == nil {
+		m.schemaMap = make(map[string]model.WHSchema)
+	}
+	m.schemaMap[whSchema.Namespace] = *whSchema
+	return 0, nil
+}
+
+func (m *mSchemaRepo) GetForNamespace(ctx context.Context, sourceID, destinationID, namespace string) (model.WHSchema, error) {
+	return m.schemaMap[namespace], nil
+}
+
+func TestFetchAndSaveSchema(t *testing.T) {
+	schemaRepo := &mSchemaRepo{}
+	v1 := &schema{
+		schemaRepo: schemaRepo,
+	}
+	v2 := newSchemaV2(v1, warehouse, nil, stats.NOP.NewStat("schema_size", "size of the schema"))
+	isEmpty := v2.IsWarehouseSchemaEmpty()
+	require.True(t, isEmpty)
+
+	ctx := context.Background()
+	fetchSchemaRepo := &mFetchSchemaRepo{}
+
+	err := FetchAndSaveSchema(ctx, fetchSchemaRepo, warehouse, schemaRepo, nil)
+	require.NoError(t, err)
+
+	schema, err := v2.GetLocalSchema(ctx)
+	require.NoError(t, err)
+	require.Equal(t, model.Schema{}, schema)
+
+	hasSchemaChanged, err := v2.SyncRemoteSchema(ctx, fetchSchemaRepo, 0)
+	require.NoError(t, err)
+	require.False(t, hasSchemaChanged)
+
+	schema, err = v2.GetLocalSchema(ctx)
+	require.NoError(t, err)
+	require.Equal(t, schema1, schema)
+}
+
+func TestUpdateLocalSchema(t *testing.T) {
+	schemaRepo := &mSchemaRepo{}
+	v1 := &schema{
+		schemaRepo: schemaRepo,
+	}
+	v2 := newSchemaV2(v1, warehouse, nil, stats.NOP.NewStat("schema_size", "size of the schema"))
+
+	ctx := context.Background()
+
+	schema, err := v2.GetLocalSchema(ctx)
+	require.NoError(t, err)
+	require.Equal(t, model.Schema{}, schema)
+
+	schema2 := model.Schema{
+		"users": model.TableSchema{
+			"anonymous_id": "string",
+			"received_at":  "datetime",
+		},
+	}
+
+	err = v2.UpdateLocalSchema(ctx, schema2)
+	require.NoError(t, err)
+
+	schema, err = v2.GetLocalSchema(ctx)
+	require.NoError(t, err)
+	require.Equal(t, schema2, schema)
+}