Skip to content

Commit

Permalink
feat: shared cache between grpc & engine and fix partial deser
Browse files Browse the repository at this point in the history
  • Loading branch information
Larkooo committed Nov 8, 2024
1 parent 14ffc8e commit ddd9921
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 53 deletions.
6 changes: 4 additions & 2 deletions bin/torii/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use torii_core::engine::{Engine, EngineConfig, IndexingFlags, Processors};
use torii_core::executor::Executor;
use torii_core::processors::store_transaction::StoreTransactionProcessor;
use torii_core::simple_broker::SimpleBroker;
use torii_core::sql::cache::ModelCache;
use torii_core::sql::Sql;
use torii_core::types::{Contract, ContractType, Model, ToriiConfig};
use torii_server::proxy::Proxy;
Expand Down Expand Up @@ -218,7 +219,8 @@ async fn main() -> anyhow::Result<()> {
executor.run().await.unwrap();
});

let db = Sql::new(pool.clone(), sender.clone(), &contracts).await?;
let model_cache = Arc::new(ModelCache::new(pool.clone()));
let db = Sql::new(pool.clone(), sender.clone(), &contracts, model_cache.clone()).await?;

let processors = Processors {
transaction: vec![Box::new(StoreTransactionProcessor)],
Expand Down Expand Up @@ -256,7 +258,7 @@ async fn main() -> anyhow::Result<()> {

let shutdown_rx = shutdown_tx.subscribe();
let (grpc_addr, grpc_server) =
torii_grpc::server::new(shutdown_rx, &pool, block_rx, world_address, Arc::clone(&provider))
torii_grpc::server::new(shutdown_rx, &pool, block_rx, world_address, Arc::clone(&provider), model_cache)
.await?;

let mut libp2p_relay_server = torii_relay::server::Relay::new(
Expand Down
55 changes: 19 additions & 36 deletions crates/dojo/types/src/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,46 +190,29 @@ impl Primitive {
}
}

pub fn to_sql_value(&self) -> Result<String, PrimitiveError> {
let value = self.serialize()?;

if value.is_empty() {
return Err(PrimitiveError::MissingFieldElement);
}

pub fn to_sql_value(&self) -> String {
match self {
// Integers
Primitive::I8(_) => Ok(format!("{}", try_from_felt::<i8>(value[0])?)),
Primitive::I16(_) => Ok(format!("{}", try_from_felt::<i16>(value[0])?)),
Primitive::I32(_) => Ok(format!("{}", try_from_felt::<i32>(value[0])?)),
Primitive::I64(_) => Ok(format!("{}", try_from_felt::<i64>(value[0])?)),
Primitive::I8(i8) => format!("{}", i8.unwrap_or_default()),
Primitive::I16(i16) => format!("{}", i16.unwrap_or_default()),
Primitive::I32(i32) => format!("{}", i32.unwrap_or_default()),
Primitive::I64(i64) => format!("{}", i64.unwrap_or_default()),

Primitive::U8(_)
| Primitive::U16(_)
| Primitive::U32(_)
| Primitive::USize(_)
| Primitive::Bool(_) => Ok(format!("{}", value[0])),
Primitive::U8(u8) => format!("{}", u8.unwrap_or_default()),
Primitive::U16(u16) => format!("{}", u16.unwrap_or_default()),
Primitive::U32(u32) => format!("{}", u32.unwrap_or_default()),
Primitive::USize(u32) => format!("{}", u32.unwrap_or_default()),
Primitive::Bool(bool) => format!("{}", bool.unwrap_or_default()),

// Hex string
Primitive::I128(_) => Ok(format!("{:#064x}", try_from_felt::<i128>(value[0])?)),
Primitive::ContractAddress(_)
| Primitive::ClassHash(_)
| Primitive::Felt252(_)
| Primitive::U128(_)
| Primitive::U64(_) => Ok(format!("{:#064x}", value[0])),

Primitive::U256(_) => {
if value.len() < 2 {
Err(PrimitiveError::NotEnoughFieldElements)
} else {
let mut buffer = [0u8; 32];
let value0_bytes = value[0].to_bytes_be();
let value1_bytes = value[1].to_bytes_be();
buffer[16..].copy_from_slice(&value0_bytes[16..]);
buffer[..16].copy_from_slice(&value1_bytes[16..]);
Ok(format!("0x{}", hex::encode(buffer)))
}
}
Primitive::I128(i128) => format!("{:#064x}", i128.unwrap_or_default()),
Primitive::ContractAddress(felt) => format!("{:#064x}", felt.unwrap_or_default()),
Primitive::ClassHash(felt) => format!("{:#064x}", felt.unwrap_or_default()),
Primitive::Felt252(felt) => format!("{:#064x}", felt.unwrap_or_default()),
Primitive::U128(u128) => format!("{:#064x}", u128.unwrap_or_default()),
Primitive::U64(u64) => format!("{:#064x}", u64.unwrap_or_default()),

Primitive::U256(u256) => format!("0x{:064x}", u256.unwrap_or_default()),
}
}

Expand Down Expand Up @@ -436,7 +419,7 @@ mod tests {
let primitive = Primitive::U256(Some(U256::from_be_hex(
"aaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbccccccccccccccccdddddddddddddddd",
)));
let sql_value = primitive.to_sql_value().unwrap();
let sql_value = primitive.to_sql_value();
let serialized = primitive.serialize().unwrap();

let mut deserialized = primitive;
Expand Down
5 changes: 5 additions & 0 deletions crates/dojo/types/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ impl Ty {
}

pub fn deserialize(&mut self, felts: &mut Vec<Felt>) -> Result<(), PrimitiveError> {
if felts.is_empty() {
// return early if there are no felts to deserialize
return Ok(());
}

match self {
Ty::Primitive(c) => {
c.deserialize(felts)?;
Expand Down
16 changes: 13 additions & 3 deletions crates/dojo/world/src/contracts/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use cainome::cairo_serde::{CairoSerde as _, ContractAddress, Error as CainomeErr
use dojo_types::packing::{PackingError, ParseError};
use dojo_types::primitive::{Primitive, PrimitiveError};
use dojo_types::schema::{Enum, EnumOption, Member, Struct, Ty};
use starknet::core::types::Felt;
use starknet::core::types::{BlockId, Felt};
use starknet::core::utils::{
cairo_short_string_to_felt, parse_cairo_short_string, CairoShortStringToFeltError,
NonAsciiNameError, ParseCairoShortStringError,
Expand Down Expand Up @@ -86,13 +86,22 @@ where
namespace: &str,
name: &str,
world: &'a WorldContractReader<P>,
) -> Result<ModelRPCReader<'a, P>, ModelError> {
Self::new_with_block(namespace, name, world, world.block_id).await
}

pub async fn new_with_block(
namespace: &str,
name: &str,
world: &'a WorldContractReader<P>,
block_id: BlockId,
) -> Result<ModelRPCReader<'a, P>, ModelError> {
let model_selector = naming::compute_selector_from_names(namespace, name);

// Events are also considered like models from a off-chain perspective. They both have
// introspection and convey type information.
let (contract_address, class_hash) =
match world.resource(&model_selector).block_id(world.block_id).call().await? {
match world.resource(&model_selector).block_id(block_id).call().await? {
abigen::world::Resource::Model((address, hash)) => (address, hash),
abigen::world::Resource::Event((address, hash)) => (address, hash),
_ => return Err(ModelError::ModelNotFound),
Expand All @@ -104,7 +113,8 @@ where
return Err(ModelError::ModelNotFound);
}

let model_reader = ModelContractReader::new(contract_address.into(), world.provider());
let mut model_reader = ModelContractReader::new(contract_address.into(), world.provider());
model_reader.set_block(block_id);

Ok(Self {
namespace: namespace.into(),
Expand Down
10 changes: 10 additions & 0 deletions crates/dojo/world/src/contracts/world.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::result::Result;

use starknet::core::types::BlockId;
use starknet::providers::Provider;

pub use super::abigen::world::{
Expand Down Expand Up @@ -33,4 +34,13 @@ where
) -> Result<ModelRPCReader<'_, P>, ModelError> {
ModelRPCReader::new(namespace, name, self).await
}

pub async fn model_reader_with_block(
&self,
namespace: &str,
name: &str,
block_id: BlockId,
) -> Result<ModelRPCReader<'_, P>, ModelError> {
ModelRPCReader::new_with_block(namespace, name, self, block_id).await
}
}
6 changes: 3 additions & 3 deletions crates/torii/core/src/processors/register_event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use async_trait::async_trait;
use dojo_world::contracts::abigen::world::Event as WorldEvent;
use dojo_world::contracts::model::ModelReader;
use dojo_world::contracts::world::WorldContractReader;
use starknet::core::types::Event;
use starknet::core::types::{BlockId, Event};
use starknet::providers::Provider;
use tracing::{debug, info};

Expand Down Expand Up @@ -34,7 +34,7 @@ where
&self,
world: &WorldContractReader<P>,
db: &mut Sql,
_block_number: u64,
block_number: u64,
block_timestamp: u64,
_event_id: &str,
event: &Event,
Expand All @@ -59,7 +59,7 @@ where

// Called model here by language, but it's an event. Torii rework will make clear
// distinction.
let model = world.model_reader(&namespace, &name).await?;
let model = world.model_reader_with_block(&namespace, &name, BlockId::Number(block_number)).await?;
let schema = model.schema().await?;
let layout = model.layout().await?;

Expand Down
6 changes: 3 additions & 3 deletions crates/torii/core/src/processors/register_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use async_trait::async_trait;
use dojo_world::contracts::abigen::world::Event as WorldEvent;
use dojo_world::contracts::model::ModelReader;
use dojo_world::contracts::world::WorldContractReader;
use starknet::core::types::Event;
use starknet::core::types::{BlockId, Event};
use starknet::providers::Provider;
use tracing::{debug, info};

Expand Down Expand Up @@ -34,7 +34,7 @@ where
&self,
world: &WorldContractReader<P>,
db: &mut Sql,
_block_number: u64,
block_number: u64,
block_timestamp: u64,
_event_id: &str,
event: &Event,
Expand All @@ -57,7 +57,7 @@ where
let namespace = event.namespace.to_string().unwrap();
let name = event.name.to_string().unwrap();

let model = world.model_reader(&namespace, &name).await?;
let model = world.model_reader_with_block(&namespace, &name, BlockId::Number(block_number)).await?;
let schema = model.schema().await?;
let layout = model.layout().await?;

Expand Down
7 changes: 5 additions & 2 deletions crates/torii/core/src/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ impl Sql {
pool: Pool<Sqlite>,
executor: UnboundedSender<QueryMessage>,
contracts: &HashMap<Felt, ContractType>,
model_cache: Arc<ModelCache>,
) -> Result<Self> {
for contract in contracts {
executor.send(QueryMessage::other(
Expand All @@ -78,7 +79,7 @@ impl Sql {
let db = Self {
pool: pool.clone(),
executor,
model_cache: Arc::new(ModelCache::new(pool.clone())),
model_cache,
local_cache,
};

Expand Down Expand Up @@ -325,6 +326,8 @@ impl Sql {
)
.await;

println!("selector: {:?}", selector);
println!("set model cache: {:?}", model);
Ok(())
}

Expand Down Expand Up @@ -768,7 +771,7 @@ impl Sql {
match &member.ty {
Ty::Primitive(ty) => {
columns.push(format!("external_{}", &member.name));
arguments.push(Argument::String(ty.to_sql_value().unwrap()));
arguments.push(Argument::String(ty.to_sql_value()));
}
Ty::Enum(e) => {
columns.push(format!("external_{}", &member.name));
Expand Down
10 changes: 6 additions & 4 deletions crates/torii/grpc/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ impl DojoWorld {
block_rx: Receiver<u64>,
world_address: Felt,
provider: Arc<JsonRpcClient<HttpTransport>>,
model_cache: Arc<ModelCache>,
) -> Self {
let model_cache = Arc::new(ModelCache::new(pool.clone()));
let entity_manager = Arc::new(EntityManager::default());
let event_message_manager = Arc::new(EventMessageManager::default());
let event_manager = Arc::new(EventManager::default());
Expand Down Expand Up @@ -624,7 +624,7 @@ impl DojoWorld {
Some(ValueType::String(value)) => value,
Some(ValueType::Primitive(value)) => {
let primitive: Primitive = value.try_into()?;
primitive.to_sql_value()?
primitive.to_sql_value()
}
None => return Err(QueryError::MissingParam("value_type".into()).into()),
};
Expand Down Expand Up @@ -969,6 +969,7 @@ fn map_row_to_entity(
schemas: &[Ty],
dont_include_hashed_keys: bool,
) -> Result<proto::types::Entity, Error> {
println!("schemas: {:?}", schemas);
let hashed_keys = Felt::from_str(&row.get::<String, _>("id")).map_err(ParseError::FromStr)?;
let models = schemas
.iter()
Expand Down Expand Up @@ -1060,7 +1061,7 @@ fn build_composite_clause(
Some(ValueType::String(value)) => value,
Some(ValueType::Primitive(value)) => {
let primitive: Primitive = value.try_into()?;
primitive.to_sql_value()?
primitive.to_sql_value()
}
None => return Err(QueryError::MissingParam("value_type".into()).into()),
};
Expand Down Expand Up @@ -1368,6 +1369,7 @@ pub async fn new(
block_rx: Receiver<u64>,
world_address: Felt,
provider: Arc<JsonRpcClient<HttpTransport>>,
model_cache: Arc<ModelCache>,
) -> Result<
(SocketAddr, impl Future<Output = Result<(), tonic::transport::Error>> + 'static),
std::io::Error,
Expand All @@ -1380,7 +1382,7 @@ pub async fn new(
.build()
.unwrap();

let world = DojoWorld::new(pool.clone(), block_rx, world_address, provider);
let world = DojoWorld::new(pool.clone(), block_rx, world_address, provider, model_cache);
let server = WorldServer::new(world)
.accept_compressed(CompressionEncoding::Gzip)
.send_compressed(CompressionEncoding::Gzip);
Expand Down

0 comments on commit ddd9921

Please sign in to comment.