diff --git a/.gitignore b/.gitignore index 213f266..7132e97 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,7 @@ __pycache__ /cli_helpers_dev .idea/ .cache/ +.vscode/ +**/.ropeproject/ +*.swp + diff --git a/AUTHORS b/AUTHORS index 40f2b90..d3aae74 100644 --- a/AUTHORS +++ b/AUTHORS @@ -25,6 +25,7 @@ This project receives help from these awesome contributors: - Mel Dafert - Andrii Kohut - Roland Walker +- Liu Zhao (astroshot) Thanks ------ diff --git a/CHANGELOG b/CHANGELOG index da2b436..faa759c 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,6 +1,10 @@ Changelog ========= +Features +------------- +* New formatter is added to export query result to sql format (such as sql-insert, sql-update). + TBD ------------- * don't escape newlines, etc. in ascii tables, and add ascii_escaped table format diff --git a/cli_helpers/tabular_output/sql_output_adapter.py b/cli_helpers/tabular_output/sql_output_adapter.py new file mode 100644 index 0000000..eb673cf --- /dev/null +++ b/cli_helpers/tabular_output/sql_output_adapter.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- + +supported_formats = ( + "sql-insert", + "sql-update", + "sql-update-1", + "sql-update-2", +) + +preprocessors = () + + +def escape_for_sql_statement(value): + if isinstance(value, bytes): + return f"X'{value.hex()}'" + else: + return "'{}'".format(value) + + +def adapter(data, headers, table_format=None, **kwargs): + """ + This function registers supported_formats to default TabularOutputFormatter + + Parameters: + data: query result + headers: columns + table_format: values from supported_formats + kwargs: + extract_tables: extract_tables function. For example, in pgcli.packages.parseutils.tables there is a function extract_tables + delimiter: Character surrounds table name or column name when it conflicts with sql keywords. + For example, mysql uses ` and postgres uses " + """ + extract_table_func = kwargs.get("extract_tables") + if not extract_table_func: + raise ValueError("extract_tables function should be registered first") + + tables = extract_table_func(formatter.query) + delimiter = kwargs.get("delimiter") + if not isinstance(delimiter, str): + delimiter = '"' + + if tables is not None and len(tables) > 0: + table = tables[0] + if table[0]: + table_name = "{}.{}".format(*table[:2]) + else: + table_name = table[1] + else: + table_name = "DUAL".format(delimiter=delimiter) + + header_joiner = "{delimiter}, {delimiter}".format(delimiter=delimiter) + if table_format == "sql-insert": + h = header_joiner.join(headers) + yield "INSERT INTO {delimiter}{table_name}{delimiter} ({delimiter}{header}{delimiter}) VALUES".format( + table_name=table_name, header=h, delimiter=delimiter + ) + prefix = " " + for d in data: + values = ", ".join(escape_for_sql_statement(v) for i, v in enumerate(d)) + yield "{}({})".format(prefix, values) + if prefix == " ": + prefix = ", " + yield ";" + if table_format.startswith("sql-update"): + s = table_format.split("-") + keys = 1 + if len(s) > 2: + keys = int(s[-1]) + for d in data: + yield "UPDATE {delimiter}{table_name}{delimiter} SET".format( + table_name=table_name, delimiter=delimiter + ) + prefix = " " + for i, v in enumerate(d[keys:], keys): + yield "{prefix}{delimiter}{column}{delimiter} = {value}".format( + prefix=prefix, + delimiter=delimiter, + column=headers[i], + value=escape_for_sql_statement(v), + ) + if prefix == " ": + prefix = ", " + f = "{delimiter}{column}{delimiter} = {value}" + where = ( + f.format( + delimiter=delimiter, + column=headers[i], + value=escape_for_sql_statement(d[i]), + ) + for i in range(keys) + ) + yield "WHERE {};".format(" AND ".join(where)) + + +def register_new_formatter(TabularOutputFormatter, **kwargs): + """ + Parameters: + TabularOutputFormatter: default TabularOutputFormatter imported from cli_helpers + kwargs: dict required, with key delimiter and tables required. + For example {"delimiter": "`", "extact_tables": extract_tables} + """ + global formatter + formatter = TabularOutputFormatter + for sql_format in supported_formats: + kwargs["table_format"] = sql_format + TabularOutputFormatter.register_new_formatter( + sql_format, adapter, preprocessors, kwargs + ) diff --git a/tests/tabular_output/test_sql_output_adapter.py b/tests/tabular_output/test_sql_output_adapter.py new file mode 100644 index 0000000..0e4c5c9 --- /dev/null +++ b/tests/tabular_output/test_sql_output_adapter.py @@ -0,0 +1,159 @@ +# -*- coding: utf-8 -*- + +from collections import namedtuple + +from cli_helpers.tabular_output import TabularOutputFormatter +from cli_helpers.tabular_output.sql_output_adapter import escape_for_sql_statement, adapter, register_new_formatter + +TableReference = namedtuple( + "TableReference", ["schema", "name", "alias", "is_function"] +) + +TableReference.ref = property( + lambda self: self.alias + or ( + self.name + if self.name.islower() or self.name[0] == '"' + else '"' + self.name + '"' + ) +) + + +def test_escape_for_sql_statement_bytes(): + bts = b"837124ab3e8dc0f" + escaped_bytes = escape_for_sql_statement(bts) + assert escaped_bytes == "X'383337313234616233653864633066'" + + +def __mock_extract_tables(sql): + """ + mock function for extract tables + in mycli, pass `mycli.packages.parseutils.extract_tables` + in pgcli, pass `pgcli.packages.parseutils.extract_tables` + + :param sql: sql query + :return: + """ + table_refs = (TableReference(schema=None, name='user', alias='"user"', is_function=False),) + return table_refs + + +def test_output_sql_insert(): + global formatter + formatter = TabularOutputFormatter + register_new_formatter(formatter) + data = [ + [ + 1, + "Jackson", + "jackson_test@gmail.com", + "132454789", + "", + "2022-09-09 19:44:32.712343+08", + "2022-09-09 19:44:32.712343+08", + ] + ] + header = ["id", "name", "email", "phone", "description", "created_at", "updated_at"] + table_format = "sql-insert" + kwargs = { + "column_types": [int, str, str, str, str, str, str], + "sep_title": "RECORD {n}", + "sep_character": "-", + "sep_length": (1, 25), + "missing_value": "", + "integer_format": "", + "float_format": "", + "disable_numparse": True, + "preserve_whitespace": True, + "max_field_width": 500, + "extract_tables": __mock_extract_tables, + } + + formatter.query = 'SELECT * FROM "user";' + # For postgresql + kwargs["delimiter"] = '"' + output = adapter(data, header, table_format=table_format, **kwargs) + output_list = [l for l in output] + expected = [ + 'INSERT INTO "user" ("id", "name", "email", "phone", "description", "created_at", "updated_at") VALUES', + " ('1', 'Jackson', 'jackson_test@gmail.com', '132454789', '', " + + "'2022-09-09 19:44:32.712343+08', '2022-09-09 19:44:32.712343+08')", + ";", + ] + assert expected == output_list + + # For mysql + kwargs["delimiter"] = "`" + output = adapter(data, header, table_format=table_format, **kwargs) + output_list = [l for l in output] + expected = [ + 'INSERT INTO `user` (`id`, `name`, `email`, `phone`, `description`, `created_at`, `updated_at`) VALUES', + " ('1', 'Jackson', 'jackson_test@gmail.com', '132454789', '', " + + "'2022-09-09 19:44:32.712343+08', '2022-09-09 19:44:32.712343+08')", + ";", + ] + assert expected == output_list + + +def test_output_sql_update_pg(): + global formatter + formatter = TabularOutputFormatter + register_new_formatter(formatter) + data = [ + [ + 1, + "Jackson", + "jackson_test@gmail.com", + "132454789", + "", + "2022-09-09 19:44:32.712343+08", + "2022-09-09 19:44:32.712343+08", + ] + ] + header = ["id", "name", "email", "phone", "description", "created_at", "updated_at"] + table_format = "sql-update" + table_refs = (TableReference(schema=None, name='user', alias='"user"', is_function=False),) + kwargs = { + "column_types": [int, str, str, str, str, str, str], + "sep_title": "RECORD {n}", + "sep_character": "-", + "sep_length": (1, 25), + "missing_value": "", + "integer_format": "", + "float_format": "", + "disable_numparse": True, + "preserve_whitespace": True, + "max_field_width": 500, + "extract_tables": __mock_extract_tables, + } + formatter.query = 'SELECT * FROM "user";' + # For postgresql + kwargs["delimiter"] = '"' + output = adapter(data, header, table_format=table_format, **kwargs) + output_list = [l for l in output] + expected = [ + 'UPDATE "user" SET', + ' "name" = \'Jackson\'', + ', "email" = \'jackson_test@gmail.com\'', + ', "phone" = \'132454789\'', + ', "description" = \'\'', + ', "created_at" = \'2022-09-09 19:44:32.712343+08\'', + ', "updated_at" = \'2022-09-09 19:44:32.712343+08\'', + 'WHERE "id" = \'1\';'] + assert expected == output_list + + # For mysql + kwargs["delimiter"] = "`" + output = adapter(data, header, table_format=table_format, **kwargs) + output_list = [l for l in output] + print(output_list) + expected = [ + 'UPDATE `user` SET', + " `name` = 'Jackson'", + ", `email` = 'jackson_test@gmail.com'", + ", `phone` = '132454789'", + ", `description` = ''", + ", `created_at` = '2022-09-09 19:44:32.712343+08'", + ", `updated_at` = '2022-09-09 19:44:32.712343+08'", + "WHERE `id` = '1';"] + assert expected == output_list