Skip to content
This repository has been archived by the owner on May 17, 2024. It is now read-only.

Commit

Permalink
Merge pull request #498 from dlawin/issue_447
Browse files Browse the repository at this point in the history
handle all custom schemas scenarios
  • Loading branch information
dlawin authored Apr 14, 2023
2 parents 25692cb + b018eb2 commit 5632150
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 27 deletions.
38 changes: 27 additions & 11 deletions data_diff/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,17 @@ def dbt_diff(
datadiff_variables = dbt_parser.get_datadiff_variables()
config_prod_database = datadiff_variables.get("prod_database")
config_prod_schema = datadiff_variables.get("prod_schema")
config_prod_custom_schema = datadiff_variables.get("prod_custom_schema")
datasource_id = datadiff_variables.get("datasource_id")
custom_schemas = datadiff_variables.get("custom_schemas")
# custom schemas is default dbt behavior, so default to True if the var doesn't exist
custom_schemas = True if custom_schemas is None else custom_schemas
set_dbt_user_id(dbt_parser.dbt_user_id)
set_dbt_version(dbt_parser.dbt_version)
set_dbt_project_id(dbt_parser.dbt_project_id)

if datadiff_variables.get("custom_schemas") is not None:
logger.warning(
"vars: data_diff: custom_schemas: is no longer used and can be removed.\nTo utilize custom schemas, see the documentation here: https://docs.datafold.com/development_testing/open_source"
)

if is_cloud:
api = _initialize_api()
# exit so the user can set the key
Expand Down Expand Up @@ -125,7 +128,9 @@ def dbt_diff(
)

for model in models:
diff_vars = _get_diff_vars(dbt_parser, config_prod_database, config_prod_schema, model, custom_schemas)
diff_vars = _get_diff_vars(
dbt_parser, config_prod_database, config_prod_schema, config_prod_custom_schema, model
)

if diff_vars.primary_keys:
if is_cloud:
Expand All @@ -149,22 +154,33 @@ def _get_diff_vars(
dbt_parser: "DbtParser",
config_prod_database: Optional[str],
config_prod_schema: Optional[str],
config_prod_custom_schema: Optional[str],
model,
custom_schemas: bool,
) -> DiffVars:
dev_database = model.database
dev_schema = model.schema_

primary_keys = dbt_parser.get_pk_from_model(model, dbt_parser.unique_columns, "primary-key")

prod_database = config_prod_database if config_prod_database else dev_database
prod_schema = config_prod_schema if config_prod_schema else dev_schema

# if project has custom schemas (default)
# need to construct the prod schema as <prod_target_schema>_<custom_schema>
# https://docs.getdbt.com/docs/build/custom-schemas
if custom_schemas and model.config.schema_:
prod_schema = prod_schema + "_" + model.config.schema_
# prod schema name differs from dev schema name
if config_prod_schema:
custom_schema = model.config.schema_

# the model has a custom schema config(schema='some_schema')
if custom_schema:
if not config_prod_custom_schema:
raise ValueError(
f"Found a custom schema on model {model.name}, but no value for\nvars:\n data_diff:\n prod_custom_schema:\nPlease set a value!\n"
+ "For more details see: https://docs.datafold.com/development_testing/open_source"
)
prod_schema = config_prod_custom_schema.replace("<custom_schema>", custom_schema)
# no custom schema, use the default
else:
prod_schema = config_prod_schema
else:
prod_schema = dev_schema

if dbt_parser.requires_upper:
dev_qualified_list = [x.upper() for x in [dev_database, dev_schema, model.alias]]
Expand Down
72 changes: 56 additions & 16 deletions tests/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,9 +740,6 @@ def test_diff_not_is_cloud_no_pks(
"prod_schema": "prod_schema",
"datasource_id": 1,
}
host = "a_host"
url = "a_url"
api_key = "a_api_key"

mock_dbt_parser_inst.get_models.return_value = [mock_model]
mock_dbt_parser_inst.get_datadiff_variables.return_value = expected_dbt_vars_dict
Expand All @@ -756,42 +753,67 @@ def test_diff_not_is_cloud_no_pks(
mock_local_diff.assert_not_called()
self.assertEqual(mock_print.call_count, 1)

def test_get_diff_vars_custom_schemas_prod_db_and_schema(self):
def test_get_diff_vars_replace_custom_schema(self):
mock_model = Mock()
prod_database = "a_prod_db"
prod_schema = "a_prod_schema"
primary_keys = ["a_primary_key"]
mock_model.database = "a_dev_db"
mock_model.schema_ = "a_custom_dev_schema"
mock_model.schema_ = "a_custom_schema"
mock_model.config.schema_ = mock_model.schema_
mock_model.alias = "a_model_name"
mock_dbt_parser = Mock()
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
mock_dbt_parser.requires_upper = False

diff_vars = _get_diff_vars(mock_dbt_parser, "a_prod_db", "a_prod_schema", mock_model, custom_schemas=True)
diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, prod_schema, "prod_<custom_schema>", mock_model)

assert diff_vars.dev_path == [mock_model.database, mock_model.schema_, mock_model.alias]
assert diff_vars.prod_path == [prod_database, prod_schema + "_" + mock_model.schema_, mock_model.alias]
assert diff_vars.prod_path == [prod_database, "prod_" + mock_model.schema_, mock_model.alias]
assert diff_vars.primary_keys == primary_keys
assert diff_vars.connection == mock_dbt_parser.connection
assert diff_vars.threads == mock_dbt_parser.threads
assert prod_schema not in diff_vars.prod_path

mock_dbt_parser.get_pk_from_model.assert_called_once()

def test_get_diff_vars_false_custom_schemas_prod_db_and_schema(self):
def test_get_diff_vars_static_custom_schema(self):
mock_model = Mock()
prod_database = "a_prod_db"
prod_schema = "a_prod_schema"
primary_keys = ["a_primary_key"]
mock_model.database = "a_dev_db"
mock_model.schema_ = "a_custom_dev_schema"
mock_model.schema_ = "a_custom_schema"
mock_model.config.schema_ = mock_model.schema_
mock_model.alias = "a_model_name"
mock_dbt_parser = Mock()
mock_dbt_parser.get_pk_from_model.return_value = ["a_primary_key"]
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
mock_dbt_parser.requires_upper = False

diff_vars = _get_diff_vars(mock_dbt_parser, "a_prod_db", "a_prod_schema", mock_model, custom_schemas=False)
diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, prod_schema, "prod", mock_model)

