Skip to content

Commit

Permalink
refactor(db, webserver): Refactor InvitationDao to use query_as! (#1555)
Browse files Browse the repository at this point in the history
* refactor(db, webserver): Refactor InvitationDao to use query_as!

* Apply suggestion

* Make DateTimeUtc use String for type info

* Remove duplicate AsID impl
  • Loading branch information
boxbeam authored Feb 29, 2024
1 parent 18fdf8a commit 7f90102
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 28 deletions.
32 changes: 17 additions & 15 deletions ee/tabby-db/src/invitations.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
use anyhow::{anyhow, Result};
use chrono::{DateTime, Utc};
use sqlx::{prelude::FromRow, query};
use uuid::Uuid;

use super::DbConn;
use crate::SQLXResultExt;
use crate::{DateTimeUtc, SQLXResultExt};

#[derive(FromRow)]
pub struct InvitationDAO {
pub id: i32,
pub id: i64,
pub email: String,
pub code: String,

pub created_at: DateTime<Utc>,
pub created_at: DateTimeUtc,
}

/// db read/write operations for `invitations` table
Expand All @@ -37,20 +36,23 @@ impl DbConn {
}

pub async fn get_invitation_by_code(&self, code: &str) -> Result<Option<InvitationDAO>> {
let token =
sqlx::query_as(r#"SELECT id, email, code, created_at FROM invitations WHERE code = ?"#)
.bind(code)
.fetch_optional(&self.pool)
.await?;
let token = sqlx::query_as!(
InvitationDAO,
r#"SELECT id as "id!", email, code, created_at as "created_at!" FROM invitations WHERE code = ?"#,
code
)
.fetch_optional(&self.pool)
.await?;

Ok(token)
}

pub async fn get_invitation_by_email(&self, email: &str) -> Result<Option<InvitationDAO>> {
let token = sqlx::query_as(
r#"SELECT id, email, code, created_at FROM invitations WHERE email = ?"#,
let token = sqlx::query_as!(
InvitationDAO,
r#"SELECT id as "id!", email, code, created_at as "created_at!" FROM invitations WHERE email = ?"#,
email
)
.bind(email)
.fetch_optional(&self.pool)
.await?;

Expand All @@ -63,7 +65,7 @@ impl DbConn {
}

let code = Uuid::new_v4().to_string();
let created_at = chrono::offset::Utc::now();
let created_at = chrono::offset::Utc::now().into();
let res = query!(
"INSERT INTO invitations (email, code, created_at) VALUES (?, ?, ?)",
email,
Expand All @@ -74,7 +76,7 @@ impl DbConn {
.await;

let res = res.unique_error("Failed to create invitation, email already exists")?;
let id = res.last_insert_rowid() as i32;
let id = res.last_insert_rowid();

Ok(InvitationDAO {
id,
Expand All @@ -84,7 +86,7 @@ impl DbConn {
})
}

pub async fn delete_invitation(&self, id: i32) -> Result<i32> {
pub async fn delete_invitation(&self, id: i64) -> Result<i64> {
let res = query!("DELETE FROM invitations WHERE id = ?", id)
.execute(&self.pool)
.await?;
Expand Down
37 changes: 34 additions & 3 deletions ee/tabby-db/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ pub use invitations::InvitationDAO;
pub use job_runs::JobRunDAO;
pub use repositories::RepositoryDAO;
pub use server_setting::ServerSettingDAO;
use sqlx::{query, query_scalar, sqlite::SqliteQueryResult, Pool, Sqlite, SqlitePool};
use sqlx::{
query, query_scalar, sqlite::SqliteQueryResult, Pool, Sqlite, SqlitePool, Type, Value, ValueRef,
};
pub use users::UserDAO;

pub mod cache;
Expand Down Expand Up @@ -200,10 +202,39 @@ impl DbConn {

pub struct DateTimeUtc(DateTime<Utc>);

impl From<DateTime<Utc>> for DateTimeUtc {
fn from(value: DateTime<Utc>) -> Self {
Self(value)
}
}

impl<'a> sqlx::Decode<'a, Sqlite> for DateTimeUtc {
fn decode(
value: <Sqlite as sqlx::database::HasValueRef<'a>>::ValueRef,
) -> std::prelude::v1::Result<Self, sqlx::error::BoxDynError> {
let time: NaiveDateTime = value.to_owned().decode();
Ok(time.into())
}
}

impl Type<Sqlite> for DateTimeUtc {
fn type_info() -> <Sqlite as sqlx::Database>::TypeInfo {
<String as Type<Sqlite>>::type_info()
}
}

impl<'a> sqlx::Encode<'a, Sqlite> for DateTimeUtc {
fn encode_by_ref(
&self,
buf: &mut <Sqlite as sqlx::database::HasArguments<'a>>::ArgumentBuffer,
) -> sqlx::encode::IsNull {
self.0.encode_by_ref(buf)
}
}

impl From<NaiveDateTime> for DateTimeUtc {
fn from(value: NaiveDateTime) -> Self {
let utc = DateTime::from_naive_utc_and_offset(value, Utc);
DateTimeUtc(utc)
DateTimeUtc(value.and_utc())
}
}

Expand Down
13 changes: 9 additions & 4 deletions ee/tabby-webserver/src/service/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ impl AuthenticationService for AuthenticationServiceImpl {
email.clone(),
pwd_hash,
!is_admin_initialized,
invitation.id,
invitation.id as i32,
)
.await?
} else {
Expand Down Expand Up @@ -308,7 +308,7 @@ impl AuthenticationService for AuthenticationServiceImpl {
}

async fn delete_invitation(&self, id: &ID) -> Result<ID> {
Ok(self.db.delete_invitation(id.as_rowid()?).await?.as_id())
Ok((self.db.delete_invitation(id.as_rowid()? as i64).await?).as_id())
}

async fn reset_user_auth_token(&self, id: &ID) -> Result<()> {
Expand Down Expand Up @@ -474,7 +474,12 @@ async fn get_or_create_oauth_user(db: &DbConn, email: &str) -> Result<(i32, bool
};
// safe to create with empty password for same reasons above
let id = db
.create_user_with_invitation(email.to_owned(), "".to_owned(), false, invitation.id)
.create_user_with_invitation(
email.to_owned(),
"".to_owned(),
false,
invitation.id as i32,
)
.await?;
let user = db.get_user(id).await?.unwrap();
Ok((user.id, user.is_admin))
Expand Down Expand Up @@ -743,7 +748,7 @@ mod tests {
// Used invitation should have been deleted, following delete attempt should fail.
assert!(service
.db
.delete_invitation(invitation.id.as_rowid().unwrap())
.delete_invitation(invitation.id.as_rowid().unwrap() as i64)
.await
.is_err());
}
Expand Down
12 changes: 6 additions & 6 deletions ee/tabby-webserver/src/service/dao.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ use crate::schema::{
impl From<InvitationDAO> for auth::Invitation {
fn from(val: InvitationDAO) -> Self {
Self {
id: val.id.as_id(),
id: (val.id as i32).as_id(),
email: val.email,
code: val.code,
created_at: val.created_at,
created_at: *val.created_at,
}
}
}
Expand Down Expand Up @@ -154,15 +154,15 @@ pub trait AsID {
fn as_id(&self) -> juniper::ID;
}

impl AsID for i32 {
impl AsID for i64 {
fn as_id(&self) -> juniper::ID {
(*self as i64).as_id()
juniper::ID::new(HASHER.encode(&[*self as u64]))
}
}

impl AsID for i64 {
impl AsID for i32 {
fn as_id(&self) -> juniper::ID {
juniper::ID::new(HASHER.encode(&[*self as u64]))
(*self as i64).as_id()
}
}

Expand Down

0 comments on commit 7f90102

Please sign in to comment.