From 0564fc483195e9f4a374936cdac9a659faed3a7f Mon Sep 17 00:00:00 2001 From: Yurii Rashkovskii Date: Thu, 24 Nov 2022 12:22:56 -0800 Subject: [PATCH] Problem: text type representation is not always efficient For this reason, Postgres allows types to have an external binary representation. Also, some clients insist on using binary representation. Solution: introduce SendRecvFuncs trait and `sendrecvfuncs` attribute These are used to specify how external binary representation encoding is accomplished. --- Cargo.lock | 1 + pgx-examples/custom_types/README.md | 23 +++ pgx-macros/src/lib.rs | 43 +++++- pgx-tests/Cargo.toml | 3 +- pgx-tests/src/tests/mod.rs | 1 + pgx-tests/src/tests/postgres_type_tests.rs | 61 +++++++- pgx-tests/src/tests/stringinfo_tests.rs | 41 ++++++ pgx-utils/src/sql_entity_graph/mod.rs | 13 +- .../sql_entity_graph/postgres_type/entity.rs | 131 ++++++++++++++++-- .../src/sql_entity_graph/postgres_type/mod.rs | 64 ++++++++- pgx/src/lib.rs | 2 + pgx/src/sendrecvfuncs.rs | 24 ++++ pgx/src/stringinfo.rs | 48 +++++++ 13 files changed, 426 insertions(+), 29 deletions(-) create mode 100644 pgx-tests/src/tests/stringinfo_tests.rs create mode 100644 pgx/src/sendrecvfuncs.rs diff --git a/Cargo.lock b/Cargo.lock index fa56b08ecc..c0dff7ab9a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1513,6 +1513,7 @@ dependencies = [ name = "pgx-tests" version = "0.6.0-alpha.2" dependencies = [ + "bytes", "eyre", "libc", "once_cell", diff --git a/pgx-examples/custom_types/README.md b/pgx-examples/custom_types/README.md index b72a48c6d0..ca7ee16276 100644 --- a/pgx-examples/custom_types/README.md +++ b/pgx-examples/custom_types/README.md @@ -110,6 +110,29 @@ fn do_a_thing(mut input: PgVarlena) -> PgVarlena { } ``` +## External Binary Representation + +PostgreSQL allows types to have an external binary representation for more efficient communication with +clients (as a matter of fact, Rust's [https://crates.io/crates/postgres](postgres) crate uses binary types +exclusively). By default, `PostgresType` do not have any external binary representation, however, this can +be done by specifying `#[sendrecvfuncs]` attribute on the type and implementing `SendRecvFuncs` trait: + +```rust +#[derive(PostgresType, Serialize, Deserialize, Debug, PartialEq)] +#[sendrecvfuncs] +pub struct BinaryEncodedType(Vec); + +impl SendRecvFuncs for BinaryEncodedType { + fn send(&self) -> Vec { + self.0.clone() + } + + fn recv(buffer: &[u8]) -> Self { + Self(buffer.to_vec()) + } +} +``` + ## Notes - For serde-compatible types, you can use the `#[inoutfuncs]` annotation (instead of `#[pgvarlena_inoutfuncs]`) if you'd diff --git a/pgx-macros/src/lib.rs b/pgx-macros/src/lib.rs index ad439ead67..cd4b20e61c 100644 --- a/pgx-macros/src/lib.rs +++ b/pgx-macros/src/lib.rs @@ -681,9 +681,13 @@ Optionally accepts the following attributes: * `inoutfuncs(some_in_fn, some_out_fn)`: Define custom in/out functions for the type. * `pgvarlena_inoutfuncs(some_in_fn, some_out_fn)`: Define custom in/out functions for the `PgVarlena` of this type. +* `sendrecvfuncs`: Define binary send/receive functions for the type. * `sql`: Same arguments as [`#[pgx(sql = ..)]`](macro@pgx). */ -#[proc_macro_derive(PostgresType, attributes(inoutfuncs, pgvarlena_inoutfuncs, requires, pgx))] +#[proc_macro_derive( + PostgresType, + attributes(inoutfuncs, pgvarlena_inoutfuncs, sendrecvfuncs, requires, pgx) +)] pub fn postgres_type(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as syn::DeriveInput); @@ -696,6 +700,8 @@ fn impl_postgres_type(ast: DeriveInput) -> proc_macro2::TokenStream { let has_lifetimes = generics.lifetimes().next(); let funcname_in = Ident::new(&format!("{}_in", name).to_lowercase(), name.span()); let funcname_out = Ident::new(&format!("{}_out", name).to_lowercase(), name.span()); + let funcname_send = Ident::new(&format!("{}_send", name).to_lowercase(), name.span()); + let funcname_recv = Ident::new(&format!("{}_recv", name).to_lowercase(), name.span()); let mut args = parse_postgres_type_args(&ast.attrs); let mut stream = proc_macro2::TokenStream::new(); @@ -710,7 +716,7 @@ fn impl_postgres_type(ast: DeriveInput) -> proc_macro2::TokenStream { _ => panic!("#[derive(PostgresType)] can only be applied to structs or enums"), } - if args.is_empty() { + if !args.contains(&PostgresTypeAttribute::InOutFuncs) && !args.contains(&PostgresTypeAttribute::PgVarlenaInOutFuncs) { // assume the user wants us to implement the InOutFuncs args.insert(PostgresTypeAttribute::Default); } @@ -803,7 +809,34 @@ fn impl_postgres_type(ast: DeriveInput) -> proc_macro2::TokenStream { }); } - let sql_graph_entity_item = PostgresType::from_derive_input(ast).unwrap(); + if args.contains(&PostgresTypeAttribute::SendReceiveFuncs) { + stream.extend(quote! { + #[doc(hidden)] + #[pg_extern(immutable,parallel_safe,strict)] + pub fn #funcname_recv #generics(input: ::pgx::Internal) -> #name #generics { + let mut buffer0 = unsafe { + input + .get_mut::<::pgx::pg_sys::StringInfoData>() + .expect("Can't retrieve StringInfo pointer") + }; + let mut buffer = StringInfo::from_pg(buffer0 as *mut _).expect("failed to construct StringInfo"); + let slice = buffer.read(..).expect("failure reading StringInfo"); + ::pgx::SendRecvFuncs::recv(slice) + } + + #[doc(hidden)] + #[pg_extern(immutable,parallel_safe,strict)] + pub fn #funcname_send #generics(input: #name #generics) -> Vec { + ::pgx::SendRecvFuncs::send(&input) + } + }); + } + + let sql_graph_entity_item = PostgresType::from_derive_input( + ast, + args.contains(&PostgresTypeAttribute::SendReceiveFuncs), + ) + .unwrap(); sql_graph_entity_item.to_tokens(&mut stream); stream @@ -895,6 +928,7 @@ fn impl_guc_enum(ast: DeriveInput) -> proc_macro2::TokenStream { enum PostgresTypeAttribute { InOutFuncs, PgVarlenaInOutFuncs, + SendReceiveFuncs, Default, } @@ -912,6 +946,9 @@ fn parse_postgres_type_args(attributes: &[Attribute]) -> HashSet { categorized_attributes.insert(PostgresTypeAttribute::PgVarlenaInOutFuncs); } + "sendrecvfuncs" => { + categorized_attributes.insert(PostgresTypeAttribute::SendReceiveFuncs); + } _ => { // we can just ignore attributes we don't understand diff --git a/pgx-tests/Cargo.toml b/pgx-tests/Cargo.toml index f3da270ffc..600f7ffc3a 100644 --- a/pgx-tests/Cargo.toml +++ b/pgx-tests/Cargo.toml @@ -20,7 +20,7 @@ pg12 = [ "pgx/pg12" ] pg13 = [ "pgx/pg13" ] pg14 = [ "pgx/pg14" ] pg15 = [ "pgx/pg15" ] -pg_test = [ ] +pg_test = [ "bytes" ] [package.metadata.docs.rs] features = ["pg14"] @@ -44,6 +44,7 @@ serde_json = "1.0.88" time = "0.3.17" eyre = "0.6.8" thiserror = "1.0" +bytes = { version = "1.2.1", optional = true } [dependencies.pgx] path = "../pgx" diff --git a/pgx-tests/src/tests/mod.rs b/pgx-tests/src/tests/mod.rs index 5924c59c42..eb4fe06084 100644 --- a/pgx-tests/src/tests/mod.rs +++ b/pgx-tests/src/tests/mod.rs @@ -39,6 +39,7 @@ mod schema_tests; mod shmem_tests; mod spi_tests; mod srf_tests; +mod stringinfo_tests; mod struct_type_tests; mod trigger_tests; mod uuid_tests; diff --git a/pgx-tests/src/tests/postgres_type_tests.rs b/pgx-tests/src/tests/postgres_type_tests.rs index 563b24eb4d..6a9c2efb8b 100644 --- a/pgx-tests/src/tests/postgres_type_tests.rs +++ b/pgx-tests/src/tests/postgres_type_tests.rs @@ -8,7 +8,7 @@ Use of this source code is governed by the MIT license that can be found in the */ use pgx::cstr_core::CStr; use pgx::prelude::*; -use pgx::{InOutFuncs, PgVarlena, PgVarlenaInOutFuncs, StringInfo}; +use pgx::{InOutFuncs, PgVarlena, PgVarlenaInOutFuncs, SendRecvFuncs, StringInfo}; use serde::{Deserialize, Serialize}; use std::str::FromStr; @@ -152,16 +152,32 @@ pub enum JsonEnumType { E2 { b: f32 }, } +#[derive(PostgresType, Serialize, Deserialize, Debug, PartialEq)] +#[sendrecvfuncs] +pub struct BinaryEncodedType(Vec); + +impl SendRecvFuncs for BinaryEncodedType { + fn send(&self) -> Vec { + self.0.clone() + } + + fn recv(buffer: &[u8]) -> Self { + Self(buffer.to_vec()) + } +} + #[cfg(any(test, feature = "pg_test"))] #[pgx::pg_schema] mod tests { + use std::error::Error; + use postgres::types::{FromSql, IsNull, ToSql, Type}; + use postgres::types::private::BytesMut; #[allow(unused_imports)] use crate as pgx_tests; - use crate::tests::postgres_type_tests::{ - CustomTextFormatSerializedEnumType, CustomTextFormatSerializedType, JsonEnumType, JsonType, - VarlenaEnumType, VarlenaType, - }; + use crate::tests::postgres_type_tests::{BinaryEncodedType, CustomTextFormatSerializedEnumType, + CustomTextFormatSerializedType, JsonEnumType, JsonType, + VarlenaEnumType, VarlenaType}; use pgx::prelude::*; use pgx::PgVarlena; @@ -246,4 +262,39 @@ mod tests { .expect("SPI returned NULL"); assert!(matches!(result, JsonEnumType::E1 { a } if a == 1.0)); } + + #[pg_test] + fn test_binary_encoded_type() { + impl ToSql for BinaryEncodedType { + fn to_sql(&self, _ty: &Type, out: &mut BytesMut) -> Result> where Self: Sized { + use bytes::BufMut; + out.put_slice(self.0.as_slice()); + Ok(IsNull::No) + } + + fn accepts(_ty: &Type) -> bool where Self: Sized { + true + } + + postgres::types::to_sql_checked!(); + } + + impl<'a> FromSql<'a> for BinaryEncodedType { + fn from_sql(_ty: &Type, raw: &'a [u8]) -> Result> { + Ok(Self(raw.to_vec())) + } + + fn accepts(_ty: &Type) -> bool { + true + } + } + + // postgres client uses binary types so we can use it to test this functionality + let (mut client, _) = pgx_tests::client().unwrap(); + let val = + BinaryEncodedType(vec![0,1,2]); + let result = client.query("SELECT $1::BinaryEncodedType", &[&val]).unwrap(); + let val1: BinaryEncodedType = result[0].get(0); + assert_eq!(val, val1); + } } diff --git a/pgx-tests/src/tests/stringinfo_tests.rs b/pgx-tests/src/tests/stringinfo_tests.rs new file mode 100644 index 0000000000..de71fb3df9 --- /dev/null +++ b/pgx-tests/src/tests/stringinfo_tests.rs @@ -0,0 +1,41 @@ +/* +Portions Copyright 2019-2021 ZomboDB, LLC. +Portions Copyright 2021-2022 Technology Concepts & Design, Inc. + +All rights reserved. + +Use of this source code is governed by the MIT license that can be found in the LICENSE file. +*/ + +#[cfg(any(test, feature = "pg_test"))] +#[pgx::pg_schema] +mod tests { + #[allow(unused_imports)] + use crate as pgx_tests; + + use pgx::*; + + #[pg_test] + fn test_string_info_read_full() { + let mut string_info = StringInfo::from(vec![1,2,3,4,5]); + assert_eq!(string_info.read(..), Some(&[1,2,3,4,5][..])); + assert_eq!(string_info.read(..), Some(&[][..])); + assert_eq!(string_info.read(..=1), None); + } + + #[pg_test] + fn test_string_info_read_offset() { + let mut string_info = StringInfo::from(vec![1,2,3,4,5]); + assert_eq!(string_info.read(1..), Some(&[2,3,4,5][..])); + assert_eq!(string_info.read(..), Some(&[][..])); + } + + #[pg_test] + fn test_string_info_read_cap() { + let mut string_info = StringInfo::from(vec![1,2,3,4,5]); + assert_eq!(string_info.read(..=1), Some(&[1][..])); + assert_eq!(string_info.read(1..=2), Some(&[3][..])); + assert_eq!(string_info.read(..), Some(&[4,5][..])); + } + +} diff --git a/pgx-utils/src/sql_entity_graph/mod.rs b/pgx-utils/src/sql_entity_graph/mod.rs index 8e3fad10a9..1c567e8672 100644 --- a/pgx-utils/src/sql_entity_graph/mod.rs +++ b/pgx-utils/src/sql_entity_graph/mod.rs @@ -205,7 +205,8 @@ impl ToSql for SqlGraphEntity { if context.graph.neighbors_undirected(context.externs.get(item).unwrap().clone()).any(|neighbor| { let neighbor_item = &context.graph[neighbor]; match neighbor_item { - SqlGraphEntity::Type(PostgresTypeEntity { in_fn, in_fn_module_path, out_fn, out_fn_module_path, .. }) => { + SqlGraphEntity::Type(PostgresTypeEntity { in_fn, in_fn_module_path, out_fn, out_fn_module_path, send_fn, recv_fn, + send_fn_module_path, recv_fn_module_path, .. }) => { let is_in_fn = item.full_path.starts_with(in_fn_module_path) && item.full_path.ends_with(in_fn); if is_in_fn { tracing::trace!(r#type = %neighbor_item.dot_identifier(), "Skipping, is an in_fn."); @@ -214,7 +215,15 @@ impl ToSql for SqlGraphEntity { if is_out_fn { tracing::trace!(r#type = %neighbor_item.dot_identifier(), "Skipping, is an out_fn."); } - is_in_fn || is_out_fn + let is_send_fn = send_fn.is_some() && item.full_path.starts_with(send_fn_module_path) && item.full_path.ends_with(send_fn.unwrap_or_default()); + if is_send_fn { + tracing::trace!(r#type = %neighbor_item.dot_identifier(), "Skipping, is an send_fn."); + } + let is_recv_fn = recv_fn.is_some() && item.full_path.starts_with(recv_fn_module_path) && item.full_path.ends_with(recv_fn.unwrap_or_default()); + if is_recv_fn { + tracing::trace!(r#type = %neighbor_item.dot_identifier(), "Skipping, is an recv_fn."); + } + is_in_fn || is_out_fn || is_send_fn || is_recv_fn }, _ => false, } diff --git a/pgx-utils/src/sql_entity_graph/postgres_type/entity.rs b/pgx-utils/src/sql_entity_graph/postgres_type/entity.rs index 2c761c5c2b..f844037fbd 100644 --- a/pgx-utils/src/sql_entity_graph/postgres_type/entity.rs +++ b/pgx-utils/src/sql_entity_graph/postgres_type/entity.rs @@ -37,6 +37,10 @@ pub struct PostgresTypeEntity { pub in_fn_module_path: String, pub out_fn: &'static str, pub out_fn_module_path: String, + pub send_fn: Option<&'static str>, + pub send_fn_module_path: String, + pub recv_fn: Option<&'static str>, + pub recv_fn_module_path: String, pub to_sql_config: ToSqlConfigEntity, } @@ -101,8 +105,12 @@ impl ToSql for PostgresTypeEntity { // - CREATE TYPE; // - CREATE FUNCTION _in; // - CREATE FUNCTION _out; + // - CREATE FUNCTION _send; (optional) + // - CREATE FUNCTION _recv; (optional) // - CREATE TYPE (...); + let mut functions = String::new(); + let in_fn_module_path = if !item.in_fn_module_path.is_empty() { item.in_fn_module_path.clone() } else { @@ -132,6 +140,8 @@ impl ToSql for PostgresTypeEntity { tracing::trace!(in_fn = ?in_fn_path, "Found matching `in_fn`"); let in_fn_sql = in_fn.to_sql(context)?; tracing::trace!(%in_fn_sql); + functions.push_str(in_fn_sql.as_str()); + functions.push('\n'); let out_fn_module_path = if !item.out_fn_module_path.is_empty() { item.out_fn_module_path.clone() @@ -162,6 +172,8 @@ impl ToSql for PostgresTypeEntity { tracing::trace!(out_fn = ?out_fn_path, "Found matching `out_fn`"); let out_fn_sql = out_fn.to_sql(context)?; tracing::trace!(%out_fn_sql); + functions.push_str(out_fn_sql.as_str()); + functions.push('\n'); let shell_type = format!( "\n\ @@ -177,30 +189,119 @@ impl ToSql for PostgresTypeEntity { ); tracing::trace!(sql = %shell_type); - let materialized_type = format!("\n\ + let full_path = item.full_path; + let file = item.file; + let line = item.line; + let schema = context.schema_prefix_for(&self_index); + let name = item.name; + let schema_prefix_in_fn = context.schema_prefix_for(&in_fn_graph_index); + let in_fn = item.in_fn; + let in_fn_path = in_fn_path; + let schema_prefix_out_fn = context.schema_prefix_for(&out_fn_graph_index); + let out_fn = item.out_fn; + let out_fn_path = out_fn_path; + + let materialized_type = match (item.send_fn, item.recv_fn) { + (Some(send_fn), Some(recv_fn)) => { + let send_fn_module_path = if !item.send_fn_module_path.is_empty() { + item.send_fn_module_path.clone() + } else { + item.module_path.to_string() // Presume a local + }; + let send_fn_path = format!( + "{module_path}{maybe_colons}{send_fn}", + module_path = send_fn_module_path, + maybe_colons = if !send_fn_module_path.is_empty() { "::" } else { "" }, + ); + let (_, _index) = context + .externs + .iter() + .find(|(k, _v)| (**k).full_path == send_fn_path.as_str()) + .ok_or_else(|| eyre::eyre!("Did not find `send_fn: {}`.", send_fn_path))?; + let (send_fn_graph_index, send_fn) = context + .graph + .neighbors_undirected(self_index) + .find_map(|neighbor| match &context.graph[neighbor] { + SqlGraphEntity::Function(func) if func.full_path == send_fn_path => { + Some((neighbor, func)) + } + _ => None, + }) + .ok_or_else(|| eyre!("Could not find send_fn graph entity."))?; + tracing::trace!(send_fn = ?send_fn_path, "Found matching `send_fn`"); + let send_fn_sql = send_fn.to_sql(context)?; + tracing::trace!(%send_fn_sql); + functions.push_str(send_fn_sql.as_str()); + functions.push('\n'); + + let recv_fn_module_path = if !item.recv_fn_module_path.is_empty() { + item.recv_fn_module_path.clone() + } else { + item.module_path.to_string() // Presume a local + }; + let recv_fn_path = format!( + "{module_path}{maybe_colons}{recv_fn}", + module_path = recv_fn_module_path, + maybe_colons = if !recv_fn_module_path.is_empty() { "::" } else { "" }, + ); + let (_, _index) = context + .externs + .iter() + .find(|(k, _v)| (**k).full_path == recv_fn_path.as_str()) + .ok_or_else(|| eyre::eyre!("Did not find `recv_fn: {}`.", recv_fn_path))?; + let (recv_fn_graph_index, recv_fn) = context + .graph + .neighbors_undirected(self_index) + .find_map(|neighbor| match &context.graph[neighbor] { + SqlGraphEntity::Function(func) if func.full_path == recv_fn_path => { + Some((neighbor, func)) + } + _ => None, + }) + .ok_or_else(|| eyre!("Could not find recv_fn graph entity."))?; + tracing::trace!(recv_fn = ?recv_fn_path, "Found matching `recv_fn`"); + let recv_fn_sql = recv_fn.to_sql(context)?; + tracing::trace!(%recv_fn_sql); + functions.push_str(recv_fn_sql.as_str()); + functions.push('\n'); + + let schema_prefix_send_fn = context.schema_prefix_for(&send_fn_graph_index); + let send_fn = item.send_fn.unwrap(); + let send_fn_path = send_fn_path; + let schema_prefix_recv_fn = context.schema_prefix_for(&recv_fn_graph_index); + let recv_fn = item.recv_fn.unwrap(); + let recv_fn_path = recv_fn_path; + format!("\n\ -- {file}:{line}\n\ -- {full_path}\n\ CREATE TYPE {schema}{name} (\n\ \tINTERNALLENGTH = variable,\n\ \tINPUT = {schema_prefix_in_fn}{in_fn}, /* {in_fn_path} */\n\ \tOUTPUT = {schema_prefix_out_fn}{out_fn}, /* {out_fn_path} */\n\ + \tSEND = {schema_prefix_send_fn}{send_fn}, /* {send_fn_path} */\n\ + \tRECEIVE = {schema_prefix_recv_fn}{recv_fn}, /* {recv_fn_path} */\n\ \tSTORAGE = extended\n\ );\ - ", - full_path = item.full_path, - file = item.file, - line = item.line, - schema = context.schema_prefix_for(&self_index), - name = item.name, - schema_prefix_in_fn = context.schema_prefix_for(&in_fn_graph_index), - in_fn = item.in_fn, - in_fn_path = in_fn_path, - schema_prefix_out_fn = context.schema_prefix_for(&out_fn_graph_index), - out_fn = item.out_fn, - out_fn_path = out_fn_path, - ); + " + ) + } + _ => { + format!("\n\ + -- {file}:{line}\n\ + -- {full_path}\n\ + CREATE TYPE {schema}{name} (\n\ + \tINTERNALLENGTH = variable,\n\ + \tINPUT = {schema_prefix_in_fn}{in_fn}, /* {in_fn_path} */\n\ + \tOUTPUT = {schema_prefix_out_fn}{out_fn}, /* {out_fn_path} */\n\ + \tSTORAGE = extended\n\ + );\ + " + ) + } + }; + tracing::trace!(sql = %materialized_type); - Ok(shell_type + "\n" + &in_fn_sql + "\n" + &out_fn_sql + "\n" + &materialized_type) + Ok(shell_type + "\n" + &functions + &materialized_type) } } diff --git a/pgx-utils/src/sql_entity_graph/postgres_type/mod.rs b/pgx-utils/src/sql_entity_graph/postgres_type/mod.rs index 724f8d3b95..2b00a0ca57 100644 --- a/pgx-utils/src/sql_entity_graph/postgres_type/mod.rs +++ b/pgx-utils/src/sql_entity_graph/postgres_type/mod.rs @@ -51,6 +51,8 @@ pub struct PostgresType { generics: Generics, in_fn: Ident, out_fn: Ident, + send_fn: Option, + recv_fn: Option, to_sql_config: ToSqlConfig, } @@ -60,15 +62,20 @@ impl PostgresType { generics: Generics, in_fn: Ident, out_fn: Ident, + send_fn: Option, + recv_fn: Option, to_sql_config: ToSqlConfig, ) -> Result { if !to_sql_config.overrides_default() { crate::ident_is_acceptable_to_postgres(&name)?; } - Ok(Self { generics, name, in_fn, out_fn, to_sql_config }) + Ok(Self { generics, name, in_fn, out_fn, send_fn, recv_fn, to_sql_config }) } - pub fn from_derive_input(derive_input: DeriveInput) -> Result { + pub fn from_derive_input( + derive_input: DeriveInput, + sendrecv: bool, + ) -> Result { match derive_input.data { syn::Data::Struct(_) | syn::Data::Enum(_) => {} syn::Data::Union(_) => { @@ -85,11 +92,29 @@ impl PostgresType { &format!("{}_out", derive_input.ident).to_lowercase(), derive_input.ident.span(), ); + + let (funcname_send, funcname_recv) = if sendrecv { + ( + Some(Ident::new( + &format!("{}_send", derive_input.ident).to_lowercase(), + derive_input.ident.span(), + )), + Some(Ident::new( + &format!("{}_recv", derive_input.ident).to_lowercase(), + derive_input.ident.span(), + )), + ) + } else { + (None, None) + }; + Self::new( derive_input.ident, derive_input.generics, funcname_in, funcname_out, + funcname_send, + funcname_recv, to_sql_config, ) } @@ -104,7 +129,23 @@ impl Parse for PostgresType { Ident::new(&format!("{}_in", parsed.ident).to_lowercase(), parsed.ident.span()); let funcname_out = Ident::new(&format!("{}_out", parsed.ident).to_lowercase(), parsed.ident.span()); - Self::new(parsed.ident, parsed.generics, funcname_in, funcname_out, to_sql_config) + + let (mut send_fn, mut recv_fn) = (None, None); + + if parsed.attrs.iter().any(|attr| attr.path.is_ident("sendrecvfuncs")) { + send_fn.replace(Ident::new(&format!("{}_send", parsed.ident).to_lowercase(), parsed.ident.span())); + recv_fn.replace(Ident::new(&format!("{}_recv", parsed.ident).to_lowercase(), parsed.ident.span())); + } + + Self::new( + parsed.ident, + parsed.generics, + funcname_in, + funcname_out, + send_fn, + recv_fn, + to_sql_config, + ) } } @@ -146,6 +187,9 @@ impl ToTokens for PostgresType { let in_fn = &self.in_fn; let out_fn = &self.out_fn; + let send_fn = self.send_fn.as_ref().map(|s| quote! { Some(stringify!(#s)) }).unwrap_or_else(|| quote! { None }); + let recv_fn = self.recv_fn.as_ref().map(|s| quote! { Some(stringify!(#s)) }).unwrap_or_else(|| quote! { None }); + let sql_graph_entity_fn_name = syn::Ident::new(&format!("__pgx_internals_type_{}", self.name), Span::call_site()); @@ -210,6 +254,20 @@ impl ToTokens for PostgresType { let _ = path_items.pop(); // Drop the one we don't want. path_items.join("::") }, + send_fn: #send_fn, + send_fn_module_path: { + let send_fn = #send_fn.unwrap_or(""); + let mut path_items: Vec<_> = send_fn.split("::").collect(); + let _ = path_items.pop(); // Drop the one we don't want. + path_items.join("::") + }, + recv_fn: #recv_fn, + recv_fn_module_path: { + let recv_fn = #recv_fn.unwrap_or(""); + let mut path_items: Vec<_> = recv_fn.split("::").collect(); + let _ = path_items.pop(); // Drop the one we don't want. + path_items.join("::") + }, to_sql_config: #to_sql_config, }; ::pgx::utils::sql_entity_graph::SqlGraphEntity::Type(submission) diff --git a/pgx/src/lib.rs b/pgx/src/lib.rs index 12f30fad65..02d3978500 100644 --- a/pgx/src/lib.rs +++ b/pgx/src/lib.rs @@ -59,6 +59,7 @@ pub mod namespace; pub mod nodes; pub mod pgbox; pub mod rel; +pub mod sendrecvfuncs; pub mod shmem; pub mod spi; pub mod spinlock; @@ -92,6 +93,7 @@ pub use namespace::*; pub use nodes::*; pub use pgbox::*; pub use rel::*; +pub use sendrecvfuncs::*; pub use shmem::*; pub use spi::*; pub use stringinfo::*; diff --git a/pgx/src/sendrecvfuncs.rs b/pgx/src/sendrecvfuncs.rs new file mode 100644 index 0000000000..377ef3f938 --- /dev/null +++ b/pgx/src/sendrecvfuncs.rs @@ -0,0 +1,24 @@ +/* +Portions Copyright 2019-2021 ZomboDB, LLC. +Portions Copyright 2021-2022 Technology Concepts & Design, Inc. + +All rights reserved. + +Use of this source code is governed by the MIT license that can be found in the LICENSE file. +*/ + +//! Helper trait for the `#[derive(PostgresType)]` proc macro for overriding custom Postgres type +//! send/receive functions +//! + +/// `#[derive(PostgresType)]` types need to implement this trait to provide the binary +/// send/receive functions for that type. They also *must* specify `#[sendrecvfuncs]` attribute. +pub trait SendRecvFuncs { + /// Convert `Self` into a binary + fn send(&self) -> Vec; + + /// Given a binary representation of `Self`, parse it into a `Self`. + /// + /// It is expected that malformed input will raise an `error!()` or `panic!()` + fn recv(buffer: &[u8]) -> Self; +} diff --git a/pgx/src/stringinfo.rs b/pgx/src/stringinfo.rs index d4f54c905a..437563aa4b 100644 --- a/pgx/src/stringinfo.rs +++ b/pgx/src/stringinfo.rs @@ -11,7 +11,9 @@ Use of this source code is governed by the MIT license that can be found in the #![allow(dead_code, non_snake_case)] use crate::{pg_sys, void_mut_ptr}; +use std::collections::Bound; use std::io::Error; +use std::ops::RangeBounds; /// StringInfoData holds information about an extensible string that is allocated by Postgres' /// memory system, but generally follows Rust's drop semantics @@ -189,6 +191,52 @@ impl StringInfo { } } + /// Reads a range of bytes, modifying the underlying cursor to reflect what was read + /// + /// Returns None if the underlying remaining binary is smaller than requested with the range. + /// + /// Ranges can start from an offset, resulting in skipped information. + /// + /// Most common use-case for this is reading the underlying data in full: + /// + /// ```no_run + /// string_info.read(..) + /// ``` + pub fn read>(&mut self, range: R) -> Option<&[u8]> { + use std::ffi::c_int; + let remaining = unsafe { (*self.sid).len - (*self.sid).cursor } as usize; + let start = match range.start_bound() { + Bound::Included(bound) => *bound, + Bound::Excluded(bound) => *bound + 1, + Bound::Unbounded => 0, + }; + let end = match range.end_bound() { + Bound::Included(bound) => *bound, + Bound::Excluded(bound) => *bound - 1, + Bound::Unbounded => remaining, + }; + let total = end - start; + + if total > remaining { + return None; + } + + // safe: self.sid will never be null + Some(unsafe { + if (*self.sid).data.is_null() { + &[] + } else { + (*self.sid).cursor += start as c_int; + let result = std::slice::from_raw_parts( + (*self.sid).data.add((*self.sid).cursor as usize) as *const u8, + total, + ); + (*self.sid).cursor += total as c_int; + result + } + }) + } + /// A mutable `&[u8]` byte slice representation #[inline] pub fn as_bytes_mut(&mut self) -> &mut [u8] {