Skip to content

Commit

Permalink
Make fixer diagnostic codes unique (#3582)
Browse files Browse the repository at this point in the history
## Changes
Make fixer diagnostic codes unique so that the right fixer can be found
for code migration/fixing.

### Linked issues

Progresses #3514
Breaks up #3520

### Functionality

- [x] modified existing command: `databricks labs ucx
migrate-local-code`

### Tests

- [ ] manually tested
- [x] modified and added unit tests
- [x] modified and added integration tests
  • Loading branch information
JCZuurmond authored Feb 3, 2025
1 parent 310d9ff commit f8bf94c
Show file tree
Hide file tree
Showing 30 changed files with 334 additions and 202 deletions.
4 changes: 3 additions & 1 deletion docs/ucx/docs/dev/contributing.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,9 @@ rdd-in-shared-clusters
spark-logging-in-shared-clusters
sql-parse-error
sys-path-cannot-compute-value
table-migrated-to-uc
table-migrated-to-uc-python
table-migrated-to-uc-python-sql
table-migrated-to-uc-sql
to-json-in-shared-clusters
unsupported-magic-line
```
Expand Down
14 changes: 9 additions & 5 deletions docs/ucx/docs/reference/linter_codes.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ spark.table(f"foo_{some_table_name}")

We even detect string constants when coming either from `dbutils.widgets.get` (via job named parameters) or through
loop variables. If `old.things` table is migrated to `brand.new.stuff` in Unity Catalog, the following code will
trigger two messages: [`table-migrated-to-uc`](#table-migrated-to-uc) for the first query, as the contents are clearly
analysable, and `cannot-autofix-table-reference` for the second query.
trigger two messages: [`table-migrated-to-uc-{sql,python,python-sql}`](#table-migrated-to-uc-sqlpythonpython-sql) for
the first query, as the contents are clearly analysable, and `cannot-autofix-table-reference` for the second query.

```python
# ucx[table-migrated-to-uc:+4:4:+4:20] Table old.things is migrated to brand.new.stuff in Unity Catalog
# ucx[table-migrated-to-uc-python-sql:+4:4:+4:20] Table old.things is migrated to brand.new.stuff in Unity Catalog
# ucx[cannot-autofix-table-reference:+3:4:+3:20] Can't migrate table_name argument in 'spark.sql(query)' because its value cannot be computed
table_name = f"table_{index}"
for query in ["SELECT * FROM old.things", f"SELECT * FROM {table_name}"]:
Expand Down Expand Up @@ -247,12 +247,16 @@ analysis where the path is located.



## `table-migrated-to-uc`
## `table-migrated-to-uc-{sql,python,python-sql}`

This message indicates that the linter has found a table that has been migrated to Unity Catalog. The user must ensure
that the table is available in Unity Catalog.


| Postfix | Explanation |
|------------|-------------------------------------------------|
| sql | Table reference in SparkSQL |
| python | Table reference in PySpark |
| python-sql | Table reference in SparkSQL called from PySpark |

## `to-json-in-shared-clusters`

Expand Down
7 changes: 6 additions & 1 deletion src/databricks/labs/ucx/source_code/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,12 @@ class Fixer(ABC):

@property
@abstractmethod
def name(self) -> str: ...
def diagnostic_code(self) -> str:
"""The diagnostic code that this fixer fixes."""

def is_supported(self, diagnostic_code: str) -> bool:
"""Indicate if the diagnostic code is supported by this fixer."""
return self.diagnostic_code is not None and diagnostic_code == self.diagnostic_code

@abstractmethod
def apply(self, code: str) -> str: ...
Expand Down
21 changes: 14 additions & 7 deletions src/databricks/labs/ucx/source_code/linters/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from databricks.labs.ucx.source_code.linters.imports import DbutilsPyLinter

from databricks.labs.ucx.source_code.linters.pyspark import (
SparkSqlPyLinter,
DirectFsAccessSqlPylinter,
FromTableSqlPyLinter,
SparkTableNamePyLinter,
SparkSqlTablePyCollector,
)
Expand Down Expand Up @@ -57,7 +58,7 @@ def __init__(
sql_linters.append(from_table)
sql_fixers.append(from_table)
sql_table_collectors.append(from_table)
spark_sql = SparkSqlPyLinter(from_table, from_table)
spark_sql = FromTableSqlPyLinter(from_table)
python_linters.append(spark_sql)
python_fixers.append(spark_sql)
python_table_collectors.append(SparkSqlTablePyCollector(from_table))
Expand All @@ -75,7 +76,7 @@ def __init__(
DBRv8d0PyLinter(dbr_version=session_state.dbr_version),
SparkConnectPyLinter(session_state),
DbutilsPyLinter(session_state),
SparkSqlPyLinter(sql_direct_fs, None),
DirectFsAccessSqlPylinter(sql_direct_fs),
]

python_dfsa_collectors += [DirectFsAccessPyLinter(session_state, prevent_spark_duplicates=False)]
Expand Down Expand Up @@ -112,10 +113,16 @@ def linter(self, language: Language) -> Linter:
raise ValueError(f"Unsupported language: {language}")

def fixer(self, language: Language, diagnostic_code: str) -> Fixer | None:
if language not in self._fixers:
return None
for fixer in self._fixers[language]:
if fixer.name == diagnostic_code:
"""Get the fixer for a language that matches the code.
The first fixer which name matches with the diagnostic code is returned. This logic assumes the fixers have
unique names.
Returns :
Fixer | None : The fixer if a match is found, otherwise None.
"""
for fixer in self._fixers.get(language, []):
if fixer.is_supported(diagnostic_code):
return fixer
return None

Expand Down
7 changes: 4 additions & 3 deletions src/databricks/labs/ucx/source_code/linters/from_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ def __init__(self, index: TableMigrationIndex, session_state: CurrentSessionStat
self._session_state: CurrentSessionState = session_state

@property
def name(self) -> str:
return 'table-migrate'
def diagnostic_code(self) -> str:
"""The diagnostic codes that this fixer fixes."""
return "table-migrated-to-uc-sql"

@property
def schema(self) -> str:
Expand All @@ -58,7 +59,7 @@ def lint_expression(self, expression: Expression) -> Iterable[Deprecation]:
if not dst:
return
yield Deprecation(
code='table-migrated-to-uc',
code="table-migrated-to-uc-sql",
message=f"Table {info.schema_name}.{info.table_name} is migrated to {dst.destination()} in Unity Catalog",
# SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159
start_line=0,
Expand Down
84 changes: 70 additions & 14 deletions src/databricks/labs/ucx/source_code/linters/pyspark.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import logging
from abc import ABC, abstractmethod
from collections.abc import Iterable, Iterator
Expand All @@ -18,7 +19,11 @@
TableSqlCollector,
DfsaSqlCollector,
)
from databricks.labs.ucx.source_code.linters.directfs import DIRECT_FS_ACCESS_PATTERNS, DirectFsAccessNode
from databricks.labs.ucx.source_code.linters.directfs import (
DIRECT_FS_ACCESS_PATTERNS,
DirectFsAccessNode,
DirectFsAccessSqlLinter,
)
from databricks.labs.ucx.source_code.python.python_infer import InferredValue
from databricks.labs.ucx.source_code.linters.from_table import FromTableSqlLinter
from databricks.labs.ucx.source_code.python.python_ast import (
Expand Down Expand Up @@ -155,7 +160,7 @@ def lint(
if dst is None:
continue
yield Deprecation.from_node(
code='table-migrated-to-uc',
code='table-migrated-to-uc-python',
message=f"Table {used_table[0]} is migrated to {dst.destination()} in Unity Catalog",
# SQLGlot does not propagate tokens yet. See https://github.com/tobymao/sqlglot/issues/3159
node=node,
Expand Down Expand Up @@ -387,6 +392,14 @@ def matchers(self) -> dict[str, _TableNameMatcher]:


class SparkTableNamePyLinter(PythonLinter, Fixer, TablePyCollector):
"""Linter for table name references in PySpark
Examples:
1. Find table name referenceS
``` python
spark.read.table("hive_metastore.schema.table")
```
"""

def __init__(
self,
Expand All @@ -400,9 +413,9 @@ def __init__(
self._spark_matchers = SparkTableNameMatchers(False).matchers

@property
def name(self) -> str:
# this is the same fixer, just in a different language context
return self._from_table.name
def diagnostic_code(self) -> str:
"""The diagnostic codes that this fixer fixes."""
return "table-migrated-to-uc-python"

def lint_tree(self, tree: Tree) -> Iterable[Advice]:
for node in tree.walk():
Expand Down Expand Up @@ -461,28 +474,32 @@ def _visit_call_nodes(cls, tree: Tree) -> Iterable[tuple[Call, NodeNG]]:
yield call_node, query


class SparkSqlPyLinter(_SparkSqlAnalyzer, PythonLinter, Fixer):
class _SparkSqlPyLinter(_SparkSqlAnalyzer, PythonLinter, Fixer):
"""Linter for SparkSQL used within PySpark."""

def __init__(self, sql_linter: SqlLinter, sql_fixer: Fixer | None):
self._sql_linter = sql_linter
self._sql_fixer = sql_fixer

@property
def name(self) -> str:
return "<none>" if self._sql_fixer is None else self._sql_fixer.name

def lint_tree(self, tree: Tree) -> Iterable[Advice]:
inferable_values = []
for call_node, query in self._visit_call_nodes(tree):
for value in InferredValue.infer_from_node(query):
if not value.is_inferred():
if value.is_inferred():
inferable_values.append((call_node, value))
else:
yield Advisory.from_node(
code="cannot-autofix-table-reference",
message=f"Can't migrate table_name argument in '{query.as_string()}' because its value cannot be computed",
node=call_node,
)
continue
for advice in self._sql_linter.lint(value.as_string()):
yield advice.replace_from_node(call_node)
for call_node, value in inferable_values:
for advice in self._sql_linter.lint(value.as_string()):
# Replacing the fixer code to indicate that the SparkSQL fixer is wrapped with PySpark
code = advice.code
if self._sql_fixer and code == self._sql_fixer.diagnostic_code:
code = self.diagnostic_code
yield dataclasses.replace(advice.replace_from_node(call_node), code=code)

def apply(self, code: str) -> str:
if not self._sql_fixer:
Expand All @@ -503,6 +520,45 @@ def apply(self, code: str) -> str:
return tree.node.as_string()


class FromTableSqlPyLinter(_SparkSqlPyLinter):
"""Lint tables and views in Spark SQL wrapped by PySpark code.
Examples:
1. Find table name reference in SparkSQL:
``` python
spark.sql("SELECT * FROM hive_metastore.schema.table").collect()
```
"""

def __init__(self, sql_linter: FromTableSqlLinter):
super().__init__(sql_linter, sql_linter)

@property
def diagnostic_code(self) -> str:
"""The diagnostic codes that this fixer fixes."""
return "table-migrated-to-uc-python-sql"


class DirectFsAccessSqlPylinter(_SparkSqlPyLinter):
"""Lint direct file system access in Spark SQL wrapped by PySpark code.
Examples:
1. Find table name reference in SparkSQL:
``` python
spark.sql("SELECT * FROM parquet.`/dbfs/path/to/table`").collect()
```
"""

def __init__(self, sql_linter: DirectFsAccessSqlLinter):
# TODO: Implement fixer for direct filesystem access (https://github.com/databrickslabs/ucx/issues/2021)
super().__init__(sql_linter, None)

@property
def diagnostic_code(self) -> str:
"""The diagnostic codes that this fixer fixes."""
return "direct-filesystem-access-python-sql"


class SparkSqlDfsaPyCollector(_SparkSqlAnalyzer, DfsaPyCollector):

def __init__(self, sql_collector: DfsaSqlCollector):
Expand Down
66 changes: 47 additions & 19 deletions tests/integration/source_code/message_codes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from collections.abc import Iterable
from pathlib import Path

import astroid # type: ignore[import-untyped]
from databricks.labs.blueprint.wheels import ProductInfo

Expand All @@ -6,30 +9,55 @@


def main():
# pylint: disable=too-many-nested-blocks
codes = set()
"""Walk the UCX code base to find all diagnostic linting codes."""
codes = set[str]()
product_info = ProductInfo.from_class(Advice)
source_code = product_info.version_file().parent
for file in source_code.glob("**/*.py"):
maybe_tree = MaybeTree.from_source_code(file.read_text())
if not maybe_tree.tree:
continue
tree = maybe_tree.tree
# recursively detect values of "code" kwarg in calls
for node in tree.walk():
if not isinstance(node, astroid.Call):
continue
for keyword in node.keywords:
name = keyword.arg
if name != "code":
continue
if not isinstance(keyword.value, astroid.Const):
continue
problem_code = keyword.value.value
codes.add(problem_code)
for path in source_code.glob("**/*.py"):
codes.update(_find_diagnostic_codes(path))
for code in sorted(codes):
print(code)


def _find_diagnostic_codes(file: Path) -> Iterable[str]:
"""Walk the Python ast tree to find the diagnostic codes."""
maybe_tree = MaybeTree.from_source_code(file.read_text())
if not maybe_tree.tree:
return
for node in maybe_tree.tree.walk():
diagnostic_code = None
if isinstance(node, astroid.ClassDef):
diagnostic_code = _find_diagnostic_code_in_class_def(node)
elif isinstance(node, astroid.Call):
diagnostic_code = _find_diagnostic_code_in_call(node)
if diagnostic_code:
yield diagnostic_code


def _find_diagnostic_code_in_call(node: astroid.Call) -> str | None:
"""Find the diagnostic code in a call node."""
for keyword in node.keywords:
if keyword.arg == "code" and isinstance(keyword.value, astroid.Const):
problem_code = keyword.value.value
return problem_code
return None


def _find_diagnostic_code_in_class_def(node: astroid.ClassDef) -> str | None:
"""Find the diagnostic code in a class definition node."""
diagnostic_code_methods = []
for child_node in node.body:
if isinstance(child_node, astroid.FunctionDef) and child_node.name == "diagnostic_code":
diagnostic_code_methods.append(child_node)
if diagnostic_code_methods and diagnostic_code_methods[0].body:
problem_code = diagnostic_code_methods[0].body[0].value.value
return problem_code
return None


if __name__ == "__main__":
main()


def test_main():
main()
Loading

0 comments on commit f8bf94c

Please sign in to comment.