From f55bec13012463bf300d3ba588b0fe56e9feca5a Mon Sep 17 00:00:00 2001 From: Barry Hart Date: Wed, 7 Feb 2024 11:52:11 -0500 Subject: [PATCH] Add SQLFluff /format endpoint to the server (#4) * Add SQLFluff /format endpoint to the server * Add an automated test for /format * Bug fix * Bug fix * Update "/format" to also support linting a string * Add some SQLFluff logging and enable server logging * Fix logging issue where "Before fixing" modified time is incorrect --- src/dbt_core_interface/project.py | 85 +++++++++++- src/dbt_core_interface/sqlfluff_util.py | 154 ++++++++++++++++++++- tests/sqlfluff_templater/test_server_v2.py | 75 +++++++++- 3 files changed, 306 insertions(+), 8 deletions(-) diff --git a/src/dbt_core_interface/project.py b/src/dbt_core_interface/project.py index 375dbb9..43e92be 100644 --- a/src/dbt_core_interface/project.py +++ b/src/dbt_core_interface/project.py @@ -119,9 +119,10 @@ try: import dbt_core_interface.state as dci_state - from dbt_core_interface.sqlfluff_util import lint_command + from dbt_core_interface.sqlfluff_util import format_command, lint_command except ImportError: dci_state = None + format_command = None lint_command = None from agate import Table, Number, Text, Column @@ -6261,6 +6262,80 @@ def lint_sql( lint_result = {"result": [error for error in result]} return lint_result +if format_command: + + @route("/format", method="POST") + def format_sql( + runners: DbtProjectContainer, + ): + LOGGER.info(f"format_sql()") + # Project Support + project_runner = ( + runners.get_project(request.get_header("X-dbt-Project")) + or runners.get_default_project() + ) + LOGGER.info(f"got project: {project_runner}") + if not project_runner: + response.status = 400 + return asdict( + ServerErrorContainer( + error=ServerError( + code=ServerErrorCode.ProjectNotRegistered, + message=( + "Project is not registered. Make a POST request to the /register" + " endpoint first to register a runner" + ), + data={"registered_projects": runners.registered_projects()}, + ) + ) + ) + + sql_path = request.query.get("sql_path") + LOGGER.info(f"sql_path: {sql_path}") + if sql_path: + # Format a file + # NOTE: Formatting a string is not supported. + LOGGER.info(f"formatting file: {sql_path}") + sql = Path(sql_path) + else: + # Format a string + LOGGER.info(f"formatting string") + sql = request.body.getvalue().decode("utf-8") + if not sql: + response.status = 400 + return { + "error": { + "data": {}, + "message": ( + "No SQL provided. Either provide a SQL file path or a SQL string to lint." + ), + } + } + try: + LOGGER.info(f"Calling format_command()") + temp_result, formatted_sql = format_command( + Path(project_runner.config.project_root), + sql=sql, + extra_config_path=( + Path(request.query.get("extra_config_path")) + if request.query.get("extra_config_path") + else None + ), + ) + except Exception as format_err: + logging.exception("Formatting failed") + response.status = 500 + return { + "error": { + "data": {}, + "message": str(format_err), + } + } + else: + LOGGER.info(f"Formatting succeeded") + format_result = {"result": temp_result, "sql": formatted_sql} + return format_result + def run_server(runner: Optional[DbtProject] = None, host="localhost", port=8581): """Run the dbt core interface server. @@ -6286,7 +6361,13 @@ def run_server(runner: Optional[DbtProject] = None, host="localhost", port=8581) if __name__ == "__main__": import argparse - # logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=logging.ERROR) + + # Configure logging for 'dbt_core_interface' and 'dbt_core_interface.sqlfluff_util' + for logger_name in ['dbt_core_interface', 'dbt_core_interface.sqlfluff_util']: + logger = logging.getLogger(logger_name) + logger.setLevel(logging.INFO) + parser = argparse.ArgumentParser( description="Run the dbt interface server. Defaults to the WSGIRefServer" ) diff --git a/src/dbt_core_interface/sqlfluff_util.py b/src/dbt_core_interface/sqlfluff_util.py index fd5b76b..6577f43 100644 --- a/src/dbt_core_interface/sqlfluff_util.py +++ b/src/dbt_core_interface/sqlfluff_util.py @@ -1,14 +1,19 @@ import atexit import logging import os +from datetime import datetime from functools import lru_cache from pathlib import Path -from typing import Dict, Optional, Union +from typing import Dict, Optional, Tuple, Union from sqlfluff.cli.outputstream import FileOutput +from sqlfluff.core import SQLLintError, SQLTemplaterError from sqlfluff.core.config import ConfigLoader, FluffConfig +LOGGER = logging.getLogger(__name__) + + # Cache linters (up to 50 though its arbitrary) @lru_cache(maxsize=50) def get_linter( @@ -17,7 +22,7 @@ def get_linter( ): """Get linter.""" from sqlfluff.cli.commands import get_linter_and_formatter - return get_linter_and_formatter(config, stream)[0] + return get_linter_and_formatter(config, stream) # Cache config to prevent wasted frames @lru_cache(maxsize=50) @@ -81,7 +86,7 @@ def lint_command( but for now this should provide maximum compatibility with the command-line tool. We can also propose changes to SQLFluff to make this easier. """ - lnt = get_linter( + lnt, formatter = get_linter( *get_config( project_root, extra_config_path, @@ -104,6 +109,109 @@ def lint_command( return records[0] if records else None +def format_command( + project_root: Path, + sql: Union[Path, str], + extra_config_path: Optional[Path] = None, + ignore_local_config: bool = False, +) -> Tuple[bool, Optional[str]]: + """Format specified file or SQL string. + + This is essentially a streamlined version of the SQLFluff command-line + format function, sqlfluff.cli.commands.cli_format(). + + This function uses a few SQLFluff internals, but it should be relatively + stable. The initial plan was to use the public API, but that was not + behaving well initially. Small details about how SQLFluff handles .sqlfluff + and dbt_project.yaml file locations and overrides generate lots of support + questions, so it seems better to use this approach for now. + + Eventually, we can look at using SQLFluff's public, high-level APIs, + but for now this should provide maximum compatibility with the command-line + tool. We can also propose changes to SQLFluff to make this easier. + """ + LOGGER.info(f"""format_command( + {project_root}, + {str(sql)[:100]}, + {extra_config_path}, + {ignore_local_config}) +""") + lnt, formatter = get_linter( + *get_config( + project_root, + extra_config_path, + ignore_local_config, + require_dialect=False, + nocolor=True, + rules=( + # All of the capitalisation rules + "capitalisation," + # All of the layout rules + "layout," + # Safe rules from other groups + "ambiguous.union," + "convention.not_equal," + "convention.coalesce," + "convention.select_trailing_comma," + "convention.is_null," + "jinja.padding," + "structure.distinct," + ) + ) + ) + + if isinstance(sql, str): + # Lint SQL passed in as a string + LOGGER.info(f"Formatting SQL string: {sql[:100]}") + result = lnt.lint_string_wrapped(sql, fname="stdin", fix=True) + total_errors, num_filtered_errors = result.count_tmp_prs_errors() + result.discard_fixes_for_lint_errors_in_files_with_tmp_or_prs_errors() + success = not num_filtered_errors + num_fixable = result.num_violations(types=SQLLintError, fixable=True) + if num_fixable > 0: + LOGGER.info(f"Fixing {num_fixable} errors in SQL string") + result_sql = result.paths[0].files[0].fix_string()[0] + LOGGER.info(f"Result string has changes? {result_sql != sql}") + else: + LOGGER.info("No fixable errors in SQL string") + result_sql = sql + else: + # Format a SQL file + LOGGER.info(f"Formatting SQL file: {sql}") + before_modified = datetime.fromtimestamp(sql.stat().st_mtime).strftime('%Y-%m-%d %H:%M:%S') + LOGGER.info(f"Before fixing, modified: {before_modified}") + result_sql = None + lint_result = lnt.lint_paths( + paths=[str(sql)], + fix=True, + ignore_non_existent_files=False, + #processes=processes, + # If --force is set, then apply the changes as we go rather + # than waiting until the end. + apply_fixes=True, + #fixed_file_suffix=fixed_suffix, + fix_even_unparsable=False, + ) + total_errors, num_filtered_errors = lint_result.count_tmp_prs_errors() + lint_result.discard_fixes_for_lint_errors_in_files_with_tmp_or_prs_errors() + success = not num_filtered_errors + if success: + num_fixable = lint_result.num_violations(types=SQLLintError, fixable=True) + if num_fixable > 0: + LOGGER.info(f"Fixing {num_fixable} errors in SQL file") + res = lint_result.persist_changes( + formatter=formatter, fixed_file_suffix="" + ) + after_modified = datetime.fromtimestamp(sql.stat().st_mtime).strftime('%Y-%m-%d %H:%M:%S') + LOGGER.info(f"After fixing, modified: {after_modified}") + LOGGER.info(f"File modification time has changes? {before_modified != after_modified}") + success = all(res.values()) + else: + LOGGER.info("No fixable errors in SQL file") + LOGGER.info(f"format_command returning success={success}, result_sql={result_sql[:100] if result_sql is not None else 'n/a'}") + return success, result_sql + + def test_lint_command(): """Quick and dirty functional test for lint_command(). @@ -138,5 +246,43 @@ def test_lint_command(): print(f"{'*'*40} Lint result {'*'*40}") +def test_format_command(): + """Quick and dirty functional test for format_command(). + + Handy for seeing SQLFluff logs if something goes wrong. The automated tests + make it difficult to see the logs. + """ + logging.basicConfig(level=logging.DEBUG) + from dbt_core_interface.project import DbtProjectContainer + dbt = DbtProjectContainer() + dbt.add_project( + name_override="dbt_project", + project_dir="tests/sqlfluff_templater/fixtures/dbt/dbt_project/", + profiles_dir="tests/sqlfluff_templater/fixtures/dbt/profiles_yml/", + target="dev", + ) + sql_path = Path( + "tests/sqlfluff_templater/fixtures/dbt/dbt_project/models/my_new_project/issue_1608.sql" + ) + + # Test formatting a string + success, result_sql = format_command( + Path("tests/sqlfluff_templater/fixtures/dbt/dbt_project"), + sql=sql_path.read_text(), + ) + print(f"{'*'*40} Formatting result {'*'*40}") + print(success, result_sql) + + # Test formatting a file + result = format_command( + Path("tests/sqlfluff_templater/fixtures/dbt/dbt_project"), + sql=sql_path, + ) + print(f"{'*'*40} Formatting result {'*'*40}") + print(result) + print(f"{'*'*40} Formatting result {'*'*40}") + + if __name__ == "__main__": - test_lint_command() + #test_lint_command() + test_format_command() diff --git a/tests/sqlfluff_templater/test_server_v2.py b/tests/sqlfluff_templater/test_server_v2.py index a5a5d5e..2925419 100644 --- a/tests/sqlfluff_templater/test_server_v2.py +++ b/tests/sqlfluff_templater/test_server_v2.py @@ -1,3 +1,6 @@ +import difflib +import os +import shutil import urllib.parse from pathlib import Path @@ -24,8 +27,8 @@ @pytest.mark.parametrize( "param_name, param_value", [ - ("sql_path", SQL_PATH), - (None, SQL_PATH.read_text()), + pytest.param("sql_path", SQL_PATH, id="sql_file"), + pytest.param(None, SQL_PATH.read_text(), id="sql_string"), ], ) def test_lint(param_name, param_value, profiles_dir, project_dir, sqlfluff_config_path, caplog): @@ -94,6 +97,74 @@ def test_lint_parse_failure(profiles_dir, project_dir, sqlfluff_config_path, cap assert response_json == {"result": []} +@pytest.mark.parametrize( + "param_name, param_value", + [ + pytest.param("sql_path", SQL_PATH, id="sql_file"), + pytest.param(None, SQL_PATH.read_text(), id="sql_string"), + ], +) +def test_format(param_name, param_value, profiles_dir, project_dir, sqlfluff_config_path, caplog): + if param_name: + # Make a copy of the file and format the copy so we don't modify a file in + # git. + destination_path = param_value.parent / f"{param_value.stem + '_new'}{param_value.suffix}" + shutil.copy(str(param_value), str(destination_path)) + param_value = destination_path + + params = {} + kwargs = {} + data = '' + if param_name: + # Formatting a file + params[param_name] = param_value + original_lines = param_value.read_text().splitlines() + else: + data = param_value + original_lines = param_value.splitlines() + response = client.post( + f"/format?{urllib.parse.urlencode(params)}", + data, + headers={"X-dbt-Project": "dbt_project"}, + **kwargs, + ) + try: + assert response.status_code == 200 + + # Compare "before and after" SQL and verify the expected changes were made. + if param_name: + formatted_lines = destination_path.read_text().splitlines() + else: + formatted_lines = response.json["sql"].splitlines() + differ = difflib.Differ() + diff = list(differ.compare(original_lines, formatted_lines)) + assert diff == [ + " {{ config(materialized='view') }}", + " ", + " with cte_example as (", + "- select 1 as col_name", + "? -\n", + "+ select 1 as col_name", + " ),", + " ", + "- final as", + "+ final as (", + "? ++\n", + "- (", + " select", + " col_name,", + " {{- echo('col_name') -}} as col_name2", + " from", + " cte_example", + " )", + " ", + " select * from final", + ] + finally: + if param_name: + os.unlink(destination_path) + + @pytest.mark.parametrize( "param_name, param_value, clients, sample", [