Skip to content
This repository has been archived by the owner on Apr 25, 2023. It is now read-only.

Commit

Permalink
Expression casting
Browse files Browse the repository at this point in the history
Introduces a new trait `Castable`, implemented so far by anything that
can be converted to an `Expression`.

Usage:

```rust
Select::default().value(1.cast_to(CastType::int2()))
```

Will convert the given value to an `int2` or a corresponding type in the
database. Casting will happen on all visitors except SQLite. To
whitelist where casting should happen, it should happen in the
`CastType`:

```rust
Select::default().value(1.cast_to(CastType::int2().on_postgres()))
```

This will only cast on PostgreSQL. The methods can be chained:

```
1.cast_to(CastType::int2().on_postgres().on_sql_server());
```

There are certain restrictions what can be casted, and the casting
mechanism tries to find the closest allowed type from each database.
  • Loading branch information
Julius de Bruijn committed May 18, 2021
1 parent 042e341 commit b25718d
Show file tree
Hide file tree
Showing 16 changed files with 745 additions and 38 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ connection-string = "0.1.10"
percent-encoding = "2"
tracing-core = "0.1"
async-trait = "0.1"
enumflags2 = "0.7"
thiserror = "1.0"
once_cell = "1.3"
num_cpus = "1.12"
Expand Down
2 changes: 2 additions & 0 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
//! actual query building is in the [visitor](../visitor/index.html) module.
//!
//! For prelude, all important imports are in `quaint::ast::*`.
mod castable;
mod column;
mod compare;
mod conditions;
Expand All @@ -29,6 +30,7 @@ mod union;
mod update;
mod values;

pub use castable::*;
pub use column::{Column, DefaultValue, TypeDataLength, TypeFamily};
pub use compare::{Comparable, Compare};
pub use conditions::ConditionTree;
Expand Down
291 changes: 291 additions & 0 deletions src/ast/castable.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
use enumflags2::{bitflags, BitFlags};
use std::borrow::Cow;

#[bitflags]
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq)]
enum CastDatabase {
Postgres = 1 << 0,
Mysql = 1 << 1,
Mssql = 1 << 2,
}

/// A typecast for an expression.
///
/// By default, casting is performed on all databases. To restrict this
/// behavior, use the corresponding methods
/// [on_postgres](struct.CastType.html#method.on_postgres),
/// [on_mysql](struct.CastType.html#method.on_mysql) or
/// [on_sql_server](struct.CastType.html#method.on_sql_server).
///
/// Always a no-op on SQLite.
#[derive(Debug, Clone, PartialEq)]
pub struct CastType<'a> {
kind: CastKind<'a>,
on_databases: BitFlags<CastDatabase>,
}

