Skip to content

Commit

Permalink
Add SQLFluff /format endpoint to the server (#4)
Browse files Browse the repository at this point in the history
* Add SQLFluff /format endpoint to the server

* Add an automated test for /format

* Bug fix

* Bug fix

* Update "/format" to also support linting a string

* Add some SQLFluff logging and enable server logging

* Fix logging issue where "Before fixing" modified time is incorrect
  • Loading branch information
barrywhart authored Feb 7, 2024
1 parent c1fcf62 commit f55bec1
Show file tree
Hide file tree
Showing 3 changed files with 306 additions and 8 deletions.
85 changes: 83 additions & 2 deletions src/dbt_core_interface/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,10 @@

try:
import dbt_core_interface.state as dci_state
from dbt_core_interface.sqlfluff_util import lint_command
from dbt_core_interface.sqlfluff_util import format_command, lint_command
except ImportError:
dci_state = None
format_command = None
lint_command = None

from agate import Table, Number, Text, Column
Expand Down Expand Up @@ -6261,6 +6262,80 @@ def lint_sql(
lint_result = {"result": [error for error in result]}
return lint_result

if format_command:

@route("/format", method="POST")
def format_sql(
runners: DbtProjectContainer,
):
LOGGER.info(f"format_sql()")
# Project Support
project_runner = (
runners.get_project(request.get_header("X-dbt-Project"))
or runners.get_default_project()
)
LOGGER.info(f"got project: {project_runner}")
if not project_runner:
response.status = 400
return asdict(
ServerErrorContainer(
error=ServerError(
code=ServerErrorCode.ProjectNotRegistered,
message=(
"Project is not registered. Make a POST request to the /register"
" endpoint first to register a runner"
),
data={"registered_projects": runners.registered_projects()},
)
)
)

sql_path = request.query.get("sql_path")
LOGGER.info(f"sql_path: {sql_path}")
if sql_path:
# Format a file
# NOTE: Formatting a string is not supported.
LOGGER.info(f"formatting file: {sql_path}")
sql = Path(sql_path)
else:
# Format a string
LOGGER.info(f"formatting string")
sql = request.body.getvalue().decode("utf-8")
if not sql:
response.status = 400
return {
"error": {
"data": {},
"message": (
"No SQL provided. Either provide a SQL file path or a SQL string to lint."
),
}
}
try:
LOGGER.info(f"Calling format_command()")
temp_result, formatted_sql = format_command(
Path(project_runner.config.project_root),
sql=sql,
extra_config_path=(
Path(request.query.get("extra_config_path"))
if request.query.get("extra_config_path")
else None
),
)
except Exception as format_err:
logging.exception("Formatting failed")
response.status = 500
return {
"error": {
"data": {},
"message": str(format_err),
}
}
else:
LOGGER.info(f"Formatting succeeded")
format_result = {"result": temp_result, "sql": formatted_sql}
return format_result


def run_server(runner: Optional[DbtProject] = None, host="localhost", port=8581):
"""Run the dbt core interface server.
Expand All @@ -6286,7 +6361,13 @@ def run_server(runner: Optional[DbtProject] = None, host="localhost", port=8581)
if __name__ == "__main__":
import argparse

# logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.ERROR)

# Configure logging for 'dbt_core_interface' and 'dbt_core_interface.sqlfluff_util'
for logger_name in ['dbt_core_interface', 'dbt_core_interface.sqlfluff_util']:
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)

parser = argparse.ArgumentParser(
description="Run the dbt interface server. Defaults to the WSGIRefServer"
)
Expand Down
154 changes: 150 additions & 4 deletions src/dbt_core_interface/sqlfluff_util.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import atexit
import logging
import os
from datetime import datetime
from functools import lru_cache
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, Optional, Tuple, Union

from sqlfluff.cli.outputstream import FileOutput
from sqlfluff.core import SQLLintError, SQLTemplaterError
from sqlfluff.core.config import ConfigLoader, FluffConfig


LOGGER = logging.getLogger(__name__)


# Cache linters (up to 50 though its arbitrary)
@lru_cache(maxsize=50)
def get_linter(
Expand All @@ -17,7 +22,7 @@ def get_linter(
):
"""Get linter."""
from sqlfluff.cli.commands import get_linter_and_formatter
return get_linter_and_formatter(config, stream)[0]
return get_linter_and_formatter(config, stream)

# Cache config to prevent wasted frames
@lru_cache(maxsize=50)
Expand Down Expand Up @@ -81,7 +86,7 @@ def lint_command(
but for now this should provide maximum compatibility with the command-line
tool. We can also propose changes to SQLFluff to make this easier.
"""
lnt = get_linter(
lnt, formatter = get_linter(
*get_config(
project_root,
extra_config_path,
Expand All @@ -104,6 +109,109 @@ def lint_command(
return records[0] if records else None


def format_command(
project_root: Path,
sql: Union[Path, str],
extra_config_path: Optional[Path] = None,
ignore_local_config: bool = False,
) -> Tuple[bool, Optional[str]]:
"""Format specified file or SQL string.
This is essentially a streamlined version of the SQLFluff command-line
format function, sqlfluff.cli.commands.cli_format().
This function uses a few SQLFluff internals, but it should be relatively
stable. The initial plan was to use the public API, but that was not
behaving well initially. Small details about how SQLFluff handles .sqlfluff
and dbt_project.yaml file locations and overrides generate lots of support
questions, so it seems better to use this approach for now.
Eventually, we can look at using SQLFluff's public, high-level APIs,
but for now this should provide maximum compatibility with the command-line
tool. We can also propose changes to SQLFluff to make this easier.
"""
LOGGER.info(f"""format_command(
{project_root},
{str(sql)[:100]},
{extra_config_path},
{ignore_local_config})
""")
lnt, formatter = get_linter(
*get_config(
project_root,
extra_config_path,
ignore_local_config,
require_dialect=False,
nocolor=True,
rules=(
# All of the capitalisation rules
"capitalisation,"
# All of the layout rules
"layout,"
# Safe rules from other groups
"ambiguous.union,"
"convention.not_equal,"
"convention.coalesce,"
"convention.select_trailing_comma,"
"convention.is_null,"
"jinja.padding,"
"structure.distinct,"
)
)
)

if isinstance(sql, str):
# Lint SQL passed in as a string
LOGGER.info(f"Formatting SQL string: {sql[:100]}")
result = lnt.lint_string_wrapped(sql, fname="stdin", fix=True)
total_errors, num_filtered_errors = result.count_tmp_prs_errors()
result.discard_fixes_for_lint_errors_in_files_with_tmp_or_prs_errors()
success = not num_filtered_errors
num_fixable = result.num_violations(types=SQLLintError, fixable=True)
if num_fixable > 0:
LOGGER.info(f"Fixing {num_fixable} errors in SQL string")
result_sql = result.paths[0].files[0].fix_string()[0]
LOGGER.info(f"Result string has changes? {result_sql != sql}")
else:
LOGGER.info("No fixable errors in SQL string")
result_sql = sql
else:
# Format a SQL file
LOGGER.info(f"Formatting SQL file: {sql}")
before_modified = datetime.fromtimestamp(sql.stat().st_mtime).strftime('%Y-%m-%d %H:%M:%S')
LOGGER.info(f"Before fixing, modified: {before_modified}")
result_sql = None
lint_result = lnt.lint_paths(
paths=[str(sql)],
fix=True,
ignore_non_existent_files=False,
#processes=processes,
# If --force is set, then apply the changes as we go rather
# than waiting until the end.
apply_fixes=True,
#fixed_file_suffix=fixed_suffix,
fix_even_unparsable=False,
)
total_errors, num_filtered_errors = lint_result.count_tmp_prs_errors()
lint_result.discard_fixes_for_lint_errors_in_files_with_tmp_or_prs_errors()
success = not num_filtered_errors
if success:
num_fixable = lint_result.num_violations(types=SQLLintError, fixable=True)
if num_fixable > 0:
LOGGER.info(f"Fixing {num_fixable} errors in SQL file")
res = lint_result.persist_changes(
formatter=formatter, fixed_file_suffix=""
)
after_modified = datetime.fromtimestamp(sql.stat().st_mtime).strftime('%Y-%m-%d %H:%M:%S')
LOGGER.info(f"After fixing, modified: {after_modified}")
LOGGER.info(f"File modification time has changes? {before_modified != after_modified}")
success = all(res.values())
else:
LOGGER.info("No fixable errors in SQL file")
LOGGER.info(f"format_command returning success={success}, result_sql={result_sql[:100] if result_sql is not None else 'n/a'}")
return success, result_sql


def test_lint_command():
"""Quick and dirty functional test for lint_command().
Expand Down Expand Up @@ -138,5 +246,43 @@ def test_lint_command():
print(f"{'*'*40} Lint result {'*'*40}")


def test_format_command():
"""Quick and dirty functional test for format_command().
Handy for seeing SQLFluff logs if something goes wrong. The automated tests
make it difficult to see the logs.
"""
logging.basicConfig(level=logging.DEBUG)
from dbt_core_interface.project import DbtProjectContainer
dbt = DbtProjectContainer()
dbt.add_project(
name_override="dbt_project",
project_dir="tests/sqlfluff_templater/fixtures/dbt/dbt_project/",
profiles_dir="tests/sqlfluff_templater/fixtures/dbt/profiles_yml/",
target="dev",
)
sql_path = Path(
"tests/sqlfluff_templater/fixtures/dbt/dbt_project/models/my_new_project/issue_1608.sql"
)

# Test formatting a string
success, result_sql = format_command(
Path("tests/sqlfluff_templater/fixtures/dbt/dbt_project"),
sql=sql_path.read_text(),
)
print(f"{'*'*40} Formatting result {'*'*40}")
print(success, result_sql)

# Test formatting a file
result = format_command(
Path("tests/sqlfluff_templater/fixtures/dbt/dbt_project"),
sql=sql_path,
)
print(f"{'*'*40} Formatting result {'*'*40}")
print(result)
print(f"{'*'*40} Formatting result {'*'*40}")


if __name__ == "__main__":
test_lint_command()
#test_lint_command()
test_format_command()
75 changes: 73 additions & 2 deletions tests/sqlfluff_templater/test_server_v2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import difflib
import os
import shutil
import urllib.parse
from pathlib import Path

Expand All @@ -24,8 +27,8 @@
@pytest.mark.parametrize(
"param_name, param_value",
[
("sql_path", SQL_PATH),
(None, SQL_PATH.read_text()),
pytest.param("sql_path", SQL_PATH, id="sql_file"),
pytest.param(None, SQL_PATH.read_text(), id="sql_string"),
],
)
def test_lint(param_name, param_value, profiles_dir, project_dir, sqlfluff_config_path, caplog):
Expand Down Expand Up @@ -94,6 +97,74 @@ def test_lint_parse_failure(profiles_dir, project_dir, sqlfluff_config_path, cap
assert response_json == {"result": []}


@pytest.mark.parametrize(
"param_name, param_value",
[
pytest.param("sql_path", SQL_PATH, id="sql_file"),
pytest.param(None, SQL_PATH.read_text(), id="sql_string"),
],
)
def test_format(param_name, param_value, profiles_dir, project_dir, sqlfluff_config_path, caplog):
if param_name:
# Make a copy of the file and format the copy so we don't modify a file in
# git.
destination_path = param_value.parent / f"{param_value.stem + '_new'}{param_value.suffix}"
shutil.copy(str(param_value), str(destination_path))
param_value = destination_path

params = {}
kwargs = {}
data = ''
if param_name:
# Formatting a file
params[param_name] = param_value
original_lines = param_value.read_text().splitlines()
else:
data = param_value
original_lines = param_value.splitlines()
response = client.post(
f"/format?{urllib.parse.urlencode(params)}",
data,
headers={"X-dbt-Project": "dbt_project"},
**kwargs,
)
try:
assert response.status_code == 200

# Compare "before and after" SQL and verify the expected changes were made.
if param_name:
formatted_lines = destination_path.read_text().splitlines()
else:
formatted_lines = response.json["sql"].splitlines()
differ = difflib.Differ()
diff = list(differ.compare(original_lines, formatted_lines))
assert diff == [
" {{ config(materialized='view') }}",
" ",
" with cte_example as (",
"- select 1 as col_name",
"? -\n",
"+ select 1 as col_name",
" ),",
" ",
"- final as",
"+ final as (",
"? ++\n",
"- (",
" select",
" col_name,",
" {{- echo('col_name') -}} as col_name2",
" from",
" cte_example",
" )",
" ",
" select * from final",
]
finally:
if param_name:
os.unlink(destination_path)


@pytest.mark.parametrize(
"param_name, param_value, clients, sample",
[
Expand Down

0 comments on commit f55bec1

Please sign in to comment.