diff --git a/edb/edgeql/quote.py b/edb/edgeql/quote.py index 5455cc5dcf9..2859fa64d32 100644 --- a/edb/edgeql/quote.py +++ b/edb/edgeql/quote.py @@ -35,17 +35,25 @@ ''') -def escape_string(s: str) -> str: - split = re.split(r"(\n|\\\\|\\')", s) +def quote_literal(string: str) -> str: - if len(split) == 1: - return s.replace(r"'", r"\'") + def escape_string(s: str) -> str: + # characters escaped according to + # https://www.edgedb.com/docs/reference/edgeql/lexical#strings + result = s - return ''.join((r if i % 2 else r.replace(r"'", r"\'")) - for i, r in enumerate(split)) + # escape backslash first + result = result.replace('\\', '\\\\') + result = result.replace('\'', '\\\'') + result = result.replace('\b', '\\b') + result = result.replace('\f', '\\f') + result = result.replace('\n', '\\n') + result = result.replace('\r', '\\r') + result = result.replace('\t', '\\t') + + return result -def quote_literal(string: str) -> str: return "'" + escape_string(string) + "'" diff --git a/edb/server/compiler/ddl.py b/edb/server/compiler/ddl.py index db63446c909..f41cd50c09f 100644 --- a/edb/server/compiler/ddl.py +++ b/edb/server/compiler/ddl.py @@ -721,8 +721,6 @@ def _describe_current_migration( **extra, } ) - .encode('unicode_escape') - .decode('utf-8') ) desc_ql = edgeql.parse_query( diff --git a/tests/edgeql/__init__.py b/tests/edgeql/__init__.py new file mode 100644 index 00000000000..8de1eeb735c --- /dev/null +++ b/tests/edgeql/__init__.py @@ -0,0 +1,26 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2016-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import unittest + + +def suite(): + test_loader = unittest.TestLoader() + test_suite = test_loader.discover('.', pattern='test_*.py') + return test_suite diff --git a/tests/edgeql/test_quote.py b/tests/edgeql/test_quote.py new file mode 100644 index 00000000000..a8924972aa4 --- /dev/null +++ b/tests/edgeql/test_quote.py @@ -0,0 +1,45 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2018-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +import edb.edgeql.quote as qlquote + +class QuoteTests(unittest.TestCase): + + def test_quote_string(self): + self.assertEqual(qlquote.quote_literal(""), "''"), + self.assertEqual(qlquote.quote_literal("abc"), "'abc'") + self.assertEqual(qlquote.quote_literal("\""), "'\"'") + self.assertEqual(qlquote.quote_literal("\b"), "'\\b'") + self.assertEqual(qlquote.quote_literal("\f"), "'\\f'") + self.assertEqual(qlquote.quote_literal("\n"), "'\\n'") + self.assertEqual(qlquote.quote_literal("\r"), "'\\r'") + self.assertEqual(qlquote.quote_literal("\t"), "'\\t'") + self.assertEqual(qlquote.quote_literal("\'"), "'\\\''") + self.assertEqual(qlquote.quote_literal("\\"), "'\\\\'") + self.assertEqual(qlquote.quote_literal("\\b"), "'\\\\b'") + self.assertEqual(qlquote.quote_literal("\\f"), "'\\\\f'") + self.assertEqual(qlquote.quote_literal("\\n"), "'\\\\n'") + self.assertEqual(qlquote.quote_literal("\\r"), "'\\\\r'") + self.assertEqual(qlquote.quote_literal("\\t"), "'\\\\t'") + self.assertEqual(qlquote.quote_literal("\\\'"), "'\\\\\\\''") + self.assertEqual(qlquote.quote_literal("\\\\"), "'\\\\\\\\'") + self.assertEqual(qlquote.quote_literal( + "abc\"efg\nhij\'klm\\nop"), + "'abc\"efg\\nhij\\\'klm\\\\nop'")