From 8356f30f8e649d0e4d3876363afccf26c046eada Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Wed, 2 Oct 2024 17:31:08 +0000 Subject: [PATCH] consolidate escape logic, match clickhouse-connect --- src/sql/escape.rs | 42 +++++++++++++++++++----------------------- src/sql/ser.rs | 10 +--------- 2 files changed, 20 insertions(+), 32 deletions(-) diff --git a/src/sql/escape.rs b/src/sql/escape.rs index e43cde6a..fae17063 100644 --- a/src/sql/escape.rs +++ b/src/sql/escape.rs @@ -1,35 +1,31 @@ use std::fmt; +// Trust clickhouse-connect https://github.com/ClickHouse/clickhouse-connect/blob/5d85563410f3ec378cb199ec51d75e033211392c/clickhouse_connect/driver/binding.py#L15 + // See https://clickhouse.tech/docs/en/sql-reference/syntax/#syntax-string-literal -pub(crate) fn string(src: &str, dst: impl fmt::Write) -> fmt::Result { - escape(src, dst, '\'') +pub(crate) fn string(src: &str, dst: &mut impl fmt::Write) -> fmt::Result { + dst.write_char('\'')?; + escape(src, dst)?; + dst.write_char('\'') } // See https://clickhouse.tech/docs/en/sql-reference/syntax/#syntax-identifiers -pub(crate) fn identifier(src: &str, dst: impl fmt::Write) -> fmt::Result { - escape(src, dst, '`') +pub(crate) fn identifier(src: &str, dst: &mut impl fmt::Write) -> fmt::Result { + dst.write_char('\'')?; + escape(src, dst)?; + dst.write_char('\'') } -fn escape(src: &str, mut dst: impl fmt::Write, ch: char) -> fmt::Result { - dst.write_char(ch)?; - - // TODO: escape newlines? - for (idx, part) in src.split(ch).enumerate() { - if idx > 0 { - dst.write_char('\\')?; - dst.write_char(ch)?; - } - - for (idx, part) in part.split('\\').enumerate() { - if idx > 0 { - dst.write_str("\\\\")?; - } - - dst.write_str(part)?; - } +pub(crate) fn escape(src: &str, dst: &mut impl fmt::Write) -> fmt::Result { + const REPLACE: &[char] = &['\\', '\'', '`', '\t', '\n']; + let mut rest = src; + while let Some(nextidx) = rest.find(REPLACE) { + let (before, after) = rest.split_at(nextidx); + rest = after; + dst.write_str(before)?; + dst.write_char('\\')?; } - - dst.write_char(ch) + dst.write_str(rest) } #[test] diff --git a/src/sql/ser.rs b/src/sql/ser.rs index 5715c12a..6cbdb22b 100644 --- a/src/sql/ser.rs +++ b/src/sql/ser.rs @@ -321,15 +321,7 @@ impl<'a, W: Write> Serializer for ParamSerializer<'a, W> { fn serialize_str(self, value: &str) -> Result { // ClickHouse expects strings in params to be unquoted until inside a nested type // nested types go through serialize_seq which'll quote strings - let mut rest = value; - while let Some(nextidx) = rest.find('\\') { - let (before, after) = rest.split_at(nextidx + 1); - rest = after; - self.writer.write_str(before)?; - self.writer.write_char('\\')?; - } - self.writer.write_str(rest)?; - Ok(()) + Ok(escape::escape(value, self.writer)?) } #[inline]