impl<'a> CastType<'a> {
/// A 16-bit integer.
///
/// - PostgreSQL: `int2`
/// - MySQL: `signed`
/// - SQL Server: `smallint`
pub fn int2() -> Self {
Self {
kind: CastKind::Int2,
on_databases: BitFlags::all(),
}
}

/// A 32-bit integer (int)
///
/// - PostgreSQL: `int4`
/// - MySQL: `signed`
/// - SQL Server: `int`
pub fn int4() -> Self {
Self {
kind: CastKind::Int4,
on_databases: BitFlags::all(),
}
}

/// A 64-bit integer (bigint)
///
/// - PostgreSQL: `int8`
/// - MySQL: `signed`
/// - SQL Server: `bigint`
pub fn int8() -> Self {
Self {
kind: CastKind::Int8,
on_databases: BitFlags::all(),
}
}

/// A 32-bit floating point number
///
/// - PostgreSQL: `float4`
/// - MySQL: `decimal`
/// - SQL Server: `real`
pub fn float4() -> Self {
Self {
kind: CastKind::Float4,
on_databases: BitFlags::all(),
}
}

/// A 64-bit floating point number
///
/// - PostgreSQL: `float8`
/// - MySQL: `decimal`
/// - SQL Server: `float`
pub fn float8() -> Self {
Self {
kind: CastKind::Float8,
on_databases: BitFlags::all(),
}
}

/// An arbitrary-precision numeric type
///
/// - PostgreSQL: `numeric`
/// - MySQL: `decimal`
/// - SQL Server: `numeric`
pub fn decimal() -> Self {
Self {
kind: CastKind::Decimal,
on_databases: BitFlags::all(),
}
}

/// True or false (or a bit)
///
/// - PostgreSQL: `boolean`
/// - MySQL: `unsigned`
/// - SQL Server: `bit`
pub fn boolean() -> Self {
Self {
kind: CastKind::Boolean,
on_databases: BitFlags::all(),
}
}

/// A unique identifier
///
/// - PostgreSQL: `uuid`
/// - MySQL: `char`
/// - SQL Server: `uniqueidentifier`
pub fn uuid() -> Self {
Self {
kind: CastKind::Uuid,
on_databases: BitFlags::all(),
}
}

/// Json data
///
/// - PostgreSQL: `json`
/// - MySQL: `nchar`
/// - SQL Server: `nvarchar`
pub fn json() -> Self {
Self {
kind: CastKind::Json,
on_databases: BitFlags::all(),
}
}

/// Jsonb data
///
/// - PostgreSQL: `jsonb`
/// - MySQL: `nchar`
/// - SQL Server: `nvarchar`
pub fn jsonb() -> Self {
Self {
kind: CastKind::Jsonb,
on_databases: BitFlags::all(),
}
}

/// Date value
///
/// - PostgreSQL: `date`
/// - MySQL: `date`
/// - SQL Server: `date`
pub fn date() -> Self {
Self {
kind: CastKind::Date,
on_databases: BitFlags::all(),
}
}

/// Time value
///
/// - PostgreSQL: `time`
/// - MySQL: `time`
/// - SQL Server: `time`
pub fn time() -> Self {
Self {
kind: CastKind::Time,
on_databases: BitFlags::all(),
}
}

/// Datetime value
///
/// - PostgreSQL: `datetime`
/// - MySQL: `datetime`
/// - SQL Server: `datetime2`
pub fn datetime() -> Self {
Self {
kind: CastKind::DateTime,
on_databases: BitFlags::all(),
}
}

/// Byte blob
///
/// - PostgreSQL: `bytea`
/// - MySQL: `binary`
/// - SQL Server: `bytes`
pub fn bytes() -> Self {
Self {
kind: CastKind::Bytes,
on_databases: BitFlags::all(),
}
}

/// Textual data
///
/// - PostgreSQL: `text`
/// - MySQL: `nchar`
/// - SQL Server: `nvarchar`
pub fn text() -> Self {
Self {
kind: CastKind::Text,
on_databases: BitFlags::all(),
}
}

/// Creates a new custom cast type.
pub fn custom(r#type: impl Into<Cow<'a, str>>) -> Self {
Self {
kind: CastKind::Custom(r#type.into()),
on_databases: BitFlags::all(),
}
}

/// Perform the given cast on PostgreSQL.
pub fn on_postgres(mut self) -> Self {
self.maybe_clear_databases();
self.on_databases.insert(CastDatabase::Postgres);

self
}

/// Perform the given cast on MySQL.
pub fn on_mysql(mut self) -> Self {
self.maybe_clear_databases();
self.on_databases.insert(CastDatabase::Mysql);

self
}

/// Perform the given cast on SQL Server.
pub fn on_sql_server(mut self) -> Self {
self.maybe_clear_databases();
self.on_databases.insert(CastDatabase::Mssql);

self
}

pub(crate) fn postgres_enabled(&self) -> bool {
self.on_databases.contains(CastDatabase::Postgres)
}

pub(crate) fn mysql_enabled(&self) -> bool {
self.on_databases.contains(CastDatabase::Mysql)
}

pub(crate) fn mssql_enabled(&self) -> bool {
self.on_databases.contains(CastDatabase::Mssql)
}

pub(crate) fn kind(&self) -> &CastKind<'a> {
&self.kind
}

fn maybe_clear_databases(&mut self) {
if self.on_databases.is_all() {
self.on_databases.remove(BitFlags::all());
}
}
}

#[derive(Debug, Clone, PartialEq)]
pub(crate) enum CastKind<'a> {
Int2,
Int4,
Int8,
Float4,
Float8,
Decimal,
Boolean,
Uuid,
Json,
Jsonb,
Date,
Time,
DateTime,
Bytes,
Text,
Custom(Cow<'a, str>),
}

/// An item that can be cast to a different type.
pub trait Castable<'a, T>
where
T: Sized,
{
/// Map the result of the underlying item into a different type.
fn cast_as(self, r#type: CastType<'a>) -> T;
}
15 changes: 15 additions & 0 deletions src/ast/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,36 @@ use crate::{
};
use std::borrow::Cow;

/// The maximum length of the column.
#[derive(Debug, Clone, Copy)]
pub enum TypeDataLength {
/// Number of either bytes or characters.
Constant(u16),
/// Stored outside of the row in the heap, usually either two or four
/// gigabytes.
Maximum,
}

/// The type family of the column.
#[derive(Debug, Clone, Copy)]
pub enum TypeFamily {
/// Textual data with an optional length.
Text(Option<TypeDataLength>),
/// Integers.
Int,
/// Floating point values, 32-bit.
Float,
/// Floating point values, 64-bit.
Double,
/// Trues and falses.
Boolean,
/// Unique identifiers.
Uuid,
/// Date, time and datetime.
DateTime,
/// Numerics with an arbitrary scale and precision.
Decimal(Option<(u8, u8)>),
/// Blobs with an optional length.
Bytes(Option<TypeDataLength>),
}

Expand Down Expand Up @@ -104,6 +118,7 @@ impl<'a> From<Column<'a>> for Expression<'a> {
Expression {
kind: ExpressionKind::Column(Box::new(col)),
alias: None,
cast: None,
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/ast/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ impl<'a> From<Compare<'a>> for Expression<'a> {
Expression {
kind: ExpressionKind::Compare(cmp),
alias: None,
cast: None,
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/ast/conditions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ impl<'a> From<ConditionTree<'a>> for Expression<'a> {
Expression {
kind: ExpressionKind::ConditionTree(ct),
alias: None,
cast: None,
}
}
}
Expand All @@ -138,6 +139,7 @@ impl<'a> From<Select<'a>> for ConditionTree<'a> {
let exp = Expression {
kind: ExpressionKind::Value(Box::new(sel.into())),
alias: None,
cast: None,
};

ConditionTree::single(exp)
Expand Down
Loading

0 comments on commit b25718d

Please sign in to comment.