assert diff_vars.dev_path == [mock_model.database, mock_model.schema_, mock_model.alias]
assert diff_vars.prod_path == [prod_database, "prod", mock_model.alias]
assert diff_vars.primary_keys == primary_keys
assert diff_vars.connection == mock_dbt_parser.connection
assert diff_vars.threads == mock_dbt_parser.threads
assert prod_schema not in diff_vars.prod_path
mock_dbt_parser.get_pk_from_model.assert_called_once()

def test_get_diff_vars_no_custom_schema_on_model(self):
mock_model = Mock()
prod_database = "a_prod_db"
prod_schema = "a_prod_schema"
primary_keys = ["a_primary_key"]
mock_model.database = "a_dev_db"
mock_model.schema_ = "a_custom_schema"
mock_model.config.schema_ = None
mock_model.alias = "a_model_name"
mock_dbt_parser = Mock()
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
mock_dbt_parser.requires_upper = False

diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, prod_schema, "prod", mock_model)

assert diff_vars.dev_path == [mock_model.database, mock_model.schema_, mock_model.alias]
assert diff_vars.prod_path == [prod_database, prod_schema, mock_model.alias]
Expand All @@ -800,23 +822,41 @@ def test_get_diff_vars_false_custom_schemas_prod_db_and_schema(self):
assert diff_vars.threads == mock_dbt_parser.threads
mock_dbt_parser.get_pk_from_model.assert_called_once()

def test_get_diff_vars_false_custom_schemas_prod_db(self):
def test_get_diff_vars_match_dev_schema(self):
mock_model = Mock()
prod_database = "a_prod_db"
primary_keys = ["a_primary_key"]
mock_model.database = "a_dev_db"
mock_model.schema_ = "a_custom_dev_schema"
mock_model.config.schema_ = mock_model.schema_
mock_model.schema_ = "a_schema"
mock_model.config.schema_ = None
mock_model.alias = "a_model_name"
mock_dbt_parser = Mock()
mock_dbt_parser.get_pk_from_model.return_value = ["a_primary_key"]
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
mock_dbt_parser.requires_upper = False

diff_vars = _get_diff_vars(mock_dbt_parser, "a_prod_db", None, mock_model, custom_schemas=False)
diff_vars = _get_diff_vars(mock_dbt_parser, prod_database, None, None, mock_model)

assert diff_vars.dev_path == [mock_model.database, mock_model.schema_, mock_model.alias]
assert diff_vars.prod_path == [prod_database, mock_model.schema_, mock_model.alias]
assert diff_vars.primary_keys == primary_keys
assert diff_vars.connection == mock_dbt_parser.connection
assert diff_vars.threads == mock_dbt_parser.threads
mock_dbt_parser.get_pk_from_model.assert_called_once()

def test_get_diff_custom_schema_no_config_exception(self):
mock_model = Mock()
prod_database = "a_prod_db"
prod_schema = "a_prod_schema"
primary_keys = ["a_primary_key"]
mock_model.database = "a_dev_db"
mock_model.schema_ = "a_schema"
mock_model.config.schema_ = "a_custom_schema"
mock_model.alias = "a_model_name"
mock_dbt_parser = Mock()
mock_dbt_parser.get_pk_from_model.return_value = primary_keys
mock_dbt_parser.requires_upper = False

with self.assertRaises(ValueError):
_get_diff_vars(mock_dbt_parser, prod_database, prod_schema, None, mock_model)

mock_dbt_parser.get_pk_from_model.assert_called_once()

0 comments on commit 5632150

Please sign in to comment.