diff --git a/dbx/api/config_reader.py b/dbx/api/config_reader.py index 344740a8..b26bbaee 100644 --- a/dbx/api/config_reader.py +++ b/dbx/api/config_reader.py @@ -58,9 +58,10 @@ def read_content(content: Dict[str, Any]) -> DeploymentConfig: class Jinja2ConfigReader(_AbstractConfigReader): - def __init__(self, path: Path, ext: str, jinja_vars_file: Optional[Path]): + def __init__(self, path: Path, ext: str, jinja_vars_file: Optional[Path], jinja_wd: Optional[bool]): self._ext = ext self._jinja_vars_file = jinja_vars_file + self._jinja_wd = jinja_wd super().__init__(path) @staticmethod @@ -68,13 +69,17 @@ def _read_vars_file(file_path: Path) -> Dict[str, Any]: return yaml.load(file_path.read_text(encoding="utf-8"), yaml.SafeLoader) @classmethod - def _render_content(cls, file_path: Path, _var: Dict[str, Any]) -> str: - absolute_parent_path = file_path.absolute().parent - file_name = file_path.name - - dbx_echo(f"The following path will be used for the jinja loader: {absolute_parent_path} with file {file_name}") + def _render_content(cls, file_path: Path, _var: Dict[str, Any], _wd: bool) -> str: + if _wd: + working_dir = os.path.abspath(os.curdir) + file_name = str(file_path.relative_to(working_dir)) + else: + working_dir = file_path.absolute().parent + file_name = file_path.name + + dbx_echo(f"The following path will be used for the jinja loader: {working_dir} with file {file_name}") - env = jinja2.Environment(loader=jinja2.FileSystemLoader(absolute_parent_path)) + env = jinja2.Environment(loader=jinja2.FileSystemLoader(working_dir)) template = env.get_template(file_name) template.globals["dbx"] = dbx_jinja @@ -92,7 +97,8 @@ def add_custom_functions(cls, template: jinja2.Template): def _read_file(self) -> DeploymentConfig: _var = {} if not self._jinja_vars_file else self._read_vars_file(self._jinja_vars_file) - rendered = self._render_content(self._path, _var) + _wd = self._jinja_wd if self._jinja_wd is not None else False + rendered = self._render_content(self._path, _var, _wd) if self._ext == ".json": _content = json.loads(rendered) @@ -116,8 +122,9 @@ class ConfigReader: If a new reader is introduced, it shall be used via the :code:`_define_reader` method. """ - def __init__(self, path: Path, jinja_vars_file: Optional[Path] = None): + def __init__(self, path: Path, jinja_vars_file: Optional[Path] = None, jinja_wd: Optional[bool] = False): self._jinja_vars_file = jinja_vars_file + self._jinja_wd = jinja_wd self._path = path self._reader = self._define_reader() self._build_properties = BuildProperties() @@ -135,9 +142,9 @@ def _define_reader(self) -> _AbstractConfigReader: you can also configure your project to support in-place Jinja by running: [code]dbx configure --enable-inplace-jinja-support[/code][/bright_magenta bold]""" ) - return Jinja2ConfigReader(self._path, ext=self._path.suffixes[0], jinja_vars_file=self._jinja_vars_file) + return Jinja2ConfigReader(self._path, ext=self._path.suffixes[0], jinja_vars_file=self._jinja_vars_file, jinja_wd=self._jinja_wd) elif ProjectConfigurationManager().get_jinja_support(): - return Jinja2ConfigReader(self._path, ext=self._path.suffixes[0], jinja_vars_file=self._jinja_vars_file) + return Jinja2ConfigReader(self._path, ext=self._path.suffixes[0], jinja_vars_file=self._jinja_vars_file, jinja_wd=self._jinja_wd) else: if self._jinja_vars_file: raise Exception( diff --git a/dbx/commands/deploy.py b/dbx/commands/deploy.py index f2a3434b..976be37e 100644 --- a/dbx/commands/deploy.py +++ b/dbx/commands/deploy.py @@ -18,6 +18,7 @@ DEPLOYMENT_FILE_OPTION, ENVIRONMENT_OPTION, HEADERS_OPTION, + JINJA_TEMPLATES_WORKING_DIR, JINJA_VARIABLES_FILE_OPTION, NO_PACKAGE_OPTION, NO_REBUILD_OPTION, @@ -88,6 +89,7 @@ def deploy( ), headers: Optional[List[str]] = HEADERS_OPTION, branch_name: Optional[str] = BRANCH_NAME_OPTION, + jinja_working_directory: Optional[bool] = JINJA_TEMPLATES_WORKING_DIR, jinja_variables_file: Optional[Path] = JINJA_VARIABLES_FILE_OPTION, debug: Optional[bool] = DEBUG_OPTION, # noqa ): @@ -104,7 +106,7 @@ def deploy( if not branch_name: branch_name = get_current_branch_name() - config_reader = ConfigReader(deployment_file, jinja_variables_file) + config_reader = ConfigReader(deployment_file, jinja_variables_file, jinja_working_directory) config = config_reader.with_build_properties( BuildProperties(potential_build=True, no_rebuild=no_rebuild) ).get_config() diff --git a/dbx/commands/destroy.py b/dbx/commands/destroy.py index 07d8728a..b124dec5 100644 --- a/dbx/commands/destroy.py +++ b/dbx/commands/destroy.py @@ -14,6 +14,7 @@ DEPLOYMENT_FILE_OPTION, ENVIRONMENT_OPTION, HEADERS_OPTION, + JINJA_TEMPLATES_WORKING_DIR, JINJA_VARIABLES_FILE_OPTION, WORKFLOW_ARGUMENT, ) @@ -28,6 +29,7 @@ def destroy( ), deployment_file: Optional[Path] = DEPLOYMENT_FILE_OPTION, environment_name: str = ENVIRONMENT_OPTION, + jinja_working_directory: Optional[bool] = JINJA_TEMPLATES_WORKING_DIR, jinja_variables_file: Optional[Path] = JINJA_VARIABLES_FILE_OPTION, deletion_mode: DeletionMode = typer.Option( DeletionMode.all, @@ -60,7 +62,7 @@ def destroy( workflow_names = workflow_names.split(",") if workflow_names else [] - global_config = ConfigReader(deployment_file, jinja_variables_file).get_config() + global_config = ConfigReader(deployment_file, jinja_variables_file, jinja_working_directory).get_config() env_config = global_config.get_environment(environment_name, raise_if_not_found=True) relevant_workflows = env_config.payload.select_relevant_or_all_workflows(workflow_name, workflow_names) diff --git a/dbx/commands/execute.py b/dbx/commands/execute.py index 15747e77..12ae689a 100644 --- a/dbx/commands/execute.py +++ b/dbx/commands/execute.py @@ -17,6 +17,7 @@ ENVIRONMENT_OPTION, EXECUTE_PARAMETERS_OPTION, HEADERS_OPTION, + JINJA_TEMPLATES_WORKING_DIR, JINJA_VARIABLES_FILE_OPTION, NO_PACKAGE_OPTION, NO_REBUILD_OPTION, @@ -66,6 +67,7 @@ def execute( Useful when core package has extras section and installation of these extras is required.""", ), headers: Optional[List[str]] = HEADERS_OPTION, + jinja_working_directory: Optional[bool] = JINJA_TEMPLATES_WORKING_DIR, jinja_variables_file: Optional[Path] = JINJA_VARIABLES_FILE_OPTION, parameters: Optional[str] = EXECUTE_PARAMETERS_OPTION, debug: Optional[bool] = DEBUG_OPTION, # noqa @@ -86,7 +88,7 @@ def execute( f"on cluster {cluster_name} (id: {cluster_id})" ) - config_reader = ConfigReader(deployment_file, jinja_variables_file) + config_reader = ConfigReader(deployment_file, jinja_variables_file, jinja_working_directory) config = config_reader.with_build_properties( BuildProperties(potential_build=True, no_rebuild=no_rebuild) diff --git a/dbx/options.py b/dbx/options.py index a0938988..eb966330 100644 --- a/dbx/options.py +++ b/dbx/options.py @@ -43,6 +43,15 @@ show_default=False, ) +JINJA_TEMPLATES_WORKING_DIR = typer.Option( + False, + "--jinja-working-dir", + is_flag=True, + help="""Use working directory as base path for Jinja Template discovery. + + If not provided, default behavior is to use the deployment file parent directory.""", +) + JINJA_VARIABLES_FILE_OPTION = typer.Option( None, "--jinja-variables-file", diff --git a/tests/deployment-configs/nested-configs/09-jinja-include.json.j2 b/tests/deployment-configs/nested-configs/09-jinja-include.json.j2 index 1e5abd4f..80dfc113 100644 --- a/tests/deployment-configs/nested-configs/09-jinja-include.json.j2 +++ b/tests/deployment-configs/nested-configs/09-jinja-include.json.j2 @@ -1,14 +1,19 @@ { "default": { - "jobs": [ + "workflows": [ { "name": "your-job-name", - "new_cluster": {% include 'includes/cluster-test.json.j2' %}, - "libraries": [], - "max_retries": 0, - "spark_python_task": { - "python_file": "file://placeholder_1.py" - } + "tasks": [ + { + "name": "main", + "new_cluster": {% include 'conf/includes/cluster-test.json.j2' %}, + "max_retries": 0, + "notebook_task": { + "notebook_path": "/Repos/some/notebook" + } + } + ], + "workflow_type": "pipeline" } ] } diff --git a/tests/unit/commands/test_deploy.py b/tests/unit/commands/test_deploy.py index 64bb9bad..a047c9bc 100644 --- a/tests/unit/commands/test_deploy.py +++ b/tests/unit/commands/test_deploy.py @@ -6,6 +6,7 @@ import pytest import yaml from pytest_mock import MockerFixture +from jinja2.exceptions import TemplateNotFound from dbx.api.config_reader import ConfigReader from dbx.api.configure import EnvironmentInfo, ProjectConfigurationManager @@ -301,3 +302,38 @@ def test_deploy_additional_headers(mocker: MockerFixture): assert deploy_result.exit_code == 0 header_parse_mock.assert_any_call(kwargs) env_mock.assert_called_once_with("default", expected_headers) + +def test_jinja_working_dir(mlflow_file_uploader, mock_storage_io, mock_api_v2_client, temp_project: Path): + project_config_dir = temp_project / "conf" + workflows_dir = temp_project / "conf" / "workflows" + includes_dir = temp_project / "conf" / "includes" + + deployment_file_name = "09-jinja-include.json.j2" + src_deployment_file = get_path_with_relation_to_current_file( + f"../deployment-configs/nested-configs/{deployment_file_name}" + ) + dst_deployment_file = project_config_dir / "workflows" / deployment_file_name + + include_file = "cluster-test.json.j2" + src_include_file = get_path_with_relation_to_current_file( + f"../deployment-configs/nested-configs/includes/{include_file}" + ) + dst_include_file = project_config_dir / "includes" / include_file + + + shutil.rmtree(project_config_dir) + project_config_dir.mkdir() + workflows_dir.mkdir() + includes_dir.mkdir() + shutil.copy(src_deployment_file, dst_deployment_file) + shutil.copy(src_include_file, dst_include_file) + + with pytest.raises(TemplateNotFound): + deploy_result = invoke_cli_runner( + ["deploy", f"--deployment-file", str(dst_deployment_file)] + ) + + deploy_result = invoke_cli_runner( + ["deploy", f"--deployment-file", str(dst_deployment_file), "--jinja-working-dir"] + ) + assert deploy_result.exit_code == 0 \ No newline at end of file