Skip to content
This repository has been archived by the owner on Apr 25, 2023. It is now read-only.

Prevent decimal/float conversions where unstable. Allow casting of expressions. #292

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ connection-string = "0.1.10"
percent-encoding = "2"
tracing-core = "0.1"
async-trait = "0.1"
enumflags2 = "0.7"
thiserror = "1.0"
once_cell = "1.3"
num_cpus = "1.12"
Expand Down
Binary file modified db/test.db
Binary file not shown.
2 changes: 2 additions & 0 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
//! actual query building is in the [visitor](../visitor/index.html) module.
//!
//! For prelude, all important imports are in `quaint::ast::*`.
mod castable;
mod column;
mod compare;
mod conditions;
Expand All @@ -29,6 +30,7 @@ mod union;
mod update;
mod values;

pub use castable::*;
pub use column::{Column, DefaultValue, TypeDataLength, TypeFamily};
pub use compare::{Comparable, Compare, JsonCompare, JsonType};
pub use conditions::ConditionTree;
Expand Down
295 changes: 295 additions & 0 deletions src/ast/castable.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
use enumflags2::{bitflags, BitFlags};
use std::borrow::Cow;

#[bitflags]
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq)]
enum CastDatabase {
Postgres = 1 << 0,
Mysql = 1 << 1,
Mssql = 1 << 2,
}

/// A typecast for an expression.
///
/// By default, casting is performed on all databases. To restrict this
/// behavior, use the corresponding methods
/// [on_postgres](struct.CastType.html#method.on_postgres),
/// [on_mysql](struct.CastType.html#method.on_mysql) or
/// [on_sql_server](struct.CastType.html#method.on_sql_server).
///
/// Always a no-op on SQLite.
#[derive(Debug, Clone, PartialEq)]
pub struct CastType<'a> {
kind: CastKind<'a>,
on_databases: BitFlags<CastDatabase>,
}

impl<'a> CastType<'a> {
/// A 16-bit integer.
///
/// - PostgreSQL: `int2`
/// - MySQL: `signed`
/// - SQL Server: `smallint`
pub fn int2() -> Self {
Self {
kind: CastKind::Int2,
on_databases: BitFlags::all(),
}
}

/// A 32-bit integer (int)
///
/// - PostgreSQL: `int4`
/// - MySQL: `signed`
/// - SQL Server: `int`
pub fn int4() -> Self {
Self {
kind: CastKind::Int4,
on_databases: BitFlags::all(),
}
}

/// A 64-bit integer (bigint)
///
/// - PostgreSQL: `int8`
/// - MySQL: `signed`
/// - SQL Server: `bigint`
pub fn int8() -> Self {
Self {
kind: CastKind::Int8,
on_databases: BitFlags::all(),
}
}

/// A 32-bit floating point number
///
/// - PostgreSQL: `float4`
/// - MySQL: `decimal`
/// - SQL Server: `real`
pub fn float4() -> Self {
Self {
kind: CastKind::Float4,
on_databases: BitFlags::all(),
}
}

/// A 64-bit floating point number
///
/// - PostgreSQL: `float8`
/// - MySQL: `decimal`
/// - SQL Server: `float`
pub fn float8() -> Self {
Self {
kind: CastKind::Float8,
on_databases: BitFlags::all(),
}
}

/// An arbitrary-precision numeric type
///
/// - PostgreSQL: `numeric`
/// - MySQL: `decimal`
/// - SQL Server: `numeric`
pub fn decimal() -> Self {
Self {
kind: CastKind::Decimal,
on_databases: BitFlags::all(),
}
}

/// True or false (or a bit)
///
/// - PostgreSQL: `boolean`
/// - MySQL: `unsigned`
/// - SQL Server: `bit`
pub fn boolean() -> Self {
Self {
kind: CastKind::Boolean,
on_databases: BitFlags::all(),
}
}

/// A unique identifier
///
/// - PostgreSQL: `uuid`
/// - MySQL: `char`
/// - SQL Server: `uniqueidentifier`
pub fn uuid() -> Self {
Self {
kind: CastKind::Uuid,
on_databases: BitFlags::all(),
}
}

/// Json data
///
/// - PostgreSQL: `json`
/// - MySQL: `nchar`
/// - SQL Server: `nvarchar`
pub fn json() -> Self {
Self {
kind: CastKind::Json,
on_databases: BitFlags::all(),
}
}

/// Jsonb data
///
/// - PostgreSQL: `jsonb`
/// - MySQL: `nchar`
/// - SQL Server: `nvarchar`
pub fn jsonb() -> Self {
Self {
kind: CastKind::Jsonb,
on_databases: BitFlags::all(),
}
}

/// Date value
///
/// - PostgreSQL: `date`
/// - MySQL: `date`
/// - SQL Server: `date`
pub fn date() -> Self {
Self {
kind: CastKind::Date,
on_databases: BitFlags::all(),
}
}

/// Time value
///
/// - PostgreSQL: `time`
/// - MySQL: `time`
/// - SQL Server: `time`
pub fn time() -> Self {
Self {
kind: CastKind::Time,
on_databases: BitFlags::all(),
}
}

/// Datetime value
///
/// - PostgreSQL: `datetime`
/// - MySQL: `datetime`
/// - SQL Server: `datetime2`
pub fn datetime() -> Self {
Self {
kind: CastKind::DateTime,
on_databases: BitFlags::all(),
}
}

/// Byte blob
///
/// - PostgreSQL: `bytea`
/// - MySQL: `binary`
/// - SQL Server: `bytes`
pub fn bytes() -> Self {
Self {
kind: CastKind::Bytes,
on_databases: BitFlags::all(),
}
}

/// Textual data
///
/// - PostgreSQL: `text`
/// - MySQL: `nchar`
/// - SQL Server: `nvarchar`
pub fn text() -> Self {
Self {
kind: CastKind::Text,
on_databases: BitFlags::all(),
}
}

/// Creates a new custom cast type.
pub fn custom(r#type: impl Into<Cow<'a, str>>) -> Self {
Self {
kind: CastKind::Custom(r#type.into()),
on_databases: BitFlags::all(),
}
}

/// Perform the given cast on PostgreSQL.
pub fn on_postgres(mut self) -> Self {
self.maybe_clear_databases();
self.on_databases.insert(CastDatabase::Postgres);

self
}

/// Perform the given cast on MySQL.
pub fn on_mysql(mut self) -> Self {
self.maybe_clear_databases();
self.on_databases.insert(CastDatabase::Mysql);

self
}

/// Perform the given cast on SQL Server.
pub fn on_sql_server(mut self) -> Self {
self.maybe_clear_databases();
self.on_databases.insert(CastDatabase::Mssql);

self
}

#[cfg(feature = "postgresql")]
pub(crate) fn postgres_enabled(&self) -> bool {
self.on_databases.contains(CastDatabase::Postgres)
}

#[cfg(feature = "mysql")]
pub(crate) fn mysql_enabled(&self) -> bool {
self.on_databases.contains(CastDatabase::Mysql)
}

#[cfg(feature = "mssql")]
pub(crate) fn mssql_enabled(&self) -> bool {
self.on_databases.contains(CastDatabase::Mssql)
}

#[cfg(any(feature = "mssql", feature = "mysql", feature = "mssql"))]
pub(crate) fn kind(&self) -> &CastKind<'a> {
&self.kind
}

fn maybe_clear_databases(&mut self) {
if self.on_databases.is_all() {
self.on_databases.remove(BitFlags::all());
}
}
}

