Skip to content

Commit

Permalink
Fix WASM transaction binding
Browse files Browse the repository at this point in the history
Correctly gets to a point of starting the tranasction and executing the
query, fails on parsing the results like normal queries do.
  • Loading branch information
Sergey Tatarintsev committed Nov 23, 2023
1 parent 272228e commit 39b2489
Show file tree
Hide file tree
Showing 13 changed files with 112 additions and 61 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions psl/psl-core/src/datamodel_connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,10 +361,7 @@ pub trait Connector: Send + Sync {
}
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen::prelude::wasm_bindgen)]
#[derive(Copy, Clone, Debug, PartialEq, Default, serde::Deserialize)]
pub enum Flavour {
#[default]
Cockroach,
Mongo,
Sqlserver,
Expand Down
2 changes: 1 addition & 1 deletion query-engine/driver-adapters/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ async-trait = "0.1"
once_cell = "1.15"
serde.workspace = true
serde_json.workspace = true
psl.workspace = true
tracing = "0.1"
tracing-core = "0.1"
metrics = "0.18"
Expand All @@ -21,6 +20,7 @@ num-bigint = "0.4.3"
bigdecimal = "0.3.0"
chrono = "0.4.20"
futures = "0.3"
web-sys = "0.3.65"

[dev-dependencies]
expect-test = "1"
Expand Down
25 changes: 11 additions & 14 deletions query-engine/driver-adapters/src/queryable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ pub(crate) use wasm::JsBaseQueryable;
use super::{
conversion,
proxy::{CommonProxy, DriverProxy, Query},
types::AdapterFlavour,
};
use crate::send_future::SendFuture;
use async_trait::async_trait;
use futures::Future;
use psl::datamodel_connector::Flavour;
use quaint::{
connector::{metrics, IsolationLevel, Transaction},
error::{Error, ErrorKind},
Expand All @@ -34,28 +34,26 @@ use tracing::{info_span, Instrument};

impl JsBaseQueryable {
pub(crate) fn new(proxy: CommonProxy) -> Self {
let flavour: Flavour = proxy.flavour.parse().unwrap();
let flavour: AdapterFlavour = proxy.flavour.parse().unwrap();
Self { proxy, flavour }
}

/// visit a quaint query AST according to the flavour of the JS connector
fn visit_quaint_query<'a>(&self, q: QuaintQuery<'a>) -> quaint::Result<(String, Vec<quaint::Value<'a>>)> {
match self.flavour {
Flavour::Mysql => visitor::Mysql::build(q),
Flavour::Postgres => visitor::Postgres::build(q),
Flavour::Sqlite => visitor::Sqlite::build(q),
_ => unimplemented!("Unsupported flavour for JS connector {:?}", self.flavour),
AdapterFlavour::Mysql => visitor::Mysql::build(q),
AdapterFlavour::Postgres => visitor::Postgres::build(q),
AdapterFlavour::Sqlite => visitor::Sqlite::build(q),
}
}

async fn build_query(&self, sql: &str, values: &[quaint::Value<'_>]) -> quaint::Result<Query> {
let sql: String = sql.to_string();

let converter = match self.flavour {
Flavour::Postgres => conversion::postgres::value_to_js_arg,
Flavour::Sqlite => conversion::sqlite::value_to_js_arg,
Flavour::Mysql => conversion::mysql::value_to_js_arg,
_ => unreachable!("Unsupported flavour for JS connector {:?}", self.flavour),
AdapterFlavour::Postgres => conversion::postgres::value_to_js_arg,
AdapterFlavour::Sqlite => conversion::sqlite::value_to_js_arg,
AdapterFlavour::Mysql => conversion::mysql::value_to_js_arg,
};

let args = values
Expand Down Expand Up @@ -127,7 +125,7 @@ impl QuaintQueryable for JsBaseQueryable {
return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build());
}

if self.flavour == Flavour::Sqlite {
if self.flavour == AdapterFlavour::Sqlite {
return match isolation_level {
IsolationLevel::Serializable => Ok(()),
_ => Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()),
Expand All @@ -140,9 +138,8 @@ impl QuaintQueryable for JsBaseQueryable {

fn requires_isolation_first(&self) -> bool {
match self.flavour {
Flavour::Mysql => true,
Flavour::Postgres | Flavour::Sqlite => false,
_ => unreachable!(),
AdapterFlavour::Mysql => true,
AdapterFlavour::Postgres | AdapterFlavour::Sqlite => false,
}
}
}
Expand Down
5 changes: 2 additions & 3 deletions query-engine/driver-adapters/src/queryable/wasm.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::types::AdapterFlavour;
use crate::wasm::proxy::{CommonProxy, DriverProxy};
use crate::{JsObjectExtern, JsQueryable};
use psl::datamodel_connector::Flavour;
use wasm_bindgen::prelude::wasm_bindgen;

/// A JsQueryable adapts a Proxy to implement quaint's Queryable interface. It has the
Expand All @@ -16,10 +16,9 @@ use wasm_bindgen::prelude::wasm_bindgen;
/// into a `quaint::connector::result_set::ResultSet`. A quaint `ResultSet` is basically a vector
/// of `quaint::Value` but said type is a tagged enum, with non-unit variants that cannot be converted to javascript as is.
#[wasm_bindgen(getter_with_clone)]
#[derive(Default)]
pub(crate) struct JsBaseQueryable {
pub(crate) proxy: CommonProxy,
pub flavour: Flavour,
pub flavour: AdapterFlavour,
}

pub fn from_wasm(driver: JsObjectExtern) -> JsQueryable {
Expand Down
27 changes: 26 additions & 1 deletion query-engine/driver-adapters/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#![allow(unused_imports)]

use std::str::FromStr;

#[cfg(not(target_arch = "wasm32"))]
use napi::bindgen_prelude::{FromNapiValue, ToNapiValue};

Expand All @@ -9,6 +11,28 @@ use tsify::Tsify;
use crate::conversion::JSArg;
use serde::{Deserialize, Serialize};

#[cfg_attr(target_arch = "wasm32", derive(Serialize, Deserialize, Tsify))]
#[cfg_attr(target_arch = "wasm32", tsify(into_wasm_abi, from_wasm_abi))]
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum AdapterFlavour {
Mysql,
Postgres,
Sqlite,
}

impl FromStr for AdapterFlavour {
type Err = String;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"postgres" => Ok(Self::Postgres),
"mysql" => Ok(Self::Mysql),
"sqlite" => Ok(Self::Sqlite),
_ => Err(format!("Unsupported adapter flavour: {:?}", s)),
}
}
}

/// This result set is more convenient to be manipulated from both Rust and NodeJS.
/// Quaint's version of ResultSet is:
///
Expand All @@ -27,7 +51,7 @@ use serde::{Deserialize, Serialize};
#[cfg_attr(target_arch = "wasm32", derive(Serialize, Deserialize, Tsify))]
#[cfg_attr(target_arch = "wasm32", tsify(into_wasm_abi, from_wasm_abi))]
#[cfg_attr(target_arch = "wasm32", serde(rename_all = "camelCase"))]
#[derive(Debug, Default)]
#[derive(Debug)]
pub struct JSResultSet {
pub column_types: Vec<ColumnType>,
pub column_names: Vec<String>,
Expand Down Expand Up @@ -190,6 +214,7 @@ pub struct Query {
#[cfg_attr(not(target_arch = "wasm32"), napi_derive::napi(object))]
#[cfg_attr(target_arch = "wasm32", derive(Serialize, Deserialize, Tsify))]
#[cfg_attr(target_arch = "wasm32", tsify(into_wasm_abi, from_wasm_abi))]
#[cfg_attr(target_arch = "wasm32", serde(rename_all = "camelCase"))]
#[derive(Debug, Default)]
pub struct TransactionOptions {
/// Whether or not to run a phantom query (i.e., a query that only influences Prisma event logs, but not the database itself)
Expand Down
30 changes: 19 additions & 11 deletions query-engine/driver-adapters/src/wasm/async_js_function.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
use js_sys::{Function as JsFunction, Promise as JsPromise};
use serde::{de::DeserializeOwned, Serialize};
use js_sys::{Function as JsFunction, JsString, Promise as JsPromise};
use serde::Serialize;
use std::marker::PhantomData;
use std::str::FromStr;
use wasm_bindgen::convert::FromWasmAbi;
use wasm_bindgen::describe::WasmDescribe;
use wasm_bindgen::{JsError, JsValue};
use wasm_bindgen_futures::JsFuture;

use super::error::into_quaint_error;
use super::from_js::FromJsValue;
use super::result::JsResult;

#[derive(Clone, Default)]
#[derive(Clone)]
pub(crate) struct AsyncJsFunction<ArgType, ReturnType>
where
ArgType: Serialize,
ReturnType: DeserializeOwned,
ReturnType: FromJsValue,
{
pub threadsafe_fn: JsFunction,

Expand All @@ -24,7 +26,7 @@ where
impl<T, R> From<JsFunction> for AsyncJsFunction<T, R>
where
T: Serialize,
R: DeserializeOwned,
R: FromJsValue,
{
fn from(js_fn: JsFunction) -> Self {
Self {
Expand All @@ -38,14 +40,20 @@ where
impl<T, R> AsyncJsFunction<T, R>
where
T: Serialize,
R: DeserializeOwned,
R: FromJsValue,
{
pub async fn call(&self, arg1: T) -> quaint::Result<R> {
let result = self.call_internal(arg1).await;

match result {
Ok(js_result) => js_result.into(),
Err(err) => Err(into_quaint_error(err)),
Ok(js_result) => {
web_sys::console::log_1(&JsString::from_str("OK JS").unwrap().into());
js_result.into()
}
Err(err) => {
web_sys::console::log_1(&JsString::from_str("CALL ERR").unwrap().into());
Err(into_quaint_error(err))
}
}
}

Expand All @@ -54,7 +62,7 @@ where
let promise = self.threadsafe_fn.call1(&JsValue::null(), &arg1)?;
let future = JsFuture::from(JsPromise::from(promise));
let value = future.await?;
let js_result: JsResult<R> = value.try_into()?;
let js_result = JsResult::<R>::from_js_value(value)?;

Ok(js_result)
}
Expand All @@ -63,7 +71,7 @@ where
impl<ArgType, ReturnType> WasmDescribe for AsyncJsFunction<ArgType, ReturnType>
where
ArgType: Serialize,
ReturnType: DeserializeOwned,
ReturnType: FromJsValue,
{
fn describe() {
JsFunction::describe();
Expand All @@ -73,7 +81,7 @@ where
impl<ArgType, ReturnType> FromWasmAbi for AsyncJsFunction<ArgType, ReturnType>
where
ArgType: Serialize,
ReturnType: DeserializeOwned,
ReturnType: FromJsValue,
{
type Abi = <JsFunction as FromWasmAbi>::Abi;

Expand Down
1 change: 1 addition & 0 deletions query-engine/driver-adapters/src/wasm/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use wasm_bindgen::JsValue;
/// transforms a Wasm error into a Quaint error
pub(crate) fn into_quaint_error(wasm_err: JsValue) -> QuaintError {
let status = "WASM_ERROR".to_string();
web_sys::console::log_1(&wasm_err);
let reason = Reflect::get(&wasm_err, &JsValue::from_str("stack"))
.ok()
.and_then(|value| value.as_string())
Expand Down
15 changes: 15 additions & 0 deletions query-engine/driver-adapters/src/wasm/from_js.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use serde::de::DeserializeOwned;
use wasm_bindgen::JsValue;

pub trait FromJsValue: Sized {
fn from_js_value(value: JsValue) -> Result<Self, JsValue>;
}

impl<T> FromJsValue for T
where
T: DeserializeOwned,
{
fn from_js_value(value: JsValue) -> Result<Self, JsValue> {
serde_wasm_bindgen::from_value(value).map_err(|e| JsValue::from(e))
}
}
1 change: 1 addition & 0 deletions query-engine/driver-adapters/src/wasm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
mod async_js_function;
mod error;
mod from_js;
mod js_object_extern;
pub(crate) mod proxy;
mod result;
Expand Down
2 changes: 0 additions & 2 deletions query-engine/driver-adapters/src/wasm/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ type JsResult<T> = core::result::Result<T, JsValue>;
/// querying and executing SQL (i.e. a client connector). The Proxy uses Wasm's JsFunction to
/// invoke the code within the node runtime that implements the client connector.
#[wasm_bindgen(getter_with_clone)]
#[derive(Default)]
pub(crate) struct CommonProxy {
/// Execute a query given as SQL, interpolating the given parameters.
query_raw: AsyncJsFunction<Query, JSResultSet>,
Expand All @@ -38,7 +37,6 @@ pub(crate) struct DriverProxy {
/// This a JS proxy for accessing the methods, specific
/// to JS transaction objects
#[wasm_bindgen(getter_with_clone)]
#[derive(Default)]
pub(crate) struct TransactionProxy {
/// transaction options
options: TransactionOptions,
Expand Down
31 changes: 12 additions & 19 deletions query-engine/driver-adapters/src/wasm/result.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use js_sys::Boolean as JsBoolean;
use std::str::FromStr;

use js_sys::{Boolean as JsBoolean, JsString};
use quaint::error::{Error as QuaintError, ErrorKind};
use serde::de::DeserializeOwned;
use wasm_bindgen::{JsCast, JsValue};

use super::from_js::FromJsValue;
use crate::{error::DriverAdapterError, JsObjectExtern};

impl From<DriverAdapterError> for QuaintError {
Expand All @@ -26,36 +28,27 @@ impl From<DriverAdapterError> for QuaintError {
/// Wrapper for JS-side result type
pub(crate) enum JsResult<T>
where
T: DeserializeOwned,
T: FromJsValue,
{
Ok(T),
Err(DriverAdapterError),
}

impl<T> TryFrom<JsValue> for JsResult<T>
where
T: DeserializeOwned,
{
type Error = JsValue;

fn try_from(value: JsValue) -> Result<Self, Self::Error> {
Self::from_js_unknown(value)
}
}

impl<T> JsResult<T>
impl<T> FromJsValue for JsResult<T>
where
T: DeserializeOwned,
T: FromJsValue,
{
fn from_js_unknown(unknown: JsValue) -> Result<Self, JsValue> {
fn from_js_value(unknown: JsValue) -> Result<Self, JsValue> {
let object = unknown.unchecked_into::<JsObjectExtern>();

let ok: JsBoolean = object.get("ok".into())?.unchecked_into();
let ok = ok.value_of();

if ok {
let js_value: JsValue = object.get("value".into())?;
let deserialized = serde_wasm_bindgen::from_value::<T>(js_value)?;
web_sys::console::log_1(&JsString::from_str("BEFORE DESERIALIZE").unwrap().into());
let deserialized = T::from_js_value(js_value)?;
web_sys::console::log_1(&JsString::from_str(" DESERIALIZE").unwrap().into());
return Ok(Self::Ok(deserialized));
}

Expand All @@ -67,7 +60,7 @@ where

impl<T> From<JsResult<T>> for quaint::Result<T>
where
T: DeserializeOwned,
T: FromJsValue,
{
fn from(value: JsResult<T>) -> Self {
match value {
Expand Down
Loading

0 comments on commit 39b2489

Please sign in to comment.