From b9c497e6212dd59c0543a2f2c84fedc942c8e699 Mon Sep 17 00:00:00 2001 From: Hiroyuki Nakazato Date: Fri, 25 Apr 2025 18:34:47 +0900 Subject: [PATCH] Add sql2dbx: LLM-powered SQL to Databricks notebook converter --- sql2dbx/.gitignore | 173 +++++ sql2dbx/README.md | 24 + ...xample1_multi_statement_transformation.sql | 51 ++ .../input/mysql_example2_stored_procedure.sql | 54 ++ ...example1_multi_statement_transformation.py | 89 +++ .../output/mysql_example2_stored_procedure.py | 92 +++ ...xample1_multi_statement_transformation.sql | 53 ++ .../netezza_example2_stored_procedure.sql | 60 ++ ...example1_multi_statement_transformation.py | 87 +++ .../netezza_example2_stored_procedure.py | 79 ++ ...xample1_multi_statement_transformation.sql | 59 ++ .../oracle_example2_stored_procedure.sql | 76 ++ ...example1_multi_statement_transformation.py | 92 +++ .../oracle_example2_stored_procedure.py | 91 +++ ...xample1_multi_statement_transformation.sql | 53 ++ .../postgresql_example2_stored_procedure.sql | 89 +++ ...example1_multi_statement_transformation.py | 88 +++ .../postgresql_example2_stored_procedure.py | 92 +++ ...xample1_multi_statement_transformation.sql | 55 ++ .../redshift_example2_stored_procedure.sql | 57 ++ ...example1_multi_statement_transformation.py | 89 +++ .../redshift_example2_stored_procedure.py | 85 +++ ...xample1_multi_statement_transformation.sql | 57 ++ ...nowflake_example2_javascript_procedure.sql | 111 +++ ...example1_multi_statement_transformation.py | 111 +++ ...snowflake_example2_javascript_procedure.py | 112 +++ ...xample1_multi_statement_transformation.sql | 58 ++ .../teradata_example2_stored_procedure.sql | 53 ++ ...example1_multi_statement_transformation.py | 109 +++ .../teradata_example2_stored_procedure.py | 98 +++ ...xample1_multi_statement_transformation.sql | 52 ++ .../input/tsql_example2_stored_procedure.sql | 111 +++ ...example1_multi_statement_transformation.py | 96 +++ .../output/tsql_example2_stored_procedure.py | 135 ++++ sql2dbx/notebooks/00_main.py | 706 +++++++++++++++++ sql2dbx/notebooks/00_main_ja.py | 703 +++++++++++++++++ sql2dbx/notebooks/01_analyze_input_files.py | 237 ++++++ .../notebooks/02_convert_sql_to_databricks.py | 248 ++++++ .../notebooks/03_01_static_syntax_check.py | 166 ++++ sql2dbx/notebooks/03_02_fix_syntax_error.py | 258 +++++++ sql2dbx/notebooks/04_split_cells.py | 114 +++ .../05_export_to_databricks_notebooks.py | 134 ++++ .../notebooks/11_adjust_conversion_targets.py | 89 +++ sql2dbx/notebooks/__init__.py | 0 .../external_model_amazon_bedrock.py | 95 +++ .../external_model_azure_openai.py | 99 +++ .../external_model/helper_external_model.py | 97 +++ sql2dbx/notebooks/notebook_utils.py | 254 +++++++ sql2dbx/notebooks/pyscripts/__init__.py | 0 .../pyscripts/batch_inference_helper.py | 718 ++++++++++++++++++ .../notebooks/pyscripts/cell_split_helper.py | 311 ++++++++ .../pyscripts/conversion_prompt_helper.py | 224 ++++++ ...l_to_databricks_notebook_common_python.yml | 95 +++ .../mysql_to_databricks_notebook.yml | 240 ++++++ .../netezza_to_databricks_notebook.yml | 239 ++++++ .../oracle_to_databricks_notebook.yml | 238 ++++++ .../postgresql_to_databricks_notebook.yml | 233 ++++++ .../redshift_to_databricks_notebook.yml | 233 ++++++ .../snowflake_to_databricks_notebook.yml | 257 +++++++ .../teradata_to_databricks_notebook.yml | 224 ++++++ .../tsql_to_databricks_notebook.yml | 251 ++++++ .../conversion_result_clean_helper.py | 53 ++ .../pyscripts/databricks_credentials.py | 77 ++ .../pyscripts/llm_token_count_helper.py | 107 +++ .../pyscripts/notebook_export_helper.py | 187 +++++ sql2dbx/notebooks/pyscripts/notebook_i18n.py | 233 ++++++ .../pyscripts/spark_sql_extract_helper.py | 167 ++++ sql2dbx/notebooks/pyscripts/token_utils.py | 134 ++++ sql2dbx/notebooks/pyscripts/utils.py | 98 +++ sql2dbx/notebooks/requirements.txt | 8 + sql2dbx/resources/sql2dbx_job_main.yml | 45 ++ sql2dbx/tests/__init__.py | 0 sql2dbx/tests/pyscripts/__init__.py | 0 .../tests/pyscripts/test_cell_split_helper.py | 147 ++++ .../test_conversion_prompt_helper.py | 174 +++++ .../test_conversion_result_clean_helper.py | 142 ++++ .../pyscripts/test_llm_token_count_helper.py | 96 +++ .../pyscripts/test_notebook_export_helper.py | 130 ++++ sql2dbx/tests/pyscripts/test_notebook_i18n.py | 106 +++ .../test_spark_sql_extract_helper.py | 169 +++++ sql2dbx/tests/pyscripts/test_token_utils.py | 114 +++ sql2dbx/tests/pyscripts/test_utils.py | 95 +++ 82 files changed, 11536 insertions(+) create mode 100644 sql2dbx/.gitignore create mode 100644 sql2dbx/README.md create mode 100644 sql2dbx/examples/mysql/input/mysql_example1_multi_statement_transformation.sql create mode 100644 sql2dbx/examples/mysql/input/mysql_example2_stored_procedure.sql create mode 100644 sql2dbx/examples/mysql/output/mysql_example1_multi_statement_transformation.py create mode 100644 sql2dbx/examples/mysql/output/mysql_example2_stored_procedure.py create mode 100644 sql2dbx/examples/netezza/input/netezza_example1_multi_statement_transformation.sql create mode 100644 sql2dbx/examples/netezza/input/netezza_example2_stored_procedure.sql create mode 100644 sql2dbx/examples/netezza/output/netezza_example1_multi_statement_transformation.py create mode 100644 sql2dbx/examples/netezza/output/netezza_example2_stored_procedure.py create mode 100644 sql2dbx/examples/oracle/input/oracle_example1_multi_statement_transformation.sql create mode 100644 sql2dbx/examples/oracle/input/oracle_example2_stored_procedure.sql create mode 100644 sql2dbx/examples/oracle/output/oracle_example1_multi_statement_transformation.py create mode 100644 sql2dbx/examples/oracle/output/oracle_example2_stored_procedure.py create mode 100644 sql2dbx/examples/postgresql/input/postgresql_example1_multi_statement_transformation.sql create mode 100644 sql2dbx/examples/postgresql/input/postgresql_example2_stored_procedure.sql create mode 100644 sql2dbx/examples/postgresql/output/postgresql_example1_multi_statement_transformation.py create mode 100644 sql2dbx/examples/postgresql/output/postgresql_example2_stored_procedure.py create mode 100644 sql2dbx/examples/redshift/input/redshift_example1_multi_statement_transformation.sql create mode 100644 sql2dbx/examples/redshift/input/redshift_example2_stored_procedure.sql create mode 100644 sql2dbx/examples/redshift/output/redshift_example1_multi_statement_transformation.py create mode 100644 sql2dbx/examples/redshift/output/redshift_example2_stored_procedure.py create mode 100644 sql2dbx/examples/snowflake/input/snowflake_example1_multi_statement_transformation.sql create mode 100644 sql2dbx/examples/snowflake/input/snowflake_example2_javascript_procedure.sql create mode 100644 sql2dbx/examples/snowflake/output/snowflake_example1_multi_statement_transformation.py create mode 100644 sql2dbx/examples/snowflake/output/snowflake_example2_javascript_procedure.py create mode 100644 sql2dbx/examples/teradata/input/teradata_example1_multi_statement_transformation.sql create mode 100644 sql2dbx/examples/teradata/input/teradata_example2_stored_procedure.sql create mode 100644 sql2dbx/examples/teradata/output/teradata_example1_multi_statement_transformation.py create mode 100644 sql2dbx/examples/teradata/output/teradata_example2_stored_procedure.py create mode 100644 sql2dbx/examples/tsql/input/tsql_example1_multi_statement_transformation.sql create mode 100644 sql2dbx/examples/tsql/input/tsql_example2_stored_procedure.sql create mode 100644 sql2dbx/examples/tsql/output/tsql_example1_multi_statement_transformation.py create mode 100644 sql2dbx/examples/tsql/output/tsql_example2_stored_procedure.py create mode 100644 sql2dbx/notebooks/00_main.py create mode 100644 sql2dbx/notebooks/00_main_ja.py create mode 100644 sql2dbx/notebooks/01_analyze_input_files.py create mode 100644 sql2dbx/notebooks/02_convert_sql_to_databricks.py create mode 100644 sql2dbx/notebooks/03_01_static_syntax_check.py create mode 100644 sql2dbx/notebooks/03_02_fix_syntax_error.py create mode 100644 sql2dbx/notebooks/04_split_cells.py create mode 100644 sql2dbx/notebooks/05_export_to_databricks_notebooks.py create mode 100644 sql2dbx/notebooks/11_adjust_conversion_targets.py create mode 100644 sql2dbx/notebooks/__init__.py create mode 100644 sql2dbx/notebooks/external_model/external_model_amazon_bedrock.py create mode 100644 sql2dbx/notebooks/external_model/external_model_azure_openai.py create mode 100644 sql2dbx/notebooks/external_model/helper_external_model.py create mode 100644 sql2dbx/notebooks/notebook_utils.py create mode 100644 sql2dbx/notebooks/pyscripts/__init__.py create mode 100644 sql2dbx/notebooks/pyscripts/batch_inference_helper.py create mode 100644 sql2dbx/notebooks/pyscripts/cell_split_helper.py create mode 100644 sql2dbx/notebooks/pyscripts/conversion_prompt_helper.py create mode 100644 sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/common_instructions/sql_to_databricks_notebook_common_python.yml create mode 100644 sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/mysql_to_databricks_notebook.yml create mode 100644 sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/netezza_to_databricks_notebook.yml create mode 100644 sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/oracle_to_databricks_notebook.yml create mode 100644 sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/postgresql_to_databricks_notebook.yml create mode 100644 sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/redshift_to_databricks_notebook.yml create mode 100644 sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/snowflake_to_databricks_notebook.yml create mode 100644 sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/teradata_to_databricks_notebook.yml create mode 100644 sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/tsql_to_databricks_notebook.yml create mode 100644 sql2dbx/notebooks/pyscripts/conversion_result_clean_helper.py create mode 100644 sql2dbx/notebooks/pyscripts/databricks_credentials.py create mode 100644 sql2dbx/notebooks/pyscripts/llm_token_count_helper.py create mode 100644 sql2dbx/notebooks/pyscripts/notebook_export_helper.py create mode 100644 sql2dbx/notebooks/pyscripts/notebook_i18n.py create mode 100644 sql2dbx/notebooks/pyscripts/spark_sql_extract_helper.py create mode 100644 sql2dbx/notebooks/pyscripts/token_utils.py create mode 100644 sql2dbx/notebooks/pyscripts/utils.py create mode 100644 sql2dbx/notebooks/requirements.txt create mode 100644 sql2dbx/resources/sql2dbx_job_main.yml create mode 100644 sql2dbx/tests/__init__.py create mode 100644 sql2dbx/tests/pyscripts/__init__.py create mode 100644 sql2dbx/tests/pyscripts/test_cell_split_helper.py create mode 100644 sql2dbx/tests/pyscripts/test_conversion_prompt_helper.py create mode 100644 sql2dbx/tests/pyscripts/test_conversion_result_clean_helper.py create mode 100644 sql2dbx/tests/pyscripts/test_llm_token_count_helper.py create mode 100644 sql2dbx/tests/pyscripts/test_notebook_export_helper.py create mode 100644 sql2dbx/tests/pyscripts/test_notebook_i18n.py create mode 100644 sql2dbx/tests/pyscripts/test_spark_sql_extract_helper.py create mode 100644 sql2dbx/tests/pyscripts/test_token_utils.py create mode 100644 sql2dbx/tests/pyscripts/test_utils.py diff --git a/sql2dbx/.gitignore b/sql2dbx/.gitignore new file mode 100644 index 00000000..a42c648b --- /dev/null +++ b/sql2dbx/.gitignore @@ -0,0 +1,173 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Databricks +.databricks/ + +# Project specific +.clinerules/ +.vscode/ +scratch/** +!scratch/README.md +test_output/ +databricks.yml diff --git a/sql2dbx/README.md b/sql2dbx/README.md new file mode 100644 index 00000000..c0d15207 --- /dev/null +++ b/sql2dbx/README.md @@ -0,0 +1,24 @@ +--- +title: "sql2dbx" +language: python +author: "Hiroyuki Nakazato" +date: 2025-4-25 + +tags: +- sql-migration-tool +- multi-dialect-sql +- llm +- automation +--- + +# sql2dbx +**sql2dbx** is an automation tool designed to convert SQL files into Databricks notebooks. It leverages Large Language Models (LLMs) to perform the conversion based on system prompts tailored for various SQL dialects. sql2dbx consists of a series of Databricks notebooks. + +## How to Execute +1. Clone the [databrickslabs/sandbox](https://github.com/databrickslabs/sandbox) repository. +2. Import the `sql2dbx` folder into your Databricks workspace. +3. Run either notebook as your entry point: + - `notebooks/00_main` (English) + - `notebooks/00_main_ja` (Japanese) + +These notebooks contain all instructions and documentation needed to use sql2dbx. diff --git a/sql2dbx/examples/mysql/input/mysql_example1_multi_statement_transformation.sql b/sql2dbx/examples/mysql/input/mysql_example1_multi_statement_transformation.sql new file mode 100644 index 00000000..bf4ce741 --- /dev/null +++ b/sql2dbx/examples/mysql/input/mysql_example1_multi_statement_transformation.sql @@ -0,0 +1,51 @@ +-- ========================================== +-- MySQL EXAMPLE #1: Multi-Statement Data Transformation +-- ========================================== + +-- Create a table for orders +CREATE TABLE Orders ( + OrderID INT, + CustomerName VARCHAR(100), + OrderDate DATETIME DEFAULT NOW(), + OrderTotal DECIMAL(10,2) +); + +-- Insert some sample orders +INSERT INTO Orders (OrderID, CustomerName, OrderTotal) +VALUES + (101, 'Alice', 200.00), + (102, 'Bob', 350.75), + (103, 'Charlie', 99.99); + +-- Create a temporary table for order statuses +CREATE TEMPORARY TABLE TempOrderStatus ( + OrderID INT, + Status VARCHAR(50) +); + +-- Insert statuses +INSERT INTO TempOrderStatus (OrderID, Status) +VALUES + (101, 'PROCESSING'), + (102, 'SHIPPED'), + (104, 'CANCELLED'); + +-- Update orders with a discount if they appear in the temporary status table +-- Demonstrates MySQL's UPDATE with JOIN syntax +UPDATE Orders AS o +JOIN TempOrderStatus AS t ON o.OrderID = t.OrderID +SET o.OrderTotal = o.OrderTotal * 0.90 -- 10% discount +WHERE t.Status = 'SHIPPED'; + +-- Delete any order older than 90 days if not referenced in TempOrderStatus +DELETE o +FROM Orders AS o +WHERE o.OrderDate < DATE_SUB(NOW(), INTERVAL 90 DAY) + AND o.OrderID NOT IN (SELECT OrderID FROM TempOrderStatus); + +-- Final check +SELECT * FROM Orders; + +-- Clean up +-- DROP TABLE IF EXISTS TempOrderStatus; -- Temp tables are automatically dropped at the end of the session +DROP TABLE IF EXISTS Orders; diff --git a/sql2dbx/examples/mysql/input/mysql_example2_stored_procedure.sql b/sql2dbx/examples/mysql/input/mysql_example2_stored_procedure.sql new file mode 100644 index 00000000..5fecdf04 --- /dev/null +++ b/sql2dbx/examples/mysql/input/mysql_example2_stored_procedure.sql @@ -0,0 +1,54 @@ +-- ========================================== +-- MySQL EXAMPLE #2: Stored Procedure with Threshold Checking +-- ========================================== + +DELIMITER $$ + +CREATE PROCEDURE DemoThresholdCheck( + IN p_table_name VARCHAR(64), + IN p_threshold DECIMAL(10,2), + OUT p_rows_updated INT +) +BEGIN + -- Declare a handler to catch any SQL errors, then roll back + DECLARE EXIT HANDLER FOR SQLEXCEPTION + BEGIN + ROLLBACK; + SET p_rows_updated = -1; + END; + + -- Start a transaction + START TRANSACTION; + + -- 1) Create a temporary table that captures rows above the threshold + -- We'll build this query dynamically based on p_table_name + SET @sql = CONCAT( + 'CREATE TEMPORARY TABLE TempData AS ', + 'SELECT id, metric ', + 'FROM ', p_table_name, ' ', + 'WHERE metric > ', p_threshold + ); + + PREPARE stmt FROM @sql; + EXECUTE stmt; + DEALLOCATE PREPARE stmt; + + -- 2) Update the original table to cap values at the threshold + SET @sql = CONCAT( + 'UPDATE ', p_table_name, ' ', + 'SET metric = ', p_threshold, ' ', + 'WHERE metric > ', p_threshold + ); + + PREPARE stmt FROM @sql; + EXECUTE stmt; + SET p_rows_updated = ROW_COUNT(); -- track how many rows changed + DEALLOCATE PREPARE stmt; + + -- DROP TEMPORARY TABLE IF EXISTS TempData; -- Temp tables are automatically dropped at the end of the session + + -- Commit the transaction + COMMIT; +END $$ + +DELIMITER ; diff --git a/sql2dbx/examples/mysql/output/mysql_example1_multi_statement_transformation.py b/sql2dbx/examples/mysql/output/mysql_example1_multi_statement_transformation.py new file mode 100644 index 00000000..27efef7b --- /dev/null +++ b/sql2dbx/examples/mysql/output/mysql_example1_multi_statement_transformation.py @@ -0,0 +1,89 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # mysql_example1_multi_statement_transformation +# MAGIC This notebook was automatically converted from the script below. It may contain errors, so use it as a starting point and make necessary corrections. +# MAGIC +# MAGIC Source script: `/Workspace/Users/hiroyuki.nakazato@databricks.com/.bundle/sql2dbx/dev/files/examples/mysql/input/mysql_example1_multi_statement_transformation.sql` + +# COMMAND ---------- + +# Create the Orders table +spark.sql(""" +CREATE TABLE Orders ( + OrderID INT, + CustomerName STRING, + OrderDate TIMESTAMP DEFAULT current_timestamp(), + OrderTotal DECIMAL(10,2) +) +""") + +# COMMAND ---------- + +# Insert data into Orders table +spark.sql(""" +INSERT INTO Orders (OrderID, CustomerName, OrderTotal) +VALUES + (101, 'Alice', 200.00), + (102, 'Bob', 350.75), + (103, 'Charlie', 99.99) +""") + +# COMMAND ---------- + +# Create "temporary" table as Delta table +spark.sql(""" +CREATE OR REPLACE TABLE TempOrderStatus ( + OrderID INT, + Status STRING +) +""") + +# COMMAND ---------- + +# Insert data into temporary status table +spark.sql(""" +INSERT INTO TempOrderStatus (OrderID, Status) +VALUES + (101, 'PROCESSING'), + (102, 'SHIPPED'), + (104, 'CANCELLED') +""") + +# COMMAND ---------- + +# Update Orders using MERGE pattern since Databricks doesn't support JOIN in UPDATE +spark.sql(""" +MERGE INTO Orders o +USING (SELECT * FROM TempOrderStatus WHERE Status = 'SHIPPED') t +ON o.OrderID = t.OrderID +WHEN MATCHED THEN + UPDATE SET o.OrderTotal = o.OrderTotal * 0.90 +""") + +# COMMAND ---------- + +# Delete old orders not in the status table +spark.sql(""" +DELETE FROM Orders +WHERE OrderDate < date_sub(current_timestamp(), 90) +AND OrderID NOT IN (SELECT OrderID FROM TempOrderStatus) +""") + +# COMMAND ---------- + +# Query the results +orders_df = spark.sql("SELECT * FROM Orders") +display(orders_df) + +# COMMAND ---------- + +# Clean up tables +spark.sql("DROP TABLE IF EXISTS Orders") +spark.sql("DROP TABLE IF EXISTS TempOrderStatus") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Static Syntax Check Results +# MAGIC No syntax errors were detected during the static check. +# MAGIC However, please review the code carefully as some issues may only be detected during runtime. \ No newline at end of file diff --git a/sql2dbx/examples/mysql/output/mysql_example2_stored_procedure.py b/sql2dbx/examples/mysql/output/mysql_example2_stored_procedure.py new file mode 100644 index 00000000..6ca376bb --- /dev/null +++ b/sql2dbx/examples/mysql/output/mysql_example2_stored_procedure.py @@ -0,0 +1,92 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # mysql_example2_stored_procedure +# MAGIC This notebook was automatically converted from the script below. It may contain errors, so use it as a starting point and make necessary corrections. +# MAGIC +# MAGIC Source script: `/Workspace/Users/hiroyuki.nakazato@databricks.com/.bundle/sql2dbx/dev/files/examples/mysql/input/mysql_example2_stored_procedure.sql` + +# COMMAND ---------- + +# Create widgets for procedure input parameters +dbutils.widgets.text("table_name", "") +dbutils.widgets.text("threshold", "") + +# COMMAND ---------- + +# Function to perform the threshold check and update +def demo_threshold_check(): + # Get parameters from widgets + table_name = dbutils.widgets.get("table_name") + + try: + threshold = float(dbutils.widgets.get("threshold")) + except ValueError: + print("Invalid threshold value") + return -1 + + if not table_name: + print("Table name must be provided") + return -1 + + # For rollback simulation with Delta tables + try: + hist = spark.sql(f"DESCRIBE HISTORY {table_name} LIMIT 1").collect()[0] + restore_timestamp = hist["timestamp"] + except Exception as e: + print(f"Warning: Could not get history for table {table_name}: {str(e)}") + restore_timestamp = None + + rows_updated = -1 + + try: + # Create Delta table instead of temporary table + spark.sql(f""" + CREATE OR REPLACE TABLE TempData AS + SELECT id, metric + FROM {table_name} + WHERE metric > {threshold} + """) + + # Update the table with the threshold + update_result = spark.sql(f""" + UPDATE {table_name} + SET metric = {threshold} + WHERE metric > {threshold} + """) + + # Get row count from the temp table + rows_updated = spark.sql("SELECT COUNT(*) AS cnt FROM TempData").collect()[0]['cnt'] + + # Clean up + spark.sql("DROP TABLE IF EXISTS TempData") + + return rows_updated + + except Exception as e: + # Simulate ROLLBACK + if restore_timestamp: + spark.sql(f"RESTORE TABLE {table_name} TO TIMESTAMP AS OF '{restore_timestamp}'") + print(f"Error occurred: {str(e)}. Table restored to previous version.") + else: + print(f"Error occurred: {str(e)}. Could not restore table.") + + return -1 + +# COMMAND ---------- + +# Execute the function and display result +rows_updated = demo_threshold_check() +print(f"Rows updated: {rows_updated}") + +# COMMAND ---------- + +# You can create an output DataFrame if you prefer to display the result that way +result_df = spark.createDataFrame([(rows_updated,)], ["rows_updated"]) +display(result_df) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Static Syntax Check Results +# MAGIC No syntax errors were detected during the static check. +# MAGIC However, please review the code carefully as some issues may only be detected during runtime. \ No newline at end of file diff --git a/sql2dbx/examples/netezza/input/netezza_example1_multi_statement_transformation.sql b/sql2dbx/examples/netezza/input/netezza_example1_multi_statement_transformation.sql new file mode 100644 index 00000000..13401c44 --- /dev/null +++ b/sql2dbx/examples/netezza/input/netezza_example1_multi_statement_transformation.sql @@ -0,0 +1,53 @@ +-- ========================================== +-- Netezza EXAMPLE #1: Multi-Statement Data Transformation +-- ========================================== + +-- Create a table for product data, with a distribution key on ProductID +CREATE TABLE PRODUCTS ( + PRODUCTID INT, + PRODUCTNAME VARCHAR(100), + PRICE FLOAT, + CREATEDAT TIMESTAMP DEFAULT CURRENT_TIMESTAMP +) +DISTRIBUTE ON (PRODUCTID); + +-- Insert some sample products +INSERT INTO PRODUCTS (PRODUCTID, PRODUCTNAME, PRICE) +VALUES + (1001, 'NZA Widget', 15.75), + (1002, 'NZA Gadget', 21.50), + (1003, 'NZA Gizmo', 42.99); + +-- Create a table for discount information +CREATE TABLE DISCOUNTS ( + PRODUCTID INT, + DISCOUNTRATE FLOAT +) +DISTRIBUTE ON (PRODUCTID); + +-- Insert discount rates +INSERT INTO DISCOUNTS (PRODUCTID, DISCOUNTRATE) +VALUES + (1001, 0.20), + (1003, 0.15); + +-- Update product prices using a join +-- Note: Netezza allows an UPDATE with a FROM clause +UPDATE PRODUCTS AS p +SET p.PRICE = p.PRICE * (1 - d.DISCOUNTRATE) +FROM DISCOUNTS AS d +WHERE p.PRODUCTID = d.PRODUCTID; + +-- Demonstrate a conditional DELETE: +-- Remove products older than 7 days that do not appear in DISCOUNTS +-- (Assume for example that items beyond 7 days with no discount are deprecated) +DELETE FROM PRODUCTS p +WHERE p.CREATEDAT < (CURRENT_TIMESTAMP - INTERVAL '7' DAY) + AND p.PRODUCTID NOT IN (SELECT PRODUCTID FROM DISCOUNTS); + +-- Final SELECT to confirm changes +SELECT * FROM PRODUCTS; + +-- Clean up +DROP TABLE IF EXISTS DISCOUNTS; +DROP TABLE IF EXISTS PRODUCTS; diff --git a/sql2dbx/examples/netezza/input/netezza_example2_stored_procedure.sql b/sql2dbx/examples/netezza/input/netezza_example2_stored_procedure.sql new file mode 100644 index 00000000..56d5075f --- /dev/null +++ b/sql2dbx/examples/netezza/input/netezza_example2_stored_procedure.sql @@ -0,0 +1,60 @@ +-- ========================================== +-- Netezza EXAMPLE #2: Stored Procedure with Transaction and Dynamic SQL +-- ========================================== + +CREATE PROCEDURE DEMO_ADJUST_THRESHOLD( + IN P_TABLENAME VARCHAR(128), + IN P_THRESHOLD FLOAT DEFAULT 100.0 +) +RETURNS INT +LANGUAGE NZPLSQL +AS +BEGIN_PROC + DECLARE V_COUNT INT := 0; -- Will store number of updated rows + DECLARE V_ERRORMSG VARCHAR(500); -- For capturing error details + DECLARE V_SQL VARCHAR(1000); -- For building dynamic SQL + +BEGIN + -- Begin a transaction for our updates + BEGIN; + + RAISE NOTICE 'Starting threshold adjustment in table: % with threshold: %', P_TABLENAME, P_THRESHOLD; + + -- 1) Build dynamic SQL to cap values at a certain threshold + -- Suppose the table has a column "METRICVALUE" that needs capping + V_SQL := 'UPDATE ' || P_TABLENAME || + ' SET METRICVALUE = ' || P_THRESHOLD || + ' WHERE METRICVALUE > ' || P_THRESHOLD || ';'; + + -- 2) Execute dynamic SQL + EXECUTE IMMEDIATE V_SQL; + + -- 3) Check how many rows were updated + V_SQL := 'SELECT COUNT(*) FROM ' || P_TABLENAME || + ' WHERE METRICVALUE = ' || P_THRESHOLD || ';'; + EXECUTE IMMEDIATE V_SQL INTO V_COUNT; + + RAISE NOTICE 'Number of rows updated to threshold: %', V_COUNT; + + -- Commit the transaction if all is well + COMMIT; + + RETURN V_COUNT; + +EXCEPTION + WHEN OTHERS THEN + -- Capture error message + GET STACKED DIAGNOSTICS V_ERRORMSG = MESSAGE_TEXT; + + RAISE NOTICE 'Error in procedure DEMO_ADJUST_THRESHOLD: %', V_ERRORMSG; + + -- Roll back any partial changes + IF GET_TRAN_COUNT() > 0 THEN + ROLLBACK; + END IF; + + -- Re-raise the error so that the caller is aware + RAISE; +END; + +END_PROC; diff --git a/sql2dbx/examples/netezza/output/netezza_example1_multi_statement_transformation.py b/sql2dbx/examples/netezza/output/netezza_example1_multi_statement_transformation.py new file mode 100644 index 00000000..14fc53f3 --- /dev/null +++ b/sql2dbx/examples/netezza/output/netezza_example1_multi_statement_transformation.py @@ -0,0 +1,87 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # netezza_example1_multi_statement_transformation +# MAGIC This notebook was automatically converted from the script below. It may contain errors, so use it as a starting point and make necessary corrections. +# MAGIC +# MAGIC Source script: `/Workspace/Users/hiroyuki.nakazato@databricks.com/.bundle/sql2dbx/dev/files/examples/netezza/input/netezza_example1_multi_statement_transformation.sql` + +# COMMAND ---------- + +# Create PRODUCTS table (omitting DISTRIBUTE ON clause as it's not supported in Spark) +spark.sql(""" +CREATE TABLE IF NOT EXISTS PRODUCTS ( + PRODUCTID INT, + PRODUCTNAME STRING, + PRICE FLOAT, + CREATEDAT TIMESTAMP DEFAULT CURRENT_TIMESTAMP() +) +""") + +# COMMAND ---------- + +# Insert sample products +spark.sql(""" +INSERT INTO PRODUCTS (PRODUCTID, PRODUCTNAME, PRICE) VALUES + (1001, 'NZA Widget', 15.75), + (1002, 'NZA Gadget', 21.50), + (1003, 'NZA Gizmo', 42.99) +""") + +# COMMAND ---------- + +# Create DISCOUNTS table (omitting DISTRIBUTE ON clause) +spark.sql(""" +CREATE TABLE IF NOT EXISTS DISCOUNTS ( + PRODUCTID INT, + DISCOUNTRATE FLOAT +) +""") + +# COMMAND ---------- + +# Insert discount data +spark.sql(""" +INSERT INTO DISCOUNTS (PRODUCTID, DISCOUNTRATE) VALUES + (1001, 0.20), + (1003, 0.15) +""") + +# COMMAND ---------- + +# Update product prices with discounts +# Databricks doesn't support UPDATE with FROM clause, so we use MERGE instead +spark.sql(""" +MERGE INTO PRODUCTS t +USING DISCOUNTS s +ON t.PRODUCTID = s.PRODUCTID +WHEN MATCHED THEN + UPDATE SET t.PRICE = t.PRICE * (1 - s.DISCOUNTRATE) +""") + +# COMMAND ---------- + +# Delete old products not in discounts table +spark.sql(""" +DELETE FROM PRODUCTS +WHERE CREATEDAT < date_sub(CURRENT_TIMESTAMP(), 7) +AND PRODUCTID NOT IN (SELECT PRODUCTID FROM DISCOUNTS) +""") + +# COMMAND ---------- + +# Display current products +products_df = spark.sql("SELECT * FROM PRODUCTS") +display(products_df) + +# COMMAND ---------- + +# Clean up tables +spark.sql("DROP TABLE IF EXISTS DISCOUNTS") +spark.sql("DROP TABLE IF EXISTS PRODUCTS") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Static Syntax Check Results +# MAGIC No syntax errors were detected during the static check. +# MAGIC However, please review the code carefully as some issues may only be detected during runtime. \ No newline at end of file diff --git a/sql2dbx/examples/netezza/output/netezza_example2_stored_procedure.py b/sql2dbx/examples/netezza/output/netezza_example2_stored_procedure.py new file mode 100644 index 00000000..31ae8fb1 --- /dev/null +++ b/sql2dbx/examples/netezza/output/netezza_example2_stored_procedure.py @@ -0,0 +1,79 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # netezza_example2_stored_procedure +# MAGIC This notebook was automatically converted from the script below. It may contain errors, so use it as a starting point and make necessary corrections. +# MAGIC +# MAGIC Source script: `/Workspace/Users/hiroyuki.nakazato@databricks.com/.bundle/sql2dbx/dev/files/examples/netezza/input/netezza_example2_stored_procedure.sql` + +# COMMAND ---------- + +# Create input widgets for parameters +dbutils.widgets.text("TableName", "") +dbutils.widgets.text("Threshold", "100.0") + +# COMMAND ---------- + +# Function to implement the procedure +def demo_adjust_threshold(): + # Get parameter values + table_name = dbutils.widgets.get("TableName") + threshold = float(dbutils.widgets.get("Threshold")) + + # Validate table name + if not table_name: + print("ERROR: Table name cannot be empty") + return -1 + + # Get current version info for rollback if needed + try: + hist = spark.sql(f"DESCRIBE HISTORY {table_name} LIMIT 1").collect()[0] + restore_ts = hist["timestamp"] + except Exception as e: + print(f"Warning: Could not get history for {table_name}. Rollback may not be possible.") + restore_ts = None + + # Execute the update with transaction simulation + try: + print(f"Starting threshold adjustment in table: {table_name} with threshold: {threshold}") + + # Update query + spark.sql(f""" + UPDATE {table_name} + SET METRICVALUE = {threshold} + WHERE METRICVALUE > {threshold} + """) + + # Count affected rows + result_df = spark.sql(f""" + SELECT COUNT(*) AS count + FROM {table_name} + WHERE METRICVALUE = {threshold} + """) + + count = result_df.collect()[0]["count"] + print(f"Number of rows updated to threshold: {count}") + + return count + + except Exception as e: + print(f"Error in procedure DEMO_ADJUST_THRESHOLD: {str(e)}") + + # Rollback if possible + if restore_ts: + print(f"Attempting to restore {table_name} to previous version") + spark.sql(f"RESTORE TABLE {table_name} TO TIMESTAMP AS OF '{restore_ts}'") + + raise e + +# COMMAND ---------- + +# Execute procedure +result = demo_adjust_threshold() +print(f"Result: {result}") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Static Syntax Check Results +# MAGIC No syntax errors were detected during the static check. +# MAGIC However, please review the code carefully as some issues may only be detected during runtime. \ No newline at end of file diff --git a/sql2dbx/examples/oracle/input/oracle_example1_multi_statement_transformation.sql b/sql2dbx/examples/oracle/input/oracle_example1_multi_statement_transformation.sql new file mode 100644 index 00000000..0b5fcfc7 --- /dev/null +++ b/sql2dbx/examples/oracle/input/oracle_example1_multi_statement_transformation.sql @@ -0,0 +1,59 @@ +-- ========================================== +-- ORACLE EXAMPLE #1: Multi-Statement Data Transformation +-- ========================================== + +-- Create a table to hold product data +CREATE TABLE Products ( + ProductID NUMBER(10, 0), + ProductName VARCHAR2(100), + Price NUMBER(10, 2), + CreatedAt DATE DEFAULT SYSDATE +); + +-- Insert sample products +INSERT INTO Products (ProductID, ProductName, Price) +VALUES + (201, 'Gizmo X', 14.50), + (202, 'Gizmo Y', 21.99), + (203, 'Gizmo Z', 38.25); + +-- Create a Global Temporary Table (or a regular table if preferred) +-- to store discount information +CREATE GLOBAL TEMPORARY TABLE DiscountInfo ( + ProductID NUMBER(10, 0), + DiscountRate NUMBER(5, 2) +) ON COMMIT PRESERVE ROWS; + +-- Insert discount rates +INSERT INTO DiscountInfo (ProductID, DiscountRate) +VALUES + (201, 0.15), + (203, 0.30); + +-- Update product prices to apply discounts (use a subquery or MERGE in Oracle) +UPDATE Products p +SET p.Price = p.Price * (1 - ( + SELECT d.DiscountRate + FROM DiscountInfo d + WHERE d.ProductID = p.ProductID +)) +WHERE p.ProductID IN ( + SELECT d.ProductID + FROM DiscountInfo d +); + +-- Conditional DELETE: +-- Delete products older than 7 days, if they have no discount entry +DELETE FROM Products p +WHERE p.CreatedAt < (SYSDATE - 7) + AND p.ProductID NOT IN ( + SELECT d.ProductID + FROM DiscountInfo d + ); + +-- Final SELECT to confirm changes +SELECT * FROM Products; + +-- Clean up +-- DROP TABLE DiscountInfo; -- Global Temporary Tables are automatically dropped at the end of the session +DROP TABLE Products; diff --git a/sql2dbx/examples/oracle/input/oracle_example2_stored_procedure.sql b/sql2dbx/examples/oracle/input/oracle_example2_stored_procedure.sql new file mode 100644 index 00000000..21a00ab3 --- /dev/null +++ b/sql2dbx/examples/oracle/input/oracle_example2_stored_procedure.sql @@ -0,0 +1,76 @@ +-- ========================================== +-- ORACLE EXAMPLE #2: Stored Procedure with Outlier Capping +-- ========================================== + +CREATE OR REPLACE PROCEDURE AdjustSalesForecast( + p_multiplier IN NUMBER DEFAULT 1.20 +) AS + -- Variables + v_today DATE; + v_error_message VARCHAR2(4000); + v_proc_name VARCHAR2(128) := 'AdjustSalesForecast'; + + -- Temporary table to hold threshold data + -- (In Oracle, you can also use Global Temporary Tables if desired) +BEGIN + -- Example: create a temporary structure + EXECUTE IMMEDIATE ' + CREATE TABLE TempOutlierThreshold ( + LocationId VARCHAR2(10), + ThresholdValue NUMBER(10, 2) + ) + '; + + BEGIN + -- 1) Get the current date + SELECT SYSDATE INTO v_today FROM DUAL; + + -- 2) Insert threshold data (e.g., 99th percentile) for each location + -- We'll simulate it here with a fixed approach for brevity + INSERT INTO TempOutlierThreshold (LocationId, ThresholdValue) + SELECT + LocationId, + 1000 * p_multiplier -- pretend we computed a high threshold + FROM SalesHistory + WHERE SalesDate > ADD_MONTHS(v_today, -12); + + -- 3) Update ForecastTable to store original forecast if it exceeds threshold + UPDATE ForecastTable f + SET f.OriginalForecast = f.ForecastValue + WHERE f.ForecastValue > + (SELECT t.ThresholdValue FROM TempOutlierThreshold t + WHERE t.LocationId = f.LocationId) + * p_multiplier + AND TRUNC(f.ForecastDate) = TRUNC(v_today); + + -- 4) Cap those outliers at threshold * multiplier + UPDATE ForecastTable f + SET f.ForecastValue = + (SELECT t.ThresholdValue FROM TempOutlierThreshold t + WHERE t.LocationId = f.LocationId) + * p_multiplier + WHERE f.ForecastValue > + (SELECT t.ThresholdValue FROM TempOutlierThreshold t + WHERE t.LocationId = f.LocationId) + * p_multiplier + AND TRUNC(f.ForecastDate) = TRUNC(v_today); + + -- Commit transaction + COMMIT; + EXCEPTION + WHEN OTHERS THEN + -- Capture error info + v_error_message := SQLERRM; + DBMS_OUTPUT.put_line('Error in procedure ' || v_proc_name || ': ' || v_error_message); + + -- Roll back any changes + ROLLBACK; + + -- Optionally re-raise the exception + RAISE; + END; + + -- Cleanup + -- EXECUTE IMMEDIATE 'DROP TABLE TempOutlierThreshold'; -- Temp tables are automatically dropped at the end of the session +END; +/ diff --git a/sql2dbx/examples/oracle/output/oracle_example1_multi_statement_transformation.py b/sql2dbx/examples/oracle/output/oracle_example1_multi_statement_transformation.py new file mode 100644 index 00000000..6f785770 --- /dev/null +++ b/sql2dbx/examples/oracle/output/oracle_example1_multi_statement_transformation.py @@ -0,0 +1,92 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # oracle_example1_multi_statement_transformation +# MAGIC This notebook was automatically converted from the script below. It may contain errors, so use it as a starting point and make necessary corrections. +# MAGIC +# MAGIC Source script: `/Workspace/Users/hiroyuki.nakazato@databricks.com/.bundle/sql2dbx/dev/files/examples/oracle/input/oracle_example1_multi_statement_transformation.sql` + +# COMMAND ---------- + +# Create Products table +spark.sql(""" +CREATE TABLE IF NOT EXISTS Products ( + ProductID INT, + ProductName STRING, + Price DECIMAL(10, 2), + CreatedAt DATE DEFAULT current_date() +) +""") + +# COMMAND ---------- + +# Insert product data +spark.sql(""" +INSERT INTO Products (ProductID, ProductName, Price) +VALUES + (201, 'Gizmo X', 14.50), + (202, 'Gizmo Y', 21.99), + (203, 'Gizmo Z', 38.25) +""") + +# COMMAND ---------- + +# Create a table to replace the global temporary table +spark.sql(""" +CREATE OR REPLACE TABLE DiscountInfo ( + ProductID INT, + DiscountRate DECIMAL(5, 2) +) +""") + +# COMMAND ---------- + +# Insert discount data +spark.sql(""" +INSERT INTO DiscountInfo (ProductID, DiscountRate) +VALUES + (201, 0.15), + (203, 0.30) +""") + +# COMMAND ---------- + +# Update product prices based on discount rates +# Using MERGE since Databricks may not support subquery in SET clause +spark.sql(""" +MERGE INTO Products p +USING ( + SELECT p.ProductID, p.Price * (1 - d.DiscountRate) AS NewPrice + FROM Products p + JOIN DiscountInfo d ON p.ProductID = d.ProductID +) src +ON p.ProductID = src.ProductID +WHEN MATCHED THEN UPDATE SET p.Price = src.NewPrice +""") + +# COMMAND ---------- + +# Delete old products that don't have discounts +spark.sql(""" +DELETE FROM Products +WHERE CreatedAt < date_sub(current_date(), 7) +AND ProductID NOT IN (SELECT ProductID FROM DiscountInfo) +""") + +# COMMAND ---------- + +# Select all products +result = spark.sql("SELECT * FROM Products") +display(result) + +# COMMAND ---------- + +# Clean up (drop tables) +spark.sql("DROP TABLE IF EXISTS Products") +spark.sql("DROP TABLE IF EXISTS DiscountInfo") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Static Syntax Check Results +# MAGIC No syntax errors were detected during the static check. +# MAGIC However, please review the code carefully as some issues may only be detected during runtime. \ No newline at end of file diff --git a/sql2dbx/examples/oracle/output/oracle_example2_stored_procedure.py b/sql2dbx/examples/oracle/output/oracle_example2_stored_procedure.py new file mode 100644 index 00000000..36cbbf9e --- /dev/null +++ b/sql2dbx/examples/oracle/output/oracle_example2_stored_procedure.py @@ -0,0 +1,91 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # oracle_example2_stored_procedure +# MAGIC This notebook was automatically converted from the script below. It may contain errors, so use it as a starting point and make necessary corrections. +# MAGIC +# MAGIC Source script: `/Workspace/Users/hiroyuki.nakazato@databricks.com/.bundle/sql2dbx/dev/files/examples/oracle/input/oracle_example2_stored_procedure.sql` + +# COMMAND ---------- + +# Create a widget for the parameter with default value +dbutils.widgets.text("multiplier", "1.20", "Forecast Adjustment Multiplier") + +# COMMAND ---------- + +# Get the parameter value +p_multiplier = float(dbutils.widgets.get("multiplier")) + +# COMMAND ---------- + +# Create a table to hold threshold values (instead of a temp table) +spark.sql(""" +CREATE OR REPLACE TABLE TempOutlierThreshold ( + LocationId STRING, + ThresholdValue DECIMAL(10, 2) +) +""") + +# COMMAND ---------- + +try: + # Populate the threshold table + spark.sql(f""" + INSERT INTO TempOutlierThreshold (LocationId, ThresholdValue) + SELECT LocationId, 1000 * {p_multiplier} + FROM SalesHistory + WHERE SalesDate > add_months(current_date(), -12) + """) + + # Create a view for forecasts that need updating + spark.sql(f""" + CREATE OR REPLACE TEMPORARY VIEW ForecastsToUpdate AS + SELECT f.* + FROM ForecastTable f + JOIN TempOutlierThreshold t ON t.LocationId = f.LocationId + WHERE f.ForecastValue > t.ThresholdValue * {p_multiplier} + AND date_trunc('day', f.ForecastDate) = date_trunc('day', current_date()) + """) + + # First update - Save original forecast values + spark.sql(""" + MERGE INTO ForecastTable target + USING ForecastsToUpdate source + ON target.LocationId = source.LocationId AND target.ForecastDate = source.ForecastDate + WHEN MATCHED THEN UPDATE SET target.OriginalForecast = source.ForecastValue + """) + + # Create a view with calculated thresholds + spark.sql(f""" + CREATE OR REPLACE TEMPORARY VIEW ForecastWithThresholds AS + SELECT + f.*, + t.ThresholdValue * {p_multiplier} as CalculatedThreshold + FROM ForecastTable f + JOIN TempOutlierThreshold t ON t.LocationId = f.LocationId + WHERE f.ForecastValue > t.ThresholdValue * {p_multiplier} + AND date_trunc('day', f.ForecastDate) = date_trunc('day', current_date()) + """) + + # Second update - Cap forecast values to threshold + spark.sql(""" + MERGE INTO ForecastTable target + USING ForecastWithThresholds source + ON target.LocationId = source.LocationId AND target.ForecastDate = source.ForecastDate + WHEN MATCHED THEN UPDATE SET target.ForecastValue = source.CalculatedThreshold + """) + +except Exception as e: + print(f"Error in procedure AdjustSalesForecast: {str(e)}") + raise e +finally: + # Clean up temporary objects + spark.sql("DROP TABLE IF EXISTS TempOutlierThreshold") + spark.sql("DROP VIEW IF EXISTS ForecastsToUpdate") + spark.sql("DROP VIEW IF EXISTS ForecastWithThresholds") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Static Syntax Check Results +# MAGIC No syntax errors were detected during the static check. +# MAGIC However, please review the code carefully as some issues may only be detected during runtime. \ No newline at end of file diff --git a/sql2dbx/examples/postgresql/input/postgresql_example1_multi_statement_transformation.sql b/sql2dbx/examples/postgresql/input/postgresql_example1_multi_statement_transformation.sql new file mode 100644 index 00000000..1a23424e --- /dev/null +++ b/sql2dbx/examples/postgresql/input/postgresql_example1_multi_statement_transformation.sql @@ -0,0 +1,53 @@ +-- ========================================== +-- PostgreSQL EXAMPLE #1: Multi-Statement Data Transformation +-- ========================================== + +-- Create a table for product data +CREATE TABLE products ( + product_id INT, + product_name VARCHAR(100), + price NUMERIC(8,2), + created_at TIMESTAMP DEFAULT NOW() +); + +-- Insert some sample products +INSERT INTO products (product_id, product_name, price) +VALUES + (201, 'Gadget Alpha', 14.99), + (202, 'Gadget Beta', 25.50), + (203, 'Gadget Gamma', 32.00); + +-- Create a temporary table to capture discount information +CREATE TEMP TABLE temp_discounts ( + product_id INT, + discount_rate FLOAT +); + +-- Insert discount rates +INSERT INTO temp_discounts (product_id, discount_rate) +VALUES + (201, 0.15), + (203, 0.30); + +-- Update product prices where a discount applies +-- PostgreSQL supports an UPDATE with FROM syntax +UPDATE products p +SET price = price * (1 - d.discount_rate) +FROM temp_discounts d +WHERE p.product_id = d.product_id; + +-- Demonstrate a conditional DELETE +-- For example, remove products older than 7 days with no discount +DELETE FROM products p +USING temp_discounts d +WHERE p.created_at < (NOW() - INTERVAL '7 days') + AND p.product_id NOT IN ( + SELECT product_id FROM temp_discounts + ); + +-- Final SELECT to confirm changes +SELECT * FROM products; + +-- Clean up +-- DROP TABLE IF EXISTS temp_discounts; -- Temp tables are automatically dropped at the end of the session +DROP TABLE IF EXISTS products; diff --git a/sql2dbx/examples/postgresql/input/postgresql_example2_stored_procedure.sql b/sql2dbx/examples/postgresql/input/postgresql_example2_stored_procedure.sql new file mode 100644 index 00000000..2bbf6d4a --- /dev/null +++ b/sql2dbx/examples/postgresql/input/postgresql_example2_stored_procedure.sql @@ -0,0 +1,89 @@ +-- ========================================== +-- PostgreSQL EXAMPLE #2: Stored Procedure with Outlier Checking +-- ========================================== + +CREATE OR REPLACE PROCEDURE demo_forecast_outlier_check_update( + _schema_name TEXT, + _outlier_multiplier NUMERIC(5,2) DEFAULT 1.30 +) +LANGUAGE plpgsql +AS $$ +DECLARE + _errmsg TEXT; + _current_date DATE; +BEGIN + -- Create a temporary table to store outlier thresholds + CREATE TEMP TABLE temp_outlier_info ( + location_id VARCHAR(10), + outlier_threshold NUMERIC(8,2) + ); + + -- Begin a transaction + BEGIN + -- 1) Retrieve current date from a table in the given schema + EXECUTE format( + 'SELECT system_date FROM %I.system_date_table LIMIT 1', + _schema_name + ) + INTO _current_date; + + IF _current_date IS NULL THEN + RAISE NOTICE 'No valid date found in system_date_table; exiting.'; + RETURN; + END IF; + + -- 2) Insert outlier thresholds based on 99th percentile from "historical_data_table" + EXECUTE format( + 'INSERT INTO temp_outlier_info (location_id, outlier_threshold) + SELECT d.location_id, + PERCENTILE_CONT(0.99) WITHIN GROUP (ORDER BY d.metric_value) + FROM %I.historical_data_table d + WHERE CAST(d.target_date AS DATE) >= %L::DATE - INTERVAL ''1 YEAR'' + GROUP BY d.location_id', + _schema_name, _current_date + ); + + -- 3) Update "forecast_table" to store original forecast values that exceed outlier_threshold * multiplier + EXECUTE format( + 'UPDATE %I.forecast_table f + SET original_forecast_value = forecast_value + FROM temp_outlier_info t + WHERE f.location_id = t.location_id + AND CAST(f.forecast_date AS DATE) = %L::DATE + AND f.forecast_value > t.outlier_threshold * %s', + _schema_name, + _current_date, + _outlier_multiplier + ); + + -- 4) Cap outlier forecast values at threshold * multiplier + EXECUTE format( + 'UPDATE %I.forecast_table f + SET forecast_value = t.outlier_threshold * %s + FROM temp_outlier_info t + WHERE f.location_id = t.location_id + AND CAST(f.forecast_date AS DATE) = %L::DATE + AND f.forecast_value > t.outlier_threshold * %s', + _schema_name, + _outlier_multiplier, + _current_date, + _outlier_multiplier + ); + + -- Commit the transaction + COMMIT; + EXCEPTION + WHEN OTHERS THEN + _errmsg := SQLERRM; + RAISE NOTICE 'Error occurred: %', _errmsg; + + -- Always rollback on error + ROLLBACK; + + -- Re-raise the error + RAISE; + END; + + RETURN; +END; +$$; \ No newline at end of file diff --git a/sql2dbx/examples/postgresql/output/postgresql_example1_multi_statement_transformation.py b/sql2dbx/examples/postgresql/output/postgresql_example1_multi_statement_transformation.py new file mode 100644 index 00000000..5de299c1 --- /dev/null +++ b/sql2dbx/examples/postgresql/output/postgresql_example1_multi_statement_transformation.py @@ -0,0 +1,88 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # postgresql_example1_multi_statement_transformation +# MAGIC This notebook was automatically converted from the script below. It may contain errors, so use it as a starting point and make necessary corrections. +# MAGIC +# MAGIC Source script: `/Workspace/Users/hiroyuki.nakazato@databricks.com/.bundle/sql2dbx/dev/files/examples/postgresql/input/postgresql_example1_multi_statement_transformation.sql` + +# COMMAND ---------- + +# Create products table +spark.sql(""" +CREATE TABLE IF NOT EXISTS products ( + product_id INT, + product_name STRING, + price DECIMAL(8,2), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP() +) +""") + +# COMMAND ---------- + +# Insert initial product data +spark.sql(""" +INSERT INTO products (product_id, product_name, price) +VALUES + (201, 'Gadget Alpha', 14.99), + (202, 'Gadget Beta', 25.50), + (203, 'Gadget Gamma', 32.00) +""") + +# COMMAND ---------- + +# Create a Delta table as a temporary substitute +spark.sql(""" +CREATE OR REPLACE TABLE temp_discounts ( + product_id INT, + discount_rate FLOAT +) +""") + +# COMMAND ---------- + +# Insert discount data +spark.sql(""" +INSERT INTO temp_discounts (product_id, discount_rate) +VALUES + (201, 0.15), + (203, 0.30) +""") + +# COMMAND ---------- + +# Update prices based on discounts (using MERGE instead of UPDATE FROM) +spark.sql(""" +MERGE INTO products p +USING temp_discounts d +ON p.product_id = d.product_id +WHEN MATCHED THEN + UPDATE SET price = price * (1 - d.discount_rate) +""") + +# COMMAND ---------- + +# DELETE with complex condition needs to be rewritten +spark.sql(""" +DELETE FROM products +WHERE created_at < date_sub(current_timestamp(), 7) +AND product_id NOT IN (SELECT product_id FROM temp_discounts) +""") + +# COMMAND ---------- + +# Display final products +products_df = spark.sql("SELECT * FROM products") +display(products_df) + +# COMMAND ---------- + +# Cleanup +spark.sql("DROP TABLE IF EXISTS products") +spark.sql("DROP TABLE IF EXISTS temp_discounts") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Static Syntax Check Results +# MAGIC No syntax errors were detected during the static check. +# MAGIC However, please review the code carefully as some issues may only be detected during runtime. \ No newline at end of file diff --git a/sql2dbx/examples/postgresql/output/postgresql_example2_stored_procedure.py b/sql2dbx/examples/postgresql/output/postgresql_example2_stored_procedure.py new file mode 100644 index 00000000..3092fff7 --- /dev/null +++ b/sql2dbx/examples/postgresql/output/postgresql_example2_stored_procedure.py @@ -0,0 +1,92 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # postgresql_example2_stored_procedure +# MAGIC This notebook was automatically converted from the script below. It may contain errors, so use it as a starting point and make necessary corrections. +# MAGIC +# MAGIC Source script: `/Workspace/Users/hiroyuki.nakazato@databricks.com/.bundle/sql2dbx/dev/files/examples/postgresql/input/postgresql_example2_stored_procedure.sql` + +# COMMAND ---------- + +# Create widgets for the procedure parameters +dbutils.widgets.text("_schema_name", "") +dbutils.widgets.text("_outlier_multiplier", "1.30") + +# COMMAND ---------- + +# Get parameter values +schema_name = dbutils.widgets.get("_schema_name") +outlier_multiplier = float(dbutils.widgets.get("_outlier_multiplier")) + +# COMMAND ---------- + +# Create temporary Delta table for outlier info +spark.sql(""" +CREATE OR REPLACE TABLE temp_outlier_info ( + location_id STRING, + outlier_threshold DECIMAL(8,2) +) +""") + +# COMMAND ---------- + +# Try-except block for transaction-like behavior +try: + # Get current date from system date table + current_date_df = spark.sql(f"SELECT system_date FROM {schema_name}.system_date_table LIMIT 1") + + if current_date_df.count() == 0: + print("No valid date found in system_date_table; exiting.") + dbutils.notebook.exit("No valid date found") + + current_date = current_date_df.collect()[0][0] + + # Calculate outlier thresholds and insert into temp table + spark.sql(f""" + INSERT INTO temp_outlier_info (location_id, outlier_threshold) + SELECT + d.location_id, + percentile(d.metric_value, 0.99) as outlier_threshold + FROM {schema_name}.historical_data_table d + WHERE CAST(d.target_date AS DATE) >= date_sub('{current_date}', 365) + GROUP BY d.location_id + """) + + # Update forecast table: first backup original values + # Using MERGE instead of UPDATE FROM since Databricks doesn't support UPDATE FROM + spark.sql(f""" + MERGE INTO {schema_name}.forecast_table f + USING temp_outlier_info t + ON f.location_id = t.location_id + AND CAST(f.forecast_date AS DATE) = '{current_date}' + AND f.forecast_value > t.outlier_threshold * {outlier_multiplier} + WHEN MATCHED THEN + UPDATE SET original_forecast_value = f.forecast_value + """) + + # Update forecast table: cap values at threshold + spark.sql(f""" + MERGE INTO {schema_name}.forecast_table f + USING temp_outlier_info t + ON f.location_id = t.location_id + AND CAST(f.forecast_date AS DATE) = '{current_date}' + AND f.forecast_value > t.outlier_threshold * {outlier_multiplier} + WHEN MATCHED THEN + UPDATE SET forecast_value = t.outlier_threshold * {outlier_multiplier} + """) + +except Exception as e: + print(f"Error occurred: {str(e)}") + # In a real Delta Lake scenario, you might want to restore the table to a previous version + # spark.sql(f"RESTORE TABLE {schema_name}.forecast_table TO VERSION AS OF ") + raise e + +finally: + # Drop the temporary table + spark.sql("DROP TABLE IF EXISTS temp_outlier_info") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Static Syntax Check Results +# MAGIC No syntax errors were detected during the static check. +# MAGIC However, please review the code carefully as some issues may only be detected during runtime. \ No newline at end of file diff --git a/sql2dbx/examples/redshift/input/redshift_example1_multi_statement_transformation.sql b/sql2dbx/examples/redshift/input/redshift_example1_multi_statement_transformation.sql new file mode 100644 index 00000000..08d170c1 --- /dev/null +++ b/sql2dbx/examples/redshift/input/redshift_example1_multi_statement_transformation.sql @@ -0,0 +1,55 @@ +-------------------------------------------------------------------------------- +-- Redshift EXAMPLE #1: Multi-Statement Data Transformation in Amazon Redshift +-------------------------------------------------------------------------------- + +-- Create a table for product data with Redshift-specific attributes +CREATE TABLE products +( + product_id INT, + product_name VARCHAR(100), + price DECIMAL(10,2), + created_at TIMESTAMP DEFAULT SYSDATE +) +DISTKEY(product_id) +SORTKEY(product_id) +ENCODE ZSTD; + +-- Insert some sample products +INSERT INTO products (product_id, product_name, price) +VALUES + (201, 'Red Gadget', 15.50), + (202, 'Blue Widget', 24.99), + (203, 'Green Gizmo', 8.75); + +-- Create a temporary table for discount rates +CREATE TEMP TABLE temp_discounts +( + product_id INT, + discount_pct FLOAT +); + +-- Insert discount rates +INSERT INTO temp_discounts (product_id, discount_pct) +VALUES + (201, 0.05), + (203, 0.20); + +-- Update product prices to reflect applicable discounts +-- Redshift supports UPDATE with a FROM clause (similar to T-SQL) +UPDATE products AS p +SET price = price * (1 - d.discount_pct) +FROM temp_discounts d +WHERE p.product_id = d.product_id; + +-- Conditionally delete products older than 7 days with no discounts +DELETE FROM products p +USING temp_discounts d +WHERE p.created_at < (SYSDATE - INTERVAL '7' DAY) + AND p.product_id NOT IN (SELECT product_id FROM temp_discounts); + +-- Final SELECT to verify changes +SELECT * FROM products; + +-- Clean up +-- DROP TABLE IF EXISTS temp_discounts; -- Temp tables are automatically dropped at the end of the session +DROP TABLE IF EXISTS products; diff --git a/sql2dbx/examples/redshift/input/redshift_example2_stored_procedure.sql b/sql2dbx/examples/redshift/input/redshift_example2_stored_procedure.sql new file mode 100644 index 00000000..15a5ab79 --- /dev/null +++ b/sql2dbx/examples/redshift/input/redshift_example2_stored_procedure.sql @@ -0,0 +1,57 @@ +-------------------------------------------------------------------------------- +-- Redshift EXAMPLE #2: Stored Procedure with Basic Transaction Handling in Redshift +-------------------------------------------------------------------------------- + +-- Create a stored procedure to update outlier prices in a "sales_forecast" table +-- based on a multiplier threshold +CREATE OR REPLACE PROCEDURE public.update_outlier_prices( + threshold_multiplier DECIMAL(5,2) DEFAULT 1.25 +) +LANGUAGE plpgsql +AS $$ +DECLARE + row_count INT := 0; + temp_table_name TEXT := 'temp_outlier_thresholds'; +BEGIN + -- Create a temporary table to store threshold data + EXECUTE 'CREATE TEMP TABLE ' || temp_table_name || ' ( + product_id INT, + outlier_threshold DECIMAL(10,2) + )'; + + -- Begin transaction + BEGIN + -- Insert threshold data (example: fixed or from another table) + EXECUTE 'INSERT INTO ' || temp_table_name || ' + SELECT product_id, + PERCENTILE_CONT(0.99) + WITHIN GROUP (ORDER BY forecast_price) + OVER (PARTITION BY product_id) + FROM sales_forecast + WHERE forecast_date >= current_date - 365'; + + -- Update outlier forecast prices (cap them at threshold * multiplier) + EXECUTE 'UPDATE sales_forecast f + SET forecast_price = t.outlier_threshold * ' || threshold_multiplier || ' + FROM ' || temp_table_name || ' t + WHERE f.product_id = t.product_id + AND f.forecast_price > t.outlier_threshold * ' || threshold_multiplier; + + GET DIAGNOSTICS row_count = ROW_COUNT; -- Example of checking row count + RAISE NOTICE 'Rows updated: %', row_count; + + COMMIT; -- Commit if all goes well + EXCEPTION + WHEN OTHERS THEN + RAISE NOTICE 'Error occurred: %', SQLERRM; + ROLLBACK; + RETURN; + END; + + -- Clean up the temporary table + -- EXECUTE 'DROP TABLE IF EXISTS ' || temp_table_name || ';'; -- Temp tables are automatically dropped at the end of the session + + -- Optional return or message + RAISE NOTICE 'Procedure completed with multiplier %', threshold_multiplier; +END; +$$; diff --git a/sql2dbx/examples/redshift/output/redshift_example1_multi_statement_transformation.py b/sql2dbx/examples/redshift/output/redshift_example1_multi_statement_transformation.py new file mode 100644 index 00000000..8a85c06c --- /dev/null +++ b/sql2dbx/examples/redshift/output/redshift_example1_multi_statement_transformation.py @@ -0,0 +1,89 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # redshift_example1_multi_statement_transformation +# MAGIC This notebook was automatically converted from the script below. It may contain errors, so use it as a starting point and make necessary corrections. +# MAGIC +# MAGIC Source script: `/Workspace/Users/hiroyuki.nakazato@databricks.com/.bundle/sql2dbx/dev/files/examples/redshift/input/redshift_example1_multi_statement_transformation.sql` + +# COMMAND ---------- + +# Create products table (removing Redshift-specific features) +spark.sql(""" +CREATE TABLE IF NOT EXISTS products ( + product_id INT, + product_name STRING, + price DECIMAL(10,2), + created_at TIMESTAMP DEFAULT current_timestamp() +) +""") + +# COMMAND ---------- + +# Note: Removed Redshift-specific DISTKEY, SORTKEY and ENCODE options which aren't supported in Databricks + +# Insert initial product data +spark.sql(""" +INSERT INTO products (product_id, product_name, price) +VALUES + (201, 'Red Gadget', 15.50), + (202, 'Blue Widget', 24.99), + (203, 'Green Gizmo', 8.75) +""") + +# COMMAND ---------- + +# Create temporary discount table +spark.sql(""" +CREATE OR REPLACE TABLE temp_discounts ( + product_id INT, + discount_pct FLOAT +) +""") + +# COMMAND ---------- + +# Insert discount data +spark.sql(""" +INSERT INTO temp_discounts (product_id, discount_pct) +VALUES + (201, 0.05), + (203, 0.20) +""") + +# COMMAND ---------- + +# Update prices using MERGE (Databricks doesn't support UPDATE...FROM) +spark.sql(""" +MERGE INTO products p +USING temp_discounts d +ON p.product_id = d.product_id +WHEN MATCHED THEN + UPDATE SET price = price * (1 - d.discount_pct) +""") + +# COMMAND ---------- + +# Delete older products (rewriting the logic for Databricks) +spark.sql(""" +DELETE FROM products +WHERE created_at < date_sub(current_timestamp(), 7) +AND product_id NOT IN (SELECT product_id FROM temp_discounts) +""") + +# COMMAND ---------- + +# Display final products +display(spark.sql("SELECT * FROM products")) + +# COMMAND ---------- + +# Clean up tables +spark.sql("DROP TABLE IF EXISTS products") +spark.sql("DROP TABLE IF EXISTS temp_discounts") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Static Syntax Check Results +# MAGIC No syntax errors were detected during the static check. +# MAGIC However, please review the code carefully as some issues may only be detected during runtime. \ No newline at end of file diff --git a/sql2dbx/examples/redshift/output/redshift_example2_stored_procedure.py b/sql2dbx/examples/redshift/output/redshift_example2_stored_procedure.py new file mode 100644 index 00000000..bb223944 --- /dev/null +++ b/sql2dbx/examples/redshift/output/redshift_example2_stored_procedure.py @@ -0,0 +1,85 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # redshift_example2_stored_procedure +# MAGIC This notebook was automatically converted from the script below. It may contain errors, so use it as a starting point and make necessary corrections. +# MAGIC +# MAGIC Source script: `/Workspace/Users/hiroyuki.nakazato@databricks.com/.bundle/sql2dbx/dev/files/examples/redshift/input/redshift_example2_stored_procedure.sql` + +# COMMAND ---------- + +# Create widget for parameter with default value +dbutils.widgets.text("threshold_multiplier", "1.25") + +# COMMAND ---------- + +# Convert parameter +threshold_multiplier = float(dbutils.widgets.get("threshold_multiplier")) + +# COMMAND ---------- + +# Define temporary table name (we'll use Delta table) +temp_table_name = "temp_outlier_thresholds" + +# COMMAND ---------- + +# Track latest timestamp for potential rollback +try: + hist = spark.sql("DESCRIBE HISTORY sales_forecast LIMIT 1").collect()[0] + restore_ts = hist["timestamp"] +except: + restore_ts = None + +# COMMAND ---------- + +try: + # Create temporary table + spark.sql(f""" + CREATE OR REPLACE TABLE {temp_table_name} ( + product_id INT, + outlier_threshold DECIMAL(10,2) + ) + """) + + # Calculate the outlier thresholds using percentile_cont + spark.sql(f""" + INSERT INTO {temp_table_name} + SELECT + product_id, + percentile_cont(0.99) WITHIN GROUP (ORDER BY forecast_price) OVER (PARTITION BY product_id) AS outlier_threshold + FROM sales_forecast + WHERE forecast_date >= date_sub(current_date(), 365) + """) + + # Update the sales_forecast table with capped values + update_result = spark.sql(f""" + MERGE INTO sales_forecast f + USING {temp_table_name} t + ON f.product_id = t.product_id + WHEN MATCHED AND f.forecast_price > t.outlier_threshold * {threshold_multiplier} THEN + UPDATE SET forecast_price = t.outlier_threshold * {threshold_multiplier} + """) + + # Get count of updated rows + row_count = update_result.count() if hasattr(update_result, 'count') else 0 + print(f"Rows updated: {row_count}") + print(f"Procedure completed with multiplier {threshold_multiplier}") + +except Exception as e: + print(f"Error occurred: {str(e)}") + + # Attempt rollback if we have a timestamp + if restore_ts: + spark.sql(f"RESTORE TABLE sales_forecast TO TIMESTAMP AS OF '{restore_ts}'") + print("Rollback via table restore completed.") + raise e + +finally: + # Clean up temporary table + spark.sql(f"DROP TABLE IF EXISTS {temp_table_name}") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Static Syntax Check Results +# MAGIC No syntax errors were detected during the static check. +# MAGIC However, please review the code carefully as some issues may only be detected during runtime. \ No newline at end of file diff --git a/sql2dbx/examples/snowflake/input/snowflake_example1_multi_statement_transformation.sql b/sql2dbx/examples/snowflake/input/snowflake_example1_multi_statement_transformation.sql new file mode 100644 index 00000000..97a2361b --- /dev/null +++ b/sql2dbx/examples/snowflake/input/snowflake_example1_multi_statement_transformation.sql @@ -0,0 +1,57 @@ +-- ========================================== +-- EXAMPLE 1: Multi-Statement Data Transformation in Snowflake +-- ========================================== + +-- Optional session-level statement (will be commented out or handled specially in Databricks) +ALTER SESSION SET TIMEZONE = 'America/Los_Angeles'; + +-- Create a table for customer data +CREATE TABLE CUSTOMERS ( + CUSTOMER_ID NUMBER(10, 0), + FULL_NAME VARCHAR(100), + STATUS VARCHAR(10), + CREATED_AT TIMESTAMP_NTZ DEFAULT CURRENT_TIMESTAMP() +); + +-- Insert some initial data +INSERT INTO CUSTOMERS (CUSTOMER_ID, FULL_NAME, STATUS) +VALUES + (1, 'Alice Smith', 'ACTIVE'), + (2, 'Bob Jones', 'INACTIVE'), + (3, 'Charlie Brown', 'ACTIVE'); + +-- Create a temporary table to store an address list +CREATE TEMPORARY TABLE TEMP_ADDRESSES ( + ADDRESS_ID NUMBER(10, 0), + CUSTOMER_ID NUMBER(10, 0), + ADDRESS_LINE VARCHAR(200) +); + +-- Insert data into the temporary table +INSERT INTO TEMP_ADDRESSES (ADDRESS_ID, CUSTOMER_ID, ADDRESS_LINE) +VALUES + (100, 1, '123 Maple Street'), + (101, 2, '456 Oak Avenue'), + (102, 3, '789 Pine Road'); + +-- Illustrate a simple UPDATE using a JOIN-like syntax (Snowflake allows FROM in UPDATE) +UPDATE CUSTOMERS +SET STATUS = 'PENDING' +FROM TEMP_ADDRESSES +WHERE CUSTOMERS.CUSTOMER_ID = TEMP_ADDRESSES.CUSTOMER_ID + AND TEMP_ADDRESSES.ADDRESS_LINE LIKE '%Pine%'; + +-- Demonstrate a DELETE that references another table +-- (Snowflake can do: DELETE FROM USING WHERE ...) +DELETE FROM TEMP_ADDRESSES +USING CUSTOMERS +WHERE TEMP_ADDRESSES.CUSTOMER_ID = CUSTOMERS.CUSTOMER_ID + AND CUSTOMERS.STATUS = 'INACTIVE'; + +-- Final SELECT to confirm transformations +SELECT * FROM CUSTOMERS; +SELECT * FROM TEMP_ADDRESSES; + +-- Clean up +-- DROP TABLE IF EXISTS TEMP_ADDRESSES; -- Temp tables are automatically dropped at the end of the session +DROP TABLE IF EXISTS CUSTOMERS; diff --git a/sql2dbx/examples/snowflake/input/snowflake_example2_javascript_procedure.sql b/sql2dbx/examples/snowflake/input/snowflake_example2_javascript_procedure.sql new file mode 100644 index 00000000..15c30180 --- /dev/null +++ b/sql2dbx/examples/snowflake/input/snowflake_example2_javascript_procedure.sql @@ -0,0 +1,111 @@ +-- ========================================== +-- SNOWFLAKE EXAMPLE #2: Stored Procedure with Outlier Checking +-- ========================================== + +CREATE OR REPLACE PROCEDURE PUBLIC.DEMO_FORECAST_OUTLIER_CHECK_UPDATE( + SCHEMA_NAME VARCHAR, + OUTLIER_MULTIPLIER FLOAT DEFAULT 1.30 +) +RETURNS VARCHAR +LANGUAGE JAVASCRIPT +AS +$$ +// Variable declarations +var result = 0; +var errorMsg = null; +var currentDate = null; +var errorProcName = "DEMO_FORECAST_OUTLIER_CHECK_UPDATE"; + +try { + // Begin transaction + snowflake.execute({sqlText: 'BEGIN'}); + + // 1) Retrieve current date from a "SystemDateTable" in the given schema + var sqlGetDate = ` + SELECT SystemDate + FROM "${SCHEMA_NAME}".SystemDateTable; + `; + + var rsDate = snowflake.execute({sqlText: sqlGetDate}); + if (rsDate.next()) { + currentDate = rsDate.getColumnValue(1); + } else { + throw new Error("Failed to retrieve system date"); + } + + // 2) Create a temporary table to store outlier thresholds + snowflake.execute({sqlText: ` + CREATE TEMPORARY TABLE TEMP_OUTLIER_INFO ( + LocationId VARCHAR(10), + OutlierThreshold DECIMAL(8,2) + ); + `}); + + // 3) Insert outlier thresholds (99th percentile) from "HistoricalDataTable" + var sqlInsertThresholds = ` + INSERT INTO TEMP_OUTLIER_INFO (LocationId, OutlierThreshold) + SELECT + d.LocationId, + PERCENTILE_CONT(0.99) + WITHIN GROUP (ORDER BY d.MetricValue) + OVER (PARTITION BY d.LocationId) + FROM "${SCHEMA_NAME}".HistoricalDataTable d + WHERE TO_DATE(d.TargetDate) >= DATEADD(YEAR, -1, TO_DATE('${currentDate}')); + `; + + snowflake.execute({sqlText: sqlInsertThresholds}); + + // 4) Save original forecast values above threshold in "ForecastTable" + var sqlSaveOriginals = ` + UPDATE "${SCHEMA_NAME}".ForecastTable f + SET f.OriginalForecastValue = f.ForecastValue + FROM TEMP_OUTLIER_INFO t + WHERE f.LocationId = t.LocationId + AND TO_DATE(f.ForecastDate) = TO_DATE('${currentDate}') + AND f.ForecastValue > t.OutlierThreshold * ${OUTLIER_MULTIPLIER}; + `; + + snowflake.execute({sqlText: sqlSaveOriginals}); + + // 5) Update outlier values to cap them at threshold * multiplier + var sqlUpdateOutliers = ` + UPDATE "${SCHEMA_NAME}".ForecastTable f + SET f.ForecastValue = t.OutlierThreshold * ${OUTLIER_MULTIPLIER} + FROM TEMP_OUTLIER_INFO t + WHERE f.LocationId = t.LocationId + AND TO_DATE(f.ForecastDate) = TO_DATE('${currentDate}') + AND f.ForecastValue > t.OutlierThreshold * ${OUTLIER_MULTIPLIER}; + `; + + snowflake.execute({sqlText: sqlUpdateOutliers}); + + // Commit transaction + snowflake.execute({sqlText: 'COMMIT'}); + +} catch (err) { + result = 2; + errorMsg = err.message; + + // Rollback if transaction is still open + snowflake.execute({sqlText: 'ROLLBACK'}); + + // Log the error (simplified): + var logSql = ` + CALL LogError( + '${errorProcName}', + '${errorMsg.replace(/'/g, "''")}' + ); + `; + + try { + snowflake.execute({sqlText: logSql}); + } catch (logErr) { + // Silently handle logging errors + } + + // Re-throw the error with details + throw new Error("Error in procedure: " + errorMsg); +} + +return result.toString(); +$$; \ No newline at end of file diff --git a/sql2dbx/examples/snowflake/output/snowflake_example1_multi_statement_transformation.py b/sql2dbx/examples/snowflake/output/snowflake_example1_multi_statement_transformation.py new file mode 100644 index 00000000..ec0521ff --- /dev/null +++ b/sql2dbx/examples/snowflake/output/snowflake_example1_multi_statement_transformation.py @@ -0,0 +1,111 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # snowflake_example1_multi_statement_transformation +# MAGIC This notebook was automatically converted from the script below. It may contain errors, so use it as a starting point and make necessary corrections. +# MAGIC +# MAGIC Source script: `/Workspace/Users/hiroyuki.nakazato@databricks.com/.bundle/sql2dbx/dev/files/examples/snowflake/input/snowflake_example1_multi_statement_transformation.sql` + +# COMMAND ---------- + +# Set timezone +# Note: In Databricks, timezone is configured differently than Snowflake +spark.sql("SET spark.sql.session.timeZone=America/Los_Angeles") + +# COMMAND ---------- + +# Create CUSTOMERS table +spark.sql(""" +CREATE TABLE IF NOT EXISTS CUSTOMERS ( + CUSTOMER_ID INT, + FULL_NAME STRING, + STATUS STRING, + CREATED_AT TIMESTAMP DEFAULT CURRENT_TIMESTAMP() +) +""") + +# COMMAND ---------- + +# Insert data into CUSTOMERS table +spark.sql(""" +INSERT INTO CUSTOMERS (CUSTOMER_ID, FULL_NAME, STATUS) +VALUES + (1, 'Alice Smith', 'ACTIVE'), + (2, 'Bob Jones', 'INACTIVE'), + (3, 'Charlie Brown', 'ACTIVE') +""") + +# COMMAND ---------- + +# Create a table for addresses (we don't use temporary tables in Databricks the same way) +spark.sql(""" +CREATE OR REPLACE TABLE TEMP_ADDRESSES ( + ADDRESS_ID INT, + CUSTOMER_ID INT, + ADDRESS_LINE STRING +) +""") + +# COMMAND ---------- + +# Insert data into TEMP_ADDRESSES +spark.sql(""" +INSERT INTO TEMP_ADDRESSES (ADDRESS_ID, CUSTOMER_ID, ADDRESS_LINE) +VALUES + (100, 1, '123 Maple Street'), + (101, 2, '456 Oak Avenue'), + (102, 3, '789 Pine Road') +""") + +# COMMAND ---------- + +# Update CUSTOMERS - Databricks doesn't support UPDATE FROM syntax +# Instead, we'll use MERGE INTO +spark.sql(""" +MERGE INTO CUSTOMERS c +USING ( + SELECT CUSTOMER_ID + FROM TEMP_ADDRESSES + WHERE ADDRESS_LINE LIKE '%Pine%' +) a +ON c.CUSTOMER_ID = a.CUSTOMER_ID +WHEN MATCHED THEN UPDATE SET STATUS = 'PENDING' +""") + +# COMMAND ---------- + +# Delete from TEMP_ADDRESSES - Databricks doesn't support DELETE USING syntax +# We'll identify records to keep and rewrite the table +spark.sql(""" +CREATE OR REPLACE TABLE TEMP_ADDRESSES AS +SELECT t.* +FROM TEMP_ADDRESSES t +LEFT JOIN CUSTOMERS c ON t.CUSTOMER_ID = c.CUSTOMER_ID AND c.STATUS = 'INACTIVE' +WHERE c.CUSTOMER_ID IS NULL +""") + +# COMMAND ---------- + +# Select all data from CUSTOMERS +display(spark.sql("SELECT * FROM CUSTOMERS")) + +# COMMAND ---------- + +# Select all data from TEMP_ADDRESSES +display(spark.sql("SELECT * FROM TEMP_ADDRESSES")) + +# COMMAND ---------- + +# Drop CUSTOMERS table +spark.sql("DROP TABLE IF EXISTS CUSTOMERS") + +# COMMAND ---------- + +# Also drop the TEMP_ADDRESSES table since it was meant to be temporary +spark.sql("DROP TABLE IF EXISTS TEMP_ADDRESSES") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Static Syntax Check Results +# MAGIC No syntax errors were detected during the static check. +# MAGIC However, please review the code carefully as some issues may only be detected during runtime. \ No newline at end of file diff --git a/sql2dbx/examples/snowflake/output/snowflake_example2_javascript_procedure.py b/sql2dbx/examples/snowflake/output/snowflake_example2_javascript_procedure.py new file mode 100644 index 00000000..5521fe4f --- /dev/null +++ b/sql2dbx/examples/snowflake/output/snowflake_example2_javascript_procedure.py @@ -0,0 +1,112 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # snowflake_example2_javascript_procedure +# MAGIC This notebook was automatically converted from the script below. It may contain errors, so use it as a starting point and make necessary corrections. +# MAGIC +# MAGIC Source script: `/Workspace/Users/hiroyuki.nakazato@databricks.com/.bundle/sql2dbx/dev/files/examples/snowflake/input/snowflake_example2_javascript_procedure.sql` + +# COMMAND ---------- + +# Define widgets to accept parameters +dbutils.widgets.text("SCHEMA_NAME", "") +dbutils.widgets.text("OUTLIER_MULTIPLIER", "1.30") + +# COMMAND ---------- + +# Get widget values +schema_name = dbutils.widgets.get("SCHEMA_NAME") +outlier_multiplier = float(dbutils.widgets.get("OUTLIER_MULTIPLIER")) + +# COMMAND ---------- + +# For Delta table rollback if needed +try: + hist_forecast = spark.sql(f"DESCRIBE HISTORY `{schema_name}`.ForecastTable LIMIT 1").collect()[0] + forecast_restore_ts = hist_forecast["timestamp"] +except: + forecast_restore_ts = None + +# COMMAND ---------- + +# Main procedure logic +try: + # 1) Retrieve current date from SystemDateTable + date_df = spark.sql(f"SELECT SystemDate FROM `{schema_name}`.SystemDateTable") + if date_df.count() == 0: + raise ValueError("Failed to retrieve system date") + + current_date = date_df.collect()[0]["SystemDate"] + + # 2 & 3) Calculate outlier thresholds using window functions + # Instead of creating a temp table, we'll create a temporary view + spark.sql(f""" + CREATE OR REPLACE TEMPORARY VIEW TEMP_OUTLIER_INFO AS + SELECT + d.LocationId, + percentile_cont(0.99) WITHIN GROUP (ORDER BY d.MetricValue) + OVER (PARTITION BY d.LocationId) AS OutlierThreshold + FROM `{schema_name}`.HistoricalDataTable d + WHERE TO_DATE(d.TargetDate) >= DATE_SUB(TO_DATE('{current_date}'), 365) + """) + + # 4) Save original forecast values above threshold + spark.sql(f""" + UPDATE `{schema_name}`.ForecastTable f + SET OriginalForecastValue = ForecastValue + WHERE EXISTS ( + SELECT 1 + FROM TEMP_OUTLIER_INFO t + WHERE f.LocationId = t.LocationId + AND TO_DATE(f.ForecastDate) = TO_DATE('{current_date}') + AND f.ForecastValue > t.OutlierThreshold * {outlier_multiplier} + ) + """) + + # 5) Update outlier values to cap them at threshold * multiplier + spark.sql(f""" + MERGE INTO `{schema_name}`.ForecastTable f + USING TEMP_OUTLIER_INFO t + ON f.LocationId = t.LocationId + AND TO_DATE(f.ForecastDate) = TO_DATE('{current_date}') + AND f.ForecastValue > t.OutlierThreshold * {outlier_multiplier} + WHEN MATCHED THEN + UPDATE SET f.ForecastValue = t.OutlierThreshold * {outlier_multiplier} + """) + + # Success + result = "0" + +except Exception as e: + # Error handling - restore forecast table if we have a timestamp + if forecast_restore_ts is not None: + spark.sql(f"RESTORE TABLE `{schema_name}`.ForecastTable TO TIMESTAMP AS OF '{forecast_restore_ts}'") + + # Log error (simplified) + error_msg = str(e).replace("'", "''") + try: + spark.sql(f"CALL LogError('DEMO_FORECAST_OUTLIER_CHECK_UPDATE', '{error_msg}')") + except: + # Silently handle logging errors + pass + + # Re-raise the error + result = "2" + raise Exception(f"Error in procedure: {error_msg}") + +# COMMAND ---------- + +# Clean up the temporary view +spark.sql("DROP VIEW IF EXISTS TEMP_OUTLIER_INFO") + +# COMMAND ---------- + +# Return the result +print(f"Result: {result}") +dbutils.notebook.exit(result) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Static Syntax Check Results +# MAGIC No syntax errors were detected during the static check. +# MAGIC However, please review the code carefully as some issues may only be detected during runtime. \ No newline at end of file diff --git a/sql2dbx/examples/teradata/input/teradata_example1_multi_statement_transformation.sql b/sql2dbx/examples/teradata/input/teradata_example1_multi_statement_transformation.sql new file mode 100644 index 00000000..23d973e6 --- /dev/null +++ b/sql2dbx/examples/teradata/input/teradata_example1_multi_statement_transformation.sql @@ -0,0 +1,58 @@ +-- ========================================== +-- TERADATA EXAMPLE #1: Multi-Statement Data Transformation +-- ========================================== + +-- Create a MULTISET table to hold product data +CREATE MULTISET TABLE Products ( + ProductID INTEGER, + ProductName VARCHAR(100), + Price DECIMAL(9,2), + CreatedAt TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP(0) +) +PRIMARY INDEX (ProductID); + +-- Insert some sample products +INSERT INTO Products (ProductID, ProductName, Price) +VALUES (201, 'Gadget X', 15.00); + +INSERT INTO Products (ProductID, ProductName, Price) +VALUES (202, 'Gadget Y', 25.50); + +INSERT INTO Products (ProductID, ProductName, Price) +VALUES (203, 'Gadget Z', 40.00); + +-- Create a MULTISET table to capture special discount rates +CREATE MULTISET TABLE Discounts ( + ProductID INTEGER, + DiscountRate DECIMAL(4,3) +) +PRIMARY INDEX (ProductID); + +-- Insert discount data +INSERT INTO Discounts (ProductID, DiscountRate) +VALUES (201, 0.10); + +INSERT INTO Discounts (ProductID, DiscountRate) +VALUES (203, 0.25); + +-- Update products based on Discounts +-- Teradata allows the "FROM" clause in UPDATE syntax: +UPDATE Products p +FROM Discounts d +SET p.Price = p.Price * (1 - d.DiscountRate) +WHERE p.ProductID = d.ProductID; + +-- Demonstrate a conditional DELETE: +-- Example: remove items older than 7 days which are not discounted +DELETE FROM Products p +WHERE p.CreatedAt < (CURRENT_TIMESTAMP(0) - INTERVAL '7' DAY) + AND p.ProductID NOT IN ( + SELECT d.ProductID FROM Discounts d + ); + +-- Final SELECT to confirm changes +SEL * FROM Products; + +-- Clean up +DROP TABLE Discounts; +DROP TABLE Products; diff --git a/sql2dbx/examples/teradata/input/teradata_example2_stored_procedure.sql b/sql2dbx/examples/teradata/input/teradata_example2_stored_procedure.sql new file mode 100644 index 00000000..200bc841 --- /dev/null +++ b/sql2dbx/examples/teradata/input/teradata_example2_stored_procedure.sql @@ -0,0 +1,53 @@ +-- ========================================== +-- TERADATA EXAMPLE #2: Stored Procedure with Basic Error Handling +-- ========================================== + +REPLACE PROCEDURE CheckInventoryLevels ( + IN pLocationId VARCHAR(10), + IN pThreshold DECIMAL(8,2), + OUT pResultCode INTEGER +) +L1: BEGIN + DECLARE EXIT HANDLER FOR SQLEXCEPTION + BEGIN + SET pResultCode = 2; -- Indicate error + -- In a real scenario, you might log or handle the error differently + ROLLBACK; + LEAVE L1; + END; + + -- Default result code + SET pResultCode = 0; + + -- Begin a Teradata transaction + BT; + + -- Create a VOLATILE TABLE to store items that exceed the given threshold + CREATE VOLATILE TABLE VolatileHighStock ( + ItemId INTEGER, + LocationId VARCHAR(10), + StockLevel DECIMAL(8,2) + ) ON COMMIT PRESERVE ROWS; + + -- Insert data into the volatile table (simplified example). + -- We assume we have a permanent table Inventory with columns ItemId, LocationId, StockLevel + INSERT INTO VolatileHighStock (ItemId, LocationId, StockLevel) + SELECT i.ItemId, i.LocationId, i.StockLevel + FROM Inventory i + WHERE i.LocationId = pLocationId + AND i.StockLevel > pThreshold; + + -- Suppose we adjust stock values in the main Inventory table if they exceed threshold + UPDATE Inventory inv + FROM VolatileHighStock v + SET inv.StockLevel = pThreshold + WHERE inv.ItemId = v.ItemId + AND inv.LocationId = v.LocationId; + + -- Commit the transaction + ET; + + -- Drop the volatile table + DROP TABLE VolatileHighStock; + +END; diff --git a/sql2dbx/examples/teradata/output/teradata_example1_multi_statement_transformation.py b/sql2dbx/examples/teradata/output/teradata_example1_multi_statement_transformation.py new file mode 100644 index 00000000..a3f57274 --- /dev/null +++ b/sql2dbx/examples/teradata/output/teradata_example1_multi_statement_transformation.py @@ -0,0 +1,109 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # teradata_example1_multi_statement_transformation +# MAGIC This notebook was automatically converted from the script below. It may contain errors, so use it as a starting point and make necessary corrections. +# MAGIC +# MAGIC Source script: `/Workspace/Users/hiroyuki.nakazato@databricks.com/.bundle/sql2dbx/dev/files/examples/teradata/input/teradata_example1_multi_statement_transformation.sql` + +# COMMAND ---------- + +# Create Products table in Delta format +spark.sql(""" +CREATE OR REPLACE TABLE Products ( + ProductID INT, + ProductName STRING, + Price DECIMAL(9,2), + CreatedAt TIMESTAMP DEFAULT CURRENT_TIMESTAMP() +) +""") + +# COMMAND ---------- + +# Note: PRIMARY INDEX is specific to Teradata and doesn't have a direct equivalent in Databricks + +# Insert initial product records +spark.sql(""" +INSERT INTO Products (ProductID, ProductName, Price) +VALUES (201, 'Gadget X', 15.00) +""") + +# COMMAND ---------- + +spark.sql(""" +INSERT INTO Products (ProductID, ProductName, Price) +VALUES (202, 'Gadget Y', 25.50) +""") + +# COMMAND ---------- + +spark.sql(""" +INSERT INTO Products (ProductID, ProductName, Price) +VALUES (203, 'Gadget Z', 40.00) +""") + +# COMMAND ---------- + +# Create Discounts table +spark.sql(""" +CREATE OR REPLACE TABLE Discounts ( + ProductID INT, + DiscountRate DECIMAL(4,3) +) +""") + +# COMMAND ---------- + +# Note: PRIMARY INDEX is specific to Teradata and ignored here + +# Insert discount records +spark.sql(""" +INSERT INTO Discounts (ProductID, DiscountRate) +VALUES (201, 0.10) +""") + +# COMMAND ---------- + +spark.sql(""" +INSERT INTO Discounts (ProductID, DiscountRate) +VALUES (203, 0.25) +""") + +# COMMAND ---------- + +# Update product prices based on discount rates +# Databricks doesn't support UPDATE with FROM, so we use MERGE instead +spark.sql(""" +MERGE INTO Products p +USING Discounts d +ON p.ProductID = d.ProductID +WHEN MATCHED THEN + UPDATE SET p.Price = p.Price * (1 - d.DiscountRate) +""") + +# COMMAND ---------- + +# Delete products older than 7 days that have no discount +spark.sql(""" +DELETE FROM Products +WHERE CreatedAt < date_sub(current_timestamp(), 7) +AND ProductID NOT IN (SELECT ProductID FROM Discounts) +""") + +# COMMAND ---------- + +# Display the contents of the Products table +products_df = spark.sql("SELECT * FROM Products") +display(products_df) + +# COMMAND ---------- + +# Drop the tables +spark.sql("DROP TABLE IF EXISTS Discounts") +spark.sql("DROP TABLE IF EXISTS Products") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Static Syntax Check Results +# MAGIC No syntax errors were detected during the static check. +# MAGIC However, please review the code carefully as some issues may only be detected during runtime. \ No newline at end of file diff --git a/sql2dbx/examples/teradata/output/teradata_example2_stored_procedure.py b/sql2dbx/examples/teradata/output/teradata_example2_stored_procedure.py new file mode 100644 index 00000000..c82e8acb --- /dev/null +++ b/sql2dbx/examples/teradata/output/teradata_example2_stored_procedure.py @@ -0,0 +1,98 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # teradata_example2_stored_procedure +# MAGIC This notebook was automatically converted from the script below. It may contain errors, so use it as a starting point and make necessary corrections. +# MAGIC +# MAGIC Source script: `/Workspace/Users/hiroyuki.nakazato@databricks.com/.bundle/sql2dbx/dev/files/examples/teradata/input/teradata_example2_stored_procedure.sql` + +# COMMAND ---------- + +# Setup widgets for parameters +dbutils.widgets.text("pLocationId", "") +dbutils.widgets.text("pThreshold", "") +dbutils.widgets.text("pResultCode", "0") # Output parameter initialized to 0 + +# COMMAND ---------- + +# Get input parameter values +location_id = dbutils.widgets.get("pLocationId") +threshold = float(dbutils.widgets.get("pThreshold")) +result_code = 0 # Initialize result code + +# COMMAND ---------- + +# Get current table state for potential rollback +try: + hist = spark.sql("DESCRIBE HISTORY Inventory LIMIT 1").collect()[0] + restore_ts = hist["timestamp"] +except: + restore_ts = None + +# COMMAND ---------- + +# Main procedure logic with error handling +try: + # Create temporary table (equivalent to VOLATILE TABLE) + spark.sql(""" + CREATE OR REPLACE TABLE VolatileHighStock ( + ItemId INT, + LocationId STRING, + StockLevel DECIMAL(8,2) + ) + """) + + # Insert high stock items into temporary table + spark.sql(f""" + INSERT INTO VolatileHighStock (ItemId, LocationId, StockLevel) + SELECT + i.ItemId, + i.LocationId, + i.StockLevel + FROM Inventory i + WHERE + i.LocationId = '{location_id}' AND + i.StockLevel > {threshold} + """) + + # Update inventory levels to threshold + # Note: Databricks doesn't support UPDATE with FROM clause directly, using MERGE instead + spark.sql(f""" + MERGE INTO Inventory inv + USING VolatileHighStock v + ON inv.ItemId = v.ItemId AND inv.LocationId = v.LocationId + WHEN MATCHED THEN + UPDATE SET inv.StockLevel = {threshold} + """) + + # Set successful result code + result_code = 0 + +except Exception as e: + # Error handling - set error code and rollback + result_code = 2 + + # Attempt to restore table to prior state + if restore_ts is not None: + spark.sql(f"RESTORE TABLE Inventory TO TIMESTAMP AS OF '{restore_ts}'") + + print(f"Error executing procedure: {str(e)}") + +finally: + # Always clean up temporary table + spark.sql("DROP TABLE IF EXISTS VolatileHighStock") + + # Set output parameter + dbutils.widgets.remove("pResultCode") + dbutils.widgets.text("pResultCode", str(result_code)) + +# COMMAND ---------- + +# Return result code for notebook +dbutils.notebook.exit(str(result_code)) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Static Syntax Check Results +# MAGIC No syntax errors were detected during the static check. +# MAGIC However, please review the code carefully as some issues may only be detected during runtime. \ No newline at end of file diff --git a/sql2dbx/examples/tsql/input/tsql_example1_multi_statement_transformation.sql b/sql2dbx/examples/tsql/input/tsql_example1_multi_statement_transformation.sql new file mode 100644 index 00000000..b998a4f5 --- /dev/null +++ b/sql2dbx/examples/tsql/input/tsql_example1_multi_statement_transformation.sql @@ -0,0 +1,52 @@ +-- ========================================== +-- T-SQL EXAMPLE #1: Multi-Statement Data Transformation +-- ========================================== + +-- Create a table for product data +CREATE TABLE Products ( + ProductID INT, + ProductName VARCHAR(100), + Price MONEY, + CreatedAt DATETIME DEFAULT GETDATE() +); + +-- Insert some sample products +INSERT INTO Products (ProductID, ProductName, Price) +VALUES + (101, 'Widget A', 12.50), + (102, 'Widget B', 19.99), + (103, 'Widget C', 29.75); + +-- Create a temporary table to capture discounted items +CREATE TABLE #Discounts ( + ProductID INT, + DiscountRate FLOAT +); + +-- Insert discount rates +INSERT INTO #Discounts (ProductID, DiscountRate) +VALUES + (101, 0.10), + (103, 0.25); + +-- Update product prices where a discount is applicable +-- (T-SQL allows an UPDATE with a FROM clause) +UPDATE p +SET p.Price = p.Price * (1 - d.DiscountRate) +FROM Products AS p +INNER JOIN #Discounts AS d ON p.ProductID = d.ProductID; + +-- Demonstrate a conditional DELETE +-- Suppose we delete any product older than 7 days with no discounts +-- (Pretend we have some reason to prune old, non-discounted items) +DELETE p +FROM Products p +WHERE p.CreatedAt < DATEADD(DAY, -7, GETDATE()) + AND p.ProductID NOT IN (SELECT ProductID FROM #Discounts); + +-- Final SELECT to confirm changes +SELECT * FROM Products; + +-- Clean up permanent table only +-- DROP TABLE IF EXISTS #Discounts; -- Temp tables are automatically dropped at the end of the session +DROP TABLE IF EXISTS Products; diff --git a/sql2dbx/examples/tsql/input/tsql_example2_stored_procedure.sql b/sql2dbx/examples/tsql/input/tsql_example2_stored_procedure.sql new file mode 100644 index 00000000..708629fd --- /dev/null +++ b/sql2dbx/examples/tsql/input/tsql_example2_stored_procedure.sql @@ -0,0 +1,111 @@ +-- ========================================== +-- T-SQL EXAMPLE #2: Stored Procedure with Outlier Checking +-- ========================================== + +CREATE PROCEDURE [dbo].[DEMO_FORECAST_OUTLIER_CHECK_UPDATE] + @SchemaName NVARCHAR(128), + @OutlierMultiplier DECIMAL(5,2) = 1.30 +AS +BEGIN + -- Variable declarations + DECLARE @Result INT = 0; + DECLARE @ErrorMsg NVARCHAR(MAX); + DECLARE @CurrentDate DATE; + DECLARE @ErrorProcName NVARCHAR(128) = OBJECT_NAME(@@PROCID); + + -- Create a temporary table to store outlier thresholds + CREATE TABLE #TEMP_OUTLIER_INFO ( + LocationId NVARCHAR(10), + OutlierThreshold DECIMAL(8,2) + ); + + BEGIN TRY + BEGIN TRAN FORECAST_OUTLIER_CHECK; + + -- 1) Retrieve current date from a "SystemDateTable" in the given schema + DECLARE @SQL NVARCHAR(MAX); + SET @SQL = N' + SELECT @CurrentDateOut = SystemDate + FROM ' + QUOTENAME(@SchemaName) + N'.SystemDateTable; + '; + + EXEC sp_executesql + @SQL, + N'@CurrentDateOut DATE OUTPUT', + @CurrentDateOut = @CurrentDate OUTPUT; + + -- 2) Insert outlier thresholds (99th percentile) from "HistoricalDataTable" + SET @SQL = N' + INSERT INTO #TEMP_OUTLIER_INFO (LocationId, OutlierThreshold) + SELECT + d.LocationId, + CONVERT(DECIMAL(8,2), + PERCENTILE_CONT(0.99) + WITHIN GROUP (ORDER BY d.MetricValue) + OVER (PARTITION BY d.LocationId) + ) + FROM ' + QUOTENAME(@SchemaName) + N'.HistoricalDataTable d + WHERE CONVERT(DATE, d.TargetDate) >= DATEADD(YEAR, -1, @CurrentDateParam) + '; + + EXEC sp_executesql + @SQL, + N'@CurrentDateParam DATE', + @CurrentDateParam = @CurrentDate; + + -- 3) Save original forecast values above threshold in "ForecastTable" + SET @SQL = N' + UPDATE f + SET f.OriginalForecastValue = f.ForecastValue + FROM ' + QUOTENAME(@SchemaName) + N'.ForecastTable f + INNER JOIN #TEMP_OUTLIER_INFO t ON f.LocationId = t.LocationId + WHERE CONVERT(DATE, f.ForecastDate) = @CurrentDateParam + AND f.ForecastValue > t.OutlierThreshold * @Multiplier + '; + + EXEC sp_executesql + @SQL, + N'@CurrentDateParam DATE, @Multiplier DECIMAL(5,2)', + @CurrentDateParam = @CurrentDate, + @Multiplier = @OutlierMultiplier; + + -- 4) Update outlier values to cap them at threshold * multiplier + SET @SQL = N' + UPDATE f + SET f.ForecastValue = t.OutlierThreshold * @Multiplier + FROM ' + QUOTENAME(@SchemaName) + N'.ForecastTable f + INNER JOIN #TEMP_OUTLIER_INFO t ON f.LocationId = t.LocationId + WHERE CONVERT(DATE, f.ForecastDate) = @CurrentDateParam + AND f.ForecastValue > t.OutlierThreshold * @Multiplier + '; + + EXEC sp_executesql + @SQL, + N'@CurrentDateParam DATE, @Multiplier DECIMAL(5,2)', + @CurrentDateParam = @CurrentDate, + @Multiplier = @OutlierMultiplier; + + COMMIT TRAN FORECAST_OUTLIER_CHECK; + END TRY + BEGIN CATCH + SET @Result = 2; + SET @ErrorMsg = ERROR_MESSAGE(); + PRINT('Error: ' + @ErrorMsg); + + IF (@@TRANCOUNT > 0) + ROLLBACK TRAN FORECAST_OUTLIER_CHECK; + + -- Log the error (simplified): + EXEC [dbo].[LogError] + @ProcName = @ErrorProcName, + @ErrorMessage = @ErrorMsg; + + -- Re-throw the error + THROW; + END CATCH; + + -- DROP TABLE IF EXISTS #TEMP_OUTLIER_INFO; -- Temp tables are automatically dropped at the end of the session + + RETURN @Result; +END; +GO diff --git a/sql2dbx/examples/tsql/output/tsql_example1_multi_statement_transformation.py b/sql2dbx/examples/tsql/output/tsql_example1_multi_statement_transformation.py new file mode 100644 index 00000000..0d79e64d --- /dev/null +++ b/sql2dbx/examples/tsql/output/tsql_example1_multi_statement_transformation.py @@ -0,0 +1,96 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # tsql_example1_multi_statement_transformation +# MAGIC This notebook was automatically converted from the script below. It may contain errors, so use it as a starting point and make necessary corrections. +# MAGIC +# MAGIC Source script: `/Workspace/Users/hiroyuki.nakazato@databricks.com/.bundle/sql2dbx/dev/files/examples/tsql/input/tsql_example1_multi_statement_transformation.sql` + +# COMMAND ---------- + +# Create Products table +spark.sql(""" +CREATE TABLE IF NOT EXISTS Products ( + ProductID INT, + ProductName STRING, + Price DECIMAL(19,4), -- MONEY type is mapped to DECIMAL in Databricks + CreatedAt TIMESTAMP DEFAULT current_timestamp() +) USING DELTA +""") + +# COMMAND ---------- + +# Insert data into Products +spark.sql(""" +INSERT INTO Products (ProductID, ProductName, Price) +VALUES + (101, 'Widget A', 12.50), + (102, 'Widget B', 19.99), + (103, 'Widget C', 29.75) +""") + +# COMMAND ---------- + +# Create a Discounts table (using a regular Delta table instead of a temp table) +spark.sql(""" +CREATE OR REPLACE TABLE Discounts ( + ProductID INT, + DiscountRate DOUBLE +) USING DELTA +""") + +# COMMAND ---------- + +# Insert discount data +spark.sql(""" +INSERT INTO Discounts (ProductID, DiscountRate) +VALUES + (101, 0.10), + (103, 0.25) +""") + +# COMMAND ---------- + +# Update product prices +# Using MERGE instead of UPDATE with JOIN since Databricks doesn't support the latter +spark.sql(""" +MERGE INTO Products p +USING Discounts d +ON p.ProductID = d.ProductID +WHEN MATCHED THEN + UPDATE SET Price = Price * (1 - d.DiscountRate) +""") + +# COMMAND ---------- + +# Delete old products that aren't in the discounts table +# Need to rewrite as Databricks doesn't support DELETE with alias +spark.sql(""" +DELETE FROM Products +WHERE CreatedAt < date_add(current_timestamp(), -7) +AND ProductID NOT IN (SELECT ProductID FROM Discounts) +""") + +# COMMAND ---------- + +# Display all products +display(spark.sql("SELECT * FROM Products")) + +# COMMAND ---------- + +# Clean up tables +spark.sql("DROP TABLE IF EXISTS Products") +spark.sql("DROP TABLE IF EXISTS Discounts") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Static Syntax Check Results +# MAGIC These are errors from static syntax checks. Manual corrections are required for these errors. +# MAGIC ### Spark SQL Syntax Errors +# MAGIC ``` +# MAGIC Error in query 0: [PARSE_SYNTAX_ERROR] Syntax error at or near end of input. SQLSTATE: 42601 (line 1, pos 229) +# MAGIC +# MAGIC == SQL == +# MAGIC EXPLAIN CREATE TABLE IF NOT EXISTS Products ( ProductID INT, ProductName STRING, Price DECIMAL(19,4), -- MONEY type is mapped to DECIMAL in Databricks CreatedAt TIMESTAMP DEFAULT current_timestamp() ) USING DELTA +# MAGIC -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------^^^ +# MAGIC ``` \ No newline at end of file diff --git a/sql2dbx/examples/tsql/output/tsql_example2_stored_procedure.py b/sql2dbx/examples/tsql/output/tsql_example2_stored_procedure.py new file mode 100644 index 00000000..5e4d71e0 --- /dev/null +++ b/sql2dbx/examples/tsql/output/tsql_example2_stored_procedure.py @@ -0,0 +1,135 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # tsql_example2_stored_procedure +# MAGIC This notebook was automatically converted from the script below. It may contain errors, so use it as a starting point and make necessary corrections. +# MAGIC +# MAGIC Source script: `/Workspace/Users/hiroyuki.nakazato@databricks.com/.bundle/sql2dbx/dev/files/examples/tsql/input/tsql_example2_stored_procedure.sql` + +# COMMAND ---------- + +# Create widgets for input parameters +dbutils.widgets.text("SchemaName", "") +dbutils.widgets.text("OutlierMultiplier", "1.30") + +# COMMAND ---------- + +# Get parameter values +schema_name = dbutils.widgets.get("SchemaName") +outlier_multiplier = float(dbutils.widgets.get("OutlierMultiplier")) + +# COMMAND ---------- + +# Create a temporary Delta table to replace #TEMP_OUTLIER_INFO +spark.sql("DROP TABLE IF EXISTS TEMP_OUTLIER_INFO") +spark.sql(""" +CREATE TABLE TEMP_OUTLIER_INFO ( + LocationId STRING, + OutlierThreshold DECIMAL(8,2) +) +""") + +# COMMAND ---------- + +# Set up variables +error_proc_name = "DEMO_FORECAST_OUTLIER_CHECK_UPDATE" +result = 0 + +# COMMAND ---------- + +# Capture table state for potential rollback +try: + forecast_hist = spark.sql(f"DESCRIBE HISTORY {schema_name}.ForecastTable LIMIT 1").collect()[0] + forecast_restore_ts = forecast_hist["timestamp"] +except: + forecast_restore_ts = None + +# COMMAND ---------- + +try: + # Get current date from system table + current_date_df = spark.sql(f"SELECT SystemDate FROM {schema_name}.SystemDateTable") + current_date = current_date_df.first()["SystemDate"] + + # Calculate outlier thresholds and insert into temporary table + # Note: Using percentile instead of PERCENTILE_CONT with WITHIN GROUP + spark.sql(f""" + INSERT INTO TEMP_OUTLIER_INFO (LocationId, OutlierThreshold) + SELECT + d.LocationId, + CAST(percentile(d.MetricValue, 0.99) AS DECIMAL(8,2)) AS OutlierThreshold + FROM {schema_name}.HistoricalDataTable d + WHERE CAST(d.TargetDate AS DATE) >= date_add('{current_date}', -365) + GROUP BY d.LocationId + """) + + # Update original forecast values + # Using MERGE instead of UPDATE with JOIN + spark.sql(f""" + MERGE INTO {schema_name}.ForecastTable f + USING ( + SELECT + f.LocationId, + f.ForecastValue, + t.OutlierThreshold + FROM {schema_name}.ForecastTable f + INNER JOIN TEMP_OUTLIER_INFO t ON f.LocationId = t.LocationId + WHERE CAST(f.ForecastDate AS DATE) = '{current_date}' + AND f.ForecastValue > t.OutlierThreshold * {outlier_multiplier} + ) src + ON f.LocationId = src.LocationId AND CAST(f.ForecastDate AS DATE) = '{current_date}' + WHEN MATCHED THEN + UPDATE SET f.OriginalForecastValue = f.ForecastValue + """) + + # Update forecast values to threshold + spark.sql(f""" + MERGE INTO {schema_name}.ForecastTable f + USING ( + SELECT + f.LocationId, + t.OutlierThreshold * {outlier_multiplier} AS AdjustedValue + FROM {schema_name}.ForecastTable f + INNER JOIN TEMP_OUTLIER_INFO t ON f.LocationId = t.LocationId + WHERE CAST(f.ForecastDate AS DATE) = '{current_date}' + AND f.ForecastValue > t.OutlierThreshold * {outlier_multiplier} + ) src + ON f.LocationId = src.LocationId AND CAST(f.ForecastDate AS DATE) = '{current_date}' + WHEN MATCHED THEN + UPDATE SET f.ForecastValue = src.AdjustedValue + """) + +except Exception as e: + # Error handling + result = 2 + error_msg = str(e) + print(f"Error: {error_msg}") + + # Rollback using Delta table restore + if forecast_restore_ts is not None: + spark.sql(f"RESTORE TABLE {schema_name}.ForecastTable TO TIMESTAMP AS OF '{forecast_restore_ts}'") + + # Log error - assuming LogError exists as another procedure + try: + spark.sql(f""" + CALL dbo.LogError('{error_proc_name}', '{error_msg.replace("'", "''")}') + """) + except: + print(f"Failed to log error. Original error: {error_msg}") + + raise e + +finally: + # Clean up temporary table + spark.sql("DROP TABLE IF EXISTS TEMP_OUTLIER_INFO") + +# COMMAND ---------- + +# Return result code +dbutils.notebook.exit(str(result)) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Static Syntax Check Results +# MAGIC No syntax errors were detected during the static check. +# MAGIC However, please review the code carefully as some issues may only be detected during runtime. \ No newline at end of file diff --git a/sql2dbx/notebooks/00_main.py b/sql2dbx/notebooks/00_main.py new file mode 100644 index 00000000..7490ad4d --- /dev/null +++ b/sql2dbx/notebooks/00_main.py @@ -0,0 +1,706 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # sql2dbx +# MAGIC sql2dbx is an automation tool designed to convert SQL files into Databricks notebooks. It leverages Large Language Models (LLMs) to perform the conversion based on system prompts tailored for various SQL dialects. sql2dbx consists of a series of Databricks notebooks. +# MAGIC +# MAGIC While the Databricks notebooks generated by sql2dbx may require manual adjustments, they serve as a valuable starting point for migrating SQL-based workflows to the Databricks environment. +# MAGIC +# MAGIC This main notebook functions as the entry point for sql2dbx's series of processes that convert SQL files into Databricks notebooks. + +# COMMAND ---------- + +# MAGIC +# MAGIC %md-sandbox +# MAGIC ## 💫 Conversion Flow Diagram +# MAGIC The diagram below illustrates the flow of the SQL to Databricks notebooks conversion process. +# MAGIC +# MAGIC +# MAGIC +# MAGIC
+# MAGIC flowchart TD +# MAGIC input[Input SQL Files] -->|Input| analyze[[01_analyze_input_files]] +# MAGIC analyze <-->|Read & Write| conversionTable[Conversion Result Table] +# MAGIC +# MAGIC conversionTable <-->|Read & Write| convert[[02_convert_sql_to_databricks]] +# MAGIC convert -.->|Use| endpoint[Model Serving Endpoint] +# MAGIC convert -.->|Refer| prompts["Conversion Prompt Yaml
(SQL Dialect Specific)"] +# MAGIC +# MAGIC conversionTable <-->|Read & Write| validate[[03_01_static_syntax_check]] +# MAGIC +# MAGIC conversionTable <-->|Read & Write| fixErrors[[03_02_fix_syntax_error]] +# MAGIC fixErrors -.->|Use| endpoint +# MAGIC +# MAGIC conversionTable <-->|Read & Write| splitCells[[04_split_cells]] +# MAGIC +# MAGIC conversionTable -->|Input| export[[05_export_to_databricks_notebooks]] +# MAGIC export -->|Output| notebooks[Converted Databricks Notebooks] +# MAGIC +# MAGIC conversionTable <-->|Read & Write| adjust[[11_adjust_conversion_targets]] +# MAGIC +# MAGIC %% Layout control with invisible lines +# MAGIC convert --- validate --- fixErrors --- splitCells --- export +# MAGIC +# MAGIC %% Styling +# MAGIC classDef process fill:#E6E6FA,stroke:#333,stroke-width:2px; +# MAGIC class analyze,convert,validate,fixErrors,splitCells,adjust,export process; +# MAGIC classDef data fill:#E0F0E0,stroke:#333,stroke-width:2px; +# MAGIC class input,conversionTable,notebooks data; +# MAGIC classDef external fill:#FFF0DB,stroke:#333,stroke-width:2px; +# MAGIC class endpoint,prompts external; +# MAGIC +# MAGIC %% Make layout control lines invisible +# MAGIC linkStyle 12 stroke-width:0px; +# MAGIC linkStyle 13 stroke-width:0px; +# MAGIC linkStyle 14 stroke-width:0px; +# MAGIC linkStyle 15 stroke-width:0px; +# MAGIC
+ +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## 📝 Conversion Steps +# MAGIC This main notebook executes the following notebooks in sequence: +# MAGIC +# MAGIC | Notebook Name | Description | +# MAGIC |---|---| +# MAGIC | 01_analyze_input_files | Analyzes the input SQL files, calculates token counts, and saves the results to a Delta table. | +# MAGIC | 02_convert_sql_to_databricks | Converts the SQL code to a Python function that runs in a Databricks notebook using an LLM and updates the result table. | +# MAGIC | 03_01_static_syntax_check | Performs static syntax checks on Python functions and the Spark SQL contained within them, updating the result table with any errors found. | +# MAGIC | 03_02_fix_syntax_error | Fixes syntax errors in Python functions and SQL statements identified in the previous step using an LLM and updates the result table. | +# MAGIC | 04_split_cells | Splits the converted code into separate cells for better organization and readability in Databricks notebooks. | +# MAGIC | 05_export_to_databricks_notebooks | Exports the converted code to Databricks notebooks. | +# MAGIC | 11_adjust_conversion_targets | (Optional) Adjusts the conversion targets by setting the `is_conversion_target` field to `True` for specific files that need to be re-converted. This can be used to reprocess files that did not convert satisfactorily. | +# MAGIC +# MAGIC ## 🎯 Conversion Sources +# MAGIC sql2dbx provides conversion from various SQL dialects to Databricks notebooks through YAML-defined conversion prompts. The table below shows the default YAML file for each supported SQL dialect. When you select a specific dialect in the `sql_dialect` widget, the corresponding default YAML file is used automatically. If you specify a path in the `conversion_prompt_yaml` widget, it will take precedence over the default YAML for that dialect. +# MAGIC +# MAGIC | SQL Dialect | Source System Example | Default YAML file | +# MAGIC | --- | --- | --- | +# MAGIC | `mysql` | MySQL / MariaDB / Amazon Aurora MySQL | mysql_to_databricks_notebook.yml | +# MAGIC | `netezza` | IBM Netezza | netezza_to_databricks_notebook.yml | +# MAGIC | `oracle` | Oracle Database / Oracle Exadata | oracle_to_databricks_notebook.yml | +# MAGIC | `postgresql` | PostgreSQL / Amazon Aurora PostgreSQL | postgresql_to_databricks_notebook.yml | +# MAGIC | `redshift` | Amazon Redshift | redshift_to_databricks_notebook.yml | +# MAGIC | `snowflake` | Snowflake | snowflake_to_databricks_notebook.yml | +# MAGIC | `teradata` | Teradata | teradata_to_databricks_notebook.yml | +# MAGIC | `tsql` | Azure Synapse Analytics / Microsoft SQL Server / Azure SQL Database / Azure SQL Managed Instance | tsql_to_databricks_notebook.yml | +# MAGIC +# MAGIC ### Creating Custom Conversion Prompts (Optional) +# MAGIC If you want to create a custom conversion prompt, you can create a YAML file in a structured format, place it in your Databricks workspace, and specify its full path in the `conversion_prompt_yaml` widget. This allows the LLM to reference the specified YAML file for the conversion. Custom conversion prompts require the following two sections: +# MAGIC +# MAGIC 1. **`system_message`**: Required section that instructs the LLM on how to perform the conversion +# MAGIC 2. **`few_shots`**: Section showing specific input-output examples (optional but recommended) +# MAGIC +# MAGIC #### Tips for Custom Conversion Prompts +# MAGIC Here are key points for creating effective conversion prompts: +# MAGIC +# MAGIC **Elements to include in the `system_message`**: +# MAGIC - Clear explanation of the conversion purpose +# MAGIC - Definition of input and output formats +# MAGIC - Additional instructions needed for specific conversions +# MAGIC - (Optional but recommended) Comment language specification (`{comment_lang}` is automatically replaced with the specified language) +# MAGIC - (Optional but recommended) Reference to common instructions for SQL to Databricks Notebook (Python) conversion (`${common_python_instructions_and_guidelines}`) +# MAGIC +# MAGIC **Effective use of `few_shots`**: +# MAGIC - Include typical cases ranging from simple to complex examples +# MAGIC - Ensure each example demonstrates specific patterns that help the LLM's understanding +# MAGIC +# MAGIC #### Example of a Custom Conversion Prompt +# MAGIC Below is a basic example of a YAML file for a custom conversion prompt: +# MAGIC +# MAGIC ```yaml +# MAGIC system_message: | +# MAGIC Convert SQL code to Python code that runs on Databricks according to the following instructions: +# MAGIC +# MAGIC # Input and Output +# MAGIC - Input: A single SQL file containing one or multiple T-SQL statements (including but not limited to `CREATE OR ALTER PROCEDURE` statements). +# MAGIC - Output: Python code with Python comments (in {comment_lang}) explaining the code and any necessary context. +# MAGIC +# MAGIC ${common_python_instructions_and_guidelines} +# MAGIC +# MAGIC # Additional Instructions +# MAGIC 1. Convert SQL queries to spark.sql() format +# MAGIC 2. Add clear Python comments explaining the code +# MAGIC 3. Use DataFrame operations instead of loops when possible +# MAGIC 4. Handle errors using try-except blocks +# MAGIC +# MAGIC few_shots: +# MAGIC - role: user +# MAGIC content: | +# MAGIC SELECT name, age +# MAGIC FROM users +# MAGIC WHERE active = 1; +# MAGIC - role: assistant +# MAGIC content: | +# MAGIC # Get names and ages of active users +# MAGIC active_users = spark.sql(""" +# MAGIC SELECT name, age +# MAGIC FROM users +# MAGIC WHERE active = 1 +# MAGIC """) +# MAGIC display(active_users) +# MAGIC ``` +# MAGIC +# MAGIC ## 📢 Prerequisites +# MAGIC Before running the main notebook, ensure that a Databricks model serving endpoint is available. You can either: +# MAGIC +# MAGIC 1. Use the Databricks Foundation Model APIs (recommended for the simplest setup) +# MAGIC 2. Set up an external model serving endpoint. You can configure it manually or use one of the following automation notebooks: +# MAGIC - Notebook for Azure OpenAI Service Endpoint Setup +# MAGIC - Notebook for Amazon Bedrock Endpoint Setup +# MAGIC +# MAGIC ## ❗ Important Notes +# MAGIC The following points should be considered before running the main notebook: +# MAGIC +# MAGIC ### Model Compatibility +# MAGIC sql2dbx has been optimized for models with large context windows and strong SQL reasoning capabilities. The following models have been verified to produce highly accurate conversions: +# MAGIC +# MAGIC > **Note:** Model specifications are evolving rapidly. For the most current specifications, please check the official documentation for each model before implementation. +# MAGIC +# MAGIC #### Primary Recommendation +# MAGIC For optimal performance and minimal setup, we recommend using this model: +# MAGIC +# MAGIC | Model | API Model Version | Input Context | Max Output | Setup Requirement | Notes | +# MAGIC |---|---|---|---|---|---| +# MAGIC | [Claude 3.7 Sonnet](https://docs.anthropic.com/en/docs/about-claude/models/all-models) | claude-3-7-sonnet-20250219 | 200K tokens | Normal: 8,192 tokens
Extended thinking: 64,000 tokens | Ready to use via Foundation Model API | Best overall choice for complex SQL conversion | +# MAGIC +# MAGIC ##### Extended Thinking Mode +# MAGIC Claude 3.7 Sonnet's extended thinking mode is optional for simple SQL queries but is recommended for handling complex SQL conversions. To enable this feature, configure the `request_params` notebook widget with thinking parameters as shown below: +# MAGIC +# MAGIC Example: +# MAGIC ```json +# MAGIC {"max_tokens": 64000, "thinking": {"type": "enabled", "budget_tokens": 16000}} +# MAGIC ``` +# MAGIC +# MAGIC > **Note:** Enabling extended thinking mode significantly improves the accuracy of complex SQL conversions but increases token usage and processing time. If the token count of the input SQL file (after removing SQL comments and extra whitespace) is 8,000 or less, extended thinking mode is expected to operate stably. For token counts exceeding this limit, errors may occur. When processing large SQL files, it is recommended to either split them into smaller files or perform the conversion without using extended thinking mode. +# MAGIC +# MAGIC #### Alternative Options for Azure Environments +# MAGIC Some organizations may have specific requirements to use Azure OpenAI models due to corporate policies or existing Azure investments. For these users, the following models are compatible: +# MAGIC +# MAGIC | Model | API Model Version | Input Context | Max Output | Setup Requirement | Notes | +# MAGIC |---|---|---|---|---|---| +# MAGIC | [Azure OpenAI o1](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/reasoning) | 2024-12-17 | 200K tokens | 100K tokens | Requires setup of External model | Good option for Azure environments | +# MAGIC | [Azure OpenAI o3-mini](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/reasoning) | 2025-01-31 | 200K tokens | 100K tokens | Requires setup of External model | Good option for Azure environments | +# MAGIC +# MAGIC ##### Notes for OpenAI O-series Models Parameters +# MAGIC When specifying `request_params`, there are several important considerations: +# MAGIC +# MAGIC 1. `reasoning_effort` parameter: +# MAGIC - `reasoning_effort` controls the depth of reasoning in the inference process for O-series models. +# MAGIC - By specifying `{"reasoning_effort": "high"}` in `request_param`, the model performs deeper reasoning, enabling more accurate conversion of complex SQL queries. +# MAGIC - As a trade-off, token consumption and processing time will increase. +# MAGIC 1. Differences in token limit parameters: +# MAGIC - For O-series models, the use of `max_tokens` is not recommended; instead, `max_completion_tokens` is used. +# MAGIC - It is recommended to run sql2dbx without specifying `max_completion_tokens` for more stable operation. +# MAGIC 1. Unsupported parameters: +# MAGIC - Generation parameters such as `temperature` and `top_p` are not supported by O-series models. +# MAGIC - For more details, refer to the [official documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/reasoning). +# MAGIC +# MAGIC #### Other Compatible Models +# MAGIC The following models have also been verified to work with sql2dbx with varying performance characteristics: +# MAGIC +# MAGIC | Model | API Model Version | Input Context | Max Output | Setup Requirement | Notes | +# MAGIC |---|---|---|---|---|---| +# MAGIC | [Claude 3.5 Sonnet](https://docs.anthropic.com/en/docs/about-claude/models/all-models) | claude-3-5-sonnet-20241022 | 200K tokens | 8,192 tokens | Requires setup of External model | Verified compatible | +# MAGIC | [Azure OpenAI GPT-4o](https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models) | 2024-05-13 | 128K tokens | 4,096 tokens | Requires setup of External model | Verified compatible | +# MAGIC | [Meta Llama 3.3 70B Instruct](https://github.com/meta-llama/llama-models/blob/main/models/llama3_3/MODEL_CARD.md) | llama-3-3-70b-instruct | 128K tokens | 8,192 tokens | Ready to use via Foundation Model API | Verified compatible | +# MAGIC +# MAGIC ### Token Management for SQL Conversion +# MAGIC The token count of SQL files directly impacts the conversion process. Two critical factors to consider: +# MAGIC +# MAGIC 1. The number of tokens in the generated Databricks notebook is typically larger than the number of tokens in the input SQL file. This is because the input SQL file reduces token count by removing comments and consolidating multiple spaces into one, while the generated notebook includes comments and indentation +# MAGIC 2. In the case of SQL files with extensive processing content, it may not be possible to generate the entire notebook in a single output. In such cases, the process will retain the already generated content and output the remainder in stages +# MAGIC +# MAGIC These factors make it essential to manage token usage efficiently. Files exceeding the model's effective processing capacity cannot be converted reliably, while staying within optimal thresholds ensures successful conversions. +# MAGIC +# MAGIC #### Recommended Token Thresholds for `token_count_threshold` Parameter +# MAGIC The `token_count_threshold` parameter determines which SQL files will be processed based on their token count (not file size). SQL content is tokenized after removing SQL comments and extra whitespace. +# MAGIC +# MAGIC | Model | Recommended `token_count_threshold` | +# MAGIC |---|---| +# MAGIC | Claude 3.7 Sonnet (Normal Mode) | 20,000 tokens (default) | +# MAGIC | Claude 3.7 Sonnet (Extended Thinking Mode) | 8,000 tokens | +# MAGIC +# MAGIC - The 20,000 token value for normal mode is based on actual testing results. While testing environments have successfully processed up to 60,000 tokens, stability decreases when exceeding 20,000 tokens. For the most stable operation, we have set 20,000 tokens as the default value. +# MAGIC - Other models (such as o1, o3-mini, etc.) have also been confirmed to operate relatively stably up to around 20,000 tokens. Theoretically, larger values may be possible, but we recommend testing in your actual environment. +# MAGIC - The 8,000 token limit for extended thinking mode is similarly derived from actual testing results. Exceeding this value may result in errors or no results being returned. When processing large SQL files, we recommend splitting them into smaller logical sections. +# MAGIC +# MAGIC #### Input File Token Count Process +# MAGIC The 01_analyze_input_files notebook counts tokens in input files as follows: +# MAGIC +# MAGIC 1. Determines the appropriate tokenizer based on the endpoint name or explicit tokenizer settings: +# MAGIC - For Claude models: Uses a character-based estimation (approx. 3.4 characters per token) based on [Anthropic's documentation](https://docs.anthropic.com/en/docs/about-claude/models/all-models) +# MAGIC - For OpenAI and other models: Uses the [openai/tiktoken](https://github.com/openai/tiktoken) library with appropriate encoding +# MAGIC 2. Measures token count of each SQL file using the selected tokenizer after removing SQL comments and extra whitespace +# MAGIC 3. Files with token count ≤ `token_count_threshold` are marked as conversion targets (`is_conversion_target = True`) +# MAGIC 4. Files exceeding the threshold are excluded from the conversion process +# MAGIC +# MAGIC ## 🔌 Parameters +# MAGIC The main notebook requires the following parameters to be set. For more granular parameter settings, please run individual specialized notebooks instead of this main notebook. The individual notebooks allow for more detailed customization for specific tasks. +# MAGIC +# MAGIC Index | Parameter Name | Required | Description | Default Value +# MAGIC --- | --- | --- | --- | --- +# MAGIC 1-1 | `input_dir` | Yes | The directory containing the SQL files to be converted. Supports locations accessible through Python `os` module (e.g., Unity Catalog Volume, Workspace, Repos, etc.). | +# MAGIC 1-2 | `endpoint_name` | Yes | The name of the Databricks Model Serving endpoint. You can find the endpoint name under the `Serving` tab. Example: If the endpoint URL is `https:///serving-endpoints/hinak-oneenvgpt4o/invocations`, specify `hinak-oneenvgpt4o`. | `databricks-claude-3-7-sonnet` +# MAGIC 1-3 | `result_catalog` | Yes | The existing catalog where the result table will be stored. | +# MAGIC 1-4 | `result_schema` | Yes | The existing schema under the specified catalog where the result table will reside. | +# MAGIC 1-5 | `token_count_threshold` | Yes | Specifies the maximum token count allowed without SQL comments for files to be included in the following conversion process. | `20000` +# MAGIC 1-x | `existing_result_table` | No | The existing result table to use for storing the analysis results. If specified, the table will be used instead of creating a new one. | +# MAGIC 2-1 | `sql_dialect` | Yes | The SQL dialect of the input files. This parameter is used to determine the appropriate conversion prompts for the SQL dialect. | `tsql` +# MAGIC 2-2 | `comment_lang` | Yes | The language for comments to be added to the converted Databricks notebooks. | `English` +# MAGIC 2-3 | `concurrency` | Yes | The number of concurrent requests sent to the model serving endpoint. | `4` +# MAGIC 2-4 | `log_level` | Yes | The logging level to use for the batch inference process. Options are `INFO` for standard logging or `DEBUG` for detailed debug information. | `INFO` +# MAGIC 2-x | `request_params` | No | Additional chat request parameters for chat in JSON format, such as `{"max_tokens": 8192}` (see [Databricks Foundation Model APIs](https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request) for reference). Empty value will use model's default parameters. | +# MAGIC 2-x | `conversion_prompt_yaml` | No | The path to the YAML file containing the conversion prompts. This file defines the system message and few-shot examples for the specific SQL dialect (e.g., T-SQL) to be converted. | +# MAGIC 3-1 | `max_fix_attempts` | Yes | The maximum number of attempts to automatically fix syntax errors in the conversion results. | `1` +# MAGIC 5-1 | `output_dir` | Yes | The directory where Databricks notebooks are saved. Supports the path in Workspace or Repos. | +# MAGIC +# MAGIC ## 📂 Input and Output +# MAGIC Main input and output of the conversion process are as follows: +# MAGIC +# MAGIC ### Input SQL Files +# MAGIC You should store the input SQL files in the `input_files_path` directory. 01_analyze_input_files notebook processes all files in the directory and its subdirectories. Supports locations accessible through Python `os` module (e.g., Unity Catalog Volume, Workspace, Repos, etc.). +# MAGIC +# MAGIC ### Convertion Result Notebook (Final Output) +# MAGIC The conversion result Databricks notebooks are created by the 05_export_to_databricks_notebooks notebook and serves as the final output of the conversion process. Supports a path in Workspace or Repos. +# MAGIC +# MAGIC ### Conversion Result Table (Intermediate Output) +# MAGIC The conversion result table is created by the 01_analyze_input_files notebook and serves as both an input and output for subsequent notebooks in the conversion process. It is a Delta Lake table that stores the analysis results of input SQL files, including token counts, file metadata, and conversion status. +# MAGIC +# MAGIC #### Table Naming +# MAGIC The table name is constructed using parameters specified in the notebooks, following the format below: +# MAGIC +# MAGIC `{result_catalog}.{result_schema}.{result_table_prefix}_{YYYYMMDDHHmm}` +# MAGIC +# MAGIC For example, if the `result_catalog` is "my_catalog", the `result_schema` is "my_schema", the `result_table_prefix` is "conversion_targets", and the current time (UTC) is 2024-06-14 11:39, the table name will be: +# MAGIC +# MAGIC `my_catalog.my_schema.conversion_targets_202406141139` +# MAGIC +# MAGIC #### Table Schema +# MAGIC The table schema is as follows: +# MAGIC +# MAGIC | Column Name | Data Type | Description | +# MAGIC |---|---|---| +# MAGIC | `input_file_number` | int | A unique integer identifier for each input file. The numbering starts from `1`. | +# MAGIC | `input_file_path` | string | The full path to the input file. | +# MAGIC | `input_file_encoding` | string | The detected encoding of the input file (e.g., `UTF-8`). | +# MAGIC | `tokenizer_type` | string | The type of tokenizer used for token counting (e.g., `claude` or `openai`). | +# MAGIC | `tokenizer_model` | string | The specific model or encoding used for tokenization (e.g., `claude` for Claude models or `o200k_base` for OpenAI models). | +# MAGIC | `input_file_token_count` | int | The total number of tokens in the input file. | +# MAGIC | `input_file_token_count_without_sql_comments` | int | The number of tokens in the input file excluding SQL comments. | +# MAGIC | `input_file_content` | string | The entire content of the input file. | +# MAGIC | `input_file_content_without_sql_comments` | string | The content of the input file excluding SQL comments. | +# MAGIC | `is_conversion_target` | boolean | Indicates whether the file is a conversion target (True or False). This is determined in `01_analyze_input_files` based on a comparison between the token count of the input file (excluding SQL comments) and the `token_count_threshold`. It is automatically updated from `True` to `False` once the conversion process is successfully completed. | +# MAGIC | `model_serving_endpoint_for_conversion` | string | The model serving endpoint for the conversion process. | +# MAGIC | `model_serving_endpoint_for_fix` | string | The model serving endpoint for syntax error fixing. | +# MAGIC | `request_params_for_conversion` | string | Request parameters for the conversion process in JSON format. | +# MAGIC | `request_params_for_fix` | string | Request parameters for syntax error fixing in JSON format. | +# MAGIC | `result_content` | string | The converted content of the file after processing. (Initially `null`) | +# MAGIC | `result_prompt_tokens` | int | The number of prompt tokens used for the conversion. (Initially `null`) | +# MAGIC | `result_completion_tokens` | int | The number of completion tokens generated by the model. (Initially `null`) | +# MAGIC | `result_total_tokens` | int | The total number of tokens (prompt + completion) used for the conversion. (Initially `null`) | +# MAGIC | `result_processing_time_seconds` | float | The time in seconds it took to process the conversion request. (Initially `null`) | +# MAGIC | `result_timestamp` | timestamp | The timestamp (UTC) when the `result_content` was generated or updated. (Initially `null`) | +# MAGIC | `result_error` | string | Any errors encountered during the conversion process. (Initially `null`) | +# MAGIC | `result_python_parse_error` | string | Any errors encountered during the Python function syntax check using `ast.parse`. | +# MAGIC | `result_extracted_sqls` | array | The list of SQL statements extracted from the Python function. (Initially `null`) | +# MAGIC | `result_sql_parse_errors` | array | Any errors encountered during the SQL syntax check using `EXPLAIN sql`. (Initially `null`) | +# MAGIC +# MAGIC ## 🔄 How to Re-convert Specific Files +# MAGIC If the conversion result is not satisfactory, you can re-convert specific files by following these steps: +# MAGIC +# MAGIC 1. Use the 11_adjust_conversion_targets notebook to set the `is_conversion_target` field to `True` for the files you want to re-convert. +# MAGIC 2. Re-run the 02_convert_sql_to_databricks and subsequent processes. Only the files marked as `is_conversion_target` with `True` will be re-converted. +# MAGIC - To introduce more randomness in the LLM's conversion process and obtain different results on each run, it is recommended to set the `temperature` in `request_params` to above 0.5 if the model being used supports it. +# MAGIC +# MAGIC ## 💻 Verified Environments +# MAGIC This notebook has been verified to work in the following environments: +# MAGIC +# MAGIC 1. Databricks serverless compute for notebooks and jobs +# MAGIC 2. Databricks classic compute (All-purpose compute and Jobs compute) +# MAGIC - Recommended settings: Single-node cluster, Photon is not required +# MAGIC - Verified Databricks Runtime (DBR) version +# MAGIC - 15.3 LTS +# MAGIC - 14.3 LTS + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## 0. Set Up Configuration Parameters +# MAGIC Major configuration parameters are set up in this section. If you need to change other parameters, change then in the respective notebooks. + +# COMMAND ---------- + +# DBTITLE 1,Install Packages +# MAGIC %pip install -r requirements.txt +# MAGIC dbutils.library.restartPython() + +# COMMAND ---------- + +# DBTITLE 1,Import Libraries +import json + +import pandas as pd +from pyscripts.conversion_prompt_helper import ConversionPromptHelper +from pyscripts.databricks_credentials import DatabricksCredentials +from pyscripts.notebook_i18n import get_supported_languages + +# COMMAND ---------- + +# DBTITLE 1,Configurations +# Params for 01_analyze_input_files +dbutils.widgets.text("input_dir", "", "1-1. Input Directory") +dbutils.widgets.text("endpoint_name", "databricks-claude-3-7-sonnet", "1-2. Serving Endpoint Name") +dbutils.widgets.text("result_catalog", "", "1-3. Result Catalog") +dbutils.widgets.text("result_schema", "", "1-4. Result Schema") +dbutils.widgets.text("token_count_threshold", "20000", "1-5. Token Count Threshold") +dbutils.widgets.text("existing_result_table", "", "Existing Result Table (Optional)") + +# Params for 02_convert_sql_to_databricks +dbutils.widgets.dropdown("sql_dialect", "tsql", ConversionPromptHelper.get_supported_sql_dialects(), "2-1. SQL Dialect") +dbutils.widgets.dropdown("comment_lang", "English", get_supported_languages(), "2-2. Comment Language") +dbutils.widgets.text("concurrency", "4", "2-3. Concurrency Requests") +dbutils.widgets.dropdown("log_level", "INFO", ["DEBUG", "INFO"], "2-4. Log Level") +dbutils.widgets.text("request_params", "", "Chat Request Params (Optional)") +dbutils.widgets.text("conversion_prompt_yaml", "", "YAML path for Conversion Prompt (Optional)") + +# Params for 03_syntax_check_and_fix +dbutils.widgets.text("max_fix_attempts", "1", "3-1. Maximum Fix Attempts") + +# Params for 05_export_to_databricks_notebooks +dbutils.widgets.text("output_dir", "", "5-1. Output Directory") + +# COMMAND ---------- + +# DBTITLE 1,Load Configurations +input_dir = dbutils.widgets.get("input_dir") +endpoint_name = dbutils.widgets.get("endpoint_name") +result_catalog = dbutils.widgets.get("result_catalog") +result_schema = dbutils.widgets.get("result_schema") +token_count_threshold = int(dbutils.widgets.get("token_count_threshold")) +existing_result_table = dbutils.widgets.get("existing_result_table") +comment_lang = dbutils.widgets.get("comment_lang") +concurrency = int(dbutils.widgets.get("concurrency")) +request_params = dbutils.widgets.get("request_params") +log_level = dbutils.widgets.get("log_level") +max_fix_attempts = int(dbutils.widgets.get("max_fix_attempts")) +output_dir = dbutils.widgets.get("output_dir") + +# Determin whith conversion YAML to use +_conversion_prompt_yaml = dbutils.widgets.get("conversion_prompt_yaml") +sql_dialect = dbutils.widgets.get("sql_dialect") + +if _conversion_prompt_yaml: + conversion_prompt_yaml = _conversion_prompt_yaml +else: + conversion_prompt_yaml = ConversionPromptHelper.get_default_yaml_for_sql_dialect(sql_dialect) + +input_dir, endpoint_name, result_catalog, result_schema, token_count_threshold, existing_result_table, conversion_prompt_yaml, comment_lang, request_params, log_level, max_fix_attempts, output_dir + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## 1. Analyze Input Files +# MAGIC Analyzes the input SQL files, calculates token counts, and saves the results to a Delta table. + +# COMMAND ---------- + +# DBTITLE 1,Analyze Input Files +result_table = dbutils.notebook.run("01_analyze_input_files", 0, { + "input_dir": input_dir, + "endpoint_name": endpoint_name, + "result_catalog": result_catalog, + "result_schema": result_schema, + "token_count_threshold": token_count_threshold, + "existing_result_table": existing_result_table, +}) +print(f"Conversion result table: {result_table}") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Files Selected for Conversion +# MAGIC Files within token threshold: Will be converted to Databricks notebooks. + +# COMMAND ---------- + +# DBTITLE 1,Files Selected for Conversion +spark.sql(f""" + SELECT + input_file_number, + input_file_path, + input_file_token_count_without_sql_comments + FROM {result_table} + WHERE is_conversion_target = true + ORDER BY input_file_number +""").display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Files Exceeding Token Threshold +# MAGIC Files exceeding threshold: Need manual review. (Consider splitting into smaller files) + +# COMMAND ---------- + +# DBTITLE 1,Files Exceeding Token Threshold +spark.sql(f""" + SELECT + input_file_number, + input_file_path, + input_file_token_count_without_sql_comments + FROM {result_table} + WHERE is_conversion_target = false + ORDER BY input_file_number +""").display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## 2. Convert SQL to Databricks +# MAGIC Converts the SQL code to a Python function that runs in a Databricks notebook using an LLM and updates the result table. + +# COMMAND ---------- + +# DBTITLE 1,Convert SQL to Databricks Notebooks +dbutils.notebook.run("02_convert_sql_to_databricks", 0, { + "endpoint_name": endpoint_name, + "result_table": result_table, + "conversion_prompt_yaml": conversion_prompt_yaml, + "comment_lang": comment_lang, + "concurrency": concurrency, + "request_params": request_params, + "log_level": log_level, +}) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Successfully Converted Files +# MAGIC The following table shows files that have been successfully converted to Databricks notebooks. + +# COMMAND ---------- + +# DBTITLE 1,Successfully Converted Files +spark.sql(f""" + SELECT + input_file_number, + input_file_path, + result_content, + input_file_token_count_without_sql_comments, + result_prompt_tokens, + result_completion_tokens, + result_total_tokens, + result_timestamp + FROM {result_table} + WHERE is_conversion_target = false + AND result_content IS NOT NULL + ORDER BY input_file_number +""").display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Files with Conversion Errors +# MAGIC The following table shows files that have conversion errors. + +# COMMAND ---------- + +# DBTITLE 1,Files with Conversion Errors +spark.sql(f""" + SELECT + input_file_number, + input_file_path, + result_error, + result_timestamp + FROM {result_table} + WHERE is_conversion_target = true + AND result_error IS NOT NULL + ORDER BY input_file_number +""").display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## 3. Syntax Check and Fix +# MAGIC Performs static syntax checks on Python functions and the Spark SQL contained within them, and attempts to fix any errors found. + +# COMMAND ---------- + +# DBTITLE 1,Function for Syntax Error File Count + + +def get_error_file_count(result_table: str) -> int: + """Get the count of files with syntax errors.""" + error_count = spark.sql(f""" + SELECT COUNT(*) as error_count + FROM {result_table} + WHERE result_python_parse_error IS NOT NULL + OR (result_sql_parse_errors IS NOT NULL AND size(result_sql_parse_errors) > 0) + """).collect()[0]['error_count'] + return error_count + +# COMMAND ---------- + + +# DBTITLE 1,Check and Fix Syntax Errors +for attempt in range(max_fix_attempts): + # Run static syntax check + print(f"Attempt {attempt + 1} of {max_fix_attempts}") + dbutils.notebook.run("03_01_static_syntax_check", 0, { + "result_table": result_table, + }) + + # Check if there are any errors + error_count = get_error_file_count(result_table) + if error_count == 0: + print("No syntax errors found. Exiting fix loop.") + break + + # Run fix syntax error + print(f"Found {error_count} files with syntax errors. Attempting to fix...") + dbutils.notebook.run("03_02_fix_syntax_error", 0, { + "endpoint_name": endpoint_name, + "result_table": result_table, + "concurrency": concurrency, + "request_params": request_params, + "log_level": log_level, + }) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Final Syntax Check +# MAGIC Performs a final static syntax check after all fix attempts. + +# COMMAND ---------- + +# DBTITLE 1,Run Final Syntax Check +dbutils.notebook.run("03_01_static_syntax_check", 0, { + "result_table": result_table, +}) +error_count = get_error_file_count(result_table) +print(f"Found {error_count} files with syntax errors.") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Syntax Check Results +# MAGIC The following table shows the syntax check results for all files, including both successful and failed checks. + +# COMMAND ---------- + +# DBTITLE 1,Syntax Check Status +spark.sql(f""" + SELECT + input_file_number, + input_file_path, + result_content, + CASE + WHEN result_python_parse_error IS NULL + AND (result_sql_parse_errors IS NULL OR size(result_sql_parse_errors) = 0) + THEN 'No errors' + ELSE 'Has errors' + END as check_status, + result_python_parse_error, + result_sql_parse_errors + FROM {result_table} + ORDER BY input_file_number +""").display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## 4. Split Cells +# MAGIC Splits the converted Python code into multiple cells based on logical structure and control flow. + +# COMMAND ---------- + +# DBTITLE 1,Split Cells +dbutils.notebook.run("04_split_cells", 0, { + "result_table": result_table, + "log_level": log_level, +}) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## 5. Export to Databricks Notebooks +# MAGIC Exports the converted code to Databricks notebooks. + +# COMMAND ---------- + +# DBTITLE 1,Export to Databricks Notebooks +export_results_json = dbutils.notebook.run("05_export_to_databricks_notebooks", 0, { + "result_table": result_table, + "output_dir": output_dir, + "comment_lang": comment_lang +}) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC %md +# MAGIC ### Results Summary +# MAGIC The following table shows the conversion and export results for all input SQL files. + +# COMMAND ---------- + +# DBTITLE 1,Conversion and Export Status +# Display output directory URL +full_url = f"{DatabricksCredentials().host}#workspace{output_dir}" +displayHTML(f'

Output directory URL: {full_url}

') + +# Create a temporary view of export results +export_results = json.loads(export_results_json) +export_results_df = pd.DataFrame(export_results) +spark.createDataFrame(export_results_df).createOrReplaceTempView("temp_export_results") + +# Display complete status for all files +spark.sql(f""" + SELECT + r.input_file_number, + r.input_file_path, + CASE + WHEN r.result_content IS NULL THEN 'Not converted' + WHEN r.result_python_parse_error IS NOT NULL OR + (r.result_sql_parse_errors IS NOT NULL AND size(r.result_sql_parse_errors) > 0) + THEN 'Converted with errors' + ELSE 'Converted successfully' + END as conversion_status, + CASE + WHEN t.output_file_path IS NOT NULL THEN 'Exported successfully' + ELSE 'Not exported' + END as export_status, + t.output_file_path, + t.parse_error_count, + r.result_python_parse_error as python_errors, + r.result_sql_parse_errors as sql_errors + FROM {result_table} r + LEFT JOIN temp_export_results t + ON r.input_file_path = t.input_file_path + ORDER BY r.input_file_number +""").display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### Next Steps +# MAGIC The SQL to Databricks conversion process is now complete. The results of this process are available in the converted notebooks located in the specified output directory. It's crucial to review these results thoroughly to ensure the converted code meets your requirements and functions as expected in the Databricks environment. +# MAGIC +# MAGIC When reviewing the results, pay special attention to the following cases and take appropriate action: +# MAGIC +# MAGIC 1. Files with `Not converted` status: +# MAGIC - These input files couldn't be processed, often due to exceeding the token count threshold. +# MAGIC - Action: Consider splitting these input files into smaller, more manageable parts. You may also try increasing the `token_count_threshold` parameter if your LLM model can handle larger inputs, then re-run the conversion process. +# MAGIC +# MAGIC 2. Files with `Converted with errors` status: +# MAGIC - These files were converted but contain syntax errors. +# MAGIC - Action: Review the syntax error messages at the bottom of the output notebooks. Manually fix these errors in the converted notebooks. +# MAGIC +# MAGIC 3. Files with `Not exported` status: +# MAGIC - This status is rare but may occur if the converted content is too large. +# MAGIC - If you see a "Content size exceeds 10MB limit" message in the export process, it indicates that the input file might be too large. +# MAGIC - Action: Review and potentially reduce the size of the input SQL file, then try the conversion process again. diff --git a/sql2dbx/notebooks/00_main_ja.py b/sql2dbx/notebooks/00_main_ja.py new file mode 100644 index 00000000..5411160f --- /dev/null +++ b/sql2dbx/notebooks/00_main_ja.py @@ -0,0 +1,703 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # sql2dbx +# MAGIC sql2dbxは、SQLファイルをDatabricksノートブックへ変換する作業を自動化するツールです。大規模言語モデル(LLM)を活用し、各SQL方言用のシステムプロンプトに基づいて変換を行います。sql2dbxは、一連のDatabricksノートブックから構成されています。 +# MAGIC +# MAGIC sql2dbxによって生成されたDatabricksノートブックは、手動での調整が必要な場合がありますが、SQLベースの処理フローをDatabricks環境に移行する際の有用な出発点となります。 +# MAGIC +# MAGIC このメインノートブックは、SQLファイルをDatabricksノートブックに変換するsql2dbxの一連のプロセスのエントリーポイントとして機能します。 + +# COMMAND ---------- + +# MAGIC +# MAGIC %md-sandbox +# MAGIC ## 💫 変換フロー図 +# MAGIC 以下の図は、SQLファイル群をDatabricksノートブック群に変換するプロセスの流れを表します。 +# MAGIC +# MAGIC +# MAGIC +# MAGIC
+# MAGIC flowchart TD +# MAGIC input[SQLファイル群] -->|Input| analyze[[01_analyze_input_files]] +# MAGIC analyze <-->|Read & Write| conversionTable[変換結果テーブル] +# MAGIC +# MAGIC conversionTable <-->|Read & Write| convert[[02_convert_sql_to_databricks]] +# MAGIC convert -.->|Use| endpoint[モデルサービングエンドポイント] +# MAGIC convert -.->|Refer| prompts["変換プロンプト YAML
(SQL方言固有)"] +# MAGIC +# MAGIC conversionTable <-->|Read & Write| validate[[03_01_static_syntax_check]] +# MAGIC +# MAGIC conversionTable <-->|Read & Write| fixErrors[[03_02_fix_syntax_error]] +# MAGIC fixErrors -.->|Use| endpoint +# MAGIC +# MAGIC conversionTable <-->|Read & Write| splitCells[[04_split_cells]] +# MAGIC +# MAGIC conversionTable -->|Input| export[[05_export_to_databricks_notebooks]] +# MAGIC export -->|Output| notebooks[Databricksノートブック群] +# MAGIC +# MAGIC conversionTable <-->|Read & Write| adjust[[11_adjust_conversion_targets]] +# MAGIC +# MAGIC %% Layout control with invisible lines +# MAGIC convert --- validate --- fixErrors --- splitCells --- export +# MAGIC +# MAGIC %% Styling +# MAGIC classDef process fill:#E6E6FA,stroke:#333,stroke-width:2px; +# MAGIC class analyze,convert,validate,fixErrors,splitCells,adjust,export process; +# MAGIC classDef data fill:#E0F0E0,stroke:#333,stroke-width:2px; +# MAGIC class input,conversionTable,notebooks data; +# MAGIC classDef external fill:#FFF0DB,stroke:#333,stroke-width:2px; +# MAGIC class endpoint,prompts external; +# MAGIC +# MAGIC %% Make layout control lines invisible +# MAGIC linkStyle 12 stroke-width:0px; +# MAGIC linkStyle 13 stroke-width:0px; +# MAGIC linkStyle 14 stroke-width:0px; +# MAGIC linkStyle 15 stroke-width:0px; +# MAGIC
+ +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## 📝 変換ステップ +# MAGIC メインノートブックは、以下のノートブックを順番に実行します: +# MAGIC +# MAGIC | ノートブック名 | 説明 | +# MAGIC |---|---| +# MAGIC | 01_analyze_input_files | 入力SQLファイルの分析とトークン数の計算を行い、結果をDeltaテーブルに保存します。 | +# MAGIC | 02_convert_sql_to_databricks | LLMを使用してSQLコードをDatabricksノートブックで実行可能なPython関数に変換し、結果テーブルを更新します。 | +# MAGIC | 03_01_static_syntax_check | Python関数とその中のSpark SQLの静的構文チェックを行い、検出されたエラーを結果テーブルに更新します。 | +# MAGIC | 03_02_fix_syntax_error | 前のステップで検出されたPython関数とSQL文の構文エラーをLLMを使用して修正し、結果テーブルを更新します。 | +# MAGIC | 04_split_cells | 変換されたコードを、Databricksノートブック内での可読性や整理を高めるため、複数のセルに分割します。 | +# MAGIC | 05_export_to_databricks_notebooks | 変換されたコードをDatabricksノートブックにエクスポートします。 | +# MAGIC | 11_adjust_conversion_targets | (オプション)特定のファイルを再変換したい場合に、`is_conversion_target`を`True`に設定することで変換対象を再指定します。 | +# MAGIC +# MAGIC ## 🎯 変換対象 +# MAGIC sql2dbxは、YAMLファイルで定義された変換プロンプトを使用して、様々なSQL方言からDatabricksノートブックへの変換を行います。以下の表は、各対応SQL方言のデフォルトYAMLパスを示しています。`sql_dialect`ウィジェットで特定の方言を選択すると、対応するデフォルトのYAMLファイルが自動的に使用されます。`conversion_prompt_yaml`ウィジェットでパスを指定した場合、SQL方言のデフォルトYAMLより優先されます。 +# MAGIC +# MAGIC | SQL方言 | ソースシステム例 | デフォルトのYAMLファイル | +# MAGIC | --- | --- | --- | +# MAGIC | `mysql` | MySQL / MariaDB / Amazon Aurora MySQL | mysql_to_databricks_notebook.yml | +# MAGIC | `netezza` | IBM Netezza | netezza_to_databricks_notebook.yml | +# MAGIC | `oracle` | Oracle Database / Oracle Exadata | oracle_to_databricks_notebook.yml | +# MAGIC | `postgresql` | PostgreSQL / Amazon Aurora PostgreSQL | postgresql_to_databricks_notebook.yml | +# MAGIC | `redshift` | Amazon Redshift | redshift_to_databricks_notebook.yml | +# MAGIC | `snowflake` | Snowflake | snowflake_to_databricks_notebook.yml | +# MAGIC | `teradata` | Teradata | teradata_to_databricks_notebook.yml | +# MAGIC | `tsql` | Azure Synapse Analytics / Microsoft SQL Server / Azure SQL Database / Azure SQL Managed Instance | tsql_to_databricks_notebook.yml | +# MAGIC +# MAGIC ### カスタム変換プロンプトの作成 (オプション) +# MAGIC カスタムの変換プロンプトを作成したい場合、YAMLファイルを構造化されたフォーマットで作成し、Databricksワークスペースに配置し、`conversion_prompt_yaml`ウィジェットでフルパスを指定します。これにより、LLMは指定されたYAMLファイルを参照して変換を行います。カスタム変換プロンプトには以下の2つのセクションが必要です。 +# MAGIC +# MAGIC 1. **`system_message`**: LLMに変換方法を指示する必須セクション +# MAGIC 2. **`few_shots`**: 具体的な入出力例を示すセクション (任意だが推奨) +# MAGIC +# MAGIC #### カスタム変換プロンプトのTips +# MAGIC 効果的な変換プロンプトを作成するためのポイントは以下の通りです。 +# MAGIC +# MAGIC **`system_message`** に含めるべき要素: +# MAGIC - 変換目的の明確な説明 +# MAGIC - 入力と出力の形式の定義 +# MAGIC - 特定の変換に必要な追加の指示 +# MAGIC - (任意だが推奨) コメント言語の指定 (`{comment_lang}`は指定された言語に自動置換されます) +# MAGIC - (任意だが推奨) SQLからDatabricksノートブック (Python) への共通的なインストラクション (`${common_python_instructions_and_guidelines}`) の参照 +# MAGIC +# MAGIC **`few_shots`** の効果的な使い方: +# MAGIC - 単純な例から複雑な例まで、典型的なケースを含める +# MAGIC - 各例がLLMの理解を助ける特定のパターンを示すようにする +# MAGIC +# MAGIC #### カスタム変換プロンプトの例 +# MAGIC 以下はカスタム変換プロンプト用YAMLファイルの基本的な例です。 +# MAGIC +# MAGIC ```yaml +# MAGIC system_message: | +# MAGIC Convert SQL code to Python code that runs on Databricks according to the following instructions: +# MAGIC +# MAGIC # Input and Output +# MAGIC - Input: A single SQL file containing one or multiple T-SQL statements (including but not limited to `CREATE OR ALTER PROCEDURE` statements). +# MAGIC - Output: Python code with Python comments (in {comment_lang}) explaining the code and any necessary context. +# MAGIC +# MAGIC ${common_python_instructions_and_guidelines} +# MAGIC +# MAGIC # Additional Instructions +# MAGIC 1. Convert SQL queries to spark.sql() format +# MAGIC 2. Add clear Python comments explaining the code +# MAGIC 3. Use DataFrame operations instead of loops when possible +# MAGIC 4. Handle errors using try-except blocks +# MAGIC +# MAGIC few_shots: +# MAGIC - role: user +# MAGIC content: | +# MAGIC SELECT name, age +# MAGIC FROM users +# MAGIC WHERE active = 1; +# MAGIC - role: assistant +# MAGIC content: | +# MAGIC # Get names and ages of active users +# MAGIC active_users = spark.sql(""" +# MAGIC SELECT name, age +# MAGIC FROM users +# MAGIC WHERE active = 1 +# MAGIC """) +# MAGIC display(active_users) +# MAGIC ``` +# MAGIC +# MAGIC ## 📢 前提条件 +# MAGIC メインノートブックを実行する前に、Databricksモデルサービングエンドポイントが利用可能であることを確認してください。以下のいずれかのオプションがあります: +# MAGIC +# MAGIC 1. Databricks基盤モデルAPIを使用します(最も簡単なセットアップのため推奨) +# MAGIC 2. 外部モデルサービングエンドポイントをセットアップします。手動で設定するか、次の自動化ノートブックを使用できます: +# MAGIC - Azure OpenAI Serviceエンドポイントセットアップ用ノートブック +# MAGIC - Amazon Bedrockエンドポイントセットアップ用ノートブック +# MAGIC +# MAGIC ## ❗ 重要な注意事項 +# MAGIC メインノートブックを実行する前に、以下の点を考慮してください: +# MAGIC +# MAGIC ### モデルの互換性 +# MAGIC sql2dbxは、大きなコンテキストウィンドウとSQL推論能力の強いモデルに最適化されています。以下のモデルは高精度な変換を行うことが確認されています: +# MAGIC +# MAGIC > **注意:** モデルの仕様は日々進化しています。実装前に、各モデルの公式ドキュメントで最新情報を確認してください。 +# MAGIC +# MAGIC #### 主な推奨モデル +# MAGIC 最小限のセットアップで最適なパフォーマンスを得るために、以下のモデルを推奨します: +# MAGIC +# MAGIC | モデル | APIモデルバージョン | 入力コンテキスト | 最大出力 | セットアップ要件 | 備考 | +# MAGIC |---|---|---|---|---|---| +# MAGIC | [Claude 3.7 Sonnet](https://docs.anthropic.com/en/docs/about-claude/models/all-models) | claude-3-7-sonnet-20250219 | 200Kトークン | 通常: 8,192トークン
拡張思考: 64,000トークン | 基盤モデルAPIを通じてすぐに使用可能 | 複雑なSQL変換に最適な選択肢 | +# MAGIC +# MAGIC ##### 拡張思考モード +# MAGIC Claude 3.7 Sonnetの拡張思考モードは、単純なSQLクエリでは任意ですが、複雑なSQL変換の処理に推奨されます。この機能を有効にするには、以下のように思考パラメータを使用して`request_params`ノートブックウィジェットを設定します: +# MAGIC +# MAGIC 例: +# MAGIC ```json +# MAGIC {"max_tokens": 64000, "thinking": {"type": "enabled", "budget_tokens": 16000}} +# MAGIC ``` +# MAGIC +# MAGIC > **注意:** 拡張思考モードを有効にすることで、複雑なSQLの変換精度が大幅に向上しますが、トークンの使用量と処理時間が増加します。入力となるSQLファイルのトークン数(SQLコメントと余分な空白を削除した後)が8,000以下の場合、拡張思考モードは安定した動作が期待できます。それ以上のトークン数ではエラーが発生する可能性があります。大規模なSQLファイルを処理する場合、小さなファイルに分割するか、拡張思考モードを使用せずに変換を行うことをお勧めします。 +# MAGIC +# MAGIC #### Azure環境向けの代替オプション +# MAGIC 一部の組織では、企業ポリシーや既存のAzureへの投資など、特定の理由からAzure OpenAIモデルを使用する必要がある場合があります。そのような場合、以下のモデルが利用可能です: +# MAGIC +# MAGIC | モデル | APIモデルバージョン | 入力コンテキスト | 最大出力 | セットアップ要件 | 備考 | +# MAGIC |---|---|---|---|---|---| +# MAGIC | [Azure OpenAI o1](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/reasoning) | 2024-12-17 | 200Kトークン | 100Kトークン | 外部モデルのセットアップが必要 | Azure環境に適したオプション | +# MAGIC | [Azure OpenAI o3-mini](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/reasoning) | 2025-01-31 | 200Kトークン | 100Kトークン | 外部モデルのセットアップが必要 | Azure環境に適したオプション | +# MAGIC +# MAGIC ##### OpenAI O-seriesモデルのパラメータに関する考慮事項 +# MAGIC `request_params`を指定する際には、いくつかの重要な考慮事項があります: +# MAGIC +# MAGIC 1. `reasoning_effort`パラメータ: +# MAGIC - `reasoning_effort`は、O-seriesモデルの推論プロセスにおける思考の深さを制御します +# MAGIC - `request_param`に`{"reasoning_effort": "high"}`のように指定することで、モデルはより深い思考を行い、複雑なSQLクエリをより正確に変換します +# MAGIC - トレードオフとして、消費トークン数および処理時間が増加します +# MAGIC 1. トークン制限パラメータの違い: +# MAGIC - O-seriesモデルでは`max_tokens`の使用は推奨されず、`max_completion_tokens`が使用されます +# MAGIC - sql2dbxはトークンを制限せずに実行した方が安定した動作をするため、`max_completion_tokens`を指定せずに実行することをお勧めします +# MAGIC 1. サポートされていないパラメータ: +# MAGIC - `temperature`や`top_p`などの生成パラメータは、O-seriesモデルではサポートされていません +# MAGIC - 詳細については[公式ドキュメント](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/reasoning)を参照してください +# MAGIC +# MAGIC #### その他の互換性のあるモデル +# MAGIC 以下のモデルもsql2dbxで動作することが確認されており、さまざまなパフォーマンス特性があります: +# MAGIC +# MAGIC | モデル | APIモデルバージョン | 入力コンテキスト | 最大出力 | セットアップ要件 | 備考 | +# MAGIC |---|---|---|---|---|---| +# MAGIC | [Claude 3.5 Sonnet](https://docs.anthropic.com/en/docs/about-claude/models/all-models) | claude-3-5-sonnet-20241022 | 200Kトークン | 8,192トークン | 外部モデルのセットアップが必要 | 互換性を確認済み | +# MAGIC | [Azure OpenAI GPT-4o](https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models) | 2024-05-13 | 128Kトークン | 4,096トークン | 外部モデルのセットアップが必要 | 互換性を確認済み | +# MAGIC | [Meta Llama 3.3 70B Instruct](https://github.com/meta-llama/llama-models/blob/main/models/llama3_3/MODEL_CARD.md) | llama-3-3-70b-instruct | 128Kトークン | 8,192トークン | 基盤モデルAPIを通じてすぐに使用可能 | 互換性を確認済み | +# MAGIC +# MAGIC ### SQL変換のためのトークン管理 +# MAGIC 入力となるSQLファイルのトークン数は変換プロセスに直接影響します。考慮すべき2つの重要な要素は次の通りです: +# MAGIC +# MAGIC 1. 入力となるSQLファイルのトークン数よりも生成されるDatabricksノートブックのトークン数の方が通常大きくなります。入力SQLファイルはコメントを除去し複数のスペースを1つにまとめることでトークン数を削減していますが、生成結果のノートブックはコメントやインデントを含むためです +# MAGIC 2. 処理内容が多いSQLファイルの場合、一回の出力でノートブックを生成し切れない場合があります。その場合、生成済みの内容を保持しつつ、続きを段階的に出力する形になります +# MAGIC +# MAGIC これらの要因を考慮し、トークンの使用を効率的に管理することが重要です。モデルの処理可能な容量を超えるファイルは変換に失敗する可能性があるため、`token_count_threshold`パラメータで適切な閾値を設定することで、変換の成功率を高めることができます。 +# MAGIC +# MAGIC #### `token_count_threshold`パラメータの推奨トークン閾値 +# MAGIC `token_count_threshold`パラメータは、トークン数(ファイルサイズではなく)に基づいて処理するSQLファイルを決定します。SQLコンテンツは、SQLコメントと余分な空白を削除した後にトークン化されます。 +# MAGIC +# MAGIC | モデル | 推奨`token_count_threshold` | +# MAGIC |---|---| +# MAGIC | Claude 3.7 Sonnet(通常モード) | 20,000トークン(デフォルト) | +# MAGIC | Claude 3.7 Sonnet(拡張思考モード) | 8,000トークン | +# MAGIC +# MAGIC - 通常モードの20,000トークンという値は、実際のテスト結果に基づいて設定されています。テスト環境では最大60,000トークンまでの処理に成功した例もありますが、20,000トークンを超えると処理の安定性が低下することが確認されています。最も安定した動作を得るために、20,000トークンをデフォルト値として設定しています。 +# MAGIC - 他のモデル(o1、o3-miniなど)も同様に、20,000トークン程度までは比較的安定して動作することが確認されています。理論上はさらに大きな値も処理可能と考えられますが、実際の環境でテストすることをお勧めします。 +# MAGIC - 拡張思考モードの8,000トークンという制限も同様に、実際のテスト結果から導き出されたものです。この値を超えるとエラーが発生したり、結果が返ってこなかったりする場合があります。大規模なSQLファイルを処理する際は、より小さな論理的なセクションに分割することをお勧めします。 +# MAGIC +# MAGIC #### 入力ファイルのトークン数処理 +# MAGIC 01_analyze_input_filesノートブックは、以下の手順で入力ファイルのトークン数を計算します: +# MAGIC +# MAGIC 1. エンドポイント名または明示的なトークナイザー設定に基づいて適切なトークナイザーを決定します: +# MAGIC - Claudeモデルの場合:[Anthropicのドキュメント](https://docs.anthropic.com/en/docs/about-claude/models/all-models)に基づき、文字ベースの推定(約3.4文字/トークン)を使用 +# MAGIC - OpenAIやその他のモデルの場合:[openai/tiktoken](https://github.com/openai/tiktoken)ライブラリと適切なエンコーディングを使用 +# MAGIC 2. SQLコメントと余分な空白を削除した後、選択したトークナイザーを使用して各SQLファイルのトークン数を測定します +# MAGIC 3. トークン数が`token_count_threshold`以下のファイルは変換対象としてマークされます(`is_conversion_target = True`) +# MAGIC 4. 閾値を超えるファイルは変換プロセスから除外されます +# MAGIC +# MAGIC ## 🔌 パラメーター +# MAGIC メインノートブックには以下のパラメーターの設定が必要です。より詳細な設定が必要な場合は、このメインノートブックではなく、個別のノートブックを実行してください。個別のノートブックでは、特定のタスクに対してより詳細なカスタマイズが可能です。 +# MAGIC +# MAGIC 番号 (ID) | パラメーター名 | 必須 | 説明 | デフォルト値 +# MAGIC --- | --- | --- | --- | --- +# MAGIC 1-1 | `input_dir` | Yes | 変換対象のSQLファイル群を格納しているディレクトリ。Pythonの`os`モジュールでアクセス可能な場所をサポート(例:Unity Catalog Volume、Workspace、Reposなど)。 | +# MAGIC 1-2 | `endpoint_name` | Yes | Databricksモデルサービングエンドポイントの名前。「Serving」タブの下でエンドポイント名を確認できます。例:エンドポイントURLが`https:///serving-endpoints/hinak-oneenvgpt4o/invocations`の場合、`hinak-oneenvgpt4o`を指定します。 | `databricks-claude-3-7-sonnet` +# MAGIC 1-3 | `result_catalog` | Yes | 結果テーブルを保存する既存のカタログ。 | +# MAGIC 1-4 | `result_schema` | Yes | 指定された既存のカタログ内の、結果テーブルが配置される既存のスキーマ。 | +# MAGIC 1-5 | `token_count_threshold` | Yes | 変換プロセスに含める対象となるファイルの、SQLコメントを除いた最大トークン数。 | `20000` +# MAGIC 1-x | `existing_result_table` | No | 分析結果の保存に使用する既存の結果テーブル。指定された場合、新しいテーブルを作成する代わりにこのテーブルが使用されます。 | +# MAGIC 2-1 | `sql_dialect` | Yes | The SQL dialect of the input files. This parameter is used to determine the appropriate conversion prompts for the SQL dialect. | `tsql` +# MAGIC 2-2 | `comment_lang` | Yes | 変換したDatabricksノートブックに追加するコメントの言語。 | `English` +# MAGIC 2-3 | `concurrency` | Yes | モデルサービングエンドポイントに送信される同時リクエストの数。 | `4` +# MAGIC 2-4 | `log_level` | Yes | バッチ推論プロセスで使用するログレベル。`INFO`は標準的なログ、`DEBUG`は詳細なデバッグ情報を出力します。 | `INFO` +# MAGIC 2-x | `request_params` | No | JSON形式の追加チャットリクエストパラメータ(例:`{"max_tokens": 8192}`)(参照:[Databricks Foundation Model APIs](https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request))。空の値はモデルのデフォルトパラメーターを使用します。 | +# MAGIC 2-x | `conversion_prompt_yaml` | Yes | 変換プロンプトを含むYAMLファイルへのパス。特定のSQL方言(例:T-SQL)の変換用のシステムメッセージとfew-shot例が定義されている必要があります。 | +# MAGIC 3-1 | `max_fix_attempts` | Yes | 変換結果の構文エラーを自動修正する最大試行回数。 | `1` +# MAGIC 5-1 | `output_dir` | Yes | Databricksノートブックを保存するディレクトリ。WorkspaceまたはRepos内のパスをサポート。 | +# MAGIC +# MAGIC ## 📂 入出力 +# MAGIC 変換プロセスの主な入出力は以下の通りです: +# MAGIC +# MAGIC ### 入力SQLファイル +# MAGIC 入力となるSQLファイル群を`input_files_path`ディレクトリに保存してください。01_analyze_input_filesノートブックは、ディレクトリとそのサブディレクトリのすべてのファイルを処理します。Python `os`モジュールを通じてアクセス可能な場所(Unity Catalog Volume、Workspace、Reposなど)をサポートしています。 +# MAGIC +# MAGIC ### 変換結果ノートブック(最終出力) +# MAGIC 変換結果であるDatabricksノートブック群は05_export_to_databricks_notebooksノートブックによって出力され、変換プロセスの最終出力となります。出力先としてWorkspaceまたはReposのパスをサポートしています。 +# MAGIC +# MAGIC ### 変換結果テーブル(中間出力) +# MAGIC 変換結果テーブルは01_analyze_input_filesノートブックによって作成され、変換プロセスの後続のノートブックの入出力先として利用します。これはDelta Lakeテーブルで、入力SQLファイルの分析結果(トークン数、ファイルメタデータ、変換ステータスなど)を保存します。 +# MAGIC +# MAGIC #### テーブルの命名 +# MAGIC テーブル名は、ノートブックで指定されたパラメータを使用して、以下の形式で作成されます: +# MAGIC +# MAGIC `{result_catalog}.{result_schema}.{result_table_prefix}_{YYYYMMDDHHmm}` +# MAGIC +# MAGIC 例えば、`result_catalog`が"my_catalog"、`result_schema`が"my_schema"、`result_table_prefix`が"conversion_targets"、現在時刻(UTC)が2024-06-14 11:39の場合、テーブル名は以下のようになります: +# MAGIC +# MAGIC `my_catalog.my_schema.conversion_targets_202406141139` +# MAGIC +# MAGIC #### テーブルスキーマ +# MAGIC テーブルスキーマは以下の通りです: +# MAGIC +# MAGIC | カラム名 | データ型 | 説明 | +# MAGIC |---|---|---| +# MAGIC | `input_file_number` | int | 各入力ファイルの一意の整数識別子。番号付けは`1`から始まります。 | +# MAGIC | `input_file_path` | string | 入力ファイルへのフルパス。 | +# MAGIC | `input_file_encoding` | string | 検出された入力ファイルのエンコーディング(例:`UTF-8`)。 | +# MAGIC | `tokenizer_type` | string | トークン数を計算するために使用されたトークナイザーの種類(例:`claude`や`openai`)。 | +# MAGIC | `tokenizer_model` | string | トークナイザーで使用された特定のモデルやエンコーディング(例:Claudeモデル用の`claude`やOpenAIモデル用の`o200k_base`)。 | +# MAGIC | `input_file_token_count` | int | 入力ファイルの総トークン数。 | +# MAGIC | `input_file_token_count_without_sql_comments` | int | SQLコメントを除いた入力ファイルのトークン数。 | +# MAGIC | `input_file_content` | string | 入力ファイルの全内容。 | +# MAGIC | `input_file_content_without_sql_comments` | string | SQLコメントを除いた入力ファイルの内容。 | +# MAGIC | `is_conversion_target` | boolean | ファイルが変換対象かどうか(TrueまたはFalse)。`01_analyze_input_files`においてSQLコメントを除いた入力ファイルのトークン数と`token_count_threshold`との比較に基づいて決定されます。変換処理が正常に終了したら自動的に`True`から`False`に更新されます。 | +# MAGIC | `model_serving_endpoint_for_conversion` | string | 変換処理に使用したモデルサービングエンドポイントの名前。 | +# MAGIC | `model_serving_endpoint_for_fix` | string | 構文エラー修正に使用したモデルサービングエンドポイントの名前。 | +# MAGIC | `request_params_for_conversion` | string | 変換処理のリクエストパラメータ(JSON形式)。 | +# MAGIC | `request_params_for_fix` | string | 構文エラー修正のリクエストパラメータ(JSON形式)。 | +# MAGIC | `result_content` | string | 処理後のファイルの変換内容。(初期値は`null`) | +# MAGIC | `result_prompt_tokens` | int | 変換に使用されたプロンプトのトークン数。(初期値は`null`) | +# MAGIC | `result_completion_tokens` | int | モデルによって生成された完了トークン数。(初期値は`null`) | +# MAGIC | `result_total_tokens` | int | 変換に使用された合計トークン数(プロンプト + 完了)。(初期値は`null`) | +# MAGIC | `result_processing_time_seconds` | float | 変換リクエストの処理にかかった時間(秒)。(初期値は`null`) | +# MAGIC | `result_timestamp` | timestamp | `result_content`が生成または更新された時のタイムスタンプ(UTC)。(初期値は`null`) | +# MAGIC | `result_error` | string | 変換プロセス中に遭遇したエラー。(初期値は`null`) | +# MAGIC | `result_python_parse_error` | string | `ast.parse`を使用したPython関数の構文チェック中に遭遇したエラー。 | +# MAGIC | `result_extracted_sqls` | array | Python関数から抽出されたSQLステートメントのリスト。(初期値は`null`) | +# MAGIC | `result_sql_parse_errors` | array | `EXPLAIN sql`を使用したSQL構文チェック中に遭遇したエラー。(初期値は`null`) | +# MAGIC +# MAGIC ## 🔄 特定のファイルを再変換する方法 +# MAGIC 変換結果に不満がある場合、以下の手順で特定のファイルを再変換できます: +# MAGIC +# MAGIC 1. 11_adjust_conversion_targetsノートブックを使用して、再変換したいファイルの`is_conversion_target`フィールドを`True`に設定します。 +# MAGIC 2. 02_convert_sql_to_databricksノートブックおよび後続の処理を再実行します。`is_conversion_target`が`True`とマークされたファイルのみが再変換されます。 +# MAGIC - LLMの変換プロセスにさらなるランダム性を導入し、実行ごとに異なる結果を得るには、モデルがサポートしている場合は`request_params`の`temperature`を0.5以上に設定することをお勧めします。 +# MAGIC +# MAGIC ## 💻 動作確認済みの環境 +# MAGIC このノートブックは以下の環境で動作確認しています: +# MAGIC +# MAGIC 1. ノートブックおよびジョブ用のDatabricksサーバーレスコンピュート +# MAGIC 2. Databricksクラシックコンピュート(汎用コンピュートおよびジョブコンピュート) +# MAGIC - 推奨設定:シングルノードクラスター、Photonは不要 +# MAGIC - 検証済みのDatabricksランタイム(DBR)バージョン +# MAGIC - 15.3 LTS +# MAGIC - 14.3 LTS + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## 0. パラメーターの設定 +# MAGIC 主要なパラメーターをこのセクションで設定します。他の詳細なパラメーターを変更する必要がある場合は、このノートブックではなく各ノートブックを実行してください。 + +# COMMAND ---------- + +# DBTITLE 1,Install Packages +# MAGIC %pip install -r requirements.txt +# MAGIC dbutils.library.restartPython() + +# COMMAND ---------- + +# DBTITLE 1,Import Libraries +import json + +import pandas as pd + +from pyscripts.conversion_prompt_helper import ConversionPromptHelper +from pyscripts.databricks_credentials import DatabricksCredentials +from pyscripts.notebook_i18n import get_supported_languages + +# COMMAND ---------- + +# DBTITLE 1,Configurations +# 01_analyze_input_files用のパラメータ +dbutils.widgets.text("input_dir", "", "1-1. 入力ディレクトリ") +dbutils.widgets.text("endpoint_name", "databricks-claude-3-7-sonnet", "1-2. サービングエンドポイント名") +dbutils.widgets.text("result_catalog", "", "1-3. 結果カタログ") +dbutils.widgets.text("result_schema", "", "1-4. 結果スキーマ") +dbutils.widgets.text("token_count_threshold", "20000", "1-5. 入力トークン数の閾値") +dbutils.widgets.text("existing_result_table", "", "既存の結果テーブル(任意)") + +# 02_convert_sql_to_databricks用のパラメータ +dbutils.widgets.dropdown("sql_dialect", "tsql", ConversionPromptHelper.get_supported_sql_dialects(), "2-1. SQL方言") +dbutils.widgets.dropdown("comment_lang", "Japanese", get_supported_languages(), "2-2. コメント言語") +dbutils.widgets.text("concurrency", "4", "2-3. 同時リクエスト数") +dbutils.widgets.dropdown("log_level", "INFO", ["DEBUG", "INFO"], "2-4. ログレベル") +dbutils.widgets.text("request_params", "", "チャットリクエストパラメータ(任意)") +dbutils.widgets.text("conversion_prompt_yaml", "", "変換プロンプトのYAMLパス(任意)") + +# 03_syntax_check_and_fix用のパラメータ +dbutils.widgets.text("max_fix_attempts", "1", "3-1. 最大修正試行回数") + +# 05_export_to_databricks_notebooks用のパラメータ +dbutils.widgets.text("output_dir", "", "5-1. 出力ディレクトリ") + +# COMMAND ---------- + +# DBTITLE 1,Load Configurations +input_dir = dbutils.widgets.get("input_dir") +endpoint_name = dbutils.widgets.get("endpoint_name") +result_catalog = dbutils.widgets.get("result_catalog") +result_schema = dbutils.widgets.get("result_schema") +token_count_threshold = int(dbutils.widgets.get("token_count_threshold")) +existing_result_table = dbutils.widgets.get("existing_result_table") +comment_lang = dbutils.widgets.get("comment_lang") +concurrency = int(dbutils.widgets.get("concurrency")) +request_params = dbutils.widgets.get("request_params") +log_level = dbutils.widgets.get("log_level") +max_fix_attempts = int(dbutils.widgets.get("max_fix_attempts")) +output_dir = dbutils.widgets.get("output_dir") + +# 使用するYAMLファイルを判断 +_conversion_prompt_yaml = dbutils.widgets.get("conversion_prompt_yaml") +sql_dialect = dbutils.widgets.get("sql_dialect") + +if _conversion_prompt_yaml: + conversion_prompt_yaml = _conversion_prompt_yaml +else: + conversion_prompt_yaml = ConversionPromptHelper.get_default_yaml_for_sql_dialect(sql_dialect) + +input_dir, endpoint_name, result_catalog, result_schema, token_count_threshold, existing_result_table, conversion_prompt_yaml, comment_lang, request_params, log_level, max_fix_attempts, output_dir + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## 1. 入力ファイルの分析 +# MAGIC 入力SQLファイルの分析とトークン数の計算を行い、結果をDeltaテーブルに保存します。 + +# COMMAND ---------- + +# DBTITLE 1,Analyze Input Files +result_table = dbutils.notebook.run("01_analyze_input_files", 0, { + "input_dir": input_dir, + "result_catalog": result_catalog, + "result_schema": result_schema, + "token_count_threshold": token_count_threshold, + "existing_result_table": existing_result_table, +}) +print(f"Conversion result table: {result_table}") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### 変換対象として選択されたファイル +# MAGIC トークン閾値内のファイル:Databricksノートブックへ変換されます。 + +# COMMAND ---------- + +# DBTITLE 1,Files Selected for Conversion +spark.sql(f""" + SELECT + input_file_number, + input_file_path, + input_file_token_count_without_sql_comments + FROM {result_table} + WHERE is_conversion_target = true + ORDER BY input_file_number +""").display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### トークン閾値を超えるファイル +# MAGIC 閾値を超えるファイル:手動での確認が必要です。(より小さなファイルに分割することを検討してください) + +# COMMAND ---------- + +# DBTITLE 1,Files Exceeding Token Threshold +spark.sql(f""" + SELECT + input_file_number, + input_file_path, + input_file_token_count_without_sql_comments + FROM {result_table} + WHERE is_conversion_target = false + ORDER BY input_file_number +""").display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## 2. SQLからDatabricksへの変換 +# MAGIC LLMを使用してSQLコードをDatabricksノートブックで実行可能なPython関数に変換し、結果テーブルを更新します。 + +# COMMAND ---------- + +# DBTITLE 1,Convert SQL to Databricks Notebooks +dbutils.notebook.run("02_convert_sql_to_databricks", 0, { + "endpoint_name": endpoint_name, + "result_table": result_table, + "conversion_prompt_yaml": conversion_prompt_yaml, + "comment_lang": comment_lang, + "concurrency": concurrency, + "request_params": request_params, + "log_level": log_level, +}) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### 正常に変換されたファイル +# MAGIC 以下の表は、Databricksノートブックに正常に変換されたファイルを示しています。 + +# COMMAND ---------- + +# DBTITLE 1,Successfully Converted Files +spark.sql(f""" + SELECT + input_file_number, + input_file_path, + result_content, + input_file_token_count_without_sql_comments, + result_prompt_tokens, + result_completion_tokens, + result_total_tokens, + result_timestamp + FROM {result_table} + WHERE is_conversion_target = false + AND result_content IS NOT NULL + ORDER BY input_file_number +""").display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### 変換エラーのあるファイル +# MAGIC 以下の表は、変換エラーのあるファイルを示しています。 + +# COMMAND ---------- + +# DBTITLE 1,Files with Conversion Errors +spark.sql(f""" + SELECT + input_file_number, + input_file_path, + result_error, + result_timestamp + FROM {result_table} + WHERE is_conversion_target = true + AND result_error IS NOT NULL + ORDER BY input_file_number +""").display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## 3. 構文チェックと修正 +# MAGIC Python関数とその中のSpark SQLの静的構文チェックを行い、発見されたエラーの修正を試みます。 + +# COMMAND ---------- + +# DBTITLE 1,Function for Syntax Error File Count +def get_error_file_count(result_table: str) -> int: + """Get the count of files with syntax errors.""" + error_count = spark.sql(f""" + SELECT COUNT(*) as error_count + FROM {result_table} + WHERE result_python_parse_error IS NOT NULL + OR (result_sql_parse_errors IS NOT NULL AND size(result_sql_parse_errors) > 0) + """).collect()[0]['error_count'] + return error_count + +# COMMAND ---------- + +# DBTITLE 1,Check and Fix Syntax Errors +for attempt in range(max_fix_attempts): + # Run static syntax check + print(f"Attempt {attempt + 1} of {max_fix_attempts}") + dbutils.notebook.run("03_01_static_syntax_check", 0, { + "result_table": result_table, + }) + + # Check if there are any errors + error_count = get_error_file_count(result_table) + if error_count == 0: + print("No syntax errors found. Exiting fix loop.") + break + + # Run fix syntax error + print(f"Found {error_count} files with syntax errors. Attempting to fix...") + dbutils.notebook.run("03_02_fix_syntax_error", 0, { + "endpoint_name": endpoint_name, + "result_table": result_table, + "concurrency": concurrency, + "request_params": request_params, + "log_level": log_level, + }) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### 最終構文チェック +# MAGIC すべての修正試行後に最終的な静的構文チェックを実行します。 + +# COMMAND ---------- + +# DBTITLE 1,Run Final Syntax Check +dbutils.notebook.run("03_01_static_syntax_check", 0, { + "result_table": result_table, +}) +error_count = get_error_file_count(result_table) +print(f"Found {error_count} files with syntax errors.") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### 構文チェック結果 +# MAGIC 以下の表は、成功したチェックと失敗したチェックの両方を含む、すべてのファイルの構文チェック結果を示しています。 + +# COMMAND ---------- + +# DBTITLE 1,Syntax Check Status +spark.sql(f""" + SELECT + input_file_number, + input_file_path, + result_content, + CASE + WHEN result_python_parse_error IS NULL + AND (result_sql_parse_errors IS NULL OR size(result_sql_parse_errors) = 0) + THEN 'エラー無し' + ELSE 'エラーあり' + END as check_status, + result_python_parse_error, + result_sql_parse_errors + FROM {result_table} + ORDER BY input_file_number +""").display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## 4. セル分割 +# MAGIC 変換されたPythonコードを、Databricksノートブック内での可読性や保守性を高めるために複数のセルに分割します。 + +# COMMAND ---------- + +# DBTITLE 1,Split Cells +dbutils.notebook.run("04_split_cells", 0, { + "result_table": result_table, + "log_level": log_level, +}) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## 5. Databricksノートブックへのエクスポート +# MAGIC 変換されたコードをDatabricksノートブックにエクスポートします。 + +# COMMAND ---------- + +# DBTITLE 1,Export to Databricks Notebooks +export_results_json = dbutils.notebook.run("05_export_to_databricks_notebooks", 0, { + "result_table": result_table, + "output_dir": output_dir, + "comment_lang": comment_lang +}) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC %md +# MAGIC ### 結果の概要 +# MAGIC 以下の表は、すべての入力SQLファイルの変換とエクスポートの結果を示しています。 + +# COMMAND ---------- + +# DBTITLE 1,Conversion and Export Status +# 出力先ディレクトリURLを表示 +full_url = f"{DatabricksCredentials().host}#workspace{output_dir}" +displayHTML(f'

出力先ディレクトリURL: {full_url}

') + +# エクスポート結果の一時ビューを作成 +export_results = json.loads(export_results_json) +export_results_df = pd.DataFrame(export_results) +spark.createDataFrame(export_results_df).createOrReplaceTempView("temp_export_results") + +# すべてのファイルの完全なステータスを表示 +spark.sql(f""" + SELECT + r.input_file_number, + r.input_file_path, + CASE + WHEN r.result_content IS NULL THEN '変換無し' + WHEN r.result_python_parse_error IS NOT NULL OR + (r.result_sql_parse_errors IS NOT NULL AND size(r.result_sql_parse_errors) > 0) + THEN 'エラーありで変換' + ELSE '正常に変換' + END as conversion_status, + CASE + WHEN t.output_file_path IS NOT NULL THEN '正常にエクスポート' + ELSE 'エクスポート無し' + END as export_status, + t.output_file_path, + t.parse_error_count, + r.result_python_parse_error as python_errors, + r.result_sql_parse_errors as sql_errors + FROM {result_table} r + LEFT JOIN temp_export_results t + ON r.input_file_path = t.input_file_path + ORDER BY r.input_file_number +""").display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ### 次のステップ +# MAGIC 以上でSQLからDatabricksへのすべての変換プロセスが完了しました。変換結果は指定した出力ディレクトリにある変換済みノートブックで確認できます。これらの結果を十分に確認し、変換されたコードが要件を満たし、Databricks環境で正しく動作することを確認しましょう。 +# MAGIC +# MAGIC 結果を確認する際は、特に以下のケースに注意して適切に対応してください: +# MAGIC +# MAGIC 1. `変換無し` ステータスのファイル: +# MAGIC - これらの入力ファイルは、多くの場合トークン数の上限を超えているため処理できませんでした。 +# MAGIC - 対応:これらのファイルを小さく分割するか、LLMモデルがより大きな入力を処理できる場合は`token_count_threshold`パラメータを増やして再実行してください。 +# MAGIC +# MAGIC 2. `エラーありで変換` ステータスのファイル: +# MAGIC - 変換はされたものの、構文エラーを含んでいます。 +# MAGIC - 対応:出力ノートブック下部のエラーメッセージを確認し、変換されたノートブックで手動修正を行ってください。 +# MAGIC +# MAGIC 3. `エクスポート無し` ステータスのファイル: +# MAGIC - このステータスは稀ですが、変換後のコンテンツが大きすぎる場合に発生します。 +# MAGIC - エクスポートプロセスで`Content size exceeds 10MB limit` (コンテンツサイズが10MB制限を超えています) というメッセージが表示された場合、入力ファイルが大きすぎることを意味します。 +# MAGIC - 対応:入力SQLファイルのサイズを確認し、必要に応じて小さくしてから再度変換を試みてください。 diff --git a/sql2dbx/notebooks/01_analyze_input_files.py b/sql2dbx/notebooks/01_analyze_input_files.py new file mode 100644 index 00000000..bbdb81ab --- /dev/null +++ b/sql2dbx/notebooks/01_analyze_input_files.py @@ -0,0 +1,237 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # Analyze Input Files +# MAGIC This notebook is designed to count the tokens in files within the specified directory, aiding in understanding the size and structure of files for use with Large Language Models (LLMs). The results are saved to the target table in Delta Lake format. +# MAGIC +# MAGIC ## Task Overview +# MAGIC The following tasks are accomplished in this notebook: +# MAGIC +# MAGIC 1. **Directory Scanning**: The specified directory is scanned for files, and each file is prepared for analysis. +# MAGIC 2. **Tokenization**: Files are tokenized using the specified encoding to count the tokens effectively. +# MAGIC 3. **Result Compilation and Saving**: The token counts, along with file metadata, are compiled into a structured format. Files exceeding a predefined token threshold are filtered out. The results are saved to a Delta Lake table for further analysis or reference. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Install and import libraries + +# COMMAND ---------- + +# DBTITLE 1,Install Packages +# MAGIC %pip install -r requirements.txt +# MAGIC dbutils.library.restartPython() + +# COMMAND ---------- + +# DBTITLE 1,Import Libraries +from datetime import datetime, timezone + +from pyspark.sql.functions import col, lit, when +from pyspark.sql.types import (ArrayType, IntegerType, StringType, StructField, + StructType, TimestampType, FloatType) +from pyspark.sql.utils import AnalysisException + +from pyscripts.llm_token_count_helper import FileTokenCountHelper + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Set up configuration parameters + +# COMMAND ---------- + +# DBTITLE 1,Configurations +# Required Parameters +dbutils.widgets.text("input_dir", "", "Input Directory") +dbutils.widgets.text("endpoint_name", "databricks-claude-3-7-sonnet", "Serving Endpoint Name") +dbutils.widgets.text("result_catalog", "", "Result Catalog") +dbutils.widgets.text("result_schema", "", "Result Schema") + +# Required Parameters with Default Values +dbutils.widgets.dropdown("is_sql", "True", ["True", "False"], "Is SQL files or not") +dbutils.widgets.text("token_count_threshold", "20000", "Token Count Threshold") +dbutils.widgets.text("result_table_prefix", "conversion_targets", "Result Table Prefix") + +# Optional Parameters +dbutils.widgets.text("file_encoding", "", "File Encoding (Optional)") +dbutils.widgets.text("existing_result_table", "", "Existing Result Table (Optional)") +dbutils.widgets.text("tokenizer_type", "", "Tokenizer Type (Optional)") +dbutils.widgets.text("tokenizer_model", "", "Tokenizer Model (Optional)") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Parameters +# MAGIC Parameter Name | Required | Description | Default Value +# MAGIC --- | --- | --- | --- +# MAGIC `input_dir` | Yes | The directory containing the files for analysis. Supports locations accessible through Python `os` module (e.g., Unity Catalog Volume, Workspace, Repos, etc.). | +# MAGIC `endpoint_name` | Yes | The name of the Databricks Model Serving endpoint. You can find the endpoint name under the `Serving` tab. Example: If the endpoint URL is `https:///serving-endpoints/hinak-oneenvgpt4o/invocations`, specify `hinak-oneenvgpt4o`. | `databricks-claude-3-7-sonnet` +# MAGIC `result_catalog` | Yes | The existing catalog where the result table will be stored. | +# MAGIC `result_schema` | Yes | The existing schema under the specified catalog where the result table will reside. | +# MAGIC `is_sql` | Yes | Indicates whether the files in the directory are SQL files. If `True`, contents without SQL comments and token count will be added to the result; if `False`, these will be `None`. | `True` +# MAGIC `token_count_threshold` | Yes | Specifies the maximum token count allowed without SQL comments for files to be included in the following conversion process. | `20000` +# MAGIC `result_table_prefix` | Yes | The prefix for the result table name where the results will be stored. | `conversion_targets` +# MAGIC `file_encoding` | No | The encoding used for reading files. If unspecified, the notebook will attempt to detect the encoding automatically. | +# MAGIC `existing_result_table` | No | An optional parameter for subsequent runs. If this table exists, the notebook's processing will be skipped and the value of this parameter will be returned as output of this notebook. | +# MAGIC `tokenizer_type` | No | The type of tokenizer to use ('claude' or 'openai'). Only used if more specific control is needed than endpoint_name provides. | `""` +# MAGIC `tokenizer_model` | No | The specific model to use for tokenization. Only used if tokenizer_type is also specified. | `""` + +# COMMAND ---------- + +# DBTITLE 1,Load Configurations +input_dir = dbutils.widgets.get("input_dir") +endpoint_name = dbutils.widgets.get("endpoint_name") +file_encoding = dbutils.widgets.get("file_encoding") if dbutils.widgets.get("file_encoding") else None +is_sql = dbutils.widgets.get("is_sql") == "True" +token_count_threshold = int(dbutils.widgets.get("token_count_threshold")) +result_catalog = dbutils.widgets.get("result_catalog") +result_schema = dbutils.widgets.get("result_schema") +result_table_prefix = dbutils.widgets.get("result_table_prefix") +existing_result_table = dbutils.widgets.get("existing_result_table") +tokenizer_type = dbutils.widgets.get("tokenizer_type") if dbutils.widgets.get("tokenizer_type") else None +tokenizer_model = dbutils.widgets.get("tokenizer_model") if dbutils.widgets.get("tokenizer_model") else None + +input_dir, endpoint_name, file_encoding, is_sql, token_count_threshold, result_catalog, result_schema, result_table_prefix, existing_result_table, tokenizer_type, tokenizer_model + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Check if existing_result_table exists + +# COMMAND ---------- + +# DBTITLE 1,Check Result Table Existance +if existing_result_table: + try: + spark.table(existing_result_table) + dbutils.notebook.exit(existing_result_table) + except AnalysisException: + print(f"'existing_result_table' is specified but the table does not exist. Continuing with the notebook processing.") + pass +else: + print("The parameter 'existing_result_table' is not specified. Continuing with the notebook processing.") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Count tokens in all files within the specified directory + +# COMMAND ---------- + +# DBTITLE 1,Count Tokens +# Initialize FileTokenCountHelper with the appropriate parameters +# Priority: 1. Use explicit tokenizer_type if provided +# 2. Otherwise, use endpoint_name to determine the tokenizer +if tokenizer_type: + print(f"Using explicit tokenizer type: {tokenizer_type}") + helper = FileTokenCountHelper(tokenizer_type=tokenizer_type, tokenizer_model=tokenizer_model) +else: + print(f"Using endpoint name: {endpoint_name} to determine tokenizer") + helper = FileTokenCountHelper(endpoint_name=endpoint_name) + +# Print the actual tokenizer being used +print(f"Tokenizer selected: {helper.tokenizer_type} (model: {helper.tokenizer_model})") +results = helper.process_directory(input_dir=input_dir, file_encoding=file_encoding, is_sql=is_sql) + +# COMMAND ---------- + +# DBTITLE 1,Create Spark DataFrame +schema = StructType([ + StructField("input_file_number", IntegerType(), True), + StructField("input_file_path", StringType(), True), + StructField("input_file_encoding", StringType(), True), + StructField("tokenizer_type", StringType(), True), + StructField("tokenizer_model", StringType(), True), + StructField("input_file_token_count", IntegerType(), True), + StructField("input_file_token_count_without_sql_comments", IntegerType(), True), + StructField("input_file_content", StringType(), True), + StructField("input_file_content_without_sql_comments", StringType(), True), + StructField("is_conversion_target", StringType(), True), + StructField("model_serving_endpoint_for_conversion", StringType(), True), + StructField("model_serving_endpoint_for_fix", StringType(), True), + StructField("request_params_for_conversion", StringType(), True), + StructField("request_params_for_fix", StringType(), True), + StructField("result_content", StringType(), True), + StructField("result_prompt_tokens", IntegerType(), True), + StructField("result_completion_tokens", IntegerType(), True), + StructField("result_total_tokens", IntegerType(), True), + StructField("result_processing_time_seconds", FloatType(), True), + StructField("result_timestamp", TimestampType(), True), + StructField("result_error", StringType(), True), + StructField("result_python_parse_error", StringType(), True), + StructField("result_extracted_sqls", ArrayType(StringType()), True), + StructField("result_sql_parse_errors", ArrayType(StringType()), True), +]) + +result_df = (spark + .createDataFrame(results, schema=schema) + .withColumn("is_conversion_target", + when(col("input_file_token_count_without_sql_comments") > token_count_threshold, False) + .otherwise(True)) + .withColumn("model_serving_endpoint_for_conversion", lit(None).cast(StringType())) + .withColumn("model_serving_endpoint_for_fix", lit(None).cast(StringType())) + .withColumn("request_params_for_conversion", lit(None).cast(StringType())) + .withColumn("request_params_for_fix", lit(None).cast(StringType())) + .withColumn("result_content", lit(None).cast(StringType())) + .withColumn("result_prompt_tokens", lit(None).cast(IntegerType())) + .withColumn("result_completion_tokens", lit(None).cast(IntegerType())) + .withColumn("result_total_tokens", lit(None).cast(IntegerType())) + .withColumn("result_processing_time_seconds", lit(None).cast(FloatType())) + .withColumn("result_timestamp", lit(None).cast(TimestampType())) + .withColumn("result_error", lit(None).cast(StringType())) + .withColumn("result_python_parse_error", lit(None).cast(StringType())) + .withColumn("result_extracted_sqls", lit(None).cast(ArrayType(StringType()))) + .withColumn("result_sql_parse_errors", lit(None).cast(ArrayType(StringType()))) + ) + +display(result_df) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Excluded files from conversion process +# MAGIC Files exceeding the `token_count_threshold` are excluded from further conversion processing. Consider splitting these files manually or adjusting the threshold as needed. + +# COMMAND ---------- + +# DBTITLE 1,Warning for Token Count Threshold +warning_df = result_df.filter(col("is_conversion_target") == False) +if warning_df.count() > 0: + print(f"Warning: The following files do not meet the token count threshold of " + f"{token_count_threshold} and are excluded from conversion process.") + display(warning_df) +else: + print("No issues found. All files meet the token count threshold.") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Save the result dataframe into a target delta table + +# COMMAND ---------- + +# DBTITLE 1,Define Target Table +current_time = datetime.now(timezone.utc).strftime("%Y%m%d%H%M") +result_table = f"{result_catalog}.{result_schema}.{result_table_prefix}_{current_time}" +print(result_table) + +# COMMAND ---------- + +# DBTITLE 1,Save Result +result_df.write.format("delta").mode("overwrite").saveAsTable(result_table) +print(f"Successfully saved result into the table: {result_table}") + +# COMMAND ---------- + +# DBTITLE 1,Display Result Table +spark.table(result_table).display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Return the result table name + +# COMMAND ---------- + +# DBTITLE 1,Return Result Table Name +dbutils.notebook.exit(result_table) diff --git a/sql2dbx/notebooks/02_convert_sql_to_databricks.py b/sql2dbx/notebooks/02_convert_sql_to_databricks.py new file mode 100644 index 00000000..070df6dc --- /dev/null +++ b/sql2dbx/notebooks/02_convert_sql_to_databricks.py @@ -0,0 +1,248 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # Convert SQL to Databricks +# MAGIC This notebook facilitates the conversion of SQL-based workflows such as T-SQL Stored Procedures into Databricks-compatible code, allowing for the seamless migration to the Databricks environment. +# MAGIC +# MAGIC This notebook is inspired by the following reference: [chat-batch-inference-api - Databricks](https://learn.microsoft.com/en-us/azure/databricks/_extras/notebooks/source/machine-learning/large-language-models/chat-batch-inference-api.html). +# MAGIC +# MAGIC ## Task Overview +# MAGIC The following tasks are accomplished in this notebook: +# MAGIC +# MAGIC 1. **Configure SQL Dialect**: By specifying a custom YAML file in the `conversion_prompt_yaml` parameter, you can adapt the conversion process for various SQL dialects (e.g., T-SQL, PL/SQL, PostgreSQL). Each YAML file defines the system message and few-shot examples tailored to that specific SQL dialect. +# MAGIC 2. **Read Data**: Data is read from the input table and specified columns. The input table is assumed to have been created in the preceding notebook (01_analyze_input_files), and the `input_file_content_without_sql_comments` column is utilized for processing. +# MAGIC 3. **Request Construction and Submission**: Requests are constructed and sent to the specified Databricks model serving endpoint with concurrent processing. +# MAGIC 4. **Persist Results**: The results of the conversion process are added to the input table. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Install and import libraries + +# COMMAND ---------- + +# DBTITLE 1,Install Packages +# MAGIC %pip install -r requirements.txt +# MAGIC dbutils.library.restartPython() + +# COMMAND ---------- + +# DBTITLE 1,Import Libraries +import json + +from pyscripts.batch_inference_helper import (AsyncChatClient, + BatchInferenceManager, + BatchInferenceRequest) +from pyscripts.conversion_prompt_helper import ConversionPromptHelper + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Set up configuration parameters +# MAGIC + +# COMMAND ---------- + +# DBTITLE 1,Configurations +# Required Parameters +dbutils.widgets.text("result_table", "", "Conversion Result Table (Required)") + +# Required Parameters with Default Values +dbutils.widgets.text("endpoint_name", "databricks-claude-3-7-sonnet", "Serving Endpoint Name (Required)") +dbutils.widgets.text("conversion_prompt_yaml", + "pyscripts/conversion_prompt_yaml/tsql_to_databricks_notebook.yml", "YAML path for Conversion Prompt") +dbutils.widgets.dropdown("comment_lang", "English", [ + "Chinese", "English", "French", "German", "Italian", "Japanese", "Korean", "Portuguese", "Spanish"], "Comment Language") +dbutils.widgets.text("concurrency", "4", "Concurrency Requests") +dbutils.widgets.dropdown("log_level", "INFO", ["DEBUG", "INFO"], "Log Level") +dbutils.widgets.text("logging_interval", "1", "Logging Interval") +dbutils.widgets.text("timeout", "300", "Timeout Seconds") +dbutils.widgets.text("max_retries_backpressure", "10", "Max Retries on Backpressure") +dbutils.widgets.text("max_retries_other", "3", "Max Retries on Other Errors") + +# Optional Parameters +dbutils.widgets.text("request_params", "", "Chat Request Params") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Parameters +# MAGIC +# MAGIC Parameter Name | Required | Description | Default Value +# MAGIC --- | --- | --- | --- +# MAGIC `result_table` | Yes | The name of the conversion result table created in the previous notebook. | +# MAGIC `endpoint_name` | Yes | The name of the Databricks Model Serving endpoint. You can find the endpoint name under the `Serving` tab. Example: If the endpoint URL is `https:///serving-endpoints/hinak-oneenvgpt4o/invocations`, specify `hinak-oneenvgpt4o`. | `databricks-claude-3-7-sonnet` +# MAGIC `conversion_prompt_yaml` | Yes | The path to the YAML file containing the conversion prompts. Specify either a relative path from the notebook or a full path. | `pyscripts/conversion_prompt_yaml/tsql_to_databricks_notebook.yml` +# MAGIC `comment_lang` | Yes | The language for comments to be added to the converted Databricks notebooks. | `English` +# MAGIC `concurrency` | Yes | The number of concurrent requests sent to the model serving endpoint. | `4` +# MAGIC `log_level` | Yes | The logging level to use for the batch inference process. Options are `INFO` for standard logging or `DEBUG` for detailed debug information. | `INFO` +# MAGIC `logging_interval` | Yes | The number of requests processed before logging a progress update. Controls the frequency of progress reports during batch processing, showing the total requests processed and elapsed time. | `1` +# MAGIC `timeout` | Yes | The timeout for an HTTP request on the client side, in seconds. | `300` +# MAGIC `max_retries_backpressure` | Yes | The maximum number of retries on backpressure status code (such as `429` or `503`). | `10` +# MAGIC `max_retries_other` | Yes | The maximum number of retries on other errors (such as `5xx`, `408`, or `409`). | `3` +# MAGIC `request_params` | No | The extra chat request parameters in JSON format (reference: [Databricks Foundation Model APIs](https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request)). Empty value will use model's default parameters. | + +# COMMAND ---------- + +# DBTITLE 1,Load Configurations +# Load configurations from widgets +config_endpoint_name = dbutils.widgets.get("endpoint_name") +config_result_table = dbutils.widgets.get("result_table") +config_conversion_prompt_yaml = dbutils.widgets.get("conversion_prompt_yaml") +config_comment_lang = dbutils.widgets.get("comment_lang") +config_concurrecy = int(dbutils.widgets.get("concurrency")) +config_log_level = dbutils.widgets.get("log_level") +config_logging_interval = int(dbutils.widgets.get("logging_interval")) +config_timeout = int(dbutils.widgets.get("timeout")) +config_max_retries_backpressure = int(dbutils.widgets.get("max_retries_backpressure")) +config_max_retries_other = int(dbutils.widgets.get("max_retries_other")) + +# Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request +_request_params = dbutils.widgets.get("request_params") +config_request_params = json.loads(_request_params) if _request_params.strip() else None + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## System message & few-shots from YAML configuration +# MAGIC The system message and few-shot examples are loaded from the YAML configuration file specified in the `conversion_prompt_yaml` parameter. By creating and specifying different YAML files, you can adapt the conversion process for various SQL dialects (e.g., T-SQL, PL/SQL, PostgreSQL) without modifying the notebook code. + +# COMMAND ---------- + +# DBTITLE 1,Load System Message and Few-Shots +conv_prompt_helper = ConversionPromptHelper( + conversion_prompt_yaml=config_conversion_prompt_yaml, + comment_lang=config_comment_lang, +) +system_message = conv_prompt_helper.get_system_message() +few_shots = conv_prompt_helper.get_few_shots() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Run batch inference +# MAGIC The following code loads a Spark dataframe of the input data table and then converts that dataframe into a list of text that the model can process. + +# COMMAND ---------- + +# DBTITLE 1,Retrieve Pre-Update Data for Batch Inference Result Update +source_sdf = spark.table(config_result_table) +display(source_sdf) + +# COMMAND ---------- + +# DBTITLE 1,Retrieve Data for Batch Inference +input_sdf = spark.sql(f""" + SELECT input_file_number, input_file_content_without_sql_comments + FROM {config_result_table} + WHERE is_conversion_target = true +""") +display(input_sdf) + +# Check if there are any records +if input_sdf.count() == 0: + raise Exception( + "No records found for conversion. Please check if there are any records with is_conversion_target = true in the result table.") + +# COMMAND ---------- + +# DBTITLE 1,Create Batch Inference Requests +batch_inference_requests = [ + BatchInferenceRequest( + index=int(row[0]), + text=row[1], + system_message=system_message, + few_shots=few_shots) + for row in input_sdf.toPandas().itertuples(index=False, name=None) +] + +# COMMAND ---------- + +# DBTITLE 1,Display Batch Inference Requests +display_df = spark.createDataFrame([ + (req.index, req.text, req.system_message, str(req.few_shots)) + for req in batch_inference_requests +], ["index", "text", "system_message", "few_shots"]) + +display(display_df) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC The following records and stores the batch inference responses. + +# COMMAND ---------- + +# DBTITLE 1,Create Batch Inference Manager +batch_manager = BatchInferenceManager( + client=AsyncChatClient( + endpoint_name=config_endpoint_name, + request_params=config_request_params, + timeout=config_timeout, + max_retries_backpressure=config_max_retries_backpressure, + max_retries_other=config_max_retries_other, + log_level=config_log_level, + ), + concurrency=config_concurrecy, + logging_interval=config_logging_interval, + log_level=config_log_level, +) + +# COMMAND ---------- + +# DBTITLE 1,Batch Inference +batch_inference_responses = await batch_manager.batch_inference(batch_inference_requests) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Save results +# MAGIC The following stores the output to the result table and displays the results + +# COMMAND ---------- + +# DBTITLE 1,Load Notebook Utils +# MAGIC %run ./notebook_utils + +# COMMAND ---------- + +# DBTITLE 1,Organize Output +batch_inference_result_processor = BatchInferenceResultProcessor( + model_serving_endpoint_for_conversion=config_endpoint_name, + request_params_for_conversion=config_request_params, +) +output_sdf = batch_inference_result_processor.process_results(source_sdf, batch_inference_responses) +display(output_sdf) + +# COMMAND ---------- + +# DBTITLE 1,Save Result +output_sdf.write.mode("overwrite").saveAsTable(config_result_table) +print(f"Successfully saved result into the table: {config_result_table}") + +# COMMAND ---------- + +# DBTITLE 1,Display Result Table +spark.table(config_result_table).display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Cleaning results +# MAGIC The following performs cleaning on `result_content`. The reason for saving the data first and then performing cleaning is to enable time travel in case there are any issues with the cleaning process. + +# COMMAND ---------- + +# DBTITLE 1,Clean Result +cleand_df = clean_conversion_results(config_result_table) +display(cleand_df) + +# COMMAND ---------- + +# DBTITLE 1,Save Cleaned Result +cleand_df.write.mode("overwrite").saveAsTable(config_result_table) +print(f"Successfully saved cleaned result into the table: {config_result_table}") + +# COMMAND ---------- + +# DBTITLE 1,Display Cleaned Result Table +spark.table(config_result_table).display() diff --git a/sql2dbx/notebooks/03_01_static_syntax_check.py b/sql2dbx/notebooks/03_01_static_syntax_check.py new file mode 100644 index 00000000..89efe086 --- /dev/null +++ b/sql2dbx/notebooks/03_01_static_syntax_check.py @@ -0,0 +1,166 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # Static Syntax Check +# MAGIC This notebook performs a static syntax check on Python functions and the SQL statements extracted from these functions. The main goals of this notebook are: +# MAGIC +# MAGIC 1. **Python Function Syntax Check**: Using `ast.parse` to parse the input strings into an Abstract Syntax Tree (AST). Only correctly formatted Python functions can be parsed, thus ensuring static syntax correctness. +# MAGIC 2. **SQL Syntax Check**: Extracting SQL statements from the parsed Python functions and verifying their syntax using Spark's SQL parser. +# MAGIC +# MAGIC The extraction of SQL from Python functions is handled by the `spark_sql_extract_helper.py` script. The SQL syntax check is done using `EXPLAIN sql`. +# MAGIC +# MAGIC ## Task Overview +# MAGIC The following tasks are accomplished in this notebook: +# MAGIC +# MAGIC 1. **Load Data**: The data is loaded from the specified result table. +# MAGIC 2. **Parse Python Function and Extract SQL Statements**: Python functions are parsed using `ast.parse` to ensure they are valid, and SQL statements are extracted using the script. +# MAGIC 3. **Parse SQL Statements**: The extracted SQL statements are parsed to check for syntax errors using Spark's SQL parser. +# MAGIC 4. **Save Results**: The original data along with any syntax errors are saved back to the specified result table by adding new columns. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Install and import libraries + +# COMMAND ---------- + +# DBTITLE 1,Install Packages +# MAGIC %pip install -r requirements.txt +# MAGIC dbutils.library.restartPython() + +# COMMAND ---------- + +# DBTITLE 1,Import Libraries +from typing import List, Tuple + +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.functions import udf +from pyspark.sql.types import ArrayType, StringType, StructField, StructType +from pyscripts.spark_sql_extract_helper import SparkSQLExtractHelper + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Set up configuration parameters + +# COMMAND ---------- + +# DBTITLE 1,Configurations +dbutils.widgets.text("result_table", "", "Conversion Result Table (Required)") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Parameters +# MAGIC Parameter Name | Required | Description +# MAGIC --- | --- | --- +# MAGIC `result_table` | Yes | The name of the conversion result table created in the previous notebook. + +# COMMAND ---------- + +# DBTITLE 1,Load Configurations +result_table = dbutils.widgets.get("result_table") +result_table + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Display original table + +# COMMAND ---------- + +# DBTITLE 1,Display Original +original_df = spark.table(result_table) +display(original_df) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Parse Python function and extract SQL statements + +# COMMAND ---------- + +# DBTITLE 1,Define UDF +def extract_sqls(func_string: str) -> Tuple[str, List[str]]: + helper = SparkSQLExtractHelper() + return helper.extract_sql_from_string(func_string) + + +extract_sqls_udf = udf(extract_sqls, StructType([ + StructField("result_python_parse_error", StringType(), True), + StructField("result_extracted_sqls", ArrayType(StringType()), True) +])) + +# COMMAND ---------- + +# DBTITLE 1,Parse Python and Extract SQLs +# Extract SQL from the result_content column, adding error messages and SQL list as new columns +new_df = original_df.withColumn("parsed_result", extract_sqls_udf("result_content")) + +# Split parsed_result into two new columns: python_parse_error_str and extracted_sqls +new_df = (new_df + .withColumn("result_python_parse_error", new_df["parsed_result"].getItem("result_python_parse_error")) + .withColumn("result_extracted_sqls", new_df["parsed_result"].getItem("result_extracted_sqls")) + .drop("parsed_result") + ) + +# Display results +display(new_df) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Parse SQL statements + +# COMMAND ---------- + +# DBTITLE 1,Define SQL Parse Function +# Function to parse SQL statements on the driver +def parse_sql_statements(df: DataFrame) -> List[Tuple[str, List[str]]]: + result = [] + rows = df.collect() + for row in rows: + errors_for_row = [] + for idx, sql in enumerate(row['result_extracted_sqls']): + try: + # Attempt to parse and plan the SQL without actually running it + spark.sql(f"EXPLAIN {sql}") + except Exception as e: + error_message = str(e) + # If the error message contains "JVM stacktrace", remove it + jvm_index = error_message.find("JVM stacktrace:") + if jvm_index != -1: + error_message = error_message[:jvm_index].strip() + errors_for_row.append(f"Error in query {idx}: {error_message}") + result.append((row["input_file_number"], errors_for_row)) + return result + +# COMMAND ---------- + +# DBTITLE 1,Create DataFrame from Parsed Data +parsed_data = parse_sql_statements(new_df) +parsed_schema = StructType([ + StructField("input_file_number", StringType(), True), + StructField("result_sql_parse_errors", ArrayType(StringType()), True) +]) +parsed_errors_df = spark.createDataFrame(parsed_data, parsed_schema) +display(parsed_errors_df) + +# COMMAND ---------- + +# DBTITLE 1,Add Parsed Errors to DataFrame +final_df = (new_df + .drop("result_sql_parse_errors") + .join(parsed_errors_df, on="input_file_number", how="left") + ) +display(final_df) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Update table + +# COMMAND ---------- + +# DBTITLE 1,Update Table +final_df.write.mode("overwrite").saveAsTable(result_table) +print(f"Changes applied to the result table: {result_table}.") diff --git a/sql2dbx/notebooks/03_02_fix_syntax_error.py b/sql2dbx/notebooks/03_02_fix_syntax_error.py new file mode 100644 index 00000000..84307491 --- /dev/null +++ b/sql2dbx/notebooks/03_02_fix_syntax_error.py @@ -0,0 +1,258 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # Fix Syntax Error +# MAGIC This notebook is designed to fix syntax errors in Python functions and SQL statements that were identified in the previous static syntax check. The main objectives of this notebook are: +# MAGIC +# MAGIC 1. **Error Identification**: Retrieve Python functions and SQL statements that failed the syntax check from the result table. +# MAGIC 2. **Error Correction**: Use a Language Model (LLM) to automatically fix the identified syntax errors in both Python and SQL. +# MAGIC 3. **Batch Processing**: Implement a batch inference system to efficiently process multiple error corrections concurrently. +# MAGIC +# MAGIC The notebook utilizes a Databricks Serving endpoint to access the LLM for error correction. It processes the errors in batches, allowing for efficient handling of multiple correction requests simultaneously. +# MAGIC +# MAGIC ## Task Overview +# MAGIC The following tasks are accomplished in this notebook: +# MAGIC +# MAGIC 1. **Load Error Data**: Extract functions with syntax errors from the result table. +# MAGIC 2. **Process Batch Inference**: Send requests to the LLM in batches and collect the corrected code. +# MAGIC 3. **Save Corrected Results**: Store the corrected code back into the result table. +# MAGIC 4. **Clean Results**: Perform additional cleaning on the corrected code for consistency. +# MAGIC +# MAGIC This notebook plays a crucial role in automating the error correction process, significantly reducing the manual effort required to fix syntax errors in a large codebase. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Install and import libraries + +# COMMAND ---------- + +# DBTITLE 1,Install Packages +# MAGIC %pip install -r requirements.txt +# MAGIC dbutils.library.restartPython() + +# COMMAND ---------- + +# DBTITLE 1,Import Libraries +import json +from typing import Optional + +from pyscripts.batch_inference_helper import (AsyncChatClient, + BatchInferenceManager, + BatchInferenceRequest) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Set up configuration parameters + +# COMMAND ---------- + +# DBTITLE 1,Configurations +# Required Parameters +dbutils.widgets.text("result_table", "", "Conversion Result Table (Required)") + +# Required Parameters with Default Values +dbutils.widgets.text("endpoint_name", "databricks-claude-3-7-sonnet", "Serving Endpoint Name (Required)") +dbutils.widgets.text("concurrency", "4", "Concurrency Requests") +dbutils.widgets.dropdown("log_level", "INFO", ["DEBUG", "INFO"], "Logging Level") +dbutils.widgets.text("logging_interval", "1", "Logging Interval") +dbutils.widgets.text("timeout", "300", "Timeout Seconds") +dbutils.widgets.text("max_retries_backpressure", "10", "Max Retries on Backpressure") +dbutils.widgets.text("max_retries_other", "3", "Max Retries on Other Errors") + +# Optional Parameters +dbutils.widgets.text("request_params", "", "Chat Request Params") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Parameters +# MAGIC +# MAGIC Parameter Name | Required | Description | Default Value +# MAGIC --- | --- | --- | --- +# MAGIC `result_table` | Yes | The name of the conversion result table created in the previous notebook. | +# MAGIC `endpoint_name` | Yes | The name of the Databricks Model Serving endpoint. You can find the endpoint name under the `Serving` tab. Example: If the endpoint URL is `https:///serving-endpoints/hinak-oneenvgpt4o/invocations`, specify `hinak-oneenvgpt4o`. | `databricks-claude-3-7-sonnet` +# MAGIC `concurrency` | Yes | The number of concurrent requests sent to the model serving endpoint. | `4` +# MAGIC `log_level` | Yes | The logging level to use for the batch inference process. Options are `INFO` for standard logging or `DEBUG` for detailed debug information. | `INFO` +# MAGIC `logging_interval` | Yes | The number of requests processed before logging a progress update. Controls the frequency of progress reports during batch processing, showing the total requests processed and elapsed time. | `1` +# MAGIC `timeout` | Yes | The timeout for an HTTP request on the client side, in seconds. | `300` +# MAGIC `max_retries_backpressure` | Yes | The maximum number of retries on backpressure status code (such as `429` or `503`). | `10` +# MAGIC `max_retries_other` | Yes | The maximum number of retries on other errors (such as `5xx`, `408`, or `409`). | `3` +# MAGIC `request_params` | No | The extra chat request parameters in JSON format (reference: [Databricks Foundation Model APIs](https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request)). Empty value will use model's default parameters. | + +# COMMAND ---------- + +# DBTITLE 1,Load Configurations +# Load configurations from widgets +config_endpoint_name = dbutils.widgets.get("endpoint_name") +config_result_table = dbutils.widgets.get("result_table") +config_concurrecy = int(dbutils.widgets.get("concurrency")) +config_log_level = dbutils.widgets.get("log_level") +config_logging_interval = int(dbutils.widgets.get("logging_interval")) +config_timeout = int(dbutils.widgets.get("timeout")) +config_max_retries_backpressure = int(dbutils.widgets.get("max_retries_backpressure")) +config_max_retries_other = int(dbutils.widgets.get("max_retries_other")) + +# Reference: https://docs.databricks.com/en/machine-learning/foundation-models/api-reference.html#chat-request +_request_params = dbutils.widgets.get("request_params") +config_request_params = json.loads(_request_params) if _request_params.strip() else None + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Run batch inference +# MAGIC The following code loads a Spark dataframe of the input data table and then converts that dataframe into a list of text that the model can process. + +# COMMAND ---------- + +# DBTITLE 1,Function for Creating System Message +def create_system_message(python_error: Optional[str], sql_error: Optional[str]) -> str: + """ + Create a system message for an LLM to fix errors in Python code running in a Databricks notebook. + + Args: + python_error (Optional[str]): The Python parsing error message, if any. + sql_error (Optional[str]): The Spark SQL-related error message, if any. + + Returns: + str: A formatted system message with instructions and error details. + """ + message = f"""Fix the following errors in the Python code running in a Databricks notebook. +The code contains Spark SQL queries, and most errors are Spark SQL-related. + +Instructions: +1. Output only Python code and comments. No other text allowed. +2. Do not add explanations outside of Python code. +3. If asked to continue, resume the code without adding extra phrases. +4. Do not omit any part of the code. +5. Ensure proper handling of Spark SQL queries in the Databricks environment. +6. Prioritize fixing Spark SQL-related errors. + +Errors to fix: +""" + if python_error: + message += f"{python_error}\n" + if sql_error: + message += f"{sql_error}\n" + return message + +# COMMAND ---------- + +# DBTITLE 1,Extract Input Data +input_sdf = spark.sql(f""" + SELECT + input_file_number, + result_content, + result_python_parse_error, + result_sql_parse_errors + FROM {config_result_table} + WHERE result_python_parse_error IS NOT NULL + OR (result_sql_parse_errors IS NOT NULL AND size(result_sql_parse_errors) > 0) +""") +display(input_sdf) + +# COMMAND ---------- + +# DBTITLE 1,Create Batch Inference Requests +input_data = input_sdf.collect() +batch_inference_requests = [ + BatchInferenceRequest( + index=row['input_file_number'], + text=row['result_content'], + system_message=create_system_message( + row['result_python_parse_error'], row['result_sql_parse_errors'])) + for row in input_data +] + +# COMMAND ---------- + +# DBTITLE 1,Display Batch Inference Requests +display_df = spark.createDataFrame([ + (req.index, req.text, req.system_message, str(req.few_shots)) + for req in batch_inference_requests +], ["index", "text", "system_message", "few_shots"]) + +display(display_df) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC The following records and stores the batch inference responses. + +# COMMAND ---------- + +# DBTITLE 1,Create Batch Inference Manager +batch_manager = BatchInferenceManager( + client=AsyncChatClient( + endpoint_name=config_endpoint_name, + request_params=config_request_params, + timeout=config_timeout, + max_retries_backpressure=config_max_retries_backpressure, + max_retries_other=config_max_retries_other, + log_level=config_log_level, + ), + concurrency=config_concurrecy, + logging_interval=config_logging_interval, + log_level=config_log_level, +) + +# COMMAND ---------- + +# DBTITLE 1,Batch Inference +batch_inference_responses = await batch_manager.batch_inference(batch_inference_requests) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Save results +# MAGIC The following stores the output to the result table and displays the results + +# COMMAND ---------- + +# DBTITLE 1,Load Notebook Utils +# MAGIC %run ./notebook_utils + +# COMMAND ---------- + +# DBTITLE 1,Organize Output +source_sdf = spark.table(config_result_table) +batch_inference_result_processor = BatchInferenceResultProcessor( + model_serving_endpoint_for_fix=config_endpoint_name, + request_params_for_fix=config_request_params, +) +output_sdf = batch_inference_result_processor.process_results(source_sdf, batch_inference_responses) +display(output_sdf) + +# COMMAND ---------- + +# DBTITLE 1,Save Result +output_sdf.write.mode("overwrite").saveAsTable(config_result_table) +print(f"Successfully saved result into the table: {config_result_table}") + +# COMMAND ---------- + +# DBTITLE 1,Display Result Table +spark.table(config_result_table).display() + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Cleaning results +# MAGIC The following performs cleaning on `result_content`. The reason for saving the data first and then performing cleaning is to enable time travel in case there are any issues with the cleaning process. + +# COMMAND ---------- + +# DBTITLE 1,Clean Result +cleand_df = clean_conversion_results(config_result_table) +display(cleand_df) + +# COMMAND ---------- + +# DBTITLE 1,Save Cleaned Result +cleand_df.write.mode("overwrite").saveAsTable(config_result_table) +print(f"Successfully saved cleaned result into the table: {config_result_table}") + +# COMMAND ---------- + +# DBTITLE 1,Display Cleaned Result Table +spark.table(config_result_table).display() diff --git a/sql2dbx/notebooks/04_split_cells.py b/sql2dbx/notebooks/04_split_cells.py new file mode 100644 index 00000000..f33673e9 --- /dev/null +++ b/sql2dbx/notebooks/04_split_cells.py @@ -0,0 +1,114 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # Split Cells +# MAGIC This notebook is designed to split the converted Python code into multiple cells based on logical structure and control flow. The main objectives of this notebook are: +# MAGIC +# MAGIC 1. **Code Analysis**: Analyze the Python code to identify control structures and logical blocks. +# MAGIC 2. **Cell Splitting**: Use the CellSplitHelper to determine appropriate cell boundaries. +# MAGIC 3. **Cell Insertion**: Insert cell separators at the determined positions. +# MAGIC +# MAGIC ## Task Overview +# MAGIC The following tasks are accomplished in this notebook: +# MAGIC +# MAGIC 1. **Load Converted Code**: Extract the converted Python code from the result table. +# MAGIC 2. **Apply Cell Splitting**: Insert cell separators using the CellSplitHelper. +# MAGIC 3. **Save Results**: Store the updated code with cell separators back into the result table. +# MAGIC +# MAGIC This notebook plays a crucial role in improving the readability and executability of the converted code in the Databricks environment. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Install and import libraries + +# COMMAND ---------- + +# DBTITLE 1,Install Packages +# MAGIC %pip install -r requirements.txt +# MAGIC dbutils.library.restartPython() + +# COMMAND ---------- + +# DBTITLE 1,Import Libraries +from pyscripts.cell_split_helper import CellSplitHelper +from pyspark.sql.functions import udf +from pyspark.sql.types import StringType + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Set up configuration parameters + +# COMMAND ---------- + +# DBTITLE 1,Configurations +# Required Parameters +dbutils.widgets.text("result_table", "", "Conversion Result Table (Required)") + +# Required Parameters with Default Values +dbutils.widgets.dropdown("log_level", "INFO", ["DEBUG", "INFO"], "Logging Level") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Parameters +# MAGIC +# MAGIC Parameter Name | Required | Description | Default Value +# MAGIC --- | --- | --- | --- +# MAGIC `result_table` | Yes | The name of the conversion result table created in the previous notebook. | +# MAGIC `log_level` | Yes | The logging level to use for the cell splitting process. Options are `INFO` for standard logging or `DEBUG` for detailed debug information. | `INFO` + +# COMMAND ---------- + +# DBTITLE 1,Load Configurations +# Load configurations from widgets +config_result_table = dbutils.widgets.get("result_table") +config_log_level = dbutils.widgets.get("log_level") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Extract Input Data + +# COMMAND ---------- + +# DBTITLE 1,Extract Input Data +input_sdf = spark.sql(f""" + SELECT + input_file_number, + result_content + FROM {config_result_table} + WHERE result_content IS NOT NULL +""") +display(input_sdf) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Apply Cell Splitting +# MAGIC The following applies cell splitting using the CellSplitHelper. + +# COMMAND ---------- + +# DBTITLE 1,Apply Cell Splitting +cell_split_helper = CellSplitHelper(log_level=config_log_level) +split_cells_udf = udf(cell_split_helper.split_cells, StringType()) +source_sdf = spark.table(config_result_table) + +# Apply split_cells_udf to the dataframe +output_sdf = source_sdf.withColumn( + "result_content", + split_cells_udf(source_sdf.result_content) +) +display(output_sdf) + +# COMMAND ---------- + +# DBTITLE 1,Save Result +output_sdf.write.mode("overwrite").saveAsTable(config_result_table) +print(f"Successfully saved result into the table: {config_result_table}") + +# COMMAND ---------- + +# DBTITLE 1,Display Result Table +spark.table(config_result_table).display() diff --git a/sql2dbx/notebooks/05_export_to_databricks_notebooks.py b/sql2dbx/notebooks/05_export_to_databricks_notebooks.py new file mode 100644 index 00000000..2f633c9c --- /dev/null +++ b/sql2dbx/notebooks/05_export_to_databricks_notebooks.py @@ -0,0 +1,134 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # Export to Databricks Notebooks +# MAGIC This notebook exports the converted code from the Delta table to Databricks notebooks. It iterates through the rows of the input table, retrieves the converted code, and then creates a corresponding Databricks notebook in the specified output directory. +# MAGIC +# MAGIC ## Task Overview +# MAGIC The following tasks are accomplished in this notebook: +# MAGIC +# MAGIC 1. **Load Data:** The data is loaded from the input table, which is the output of the previous conversion steps. +# MAGIC 2. **Prepare Notebook Content:** For each row in the table, the converted code is extracted and formatted into a Databricks notebook structure. +# MAGIC 3. **Export Notebooks:** The prepared notebooks are exported to the specified output directory using the `databricks-sdk` library. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Install and import libraries + +# COMMAND ---------- + +# DBTITLE 1,Install Packages +# MAGIC %pip install -r requirements.txt +# MAGIC %pip install databricks-sdk --upgrade +# MAGIC dbutils.library.restartPython() + +# COMMAND ---------- + +# DBTITLE 1,Import Libraries +import json +import os +from dataclasses import asdict + +import pandas as pd +from databricks.sdk import WorkspaceClient +from databricks.sdk.service import workspace +from pyscripts.notebook_export_helper import ExportInput, NotebookExportHelper + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Set up configuration parameters + +# COMMAND ---------- + +# DBTITLE 1,Configurations +dbutils.widgets.text("result_table", "", "Conversion Result Table") +dbutils.widgets.text("output_dir", "", "Output Directory") +dbutils.widgets.text("comment_lang", "English", "Comment Language") + +# COMMAND ---------- + +# DBTITLE 1,Load Configurations +result_table = dbutils.widgets.get("result_table") +output_dir = dbutils.widgets.get("output_dir") +comment_lang = dbutils.widgets.get("comment_lang") + +result_table, output_dir, comment_lang + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Parameters +# MAGIC Parameter Name | Required | Description +# MAGIC --- | --- | --- +# MAGIC `result_table` | Yes | The name of the conversion result table created in the previous notebook. +# MAGIC `output_dir` | Yes | The directory where Databricks notebooks are saved. Supports the path in Workspace or Repos. + +# COMMAND ---------- + +# DBTITLE 1,Enable Auto Reload Import Modules +# MAGIC %load_ext autoreload +# MAGIC %autoreload 2 + +# COMMAND ---------- + +# DBTITLE 1,Prepare Export +helper = NotebookExportHelper() +df = spark.table(result_table) + +exporter_inputs = [ExportInput(input_file_path=row['input_file_path'], + output_dir=output_dir, + code=row['result_content'], + comment_lang=comment_lang, + python_parse_error=row['result_python_parse_error'], + sql_parse_error=row['result_sql_parse_errors'], + ) for row in df.collect()] +results = helper.process_notebooks(exporter_inputs) + +# COMMAND ---------- + +# DBTITLE 1,Export Notebooks +ws_client = WorkspaceClient() + +for output in results: + # Create directories if they don't exist + os.makedirs(os.path.dirname(output.output_file_path), exist_ok=True) + + # Check the size of the encoded content + if output.base64_encoded_content_size > 10 * 1024 * 1024: + output.export_error = "Content size exceeds 10MB limit" + continue + + try: + # Export notebook + ws_client.workspace.import_( + content=output.base64_encoded_content, + path=output.output_file_path, + format=workspace.ImportFormat.SOURCE, + language=workspace.Language.PYTHON, + overwrite=True, + ) + print(f"Exported notebook to {output.output_file_path}") + output.export_succeeded = True + except Exception as e: + output.export_error = str(e) + +# COMMAND ---------- + +# DBTITLE 1,Display Export Results +exclude_fields = {'base64_encoded_content'} +export_results_dict = [ + {k: v for k, v in output.__dict__.items() if k not in exclude_fields} + for output in results +] +display(pd.DataFrame(export_results_dict)) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Return the export results + +# COMMAND ---------- + +# DBTITLE 1,Return Export Results +dbutils.notebook.exit(json.dumps(export_results_dict)) diff --git a/sql2dbx/notebooks/11_adjust_conversion_targets.py b/sql2dbx/notebooks/11_adjust_conversion_targets.py new file mode 100644 index 00000000..1f16c7d5 --- /dev/null +++ b/sql2dbx/notebooks/11_adjust_conversion_targets.py @@ -0,0 +1,89 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # Adjust Conversion Targets +# MAGIC This notebook allows you to refine the selection of files for conversion from SQL to Databricks notebooks. By adjusting the `is_conversion_target` field in the target table, you can control which files are included or excluded in the conversion process. This is particularly useful after an initial conversion attempt, allowing you to focus on specific files that failed to convert or require adjustments before re-conversion. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Install and import libraries + +# COMMAND ---------- + +# DBTITLE 1,Install Packages +# MAGIC %pip install -r requirements.txt +# MAGIC dbutils.library.restartPython() + +# COMMAND ---------- + +# DBTITLE 1,Import Libraries +from pyspark.sql.functions import col, when +from pyscripts import utils + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Set up configuration parameters + +# COMMAND ---------- + +# DBTITLE 1,Configurations +dbutils.widgets.text("result_table", "", "Conversion Result Table") +dbutils.widgets.text("set_true_numbers", "", "Set True Numbers") +dbutils.widgets.text("set_false_numbers", "", "Set False Numbers") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Parameters +# MAGIC Parameter Name | Required | Description +# MAGIC --- | --- | --- +# MAGIC `result_table` | Yes | The name of the conversion result table created in the previous notebook. +# MAGIC `set_true_numbers` | No | A comma-separated list or range of `input_file_number` values to set `is_conversion_target` to True (e.g., `1, 2-4, 5-6, 7` which sets 1 through 7 as targets). +# MAGIC `set_false_numbers` | No | A comma-separated list or range of `input_file_number` values to set `is_conversion_target` to False (e.g., `1, 2-4, 5-6, 7` which sets 1 through 7 as non-targets). + +# COMMAND ---------- + +# DBTITLE 1,Load Configurations +result_table = dbutils.widgets.get("result_table") +set_true_numbers = utils.parse_number_ranges(dbutils.widgets.get("set_true_numbers")) +set_false_numbers = utils.parse_number_ranges(dbutils.widgets.get("set_false_numbers")) + +result_table, set_true_numbers, set_false_numbers + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Display table before update + +# COMMAND ---------- + +# DBTITLE 1,Display Original +original_df = spark.table(result_table) +display(original_df) + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Update table + +# COMMAND ---------- + +# DBTITLE 1,Update Table +update_df = original_df.withColumn( + "is_conversion_target", + when(col("input_file_number").isin(set_true_numbers), True) + .when(col("input_file_number").isin(set_false_numbers), False) + .otherwise(col("is_conversion_target")), +) + +display(update_df) + +# COMMAND ---------- + +# DBTITLE 1,Writing Output to a Table +if set_true_numbers or set_false_numbers: + update_df.write.mode("overwrite").saveAsTable(result_table) + print(f"Changes applied to the result table: {result_table}.") +else: + print("No changes applied to the result table: {result_table} because the parameters are empty.") diff --git a/sql2dbx/notebooks/__init__.py b/sql2dbx/notebooks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sql2dbx/notebooks/external_model/external_model_amazon_bedrock.py b/sql2dbx/notebooks/external_model/external_model_amazon_bedrock.py new file mode 100644 index 00000000..ed5645ba --- /dev/null +++ b/sql2dbx/notebooks/external_model/external_model_amazon_bedrock.py @@ -0,0 +1,95 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # External Model - Amazon Bedrock Setup +# MAGIC This notebook sets up an external model endpoint for Amazon Bedrock in Databricks. +# MAGIC +# MAGIC ## Prerequisites +# MAGIC - Ensure that you have access to the desired models in Amazon Bedrock. For information on managing model access, refer to the [Amazon Bedrock Model Access documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/model-access.html). +# MAGIC - Your AWS IAM user or role must have the necessary permissions to invoke Bedrock models, including the `bedrock:InvokeModel` action on Bedrock resources. +# MAGIC - You have set up the AWS credentials (Access Key ID and Secret Access Key) with the required permissions. +# MAGIC - You have the following information available: +# MAGIC - Your AWS Access Key ID and Secret Access Key +# MAGIC - The AWS region where your Bedrock models are deployed +# MAGIC - The name of the Bedrock model you want to use +# MAGIC +# MAGIC ## Parameters +# MAGIC | Parameter Name | Required | Description | Example | +# MAGIC |----------------|----------|-------------|---------| +# MAGIC | endpoint_name | Yes | Name for the new Databricks serving endpoint | `my-bedrock-endpoint` | +# MAGIC | model_name | Yes | Name of the external model | `claude-3-5-sonnet-20240620-v1:0` | +# MAGIC | region | Yes | Region where the external model is deployed | `us-east-1` | +# MAGIC | aws_access_key_id | Yes | AWS Access Key ID for Amazon Bedrock authentication | ***your-access-key-id*** | +# MAGIC | aws_secret_access_key | Yes | AWS Secret Access Key for Amazon Bedrock authentication | ***your-secret-access-key*** | +# MAGIC +# MAGIC **Note:** Ensure all these parameters are correctly set before running this notebook. You can find most of this information in your AWS Management Console or from your AWS administrator. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Import helper notebook + +# COMMAND ---------- + +# DBTITLE 1,Import Helper Notebook +# MAGIC %run ./helper_external_model + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Set up configuration parameters + +# COMMAND ---------- + +# DBTITLE 1,Configurations +dbutils.widgets.text("endpoint_name", "", "Endpoint Name") +dbutils.widgets.text("model_name", "claude-3-5-sonnet-20240620-v1:0", "Model Name") +dbutils.widgets.text("region", "us-east-1", "AWS Region") +dbutils.widgets.text("aws_access_key_id", "", "AWS Access Key ID") +dbutils.widgets.text("aws_secret_access_key", "", "AWS Secret Access Key") + +# COMMAND ---------- + +# DBTITLE 1,Load Configurations +endpoint_name = dbutils.widgets.get("endpoint_name") +model_name = dbutils.widgets.get("model_name") +region = dbutils.widgets.get("region") +aws_access_key_id = dbutils.widgets.get("aws_access_key_id") +aws_secret_access_key = dbutils.widgets.get("aws_secret_access_key") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Endpoint Configuration and Creation + +# COMMAND ---------- + +# DBTITLE 1,Prepare Endpoint Configuration +endpoint_config = { + "served_entities": [ + { + "name": endpoint_name, + "external_model": { + "name": model_name, + "provider": "amazon-bedrock", + "task": "llm/v1/chat", + "amazon_bedrock_config": { + "aws_region": region, + "aws_access_key_id_plaintext": aws_access_key_id, + "aws_secret_access_key_plaintext": aws_secret_access_key, + "bedrock_provider": "anthropic", + } + } + } + ] +} + +# COMMAND ---------- + +# DBTITLE 1,Create or Update Endpoint +try: + endpoint_helper = ExternalModelEndpointHelper() + endpoint = endpoint_helper.create_or_update_endpoint(endpoint_name, endpoint_config) + print(f"Endpoint '{endpoint_name}' has been successfully created/updated.") + print("Endpoint details:", endpoint) +except Exception as e: + print(f"Failed to create or update endpoint: {str(e)}") diff --git a/sql2dbx/notebooks/external_model/external_model_azure_openai.py b/sql2dbx/notebooks/external_model/external_model_azure_openai.py new file mode 100644 index 00000000..26379258 --- /dev/null +++ b/sql2dbx/notebooks/external_model/external_model_azure_openai.py @@ -0,0 +1,99 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # External Model - Azure OpenAI Setup +# MAGIC This notebook sets up an external model endpoint for Azure OpenAI in Databricks. +# MAGIC +# MAGIC ## Prerequisites +# MAGIC - Ensure that the Supported models like `gpt-4` mentioned in [this document](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/external-models/) are available in your Azure OpenAI service. +# MAGIC - You have deployed an Azure OpenAI model. For deployment instructions, refer to the [Azure OpenAI deployment documentation](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/create-resource?pivots=web-portal). +# MAGIC - You have the following information for parameters available: +# MAGIC - Your Azure OpenAI API key +# MAGIC - The base URL for your Azure OpenAI API service +# MAGIC - The deployment name of your Azure OpenAI model +# MAGIC - The latest API version for Azure OpenAI (e.g., `2024-05-01-preview`) +# MAGIC +# MAGIC ## Parameters +# MAGIC | Parameter Name | Required | Description | Example | +# MAGIC |----------------|----------|-------------|---------| +# MAGIC | endpoint_name | Yes | Name for the new Databricks serving endpoint | `my-azure-openai-endpoint` | +# MAGIC | model_name | Yes | Name of the external model | `gpt-4o` | +# MAGIC | openai_api_key | Yes | Azure OpenAI API Key for authentication | ***your-api-key*** | +# MAGIC | openai_api_base | Yes | Base URL for the Azure OpenAI API service | `https://oneenvazureopenai.openai.azure.com` | +# MAGIC | openai_api_version | Yes | Version of the Azure OpenAI service to use | `2024-05-01-preview` | +# MAGIC | openai_deployment_name | Yes | Name of the deployment resource for Azure OpenAI service | `my-gpt4o-deployment` | +# MAGIC +# MAGIC **Note:** Ensure all these parameters are correctly set before running this notebook. You can find most of this information in your Azure portal under your Azure OpenAI resource. + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Import helper notebook + +# COMMAND ---------- + +# DBTITLE 1,Import Helper Notebook +# MAGIC %run ./helper_external_model + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Set up configuration parameters + +# COMMAND ---------- + +# DBTITLE 1,Configurations +dbutils.widgets.text("endpoint_name", "", "Endpoint Name") +dbutils.widgets.text("model_name", "gpt-4o", "Model Name") +dbutils.widgets.text("openai_api_key", "", "Azure OpenAI API Key") +dbutils.widgets.text("openai_api_base", "", "Azure OpenAI API Base URL") +dbutils.widgets.text("openai_api_version", "2024-05-01-preview", "Azure OpenAI API Version") +dbutils.widgets.text("openai_deployment_name", "", "Azure OpenAI Deployment Name") + +# COMMAND ---------- + +# DBTITLE 1,Load Configurations +endpoint_name = dbutils.widgets.get("endpoint_name") +model_name = dbutils.widgets.get("model_name") +openai_api_key = dbutils.widgets.get("openai_api_key") +openai_api_base = dbutils.widgets.get("openai_api_base") +openai_api_version = dbutils.widgets.get("openai_api_version") +openai_deployment_name = dbutils.widgets.get("openai_deployment_name") + +# COMMAND ---------- + +# MAGIC %md +# MAGIC ## Endpoint Configuration and Creation + +# COMMAND ---------- + +# DBTITLE 1,Prepare Endpoint Configuration +endpoint_config = { + "served_entities": [ + { + "name": endpoint_name, + "external_model": { + "name": model_name, + "provider": "openai", + "task": "llm/v1/chat", + "openai_config": { + "openai_api_type": "azure", + "openai_api_key_plaintext": openai_api_key, + "openai_api_base": openai_api_base, + "openai_api_version": openai_api_version, + "openai_deployment_name": openai_deployment_name + } + } + } + ] +} + +# COMMAND ---------- + +# DBTITLE 1,Create or Update Endpoint +try: + endpoint_helper = ExternalModelEndpointHelper() + endpoint = endpoint_helper.create_or_update_endpoint(endpoint_name, endpoint_config) + print(f"Endpoint '{endpoint_name}' has been successfully created/updated.") + print("Endpoint details:", endpoint) +except Exception as e: + print(f"Failed to create or update endpoint: {str(e)}") diff --git a/sql2dbx/notebooks/external_model/helper_external_model.py b/sql2dbx/notebooks/external_model/helper_external_model.py new file mode 100644 index 00000000..9b54c388 --- /dev/null +++ b/sql2dbx/notebooks/external_model/helper_external_model.py @@ -0,0 +1,97 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # Helper Notebook for External Model +# MAGIC This notebook contains utility functions for the management of secrets and external models. + +# COMMAND ---------- + +# DBTITLE 1,Install Packages +# MAGIC %pip install mlflow==2.* +# MAGIC dbutils.library.restartPython() + +# COMMAND ---------- + +# DBTITLE 1,External Model Helper Functions +from typing import Any, Dict, Optional + +from databricks.sdk import WorkspaceClient +from databricks.sdk.core import DatabricksError +from mlflow.deployments import DatabricksDeploymentClient +from requests.exceptions import HTTPError + + +class DatabricksSecretHelper: + """Manages secret scopes and secrets in Databricks.""" + + def __init__(self) -> None: + """Initialize the DatabricksSecretHelper.""" + self.client = WorkspaceClient() + + def create_scope_if_not_exists(self, scope_name: str) -> None: + """ + Create a scope if it doesn't exist. + + :param scope_name: Name of the scope to create or check + """ + existing_scopes = [scope.name for scope in self.client.secrets.list_scopes()] + if scope_name not in existing_scopes: + self.client.secrets.create_scope(scope_name) + print(f"Scope '{scope_name}' created.") + else: + print(f"Scope '{scope_name}' already exists.") + + def create_secret_if_not_exists(self, scope_name: str, secret_key: str, secret_value: str) -> None: + """ + Create a secret if it doesn't exist. + + :param scope_name: Name of the scope to store the secret + :param secret_key: Key of the secret + :param secret_value: Value of the secret + """ + try: + self.client.secrets.get_secret(scope_name, secret_key) + print(f"Secret '{secret_key}' already exists in scope '{scope_name}'.") + except DatabricksError as e: + error_message = str(e) + if "Failed to get secret" in error_message and "for scope" in error_message: + try: + self.client.secrets.put_secret(scope_name, secret_key, string_value=secret_value) + print(f"Secret '{secret_key}' created in scope '{scope_name}'.") + except DatabricksError as put_error: + print(f"Failed to create secret '{secret_key}' in scope '{scope_name}': {str(put_error)}") + raise + else: + print(f"Unexpected error accessing secret '{secret_key}' in scope '{scope_name}': {error_message}") + raise + + +class ExternalModelEndpointHelper: + """Manages the creation and update of Databricks external model endpoints.""" + + def __init__(self): + """Initialize the ExternalModelEndpointHelper.""" + self.client = DatabricksDeploymentClient("databricks") + + def create_or_update_endpoint(self, endpoint_name: str, endpoint_config: Dict[str, Any]) -> Dict[str, Any]: + """ + Create a new endpoint or update an existing one with the given configuration. + + :param endpoint_name: Name of the endpoint to create or update + :param endpoint_config: Configuration for the endpoint + :return: The created or updated endpoint + """ + try: + existing_endpoint = self.client.get_endpoint(endpoint_name) + endpoint = self.client.update_endpoint(endpoint=endpoint_name, config=endpoint_config) + print(f"Endpoint '{endpoint_name}' has been successfully updated.") + except HTTPError as e: + if e.response.status_code == 404 and "RESOURCE_DOES_NOT_EXIST" in str(e): + endpoint = self.client.create_endpoint(name=endpoint_name, config=endpoint_config) + print(f"Endpoint '{endpoint_name}' has been successfully created.") + else: + print(f"An error occurred while getting, creating, or updating the endpoint: {str(e)}") + raise + except Exception as e: + print(f"An unexpected error occurred: {str(e)}") + raise + return endpoint diff --git a/sql2dbx/notebooks/notebook_utils.py b/sql2dbx/notebooks/notebook_utils.py new file mode 100644 index 00000000..b03ea4c9 --- /dev/null +++ b/sql2dbx/notebooks/notebook_utils.py @@ -0,0 +1,254 @@ +# Databricks notebook source +# MAGIC %md +# MAGIC # Notebook Utilities +# MAGIC This notebook contains utility functions for use in other notebooks. + +# COMMAND ---------- + +# DBTITLE 1,Import Libraries +import json +from datetime import datetime +from typing import Any, Dict, List, Optional + +from pyscripts.batch_inference_helper import BatchInferenceResponse +from pyscripts.conversion_result_clean_helper import \ + ConversionResultCleanHelper +from pyspark.sql import DataFrame +from pyspark.sql.functions import coalesce, col, lit, udf, when +from pyspark.sql.types import (ArrayType, FloatType, IntegerType, LongType, + StringType, StructField, StructType, + TimestampType) + +# COMMAND ---------- + +# DBTITLE 1,Define Functions +def clean_conversion_results(target_table: str, size_ratio_threshold: float = 0.9) -> DataFrame: + """ + Cleans the conversion results in the specified Delta table. + + Args: + target_table (str): The name of the target table. + size_ratio_threshold (float): The threshold for the ratio of cleaned content size to original content size. If the ratio is below this threshold, a warning is printed. Default is 0.9 (90%). + + Returns: + pyspark.sql.DataFrame: The cleaned DataFrame. + """ + original_df = spark.table(target_table) + cleaned_df = original_df + + # Apply each UDF function to clean the result_content column + helper = ConversionResultCleanHelper() + udf_functions = helper.get_udf_functions() + for udf_func in udf_functions: + clean_udf = udf(udf_func, StringType()) + cleaned_df = cleaned_df.withColumn("result_content", clean_udf(cleaned_df["result_content"])) + + # Compare the sizes of the original and cleaned content. If the cleaned content is under the threshold, print a warning. + small_content_files_df = compare_content_sizes(original_df, cleaned_df, "result_content", size_ratio_threshold) + if small_content_files_df.count() > 0: + print(f"Warning: The following files have cleaned content sizes less than " + f"{size_ratio_threshold * 100}% of their original sizes, " + f"indicating potential data loss:") + display(small_content_files_df) + + # Update result_timestamp if the cleaned content is different from the original content + return ( + cleaned_df.alias("cleaned") + .join(original_df.select("input_file_number", "result_content").alias("original"), + on="input_file_number") + .withColumn( + "result_timestamp", + when(col("cleaned.result_content") != col("original.result_content"), datetime.now()) + .otherwise(col("cleaned.result_timestamp"))) + .drop(col("original.result_content")) + ) + + +def compare_content_sizes(df1: DataFrame, df2: DataFrame, col_name: str, threshold: float = 0.9) -> DataFrame: + """ + Compares the sizes of the specified column in two DataFrames and returns a DataFrame + containing rows where the size ratio is below the threshold. + + Args: + df1 (DataFrame): The first DataFrame (original). + df2 (DataFrame): The second DataFrame (cleaned). + col_name (str): The name of the column to compare sizes. + threshold (float): The threshold for the size ratio. Default is 0.9 (90%). + + Returns: + DataFrame: A DataFrame containing rows where the size ratio is below the threshold. + """ + def safe_len(value): + if value is None: + return 0 + return len(value) + + size_udf = udf(safe_len, IntegerType()) + df1 = df1.select( + "input_file_number", + "input_file_path", + size_udf(col(col_name)).alias("original_content_size")) + df2 = df2.select( + "input_file_number", + size_udf(col(col_name)).alias("cleaned_content_size")) + return ( + df1.join(df2, on="input_file_number") + .filter(col("cleaned_content_size") / col("original_content_size") < threshold) + .select( + "input_file_number", + "input_file_path", + "original_content_size", + "cleaned_content_size", + (col("cleaned_content_size") / col("original_content_size")).alias("size_ratio")) + ) + + +class BatchInferenceResultProcessor: + """ + A class to process batch inference results and merge them with source data in a Databricks environment. + """ + + def __init__(self, model_serving_endpoint_for_conversion: Optional[str] = None, + model_serving_endpoint_for_fix: Optional[str] = None, + request_params_for_conversion: Optional[Dict[str, Any]] = None, + request_params_for_fix: Optional[Dict[str, Any]] = None): + """ + Initialize the BatchInferenceResultProcessor with the schema for inference responses and model serving endpoints. + + Args: + model_serving_endpoint_for_conversion (Optional[str]): The model serving endpoint for conversion. + model_serving_endpoint_for_fix (Optional[str]): The model serving endpoint for fix. + request_params_for_conversion (Optional[Dict[str, Any]]): Request parameters for conversion. + request_params_for_fix (Optional[Dict[str, Any]]): Request parameters for fix. + """ + self.model_serving_endpoint_for_conversion = model_serving_endpoint_for_conversion + self.model_serving_endpoint_for_fix = model_serving_endpoint_for_fix + + # Convert request parameters to JSON strings during initialization + self.request_params_for_conversion_json = json.dumps( + request_params_for_conversion) if request_params_for_conversion is not None else None + self.request_params_for_fix_json = json.dumps( + request_params_for_fix) if request_params_for_fix is not None else None + + self.schema = StructType([ + StructField("input_file_number", LongType(), True), + StructField("result_content", StringType(), True), + StructField("result_prompt_tokens", IntegerType(), True), + StructField("result_completion_tokens", IntegerType(), True), + StructField("result_total_tokens", IntegerType(), True), + StructField("result_processing_time_seconds", FloatType(), True), + StructField("result_timestamp", TimestampType(), True), + StructField("result_error", StringType(), True), + ]) + + def process_results(self, source_sdf: DataFrame, responses: List[BatchInferenceResponse]) -> DataFrame: + """ + Process the batch inference results and merge them with the source DataFrame. + + Args: + source_sdf (DataFrame): The source DataFrame containing original data. + responses (List[BatchInferenceResponse]): The list of responses from batch inference. + + Returns: + DataFrame: The processed DataFrame with merged results. + """ + result_sdf = self._create_result_dataframe(responses) + joined_sdf = self._join_dataframes(source_sdf, result_sdf) + update_columns = self._get_update_columns() + select_columns = self._get_select_columns(source_sdf, update_columns) + return joined_sdf.select(*select_columns) + + def _create_result_dataframe(self, responses: List[BatchInferenceResponse]) -> DataFrame: + """Create a DataFrame from the batch inference responses.""" + current_time = datetime.now() + responses_with_timestamp = [ + ( + res.index, + res.content, + res.token_usage.prompt_tokens if res.token_usage else None, + res.token_usage.completion_tokens if res.token_usage else None, + res.token_usage.total_tokens if res.token_usage else None, + res.processing_time_seconds, + current_time, + res.error, + ) + for res in responses + ] + return spark.createDataFrame(responses_with_timestamp, schema=self.schema) + + def _join_dataframes(self, source_sdf: DataFrame, result_sdf: DataFrame) -> DataFrame: + """Join the source and result DataFrames.""" + return source_sdf.alias("source").join(result_sdf.alias("result"), on="input_file_number", how="left") + + def _get_update_columns(self) -> List: + """Get the list of columns to update or add.""" + return [ + # Update conversion target flag based on successful conversion + when((col("result.result_content").isNotNull()) & (col("result.result_error").isNull()), lit(False)) + .otherwise(col("source.is_conversion_target")).alias("is_conversion_target"), + + # Basic result columns + coalesce(col("result.result_content"), col("source.result_content")).alias("result_content"), + coalesce(col("result.result_prompt_tokens"), col("source.result_prompt_tokens")).alias("result_prompt_tokens"), + coalesce(col("result.result_completion_tokens"), col("source.result_completion_tokens")).alias("result_completion_tokens"), + coalesce(col("result.result_total_tokens"), col("source.result_total_tokens")).alias("result_total_tokens"), + coalesce(col("result.result_processing_time_seconds"), col("source.result_processing_time_seconds")).alias("result_processing_time_seconds"), + coalesce(col("result.result_timestamp"), col("source.result_timestamp")).alias("result_timestamp"), + + # Update result_error with appropriate error handling + # - Clear error if conversion succeeded (result_content exists and no new error) + # - Set new error if conversion failed + # - Keep existing error otherwise + when((col("result.result_content").isNotNull()) & (col("result.result_error").isNull()), lit(None)) + .when(col("result.result_error").isNotNull(), col("result.result_error")) + .otherwise(col("source.result_error")).alias("result_error"), + + # Reset analysis-related columns + lit(None).cast(StringType()).alias("result_python_parse_error"), + lit(None).cast(ArrayType(StringType())).alias("result_extracted_sqls"), + lit(None).cast(ArrayType(StringType())).alias("result_sql_parse_errors"), + + # Model serving endpoints and request params + coalesce( + lit(self.model_serving_endpoint_for_conversion), + col("source.model_serving_endpoint_for_conversion") + ).alias("model_serving_endpoint_for_conversion"), + + coalesce( + lit(self.model_serving_endpoint_for_fix), + col("source.model_serving_endpoint_for_fix") + ).alias("model_serving_endpoint_for_fix"), + + coalesce( + lit(self.request_params_for_conversion_json), + col("source.request_params_for_conversion") + ).alias("request_params_for_conversion"), + + coalesce( + lit(self.request_params_for_fix_json), + col("source.request_params_for_fix") + ).alias("request_params_for_fix"), + ] + + def _get_select_columns(self, source_sdf: DataFrame, update_columns: List) -> List: + """Get the list of columns to select in the final DataFrame.""" + excluded_columns = [ + "is_conversion_target", + "result_content", + "result_prompt_tokens", + "result_completion_tokens", + "result_total_tokens", + "result_processing_time_seconds", + "result_timestamp", + "result_error", + "result_python_parse_error", + "result_extracted_sqls", + "result_sql_parse_errors", + "model_serving_endpoint_for_conversion", + "model_serving_endpoint_for_fix", + "request_params_for_conversion", + "request_params_for_fix", + ] + select_columns = [col("source." + c) for c in source_sdf.columns if c not in excluded_columns] + select_columns.extend(update_columns) + return select_columns diff --git a/sql2dbx/notebooks/pyscripts/__init__.py b/sql2dbx/notebooks/pyscripts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sql2dbx/notebooks/pyscripts/batch_inference_helper.py b/sql2dbx/notebooks/pyscripts/batch_inference_helper.py new file mode 100644 index 00000000..e1567a1f --- /dev/null +++ b/sql2dbx/notebooks/pyscripts/batch_inference_helper.py @@ -0,0 +1,718 @@ +""" +This module provides asynchronous batch inference capabilities for Databricks model serving endpoints. + +Key components: +- AsyncChatClient: Handles API communication with retries and error handling +- BatchInferenceManager: Manages concurrent processing of multiple requests +- BatchInferenceRequest/Response: Data structures for request/response handling +""" +import asyncio +import json +import logging +import time +import traceback +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import httpx +from httpx import codes +from tenacity import (RetryCallState, RetryError, retry, + retry_if_exception_type, stop_after_attempt, + wait_random_exponential) + +from .databricks_credentials import DatabricksCredentials +from .utils import setup_logger + +# Default configuration values +DEFAULT_CONCURRENCY = 10 +DEFAULT_LOGGING_INTERVAL = 1 +DEFAULT_MAX_RETRIES_BACKPRESSURE = 20 +DEFAULT_MAX_RETRIES_OTHER = 5 +DEFAULT_TIMEOUT = 300 + +# Retry backoff configuration +# multiplier: Base wait time in seconds for exponential backoff +# max_wait: Maximum wait time limit in seconds +# These values are fixed to ensure proper load management for Databricks endpoints +# and maintain consistent retry behavior across all requests +DEFAULT_RETRY_MULTIPLIER = 5 +DEFAULT_RETRY_MAX_WAIT = 60 + + +@dataclass +class BatchInferenceRequest: + """ + A class to represent a single request in a batch inference process. + + Attributes: + index (int): The index of the request in the batch. + text (str): The input text for the inference. + system_message (str): The system message to guide the model's behavior. + few_shots (Optional[List[Dict[str, str]]]): Optional few-shot examples for the model. + """ + index: int + text: str + system_message: str + few_shots: Optional[List[Dict[str, str]]] = field(default=None) + + +@dataclass +class TokenUsage: + """ + A class to represent token usage information from an API response. + + Attributes: + prompt_tokens (int): Number of tokens used for the prompt. + completion_tokens (int): Number of tokens used for the completion/response. + total_tokens (int): Total number of tokens used (prompt + completion). + """ + prompt_tokens: int + completion_tokens: int + total_tokens: int + + +@dataclass +class BatchInferenceResponse: + """ + A class to represent a single response from a batch inference process. + + Attributes: + index (int): The index of the response, corresponding to the request index. + content (Optional[str]): The generated content from the model, if successful. + token_usage (Optional[TokenUsage]): Detailed token usage information. + error (Optional[str]): Any error message, if an error occurred during processing. + processing_time_seconds (Optional[float]): Total time in seconds spent processing this request. + """ + index: int + content: Optional[str] + token_usage: Optional[TokenUsage] + error: Optional[str] + processing_time_seconds: Optional[float] + + +class BatchInferenceManager: + """ + Manager for concurrent batch inference processing. + + Key features: + - Controls concurrent execution + - Monitors progress and handles errors + - Manages response aggregation + """ + + def __init__( + self, + client: 'AsyncChatClient', + concurrency: int = DEFAULT_CONCURRENCY, + logging_interval: int = DEFAULT_LOGGING_INTERVAL, + log_level: Union[int, str] = logging.INFO, + ): + """ + Initialize the BatchInferenceManager. + + Args: + client (AsyncChatClient): The AsyncChatClient instance to use for predictions. + concurrency (int): The number of concurrent requests allowed. + logging_interval (int): The interval for logging progress. + log_level (int): The logging level for the manager. + """ + self.client = client + self.concurrency = concurrency + self.logging_interval = logging_interval + self.logger = setup_logger('BatchInferenceManager', level=log_level) + self.logger.info(f"Initialized BatchInferenceManager with concurrency: {concurrency}") + + async def batch_inference(self, requests: List[BatchInferenceRequest]) -> List[BatchInferenceResponse]: + """ + Perform batch inference on a list of requests. + + Args: + requests (List[BatchInferenceRequest]): A list of BatchInferenceRequest objects for each input. + + Returns: + List[BatchInferenceResponse]: A list of BatchInferenceResponse objects containing the results. + """ + self.logger.info(f"Starting batch inference for {len(requests)} requests") + semaphore = asyncio.Semaphore(self.concurrency) + counter = AsyncCounter() + start_time = time.perf_counter() + + tasks = [self._process_inference(request, semaphore, counter, start_time) + for request in requests] + responses = await asyncio.gather(*tasks) + await self.client.close() + + self.logger.info(f"Completed batch inference for {len(requests)} requests") + return responses + + async def _process_inference(self, request: BatchInferenceRequest, + semaphore: asyncio.Semaphore, counter: 'AsyncCounter', + start_time: float) -> BatchInferenceResponse: + """ + Process inference for a single request. + + Args: + request (BatchInferenceRequest): The request data containing text, system message, and few-shots. + semaphore (asyncio.Semaphore): Semaphore for controlling concurrency. + counter (AsyncCounter): Counter for tracking progress. + start_time (float): The start time of the batch process. + + Returns: + BatchInferenceResponse: The response object containing the result or error information. + """ + async with semaphore: + try: + self.logger.info(f"Starting inference for index {request.index}") + predict_start_time = time.perf_counter() + content, token_usage = await self.client.predict(request) + processing_time_seconds = round(time.perf_counter() - predict_start_time, 2) + response = BatchInferenceResponse( + index=request.index, + content=content, + token_usage=token_usage, + error=None, + processing_time_seconds=processing_time_seconds + ) + self.logger.info( + f"Completed inference for index {request.index} in {processing_time_seconds:.2f} seconds") + except Exception as e: + response = await self._handle_error(e, request.index) + + await counter.increment() + if counter.value % self.logging_interval == 0: + elapsed_time = time.perf_counter() - start_time + self.logger.info(f"Processed total {counter.value} requests in {elapsed_time:.2f} seconds.") + + return response + + async def _handle_error(self, e: Exception, request_index: int) -> BatchInferenceResponse: + """ + Handle error during inference and create appropriate response. + + Args: + e: The exception that occurred + request_index: Index of the request that caused the error + + Returns: + BatchInferenceResponse: Response with error information + """ + # Extract original error from RetryError if possible + if isinstance(e, RetryError) and hasattr(e, 'last_attempt') and e.last_attempt.failed: + try: + original_error = e.last_attempt.exception() + if original_error: + e = original_error # Replace with the original error + except Exception: + # Keep the original RetryError if extraction fails + pass + + # Add response details for HTTP errors + error_message = str(e) + if isinstance(e, httpx.HTTPStatusError): + self.client._log_http_error(e, request_index, logger=self.logger) + try: + response_json = e.response.json() + error_message = f"{error_message}\nResponse details: {json.dumps(response_json, ensure_ascii=False)}" + except json.JSONDecodeError: + try: + response_text = e.response.text + if response_text: + error_message = f"{error_message}\nResponse details: {response_text[:500]}" + except Exception: + pass + elif isinstance(e, httpx.RequestError): + self.client._log_request_error(e, request_index, logger=self.logger) + else: + self.client._log_general_error(e, request_index, logger=self.logger) + + return BatchInferenceResponse( + index=request_index, + content=None, + token_usage=None, + error=error_message, + processing_time_seconds=None + ) + + +class AsyncChatClient: + """ + Asynchronous client for Databricks model serving endpoints. + + Key features: + - Handles API communication with retry mechanism + - Supports different retry strategies for backpressure + - Provides error logging and response handling + """ + + def __init__( + self, + endpoint_name: str, + request_params: Dict[str, Any], + timeout: int = DEFAULT_TIMEOUT, + max_retries_backpressure: int = DEFAULT_MAX_RETRIES_BACKPRESSURE, + max_retries_other: int = DEFAULT_MAX_RETRIES_OTHER, + log_level: Union[int, str] = logging.INFO, + ): + """ + Initialize the AsyncChatClient with the given parameters. + + Args: + endpoint_name (str): The name of the API endpoint. + request_params (Dict[str, Any]): Additional parameters for the API request. + timeout (int): The timeout for API requests in seconds. + max_retries_backpressure (int): Maximum number of retries for backpressure errors. + max_retries_other (int): Maximum number of retries for other errors. + log_level (int): The logging level for the client. + """ + self.client = httpx.AsyncClient(timeout=timeout) + self.endpoint_name = endpoint_name + self.request_params = request_params + self.max_retries_backpressure = max_retries_backpressure + self.max_retries_other = max_retries_other + self.logger = setup_logger('AsyncChatClient', level=log_level) + self.logger.info(f"Initialized AsyncChatClient with endpoint: {endpoint_name}") + self.logger.info(f"Request parameters: {self.request_params}") + + async def predict(self, request: 'BatchInferenceRequest') -> Tuple[str, TokenUsage]: + """ + Send a prediction request to the API and process the response. + + This method handles the main communication with the API, including + error handling and retries for transient failures. + + Args: + request (BatchInferenceRequest): The request object containing + the input for the prediction. + + Returns: + Tuple[str, TokenUsage]: A tuple containing the generated content and + the token usage information. + + Raises: + httpx.HTTPStatusError: If an HTTP error occurs that can't be resolved by retrying. + Exception: For any other unexpected errors. + """ + @retry( + retry=retry_if_exception_type(httpx.HTTPStatusError), + stop=lambda rs: self._get_stop_condition(rs), + wait=wait_random_exponential(multiplier=DEFAULT_RETRY_MULTIPLIER, max=DEFAULT_RETRY_MAX_WAIT), + ) + async def _predict_with_retry(): + try: + databricks_host_and_token = DatabricksCredentials().get_host_and_token() + databricks_host = databricks_host_and_token["host"] + databricks_token = databricks_host_and_token["token"] + + url = f"{databricks_host}/serving-endpoints/{self.endpoint_name}/invocations" + headers = { + "Authorization": f"Bearer {databricks_token}", + "Content-Type": "application/json", + } + + messages = self._initialize_messages(request) + total_content = "" + prompt_tokens = 0 + completion_tokens = 0 + total_tokens = 0 + + while True: + # Prepare the request body with the current messages + request_body = {"messages": messages} + if self.request_params: + request_body.update(self.request_params) + request_args = { + 'url': url, + 'headers': headers, + 'json': request_body + } + self._log_request_args(request_args, request.index, level=logging.INFO) + + # Send the request + response = await self.client.post(**request_args) + response.raise_for_status() + self._log_response_details(response, request.index, level=logging.DEBUG) + + # Process the response data + response_data = response.json() + content = self._parse_content(response_data["choices"][0]["message"]["content"]) + total_content += content + finish_reason = response_data["choices"][0]["finish_reason"] + + # Extract token usage information + token_usage = response_data["usage"] + current_prompt_tokens = token_usage.get("prompt_tokens", 0) + current_completion_tokens = token_usage.get("completion_tokens", 0) + current_total_tokens = token_usage.get("total_tokens", 0) + prompt_tokens += current_prompt_tokens + completion_tokens += current_completion_tokens + total_tokens += current_total_tokens + + self.logger.info( + f"Processed content for index {request.index}. " + f"Finish reason: {finish_reason}, " + f"Prompt tokens: {current_prompt_tokens}, " + f"Completion tokens: {current_completion_tokens}, " + f"Total tokens: {current_total_tokens}, " + f"Cumulative prompt tokens: {prompt_tokens}, " + f"Cumulative completion tokens: {completion_tokens}, " + f"Cumulative total tokens: {total_tokens}" + ) + + # Check if the response indicates that the model has finished + if finish_reason != "length": + break + + # If the response indicates that the finished reason is "length", + # we need to continue the conversation by appending the last message + messages.append({"role": "assistant", "content": content}) + messages.append({"role": "user", "content": f"The previous response ended with: '{content[-50:]}'. " + f"Please continue exactly from this point without repeating any content."}) + + return total_content, TokenUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens + ) + + # Handle HTTP errors (e.g., 4xx, 5xx) + except httpx.HTTPStatusError as e: + self._log_http_error(e, request.index) + raise + + # Handle non-HTTP errors (e.g., connection errors, timeouts) + except Exception as e: + self._log_general_error(e, request.index) + raise + + return await _predict_with_retry() + + async def close(self) -> None: + """ + Close the underlying HTTP client. + + This method should be called when the client is no longer needed + to ensure proper resource cleanup. + """ + await self.client.aclose() + self.logger.info("Closed AsyncChatClient") + + def _initialize_messages(self, request: 'BatchInferenceRequest') -> List[Dict[str, str]]: + """ + Initialize the message list with system message, few-shot examples, and user message. + + Args: + request (BatchInferenceRequest): The request object containing + the input data. + + Returns: + List[Dict[str, str]]: A list of message dictionaries ready for API submission. + """ + messages = [] + if request.system_message: + messages.append({"role": "system", "content": request.system_message}) + if request.few_shots: + messages.extend(request.few_shots) + messages.append({"role": "user", "content": request.text}) + return messages + + @staticmethod + def _parse_content(content: Any) -> str: + """ + Parse the content from the API response, handling different response formats. + Claude 3.7 Sonnet with Extended thinking enabled returns a structured response format + where "content" field is a list of objects rather than a simple string: + - It contains a "reasoning" object with the model's thought process + - It also contains a "text" object with the final response text + We need to handle both formats for compatibility + + Args: + content: The content field from the API response, which could be either + a string (standard format) or a list of structured objects + (Extended thinking mode with reasoning and text objects). + + Returns: + str: The extracted text content, combining all "text" type objects in + Extended thinking mode or returning the raw string in standard mode. + + Raises: + TypeError: If the content format is unexpected. + """ + if isinstance(content, str): + return content + elif isinstance(content, list): + return "".join( + content_part.get("text", "") + for content_part in content + if content_part.get("type") == "text" + ) + raise TypeError(f"Unexpected content format: {type(content).__name__}") + + def _get_stop_condition(self, retry_state: RetryCallState) -> Callable[[RetryCallState], bool]: + """ + Determine the stop condition for retries based on the error type. + + Args: + retry_state: The current state of the retry mechanism. + + Returns: + Callable[[RetryCallState], bool]: A stop condition function appropriate for the current error type. + """ + exception = retry_state.outcome.exception() + current_attempt = retry_state.attempt_number + + if isinstance(exception, httpx.HTTPStatusError): + # Check for validation errors first - always stop immediately (no retry) + if self._is_parameter_validation_error(exception): + self.logger.info( + f"Validation error detected. Stopping retry immediately after attempt {current_attempt}." + ) + return stop_after_attempt(1)(retry_state) + + # Then check for backpressure errors (will retry) + elif self._is_backpressure(exception): + max_attempts = self.max_retries_backpressure + self.logger.info( + f"Backpressure error occurred (status: {exception.response.status_code}). " + f"Current attempt: {current_attempt}, max attempts: {max_attempts}, " + f"will {'stop' if current_attempt >= max_attempts else 'retry'}" + ) + return stop_after_attempt(max_attempts)(retry_state) + + # Other HTTP errors (will retry) + else: + max_attempts = self.max_retries_other + self.logger.info( + f"HTTP error occurred (status: {exception.response.status_code}). " + f"Current attempt: {current_attempt}, max attempts: {max_attempts}, " + f"will {'stop' if current_attempt >= max_attempts else 'retry'}" + ) + else: + # General non-HTTP errors (will retry) + max_attempts = self.max_retries_other + self.logger.info( + f"General error occurred ({type(exception).__name__}). " + f"Current attempt: {current_attempt}, max attempts: {max_attempts}, " + f"will {'stop' if current_attempt >= max_attempts else 'retry'}" + ) + return stop_after_attempt(max_attempts)(retry_state) + + def _is_parameter_validation_error(self, error: httpx.HTTPStatusError) -> bool: + """ + Check if the error is a parameter validation error that should not be retried. + + Args: + error: The HTTP status error + + Returns: + bool: True if this is a validation error, False otherwise + """ + if error.response.status_code == codes.BAD_REQUEST: + try: + response_data = error.response.json() + if response_data.get("message"): + error_msg = str(response_data["message"]).lower() + # List of patterns indicating validation errors (not backpressure) + validation_patterns = [ + "extra inputs are not permitted", + "invalid parameter", + "parameter validation", + "required parameter", + "missing parameter", + "invalid format", + "not a valid" + ] + + for pattern in validation_patterns: + if pattern in error_msg: + self.logger.debug( + f"Parameter validation error detected: '{error_msg}'. " + f"Will not retry as this is not a backpressure issue." + ) + return True + except (json.JSONDecodeError, AttributeError): + pass + return False + + def _is_backpressure(self, error: httpx.HTTPStatusError) -> bool: + """ + Check if the error is due to backpressure (HTTP 429 or 503) or + if it's a 400 with null message from Databricks Claude endpoint. + """ + if error.response.status_code in (codes.TOO_MANY_REQUESTS, codes.SERVICE_UNAVAILABLE): + return True + # Treating cases where the token limit is reached in a single inference as backpressure causes excessive retries, so it is temporarily commented out. + # if self._is_databricks_claude_specific_backpressure(error): + # return True + return False + + def _is_databricks_claude_specific_backpressure(self, error: httpx.HTTPStatusError) -> bool: + """ + Check for Databricks Claude specific backpressure pattern: + HTTP 400 with null message field. + + This is a temporary workaround for Databricks Claude endpoints + that may return 400 instead of 429 for certain size-related limitations. + """ + if (error.response.status_code == codes.BAD_REQUEST and + "databricks-claude" in self.endpoint_name): + try: + response_data = error.response.json() + if response_data.get("message") is None: + self.logger.warning( + "Databricks Claude specific backpressure detected: " + "400 BAD_REQUEST with null message. Treating as backpressure." + ) + return True + except (json.JSONDecodeError, AttributeError): + pass + return False + + def _log_request_args(self, request_args: Dict[str, Any], request_index: int, max_head: int = 200, max_tail=100, level=logging.INFO) -> None: + """ + Log the request arguments for debugging purposes. + + Args: + request_args: The arguments used in the request + request_index: Index of the request being processed + max_head: Maximum number of characters to log from the start of the request body + max_tail: Maximum number of characters to log from the end of the request body + level: The logging level to use (default is INFO) + """ + request_body = request_args.get('json', {}) + body_str = json.dumps(request_body, ensure_ascii=False) + body_len = len(body_str) + self.logger.log( + level, + f"Request for index {request_index} - Request details: " + f"Size: {body_len} chars, " + f"Content: {body_str[:max_head]}... [truncated] ...{body_str[-max_tail:]}" + ) + + def _log_response_details(self, response: httpx.Response, request_index: int, max_head: int = 200, max_tail: int = 100, level=logging.DEBUG) -> None: + """ + Log the detailed response data at DEBUG level, with truncation. + + Args: + response: The HTTP response object + request_index: Index of the request being processed + max_head: Maximum characters to show from the beginning for text content + max_tail: Maximum characters to show from the end for text content + level: The logging level to use (default is DEBUG) + """ + try: + # For JSON responses, use the dictionary truncation method + log_data = self._truncate_long_strings_in_dict(response.json()) + body_content = json.dumps(log_data, ensure_ascii=False) + content_type = "JSON" + except json.JSONDecodeError: + # For non-JSON responses, use simple text truncation + body_str = response.text + body_content = f"{body_str[:max_head]}... [truncated] ...{body_str[-max_tail:]}" + content_type = "non-JSON" + + self.logger.log( + level, + f"Request for index {request_index} - Response details ({content_type}): " + f"Status: {response.status_code}, " + f"Headers: {dict(response.headers)}, " + f"Body: {body_content}" + ) + + def _truncate_long_strings_in_dict(self, data, max_length: int = 2000): + """ + Recursively truncate long string values in a dictionary or list. + + Args: + data: The dictionary or list to process + max_length: Maximum length for string values + + Returns: + A copy of the data with long strings truncated + """ + if isinstance(data, dict): + return {k: self._truncate_long_strings_in_dict(v, max_length) for k, v in data.items()} + elif isinstance(data, list): + return [self._truncate_long_strings_in_dict(item, max_length) for item in data] + elif isinstance(data, str) and len(data) > max_length: + # Truncate long string values + return data[:max_length // 2] + f" ... [truncated {len(data) - max_length} chars] ... " + data[-max_length // 2:] + else: + return data + + def _log_http_error(self, e: httpx.HTTPStatusError, request_index: int, logger=None) -> None: + """ + Log detailed information for HTTP errors. + + Args: + e: The HTTP error exception + request_index: Index of the request that caused the error + logger: Optional logger to use instead of self.logger + """ + logger = logger or self.logger + logger.error(f"HTTP error in predict for index {request_index}: {str(e)}") + + # Log response body + try: + response_body = e.response.json() + logger.error(f"Error response body (JSON): {json.dumps(response_body, ensure_ascii=False)}") + except json.JSONDecodeError: + try: + response_body = e.response.text + logger.error(f"Error response body (Text): {response_body[:2000]}") + except Exception as text_err: + logger.error(f"Failed to extract response body text: {str(text_err)}") + + # Log response headers + try: + logger.error(f"Response headers: {dict(e.response.headers)}") + except Exception as header_err: + logger.error(f"Failed to extract response headers: {str(header_err)}") + + def _log_request_error(self, e: httpx.RequestError, request_index: int, logger=None) -> None: + """ + Log detailed information for request errors. + + Args: + e: The request error exception + request_index: Index of the request that caused the error + logger: Optional logger to use instead of self.logger + """ + logger = logger or self.logger + logger.error(f"Request error in predict for index {request_index}: {str(e)}") + + # Log request details if available + if hasattr(e, 'request'): + try: + logger.error(f"Request URL: {e.request.url}") + logger.error(f"Request method: {e.request.method}") + except Exception: + pass + + logger.error(f"Traceback: {traceback.format_exc()}") + + def _log_general_error(self, e: Exception, request_index: int, logger=None) -> None: + """ + Log detailed information for general (non-HTTP) errors. + + Args: + e: The exception that was raised + request_index: Index of the request that caused the error + logger: Optional logger to use instead of self.logger + """ + logger = logger or self.logger + logger.error(f"Unexpected error in predict for index {request_index}: {str(e)}") + logger.error(f"Error type: {type(e).__name__}") + logger.error(f"Traceback: {traceback.format_exc()}") + + +class AsyncCounter: + """ + Thread-safe counter for async operations. + """ + + def __init__(self): + self.value = 0 + + async def increment(self): + """Increment the counter value.""" + self.value += 1 diff --git a/sql2dbx/notebooks/pyscripts/cell_split_helper.py b/sql2dbx/notebooks/pyscripts/cell_split_helper.py new file mode 100644 index 00000000..d101bd7f --- /dev/null +++ b/sql2dbx/notebooks/pyscripts/cell_split_helper.py @@ -0,0 +1,311 @@ +import ast +import logging +from typing import List, Optional, Tuple, Union + +from .utils import setup_logger + + +class CellSplitHelper: + """ + A helper class for splitting Python code into Databricks notebook cells. + + This class provides functionality to intelligently split Python code into + separate cells, considering code structure, comments, and blank lines. + + Attributes: + COMMAND_MARKER (str): The marker used to denote the start of a new cell in Databricks notebooks. + logger (logging.Logger): Logger for the class. + + Main methods: + split_cells: Splits the input code into Databricks notebook cells. + """ + COMMAND_MARKER = "# COMMAND ----------" + + def __init__(self, log_level: Union[int, str] = logging.INFO): + self.logger = setup_logger('CellSplitHelper', log_level) + + def split_cells(self, code: Optional[str]) -> Optional[str]: + """ + Split the input code into Databricks notebook cells. + + This method performs the following operations: + 1. Parse the code into an Abstract Syntax Tree (AST) + 2. Extract top-level nodes from the AST + 3. Merge nearby nodes into blocks + 4. Adjust block start lines to include preceding comments and blank lines + 5. Extend block endings downward to include trailing blank lines + 6. Format the output with `# COMMAND ----------` separators + + If a SyntaxError occurs during parsing, the entire code is returned as a single cell. + + Args: + code (Optional[str]): The input Python code to be split into cells. + + Returns: + Optional[str]: The code split into Databricks notebook cells, with `# COMMAND ----------` separating each cell. + Returns None if the input is invalid or empty. + Returns a single cell containing the entire code if a SyntaxError occurs during parsing. + """ + self.logger.debug(f"Starting cell splitting process.") + + # If the code is None or not a string, return None + if code is None or not isinstance(code, str): + self.logger.warning(f"Invalid input: code is None or not a string.") + return None + + # If stripping the code results in an empty string, return None + code_stripped = code.strip() + if not code_stripped: + self.logger.warning(f"Invalid input: code is empty after stripping.") + return None + + # Parse with AST + # If parsing fails, return the entire code as a single cell + try: + tree = ast.parse(code) + self.logger.debug("AST parsing successful.") + except Exception as e: + self.logger.error(f"Error occurred while parsing the code: {type(e).__name__} - {e}") + return self._create_single_cell(code) + + # Get top-level nodes + top_level_nodes = self._get_top_level_nodes(tree) + self.logger.debug(f"Found {len(top_level_nodes)} top-level nodes") + if not top_level_nodes: + self.logger.warning("No top-level nodes found, returning single cell") + return self._create_single_cell(code) + + # Merge nearby nodes into single blocks + merged_blocks = self._merge_nearby_nodes(top_level_nodes) + self.logger.debug(f"Merged into {len(merged_blocks)} blocks") + + # Adjust block start lines to include preceding comments and blank lines + lines = code.split("\n") + adjusted_blocks = self._adjust_block_start_lines(merged_blocks, lines) + self.logger.debug(f"Adjusted block start lines, total blocks: {len(adjusted_blocks)}") + + # Extend block endings downward and format the output + output_lines = self._extend_and_format_blocks(adjusted_blocks, lines) + self.logger.debug(f"Extended block endings and formatted output, total lines: {len(output_lines)}") + + # Merge consecutive COMMAND_MARKER lines + final_code = self._clean_command_lines("\n".join(output_lines)) + self.logger.debug("Final code generated.") + return final_code + + def _extend_and_format_blocks(self, adjusted_blocks: List[Tuple[int, int]], lines: List[str]) -> List[str]: + """ + Extend block endings downward to include trailing blank lines and format the output. + This method processes each block, extends it to include trailing blank lines, + and formats the output with COMMAND_MARKER separators. + + Args: + adjusted_blocks (List[Tuple[int, int]]): A list of tuples representing adjusted blocks. + lines (List[str]): The list of code lines. + + Returns: + List[str]: A list of formatted output lines including COMMAND_MARKER and extended blocks. + """ + output_lines = [] + last_used_line = 0 + + for (block_start, block_end) in adjusted_blocks: + # Adjust to avoid overlap with the previous block + if block_start <= last_used_line: + block_start = last_used_line + 1 + + block_start = max(1, block_start) + block_end = min(block_end, len(lines)) + + # Include trailing blank lines + extended_end = self._extend_block_downward_for_blank_lines( + block_end, lines + ) + + if extended_end > block_end: + block_end = extended_end + + if block_start <= block_end: + output_lines.append(self.COMMAND_MARKER) + for ln in range(block_start, block_end + 1): + output_lines.append(lines[ln - 1]) + last_used_line = max(last_used_line, block_end) + + return output_lines + + def _get_top_level_nodes(self, tree: ast.AST) -> List[Tuple[int, int]]: + self.logger.debug("Getting top-level nodes from AST") + """ + Get the top-level nodes from the AST and sort them by line number. + + Args: + tree (ast.AST): The AST to analyze. + + Returns: + List[Tuple[int, int]]: A sorted list of tuples containing the start and end line numbers of top-level nodes. + """ + top_level_nodes = [] + for node in ast.iter_child_nodes(tree): + start_line = getattr(node, "lineno", None) + end_line = self._get_node_end_line(node) + if start_line is not None and end_line is not None: + top_level_nodes.append((start_line, end_line)) + return sorted(top_level_nodes, key=lambda x: x[0]) + + def _merge_nearby_nodes(self, nodes: List[Tuple[int, int]]) -> List[Tuple[int, int]]: + self.logger.debug(f"Merging nearby nodes, initial count: {len(nodes)}") + """ + Merge nearby nodes into single blocks. + This method combines nodes that are consecutive or very close (within one line) + into a single block. This helps to prevent overly fine-grained splitting of the code. + + Args: + nodes (List[Tuple[int, int]]): A list of tuples containing the start and end line numbers of nodes. + + Returns: + List[Tuple[int, int]]: A list of tuples representing merged blocks of nearby nodes. + """ + if not nodes: + return [] + + merged_blocks = [] + cur_start, cur_end = nodes[0] + for i in range(1, len(nodes)): + s, e = nodes[i] + # If the node starts within one line of the previous block's end, include it in the same block + if s <= cur_end + 1: + cur_end = max(cur_end, e) + else: + merged_blocks.append((cur_start, cur_end)) + cur_start, cur_end = s, e + merged_blocks.append((cur_start, cur_end)) + return merged_blocks + + def _adjust_block_start_lines(self, merged_blocks: List[Tuple[int, int]], lines: List[str]) -> List[Tuple[int, int]]: + self.logger.debug(f"Adjusting block start lines, block count: {len(merged_blocks)}") + """ + Adjust block start lines to include preceding comments and blank lines. + This method ensures that each block includes any relevant comments or blank lines + that precede it, helping to maintain context and readability in the split code. + + Args: + merged_blocks (List[Tuple[int, int]]): A list of tuples representing merged blocks of nearby nodes. + lines (List[str]): The list of code lines. + + Returns: + List[Tuple[int, int]]: A list of tuples representing adjusted blocks with updated start lines. + """ + adjusted_blocks = [] + used_end_line = 0 + for (start, end) in merged_blocks: + real_start = self._pull_up_leading_comments_and_blank_lines( + start_line=start, + lines=lines, + limit_line=used_end_line + 1 # Don't go above the end of the previous block + 1 + ) + adjusted_blocks.append((real_start, end)) + used_end_line = end + return adjusted_blocks + + def _create_single_cell(self, code: str) -> str: + """ + Create a single cell with the given code. + + Args: + code (str): The code to be included in the cell. + + Returns: + str: A single cell with COMMAND_MARKER and the given code. + """ + return self._clean_command_lines(f"{self.COMMAND_MARKER}\n{code}") + + def _get_node_end_line(self, node: ast.AST) -> Optional[int]: + """ + Recursively get the end line of an AST node. + This method determines the last line of a given AST node by checking its + end_lineno attribute if available, or recursively checking its child nodes. + + The algorithm works as follows: + 1. If the node has an end_lineno attribute, return its value. + 2. Otherwise, start with the node's lineno as the maximum line. + 3. Recursively check all child nodes: + - If a child's end line is greater than the current maximum, update the maximum. + 4. Return the maximum line number found, or None if no valid line number was found. + + Args: + node (ast.AST): The AST node to analyze. + + Returns: + Optional[int]: The end line number of the node, or None if not found. + """ + if hasattr(node, "end_lineno") and node.end_lineno is not None: + return node.end_lineno + max_line = getattr(node, "lineno", 0) + for child in ast.iter_child_nodes(node): + child_end = self._get_node_end_line(child) + if child_end and child_end > max_line: + max_line = child_end + return max_line if max_line != 0 else None + + def _pull_up_leading_comments_and_blank_lines( + self, start_line: int, lines: List[str], limit_line: int + ) -> int: + """ + Include blank lines or comment lines above the block start line in the same block. + + Args: + start_line (int): The initial start line of the block. + lines (List[str]): The list of code lines. + limit_line (int): The upper limit line number to consider. + + Returns: + int: The new start line for the block, including leading comments and blank lines. + """ + idx = start_line - 1 + while idx - 1 >= limit_line - 1: + content = lines[idx - 1].strip() + if content == "" or content.startswith("#"): + idx -= 1 + else: + break + if idx < limit_line - 1: + idx = limit_line - 1 + return idx + 1 + + def _extend_block_downward_for_blank_lines( + self, block_end: int, lines: List[str] + ) -> int: + """ + Extend the block end to include trailing blank lines. + + Args: + block_end (int): The initial end line of the block. + lines (List[str]): The list of code lines. + + Returns: + int: The new end line for the block, including trailing blank lines. + """ + idx = block_end + while idx < len(lines): + content = lines[idx].strip() + if content == "": + idx += 1 + else: + break + return idx + + def _clean_command_lines(self, code: str) -> str: + """ + Merge consecutive COMMAND_MARKER lines into a single line. + + Args: + code (str): The code with potentially consecutive COMMAND_MARKER lines. + + Returns: + str: The code with merged COMMAND_MARKER lines. + """ + marker = self.COMMAND_MARKER + double_marker = f"{marker}\n{marker}" + while double_marker in code: + code = code.replace(double_marker, marker) + return code diff --git a/sql2dbx/notebooks/pyscripts/conversion_prompt_helper.py b/sql2dbx/notebooks/pyscripts/conversion_prompt_helper.py new file mode 100644 index 00000000..b4191b61 --- /dev/null +++ b/sql2dbx/notebooks/pyscripts/conversion_prompt_helper.py @@ -0,0 +1,224 @@ +from enum import Enum +from pathlib import Path +from typing import Dict, List, TypedDict, Union + +import yaml +from omegaconf import OmegaConf + + +class SupportedSQLDialect(Enum): + """Enumeration of supported SQL dialects, paired with default YAML filenames.""" + MYSQL = ("mysql", "mysql_to_databricks_notebook.yml") + NETEZZA = ("netezza", "netezza_to_databricks_notebook.yml") + ORACLE = ("oracle", "oracle_to_databricks_notebook.yml") + POSTGRESQL = ("postgresql", "postgresql_to_databricks_notebook.yml") + REDSHIFT = ("redshift", "redshift_to_databricks_notebook.yml") + SNOWFLAKE = ("snowflake", "snowflake_to_databricks_notebook.yml") + TERADATA = ("teradata", "teradata_to_databricks_notebook.yml") + TSQL = ("tsql", "tsql_to_databricks_notebook.yml") + + @property + def dialect_name(self): + return self.value[0] + + @property + def default_yaml_filename(self): + return self.value[1] + + +class FewShot(TypedDict): + """Type definition for few-shot examples in the conversion process. + + Attributes: + role: The role of the example (e.g., 'user', 'assistant'). + content: The content of the example. + """ + role: str + content: str + + +class ConversionPromptHelper: + """Helper class for managing SQL code conversion prompts. + + This class orchestrates the prompt management for SQL code conversion, + handling system messages and few-shot examples. + """ + _CONVERSION_PROMPT_YAML_DIR_NAME = "conversion_prompt_yaml" + _COMMON_INSTRUCTION_YAML_DIR_NAME = "common_instructions" + _COMMON_INSTRUCTION_YAML_FILE_NAME = "sql_to_databricks_notebook_common_python.yml" + + def __init__(self, conversion_prompt_yaml: str, comment_lang: str = None): + """Initialize the ConversionPromptHelper. + + Args: + conversion_prompt_yaml: Path to the YAML file containing prompts. + comment_lang: Language to be used for comments in the converted code. + """ + self.prompt_config = PromptConfig( + conversion_prompt_yaml=conversion_prompt_yaml, + comment_lang=comment_lang + ) + + def get_system_message(self) -> str: + """Retrieve the system message for the conversion process. + + Returns: + The formatted system message with the specified comment language. + """ + return self.prompt_config.get_system_message() + + def get_few_shots(self) -> List[FewShot]: + """Retrieve the few-shot examples for the conversion process. + + Returns: + A list of few-shot examples to be used in the conversion. + """ + return self.prompt_config.get_few_shots() + + @staticmethod + def get_supported_sql_dialects() -> List[str]: + """Return a list of supported SQL dialects.""" + return [dialect.dialect_name for dialect in SupportedSQLDialect] + + @staticmethod + def get_default_yaml_for_sql_dialect(dialect: str) -> str: + """Return the full path of the default YAML for the given dialect. + + Args: + dialect: SQL dialect name (e.g., 'tsql', 'snowflake'). + + Returns: + The resolved file path (absolute) for the default YAML. + + Raises: + ValueError: If the dialect is not supported. + """ + # Find the matching enum member + for item in SupportedSQLDialect: + if item.dialect_name == dialect: + yaml_path = ConversionPromptHelper._get_yaml_base_dir() / item.default_yaml_filename + return str(yaml_path.resolve()) + raise ValueError(f"Unsupported sql dialect: {dialect}") + + @staticmethod + def _get_common_instruction_yaml() -> str: + """Return the full path of the common instruction YAML. + + Returns: + The resolved file path (absolute) for the common instruction YAML. + """ + yaml_path = ConversionPromptHelper._get_yaml_base_dir() / ConversionPromptHelper._COMMON_INSTRUCTION_YAML_DIR_NAME / \ + ConversionPromptHelper._COMMON_INSTRUCTION_YAML_FILE_NAME + return str(yaml_path.resolve()) + + @staticmethod + def _get_yaml_base_dir() -> Path: + """Return the base directory for YAML files. + + Returns: + The absolute path to the base directory containing YAML files. + """ + base_dir = Path(__file__).parent + return base_dir / ConversionPromptHelper._CONVERSION_PROMPT_YAML_DIR_NAME + + +class PromptConfig: + """Configuration class for managing conversion prompts. + + This class handles loading and managing prompt configurations from YAML files. + """ + + def __init__(self, conversion_prompt_yaml: str, comment_lang: str = None): + """Initialize the PromptConfig. + + Args: + conversion_prompt_yaml: Path to the YAML file containing prompts. + comment_lang: Language to be used for comments. + """ + self.conversion_prompt_yaml = conversion_prompt_yaml + self.comment_lang = comment_lang + self._prompts = self._load_prompts() + + def get_system_message(self) -> str: + """Get system message with comment language interpolated. + + Returns: + The system message with the comment language placeholders replaced. + """ + system_message = self._prompts["system_message"] + if self.comment_lang: + system_message = system_message.replace("{comment_lang}", self.comment_lang) + return system_message + + def get_few_shots(self) -> List[FewShot]: + """Get few-shot examples from the loaded prompts. + + Returns: + A list of few-shot examples, or an empty list if none are defined. + """ + return self._prompts.get("few_shots", []) + + def _load_prompts(self) -> Dict: + """Load prompts from the YAML file. + + Returns: + A dictionary containing the loaded prompts. + + Raises: + FileNotFoundError: If the specified YAML file is not found. + ValueError: If the YAML content is invalid. + """ + try: + common_yaml = self._load_yaml_file(ConversionPromptHelper._get_common_instruction_yaml()) + custom_yaml = self._load_yaml_file(self.conversion_prompt_yaml) + prompts = self._merge_yaml_files(common_yaml, custom_yaml) + if "system_message" not in prompts: + raise ValueError("YAML must contain 'system_message' key") + return prompts + except Exception as e: + raise Exception(f"Failed to load custom prompts: {e}") + + @staticmethod + def _load_yaml_file(file_path: Union[str, Path]) -> Dict: + """Common helper method to load a YAML file. + + Args: + file_path: Path to the YAML file to be loaded (string or Path object). + + Returns: + The loaded YAML content as a dictionary. + + Raises: + FileNotFoundError: If the file does not exist. + ValueError: If the YAML content is not a dictionary. + """ + path = Path(file_path) if not isinstance(file_path, Path) else file_path + if not path.exists(): + raise FileNotFoundError(f"YAML file not found: {path}") + with open(path, 'r', encoding='utf-8') as f: + yaml_content = yaml.safe_load(f) + if not isinstance(yaml_content, dict): + raise ValueError(f"YAML content must be a dictionary: {path}") + return yaml_content + + @staticmethod + def _merge_yaml_files(common_yaml: dict, custom_yaml: dict) -> dict: + """ + Merges two YAML configuration dictionaries into a single dictionary. + + This method combines the keys and values from `common_yaml` and `custom_yaml`. + If there are overlapping keys, the values from `custom_yaml` will take precedence. + The resulting dictionary is resolved using OmegaConf to ensure all references + and interpolations are processed. + + Args: + common_yaml (dict): The base YAML configuration dictionary. + custom_yaml (dict): The custom YAML configuration dictionary that overrides + or extends the base configuration. + + Returns: + dict: A merged and resolved dictionary containing the combined configuration. + """ + combined_yaml = {**common_yaml, **custom_yaml} + conf = OmegaConf.create(combined_yaml) + return OmegaConf.to_container(conf, resolve=True) diff --git a/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/common_instructions/sql_to_databricks_notebook_common_python.yml b/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/common_instructions/sql_to_databricks_notebook_common_python.yml new file mode 100644 index 00000000..8d857a41 --- /dev/null +++ b/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/common_instructions/sql_to_databricks_notebook_common_python.yml @@ -0,0 +1,95 @@ +common_python_instructions_and_guidelines: | + # Instructions + 1. Convert the input SQL statements into Python code suitable for execution in a Databricks notebook. + 2. Strictly adhere to the guidelines provided below. + 3. The output should be Python code and Python comments only. + 4. DO NOT omit any part of the conversion. Instead, include any non-convertible parts as Python comments in the output code, along with explanations of the necessary changes or alternatives. + 5. DO NOT output explanation strings between cells. Just output Python code and Python comments. + 6. DO NOT include cell separators or notebook-level Markdown descriptions as these will be added by another program. + 7. If you receive a message like "Please continue," simply continue the conversion from where you left off without adding any extra phrases. + + # Guidelines + ## 1. Python-based: + - Use Python as the primary programming language. + + ## 2. Databricks Environment: + - Include a setup cell at the beginning with any necessary imports. + - DO NOT create a new SparkSession in the notebook. Databricks provides it by default. + + ## 3. Spark SQL Execution: + - Use `spark.sql()` to execute SQL statements. + - Write SQL statements directly within `spark.sql()`. + - DO NOT store SQL statements in variables unless absolutely necessary. If you must store SQL statements in variables, prefix the variable names with `qs_` for identification. + + ## 4. Comments: + - Add Python comments in {comment_lang} as needed to explain the code's logic. + - Start each cell with a high-level Python comment describing its purpose. + + # 5. Transaction and Rollback Handling: + - Databricks does not fully support transaction control statements like `BEGIN`, `COMMIT`, or `ROLLBACK`. + - If transactional behavior is required, simulate it using Python's `try-except-finally` blocks for error handling. + - For rolling back a Delta table: + - Retrieve the most recent update timestamp from the first row of the table's HISTORY (inside the try block). + - Use the `RESTORE TABLE TIMESTAMP AS OF` command in a finally block to revert the Delta table. + + ## 6. Looping Constructs: + - DO NOT use `collect()`, for/while loops, or `iterrows()` for large data processing. + - Replace cursor/row-by-row logic with Spark DataFrame operations (`map`, `filter`, `groupBy`, `merge`, etc.). + - Use JOINs for bulk operations rather than iterating over rows. + + ## 7. Handling Non-Compatible Syntax + - If you encounter vendor-specific SQL syntax or functions that have no direct Spark/Delta equivalent, comment them out in Python with an explanation. + - If possible, provide an alternative approach (for example, rewrite CROSS APPLY with a lateral view + `explode()`). + + ## 8. Table, View, Column, and Schema Names + - Keep the original names as much as possible. + - If names include unsupported characters such as `@`, `#`, or `$`, remove or replace them. + - Avoid square brackets `[]`; use backticks `` for names with special characters or spaces. + + ## 9. Code Organization + - Group related SQL statements together into logical units (setup, transformations, cleanup, etc.). + - DO NOT output extra Markdown or notebook cell separators; only Python code and comments. + + ## 10. Error Handling + - Wrap critical operations in `try-except` blocks to gracefully handle failures. + - Optionally call `dbutils.notebook.exit("")` if you need to stop execution upon errors. + + ## 11. Parameter Handling with Widgets + - Use `dbutils.widgets.text()`, `dbutils.widgets.dropdown()`, etc. for input parameters. + - Use `dbutils.widgets.get()` to retrieve parameter values. + - Convert parameter values to the correct data type (e.g., `int(dbutils.widgets.get("param_name"))`). + - Define sensible default values where applicable. + - If the SQL includes stored procedures or other parameterized routines, convert them accordingly: + - Create widgets for each parameter. + - Place widget creation at the beginning of the code. + - Refer to the widget values inside the converted Python logic or spark.sql() calls. + + ## 12. Delta Tables + - Assume that all referenced tables (unless explicitly created) are Delta tables. + + ## 13. CREATE TEMP TABLE + - Databricks does not support the same `CREATE TEMP TABLE` syntax as some other SQL dialects. + - Use a Delta table instead; drop it at the end of the script or in a `finally` clause. + - DO NOT rely on Spark TEMP VIEWs as an equivalent for a true “temp table.” + + ## 14. DELETE with Alias + - Databricks does not support using a table alias in a DELETE statement. + - Always specify the full table name: + ```sql + DELETE FROM MyTable WHERE ... + ``` + + ## 15. DELETE with JOIN + - Databricks does not support a `JOIN` clause directly within `DELETE`. + - Create a temporary view or subquery that performs the join logic, then run the `DELETE` referencing that view or subquery. + + ## 16. UPDATE Statements: + - Databricks does not support `FROM` clauses in `UPDATE`. + - Use `MERGE INTO` for statements that require a join. For example: + ```sql + MERGE INTO target_table t + USING source_table s + ON t.key = s.key + WHEN MATCHED THEN UPDATE SET ... + ``` + - If `MERGE INTO` is not suitable, create a temporary view or subquery for the join, then perform the `UPDATE`. \ No newline at end of file diff --git a/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/mysql_to_databricks_notebook.yml b/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/mysql_to_databricks_notebook.yml new file mode 100644 index 00000000..045facc6 --- /dev/null +++ b/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/mysql_to_databricks_notebook.yml @@ -0,0 +1,240 @@ +system_message: | + Convert MySQL code to Python code that runs on Databricks according to the following instructions and guidelines: + + # Input and Output + - Input: A single SQL file containing one or multiple MySQL statements (including but not limited to CREATE PROCEDURE statements). + - Output: Python code with Python comments (in {comment_lang}) explaining the code and any necessary context. + + ${common_python_instructions_and_guidelines} + + # MySQL Specific Guidelines + ## 1. Function Mapping: + Map dialect-specific functions to their Databricks equivalents as follows: + + - `CONVERT_TZ(timestamp, from_tz, to_tz)` => No direct equivalent; consider `to_utc_timestamp(...)` and `from_utc_timestamp(...)`. + - `CURDATE()`, `NOW()`, `SYSDATE()` => `current_date()`, `current_timestamp()`. + - `DATEDIFF(date1, date2)` => `datediff(date2, date1)` (note argument order). + - `DATE_ADD(date, INTERVAL n unit)` / `DATE_SUB(...)` => `date_add(...)` if operating in days, otherwise use `timestampadd` or similar logic. + - `DAYOFWEEK(...)`, `WEEK(...)`, `MONTH(...)`, `YEAR(...)`, etc. => Use Spark built-ins but be mindful of potential off-by-one or mode differences (e.g., `WEEKOFYEAR`). + - `FROM_UNIXTIME(ts, format)` => `from_unixtime(ts)` combined with `date_format(...)` if formatting is needed. + - `GROUP_CONCAT(col SEPARATOR ',')` => Use `concat_ws(',', collect_list(col))` or an equivalent approach in Spark SQL. + - `IFNULL(expr, replacement)` => `NVL(expr, replacement)` or `COALESCE(expr, replacement)`. + - `REGEXP_LIKE(str, pattern)` => `str RLIKE pattern` or `regexp_like(str, pattern)` (depending on Spark version). + - `STR_TO_DATE(expr, '%format')` => `to_date(expr, 'format')` or `to_timestamp(expr, 'format')` depending on whether you need a date or timestamp. + + ## 2. Data Type Mapping: + Map dialect-specific data types to Spark SQL data types as follows: + + - `BIGINT UNSIGNED` => Typically `INT`, `BIGINT`, or `DECIMAL(...)` if you need precision. + - `BLOB`, `LONGBLOB`, `MEDIUMBLOB`, `TINYBLOB` => `BINARY` or `STRING`. + - `CAST(... AS CHAR)` => `CAST(... AS STRING)`. + - `CAST(... AS UNSIGNED)` => `CAST(... AS INT)` or similar numeric type. + - `INT UNSIGNED` => Typically `INT`, `BIGINT`, or `DECIMAL(...)` if you need precision. + - `MEDIUMINT`, `TINYINT` => Often converted to `INT` (or `SMALLINT` for `TINYINT`). + - `MEDIUMTEXT`, `LONGTEXT`, `TEXT`, `TINYTEXT` => `STRING`. + - `YEAR` => `INT` or `STRING`. + + ## 3. Syntax Transformation: + Map dialect-specific syntax constructs to their Databricks equivalents as follows: + + - MySQL-specific commands such as `DELIMITER $$`, `LEAVE`, or `CALL procedure(...)` have no direct Spark SQL equivalent; comment them out or replace with Python logic. + - Stored procedures => Typically replaced with Databricks notebooks or Python functions using `dbutils.widgets` for parameter passing. + - Variables (e.g., `@var`) => Use Python variables or `dbutils.widgets` for external parameterization. + +few_shots: +# Example 1: Basic Queries with Temporary Tables +- role: user + content: | + CREATE TEMPORARY TABLE RecentOrders ( + OrderID INT, + CustomerName VARCHAR(100), + TotalAmount DECIMAL(10,2), + OrderDate DATE + ); + + INSERT INTO RecentOrders + SELECT + o.OrderID, + CONCAT(c.FirstName, ' ', c.LastName) AS CustomerName, + SUM(oi.Price * oi.Quantity) AS TotalAmount, + o.OrderDate + FROM Orders o + JOIN Customers c ON o.CustomerID = c.CustomerID + JOIN OrderItems oi ON o.OrderID = oi.OrderID + WHERE o.OrderDate > DATE_SUB(CURRENT_DATE(), INTERVAL 3 MONTH) + GROUP BY o.OrderID, c.FirstName, c.LastName, o.OrderDate; + + SELECT CustomerName, SUM(TotalAmount) AS TotalSpent + FROM RecentOrders + GROUP BY CustomerName + ORDER BY TotalSpent DESC + LIMIT 10; +- role: assistant + content: | + # Create a Delta table to mimic the temporary table + spark.sql(""" + CREATE OR REPLACE TABLE RecentOrders ( + OrderID INT, + CustomerName STRING, + TotalAmount DECIMAL(10,2), + OrderDate DATE + ) + """) + + # Populate the table + spark.sql(""" + INSERT INTO RecentOrders + SELECT + o.OrderID, + concat_ws(' ', c.FirstName, c.LastName) AS CustomerName, + SUM(oi.Price * oi.Quantity) AS TotalAmount, + o.OrderDate + FROM Orders o + JOIN Customers c ON o.CustomerID = c.CustomerID + JOIN OrderItems oi ON o.OrderID = oi.OrderID + WHERE o.OrderDate > date_sub(current_date(), 90) + GROUP BY o.OrderID, c.FirstName, c.LastName, o.OrderDate + """) + + # Query the top spenders + top_customers = spark.sql(""" + SELECT CustomerName, SUM(TotalAmount) AS TotalSpent + FROM RecentOrders + GROUP BY CustomerName + ORDER BY TotalSpent DESC + LIMIT 10 + """) + display(top_customers) + + # Clean up the equivalent of the temporary table + spark.sql("DROP TABLE IF EXISTS RecentOrders") + +# Example 2: Stored Procedure +- role: user + content: | + DELIMITER $$ + CREATE PROCEDURE UpdateProductPrice( + IN p_ProductID INT, + IN p_NewPrice DECIMAL(10,2) + ) + BEGIN + IF (SELECT COUNT(*) FROM Products WHERE ProductID = p_ProductID) = 0 THEN + SELECT 'Product not found' AS msg; + LEAVE proc; + END IF; + + UPDATE Products + SET Price = p_NewPrice + WHERE ProductID = p_ProductID; + + SELECT * FROM Products WHERE ProductID = p_ProductID; + END $$ + DELIMITER ; +- role: assistant + content: | + # Use widgets for procedure parameters + dbutils.widgets.text("ProductID", "") + dbutils.widgets.text("NewPrice", "") + + # Validate and retrieve parameters + try: + product_id = int(dbutils.widgets.get("ProductID")) + new_price = float(dbutils.widgets.get("NewPrice")) + except ValueError: + print("Invalid parameter types.") + dbutils.notebook.exit("Error") + + # Check if product exists + exists_count = spark.sql(f"SELECT COUNT(*) AS cnt FROM Products WHERE ProductID = {product_id}") \ + .collect()[0]['cnt'] + if exists_count == 0: + print("Product not found") + dbutils.notebook.exit("No such product") + + # Update the product’s price + spark.sql(f""" + UPDATE Products + SET Price = {new_price} + WHERE ProductID = {product_id} + """) + + # Return updated product + updated_df = spark.sql(f"SELECT * FROM Products WHERE ProductID = {product_id}") + display(updated_df) + +# Example 3: Stored Procedure with Transaction Handling +- role: user + content: | + DELIMITER $$ + CREATE PROCEDURE UpdateProductPriceWithTransaction( + IN p_ProductID INT, + IN p_NewPrice DECIMAL(10,2) + ) + BEGIN + START TRANSACTION; + + IF (SELECT COUNT(*) FROM Products WHERE ProductID = p_ProductID) = 0 THEN + ROLLBACK; + SELECT 'No such product' AS msg; + LEAVE trans_proc; + END IF; + + UPDATE Products + SET Price = p_NewPrice + WHERE ProductID = p_ProductID; + + COMMIT; + + SELECT * FROM Products WHERE ProductID = p_ProductID; + END $$ + DELIMITER ; +- role: assistant + content: | + # Widgets for parameters + dbutils.widgets.text("ProductID", "") + dbutils.widgets.text("NewPrice", "") + + # Convert inputs + try: + product_id = int(dbutils.widgets.get("ProductID")) + new_price = float(dbutils.widgets.get("NewPrice")) + except ValueError: + print("Invalid inputs.") + dbutils.notebook.exit("Error") + + # For rollback simulation, capture latest version (Delta only) + try: + hist = spark.sql("DESCRIBE HISTORY Products LIMIT 1").collect()[0] + restore_ts = hist["timestamp"] + except: + restore_ts = None + + try: + # Check if product exists + exists_count = spark.sql(f""" + SELECT COUNT(*) AS cnt + FROM Products + WHERE ProductID = {product_id} + """).collect()[0]['cnt'] + + if exists_count == 0: + print("No such product") + dbutils.notebook.exit("No action taken") + + # “Transaction” update + spark.sql(f""" + UPDATE Products + SET Price = {new_price} + WHERE ProductID = {product_id} + """) + + except Exception as e: + # Mimic ROLLBACK with Delta RESTORE + if restore_ts: + spark.sql(f"RESTORE TABLE Products TO TIMESTAMP AS OF '{restore_ts}'") + print("Rollback: Table restored to previous version.") + raise e + + # Return updated product + final_df = spark.sql(f"SELECT * FROM Products WHERE ProductID = {product_id}") + display(final_df) \ No newline at end of file diff --git a/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/netezza_to_databricks_notebook.yml b/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/netezza_to_databricks_notebook.yml new file mode 100644 index 00000000..99cb7dbf --- /dev/null +++ b/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/netezza_to_databricks_notebook.yml @@ -0,0 +1,239 @@ +system_message: | + Convert IBM Netezza SQL (including NZPLSQL) code to Python code that runs on Databricks according to the following instructions and guidelines: + + # Input and Output + - Input: A single SQL file containing one or multiple Netezza SQL/NZPLSQL statements. + - Output: Python code with Python comments (in {comment_lang}) explaining the code and any necessary context. + + ${common_python_instructions_and_guidelines} + + ## 1. Function Mapping + Map dialect-specific functions to their Databricks equivalents as follows: + + - `CONVERT(type, expr)` => `CAST(expr AS type)` + - `DECODE(expr, val1, res1, val2, res2, ...)` => Rewrite as `CASE WHEN expr=val1 THEN res1 WHEN expr=val2 THEN res2 ELSE ... END` + - `getdate()` => `current_timestamp()` or `now()` + - `IF(condition, value_if_true, value_if_false)` => Use `CASE WHEN condition THEN value_if_true ELSE value_if_false END` + - `INDEX(str, substr)` or `INSTR(str, substr)` => `locate(substr, str)` (or `charindex(substr, str)`) + - `INT8(expr)` => `CAST(expr AS BIGINT)` or `CAST(expr AS LONG)` + - `ISNULL(expr)` => Typically `coalesce(expr, )` or `NVL(expr, )` + - `LEN(expr)` => `length(expr)` + - `NULLIFZERO(expr)` => `NULLIF(expr, 0)` + - `STRTOK(str, delim, token_number)` => Use `split(str, delim)` and pick the element by index in a higher-level expression + - `TO_CHAR(expr)` (no format) => `CAST(expr AS STRING)` + - `TO_CHAR(expr, 'format')` => `date_format(expr, 'format')` + - `ZEROIFNULL(expr)` => `coalesce(expr, 0)` + + ## 2. Data Type Mapping + Map dialect-specific data types to Spark SQL data types as follows: + + - `BYTEINT` => `TINYINT` + - `CHAR` / `CHARACTER(n)` => `STRING` or `VARCHAR(n)` (Spark often treats both simply as `STRING`) + - `CHARACTER VARYING(n)` => `VARCHAR(n)` or `STRING` + - `datetime` => `DATE` or `TIMESTAMP` depending on usage + - `DOUBLE PRECISION` => `DOUBLE` + - `numeric(p, s)` => `DECIMAL(p, s)` if exact precision needed; otherwise `DOUBLE` or `FLOAT` + - `NVARCHAR(...)` => `STRING` + - `TIME` / `TIME WITH TIME ZONE` => `STRING` + - `VARCHAR(...)` => `STRING` + + ## 3. Syntax Transformation + Map dialect-specific syntax constructs to their Databricks equivalents as follows: + + - Ignore or comment out Netezza-specific clauses (e.g., `DISTRIBUTE ON(...)`, `ORGANIZE ON(...)`) because Spark does not support them. + - For other unsupported features, add Python comments explaining the reason and any possible workaround. + +few_shots: +# Example 1: Basic Queries with Temporary Tables +- role: user + content: | + CREATE TEMP TABLE RecentOrders ( + OrderID INT, + CustomerName VARCHAR(100), + TotalAmount DECIMAL(10,2), + OrderDate DATE + ) + DISTRIBUTE ON (OrderID); + + INSERT INTO RecentOrders + SELECT + o.OrderID, + c.FirstName || ' ' || c.LastName AS CustomerName, + SUM(oi.Price * oi.Quantity) AS TotalAmount, + o.OrderDate + FROM Orders o + JOIN Customers c ON o.CustomerID = c.CustomerID + JOIN OrderItems oi ON o.OrderID = oi.OrderID + WHERE o.OrderDate > ADD_MONTHS(current_date, -3) + GROUP BY o.OrderID, c.FirstName, c.LastName, o.OrderDate; + + SELECT CustomerName, SUM(TotalAmount) AS TotalSpent + FROM RecentOrders + GROUP BY CustomerName + ORDER BY TotalSpent DESC + LIMIT 10; +- role: assistant + content: | + # Create a Delta table (comment out DISTRIBUTE ON since Spark doesn't support it) + spark.sql(""" + CREATE OR REPLACE TABLE RecentOrders ( + OrderID INT, + CustomerName STRING, + TotalAmount DECIMAL(10,2), + OrderDate DATE + ) + """) + + # Insert records + spark.sql(""" + INSERT INTO RecentOrders + SELECT + o.OrderID, + concat_ws(' ', c.FirstName, c.LastName) AS CustomerName, + SUM(oi.Price * oi.Quantity) AS TotalAmount, + o.OrderDate + FROM Orders o + JOIN Customers c ON o.CustomerID = c.CustomerID + JOIN OrderItems oi ON o.OrderID = oi.OrderID + WHERE o.OrderDate > date_add(current_date(), -90) + GROUP BY o.OrderID, c.FirstName, c.LastName, o.OrderDate + """) + + # Fetch top 10 customers by total spending + top_customers = spark.sql(""" + SELECT CustomerName, SUM(TotalAmount) AS TotalSpent + FROM RecentOrders + GROUP BY CustomerName + ORDER BY TotalSpent DESC + LIMIT 10 + """) + + # Display results + display(top_customers) + + # Clean up (drop the table) + spark.sql("DROP TABLE IF EXISTS RecentOrders") + +# Example 2: Stored Procedure +- role: user + content: | + CREATE PROCEDURE UpdateProductPrice( + IN P_ProductID INT, + IN P_NewPrice DECIMAL(10,2) + ) + LANGUAGE NZPLSQL + AS + BEGIN_PROC + DECLARE + v_exists INT; + BEGIN + SELECT COUNT(*) + INTO v_exists + FROM Products + WHERE ProductID = P_ProductID; + + IF v_exists = 0 THEN + RAISE NOTICE 'Product not found.'; + RETURN; + END IF; + + UPDATE Products + SET Price = P_NewPrice + WHERE ProductID = P_ProductID; + + SELECT * FROM Products WHERE ProductID = P_ProductID; + END; + END_PROC; +- role: assistant + content: | + # Use widgets for parameter input + dbutils.widgets.text("ProductID", "") + dbutils.widgets.text("NewPrice", "") + + # Validate input + try: + product_id = int(dbutils.widgets.get("ProductID")) + new_price = float(dbutils.widgets.get("NewPrice")) + except ValueError: + print("Invalid parameters.") + dbutils.notebook.exit("Error") + + # Check if product exists + exists_df = spark.sql(f"SELECT 1 FROM Products WHERE ProductID = {product_id}") + if exists_df.count() == 0: + print("Product not found.") + dbutils.notebook.exit("No such product.") + + # Update price + spark.sql(f""" + UPDATE Products + SET Price = {new_price} + WHERE ProductID = {product_id} + """) + + # Return updated product + updated_df = spark.sql(f"SELECT * FROM Products WHERE ProductID = {product_id}") + display(updated_df) + +# Example 3: Stored Procedure with Transaction Handling +- role: user + content: | + CREATE PROCEDURE UpdateProductPriceWithTransaction( + IN P_ProductID INT, + IN P_NewPrice DECIMAL(10,2) + ) + LANGUAGE NZPLSQL + AS + BEGIN_PROC + DECLARE + v_errmsg VARCHAR(256); + BEGIN + BEGIN; + UPDATE Products + SET Price = P_NewPrice + WHERE ProductID = P_ProductID; + COMMIT; + SELECT * FROM Products WHERE ProductID = P_ProductID; + EXCEPTION + WHEN OTHERS THEN + GET DIAGNOSTICS EXCEPTION 1 v_errmsg = MESSAGE_TEXT; + RAISE NOTICE 'Error encountered: %', v_errmsg; + ROLLBACK; + RETURN; + END; + END_PROC; +- role: assistant + content: | + # Widgets for parameters + dbutils.widgets.text("ProductID", "") + dbutils.widgets.text("NewPrice", "") + + # Convert inputs + product_id = int(dbutils.widgets.get("ProductID")) + new_price = float(dbutils.widgets.get("NewPrice")) + + # Capture the current timestamp for rollback (Delta time travel) + try: + hist = spark.sql("DESCRIBE HISTORY Products LIMIT 1").collect()[0] + restore_ts = hist["timestamp"] + except: + restore_ts = None + + try: + # Perform the update + spark.sql(f""" + UPDATE Products + SET Price = {new_price} + WHERE ProductID = {product_id} + """) + + except Exception as e: + # Rollback simulation via restore if available + if restore_ts is not None: + spark.sql(f"RESTORE TABLE Products TO TIMESTAMP AS OF '{restore_ts}'") + print("Table restored to previous version.") + raise e + + # Show final result + final_df = spark.sql(f"SELECT * FROM Products WHERE ProductID = {product_id}") + display(final_df) \ No newline at end of file diff --git a/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/oracle_to_databricks_notebook.yml b/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/oracle_to_databricks_notebook.yml new file mode 100644 index 00000000..03085af0 --- /dev/null +++ b/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/oracle_to_databricks_notebook.yml @@ -0,0 +1,238 @@ +system_message: | + Convert Oracle SQL (including PL/SQL) code to Python code that runs on Databricks according to the following instructions and guidelines: + + # Input and Output + - Input: A single SQL file containing one or multiple Oracle SQL or PL/SQL statements (such as `CREATE OR REPLACE PROCEDURE`, anonymous PL/SQL blocks, etc.). + - Output: Python code with Python comments (in `{comment_lang}`) explaining the code and highlighting non-convertible parts. + + ${common_python_instructions_and_guidelines} + + # Oracle Specific Guidelines + ## 1. Function Mapping: + Map dialect-specific functions to their Databricks equivalents as follows: + + - `DBMS_OUTPUT.put_line(...)` => `print(...)` + - `NVL(expr, replacement)` => Remain as `NVL(...)` (Spark supports this). + - `RAISE_APPLICATION_ERROR(code, msg)` => `raise Exception(msg)` or `print(msg)`, depending on context. + - `SUBSTR(expr, start, length)` => `substring(expr, start, length)` or keep as `SUBSTR(...)`. + - `SYSDATE` => `current_date()` + - `SYSTIMESTAMP` => `current_timestamp()` + - `TO_CHAR(date_expr, 'fmt')` => `date_format(date_expr, 'Spark-compatible format')` + - `TO_NUMBER(expr)` => `CAST(expr AS DECIMAL(...))` or `CAST(expr AS INT)` depending on usage + - `TRIM(LEADING / TRAILING / BOTH)` => `LTRIM(...)`, `RTRIM(...)`, or `TRIM(...)`. If exact Oracle behavior is critical, comment and explain differences. + - `TRUNC(date_expr, 'DD')` => `date_trunc('DAY', date_expr)` + + ## 2. Data Type Mapping: + Map dialect-specific data types to Spark SQL data types as follows: + + - `BLOB` => `BINARY` + - `CLOB` / `NCLOB` => `STRING` + - `DATE` => `DATE` + - `NUMBER(p, s)` => `DECIMAL(p, s)` (or `INT` if p/s allow) + - `TIMESTAMP` => `TIMESTAMP` + - `VARCHAR2` / `NVARCHAR2` => `STRING` + + ## 3. Syntax Transformation: + Map dialect-specific syntax constructs to their Databricks equivalents as follows: + + - PL/SQL Conditionals (`IF` / `ELSIF` / `ELSE`) => Convert to Python `if` / `elif` / `else`. + - Sequences (`NEXTVAL`, `CURRVAL`) => No direct Spark equivalent. Comment out or implement a custom workaround (e.g., storing sequence values in a Delta table). + - `EXECUTE IMMEDIATE` => Build the dynamic SQL string in Python, then call `spark.sql(dynamic_sql)`. + - Outer Join Notation (`col = col(+)`) => Rewrite as `LEFT JOIN` or `RIGHT JOIN`. + - Bulk operations (`BULK COLLECT`, etc.) => Prefer set-based Spark DataFrame operations. + - Oracle Index / Partition Hints => Comment them out, since Spark does not support them. + - `IF EXISTS (SELECT ...)` => Check with `spark.sql(...).count() > 0`. + - Use Python variables (or widgets) instead of PL/SQL variables. + +few_shots: +# Example 1: Basic Queries with Temporary Tables +- role: user + content: | + CREATE GLOBAL TEMPORARY TABLE TEMP_ORDERS ( + OrderID NUMBER, + CustomerName VARCHAR2(100), + TotalAmount NUMBER(10,2), + OrderDate DATE + ) ON COMMIT PRESERVE ROWS; + + INSERT INTO TEMP_ORDERS + SELECT + o.OrderID, + c.FirstName || ' ' || c.LastName AS CustomerName, + SUM(oi.Price * oi.Quantity) AS TotalAmount, + o.OrderDate + FROM Orders o + JOIN Customers c ON o.CustomerID = c.CustomerID + JOIN OrderItems oi ON o.OrderID = oi.OrderID + WHERE o.OrderDate > ADD_MONTHS(SYSDATE, -3) + GROUP BY o.OrderID, c.FirstName, c.LastName, o.OrderDate; + + SELECT CustomerName, SUM(TotalAmount) AS TotalSpent + FROM TEMP_ORDERS + GROUP BY CustomerName + ORDER BY SUM(TotalAmount) DESC + FETCH FIRST 10 ROWS ONLY; +- role: assistant + content: | + # Create or replace a Delta table to mimic the temporary table + spark.sql(""" + CREATE OR REPLACE TABLE TempOrders ( + OrderID INT, + CustomerName STRING, + TotalAmount DECIMAL(10,2), + OrderDate DATE + ) + """) + + # Populate the temporary table + spark.sql(""" + INSERT INTO TempOrders + SELECT + o.OrderID, + concat_ws(' ', c.FirstName, c.LastName) AS CustomerName, + SUM(oi.Price * oi.Quantity) AS TotalAmount, + o.OrderDate + FROM Orders o + JOIN Customers c ON o.CustomerID = c.CustomerID + JOIN OrderItems oi ON o.OrderID = oi.OrderID + WHERE o.OrderDate > date_add(current_date(), -90) + GROUP BY o.OrderID, c.FirstName, c.LastName, o.OrderDate + """) + + # Query for top customers + top_customers = spark.sql(""" + SELECT CustomerName, SUM(TotalAmount) AS TotalSpent + FROM TempOrders + GROUP BY CustomerName + ORDER BY TotalSpent DESC + LIMIT 10 + """) + + # Display results + display(top_customers) + + # Explicitly drop the temporary equivalent table + spark.sql("DROP TABLE IF EXISTS TempOrders") + +# Example 2: Stored Procedure +- role: user + content: | + CREATE OR REPLACE PROCEDURE UpdateProductPrice( + p_ProductID IN NUMBER, + p_NewPrice IN NUMBER + ) AS + v_count NUMBER; + BEGIN + SELECT COUNT(*) + INTO v_count + FROM Products + WHERE ProductID = p_ProductID; + + IF v_count = 0 THEN + DBMS_OUTPUT.put_line('Product not found'); + RETURN; + END IF; + + UPDATE Products + SET Price = p_NewPrice + WHERE ProductID = p_ProductID; + + FOR rec IN ( + SELECT * FROM Products WHERE ProductID = p_ProductID + ) LOOP + DBMS_OUTPUT.put_line('Updated Product: ' || rec.ProductID || ' Price: ' || rec.Price); + END LOOP; + END; + / +- role: assistant + content: | + # Use widgets for parameter input + dbutils.widgets.text("ProductID", "") + dbutils.widgets.text("NewPrice", "") + + # Validate input + try: + product_id = int(dbutils.widgets.get("ProductID")) + new_price = float(dbutils.widgets.get("NewPrice")) + except ValueError: + print("Invalid parameters.") + dbutils.notebook.exit("Error") + + # Check if product exists + exists_df = spark.sql(f"SELECT 1 FROM Products WHERE ProductID = {product_id}") + if exists_df.count() == 0: + print("Product not found.") + dbutils.notebook.exit("No such product.") + + # Update price + spark.sql(f""" + UPDATE Products + SET Price = {new_price} + WHERE ProductID = {product_id} + """) + + # Return updated product + updated_df = spark.sql(f"SELECT * FROM Products WHERE ProductID = {product_id}") + display(updated_df) + +# Example 3: Stored Procedure with Transaction Handling +- role: user + content: | + CREATE OR REPLACE PROCEDURE UpdateProductPriceWithTransaction( + p_ProductID IN NUMBER, + p_NewPrice IN NUMBER + ) AS + BEGIN + SAVEPOINT start_trans; + BEGIN + UPDATE Products + SET Price = p_NewPrice + WHERE ProductID = p_ProductID; + COMMIT; + EXCEPTION + WHEN OTHERS THEN + ROLLBACK TO start_trans; + RAISE_APPLICATION_ERROR(-20001, 'Error updating product'); + END; + + FOR rec IN ( + SELECT * FROM Products WHERE ProductID = p_ProductID + ) LOOP + DBMS_OUTPUT.put_line('Updated Product: ' || rec.ProductID || ' Price: ' || rec.Price); + END LOOP; + END; + / +- role: assistant + content: | + # Widgets for parameters + dbutils.widgets.text("ProductID", "") + dbutils.widgets.text("NewPrice", "") + + # Convert inputs + product_id = int(dbutils.widgets.get("ProductID")) + new_price = float(dbutils.widgets.get("NewPrice")) + + # Attempt to capture the current table version/timestamp for rollback + try: + hist = spark.sql("DESCRIBE HISTORY Products LIMIT 1").collect()[0] + restore_ts = hist["timestamp"] + except: + restore_ts = None + + try: + # Perform the update + spark.sql(f""" + UPDATE Products + SET Price = {new_price} + WHERE ProductID = {product_id} + """) + except Exception as e: + # Roll back via restore if available + if restore_ts is not None: + spark.sql(f"RESTORE TABLE Products TO TIMESTAMP AS OF '{restore_ts}'") + print("Table restored to previous version.") + raise e + + # Show final result + final_df = spark.sql(f"SELECT * FROM Products WHERE ProductID = {product_id}") + display(final_df) \ No newline at end of file diff --git a/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/postgresql_to_databricks_notebook.yml b/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/postgresql_to_databricks_notebook.yml new file mode 100644 index 00000000..828238e0 --- /dev/null +++ b/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/postgresql_to_databricks_notebook.yml @@ -0,0 +1,233 @@ +system_message: | + Convert PostgreSQL (including PL/pgSQL) code to Python code that runs on Databricks according to the following instructions and guidelines: + + # Input and Output + - Input: A single SQL file containing one or multiple PostgreSQL statements (including but not limited to `CREATE FUNCTION` or `CREATE PROCEDURE` statements in PL/pgSQL). + - Output: Python code with Python comments (in {comment_lang}) explaining the code and any necessary context. + + ${common_python_instructions_and_guidelines} + + # PostgreSQL Specific Guidelines + ## 1. Function Mapping: + Map dialect-specific functions to their Databricks equivalents as follows: + + - `GENERATE_SERIES(start, stop, step)` => Use `sequence(start, stop, step)` combined with `explode()` + - `json#>'{a,b,c}'` => `get_json_object(json_str, '$.a.b.c')` + - `json#>>'{a,b,c}'` => `get_json_object(json_str, '$.a.b.c')` (when treating as string) + - `json->'key'` => `get_json_object(json_str, '$.key')` (may require additional processing if treating as JSON object) + - `json->>'key'` => `get_json_object(json_str, '$.key')` + - `LEVENSHTEIN(string1, string2)` => No direct equivalent; requires implementation via Python UDF + - `PERFORM function_call()` => Use `_ = spark.sql("...")` when result needs to be discarded + - `RAISE NOTICE 'message'` => Use Python's `print("message")` or logging + - `REGEXP_REPLACE(string, pattern, replacement, flags)` => `regexp_replace(string, pattern, replacement)` (flags converted to appropriate pattern modifiers) + - `STRING_AGG(col, delimiter)` => `concat_ws(delimiter, collect_list(col))` + + ## 2. Data Type Mapping: + Map dialect-specific data types to Spark SQL data types as follows: + + - `BYTEA` => `BINARY` + - `CHAR`, `VARCHAR`, `TEXT` => `STRING` + - `DOUBLE PRECISION`, `REAL` => `DOUBLE` + - `INTEGER` => `INT` + - `NUMERIC(p, s)`, `DECIMAL(p, s)` => `DECIMAL(p, s)` + - `SERIAL`, `BIGSERIAL`, `SMALLSERIAL` => `INT` or `LONG` (note that auto-increment is not directly supported) + - `UUID` => `STRING` + - Array types => Convert to corresponding `ARRAY` + - Composite types => Typically converted to `STRUCT<...>` + + ## 3. Syntax Transformation: + Map PostgreSQL-specific syntax constructs to their Databricks equivalents as follows: + + - `IF EXISTS (SELECT 1 FROM ...)` => `if spark.sql("SELECT COUNT(*) FROM ...").collect()[0][0] > 0:` + - `IF NOT EXISTS (SELECT 1 FROM ...)` => `if spark.sql("SELECT COUNT(*) FROM ...").collect()[0][0] == 0:` + - `GENERATE_SERIES` with date intervals => Generate date sequences in Python + - `WITH RECURSIVE` => Not directly supported; implement iterative logic in Python + - `LATERAL JOIN` => May need to be rewritten with UDFs or multiple operations + - `INSERT ... RETURNING ...` => Split into INSERT followed by a separate SELECT query + - PostgreSQL-specific date/time manipulations => Use Spark functions with appropriate adjustments + - `CREATE TABLE ... INHERITS (...)` => Not supported in Spark; create separate tables or use a unified schema + - `SELECT ... FOR UPDATE` => Not supported; lock management differs in distributed systems + - `DO $$ ... $$` => Convert to Python code outside of SQL + +few_shots: +# Example 1: Basic PostgreSQL Queries with Temporary Tables +- role: user + content: | + CREATE TEMP TABLE temp_recent_orders ( + order_id INTEGER, + customer_name VARCHAR(100), + total_amount DECIMAL(10,2), + order_date DATE + ); + + INSERT INTO temp_recent_orders + SELECT + o.order_id, + c.first_name || ' ' || c.last_name AS customer_name, + SUM(oi.price * oi.quantity) AS total_amount, + o.order_date + FROM orders o + JOIN customers c ON o.customer_id = c.customer_id + JOIN order_items oi ON o.order_id = oi.order_id + WHERE o.order_date > current_date - INTERVAL '3 months' + GROUP BY o.order_id, c.first_name, c.last_name, o.order_date; + + SELECT customer_name, SUM(total_amount) AS total_spent + FROM temp_recent_orders + GROUP BY customer_name + ORDER BY total_spent DESC + LIMIT 10; +- role: assistant + content: | + # Create a temporary table with Delta (using OR REPLACE to avoid conflicts) + spark.sql(""" + CREATE OR REPLACE TABLE temp_recent_orders ( + order_id INT, + customer_name STRING, + total_amount DECIMAL(10,2), + order_date DATE + ) + """) + + # Populate the temporary table + spark.sql(""" + INSERT INTO temp_recent_orders + SELECT + o.order_id, + concat_ws(' ', c.first_name, c.last_name) AS customer_name, + SUM(oi.price * oi.quantity) AS total_amount, + o.order_date + FROM orders o + JOIN customers c ON o.customer_id = c.customer_id + JOIN order_items oi ON o.order_id = oi.order_id + WHERE o.order_date > date_sub(current_date(), 90) + GROUP BY o.order_id, c.first_name, c.last_name, o.order_date + """) + + # Query for top customers + top_customers = spark.sql(""" + SELECT customer_name, SUM(total_amount) AS total_spent + FROM temp_recent_orders + GROUP BY customer_name + ORDER BY total_spent DESC + LIMIT 10 + """) + + # Display results + display(top_customers) + + # Explicitly drop the temporary equivalent table + spark.sql("DROP TABLE IF EXISTS temp_recent_orders") + +# Example 2: PostgreSQL Function +- role: user + content: | + CREATE OR REPLACE FUNCTION update_product_price( + p_product_id INTEGER, + p_new_price DECIMAL(10,2) + ) + RETURNS VOID AS $$ + BEGIN + IF NOT EXISTS (SELECT 1 FROM products WHERE product_id = p_product_id) THEN + RAISE NOTICE 'Product not found'; + RETURN; + END IF; + UPDATE products + SET price = p_new_price + WHERE product_id = p_product_id; + PERFORM (SELECT * FROM products WHERE product_id = p_product_id); + END; + $$ LANGUAGE plpgsql; +- role: assistant + content: | + # Use widgets for parameter input + dbutils.widgets.text("p_product_id", "") + dbutils.widgets.text("p_new_price", "") + + # Validate input + try: + product_id = int(dbutils.widgets.get("p_product_id")) + new_price = float(dbutils.widgets.get("p_new_price")) + except ValueError: + print("Invalid parameters.") + dbutils.notebook.exit("Error") + + # Check if product exists + exists_df = spark.sql(f"SELECT 1 FROM products WHERE product_id = {product_id}") + if exists_df.count() == 0: + print("Product not found.") + dbutils.notebook.exit("No such product.") + + # Update price + spark.sql(f""" + UPDATE products + SET price = {new_price} + WHERE product_id = {product_id} + """) + + # Return updated product + updated_df = spark.sql(f"SELECT * FROM products WHERE product_id = {product_id}") + display(updated_df) + +# Example 3: PostgreSQL Function with Transaction Handling +- role: user + content: | + CREATE OR REPLACE FUNCTION update_product_price_with_transaction( + p_product_id INTEGER, + p_new_price DECIMAL(10,2) + ) + RETURNS TABLE (product_id INTEGER, price DECIMAL(10,2)) AS $$ + BEGIN + BEGIN; + + UPDATE products + SET price = p_new_price + WHERE product_id = p_product_id; + + COMMIT; + + RETURN QUERY + SELECT products.product_id, products.price + FROM products + WHERE product_id = p_product_id; + + EXCEPTION + WHEN OTHERS THEN + ROLLBACK; + RAISE; + END; + $$ LANGUAGE plpgsql; +- role: assistant + content: | + # Widgets for parameters + dbutils.widgets.text("p_product_id", "") + dbutils.widgets.text("p_new_price", "") + + # Convert inputs + product_id = int(dbutils.widgets.get("p_product_id")) + new_price = float(dbutils.widgets.get("p_new_price")) + + # Capture the current timestamp for rollback + try: + hist = spark.sql("DESCRIBE HISTORY products LIMIT 1").collect()[0] + restore_ts = hist["timestamp"] + except: + restore_ts = None + + try: + # Perform the update + spark.sql(f""" + UPDATE products + SET price = {new_price} + WHERE product_id = {product_id} + """) + except Exception as e: + # Roll back via restore if available + if restore_ts is not None: + spark.sql(f"RESTORE TABLE products TO TIMESTAMP AS OF '{restore_ts}'") + print("Table restored to previous version.") + raise e + + # Show final result + final_df = spark.sql(f"SELECT product_id, price FROM products WHERE product_id = {product_id}") + display(final_df) \ No newline at end of file diff --git a/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/redshift_to_databricks_notebook.yml b/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/redshift_to_databricks_notebook.yml new file mode 100644 index 00000000..129f5905 --- /dev/null +++ b/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/redshift_to_databricks_notebook.yml @@ -0,0 +1,233 @@ +system_message: | + Convert Amazon Redshift SQL code to Python code that runs on Databricks according to the following instructions and guidelines: + + # Input and Output + - Input: A single SQL file containing one or multiple Redshift statements (including but not limited to CREATE PROCEDURE statements). + - Output: Python code with Python comments (in {comment_lang}) explaining the code and any necessary context. + + ${common_python_instructions_and_guidelines} + + # Redshift Specific Guidelines + ## 1. Function Mapping + Map dialect-specific functions to their Databricks equivalents as follows: + + - `ADD_MONTHS(expr, n)` => `add_months(expr, n)` + - `APPROXIMATE COUNT(DISTINCT expr)` => Use `approx_count_distinct(expr)` or add a comment if behavior differs + - `CONVERT_TIMEZONE('from_zone', 'to_zone', col)` => Consider Spark functions like `from_utc_timestamp(...)` / `to_utc_timestamp(...)` + - `DATEADD(unit, n, date_col)` => `date_add(date_col, n)` + - `DATEDIFF(unit, date_col1, date_col2)` => `datediff(date_col2, date_col1)` + - `LISTAGG(expr, delimiter)` => Often emulated via `collect_list()` + `concat_ws()` + - `SPLIT_TO_ARRAY(str, delimiter)` => `split(str, delimiter)` + - `SYSDATE` or `GETDATE()` => `current_timestamp()` + + ## 2. Data Type Mapping + Map dialect-specific data types to Spark SQL data types as follows: + + - `BIGINT` => `LONG` + - `BOOLEAN` => `BOOLEAN` + - `BYTEINT` => `TINYINT` + - `CHAR`, `VARCHAR`, `TEXT` => `STRING` + - `DECIMAL(p, s)` => `DECIMAL(p, s)` + - `INT` or `INTEGER` => `INT` + - `SUPER`, `HLLSKETCH` => `STRING` (note unsupported features) + - `TIMESTAMPTZ`, `TIMETZ` => `TIMESTAMP` (time zone logic may differ) + - `VARBYTE` => `BINARY` or `STRING` depending on usage + + ## 3. Syntax Transformation + Map dialect-specific syntax constructs to their Databricks equivalents as follows: + + - Redshift-specific statements (e.g., `UNLOAD`, `CREATE INDEX`, `DISTKEY`, `SORTKEY`, `ENCODE`) are not supported in Databricks; either comment them out or note they are ignored. + - IDENTITY columns in Redshift (`IDENTITY(start,inc)`) can be mapped to `GENERATED ALWAYS AS IDENTITY` in Delta, but exact behavior may differ. If not feasible, comment out or explain the limitation. + - Redshift system tables or catalogs (`SVL_`, `STV_`, etc.) have no direct Spark equivalent; comment out or adapt as needed. + +few_shots: +# Example 1: Basic Queries with Temporary Tables +- role: user + content: | + CREATE TEMP TABLE recentorders ( + orderid INT, + customername VARCHAR(100), + totalamount DECIMAL(10,2), + orderdate DATE + ); + + INSERT INTO recentorders + SELECT + o.orderid, + c.firstname || ' ' || c.lastname AS customername, + SUM(oi.price * oi.quantity) AS totalamount, + o.orderdate + FROM orders o + JOIN customers c ON o.customerid = c.customerid + JOIN orderitems oi ON o.orderid = oi.orderid + WHERE o.orderdate > DATEADD(month, -3, current_timestamp) + GROUP BY o.orderid, c.firstname, c.lastname, o.orderdate; + + SELECT customername, SUM(totalamount) AS totalspent + FROM recentorders + GROUP BY customername + ORDER BY totalspent DESC + LIMIT 10; +- role: assistant + content: | + # Create a temporary table with Delta (using OR REPLACE to avoid conflicts) + spark.sql(""" + CREATE OR REPLACE TABLE recentorders ( + orderid INT, + customername STRING, + totalamount DECIMAL(10,2), + orderdate DATE + ) + """) + + # Populate the table + spark.sql(""" + INSERT INTO recentorders + SELECT + o.orderid, + concat_ws(' ', c.firstname, c.lastname) AS customername, + SUM(oi.price * oi.quantity) AS totalamount, + o.orderdate + FROM orders o + JOIN customers c ON o.customerid = c.customerid + JOIN orderitems oi ON o.orderid = oi.orderid + WHERE o.orderdate > date_add(current_date(), -90) + GROUP BY o.orderid, c.firstname, c.lastname, o.orderdate + """) + + # Query for top spenders + top_spenders = spark.sql(""" + SELECT customername, SUM(totalamount) AS totalspent + FROM recentorders + GROUP BY customername + ORDER BY totalspent DESC + LIMIT 10 + """) + + # Display results + display(top_spenders) + + # Explicitly drop the temporary equivalent table + spark.sql("DROP TABLE IF EXISTS recentorders") + +# Example 2: Stored Procedure +- role: user + content: | + CREATE OR REPLACE PROCEDURE public.updateproductprice( + p_productid INT, + p_newprice DECIMAL(10,2) + ) + LANGUAGE plpgsql + AS $$ + BEGIN + IF (SELECT COUNT(*) FROM products WHERE productid = p_productid) = 0 THEN + RAISE EXCEPTION 'Product not found'; + END IF; + + UPDATE products + SET price = p_newprice + WHERE productid = p_productid; + + SELECT * FROM products WHERE productid = p_productid; + END; + $$; +- role: assistant + content: | + # Use Databricks widgets for procedure parameters + dbutils.widgets.text("ProductID", "") + dbutils.widgets.text("NewPrice", "") + + # Validate and convert parameters + try: + product_id = int(dbutils.widgets.get("ProductID")) + new_price = float(dbutils.widgets.get("NewPrice")) + except ValueError: + print("Invalid parameters.") + dbutils.notebook.exit("Error") + + # Check if product exists + exists_df = spark.sql(f"SELECT 1 FROM products WHERE productid = {product_id}") + if exists_df.count() == 0: + print("Product not found.") + dbutils.notebook.exit("No such product.") + + # Perform update + spark.sql(f""" + UPDATE products + SET price = {new_price} + WHERE productid = {product_id} + """) + + # Return updated product + updated_df = spark.sql(f"SELECT * FROM products WHERE productid = {product_id}") + display(updated_df) + +# Example 3: Stored Procedure with Transaction Handling +- role: user + content: | + CREATE OR REPLACE PROCEDURE public.updateproductpricewithtx( + p_productid INT, + p_newprice DECIMAL(10,2) + ) + LANGUAGE plpgsql + AS $$ + DECLARE + BEGIN + BEGIN; + + UPDATE products + SET price = p_newprice + WHERE productid = p_productid; + + IF p_newprice < 0 THEN + RAISE EXCEPTION 'Price cannot be negative'; + END IF; + + COMMIT; + EXCEPTION + WHEN OTHERS THEN + ROLLBACK; + RAISE; + END; + $$; +- role: assistant + content: | + # Widgets for parameters + dbutils.widgets.text("ProductID", "") + dbutils.widgets.text("NewPrice", "") + + # Convert inputs + try: + product_id = int(dbutils.widgets.get("ProductID")) + new_price = float(dbutils.widgets.get("NewPrice")) + except ValueError: + raise ValueError("Invalid input parameters.") + + # Capture the latest timestamp for rollback + try: + hist = spark.sql("DESCRIBE HISTORY products LIMIT 1").collect()[0] + restore_ts = hist["timestamp"] + except: + restore_ts = None + + try: + # Reject negative price logic (simulating Redshift's exception) + if new_price < 0: + raise ValueError("Price cannot be negative") + + # Update the product + spark.sql(f""" + UPDATE products + SET price = {new_price} + WHERE productid = {product_id} + """) + except Exception as e: + # Attempt rollback if we have a timestamp + if restore_ts: + spark.sql(f"RESTORE TABLE products TO TIMESTAMP AS OF '{restore_ts}'") + print("Rollback via table restore completed.") + raise e + + # Final check + final_df = spark.sql(f"SELECT * FROM products WHERE productid = {product_id}") + display(final_df) \ No newline at end of file diff --git a/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/snowflake_to_databricks_notebook.yml b/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/snowflake_to_databricks_notebook.yml new file mode 100644 index 00000000..27cbb957 --- /dev/null +++ b/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/snowflake_to_databricks_notebook.yml @@ -0,0 +1,257 @@ +system_message: | + Convert Snowflake SQL code to Python code that runs on Databricks according to the following instructions and guidelines: + + # Input and Output + - Input: A single SQL file containing one or multiple Snowflake SQL statements (including but not limited DDL, DML, or stored procedures). + - Output: Python code with Python comments (in {comment_lang}) explaining the code and any necessary context. + + ${common_python_instructions_and_guidelines} + + # Snowflake Specific Guidelines + ## 1. Function Mapping: + Map dialect-specific functions to their Databricks equivalents as follows: + + - `ARRAY_AGG(expr)` => `collect_list(expr)` + - `ARRAY_CONSTRUCT(...)` => `array(...)` for flat arrays, or `map(...)` / `struct(...)` if keys/values are involved. + - `DIV0(expr1, expr2)` => `case when expr2=0 then 0 else expr1/expr2 end` or Python try/except block. + - `FLATTEN(...)` => `explode()` or `posexplode()` plus a lateral view. Provide a Python comment on how to replicate the logic. + - `LEN(expr)` => `length(expr)` + - `LISTAGG(...)` => `collect_list(...)` plus `concat_ws(...)`, or a similar approach. + - `NULLIFZERO(expr)` => `case when expr = 0 then null else expr end`. + - `OBJECT_CONSTRUCT(...)` => `map(...)`, JSON operations, or `struct(...)`; note any limitations in comments. + - `PIVOT` / `UNPIVOT` => `stack()`, multiple `union`s, or DataFrame operations. Document any complexity via Python comments. + - `REGEXP_EXTRACT(...)`, `REGEXP_EXTRACT_ALL(...)`, `REGEXP_REPLACE(...)`, `REGEXP_ILIKE(...)`, `RLIKE(...)` => Use Spark equivalents or Python comments for unsupported functions. + - `SQUARE(expr)` => `power(expr, 2)` or a Python comment. + - `TIMEDIFF`, `TIMESTAMPDIFF` => Often reversed argument order or different semantics in Spark. Make sure to confirm the logic does not invert the result. Add clarifying Python comments where needed. + - `TIMESTAMP_NTZ` / `TIMESTAMP_LTZ` / `TIMESTAMP_TZ` => All map to a single `timestamp` type in Spark. Note any differences (time zone handling, local time, etc.) in comments. + - `TIMEADD(unit, n, date)` or `TIMESTAMPADD(unit, n, date)` => `dateadd(unit, n, date)` + - `TO_CHAR(date, format)` => `date_format(date, format)` (or rely on the LLM to generate the correct pattern). + - `TO_TIMESTAMP_LTZ` / `TO_TIMESTAMP_TZ` / `TO_TIMESTAMP_NTZ` => In Spark, generally `to_timestamp()` or `cast(... as timestamp)`. If time zone data is lost, insert an explanatory comment. + - `VARIANCE_SAMP(expr)` => `variance(expr)` + - `ZEROIFNULL(expr)` => `coalesce(expr, 0)` or a `case when expr is null then 0 else expr end`. + + ## 2. Data Type Mapping: + Map dialect-specific data types to Spark SQL data types as follows: + + - `CHAR`, `VARCHAR`, `NCHAR`, `NVARCHAR`, `TEXT`, `VARCHAR2` => `STRING` + - `DATETIME`, `TIMESTAMP_NTZ`, `TIMESTAMP_LTZ`, `TIMESTAMP_TZ` => `TIMESTAMP` + - `FLOAT4` / `FLOAT8` / `DOUBLE PRECISION` => `DOUBLE` + - `INTEGER` / `INT` => `BIGINT` or `INT` + - `NUMBER(p, s)` => `DECIMAL(p, s)` + - `VARBINARY` => `BINARY` + - If `NOT NULL` or other DDL clauses do not map cleanly, use Python comments to explain any differences or potential issues. + + ## 3. Syntax Transformation: + Map dialect-specific syntax constructs to their Databricks equivalents as follows: + + - For Snowflake statements with no direct equivalent (e.g., `SHOW PRIMARY KEYS`, `PUT`, `GET`, etc.), exclude or comment them out. Add Python comments explaining why. + - Variable interpolation: `:var`, `$var` => Use Python f-strings or `.format()` instead. + - Identifier quoting: `"My Table Name"` => Use backticks (e.g., `My Table Name`) in Spark SQL. + - Special characters in names: `@`, `#`, `$` => Remove or replace them as these are forbidden in Spark. Add Python comments if this could break references. + - `BITSHIFTLEFT` => `<<` or use Python-level operations. + - `BITSHIFTRIGHT` => `>>` or use Python-level operations. + - `BITXOR` => `^` or use Python-level operations. + - `BOOLXOR` => Not directly supported. Use `(a OR b) AND NOT (a AND b)` or similar logical expression. + + ## 4. JavaScript-based Stored Procedures: + Snowflake allows creating stored procedures in JavaScript. In Databricks, these should be converted to Python code following these guidelines: + + - Replace `snowflake.execute({sqlText: ...})` with `spark.sql()` calls. + - Convert JavaScript local variables to Python variables. + - Replace parameter passing mechanisms with Databricks widgets (`dbutils.widgets`), environment variables, or function parameters in Python. + - Transform JavaScript `try { ... } catch (err) { ... }` blocks to Python's `try-except` blocks. + - Replicate JavaScript conditional logic (`if`, `for`, etc.) with equivalent Python control structures. + +few_shots: +# Example 1: Basic Queries with Temporary Tables +- role: user + content: | + CREATE TEMPORARY TABLE TEMP_RECENT_ORDERS ( + OrderID NUMBER(10, 0), + CustomerName VARCHAR(100), + TotalAmount DECIMAL(10,2), + OrderDate DATE + ); + + INSERT INTO TEMP_RECENT_ORDERS + SELECT + o.OrderID, + -- Snowflake typically uses '||' for string concatenation: + c.FirstName || ' ' || c.LastName AS CustomerName, + SUM(oi.Price * oi.Quantity) AS TotalAmount, + o.OrderDate + FROM Orders o + JOIN Customers c ON o.CustomerID = c.CustomerID + JOIN OrderItems oi ON o.OrderID = oi.OrderID + WHERE o.OrderDate > DATEADD(MONTH, -3, CURRENT_TIMESTAMP()) + GROUP BY o.OrderID, c.FirstName, c.LastName, o.OrderDate; + + SELECT CustomerName, SUM(TotalAmount) AS TotalSpent + FROM TEMP_RECENT_ORDERS + GROUP BY CustomerName + ORDER BY TotalSpent DESC + LIMIT 10; +- role: assistant + content: | + # Create a Delta table (replacing the temporary concept) + spark.sql(""" + CREATE OR REPLACE TABLE TEMP_RECENT_ORDERS ( + OrderID INT, + CustomerName STRING, + TotalAmount DECIMAL(10,2), + OrderDate DATE + ) + """) + + # Populate the temporary table + spark.sql(""" + INSERT INTO TEMP_RECENT_ORDERS + SELECT + o.OrderID, + concat_ws(' ', c.FirstName, c.LastName) AS CustomerName, + SUM(oi.Price * oi.Quantity) AS TotalAmount, + o.OrderDate + FROM Orders o + JOIN Customers c ON o.CustomerID = c.CustomerID + JOIN OrderItems oi ON o.OrderID = oi.OrderID + WHERE o.OrderDate > date_add(current_timestamp(), -90) + GROUP BY o.OrderID, c.FirstName, c.LastName, o.OrderDate + """) + + # Query for top customers + top_customers = spark.sql(""" + SELECT CustomerName, SUM(TotalAmount) AS TotalSpent + FROM TEMP_RECENT_ORDERS + GROUP BY CustomerName + ORDER BY TotalSpent DESC + LIMIT 10 + """) + + # Display the results + display(top_customers) + + # Explicitly drop the temporary equivalent table + spark.sql("DROP TABLE IF EXISTS TEMP_RECENT_ORDERS") + +# EXAMPLE 2: JavaScript-based Stored Procedure +- role: user + content: | + CREATE OR REPLACE PROCEDURE PUBLIC.UPDATE_OR_INSERT_CUSTOMER( + "CUST_ID" NUMBER, + "NEW_NAME" VARCHAR + ) + RETURNS VARCHAR + LANGUAGE JAVASCRIPT + AS + ' + var checkSQL = `SELECT COUNT(*) AS CNT FROM CUSTOMERS WHERE CUSTOMER_ID = ` + CUST_ID; + var rs = snowflake.execute({sqlText: checkSQL}); + rs.next(); + var countVal = rs.getColumnValue(1); + + if (countVal > 0) { + var updateSQL = `UPDATE CUSTOMERS SET CUSTOMER_NAME = ''' + NEW_NAME + ''' WHERE CUSTOMER_ID = ` + CUST_ID; + snowflake.execute({sqlText: updateSQL}); + return "Updated existing customer."; + } else { + var insertSQL = `INSERT INTO CUSTOMERS (CUSTOMER_ID, CUSTOMER_NAME) VALUES (` + CUST_ID + `, ''' + NEW_NAME + ''')`; + snowflake.execute({sqlText: insertSQL}); + return "Inserted new customer."; + } + '; +- role: assistant + content: | + # Retrieve parameters via widgets (simulating procedure arguments) + dbutils.widgets.text("CUST_ID", "") + dbutils.widgets.text("NEW_NAME", "") + + try: + cust_id = int(dbutils.widgets.get("CUST_ID")) + new_name = dbutils.widgets.get("NEW_NAME") + except ValueError: + raise ValueError("Invalid CUST_ID. Must be an integer.") + + # Check if the customer exists + check_df = spark.sql(f"SELECT COUNT(*) AS CNT FROM CUSTOMERS WHERE CUSTOMER_ID = {cust_id}") + count_val = check_df.collect()[0]["CNT"] + + # Update or Insert logic + if count_val > 0: + # Update + spark.sql(f""" + UPDATE CUSTOMERS + SET CUSTOMER_NAME = '{new_name}' + WHERE CUSTOMER_ID = {cust_id} + """) + result_msg = "Updated existing customer." + else: + # Insert + spark.sql(f""" + INSERT INTO CUSTOMERS (CUSTOMER_ID, CUSTOMER_NAME) + VALUES ({cust_id}, '{new_name}') + """) + result_msg = "Inserted new customer." + + # Return or display the result + print(result_msg) + dbutils.notebook.exit(result_msg) + +# Example 3: JavaScript-based Stored Procedure with Transaction Handling +- role: user + content: | + CREATE OR REPLACE PROCEDURE PUBLIC.UpdateProductPriceWithTransaction( + "PRODUCT_ID" NUMBER, + "NEW_PRICE" NUMBER(10,2) + ) + RETURNS TABLE (PRODUCTID NUMBER, PRICE NUMBER(10,2)) + LANGUAGE JAVASCRIPT + EXECUTE AS OWNER + AS + $$ + try { + snowflake.execute({ sqlText: 'BEGIN' }); + let updateStmt = "UPDATE PRODUCTS SET PRICE = " + NEW_PRICE + " WHERE PRODUCTID = " + PRODUCT_ID; + snowflake.execute({ sqlText: updateStmt }); + snowflake.execute({ sqlText: 'COMMIT' }); + return snowflake.execute({ + sqlText: "SELECT PRODUCTID, PRICE FROM PRODUCTS WHERE PRODUCTID = " + PRODUCT_ID + }); + } catch (err) { + snowflake.execute({ sqlText: 'ROLLBACK' }); + throw err; + } + $$; +- role: assistant + content: | + # Widgets for parameters + dbutils.widgets.text("PRODUCT_ID", "") + dbutils.widgets.text("NEW_PRICE", "") + + # Convert inputs + product_id = int(dbutils.widgets.get("PRODUCT_ID")) + new_price = float(dbutils.widgets.get("NEW_PRICE")) + + # Capture the current timestamp for rollback + try: + hist = spark.sql("DESCRIBE HISTORY Products LIMIT 1").collect()[0] + restore_ts = hist["timestamp"] + except: + restore_ts = None + + try: + # Perform the update + spark.sql(f""" + UPDATE Products + SET Price = {new_price} + WHERE ProductID = {product_id} + """) + except Exception as e: + # Rollback via restore if available + if restore_ts is not None: + spark.sql(f"RESTORE TABLE Products TO TIMESTAMP AS OF '{restore_ts}'") + print("Table restored to previous version.") + raise e + + # Show final result + final_df = spark.sql(f"SELECT * FROM Products WHERE ProductID = {product_id}") + display(final_df) \ No newline at end of file diff --git a/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/teradata_to_databricks_notebook.yml b/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/teradata_to_databricks_notebook.yml new file mode 100644 index 00000000..bd5157e6 --- /dev/null +++ b/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/teradata_to_databricks_notebook.yml @@ -0,0 +1,224 @@ +system_message: | + Convert Teradata SQL code to Python code that runs on Databricks according to the following instructions and guidelines: + + # Input and Output + - Input: A single SQL file containing one or multiple Teradata SQL statements. + - Output: Python code with Python comments (in {comment_lang}) explaining the code and any necessary context. + + ${common_python_instructions_and_guidelines} + + # Teradata Specific Guidelines + ## 1. Function Mapping: + Map dialect-specific functions to their Databricks equivalents as follows: + + - `POSITION(substr IN str)` => `INSTR(str, substr)` + - `SAMPLE n` => Spark has no direct random sample equivalent using the exact syntax. Use `TABLESAMPLE` (with caution, different behavior) or simply `LIMIT n` for a non-random limit. + - Date/Time Arithmetic: + - `current_date - n` => `dateadd('DAY', -n, current_date())` + - `INTERVAL '1' WEEK` => approximate with `dateadd('DAY', 7, current_date())` + - Custom intervals like `INTERVAL '1' QUARTER` => approximate via days or comment out if no direct mapping. + - If no direct Spark equivalent (e.g., custom functions like `OREPLACE`, `OTRANSLATE`), comment them out or add Python notes. + + ## 2. Data Type Mapping: + Map dialect-specific data types to Spark SQL data types as follows: + + - `BIGINT` => `LONG` + - `BLOB` / `VARBYTE` => `BINARY` + - `BYTEINT` => `TINYINT` or `SMALLINT` + - `CHAR` / `VARCHAR` => `STRING` + - `DECIMAL(p, s)` => `DECIMAL(p, s)` + - `INTEGER` => `INT` + - `SMALLINT` => `SMALLINT` + - `ST_GEOMETRY` => `STRING` (comment out if unsupported) + - `TIMESTAMP(n)` => `TIMESTAMP` (precision may be truncated) + + ## 3. Syntax Transformation: + Map dialect-specific syntax constructs to their Databricks equivalents as follows: + + - `REPLACE ` => `CREATE OR REPLACE ` + - `TOP n` => `LIMIT n` + - `MINUS` => `EXCEPT` + - Teradata-specific statements like `LOCKING`, `MLOAD`, `FAST_LOAD`, `COLLECT STATISTICS` => remove or comment out. + - `CREATE MULTISET TABLE` / `CREATE VOLATILE TABLE` => `CREATE TABLE` (Delta, temp view, etc.). Comment differences. + - `UPDATE ... FROM`, `DELETE ... JOIN` => Not directly supported; consider `MERGE INTO` or a workaround with temp views. + - `PRIMARY INDEX`, `UNIQUE PRIMARY INDEX`, `PARTITION BY RANGE_N` => Spark does not support these directly. Remove or comment out. + - Character set conversions (e.g., `TRANSLATE(... USING ...)`) => remove or comment out. + - For hashing, encryption, or specialized Teradata functions (e.g., `HASHBYTES` analogs), comment or note as unsupported. + + ## 4. Abbreviations: + - `SEL` => `SELECT` + - `DEL` => `DELETE` + - `INS` => `INSERT` + - `UPD` => `UPDATE` + +few_shots: +# Example 1: Basic Queries with (Volatile) Tables +- role: user + content: | + CREATE VOLATILE TABLE RecentOrders ( + OrderID INTEGER, + CustomerName VARCHAR(100), + TotalAmount DECIMAL(10,2), + OrderDate DATE + ) + ON COMMIT PRESERVE ROWS; + + INSERT INTO RecentOrders + SELECT + o.OrderID, + c.FirstName || ' ' || c.LastName AS CustomerName, + SUM(oi.Price * oi.Quantity) AS TotalAmount, + o.OrderDate + FROM Orders o + JOIN Customers c ON o.CustomerID = c.CustomerID + JOIN OrderItems oi ON o.OrderID = oi.OrderID + WHERE o.OrderDate > ADD_MONTHS(CURRENT_DATE, -3) + GROUP BY o.OrderID, c.FirstName, c.LastName, o.OrderDate; + + SELECT TOP 10 CustomerName, SUM(TotalAmount) AS TotalSpent + FROM RecentOrders + GROUP BY CustomerName + ORDER BY TotalSpent DESC; +- role: assistant + content: | + # Create a table in Databricks (Delta). Volatile table doesn't exist, so we simply create or replace a normal table. + spark.sql(""" + CREATE OR REPLACE TABLE RecentOrders ( + OrderID INT, + CustomerName STRING, + TotalAmount DECIMAL(10,2), + OrderDate DATE + ) + """) + + # Insert data into the table + spark.sql(""" + INSERT INTO RecentOrders + SELECT + o.OrderID, + concat_ws(' ', c.FirstName, c.LastName) AS CustomerName, + SUM(oi.Price * oi.Quantity) AS TotalAmount, + o.OrderDate + FROM Orders o + JOIN Customers c ON o.CustomerID = c.CustomerID + JOIN OrderItems oi ON o.OrderID = oi.OrderID + WHERE o.OrderDate > date_add(current_date(), -90) + GROUP BY o.OrderID, c.FirstName, c.LastName, o.OrderDate + """) + + # Retrieve top 10 by total spent + top_customers = spark.sql(""" + SELECT CustomerName, SUM(TotalAmount) AS TotalSpent + FROM RecentOrders + GROUP BY CustomerName + ORDER BY TotalSpent DESC + LIMIT 10 + """) + display(top_customers) + + # Drop the table (to emulate dropping the volatile table) + spark.sql("DROP TABLE IF EXISTS RecentOrders") + +# Example 2: Stored Procedure +- role: user + content: | + REPLACE PROCEDURE UpdateProductPrice( + IN ProductID INT, + IN NewPrice DECIMAL(10,2) + ) + BEGIN + IF (NOT EXISTS (SELECT 1 FROM Products WHERE ProductID = :ProductID)) THEN + CALL dbc.trace('Product not found'); + LEAVE proc_exit; + END IF; + + UPDATE Products + SET Price = :NewPrice + WHERE ProductID = :ProductID; + + SELECT * FROM Products WHERE ProductID = :ProductID; + END proc_exit; +- role: assistant + content: | + # Use Databricks widgets for parameter input (simulating procedure parameters) + dbutils.widgets.text("ProductID", "") + dbutils.widgets.text("NewPrice", "") + + # Convert inputs from string to numeric types + try: + product_id = int(dbutils.widgets.get("ProductID")) + new_price = float(dbutils.widgets.get("NewPrice")) + except ValueError: + print("Invalid parameters.") + dbutils.notebook.exit("Error") + + # Check if product exists + exists_df = spark.sql(f"SELECT 1 FROM Products WHERE ProductID = {product_id}") + if exists_df.count() == 0: + print("Product not found.") + dbutils.notebook.exit("No such product.") + + # Update product price + spark.sql(f""" + UPDATE Products + SET Price = {new_price} + WHERE ProductID = {product_id} + """) + + # Return updated product + updated_df = spark.sql(f"SELECT * FROM Products WHERE ProductID = {product_id}") + display(updated_df) + +# Example 3: Stored Procedure with Transaction Handling +- role: user + content: | + REPLACE PROCEDURE UpdateProductPriceWithTransaction( + IN ProductID INT, + IN NewPrice DECIMAL(10,2) + ) + BEGIN + BT; -- Begin Teradata transaction + + UPDATE Products + SET Price = :NewPrice + WHERE ProductID = :ProductID; + + ET; -- End transaction (commit) + + SELECT * FROM Products WHERE ProductID = :ProductID; + END; +- role: assistant + content: | + # Widgets for parameters + dbutils.widgets.text("ProductID", "") + dbutils.widgets.text("NewPrice", "") + + # Convert inputs + product_id = int(dbutils.widgets.get("ProductID")) + new_price = float(dbutils.widgets.get("NewPrice")) + + # Capture the current timestamp for rollback + try: + hist = spark.sql("DESCRIBE HISTORY Products LIMIT 1").collect()[0] + restore_ts = hist["timestamp"] + except: + restore_ts = None + + try: + # Equivalent to "BT" -> begin a logical block + spark.sql(f""" + UPDATE Products + SET Price = {new_price} + WHERE ProductID = {product_id} + """) + # "ET" -> we simply do nothing extra in Delta for commit + except Exception as e: + # If an error occurs, attempt rollback via RESTORE + if restore_ts is not None: + spark.sql(f"RESTORE TABLE Products TO TIMESTAMP AS OF '{restore_ts}'") + print("Table restored to previous version.") + raise e + + # Show final result + final_df = spark.sql(f"SELECT * FROM Products WHERE ProductID = {product_id}") + display(final_df) \ No newline at end of file diff --git a/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/tsql_to_databricks_notebook.yml b/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/tsql_to_databricks_notebook.yml new file mode 100644 index 00000000..4b78e606 --- /dev/null +++ b/sql2dbx/notebooks/pyscripts/conversion_prompt_yaml/tsql_to_databricks_notebook.yml @@ -0,0 +1,251 @@ +system_message: | + Convert T-SQL code to Python code that runs on Databricks according to the following instructions and guidelines: + + # Input and Output + - Input: A single SQL file containing one or multiple T-SQL statements (including but not limited to `CREATE OR ALTER PROCEDURE` statements). + - Output: Python code with Python comments (in {comment_lang}) explaining the code and any necessary context. + + ${common_python_instructions_and_guidelines} + + # T-SQL Specific Guidelines + ## 1. Function Mapping: + Map dialect-specific functions to their Databricks equivalents as follows: + + - `CONVERT(type, expr)` => `cast(expr AS type)` + - `DATEADD(unit, n, date)` => `date_add(date, n)` or an equivalent + - `DATEDIFF(unit, startDate, endDate)` => `datediff(endDate, startDate)` (note argument order) + - `DATEPART(WEEKDAY, col)` => `dayofweek(col)` + - `FORMAT(date, 'pattern')` => `date_format(date, 'pattern')` + - `GETDATE()` => `current_timestamp()` + - `HASHBYTES('MD5' | 'SHA1' | 'SHA2_256' | 'SHA2_512', expr)`: + - MD5 => `md5(expr)` + - SHA1 => `sha(expr)` + - SHA2_256 => `sha2(expr, 256)` + - SHA2_512 => `sha2(expr, 512)` + - `ISNULL(expr, replacement)` => `nvl(expr, replacement)` + - `JSON_VALUE(...)`, `JSON_QUERY(...)` => `get_json_object()`, `json_tuple()`, or comment out if not feasible. + - `OBJECT_NAME(@@PROCID)` => Use the notebook name or a defined variable for references. + - `SQUARE(expr)` => `expr * expr` or `power(expr, 2)` + - `SYSDATE` => `current_date()` + - `SYSDATETIME()` => `current_timestamp()` + - `TIMEFROMPARTS(...)`, `DATETIMEFROMPARTS(...)` => Consider Python datetime or `make_timestamp`. + - `TO_NUMBER(expr)` => `cast(expr AS INT)` or `cast(expr AS DECIMAL(...))` depending on usage + - `TRUNC(date, 'DD')` => `date_trunc('DAY', date)` + - For string concat, handle NULL carefully (e.g., `concat_ws`) + + ## 2. Data Type Mapping: + Map dialect-specific data types to Spark SQL data types as follows: + + - `BIGINT` => `LONG` + - `FLOAT` / `REAL` => `DOUBLE` + - `GEOGRAPHY` => `STRING` + - `GEOMETRY` => `STRING` + - `IMAGE` => `BINARY` + - `INT` / `INTEGER` => `INT` + - `MONEY` => `DECIMAL` or `DOUBLE` + - `NUMERIC(p, s)` / `DECIMAL(p, s)` => `DECIMAL(p, s)` + - `NTEXT` => `STRING` + - `NVARCHAR` => `STRING` + - `ROWVERSION`: + - T-SQL uses `ROWVERSION` for concurrency checks. Spark SQL has no direct equivalent. + - Often converted to `BINARY(8)` or left commented out if you only need the concurrency logic in T-SQL. + - `SMALLMONEY` => `DECIMAL` or `DOUBLE` + - `SQL_VARIANT` => `STRING` + - `TEXT` => `STRING` + - `UNIQUEIDENTIFIER` => `STRING` + - `VARBINARY` => `BINARY` + - `VARCHAR` => `STRING` + + ## 3. Syntax Transformation: + Map dialect-specific syntax constructs to their Databricks equivalents as follows: + + - For T-SQL statements with no direct equivalent (e.g., `CREATE NONCLUSTERED INDEX`, `EXEC `, etc.), exclude or comment them out. Add Python comments explaining why. + - `CROSS APPLY` => Typically replaced by a lateral view + `explode()` + - `TOP(n)` => `LIMIT n` + - `IF EXISTS(SELECT 1 FROM ...)` => `spark.sql(...).count() > 0` + - `IF NOT EXISTS(SELECT 1 FROM ...)` => `spark.sql(...).count() <= 0` + - `DECLARE @var...`, `SET @var = ...` => Not supported. Use Python variables or Databricks widgets instead. + - Query Hints (e.g., `OPTION(MAXDOP=1)`, `WITH(NOLOCK)`) => Not supported. Comment out. + - `INDEX` definitions (`CLUSTERED` / `NONCLUSTERED`) => Not supported in Delta Lake. Comment out. + - Bracketed Identifiers: `[schema].[table]` => `schema.table` or `schema`.`table` (Spark SQL doesn’t support square brackets) + - T-SQL Variables: `@VariableName` => Switch to Pythonic variables or f-string placeholders. + +few_shots: +# Example 1: Basic Queries with Temporary Tables +- role: user + content: | + CREATE TABLE #RecentOrders ( + OrderID INT, + CustomerName VARCHAR(100), + TotalAmount DECIMAL(10,2), + OrderDate DATE + ); + + INSERT INTO #RecentOrders + SELECT + o.OrderID, + c.FirstName + ' ' + c.LastName AS CustomerName, + SUM(oi.Price * oi.Quantity) AS TotalAmount, + o.OrderDate + FROM Orders o + JOIN Customers c ON o.CustomerID = c.CustomerID + JOIN OrderItems oi ON o.OrderID = oi.OrderID + WHERE o.OrderDate > DATEADD(MONTH, -3, GETDATE()) + GROUP BY o.OrderID, c.FirstName, c.LastName, o.OrderDate; + + SELECT TOP 10 CustomerName, SUM(TotalAmount) AS TotalSpent + FROM #RecentOrders + GROUP BY CustomerName + ORDER BY TotalSpent DESC; +- role: assistant + content: | + # Create a temporary table with Delta (using OR REPLACE to avoid conflicts) + spark.sql(""" + CREATE OR REPLACE TABLE RecentOrders ( + OrderID INT, + CustomerName STRING, + TotalAmount DECIMAL(10,2), + OrderDate DATE + ) + """) + + # Populate the temporary table + spark.sql(""" + INSERT INTO RecentOrders + SELECT + o.OrderID, + concat_ws(' ', c.FirstName, c.LastName) AS CustomerName, + SUM(oi.Price * oi.Quantity) AS TotalAmount, + o.OrderDate + FROM Orders o + JOIN Customers c ON o.CustomerID = c.CustomerID + JOIN OrderItems oi ON o.OrderID = oi.OrderID + WHERE o.OrderDate > date_add(current_date(), -90) + GROUP BY o.OrderID, c.FirstName, c.LastName, o.OrderDate + """) + + # Query for top customers + top_customers = spark.sql(""" + SELECT CustomerName, SUM(TotalAmount) AS TotalSpent + FROM RecentOrders + GROUP BY CustomerName + ORDER BY TotalSpent DESC + LIMIT 10 + """) + + # Display results + display(top_customers) + + # Explicitly drop the temporary equivalent table + spark.sql("DROP TABLE IF EXISTS RecentOrders") + +# Example 2: Stored Procedure +- role: user + content: | + CREATE OR ALTER PROCEDURE dbo.UpdateProductPrice + @ProductID INT, + @NewPrice DECIMAL(10,2) + AS + BEGIN + SET NOCOUNT ON; + + IF NOT EXISTS (SELECT 1 FROM Products WHERE ProductID = @ProductID) + BEGIN + PRINT 'Product not found'; + RETURN; + END + + UPDATE Products + SET Price = @NewPrice + WHERE ProductID = @ProductID; + + SELECT * FROM Products WHERE ProductID = @ProductID; + END; +- role: assistant + content: | + # Use widgets for parameter input + dbutils.widgets.text("ProductID", "") + dbutils.widgets.text("NewPrice", "") + + # Validate input + try: + product_id = int(dbutils.widgets.get("ProductID")) + new_price = float(dbutils.widgets.get("NewPrice")) + except ValueError: + print("Invalid parameters.") + dbutils.notebook.exit("Error") + + # Check if product exists + exists_df = spark.sql(f"SELECT 1 FROM Products WHERE ProductID = {product_id}") + if exists_df.count() == 0: + print("Product not found.") + dbutils.notebook.exit("No such product.") + + # Update price + spark.sql(f""" + UPDATE Products + SET Price = {new_price} + WHERE ProductID = {product_id} + """) + + # Return updated product + updated_df = spark.sql(f"SELECT * FROM Products WHERE ProductID = {product_id}") + display(updated_df) + +# Example 3: Stored Procedure with Transaction Handling +- role: user + content: | + CREATE OR ALTER PROCEDURE dbo.UpdateProductPriceWithTransaction + @ProductID INT, + @NewPrice DECIMAL(10,2) + AS + BEGIN + BEGIN TRANSACTION; + BEGIN TRY + UPDATE Products + SET Price = @NewPrice + WHERE ProductID = @ProductID; + COMMIT TRANSACTION; + END TRY + BEGIN CATCH + ROLLBACK TRANSACTION; + THROW; + END CATCH + + SELECT * FROM Products WHERE ProductID = @ProductID; + END; +- role: assistant + content: | + # Widgets for parameters + dbutils.widgets.text("ProductID", "") + dbutils.widgets.text("NewPrice", "") + + # Convert inputs + product_id = int(dbutils.widgets.get("ProductID")) + new_price = float(dbutils.widgets.get("NewPrice")) + + # Capture the current timestamp for rollback + try: + hist = spark.sql("DESCRIBE HISTORY Products LIMIT 1").collect()[0] + restore_ts = hist["timestamp"] + except: + restore_ts = None + + try: + # Perform the update + spark.sql(f""" + UPDATE Products + SET Price = {new_price} + WHERE ProductID = {product_id} + """) + + except Exception as e: + # Rollback via restore if available + if restore_ts is not None: + spark.sql(f"RESTORE TABLE Products TO TIMESTAMP AS OF '{restore_ts}'") + print("Table restored to previous version.") + raise e + + # Show final result + final_df = spark.sql(f"SELECT * FROM Products WHERE ProductID = {product_id}") + display(final_df) \ No newline at end of file diff --git a/sql2dbx/notebooks/pyscripts/conversion_result_clean_helper.py b/sql2dbx/notebooks/pyscripts/conversion_result_clean_helper.py new file mode 100644 index 00000000..e85a33c0 --- /dev/null +++ b/sql2dbx/notebooks/pyscripts/conversion_result_clean_helper.py @@ -0,0 +1,53 @@ +from typing import Callable, List + + +class ConversionResultCleanHelper: + """ + Helper class to clean and extract Python code blocks from text generated by LLMs. + """ + TRIPLE_BACK_QUOTES_WITH_PYTHON = "```python" + TRIPLE_BACK_QUOTES = "```" + + def get_udf_functions(self) -> List[Callable[[str], str]]: + """ + Returns a list of UDF functions for use in PySpark. + """ + return [self.clean_python_code_blocks] + + def clean_python_code_blocks(self, text: str) -> str: + """ + Cleans Python code blocks with two simple rules: + 1. If text starts with ```python and has a closing ```, extract the code between them + 2. If ```python appears in the middle of text, simply remove it + + Args: + text: Input text containing Python code blocks + + Returns: + Cleaned code + """ + if text is None: + return None + + # If no ```python exists, return the original text + if self.TRIPLE_BACK_QUOTES_WITH_PYTHON not in text: + return text + + # Check if text starts with ```python (allowing for whitespace at the beginning) + if text.lstrip().startswith(self.TRIPLE_BACK_QUOTES_WITH_PYTHON): + # Case 1: Text starts with ```python - extract code between ```python and ``` + parts = text.split(self.TRIPLE_BACK_QUOTES_WITH_PYTHON, 1) + if len(parts) > 1 and self.TRIPLE_BACK_QUOTES in parts[1]: + code = parts[1].split(self.TRIPLE_BACK_QUOTES, 1)[0] + # Remove leading newline if present + if code.startswith("\n"): + code = code[1:] + # Ensure it ends with a newline + if code and not code.endswith("\n"): + code += "\n" + return code + + # Case 2: ```python appears in the middle - simply remove it + text = text.replace(self.TRIPLE_BACK_QUOTES_WITH_PYTHON, "") + + return text diff --git a/sql2dbx/notebooks/pyscripts/databricks_credentials.py b/sql2dbx/notebooks/pyscripts/databricks_credentials.py new file mode 100644 index 00000000..6af09ee2 --- /dev/null +++ b/sql2dbx/notebooks/pyscripts/databricks_credentials.py @@ -0,0 +1,77 @@ +import os +from typing import Dict + + +class DatabricksCredentials: + """ + Provides access to Databricks host and authentication token from environment variables or dbutils. + + This class handles two authentication approaches: + 1. Environment variables: DATABRICKS_HOST and DATABRICKS_TOKEN + 2. Ephemeral tokens from dbutils (when running in Databricks notebooks) + + Important: Ephemeral tokens have a short lifespan, so this class retrieves fresh values + on each property access to ensure tokens are always valid, even after rotation. + """ + + def get_host_and_token(self) -> Dict[str, str]: + """ + Returns both host URL and authentication token as a dictionary. + + The values are freshly retrieved each time this method is called + to ensure token validity, as Databricks ephemeral tokens might expire quickly. + + Returns: + Dict with 'host' and 'token' keys + + Raises: + RuntimeError: If credentials cannot be retrieved from either source + """ + # First check environment variables + host = os.environ.get("DATABRICKS_HOST") + token = os.environ.get("DATABRICKS_TOKEN") + + if host and token: + return {"host": host, "token": token} + + # Fallback to dbutils if environment variables not found + try: + import IPython + ipython = IPython.get_ipython() + dbutils = ipython.user_ns["dbutils"] + ctx = dbutils.notebook.entry_point.getDbutils().notebook().getContext() + + # Get both values at once from the context + return { + "host": getattr(ctx, "apiUrl")().get(), + "token": getattr(ctx, "apiToken")().get() + } + except Exception as e: + raise RuntimeError( + "Could not retrieve Databricks credentials from environment or dbutils context." + ) from e + + @property + def host(self) -> str: + """ + Returns the current Databricks host URL. + + Note: This retrieves a fresh value on each access. + + Raises: + RuntimeError: If the host cannot be found in environment variables or via dbutils. + """ + return self.get_host_and_token()["host"] + + @property + def token(self) -> str: + """ + Returns the current Databricks authentication token. + + Note: This retrieves a fresh value on each access to ensure the token is valid, + as Databricks ephemeral tokens have a short lifespan. + + Raises: + RuntimeError: If the token cannot be found in environment variables or via dbutils. + """ + return self.get_host_and_token()["token"] diff --git a/sql2dbx/notebooks/pyscripts/llm_token_count_helper.py b/sql2dbx/notebooks/pyscripts/llm_token_count_helper.py new file mode 100644 index 00000000..d7c1083d --- /dev/null +++ b/sql2dbx/notebooks/pyscripts/llm_token_count_helper.py @@ -0,0 +1,107 @@ +import re +from dataclasses import dataclass +from typing import List, Optional + +from . import token_utils, utils + + +@dataclass +class FileTokenMetadata: + """Data class for storing metadata of a file with token counts.""" + input_file_number: Optional[int] + input_file_path: str + input_file_encoding: str + input_file_content: str + input_file_content_without_sql_comments: Optional[str] + input_file_token_count: int + input_file_token_count_without_sql_comments: Optional[int] + tokenizer_type: str + tokenizer_model: str + + +class FileTokenCountHelper: + def __init__(self, endpoint_name: str = None, tokenizer_type: str = None, tokenizer_model: str = None): + """ + Initialize the FileTokenCounter with endpoint name or explicit tokenizer settings. + + Args: + endpoint_name (str, optional): The name of the endpoint to determine the tokenizer type. + Used to infer tokenizer type and model if not explicitly provided. + tokenizer_type (str, optional): The type of tokenizer to use ('openai' or 'claude'). + If not provided, will be inferred from endpoint_name. + tokenizer_model (str, optional): The specific model to use for tokenization. + If not provided, will be inferred from tokenizer_type or endpoint_name. + """ + self.endpoint_name = endpoint_name + + # Use explicit tokenizer settings if provided + if tokenizer_type: + self.tokenizer_type = tokenizer_type + self.tokenizer_model = tokenizer_model or ('claude' if tokenizer_type == 'claude' else 'o200k_base') + # Otherwise infer from endpoint_name + elif endpoint_name: + self.tokenizer_type, self.tokenizer_model = token_utils.determine_tokenizer_from_endpoint(endpoint_name) + # Default to Claude if neither is provided + else: + self.tokenizer_type = 'claude' + self.tokenizer_model = 'claude' + + # Create the token counter + self.token_counter = token_utils.get_token_counter(self.tokenizer_type, self.tokenizer_model) + + def process_directory(self, input_dir: str, file_encoding: Optional[str] = None, + is_sql: bool = True) -> List[FileTokenMetadata]: + """ + Process all files in a directory and return a list of FileTokenMetadata objects with file details. + + Args: + input_dir (str): The directory containing the files to be processed. + file_encoding (Optional[str]): The encoding to use for reading the files. If not specified, the encoding is automatically detected using chardet.detect. + is_sql (bool): Flag indicating whether the files are SQL files. If True, SQL comments will be removed for token counting. + + Returns: + List[FileTokenMetadata]: A list of metadata objects for each processed file. + """ + results = [] + for i, file_path in enumerate(utils.list_files_recursively(input_dir), start=1): + sql_file_token_metadata = self.process_file( + input_file_path=file_path, input_file_number=i, file_encoding=file_encoding, is_sql=is_sql) + results.append(sql_file_token_metadata) + return results + + def process_file(self, input_file_path: str, input_file_number: Optional[int] = None, + file_encoding: Optional[str] = None, is_sql: bool = True) -> FileTokenMetadata: + """ + Process a file and return its details including token counts. + + Args: + input_file_path (str): The path of the file to be processed. + input_file_number (Optional[int]): The number of the input file. If not provided, it will be generated automatically. + file_encoding (Optional[str]): The encoding to use for reading the file. If not specified, the encoding is automatically detected using chardet.detect. + is_sql (bool): Flag indicating whether the file is a SQL file. If True, SQL comments will be removed for token counting. + + Returns: + FileTokenMetadata: Metadata object containing file details and token counts. + """ + content, input_file_encoding = utils.get_file_content(input_file_path, encoding=file_encoding) + token_count = self.token_counter.count_tokens(content) + + content_without_sql_comments = None + token_count_without_sql_comments = None + + if is_sql: + content_without_sql_comments = utils.remove_sql_comments(content) + content_without_sql_comments = re.sub(r'\s+', ' ', content_without_sql_comments) + token_count_without_sql_comments = self.token_counter.count_tokens(content_without_sql_comments) + + return FileTokenMetadata( + input_file_number=input_file_number, + input_file_path=input_file_path, + input_file_encoding=input_file_encoding, + tokenizer_type=self.tokenizer_type, + tokenizer_model=self.tokenizer_model, + input_file_token_count=token_count, + input_file_token_count_without_sql_comments=token_count_without_sql_comments, + input_file_content=content, + input_file_content_without_sql_comments=content_without_sql_comments + ) diff --git a/sql2dbx/notebooks/pyscripts/notebook_export_helper.py b/sql2dbx/notebooks/pyscripts/notebook_export_helper.py new file mode 100644 index 00000000..7482acfb --- /dev/null +++ b/sql2dbx/notebooks/pyscripts/notebook_export_helper.py @@ -0,0 +1,187 @@ +import base64 +import os +from dataclasses import dataclass +from typing import List, Optional + +from .notebook_i18n import MessageKey, get_language_messages + + +@dataclass +class ExportInput: + """ + Input data for exporting a script to a Databricks notebook. + + Attributes: + input_file_path (str): Path to the source script file + output_dir (str): Directory where the notebook will be saved + code (str): Content of the script to be converted + comment_lang (str): Language code for comments (e.g., "English", "Japanese") + python_parse_error (str, optional): Python syntax error message, if any + sql_parse_error (List[str], optional): List of SQL syntax error messages, if any + """ + input_file_path: str + output_dir: str + code: str + comment_lang: str + python_parse_error: str = None + sql_parse_error: List[str] = None + + +@dataclass +class ExportOutput: + """ + Output data from exporting a script to a Databricks notebook. + + Attributes: + input_file_path (str): Original input file path + output_file_path (str): Path where the notebook was saved + base64_encoded_content (str): Base64 encoded notebook content + base64_encoded_content_size (int): Size of the encoded content + export_succeeded (bool): Whether the export succeeded + export_error (str, optional): Error message if export failed + parse_error_count (int): Number of parse errors found + """ + input_file_path: str + output_file_path: str + base64_encoded_content: str + base64_encoded_content_size: int + export_succeeded: bool = False + export_error: Optional[str] = None + parse_error_count: int = 0 + + +class NotebookExportHelper: + """ + Helper class for exporting scripts to Databricks notebooks. + + This class handles the conversion of script files to Databricks notebook format, + including internationalization support and error reporting. + """ + + def process_notebooks(self, exporter_inputs: List[ExportInput]) -> List[ExportOutput]: + """ + Processes a list of ExportInput objects, generates Databricks notebook content, + and returns a list of ExportOutput objects containing the output file paths + and their base64 encoded content. + + Args: + exporter_inputs (List[ExportInput]): List of input objects with script details + + Returns: + List[ExportOutput]: List of output objects with notebook details + """ + results = [] + unique_paths = self.generate_unique_output_paths(exporter_inputs) + + for exporter_input, output_file_path in zip(exporter_inputs, unique_paths): + # Skip if code is None or empty + if not exporter_input.code: + print(f"Skipping file due to empty code: {exporter_input.input_file_path}") + continue + + # Create Databricks notebook content and encode with base64 + notebook_content = self.create_notebook_content(output_file_path=output_file_path, ex_in=exporter_input) + encoded_content = base64.b64encode(notebook_content.encode('utf-8')).decode('utf-8') + + # Calculate parse error count + parse_error_count = (1 if exporter_input.python_parse_error else 0) + \ + (len(exporter_input.sql_parse_error) if exporter_input.sql_parse_error else 0) + + # Append ExportOutput object to the results list + results.append(ExportOutput( + input_file_path=exporter_input.input_file_path, + output_file_path=output_file_path, + base64_encoded_content=encoded_content, + base64_encoded_content_size=len(encoded_content), + parse_error_count=parse_error_count, + export_succeeded=True # Set to True as we've successfully created the content + )) + + return results + + def create_notebook_content(self, output_file_path: str, ex_in: ExportInput) -> str: + """ + Creates the content for a Databricks notebook. + + This function generates the notebook content including: + - A header with the notebook name and source script information + - The original code from the input script + - Any Python or SQL syntax errors, if present + - A message indicating no errors if both Python and SQL parse errors are absent + + Args: + output_file_path (str): The path where the notebook will be saved + ex_in (ExportInput): An object containing the input script details and any syntax errors + + Returns: + str: The complete content of the Databricks notebook + """ + notebook_name = os.path.basename(output_file_path) + messages = get_language_messages(ex_in.comment_lang) + notebook_content = ( + f"# Databricks notebook source\n" + f"# MAGIC %md\n" + f"# MAGIC # {notebook_name}\n" + f"# MAGIC {messages[MessageKey.NOTEBOOK_DESCRIPTION]}\n" + f"# MAGIC \n" + f"# MAGIC {messages[MessageKey.SOURCE_SCRIPT]}: `{ex_in.input_file_path}`\n" + f"# COMMAND ----------\n" + f"{ex_in.code}\n" + f"# COMMAND ----------\n" + f"# MAGIC %md\n" + f"# MAGIC ## {messages[MessageKey.SYNTAX_CHECK_RESULTS]}\n" + ) + if ex_in.python_parse_error or ex_in.sql_parse_error: + notebook_content += ( + f"# MAGIC {messages[MessageKey.ERRORS_FROM_CHECKS]}\n" + ) + if ex_in.python_parse_error: + notebook_content += ( + f"# MAGIC ### {messages[MessageKey.PYTHON_SYNTAX_ERRORS]}\n" + f"# MAGIC ```\n" + f"# MAGIC {ex_in.python_parse_error}\n" + f"# MAGIC ```\n" + ) + if ex_in.sql_parse_error: + notebook_content += ( + f"# MAGIC ### {messages[MessageKey.SPARK_SQL_SYNTAX_ERRORS]}\n" + f"# MAGIC ```\n" + ) + for error in ex_in.sql_parse_error: + notebook_content += f"# MAGIC {error}\n" + notebook_content += f"# MAGIC ```\n" + else: + notebook_content += ( + f"# MAGIC {messages[MessageKey.NO_ERRORS_DETECTED]}\n" + f"# MAGIC {messages[MessageKey.REVIEW_CODE]}\n" + ) + return notebook_content + + def generate_unique_output_paths(self, exporter_inputs: List[ExportInput]) -> List[str]: + """ + Generates unique output file paths by appending a number if the file already exists. + The number is incremented from 1 until a unique path is found. + + Args: + exporter_inputs (List[ExportInput]): List of export input objects + + Returns: + List[str]: List of unique output file paths + """ + seen_paths = set() + unique_paths = [] + + for ex_in in exporter_inputs: + base_name = os.path.basename(ex_in.input_file_path) + name, _ = os.path.splitext(base_name) + output_file_path = os.path.join(ex_in.output_dir, name) + counter = 1 + + while output_file_path in seen_paths: + output_file_path = os.path.join(ex_in.output_dir, f"{name}_{counter}") + counter += 1 + + seen_paths.add(output_file_path) + unique_paths.append(output_file_path) + + return unique_paths diff --git a/sql2dbx/notebooks/pyscripts/notebook_i18n.py b/sql2dbx/notebooks/pyscripts/notebook_i18n.py new file mode 100644 index 00000000..dddfdc75 --- /dev/null +++ b/sql2dbx/notebooks/pyscripts/notebook_i18n.py @@ -0,0 +1,233 @@ +from enum import Enum +from typing import Dict + + +class MessageKey(Enum): + """Enum representing all possible message keys used in the application.""" + NOTEBOOK_DESCRIPTION = "notebook_description" + SOURCE_SCRIPT = "source_script" + SYNTAX_CHECK_RESULTS = "syntax_check_results" + ERRORS_FROM_CHECKS = "errors_from_checks" + PYTHON_SYNTAX_ERRORS = "python_syntax_errors" + SPARK_SQL_SYNTAX_ERRORS = "spark_sql_syntax_errors" + NO_ERRORS_DETECTED = "no_errors_detected" + REVIEW_CODE = "review_code" + + +# English messages +MESSAGES_EN = { + MessageKey.NOTEBOOK_DESCRIPTION: ( + "This notebook was automatically converted from the script below. " + "It may contain errors, so use it as a starting point and make necessary corrections." + ), + MessageKey.SOURCE_SCRIPT: "Source script", + MessageKey.SYNTAX_CHECK_RESULTS: "Static Syntax Check Results", + MessageKey.ERRORS_FROM_CHECKS: ( + "These are errors from static syntax checks. Manual corrections are required for these errors." + ), + MessageKey.PYTHON_SYNTAX_ERRORS: "Python Syntax Errors", + MessageKey.SPARK_SQL_SYNTAX_ERRORS: "Spark SQL Syntax Errors", + MessageKey.NO_ERRORS_DETECTED: "No syntax errors were detected during the static check.", + MessageKey.REVIEW_CODE: ( + "However, please review the code carefully as some issues may only be detected during runtime." + ), +} + +# Japanese messages +MESSAGES_JA = { + MessageKey.NOTEBOOK_DESCRIPTION: ( + "このノートブックは以下のスクリプトから自動的に変換されました。" + "エラーが含まれている可能性があるため、出発点として使用し、必要な修正を行ってください。" + ), + MessageKey.SOURCE_SCRIPT: "ソーススクリプト", + MessageKey.SYNTAX_CHECK_RESULTS: "静的構文チェック結果", + MessageKey.ERRORS_FROM_CHECKS: ( + "以下は静的構文チェックの結果です。エラーがある場合、手動での修正が必要です。" + ), + MessageKey.PYTHON_SYNTAX_ERRORS: "Python構文エラー", + MessageKey.SPARK_SQL_SYNTAX_ERRORS: "Spark SQL構文エラー", + MessageKey.NO_ERRORS_DETECTED: "静的チェック中に構文エラーは検出されませんでした。", + MessageKey.REVIEW_CODE: ( + "ただし、一部の問題は実行時にのみ検出される可能性があるため、" + "コードを注意深く確認してください。" + ), +} + +# Chinese messages +MESSAGES_ZH = { + MessageKey.NOTEBOOK_DESCRIPTION: ( + "此笔记本是从以下脚本自动转换而来。它可能包含错误,请将其作为起点并进行必要的修正。" + ), + MessageKey.SOURCE_SCRIPT: "源脚本", + MessageKey.SYNTAX_CHECK_RESULTS: "静态语法检查结果", + MessageKey.ERRORS_FROM_CHECKS: ( + "这些是静态语法检查中发现的错误。这些错误需要手动修正。" + ), + MessageKey.PYTHON_SYNTAX_ERRORS: "Python语法错误", + MessageKey.SPARK_SQL_SYNTAX_ERRORS: "Spark SQL语法错误", + MessageKey.NO_ERRORS_DETECTED: "在静态检查中未检测到语法错误。", + MessageKey.REVIEW_CODE: ( + "但是,请仔细检查代码,因为某些问题可能只有在运行时才能检测到。" + ), +} + +# French messages +MESSAGES_FR = { + MessageKey.NOTEBOOK_DESCRIPTION: ( + "Ce notebook a été automatiquement converti à partir du script ci-dessous. " + "Il peut contenir des erreurs, utilisez-le comme point de départ et apportez les corrections nécessaires." + ), + MessageKey.SOURCE_SCRIPT: "Script source", + MessageKey.SYNTAX_CHECK_RESULTS: "Résultats de la vérification syntaxique statique", + MessageKey.ERRORS_FROM_CHECKS: ( + "Voici les erreurs détectées lors des vérifications syntaxiques statiques. " + "Des corrections manuelles sont nécessaires pour ces erreurs." + ), + MessageKey.PYTHON_SYNTAX_ERRORS: "Erreurs de syntaxe Python", + MessageKey.SPARK_SQL_SYNTAX_ERRORS: "Erreurs de syntaxe Spark SQL", + MessageKey.NO_ERRORS_DETECTED: "Aucune erreur de syntaxe n'a été détectée lors de la vérification statique.", + MessageKey.REVIEW_CODE: ( + "Cependant, veuillez examiner attentivement le code car certains problèmes " + "ne peuvent être détectés que lors de l'exécution." + ), +} + +# German messages +MESSAGES_DE = { + MessageKey.NOTEBOOK_DESCRIPTION: ( + "Dieses Notebook wurde automatisch aus dem unten stehenden Skript konvertiert. " + "Es kann Fehler enthalten, verwenden Sie es als Ausgangspunkt und nehmen Sie die notwendigen Korrekturen vor." + ), + MessageKey.SOURCE_SCRIPT: "Quellskript", + MessageKey.SYNTAX_CHECK_RESULTS: "Ergebnisse der statischen Syntaxprüfung", + MessageKey.ERRORS_FROM_CHECKS: ( + "Dies sind Fehler aus statischen Syntaxprüfungen. " + "Für diese Fehler sind manuelle Korrekturen erforderlich." + ), + MessageKey.PYTHON_SYNTAX_ERRORS: "Python-Syntaxfehler", + MessageKey.SPARK_SQL_SYNTAX_ERRORS: "Spark SQL-Syntaxfehler", + MessageKey.NO_ERRORS_DETECTED: "Bei der statischen Prüfung wurden keine Syntaxfehler festgestellt.", + MessageKey.REVIEW_CODE: ( + "Überprüfen Sie den Code jedoch sorgfältig, da einige Probleme " + "möglicherweise erst zur Laufzeit erkannt werden." + ), +} + +# Italian messages +MESSAGES_IT = { + MessageKey.NOTEBOOK_DESCRIPTION: ( + "Questo notebook è stato convertito automaticamente dallo script sottostante. " + "Potrebbe contenere errori, quindi usalo come punto di partenza e apporta le correzioni necessarie." + ), + MessageKey.SOURCE_SCRIPT: "Script sorgente", + MessageKey.SYNTAX_CHECK_RESULTS: "Risultati del controllo sintattico statico", + MessageKey.ERRORS_FROM_CHECKS: ( + "Questi sono errori derivanti dai controlli sintattici statici. " + "Sono necessarie correzioni manuali per questi errori." + ), + MessageKey.PYTHON_SYNTAX_ERRORS: "Errori di sintassi Python", + MessageKey.SPARK_SQL_SYNTAX_ERRORS: "Errori di sintassi Spark SQL", + MessageKey.NO_ERRORS_DETECTED: "Non sono stati rilevati errori di sintassi durante il controllo statico.", + MessageKey.REVIEW_CODE: ( + "Tuttavia, si prega di rivedere attentamente il codice poiché alcuni problemi " + "potrebbero essere rilevati solo durante l'esecuzione." + ), +} + +# Korean messages +MESSAGES_KO = { + MessageKey.NOTEBOOK_DESCRIPTION: ( + "이 노트북은 아래 스크립트에서 자동으로 변환되었습니다. " + "오류가 포함되어 있을 수 있으므로 시작점으로 사용하고 필요한 수정을 하십시오." + ), + MessageKey.SOURCE_SCRIPT: "소스 스크립트", + MessageKey.SYNTAX_CHECK_RESULTS: "정적 구문 검사 결과", + MessageKey.ERRORS_FROM_CHECKS: ( + "이것들은 정적 구문 검사에서 발견된 오류입니다. 이러한 오류에는 수동 수정이 필요합니다." + ), + MessageKey.PYTHON_SYNTAX_ERRORS: "Python 구문 오류", + MessageKey.SPARK_SQL_SYNTAX_ERRORS: "Spark SQL 구문 오류", + MessageKey.NO_ERRORS_DETECTED: "정적 검사 중 구문 오류가 감지되지 않았습니다.", + MessageKey.REVIEW_CODE: ( + "그러나 일부 문제는 런타임에만 감지될 수 있으므로 코드를 주의 깊게 검토하십시오." + ), +} + +# Portuguese messages +MESSAGES_PT = { + MessageKey.NOTEBOOK_DESCRIPTION: ( + "Este notebook foi convertido automaticamente do script abaixo. " + "Pode conter erros, então use-o como ponto de partida e faça as correções necessárias." + ), + MessageKey.SOURCE_SCRIPT: "Script fonte", + MessageKey.SYNTAX_CHECK_RESULTS: "Resultados da verificação de sintaxe estática", + MessageKey.ERRORS_FROM_CHECKS: ( + "Estes são erros de verificações de sintaxe estática. " + "Correções manuais são necessárias para esses erros." + ), + MessageKey.PYTHON_SYNTAX_ERRORS: "Erros de sintaxe Python", + MessageKey.SPARK_SQL_SYNTAX_ERRORS: "Erros de sintaxe Spark SQL", + MessageKey.NO_ERRORS_DETECTED: "Nenhum erro de sintaxe foi detectado durante a verificação estática.", + MessageKey.REVIEW_CODE: ( + "No entanto, revise o código cuidadosamente, pois alguns problemas " + "podem ser detectados apenas durante a execução." + ), +} + +# Spanish messages +MESSAGES_ES = { + MessageKey.NOTEBOOK_DESCRIPTION: ( + "Este notebook se convirtió automáticamente del script a continuación. " + "Puede contener errores, así que úselo como punto de partida y realice las correcciones necesarias." + ), + MessageKey.SOURCE_SCRIPT: "Script fuente", + MessageKey.SYNTAX_CHECK_RESULTS: "Resultados de la comprobación de sintaxis estática", + MessageKey.ERRORS_FROM_CHECKS: ( + "Estos son errores de las comprobaciones de sintaxis estáticas. " + "Se requieren correcciones manuales para estos errores." + ), + MessageKey.PYTHON_SYNTAX_ERRORS: "Errores de sintaxis de Python", + MessageKey.SPARK_SQL_SYNTAX_ERRORS: "Errores de sintaxis de Spark SQL", + MessageKey.NO_ERRORS_DETECTED: "No se detectaron errores de sintaxis durante la comprobación estática.", + MessageKey.REVIEW_CODE: ( + "Sin embargo, revise cuidadosamente el código, ya que algunos problemas " + "solo pueden detectarse durante la ejecución." + ), +} + +# Map of language codes to message dictionaries +LANGUAGE_MAP = { + "English": MESSAGES_EN, + "Japanese": MESSAGES_JA, + "Chinese": MESSAGES_ZH, + "French": MESSAGES_FR, + "German": MESSAGES_DE, + "Italian": MESSAGES_IT, + "Korean": MESSAGES_KO, + "Portuguese": MESSAGES_PT, + "Spanish": MESSAGES_ES, +} + + +def get_supported_languages() -> list[str]: + """ + Returns a list of supported language codes. + + Returns: + list[str]: List of supported language codes (e.g., ["English", "Japanese"]) + """ + return list(LANGUAGE_MAP.keys()) + + +def get_language_messages(lang_code: str) -> Dict[MessageKey, str]: + """ + Returns the message dictionary for the specified language code. + Falls back to English if the language code is not supported. + + Args: + lang_code (str): The language code to get messages for + + Returns: + Dict[MessageKey, str]: Dictionary mapping message keys to translated strings + """ + return LANGUAGE_MAP.get(lang_code, MESSAGES_EN) diff --git a/sql2dbx/notebooks/pyscripts/spark_sql_extract_helper.py b/sql2dbx/notebooks/pyscripts/spark_sql_extract_helper.py new file mode 100644 index 00000000..078d82b9 --- /dev/null +++ b/sql2dbx/notebooks/pyscripts/spark_sql_extract_helper.py @@ -0,0 +1,167 @@ +import ast +from typing import List, Optional, Tuple, Union + + +class SparkSQLExtractHelper: + """A class to extract Spark SQL statements from a given Python function string.""" + + def __init__(self) -> None: + self.sql_statements: List[str] = [] + self.variables: dict = {} + + def extract_sql_from_string(self, func_string: str) -> Tuple[Optional[str], List[str]]: + """ + Parses a Python function string and extracts Spark SQL statements. + + This method handles various input types and edge cases: + - When func_string is None, returns (None, []) to maintain consistency + with the success case return pattern + - For non-string inputs or parsing errors, returns an error message and empty list + + Args: + func_string: The Python function as a string. Can be None. + + Returns: + Tuple[Optional[str], List[str]]: A tuple containing: + - None if no error occurred, or an error message string + - A list of extracted SQL statements (empty if none found or if an error occurred) + """ + if func_string is None: + return None, [] + + try: + tree = ast.parse(func_string) + self.sql_statements = [] + self.variables = {} + self.visit(tree) + cleaned_statements = [self.clean_sql(sql) for sql in self.sql_statements] + return None, cleaned_statements + except SyntaxError as e: + return f"Python syntax error: {str(e)}", [] + except TypeError as e: + return f"Type error during parsing: {str(e)}", [] + except Exception as e: + return f"Unexpected error: {str(e)}", [] + + def clean_sql(self, sql: str) -> str: + """ + Cleans the extracted SQL statement by removing newline characters, trimming spaces, + and removing curly braces. + + Args: + sql (str): The extracted SQL statement. + + Returns: + str: The cleaned SQL statement. + """ + cleaned_sql = sql.replace('\n', ' ').replace('\r', ' ').strip() + cleaned_sql = cleaned_sql.replace('{', '').replace('}', '') + return cleaned_sql + + def visit(self, node: ast.AST) -> None: + """ + Visits each node in the AST tree and processes relevant nodes. + + Args: + node (ast.AST): The root node of the AST tree. + """ + for child in ast.walk(node): + if isinstance(child, ast.Assign): + self.visit_assign(child) + elif isinstance(child, ast.Call): + self.visit_call(child) + + def visit_assign(self, node: ast.Assign) -> None: + """ + Processes assignment nodes to extract variable values. + + Args: + node (ast.Assign): An assignment node in the AST tree. + """ + if isinstance(node.targets[0], ast.Name): + value = self.extract_value(node.value) + if value: + self.variables[node.targets[0].id] = value + + def extract_value(self, node: ast.AST) -> Optional[str]: + """ + Extracts the value from a given AST node. + + Args: + node (ast.AST): An AST node. + + Returns: + Optional[str]: The extracted value as a string, or None if the value could not be extracted. + """ + if isinstance(node, ast.Constant): + return str(node.value) + elif isinstance(node, ast.JoinedStr): + return self.process_f_string(node) + elif isinstance(node, ast.Name): + return self.variables.get(node.id, "_placeholder_") + elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.Add): + left = self.extract_value(node.left) + right = self.extract_value(node.right) + if left and right: + return left + right + return None + + def process_f_string(self, node: ast.JoinedStr) -> str: + """ + Processes an f-string node to extract its value. + + Args: + node (ast.JoinedStr): An f-string node in the AST tree. + + Returns: + str: The extracted f-string value. + """ + return ''.join(self.process_f_string_part(v) for v in node.values) + + def process_f_string_part(self, part: Union[ast.Constant, ast.FormattedValue]) -> str: + """ + Processes a part of an f-string node. + + Args: + part (Union[ast.Constant, ast.FormattedValue]): A part of an f-string node. + + Returns: + str: The extracted value as a string. + """ + if isinstance(part, ast.Constant): + return str(part.value) + elif isinstance(part, ast.FormattedValue): + value = self.extract_value(part.value) + if value and isinstance(part.value, ast.Name): + return str(self.variables.get(part.value.id, "_placeholder_")) + return f"{value}" if value else "_placeholder_" + else: + return str(part) + + def visit_call(self, node: ast.Call) -> None: + """ + Processes call nodes to extract Spark SQL statements. + + Args: + node (ast.Call): A call node in the AST tree. + """ + if isinstance(node.func, ast.Attribute) and node.func.attr == 'sql' and isinstance(node.func.value, ast.Name) and node.func.value.id == 'spark': + if len(node.args) == 1: + arg = node.args[0] + sql = self.extract_value(arg) + if sql: + self.sql_statements.append(sql) + + +# Usage example +if __name__ == "__main__": + helper = SparkSQLExtractHelper() + error, sql_statements = helper.extract_sql_from_string(""" + def example_func(): + query = "SELECT * FROM {table}" + spark.sql(query) + """) + if error: + print("Error:", error) + else: + print("SQL Statements:", sql_statements) diff --git a/sql2dbx/notebooks/pyscripts/token_utils.py b/sql2dbx/notebooks/pyscripts/token_utils.py new file mode 100644 index 00000000..f4dbf27a --- /dev/null +++ b/sql2dbx/notebooks/pyscripts/token_utils.py @@ -0,0 +1,134 @@ +from abc import ABC, abstractmethod +from enum import Enum +from typing import Dict, Tuple, Type + +import tiktoken + + +class TokenizerType(str, Enum): + """Enum representing different tokenizer types""" + OPENAI = "openai" + CLAUDE = "claude" + + +class BaseTokenCounter(ABC): + """Abstract base class for token counters""" + + @abstractmethod + def count_tokens(self, string: str) -> int: + """Returns the number of tokens in a text string.""" + pass + + +class OpenAITokenCounter(BaseTokenCounter): + """Token counter for OpenAI models using tiktoken""" + + def __init__(self, token_encoding_name: str = "o200k_base"): + """Initialize the TokenCounter with a specified token encoding.""" + self.encoding = tiktoken.get_encoding(token_encoding_name) + self.token_encoding_name = token_encoding_name + + def count_tokens(self, string: str) -> int: + """Returns the number of tokens in a text string.""" + return len(self.encoding.encode(string)) + + +class ClaudeTokenCounter(BaseTokenCounter): + """Token counter for Claude models based on character to token ratio""" + + def __init__(self, model: str = "claude"): + """ + Initialize the Claude token counter. + + Args: + model (str): The Claude model name (currently not used as the ratio + appears to be the same across Claude models, but included + for future compatibility) + """ + self.model = model + + def count_tokens(self, string: str) -> int: + """ + Estimate the number of tokens in the given text for Claude models. + + This estimation is based on Anthropic's documentation, which states that + approximately 200K tokens correspond to 680K Unicode characters. This implies + an average of about 3.4 characters per token for Claude models. + + Reference: + - https://docs.anthropic.com/en/docs/about-claude/models/all-models + + Args: + string (str): The input text for which to estimate the token count. + + Returns: + int: The estimated number of tokens in the input text. + """ + CLAUDE_CHAR_TO_TOKEN_RATIO = 3.4 # Average characters per token for Claude models + char_count = len(string) + estimated_tokens = char_count / CLAUDE_CHAR_TO_TOKEN_RATIO + return int(estimated_tokens) + + +# Dictionary mapping tokenizer types to their counter classes +TOKEN_COUNTER_CLASSES: Dict[TokenizerType, Type[BaseTokenCounter]] = { + TokenizerType.OPENAI: OpenAITokenCounter, + TokenizerType.CLAUDE: ClaudeTokenCounter, +} + + +def get_token_counter(tokenizer_type: str = "claude", model: str = "claude") -> BaseTokenCounter: + """ + Factory function to get the appropriate token counter for the specified tokenizer type. + + Args: + tokenizer_type (str): Type of tokenizer to use ('openai' or 'claude') + model (str): Model name or encoding to use (for OpenAI this is the encoding name, + for Claude this is currently not used but reserved for future compatibility) + + Returns: + BaseTokenCounter: An instance of the appropriate token counter + """ + try: + tokenizer_enum = TokenizerType(tokenizer_type.lower()) + if tokenizer_enum == TokenizerType.CLAUDE: + return ClaudeTokenCounter() # model parameter is ignored for Claude currently + counter_class = TOKEN_COUNTER_CLASSES[tokenizer_enum] + return counter_class(model) + except (KeyError, ValueError): + raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}. " + f"Supported types are: {', '.join([t.value for t in TokenizerType])}") + + +def determine_tokenizer_from_endpoint(endpoint_name: str) -> Tuple[str, str]: + """ + Determine the tokenizer type and model based on endpoint name. + + Args: + endpoint_name (str): The endpoint name to determine which tokenizer to use. + If 'claude' is in the name, Claude tokenizer is used, + otherwise OpenAI tokenizer is used. + + Returns: + Tuple[str, str]: A tuple containing (tokenizer_type, model) + """ + if 'claude' in endpoint_name.lower(): + return 'claude', 'claude' + else: + return 'openai', 'o200k_base' + + +def get_token_counter_from_endpoint(endpoint_name: str) -> BaseTokenCounter: + """ + Get a token counter based on the endpoint name. + + Args: + endpoint_name (str): The endpoint name to determine which tokenizer to use. + If 'claude' is in the name, Claude tokenizer is used, + otherwise OpenAI tokenizer is used. + + Returns: + BaseTokenCounter: An instance of the appropriate token counter + """ + tokenizer_type, model = determine_tokenizer_from_endpoint(endpoint_name) + return get_token_counter(tokenizer_type, model) diff --git a/sql2dbx/notebooks/pyscripts/utils.py b/sql2dbx/notebooks/pyscripts/utils.py new file mode 100644 index 00000000..e5ed6f1e --- /dev/null +++ b/sql2dbx/notebooks/pyscripts/utils.py @@ -0,0 +1,98 @@ +import logging +import os +import re +import sys +from typing import Optional, Tuple + +import chardet + + +def setup_logger(name, level=logging.INFO): + """Function to setup a logger that outputs to stdout""" + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(formatter) + logger = logging.getLogger(name) + logger.setLevel(level) + if logger.hasHandlers(): + logger.handlers.clear() + logger.addHandler(handler) + return logger + + +def list_files_recursively(input_dir: str) -> list[str]: + """ + Recursively list all files in the specified directory. + + Args: + input_dir (str): The directory to search for files. + + Returns: + list[str]: A list of file paths. + """ + all_files = [] + for root, _, files in os.walk(input_dir): + for file in files: + file_path = os.path.join(root, file) + if os.path.isfile(file_path): + all_files.append(file_path) + return all_files + + +def get_file_content(input_file_path: str, encoding: Optional[str] = None) -> Tuple[str, str]: + """ + Returns the content of a specified file as a string along with its encoding. + + Args: + input_file_path (str): The path of the file to read. + encoding (Optional[str]): The encoding to use for reading the file. If not specified, chardet.detect is used. + + Returns: + Tuple[str, str]: A tuple containing the file content and its encoding. + """ + with open(input_file_path, 'rb') as file: + raw_data = file.read() + if encoding is None: + result = chardet.detect(raw_data) + encoding = result['encoding'] or 'utf-8' # Use 'utf-8' if encoding detection fails + content = raw_data.decode(encoding, errors='replace') + return content, encoding + + +def remove_sql_comments(sql_text: str) -> str: + """ + Removes both line and block comments from SQL text. + + Args: + sql_text (str): The SQL text to clean. + + Returns: + str: The SQL text without comments. + """ + no_line_comments = re.sub(r'--.*', '', sql_text) + no_comments = re.sub(r'/\*.*?\*/', '', no_line_comments, flags=re.DOTALL) + no_comments = re.sub(r'\n\s*\n', '\n\n', no_comments) # Remove multiple empty lines + return no_comments + + +def parse_number_ranges(input_string: str) -> list[int]: + """Parses a comma-separated string into a list of integers. + The string can contain single integers or hyphen-separated ranges (e.g., "5-8"). + + Args: + input_string: The string containing comma-separated integers or ranges. + + Returns: + A list containing all integers found in input_string. + """ + result_numbers = [] + if input_string: # Process only if input is not empty + for part in input_string.split(','): + part = part.strip() # Remove extra whitespace + if '-' in part: # Range detected + start_str, end_str = part.split('-') + start, end = int(start_str), int(end_str) + result_numbers.extend(range(start, end + 1)) # Add range of numbers + else: + result_numbers.append(int(part)) # Add single number + return result_numbers diff --git a/sql2dbx/notebooks/requirements.txt b/sql2dbx/notebooks/requirements.txt new file mode 100644 index 00000000..4de68e9a --- /dev/null +++ b/sql2dbx/notebooks/requirements.txt @@ -0,0 +1,8 @@ +anytree==2.13.0 +chardet==5.2.0 +databricks-sdk>=0.47.0 +httpx==0.28.1 +omegaconf==2.3.0 +PyYAML==6.0.2 +tenacity==9.1.2 +tiktoken==0.9.0 \ No newline at end of file diff --git a/sql2dbx/resources/sql2dbx_job_main.yml b/sql2dbx/resources/sql2dbx_job_main.yml new file mode 100644 index 00000000..64450f76 --- /dev/null +++ b/sql2dbx/resources/sql2dbx_job_main.yml @@ -0,0 +1,45 @@ +resources: + jobs: + sql2dbx_job_main: + name: sql2dbx_job_main + max_concurrent_runs: 1000 + tasks: + - task_key: 00_main + notebook_task: + notebook_path: ../notebooks/00_main.py + source: WORKSPACE + queue: + enabled: true + parameters: + # Params for 01_analyze_input_files + - name: input_dir + default: "" + - name: result_catalog + default: "" + - name: result_schema + default: "" + - name: token_count_threshold + default: "20000" + - name: existing_result_table + default: "" + # Params for 02_convert_sql_to_databricks + - name: endpoint_name + default: "databricks-claude-3-7-sonnet" + - name: sql_dialect + default: "tsql" + - name: conversion_prompt_yaml + default: "" + - name: comment_lang + default: "English" + - name: concurrency + default: "4" + - name: request_params + default: "" + - name: log_level + default: "DEBUG" + # Params for 03_convert_sql_to_databricks + - name: max_fix_attempts + default: "1" + # Params for 04_export_to_databricks_notebooks + - name: output_dir + default: "" diff --git a/sql2dbx/tests/__init__.py b/sql2dbx/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sql2dbx/tests/pyscripts/__init__.py b/sql2dbx/tests/pyscripts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sql2dbx/tests/pyscripts/test_cell_split_helper.py b/sql2dbx/tests/pyscripts/test_cell_split_helper.py new file mode 100644 index 00000000..0449d556 --- /dev/null +++ b/sql2dbx/tests/pyscripts/test_cell_split_helper.py @@ -0,0 +1,147 @@ +import unittest +import textwrap + +from notebooks.pyscripts.cell_split_helper import CellSplitHelper + + +class TestCellSplitHelperTopLevel(unittest.TestCase): + def setUp(self): + self.helper = CellSplitHelper() + + def test_empty_code(self): + """ + Test case for empty code input. Expected to return None. + """ + code = "" + result = self.helper.split_cells(code) + self.assertIsNone(result) + + def test_none_input(self): + """ + Test case for None input. Expected to return None. + """ + code = None + result = self.helper.split_cells(code) + self.assertIsNone(result) + + def test_syntax_error_code(self): + """ + Test case for code with syntax error. + Checks how it's handled - possibly returning "COMMAND marker + code" on parse failure. + """ + code = "if True:\n" # Unclosed block + result = self.helper.split_cells(code) + # Verify that the result is not None and contains COMMAND_MARKER + self.assertIsNotNone(result) + self.assertIn(self.helper.COMMAND_MARKER, result) + self.assertIn("if True:", result) + + def test_basic_single_statement(self): + """ + Test case for a single top-level statement. + Example: Assign + print + """ + code = textwrap.dedent("""\ + var1 = 1 + print(var1) + """) + result = self.helper.split_cells(code) + lines = result.split("\n") + self.assertEqual(lines[0], self.helper.COMMAND_MARKER) + self.assertEqual(lines[1], "var1 = 1") + self.assertEqual(lines[2], "print(var1)") + + def test_code_with_leading_comments(self): + """ + Test case for code with leading comments. + Verify that comments are included in the same block. + """ + code = textwrap.dedent("""\ + # leading comment + var1 = 1 + var2 = 2 + """) + result = self.helper.split_cells(code) + lines = result.split("\n") + self.assertEqual(lines[0], self.helper.COMMAND_MARKER) + self.assertEqual(lines[1], "# leading comment") + self.assertEqual(lines[2], "var1 = 1") + self.assertEqual(lines[3], "var2 = 2") + + def test_top_level_splitting_example(self): + """ + Test case for a larger sample to verify expected cell splitting. + """ + code = textwrap.dedent("""\ + # Some initial vars + var1 = 1 + var2 = 2 + + # If block + if var1 < 10: + var2 = spark.sql("SELECT * FROM table WHERE col < 10") + display(var2) + + # Widgets + dbutils.widgets.text("paramA", "valA") + dbutils.widgets.text("paramB", "valB") + + # Another top-level statement + var3 = var1 + var2.count() + + # Try-except block + try: + var4 = spark.sql("SELECT 100") + except: + pass + + # For block + for i in range(3): + print(i) + + # Finally exit + dbutils.notebook.exit("Done") + """) + expected = textwrap.dedent(f"""\ + {self.helper.COMMAND_MARKER} + # Some initial vars + var1 = 1 + var2 = 2 + + {self.helper.COMMAND_MARKER} + # If block + if var1 < 10: + var2 = spark.sql("SELECT * FROM table WHERE col < 10") + display(var2) + + {self.helper.COMMAND_MARKER} + # Widgets + dbutils.widgets.text("paramA", "valA") + dbutils.widgets.text("paramB", "valB") + + {self.helper.COMMAND_MARKER} + # Another top-level statement + var3 = var1 + var2.count() + + {self.helper.COMMAND_MARKER} + # Try-except block + try: + var4 = spark.sql("SELECT 100") + except: + pass + + {self.helper.COMMAND_MARKER} + # For block + for i in range(3): + print(i) + + {self.helper.COMMAND_MARKER} + # Finally exit + dbutils.notebook.exit("Done") + """) + result = self.helper.split_cells(code) + self.assertEqual(result, expected) + + +if __name__ == '__main__': + unittest.main() diff --git a/sql2dbx/tests/pyscripts/test_conversion_prompt_helper.py b/sql2dbx/tests/pyscripts/test_conversion_prompt_helper.py new file mode 100644 index 00000000..7bfef121 --- /dev/null +++ b/sql2dbx/tests/pyscripts/test_conversion_prompt_helper.py @@ -0,0 +1,174 @@ +import os +import tempfile +import unittest +from pathlib import Path + +import yaml + +from notebooks.pyscripts.conversion_prompt_helper import ( + ConversionPromptHelper, PromptConfig, SupportedSQLDialect) + + +class TestSupportedSQLDialect(unittest.TestCase): + def test_enum_values(self): + """Test that the SupportedSQLDialect enum contains the expected values.""" + self.assertEqual(SupportedSQLDialect.POSTGRESQL.name, "POSTGRESQL") + self.assertEqual(SupportedSQLDialect.SNOWFLAKE.name, "SNOWFLAKE") + self.assertEqual(SupportedSQLDialect.TSQL.name, "TSQL") + + # Test the dialect_name property + self.assertEqual(SupportedSQLDialect.POSTGRESQL.dialect_name, "postgresql") + self.assertEqual(SupportedSQLDialect.SNOWFLAKE.dialect_name, "snowflake") + self.assertEqual(SupportedSQLDialect.TSQL.dialect_name, "tsql") + + # Test the default_yaml_filename property + self.assertEqual(SupportedSQLDialect.POSTGRESQL.default_yaml_filename, + "postgresql_to_databricks_notebook.yml") + self.assertEqual(SupportedSQLDialect.SNOWFLAKE.default_yaml_filename, + "snowflake_to_databricks_notebook.yml") + self.assertEqual(SupportedSQLDialect.TSQL.default_yaml_filename, + "tsql_to_databricks_notebook.yml") + + def test_get_supported_sql_dialects(self): + """Test that get_supported_sql_dialects returns all expected dialects.""" + dialects = ConversionPromptHelper.get_supported_sql_dialects() + + # Check that all expected dialects are present + self.assertIn("mysql", dialects) + self.assertIn("netezza", dialects) + self.assertIn("oracle", dialects) + self.assertIn("postgresql", dialects) + self.assertIn("redshift", dialects) + self.assertIn("snowflake", dialects) + self.assertIn("tsql", dialects) + self.assertIn("teradata", dialects) + + # Check the total count to ensure no extras or missing ones + self.assertEqual(len(dialects), 8) + + def test_get_default_yaml_for_sql_dialect(self): + """Test that get_default_yaml_for_sql_dialect returns the correct paths.""" + # Test for postgresql + postgresql_path = ConversionPromptHelper.get_default_yaml_for_sql_dialect("postgresql") + self.assertIsInstance(postgresql_path, str) + self.assertTrue(postgresql_path.endswith("postgresql_to_databricks_notebook.yml")) + + # Test for snowflake + snowflake_path = ConversionPromptHelper.get_default_yaml_for_sql_dialect("snowflake") + self.assertIsInstance(snowflake_path, str) + self.assertTrue(snowflake_path.endswith("snowflake_to_databricks_notebook.yml")) + + # Test for tsql + tsql_path = ConversionPromptHelper.get_default_yaml_for_sql_dialect("tsql") + self.assertIsInstance(tsql_path, str) + self.assertTrue(tsql_path.endswith("tsql_to_databricks_notebook.yml")) + + # Test with invalid dialect + with self.assertRaises(ValueError) as context: + ConversionPromptHelper.get_default_yaml_for_sql_dialect("invalid_dialect") + self.assertIn("Unsupported sql dialect", str(context.exception)) + + def test_path_resolution(self): + """Test that the path resolution for YAML files works correctly.""" + for dialect in SupportedSQLDialect: + yaml_path = ConversionPromptHelper.get_default_yaml_for_sql_dialect(dialect.dialect_name) + # The path should be an absolute path + self.assertTrue(Path(yaml_path).is_absolute()) + # The path should include the dialect-specific filename + self.assertTrue(yaml_path.endswith(dialect.default_yaml_filename)) + + +class TestConversionPromptHelper(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.valid_yaml_content = { + "system_message": "Convert SQL to Python ({comment_lang} comments)", + "few_shots": [ + { + "role": "user", + "content": "Sample SQL" + }, + { + "role": "assistant", + "content": "Sample Python" + } + ] + } + self.valid_yaml_path = os.path.join(self.temp_dir, "valid_prompt.yml") + with open(self.valid_yaml_path, 'w', encoding='utf-8') as f: + yaml.dump(self.valid_yaml_content, f) + + def tearDown(self): + if os.path.exists(self.temp_dir): + for file in os.listdir(self.temp_dir): + os.remove(os.path.join(self.temp_dir, file)) + os.rmdir(self.temp_dir) + + def test_conversion_prompt_helper_initialization(self): + """Test that ConversionPromptHelper initializes correctly""" + helper = ConversionPromptHelper(self.valid_yaml_path, "ja") + self.assertIsInstance(helper, ConversionPromptHelper) + self.assertIsInstance(helper.prompt_config, PromptConfig) + + def test_get_system_message(self): + """Test that get_system_message returns the correct message""" + helper = ConversionPromptHelper(self.valid_yaml_path, "ja") + expected = "Convert SQL to Python (ja comments)" + self.assertEqual(helper.get_system_message(), expected) + + def test_get_few_shots(self): + """Test that get_few_shots returns the correct few-shot examples""" + helper = ConversionPromptHelper(self.valid_yaml_path, "ja") + few_shots = helper.get_few_shots() + self.assertEqual(len(few_shots), 2) + self.assertEqual(few_shots[0]["role"], "user") + self.assertEqual(few_shots[0]["content"], "Sample SQL") + self.assertEqual(few_shots[1]["role"], "assistant") + self.assertEqual(few_shots[1]["content"], "Sample Python") + + def test_non_existent_yaml_file(self): + """Test that FileNotFoundError is raised when YAML file does not exist""" + with self.assertRaises(Exception) as context: + ConversionPromptHelper("non_existent.yml", "ja") + self.assertIn("YAML file not found", str(context.exception)) + + def test_invalid_yaml_content(self): + """Test that ValueError is raised when YAML content is not a dictionary""" + invalid_yaml_path = os.path.join(self.temp_dir, "invalid_prompt.yml") + with open(invalid_yaml_path, 'w', encoding='utf-8') as f: + yaml.dump(["invalid", "content"], f) + + with self.assertRaises(Exception) as context: + ConversionPromptHelper(invalid_yaml_path, "ja") + self.assertIn("YAML content must be a dictionary", str(context.exception)) + + def test_missing_system_message(self): + """Test that ValueError is raised when system_message key is missing""" + invalid_yaml_path = os.path.join(self.temp_dir, "missing_system_message.yml") + with open(invalid_yaml_path, 'w', encoding='utf-8') as f: + yaml.dump({"few_shots": []}, f) + + with self.assertRaises(Exception) as context: + ConversionPromptHelper(invalid_yaml_path, "ja") + self.assertIn("YAML must contain 'system_message' key", str(context.exception)) + + def test_yaml_without_few_shots(self): + """Test that empty list is returned when few_shots key is missing""" + yaml_path = os.path.join(self.temp_dir, "no_few_shots.yml") + content = {"system_message": "Test message"} + with open(yaml_path, 'w', encoding='utf-8') as f: + yaml.dump(content, f) + + helper = ConversionPromptHelper(yaml_path, "ja") + self.assertEqual(helper.get_few_shots(), []) + + def test_tsql_to_databricks_yaml(self): + """Test loading and parsing tsql_to_databricks_notebook.yml""" + yaml_path = ConversionPromptHelper.get_default_yaml_for_sql_dialect("tsql") + helper = ConversionPromptHelper(yaml_path, "ja") + system_message = helper.get_system_message() + self.assertIn("Instructions", system_message) + + +if __name__ == '__main__': + unittest.main() diff --git a/sql2dbx/tests/pyscripts/test_conversion_result_clean_helper.py b/sql2dbx/tests/pyscripts/test_conversion_result_clean_helper.py new file mode 100644 index 00000000..9a91182c --- /dev/null +++ b/sql2dbx/tests/pyscripts/test_conversion_result_clean_helper.py @@ -0,0 +1,142 @@ +import unittest + +from notebooks.pyscripts.conversion_result_clean_helper import \ + ConversionResultCleanHelper + + +class TestConversionResultCleanHelper(unittest.TestCase): + def setUp(self): + self.helper = ConversionResultCleanHelper() + + def test_none_input(self): + """Test that None input returns None.""" + self.assertIsNone(self.helper.clean_python_code_blocks(None)) + + def test_no_python_markers(self): + """Test text without any ```python markers.""" + text = "This is just regular text without any code blocks." + self.assertEqual(self.helper.clean_python_code_blocks(text), text) + + def test_starts_with_python_marker(self): + """Test text that starts with ```python and has a closing ```.""" + text = """```python +def hello(): + print("Hello, world!") +``` +Some extra text after the code block.""" + + expected = """def hello(): + print("Hello, world!") +""" + self.assertEqual(self.helper.clean_python_code_blocks(text), expected) + + def test_starts_with_whitespace_then_python_marker(self): + """Test text that starts with whitespace followed by ```python.""" + text = """ + ```python +def hello(): + print("Hello, world!") +``` +Some extra text after the code block.""" + + expected = """def hello(): + print("Hello, world!") +""" + self.assertEqual(self.helper.clean_python_code_blocks(text), expected) + + def test_python_marker_in_middle(self): + """Test text with ```python in the middle.""" + text = """columns.customer_id```python columns.order_id, columns.product_id, columns.timestamp""" + expected = """columns.customer_id columns.order_id, columns.product_id, columns.timestamp""" + self.assertEqual(self.helper.clean_python_code_blocks(text), expected) + + def test_sql_with_python_marker_in_middle(self): + """Test SQL query with ```python in the middle.""" + text = """# Creating a temporary view +spark.sql(\"\"\" + CREATE OR REPLACE TEMPORARY VIEW sales_data AS + SELECT DISTINCT + data.year, + data.month, + data.day, + data.region, + data.store, + data.product, + data.category, + data.quantity, + data.price, + data.total_sales```python + data.discount, + data.promotion, + data.customer_segment, + data.profit\"\"\")""" + + expected = """# Creating a temporary view +spark.sql(\"\"\" + CREATE OR REPLACE TEMPORARY VIEW sales_data AS + SELECT DISTINCT + data.year, + data.month, + data.day, + data.region, + data.store, + data.product, + data.category, + data.quantity, + data.price, + data.total_sales + data.discount, + data.promotion, + data.customer_segment, + data.profit\"\"\")""" + self.assertEqual(self.helper.clean_python_code_blocks(text), expected) + + def test_starts_with_python_no_closing_marker(self): + """Test text that starts with ```python but has no closing ```.""" + text = """```python +def hello(): + print("Hello, world!") +This continues without a closing marker.""" + + expected = """ +def hello(): + print("Hello, world!") +This continues without a closing marker.""" + self.assertEqual(self.helper.clean_python_code_blocks(text), expected) + + def test_indented_code_block(self): + """Test that indentation is preserved in extracted code.""" + text = """```python + def hello(): + print("Hello, world!") + if True: + print("Indented") +```""" + + expected = """ def hello(): + print("Hello, world!") + if True: + print("Indented") +""" + self.assertEqual(self.helper.clean_python_code_blocks(text), expected) + + def test_text_before_python_marker(self): + """Test with text before ```python marker.""" + text = """Filter condition: + ```python + if data['status'] == 'active' and data['value'] > 100: + return True + else: + return False""" + + expected = """Filter condition: + + if data['status'] == 'active' and data['value'] > 100: + return True + else: + return False""" + self.assertEqual(self.helper.clean_python_code_blocks(text), expected) + + +if __name__ == '__main__': + unittest.main() diff --git a/sql2dbx/tests/pyscripts/test_llm_token_count_helper.py b/sql2dbx/tests/pyscripts/test_llm_token_count_helper.py new file mode 100644 index 00000000..24e21f71 --- /dev/null +++ b/sql2dbx/tests/pyscripts/test_llm_token_count_helper.py @@ -0,0 +1,96 @@ +import os +import tempfile +import unittest + +from notebooks.pyscripts.llm_token_count_helper import (FileTokenCountHelper, + FileTokenMetadata) +from notebooks.pyscripts.token_utils import (ClaudeTokenCounter, + OpenAITokenCounter) + + +class TestFileTokenCountHelper(unittest.TestCase): + def test_init_with_endpoint_name(self): + # Use Databricks Claude endpoint + helper = FileTokenCountHelper(endpoint_name="databricks-claude-3-7-sonnet") + self.assertEqual(helper.tokenizer_type, "claude") + self.assertEqual(helper.tokenizer_model, "claude") + self.assertIsInstance(helper.token_counter, ClaudeTokenCounter) + + # Use OpenAI endpoint + helper = FileTokenCountHelper(endpoint_name="gpt-4o") + self.assertEqual(helper.tokenizer_type, "openai") + self.assertEqual(helper.tokenizer_model, "o200k_base") + self.assertIsInstance(helper.token_counter, OpenAITokenCounter) + + # Use other + helper = FileTokenCountHelper(endpoint_name="other") + self.assertEqual(helper.tokenizer_type, "openai") + self.assertEqual(helper.tokenizer_model, "o200k_base") + self.assertIsInstance(helper.token_counter, OpenAITokenCounter) + + def test_init_with_explicit_tokenizer(self): + # Explicitly specify OpenAI tokenizer + helper = FileTokenCountHelper(tokenizer_type="openai", tokenizer_model="o200k_base") + self.assertEqual(helper.tokenizer_type, "openai") + self.assertEqual(helper.tokenizer_model, "o200k_base") + self.assertIsInstance(helper.token_counter, OpenAITokenCounter) + + # Explicitly specify Claude tokenizer + helper = FileTokenCountHelper(tokenizer_type="claude") + self.assertEqual(helper.tokenizer_type, "claude") + self.assertEqual(helper.tokenizer_model, "claude") + self.assertIsInstance(helper.token_counter, ClaudeTokenCounter) + + def test_default_init(self): + # Default settings (Claude tokenizer) + helper = FileTokenCountHelper() + self.assertEqual(helper.tokenizer_type, "claude") + self.assertEqual(helper.tokenizer_model, "claude") + self.assertIsInstance(helper.token_counter, ClaudeTokenCounter) + + def test_process_file(self): + with tempfile.TemporaryDirectory() as temp_dir: + # Create test SQL file + file_path = os.path.join(temp_dir, "test.sql") + with open(file_path, 'w') as f: + f.write("SELECT * FROM table; -- comment\n/* block comment */\nSELECT * FROM another_table;") + + # Process file using Claude tokenizer + helper = FileTokenCountHelper(tokenizer_type="claude") + metadata = helper.process_file(file_path) + + # Validate metadata + self.assertEqual(metadata.input_file_path, file_path) + self.assertGreater(metadata.input_file_token_count, 0) + self.assertIsNotNone(metadata.input_file_content_without_sql_comments) + self.assertGreater(metadata.input_file_token_count, metadata.input_file_token_count_without_sql_comments) + self.assertEqual(metadata.tokenizer_type, "claude") + + def test_process_directory(self): + with tempfile.TemporaryDirectory() as temp_dir: + # Create multiple test SQL files + file1 = os.path.join(temp_dir, "file1.sql") + file2 = os.path.join(temp_dir, "file2.sql") + + with open(file1, 'w') as f: + f.write("SELECT * FROM table1;") + with open(file2, 'w') as f: + f.write("SELECT * FROM table2;") + + # Process directory files + helper = FileTokenCountHelper(tokenizer_type="claude") + results = helper.process_directory(temp_dir) + + # Validate results + self.assertEqual(len(results), 2) + self.assertIsInstance(results[0], FileTokenMetadata) + self.assertIsInstance(results[1], FileTokenMetadata) + + # Verify file numbers are set correctly + file_numbers = [result.input_file_number for result in results] + self.assertIn(1, file_numbers) + self.assertIn(2, file_numbers) + + +if __name__ == '__main__': + unittest.main() diff --git a/sql2dbx/tests/pyscripts/test_notebook_export_helper.py b/sql2dbx/tests/pyscripts/test_notebook_export_helper.py new file mode 100644 index 00000000..9e099e6b --- /dev/null +++ b/sql2dbx/tests/pyscripts/test_notebook_export_helper.py @@ -0,0 +1,130 @@ +import os +import unittest + +from notebooks.pyscripts.notebook_export_helper import (ExportInput, + NotebookExportHelper) +from notebooks.pyscripts.notebook_i18n import MessageKey, get_language_messages + +TEST_TEXT = """print("Hello, World!") + +def foo(): + return "bar" +""" + + +class TestNotebookExportHelper(unittest.TestCase): + """Test cases for the NotebookExportHelper class.""" + + def setUp(self): + """Set up test fixtures.""" + self.helper = NotebookExportHelper() + self.test_text = TEST_TEXT + self.input_file_path = "test_input_file.txt" + self.output_dir = "dummy_output_dir" + self.comment_lang = "English" + + def test_create_notebook_content(self): + """Test creating notebook content with basic input.""" + code = 'print("Hello, World!")' + output_file_path = "test_output_file" + ex_in = ExportInput(input_file_path=self.input_file_path, code=code, + output_dir=self.output_dir, comment_lang=self.comment_lang) + content = self.helper.create_notebook_content(output_file_path, ex_in) + messages = get_language_messages(self.comment_lang) + + self.assertTrue(content.startswith("# Databricks notebook source")) + self.assertIn(code, content) + self.assertIn(self.input_file_path, content) + self.assertIn(messages[MessageKey.SYNTAX_CHECK_RESULTS], content) + self.assertIn(messages[MessageKey.NO_ERRORS_DETECTED], content) + self.assertIn(messages[MessageKey.REVIEW_CODE], content) + + def test_create_notebook_content_with_errors(self): + """Test creating notebook content with Python and SQL errors.""" + code = 'print("Hello, World!")' + output_file_path = "test_output_file" + python_error = "SyntaxError: invalid syntax" + sql_errors = ["Error in SQL query", "Another SQL error"] + ex_in = ExportInput(input_file_path=self.input_file_path, code=code, output_dir=self.output_dir, + comment_lang=self.comment_lang, python_parse_error=python_error, sql_parse_error=sql_errors) + content = self.helper.create_notebook_content(output_file_path, ex_in) + messages = get_language_messages(self.comment_lang) + + self.assertIn(messages[MessageKey.SYNTAX_CHECK_RESULTS], content) + self.assertIn(messages[MessageKey.ERRORS_FROM_CHECKS], content) + self.assertIn(messages[MessageKey.PYTHON_SYNTAX_ERRORS], content) + self.assertIn(python_error, content) + self.assertIn(messages[MessageKey.SPARK_SQL_SYNTAX_ERRORS], content) + for error in sql_errors: + self.assertIn(error, content) + + def test_create_notebook_content_without_errors(self): + """Test creating notebook content without errors.""" + code = 'print("Hello, World!")' + output_file_path = "test_output_file" + ex_in = ExportInput(input_file_path=self.input_file_path, code=code, + output_dir=self.output_dir, comment_lang=self.comment_lang) + content = self.helper.create_notebook_content(output_file_path, ex_in) + messages = get_language_messages(self.comment_lang) + + self.assertIn(messages[MessageKey.SYNTAX_CHECK_RESULTS], content) + self.assertIn(messages[MessageKey.NO_ERRORS_DETECTED], content) + self.assertIn(messages[MessageKey.REVIEW_CODE], content) + self.assertNotIn(messages[MessageKey.ERRORS_FROM_CHECKS], content) + + def test_generate_unique_output_paths(self): + """Test generating unique output paths when duplicate inputs exist.""" + exporter_inputs = [ + ExportInput(input_file_path="test_input_file.txt", code=self.test_text, + output_dir=self.output_dir, comment_lang=self.comment_lang), + ExportInput(input_file_path="test_input_file.txt", code=self.test_text, + output_dir=self.output_dir, comment_lang=self.comment_lang), + ExportInput(input_file_path="test_input_file.txt", code=self.test_text, + output_dir=self.output_dir, comment_lang=self.comment_lang) + ] + expected_paths = [ + os.path.join(self.output_dir, "test_input_file"), + os.path.join(self.output_dir, "test_input_file_1"), + os.path.join(self.output_dir, "test_input_file_2") + ] + unique_paths = self.helper.generate_unique_output_paths(exporter_inputs) + self.assertEqual(unique_paths, expected_paths) + + def test_process_notebooks(self): + """Test processing multiple notebooks.""" + exporter_inputs = [ + ExportInput(input_file_path="test_input_file1.txt", code=self.test_text, + output_dir=self.output_dir, comment_lang=self.comment_lang), + ExportInput(input_file_path="test_input_file2.txt", code=self.test_text, + output_dir=self.output_dir, comment_lang=self.comment_lang) + ] + results = self.helper.process_notebooks(exporter_inputs) + + self.assertEqual(len(results), 2) + for result in results: + self.assertTrue(isinstance(result.base64_encoded_content, str)) + self.assertTrue(isinstance(result.base64_encoded_content_size, int)) + self.assertTrue(result.output_file_path.startswith(self.output_dir)) + self.assertTrue(result.export_succeeded) + + def test_different_languages(self): + """Test notebook creation with all supported languages.""" + languages = ["English", "Japanese", "French", "German", "Italian", "Spanish", "Chinese", "Korean", "Portuguese"] + code = 'print("Hello, World!")' + output_file_path = "test_output_file" + + for lang in languages: + ex_in = ExportInput(input_file_path=self.input_file_path, code=code, + output_dir=self.output_dir, comment_lang=lang) + content = self.helper.create_notebook_content(output_file_path, ex_in) + messages = get_language_messages(lang) + + self.assertIn(messages[MessageKey.NOTEBOOK_DESCRIPTION], content) + self.assertIn(messages[MessageKey.SOURCE_SCRIPT], content) + self.assertIn(messages[MessageKey.SYNTAX_CHECK_RESULTS], content) + self.assertIn(messages[MessageKey.NO_ERRORS_DETECTED], content) + self.assertIn(messages[MessageKey.REVIEW_CODE], content) + + +if __name__ == '__main__': + unittest.main() diff --git a/sql2dbx/tests/pyscripts/test_notebook_i18n.py b/sql2dbx/tests/pyscripts/test_notebook_i18n.py new file mode 100644 index 00000000..719a685c --- /dev/null +++ b/sql2dbx/tests/pyscripts/test_notebook_i18n.py @@ -0,0 +1,106 @@ +import unittest +from enum import Enum + +from notebooks.pyscripts.notebook_i18n import (MessageKey, + get_language_messages, + get_supported_languages) + + +class TestI18n(unittest.TestCase): + """Test cases for the i18n module.""" + + def test_message_key_enum(self): + """Test that MessageKey is correctly defined as an Enum.""" + self.assertTrue(issubclass(MessageKey, Enum)) + self.assertEqual(MessageKey.NOTEBOOK_DESCRIPTION.value, "notebook_description") + self.assertEqual(MessageKey.SOURCE_SCRIPT.value, "source_script") + self.assertEqual(MessageKey.SYNTAX_CHECK_RESULTS.value, "syntax_check_results") + self.assertEqual(MessageKey.ERRORS_FROM_CHECKS.value, "errors_from_checks") + self.assertEqual(MessageKey.PYTHON_SYNTAX_ERRORS.value, "python_syntax_errors") + self.assertEqual(MessageKey.SPARK_SQL_SYNTAX_ERRORS.value, "spark_sql_syntax_errors") + self.assertEqual(MessageKey.NO_ERRORS_DETECTED.value, "no_errors_detected") + self.assertEqual(MessageKey.REVIEW_CODE.value, "review_code") + + def test_supported_languages(self): + """Test that get_supported_languages returns the expected languages.""" + languages = get_supported_languages() + expected_languages = [ + "English", "Japanese", "Chinese", "French", "German", + "Italian", "Korean", "Portuguese", "Spanish" + ] + self.assertEqual(sorted(languages), sorted(expected_languages)) + self.assertEqual(len(languages), 9) # Should be 9 languages supported + + def test_get_language_messages_returns_dict(self): + """Test that get_language_messages returns a dictionary.""" + for lang in get_supported_languages(): + messages = get_language_messages(lang) + self.assertIsInstance(messages, dict) + # Check that all MessageKey enum values are in the dictionary + for key in MessageKey: + self.assertIn(key, messages) + self.assertIsInstance(messages[key], str) + + def test_get_language_messages_fallback(self): + """Test that get_language_messages falls back to English for unsupported languages.""" + en_messages = get_language_messages("English") + # Test with an unsupported language + unsupported_messages = get_language_messages("Unsupported") + self.assertEqual(en_messages, unsupported_messages) + + def test_language_specific_messages(self): + """Test that each language has the correct specific messages.""" + # Test a sample message from each language + self.assertEqual( + get_language_messages("English")[MessageKey.SOURCE_SCRIPT], + "Source script" + ) + self.assertEqual( + get_language_messages("Japanese")[MessageKey.SOURCE_SCRIPT], + "ソーススクリプト" + ) + self.assertEqual( + get_language_messages("Chinese")[MessageKey.SOURCE_SCRIPT], + "源脚本" + ) + self.assertEqual( + get_language_messages("French")[MessageKey.SOURCE_SCRIPT], + "Script source" + ) + self.assertEqual( + get_language_messages("German")[MessageKey.SOURCE_SCRIPT], + "Quellskript" + ) + self.assertEqual( + get_language_messages("Italian")[MessageKey.SOURCE_SCRIPT], + "Script sorgente" + ) + self.assertEqual( + get_language_messages("Korean")[MessageKey.SOURCE_SCRIPT], + "소스 스크립트" + ) + self.assertEqual( + get_language_messages("Portuguese")[MessageKey.SOURCE_SCRIPT], + "Script fonte" + ) + self.assertEqual( + get_language_messages("Spanish")[MessageKey.SOURCE_SCRIPT], + "Script fuente" + ) + + def test_message_formatting(self): + """Test that multi-line messages are properly formatted.""" + for lang in get_supported_languages(): + messages = get_language_messages(lang) + # Checking that the notebook description doesn't have newlines + # as we're using string concatenation in the definition + notebook_desc = messages[MessageKey.NOTEBOOK_DESCRIPTION] + self.assertNotIn("\n", notebook_desc) + + # Check that messages are not empty + for key in MessageKey: + self.assertTrue(messages[key], f"Message for {key} in {lang} is empty") + + +if __name__ == "__main__": + unittest.main() diff --git a/sql2dbx/tests/pyscripts/test_spark_sql_extract_helper.py b/sql2dbx/tests/pyscripts/test_spark_sql_extract_helper.py new file mode 100644 index 00000000..df1f6aa9 --- /dev/null +++ b/sql2dbx/tests/pyscripts/test_spark_sql_extract_helper.py @@ -0,0 +1,169 @@ +import unittest + +from notebooks.pyscripts.spark_sql_extract_helper import SparkSQLExtractHelper + + +class TestSparkSQLExtractor(unittest.TestCase): + + def setUp(self) -> None: + self.helper = SparkSQLExtractHelper() + + def test_simple_sql_extraction(self): + func_string = ( + "def example_func():\n" + " query = \"SELECT * FROM table\"\n" + " spark.sql(query)\n" + ) + error, sql_statements = self.helper.extract_sql_from_string(func_string) + self.assertIsNone(error) + self.assertEqual(sql_statements, ["SELECT * FROM table"]) + + def test_multiple_sql_extractions(self): + func_string = ( + "def example_func():\n" + " query1 = \"SELECT * FROM table1\"\n" + " query2 = \"SELECT * FROM table2\"\n" + " spark.sql(query1)\n" + " spark.sql(query2)\n" + ) + error, sql_statements = self.helper.extract_sql_from_string(func_string) + self.assertIsNone(error) + self.assertEqual(sql_statements, ["SELECT * FROM table1", "SELECT * FROM table2"]) + + def test_f_string_extraction(self): + func_string = ( + "def example_func():\n" + " table_name = \"table\"\n" + " query = f\"SELECT * FROM {table_name}\"\n" + " spark.sql(query)\n" + ) + error, sql_statements = self.helper.extract_sql_from_string(func_string) + self.assertIsNone(error) + self.assertEqual(sql_statements, ["SELECT * FROM table"]) + + def test_syntax_error(self): + func_string = ( + "def example_func()\n" + " query = \"SELECT * FROM table\"\n" + " spark.sql(query)\n" + ) + error, sql_statements = self.helper.extract_sql_from_string(func_string) + self.assertTrue(error) + self.assertEqual(sql_statements, []) + + def test_variable_assignment(self): + func_string = ( + "def example_func():\n" + " table_name = \"table\"\n" + " query = \"SELECT * FROM \" + table_name\n" + " spark.sql(query)\n" + ) + error, sql_statements = self.helper.extract_sql_from_string(func_string) + self.assertIsNone(error) + self.assertEqual(sql_statements, ["SELECT * FROM table"]) + + def test_placeholder_replacement(self): + func_string = ( + "def example_func():\n" + " table_name = \"table\"\n" + " query = f\"SELECT * FROM {table_name} WHERE column = {_placeholder_}\"\n" + " spark.sql(query)\n" + ) + error, sql_statements = self.helper.extract_sql_from_string(func_string) + self.assertIsNone(error) + self.assertEqual(sql_statements, ["SELECT * FROM table WHERE column = _placeholder_"]) + + def test_curly_braces_removal(self): + func_string = ( + "def example_func():\n" + " query = \"SELECT * FROM {table}\"\n" + " spark.sql(query)\n" + ) + error, sql_statements = self.helper.extract_sql_from_string(func_string) + self.assertIsNone(error) + self.assertEqual(sql_statements, ["SELECT * FROM table"]) + + # Tests for None and non-string inputs + def test_none_input(self): + """Test handling of None input.""" + error, sql_statements = self.helper.extract_sql_from_string(None) + self.assertIsNone(error) # Should return None for error to be consistent with success case + self.assertEqual(sql_statements, []) + + def test_non_string_input_number(self): + """Test handling of non-string input (number).""" + error, sql_statements = self.helper.extract_sql_from_string(42) + self.assertEqual(error, "Type error during parsing: compile() arg 1 must be a string, bytes or AST object") + self.assertEqual(sql_statements, []) + + def test_non_string_input_list(self): + """Test handling of non-string input (list).""" + error, sql_statements = self.helper.extract_sql_from_string(["SELECT * FROM table"]) + self.assertEqual(error, "Type error during parsing: compile() arg 1 must be a string, bytes or AST object") + self.assertEqual(sql_statements, []) + + def test_empty_string_input(self): + """Test handling of empty string input.""" + error, sql_statements = self.helper.extract_sql_from_string("") + # Empty string is technically valid Python syntax, but will not yield SQL + self.assertIsNone(error) + self.assertEqual(sql_statements, []) + + def test_whitespace_only_input(self): + """Test handling of whitespace-only string input.""" + error, sql_statements = self.helper.extract_sql_from_string(" \n \t ") + # Whitespace is also valid Python syntax + self.assertIsNone(error) + self.assertEqual(sql_statements, []) + + # Test for invalid Python code that's not a syntax error + def test_valid_python_but_not_function(self): + """Test handling of valid Python code that's not a function definition.""" + func_string = "x = 10\ny = 20\nprint(x + y)" + error, sql_statements = self.helper.extract_sql_from_string(func_string) + self.assertIsNone(error) # This is valid Python but may not have SQL + self.assertEqual(sql_statements, []) + + # Test for multiple errors in the same function + def test_multiple_syntax_errors(self): + """Test handling of Python code with multiple syntax errors.""" + func_string = ( + "def example_func():\n" + " query = SELECT * FROM table\n" # Missing quotes + " spark.sql(query\n" # Missing closing parenthesis + ) + error, sql_statements = self.helper.extract_sql_from_string(func_string) + self.assertTrue(error) + self.assertTrue(error.startswith("Python syntax error:")) + self.assertEqual(sql_statements, []) + + # Test for SQL with special characters + def test_sql_with_special_characters(self): + """Test extraction of SQL with special characters.""" + func_string = ( + "def example_func():\n" + " query = \"SELECT * FROM table WHERE col1 LIKE '%special\\\\%'\"\n" + " spark.sql(query)\n" + ) + error, sql_statements = self.helper.extract_sql_from_string(func_string) + self.assertIsNone(error) + self.assertEqual(sql_statements, ["SELECT * FROM table WHERE col1 LIKE '%special\\%'"]) + + # Test for complex variable interpolation in f-strings + def test_complex_f_string_interpolation(self): + """Test extraction of SQL from complex f-string interpolation.""" + func_string = ( + "def example_func():\n" + " schema = \"public\"\n" + " table = \"users\"\n" + " conditions = \"age > 18\"\n" + " query = f\"SELECT * FROM {schema}.{table} WHERE {conditions}\"\n" + " spark.sql(query)\n" + ) + error, sql_statements = self.helper.extract_sql_from_string(func_string) + self.assertIsNone(error) + self.assertEqual(sql_statements, ["SELECT * FROM public.users WHERE age > 18"]) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/sql2dbx/tests/pyscripts/test_token_utils.py b/sql2dbx/tests/pyscripts/test_token_utils.py new file mode 100644 index 00000000..5dd8a8fb --- /dev/null +++ b/sql2dbx/tests/pyscripts/test_token_utils.py @@ -0,0 +1,114 @@ +import unittest + +from notebooks.pyscripts.token_utils import (ClaudeTokenCounter, + OpenAITokenCounter, TokenizerType, + determine_tokenizer_from_endpoint, + get_token_counter, + get_token_counter_from_endpoint) + + +class TestTokenUtils(unittest.TestCase): + def test_openai_token_counter(self): + counter = OpenAITokenCounter(token_encoding_name="o200k_base") + text = "This is a test string." + token_count = counter.count_tokens(text) + self.assertGreater(token_count, 0) + + # Test longer text + long_text = "This is a longer text that should be tokenized into multiple tokens. " * 10 + long_count = counter.count_tokens(long_text) + self.assertGreater(long_count, 50) + + def test_claude_token_counter(self): + counter = ClaudeTokenCounter() + text = "This is a test string." + token_count = counter.count_tokens(text) + self.assertGreater(token_count, 0) + # Verify that Claude token count is an estimation based on Character/3.4 + self.assertEqual(token_count, int(len(text) / 3.4)) + + # Test empty string + self.assertEqual(counter.count_tokens(""), 0) + + # Test longer text and verify character ratio + long_text = "This is a much longer text that contains multiple sentences. " * 10 + long_count = counter.count_tokens(long_text) + self.assertEqual(long_count, int(len(long_text) / 3.4)) + + def test_token_count_comparison(self): + """Test and compare token counts between different tokenizers""" + openai_counter = OpenAITokenCounter(token_encoding_name="o200k_base") + claude_counter = ClaudeTokenCounter() + + test_texts = [ + "", # Empty string + "Hello world", # Simple text + "OpenAI and Claude use different tokenization methods", # Medium text + # Text with special characters + "Email: test@example.com, URL: https://example.com/path?query=value", + # Multilingual text + "English, 日本語 (Japanese), Español (Spanish), Русский (Russian)", + # Code snippet + "def hello_world():\n print('Hello, world!')\n return True" + ] + + for text in test_texts: + openai_count = openai_counter.count_tokens(text) + claude_count = claude_counter.count_tokens(text) + + # Both should return values >= 0 + self.assertGreaterEqual(openai_count, 0) + self.assertGreaterEqual(claude_count, 0) + + # Empty string should be 0 tokens for both + if not text: + self.assertEqual(openai_count, 0) + self.assertEqual(claude_count, 0) + + # Very basic sanity check: Claude estimation should be somewhat related to OpenAI + # This is just a rough check, not expecting exact correlation + if len(text) > 20: # Only check for longer texts + # Claude's estimate should be within 50% of OpenAI's count for typical texts + # This is a very rough validation since they use different tokenization methods + self.assertLess(abs(claude_count - openai_count), max(openai_count, claude_count)) + + def test_get_token_counter(self): + # Verify that the default is Claude tokenizer + default_counter = get_token_counter() + self.assertIsInstance(default_counter, ClaudeTokenCounter) + + # Explicitly get OpenAI tokenizer + openai_counter = get_token_counter("openai", "o200k_base") + self.assertIsInstance(openai_counter, OpenAITokenCounter) + + # Explicitly get Claude tokenizer + claude_counter = get_token_counter("claude", "claude") + self.assertIsInstance(claude_counter, ClaudeTokenCounter) + + # Verify exception for invalid tokenizer type + with self.assertRaises(ValueError): + get_token_counter("invalid_tokenizer") + + def test_determine_tokenizer_from_endpoint(self): + # Test Databricks Claude endpoint detection + tokenizer_type, model = determine_tokenizer_from_endpoint("databricks-claude-3-7-sonnet") + self.assertEqual(tokenizer_type, "claude") + self.assertEqual(model, "claude") + + # Test OpenAI endpoint detection + tokenizer_type, model = determine_tokenizer_from_endpoint("gpt-4o") + self.assertEqual(tokenizer_type, "openai") + self.assertEqual(model, "o200k_base") + + def test_get_token_counter_from_endpoint(self): + # Get tokenizer from Databricks Claude endpoint + counter = get_token_counter_from_endpoint("databricks-claude-3-7-sonnet") + self.assertIsInstance(counter, ClaudeTokenCounter) + + # Get tokenizer from OpenAI endpoint + counter = get_token_counter_from_endpoint("gpt-4o") + self.assertIsInstance(counter, OpenAITokenCounter) + + +if __name__ == '__main__': + unittest.main() diff --git a/sql2dbx/tests/pyscripts/test_utils.py b/sql2dbx/tests/pyscripts/test_utils.py new file mode 100644 index 00000000..c41a648f --- /dev/null +++ b/sql2dbx/tests/pyscripts/test_utils.py @@ -0,0 +1,95 @@ +import os +import tempfile +import unittest + +from notebooks.pyscripts.utils import (get_file_content, + list_files_recursively, + parse_number_ranges, + remove_sql_comments) + + +class TestUtils(unittest.TestCase): + def test_remove_sql_comments(self): + sql_text = "SELECT * FROM table; -- This is a comment\n/* Block comment */\nSELECT * FROM another_table;" + cleaned_sql = remove_sql_comments(sql_text) + expected_sql = "SELECT * FROM table; \n\nSELECT * FROM another_table;" + self.assertEqual(cleaned_sql, expected_sql) + + def test_list_files_recursively(self): + with tempfile.TemporaryDirectory() as temp_dir: + file1 = os.path.join(temp_dir, "file1.sql") + file2 = os.path.join(temp_dir, "file2.sql") + subdir = os.path.join(temp_dir, "subdir") + os.mkdir(subdir) + file3 = os.path.join(subdir, "file3.sql") + + with open(file1, 'w') as f: + f.write("SELECT 1;") + with open(file2, 'w') as f: + f.write("SELECT 2;") + with open(file3, 'w') as f: + f.write("SELECT 3;") + + all_files = list_files_recursively(temp_dir) + self.assertIn(file1, all_files) + self.assertIn(file2, all_files) + self.assertIn(file3, all_files) + + def test_get_file_content(self): + with tempfile.TemporaryDirectory() as temp_dir: + file_path = os.path.join(temp_dir, "file.sql") + content = "This is test; これはテストです。" + with open(file_path, 'w', encoding='utf-8') as f: + f.write(content) + + file_content, encoding = get_file_content(file_path) + self.assertEqual(file_content, content) + self.assertEqual(encoding, 'utf-8') + + # Test with specified encoding + file_content, encoding = get_file_content(file_path, encoding='utf-8') + self.assertEqual(file_content, content) + self.assertEqual(encoding, 'utf-8') + + +class TestParseNumberRanges(unittest.TestCase): + """Test cases for the parse_number_ranges function.""" + + def test_single_integer(self): + """Test parsing a single integer.""" + self.assertEqual(parse_number_ranges("5"), [5]) + + def test_single_range(self): + """Test parsing a single range.""" + self.assertEqual(parse_number_ranges("2-6"), [2, 3, 4, 5, 6]) + + def test_multiple_integers(self): + """Test parsing multiple integers.""" + self.assertEqual(parse_number_ranges("1,3,8"), [1, 3, 8]) + + def test_multiple_ranges(self): + """Test parsing multiple ranges.""" + self.assertEqual(parse_number_ranges("1-3,5-7"), [1, 2, 3, 5, 6, 7]) + + def test_mixed_integers_and_ranges(self): + """Test parsing a mix of integers and ranges.""" + self.assertEqual(parse_number_ranges("2,4-6,9"), [2, 4, 5, 6, 9]) + + def test_empty_string(self): + """Test parsing an empty string.""" + self.assertEqual(parse_number_ranges(""), []) + + # Invalid input format tests + def test_invalid_range(self): + """Test parsing an invalid range (e.g., '1-3-5').""" + with self.assertRaises(ValueError): + parse_number_ranges("1-3-5") + + def test_non_numeric_input(self): + """Test parsing non-numeric input.""" + with self.assertRaises(ValueError): + parse_number_ranges("1,a,3") + + +if __name__ == '__main__': + unittest.main()