#[derive(Debug, Clone, PartialEq)]
pub(crate) enum CastKind<'a> {
Int2,
Int4,
Int8,
Float4,
Float8,
Decimal,
Boolean,
Uuid,
Json,
Jsonb,
Date,
Time,
DateTime,
Bytes,
Text,
Custom(Cow<'a, str>),
}

/// An item that can be cast to a different type.
pub trait Castable<'a, T>
where
T: Sized,
{
/// Map the result of the underlying item into a different type.
fn cast_as(self, r#type: CastType<'a>) -> T;
}
15 changes: 15 additions & 0 deletions src/ast/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,36 @@ use crate::{
};
use std::borrow::Cow;

/// The maximum length of the column.
#[derive(Debug, Clone, Copy)]
pub enum TypeDataLength {
/// Number of either bytes or characters.
Constant(u16),
/// Stored outside of the row in the heap, usually either two or four
/// gigabytes.
Maximum,
}

/// The type family of the column.
#[derive(Debug, Clone, Copy)]
pub enum TypeFamily {
/// Textual data with an optional length.
Text(Option<TypeDataLength>),
/// Integers.
Int,
/// Floating point values, 32-bit.
Float,
/// Floating point values, 64-bit.
Double,
/// Trues and falses.
Boolean,
/// Unique identifiers.
Uuid,
/// Date, time and datetime.
DateTime,
/// Numerics with an arbitrary scale and precision.
Decimal(Option<(u8, u8)>),
/// Blobs with an optional length.
Bytes(Option<TypeDataLength>),
}

Expand Down Expand Up @@ -104,6 +118,7 @@ impl<'a> From<Column<'a>> for Expression<'a> {
Expression {
kind: ExpressionKind::Column(Box::new(col)),
alias: None,
cast: None,
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/ast/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ impl<'a> From<Compare<'a>> for Expression<'a> {
Expression {
kind: ExpressionKind::Compare(cmp),
alias: None,
cast: None,
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/ast/conditions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ impl<'a> From<ConditionTree<'a>> for Expression<'a> {
Expression {
kind: ExpressionKind::ConditionTree(ct),
alias: None,
cast: None,
}
}
}
Expand All @@ -138,6 +139,7 @@ impl<'a> From<Select<'a>> for ConditionTree<'a> {
let exp = Expression {
kind: ExpressionKind::Value(Box::new(sel.into())),
alias: None,
cast: None,
};

ConditionTree::single(exp)
Expand Down
Loading