Skip to content

Commit

Permalink
Merge pull request #30 from ethan-lowman-fp/main
Browse files Browse the repository at this point in the history
Support newtype refs by mirroring diesel's AsExpression derive
  • Loading branch information
quodlibetor authored Mar 17, 2024
2 parents e582c18 + eb9e3c1 commit 23001b5
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 26 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
# Unreleased

* Add support for structs with internal references to DieselNewTypes (`ethan-lowman-fp` [#30](https://github.com/quodlibetor/diesel-derive-newtype/pull/30)):

```rust
#[derive(DieselNewType)]
pub struct MyIdString(String);

#[derive(Insertable, Queryable)]
#[diesel(table_name = my_entities)]
pub struct NewMyEntity<'a> {
id: &'a MyIdString, // <-- &'a of DieselNewType
}
```

# 2.1.0

* Update for Diesel 2.1 (`@marhag87`), not compatible with Diesel 2.0.x.
Expand Down
25 changes: 19 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ fn expand_sql_types(ast: &syn::DeriveInput) -> TokenStream {

// Required to be able to insert/read from the db, don't allow searching
let to_sql_impl = gen_tosql(&name, &wrapped_ty);
let as_expr_impl = gen_asexpresions(&name, &wrapped_ty);
let as_expr_impl = gen_asexpressions(&name, &wrapped_ty);

// raw deserialization
let from_sql_impl = gen_from_sql(&name, &wrapped_ty);
Expand Down Expand Up @@ -192,7 +192,7 @@ fn gen_tosql(name: &syn::Ident, wrapped_ty: &syn::Type) -> TokenStream {
}
}

fn gen_asexpresions(name: &syn::Ident, wrapped_ty: &syn::Type) -> TokenStream {
fn gen_asexpressions(name: &syn::Ident, wrapped_ty: &syn::Type) -> TokenStream {
quote! {

impl<ST> diesel::expression::AsExpression<ST> for #name
Expand All @@ -201,10 +201,10 @@ fn gen_asexpresions(name: &syn::Ident, wrapped_ty: &syn::Type) -> TokenStream {
diesel::expression::Expression<SqlType=ST>,
ST: diesel::sql_types::SingleValue,
{
type Expression = diesel::internal::derives::as_expression::Bound<ST, #wrapped_ty>;
type Expression = diesel::internal::derives::as_expression::Bound<ST, Self>;

fn as_expression(self) -> Self::Expression {
diesel::internal::derives::as_expression::Bound::new(self.0)
diesel::internal::derives::as_expression::Bound::new(self)
}
}

Expand All @@ -214,10 +214,23 @@ fn gen_asexpresions(name: &syn::Ident, wrapped_ty: &syn::Type) -> TokenStream {
diesel::expression::Expression<SqlType=ST>,
ST: diesel::sql_types::SingleValue,
{
type Expression = diesel::internal::derives::as_expression::Bound<ST, &'expr #wrapped_ty>;
type Expression = diesel::internal::derives::as_expression::Bound<ST, Self>;

fn as_expression(self) -> Self::Expression {
diesel::internal::derives::as_expression::Bound::new(&self.0)
diesel::internal::derives::as_expression::Bound::new(self)
}
}

impl<'expr2, 'expr, ST> diesel::expression::AsExpression<ST> for &'expr2 &'expr #name
where
diesel::internal::derives::as_expression::Bound<ST, #wrapped_ty>:
diesel::expression::Expression<SqlType=ST>,
ST: diesel::sql_types::SingleValue,
{
type Expression = diesel::internal::derives::as_expression::Bound<ST, Self>;

fn as_expression(self) -> Self::Expression {
diesel::internal::derives::as_expression::Bound::new(self)
}
}
}
Expand Down
159 changes: 139 additions & 20 deletions tests/db-roundtrips.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,43 @@ use diesel::sqlite::SqliteConnection;
use diesel_derive_newtype::DieselNewType;

#[derive(Debug, Clone, PartialEq, Eq, Hash, DieselNewType)]
pub struct MyId(String);
pub struct MyIdString(String);

#[derive(Debug, Clone, PartialEq, Eq, Hash, DieselNewType)]
pub struct MyI32(i32);

#[derive(Debug, Clone, PartialEq, Eq, Hash, DieselNewType)]
pub struct MyNullableString(Option<String>);

#[derive(Debug, Clone, PartialEq, Eq, Hash, DieselNewType)]
pub struct MyNullableI32(Option<i32>);

#[derive(Debug, Clone, PartialEq, Identifiable, Insertable, Queryable)]
#[diesel(table_name = my_entities)]
pub struct MyEntity {
id: MyId,
id: MyIdString,
my_i32: MyI32,
my_nullable_string: MyNullableString,
my_nullable_i32: MyNullableI32,
val: i32,
}

#[derive(Debug, Clone, PartialEq, Insertable)]
#[diesel(table_name = my_entities)]
pub struct MyEntityInternalRefs<'a> {
id: &'a MyIdString,
my_i32: MyI32,
my_nullable_string: &'a MyNullableString,
my_nullable_i32: &'a MyNullableI32,
val: i32,
}

table! {
my_entities {
id -> Text,
my_i32 -> Integer,
my_nullable_string -> Nullable<Text>,
my_nullable_i32 -> Nullable<Integer>,
val -> Integer,
}
}
Expand All @@ -24,9 +49,12 @@ table! {
fn setup() -> SqliteConnection {
let mut conn = SqliteConnection::establish(":memory:").unwrap();
let setup = sql::<diesel::sql_types::Bool>(
"CREATE TABLE IF NOT EXISTS my_entities (
"CREATE TABLE my_entities (
id TEXT PRIMARY KEY,
val Int
my_i32 int NOT NULL,
my_nullable_string TEXT,
my_nullable_i32 int,
val Int NOT NULL
)",
);
setup.execute(&mut conn).expect("Can't create table");
Expand All @@ -37,7 +65,56 @@ fn setup() -> SqliteConnection {
fn does_roundtrip() {
let mut conn = setup();
let obj = MyEntity {
id: MyId("WooHoo".into()),
id: MyIdString("WooHoo".into()),
my_i32: MyI32(10),
my_nullable_string: MyNullableString(Some("WooHoo".into())),
my_nullable_i32: MyNullableI32(Some(10)),
val: 1,
};

diesel::insert_into(my_entities::table)
.values(&obj)
.execute(&mut conn)
.expect("Couldn't insert struct into my_entities");

let found: Vec<MyEntity> = my_entities::table.load(&mut conn).unwrap();
println!("found: {:?}", found);
assert_eq!(found[0], obj);
}

#[test]
fn does_roundtrip_with_ref() {
let mut conn = setup();
let obj = MyEntityInternalRefs {
id: &MyIdString("WooHoo".into()),
my_i32: MyI32(10),
my_nullable_string: &MyNullableString(Some("WooHoo".into())),
my_nullable_i32: &MyNullableI32(Some(10)),
val: 1,
};

diesel::insert_into(my_entities::table)
.values(&obj)
.execute(&mut conn)
.expect("Couldn't insert struct into my_entities");

let found: Vec<MyEntity> = my_entities::table.load(&mut conn).unwrap();
println!("found: {:?}", found);
assert_eq!(found[0].id, *obj.id);
assert_eq!(found[0].my_i32, obj.my_i32);
assert_eq!(found[0].my_nullable_string, *obj.my_nullable_string);
assert_eq!(found[0].my_nullable_i32, *obj.my_nullable_i32);
assert_eq!(found[0].val, obj.val);
}

#[test]
fn does_roundtrip_nulls() {
let mut conn = setup();
let obj = MyEntity {
id: MyIdString("WooHoo".into()),
my_i32: MyI32(10),
my_nullable_string: MyNullableString(None),
my_nullable_i32: MyNullableI32(None),
val: 1,
};

Expand All @@ -56,11 +133,17 @@ fn queryable() {
let mut conn = setup();
let objs = vec![
MyEntity {
id: MyId("WooHoo".into()),
id: MyIdString("WooHoo".into()),
my_i32: MyI32(10),
my_nullable_string: MyNullableString(Some("WooHoo".into())),
my_nullable_i32: MyNullableI32(Some(10)),
val: 1,
},
MyEntity {
id: MyId("boo".into()),
id: MyIdString("boo".into()),
my_i32: MyI32(20),
my_nullable_string: MyNullableString(None),
my_nullable_i32: MyNullableI32(None),
val: 2,
},
];
Expand All @@ -70,7 +153,7 @@ fn queryable() {
.execute(&mut conn)
.expect("Couldn't insert struct into my_entities");

let ids: Vec<MyId> = my_entities::table
let ids: Vec<MyIdString> = my_entities::table
.select(my_entities::columns::id)
.load(&mut conn)
.unwrap();
Expand All @@ -82,17 +165,26 @@ fn queryable() {
fn query_as_id() {
let mut conn = setup();
let expected = MyEntity {
id: MyId("WooHoo".into()),
id: MyIdString("WooHoo".into()),
my_i32: MyI32(10),
my_nullable_string: MyNullableString(Some("WooHoo".into())),
my_nullable_i32: MyNullableI32(Some(10)),
val: 1,
};
let objs = vec![
MyEntity {
id: MyId("loop".into()),
id: MyIdString("loop".into()),
my_i32: MyI32(0),
my_nullable_string: MyNullableString(Some("loop".into())),
my_nullable_i32: MyNullableI32(Some(0)),
val: 0,
},
expected.clone(),
MyEntity {
id: MyId("boo".into()),
id: MyIdString("boo".into()),
my_i32: MyI32(20),
my_nullable_string: MyNullableString(None),
my_nullable_i32: MyNullableI32(None),
val: 2,
},
];
Expand All @@ -103,7 +195,7 @@ fn query_as_id() {
.expect("Couldn't insert struct into my_entities");

let ids: Vec<MyEntity> = my_entities::table
.filter(my_entities::id.eq(MyId("WooHoo".into())))
.filter(my_entities::id.eq(MyIdString("WooHoo".into())))
.load(&mut conn)
.unwrap();
assert_eq!(ids, vec![expected])
Expand All @@ -113,17 +205,26 @@ fn query_as_id() {
fn query_as_underlying_type() {
let mut conn = setup();
let expected = MyEntity {
id: MyId("WooHoo".into()),
id: MyIdString("WooHoo".into()),
my_i32: MyI32(10),
my_nullable_string: MyNullableString(Some("WooHoo".into())),
my_nullable_i32: MyNullableI32(Some(10)),
val: 1,
};
let objs = vec![
MyEntity {
id: MyId("loop".into()),
my_i32: MyI32(0),
id: MyIdString("loop".into()),
my_nullable_string: MyNullableString(Some("loop".into())),
my_nullable_i32: MyNullableI32(Some(0)),
val: 0,
},
expected.clone(),
MyEntity {
id: MyId("boo".into()),
id: MyIdString("boo".into()),
my_i32: MyI32(20),
my_nullable_string: MyNullableString(None),
my_nullable_i32: MyNullableI32(None),
val: 2,
},
];
Expand All @@ -144,17 +245,26 @@ fn query_as_underlying_type() {
fn set() {
let mut conn = setup();
let expected = MyEntity {
id: MyId("WooHoo".into()),
id: MyIdString("WooHoo".into()),
my_i32: MyI32(10),
my_nullable_string: MyNullableString(Some("WooHoo".into())),
my_nullable_i32: MyNullableI32(Some(10)),
val: 1,
};
let objs = vec![
MyEntity {
id: MyId("loop".into()),
id: MyIdString("loop".into()),
my_i32: MyI32(0),
my_nullable_string: MyNullableString(Some("loop".into())),
my_nullable_i32: MyNullableI32(Some(0)),
val: 0,
},
expected.clone(),
MyEntity {
id: MyId("boo".into()),
id: MyIdString("boo".into()),
my_i32: MyI32(20),
my_nullable_string: MyNullableString(None),
my_nullable_i32: MyNullableI32(None),
val: 2,
},
];
Expand All @@ -164,7 +274,7 @@ fn set() {
.execute(&mut conn)
.expect("Couldn't insert struct into my_entities");

let new_id = MyId("Oh My".into());
let new_id = MyIdString("Oh My".into());
diesel::update(my_entities::table.find(&expected.id))
.set(my_entities::id.eq(&new_id))
.execute(&mut conn)
Expand All @@ -173,5 +283,14 @@ fn set() {
.filter(my_entities::id.eq(&new_id))
.load(&mut conn)
.unwrap();
assert_eq!(updated_ids, vec![MyEntity { id: new_id, val: 1 }])
assert_eq!(
updated_ids,
vec![MyEntity {
id: new_id,
my_i32: MyI32(10),
my_nullable_string: MyNullableString(Some("WooHoo".into())),
my_nullable_i32: MyNullableI32(Some(10)),
val: 1
}]
)
}

0 comments on commit 23001b5

Please sign in to comment.