From 159f4be66f6c09b86b25a619c4b74e4ca06a9cfe Mon Sep 17 00:00:00 2001 From: Roland Bock Date: Sun, 3 Jul 2022 08:19:28 +0200 Subject: [PATCH] Support schema in ddl #444 --- scripts/ddl2cpp | 49 +++++++++++++++++++++------------ tests/mysql/usage/TabSample.sql | 2 +- 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/scripts/ddl2cpp b/scripts/ddl2cpp index 582741ae6..86584a487 100755 --- a/scripts/ddl2cpp +++ b/scripts/ddl2cpp @@ -42,8 +42,8 @@ ddlNumber = pp.Word(pp.nums + "+-.", pp.nums + "+-.Ee") ddlString = ( pp.QuotedString("'") | pp.QuotedString('"', escQuote='""') | pp.QuotedString("`") ) -ddlTerm = pp.Word(pp.alphas + "_", pp.alphanums + "_.$") -ddlName = pp.Or([ddlTerm, ddlString, pp.Combine(ddlString + "." + ddlString)]) +ddlTerm = pp.Word(pp.alphas + "_", pp.alphanums + "_$") +ddlName = ddlTerm | ddlString ddlOperator = pp.Or( map(pp.CaselessLiteral, ["+", "-", "*", "/", "<", "<=", ">", ">=", "=", "%"]) ) @@ -69,7 +69,7 @@ ddlBracedArguments << ddlLeft + pp.delimitedList(ddlExpression) + ddlRight ddlBracedExpression << ddlLeft + ddlExpression + ddlRight ddlArguments = pp.Suppress(pp.Group(pp.delimitedList(ddlExpression))) -ddlFunctionCall << ddlName + ddlLeft + pp.Optional(ddlArguments) + ddlRight +ddlFunctionCall << pp.Optional(ddlName + ".") + ddlName + ddlLeft + pp.Optional(ddlArguments) + ddlRight # Column and constraint parsers ddlBooleanTypes = [ @@ -270,6 +270,7 @@ ddlCreateTable = pp.Group( + pp.Suppress(pp.Optional(ddlOrReplace)) + pp.CaselessLiteral("TABLE") + pp.Suppress(pp.Optional(ddlIfNotExists)) + + pp.Optional(ddlName.setResultsName("schema") + pp.Suppress('.')) + ddlName.setResultsName("tableName") + ddlLeft + pp.Group(pp.delimitedList(pp.Suppress(ddlConstraint) | ddlColumn)).setResultsName( @@ -386,14 +387,17 @@ def testRational(): def testTable(): text = """ - CREATE TABLE "public"."dk" ( + CREATE TABLE 'public'.'dk' ( "id" int8 NOT NULL DEFAULT nextval('dk_id_seq'::regclass), "strange" NUMERIC(314, 15), "last_update" timestamp(6) DEFAULT now(), PRIMARY KEY (id) -) + ) """ result = ddlCreateTable.parseString(text, parseAll=True) + table = result[0] + assert table.schema == "public" + assert table.tableName == "dk" def testParser(): @@ -422,8 +426,10 @@ def get_include_guard_name(namespace, inputfile): return val.upper() -def identity_naming_func(s): - return s +def identity_naming_func(name, schema = None): + if schema: + return schema + '__' + name; + return name def repl_camel_case_func(m): @@ -433,14 +439,18 @@ def repl_camel_case_func(m): return m.group(1) + m.group(2).upper() -def class_name_naming_func(s): - s = s.replace(".", "_") - return re.sub("(^|\s|[_0-9])(\S)", repl_camel_case_func, s) +def class_name_naming_func(name, schema = None): + if schema: + name = schema + "__" + name + name = name.replace(".", "_") + return re.sub("(^|\s|[_0-9])(\S)", repl_camel_case_func, name) -def member_name_naming_func(s): - s = s.replace(".", "_") - return re.sub("(\s|_|[0-9])(\S)", repl_camel_case_func, s) +def member_name_naming_func(name, schema = None): + if schema: + name = schema + "_" + name + name = name.replace(".", "_") + return re.sub("(\s|_|[0-9])(\S)", repl_camel_case_func, name) def repl_func_for_args(m): @@ -579,11 +589,16 @@ def createHeader(): header = beginHeader(pathToHeader, namespace, nsList) DataTypeError = False for create in tableCreations: - sqlTableName = create.tableName + if identityNaming: + sqlSchema = create.schema + sqlTableName = create.tableName + else: + sqlSchema = None + sqlTableName = create.schema + '.' + create.tableName if splitTables: header = beginHeader(pathToHeader + sqlTableName + ".h", namespace, nsList) - tableClass = toClassName(sqlTableName) - tableMember = toMemberName(sqlTableName) + tableClass = toClassName(sqlTableName, sqlSchema) + tableMember = toMemberName(sqlTableName, sqlSchema) tableNamespace = tableClass + "_" tableTemplateParameters = tableClass print(" namespace " + tableNamespace, file=header) @@ -682,7 +697,7 @@ def createHeader(): print(" struct _alias_t", file=header) print(" {", file=header) print( - ' static constexpr const char _literal[] = "' + sqlTableName + '";', + ' static constexpr const char _literal[] = "' + (sqlSchema + '.' if sqlSchema else '') + sqlTableName + '";', file=header, ) print( diff --git a/tests/mysql/usage/TabSample.sql b/tests/mysql/usage/TabSample.sql index 380cf58e4..2f90f9698 100644 --- a/tests/mysql/usage/TabSample.sql +++ b/tests/mysql/usage/TabSample.sql @@ -1,4 +1,4 @@ -CREATE TABLE tab_sample ( +CREATE TABLE public.tab_sample ( alpha bigint(20) DEFAULT NULL AUTO_INCREMENT, beta tinyint(1) DEFAULT NULL, gamma varchar(255) DEFAULT NULL