diff --git a/pgx-examples/schemas/src/lib.rs b/pgx-examples/schemas/src/lib.rs index 8408159e7..34945059e 100644 --- a/pgx-examples/schemas/src/lib.rs +++ b/pgx-examples/schemas/src/lib.rs @@ -100,7 +100,7 @@ mod tests { #[pg_test] fn test_my_some_schema_type() { - Spi::connect(|c| { + Spi::connect(|mut c| { // "MySomeSchemaType" is in 'some_schema', so it needs to be discoverable c.update("SET search_path TO some_schema,public", None, None); assert_eq!( diff --git a/pgx-tests/src/tests/bgworker_tests.rs b/pgx-tests/src/tests/bgworker_tests.rs index 754faea75..c73a15423 100644 --- a/pgx-tests/src/tests/bgworker_tests.rs +++ b/pgx-tests/src/tests/bgworker_tests.rs @@ -25,7 +25,7 @@ pub extern "C" fn bgworker(arg: pg_sys::Datum) { if arg > 0 { BackgroundWorker::transaction(|| { Spi::run("CREATE TABLE tests.bgworker_test (v INTEGER);"); - Spi::connect(|client| { + Spi::connect(|mut client| { client.update( "INSERT INTO tests.bgworker_test VALUES ($1);", None, @@ -71,7 +71,7 @@ pub extern "C" fn bgworker_return_value(arg: pg_sys::Datum) { }; while BackgroundWorker::wait_latch(Some(Duration::from_millis(100))) {} BackgroundWorker::transaction(|| { - Spi::connect(|c| { + Spi::connect(|mut c| { c.update( "INSERT INTO tests.bgworker_test_return VALUES ($1)", None, diff --git a/pgx-tests/src/tests/spi_tests.rs b/pgx-tests/src/tests/spi_tests.rs index efc88c64b..6744fbb79 100644 --- a/pgx-tests/src/tests/spi_tests.rs +++ b/pgx-tests/src/tests/spi_tests.rs @@ -186,7 +186,7 @@ mod tests { #[pg_test] fn test_inserting_null() -> Result<(), pgx::spi::Error> { - Spi::connect(|client| { + Spi::connect(|mut client| { client.update("CREATE TABLE tests.null_test (id uuid)", None, None); }); assert_eq!( @@ -202,7 +202,7 @@ mod tests { #[pg_test] fn test_cursor() { - Spi::connect(|client| { + Spi::connect(|mut client| { client.update("CREATE TABLE tests.cursor_table (id int)", None, None); client.update( "INSERT INTO tests.cursor_table (id) \ @@ -224,7 +224,7 @@ mod tests { #[pg_test] fn test_cursor_prepared_statement() -> Result<(), pgx::spi::Error> { - Spi::connect(|client| { + Spi::connect(|mut client| { client.update("CREATE TABLE tests.cursor_table (id int)", None, None); client.update( "INSERT INTO tests.cursor_table (id) \ @@ -248,7 +248,7 @@ mod tests { #[pg_test] fn test_cursor_by_name() -> Result<(), pgx::spi::Error> { - let cursor_name = Spi::connect(|client| { + let cursor_name = Spi::connect(|mut client| { client.update("CREATE TABLE tests.cursor_table (id int)", None, None); client.update( "INSERT INTO tests.cursor_table (id) \ @@ -308,7 +308,7 @@ mod tests { assert_eq!(res.column_name(2).unwrap(), "b"); }); - Spi::connect(|client| { + Spi::connect(|mut client| { let res = client.update("SET TIME ZONE 'PST8PDT'", None, None); assert_eq!(0, res.columns()); @@ -324,7 +324,7 @@ mod tests { #[pg_test] fn test_spi_non_mut() -> Result<(), pgx::spi::Error> { // Ensures update and cursor APIs do not need mutable reference to SpiClient - Spi::connect(|client| { + Spi::connect(|mut client| { client.update("SELECT 1", None, None); let cursor = client.open_cursor("SELECT 1", None)?.detach_into_name(); client.find_cursor(&cursor).map(|_| ()) @@ -413,4 +413,20 @@ mod tests { fn test_option() { assert!(Spi::get_one::("SELECT NULL::integer").unwrap().is_none()); } + + #[pg_test(error = "CREATE TABLE is not allowed in a non-volatile function")] + fn test_readwrite_in_readonly() { + // This is supposed to run in read-only + Spi::connect(|client| client.select("CREATE TABLE a ()", None, None)); + } + + #[pg_test] + fn test_readwrite_in_select_readwrite() { + Spi::connect(|mut client| { + // This is supposed to switch connection to read-write and run it there + client.update("CREATE TABLE a (id INT)", None, None); + // This is supposed to run in read-write + client.select("INSERT INTO a VALUES (1)", None, None); + }); + } } diff --git a/pgx-tests/src/tests/srf_tests.rs b/pgx-tests/src/tests/srf_tests.rs index 8e23abbae..be6a66a07 100644 --- a/pgx-tests/src/tests/srf_tests.rs +++ b/pgx-tests/src/tests/srf_tests.rs @@ -177,7 +177,7 @@ mod tests { #[pg_test] fn test_srf_setof_datum_detoasting_with_borrow() { - let cnt = Spi::connect(|client| { + let cnt = Spi::connect(|mut client| { // build up a table with one large column that Postgres will be forced to TOAST client.update("CREATE TABLE test_srf_datum_detoasting AS SELECT array_to_string(array_agg(g),' ') s FROM (SELECT 'a' g FROM generate_series(1, 1000000)) x;", None, None); @@ -195,7 +195,7 @@ mod tests { #[pg_test] fn test_srf_table_datum_detoasting_with_borrow() { - let cnt = Spi::connect(|client| { + let cnt = Spi::connect(|mut client| { // build up a table with one large column that Postgres will be forced to TOAST client.update("CREATE TABLE test_srf_datum_detoasting AS SELECT array_to_string(array_agg(g),' ') s FROM (SELECT 'a' g FROM generate_series(1, 1000000)) x;", None, None); diff --git a/pgx-tests/src/tests/struct_type_tests.rs b/pgx-tests/src/tests/struct_type_tests.rs index da814d243..d8bb0f10c 100644 --- a/pgx-tests/src/tests/struct_type_tests.rs +++ b/pgx-tests/src/tests/struct_type_tests.rs @@ -124,7 +124,7 @@ mod tests { #[pg_test] fn test_complex_storage_and_retrieval() -> Result<(), pgx::spi::Error> { - let complex = Spi::connect(|client| { + let complex = Spi::connect(|mut client| { client.update( "CREATE TABLE complex_test AS SELECT s as id, (s || '.0, 2.0' || s)::complex as value FROM generate_series(1, 1000) s;\ SELECT value FROM complex_test ORDER BY id;", None, None).first().get_one::>() diff --git a/pgx/src/spi.rs b/pgx/src/spi.rs index 707da3c4b..12f69b5a2 100644 --- a/pgx/src/spi.rs +++ b/pgx/src/spi.rs @@ -131,7 +131,22 @@ pub enum Error { pub struct Spi; // TODO: should `'conn` be invariant? -pub struct SpiClient<'conn>(PhantomData<&'conn SpiConnection>); +pub struct SpiClient<'conn> { + phantom: PhantomData<&'conn SpiConnection>, + // This field indicates whether queries be readonly. Unless any `update` has been used + // `readonly` will be `true`. + // Postgres docs say: + // + // It is generally unwise to mix read-only and read-write commands within a single function + // using SPI; that could result in very confusing behavior, since the read-only queries + // would not see the results of any database updates done by the read-write queries. + // + // TODO: Alternatively, we can detect if the command counter (or something?) has incremented and if yes + // then we set read_only=false, else we can set it to true. + // However, we would still need to remember the previous value, which will be larger than the boolean. + // So, unless somebody will send commands to Postgres bypassing this SPI API, this flag seems sufficient. + readonly: bool, +} /// a struct to manage our SPI connection lifetime struct SpiConnection(PhantomData<*mut ()>); @@ -156,7 +171,7 @@ impl Drop for SpiConnection { impl SpiConnection { /// Return a client that with a lifetime scoped to this connection. fn client(&self) -> SpiClient<'_> { - SpiClient(PhantomData) + SpiClient { phantom: PhantomData, readonly: true } } } @@ -173,7 +188,6 @@ pub trait Query { fn execute( self, client: &SpiClient, - read_only: bool, limit: Option, arguments: Self::Arguments, ) -> Self::Result; @@ -193,11 +207,10 @@ impl<'a> Query for &'a String { fn execute( self, client: &SpiClient, - read_only: bool, limit: Option, arguments: Self::Arguments, ) -> Self::Result { - self.as_str().execute(client, read_only, limit, arguments) + self.as_str().execute(client, limit, arguments) } fn open_cursor<'c: 'cc, 'cc>( @@ -222,8 +235,7 @@ impl<'a> Query for &'a str { fn execute( self, - _client: &SpiClient, - read_only: bool, + client: &SpiClient, limit: Option, arguments: Self::Arguments, ) -> Self::Result { @@ -249,13 +261,15 @@ impl<'a> Query for &'a str { argtypes.as_mut_ptr(), datums.as_mut_ptr(), nulls.as_ptr(), - read_only, + client.readonly, limit.unwrap_or(0), ) } } // SAFETY: arguments are prepared above - None => unsafe { pg_sys::SPI_execute(src.as_ptr(), read_only, limit.unwrap_or(0)) }, + None => unsafe { + pg_sys::SPI_execute(src.as_ptr(), client.readonly, limit.unwrap_or(0)) + }, }; SpiClient::prepare_tuple_table(status_code) @@ -263,7 +277,7 @@ impl<'a> Query for &'a str { fn open_cursor<'c: 'cc, 'cc>( self, - _client: &'cc SpiClient<'c>, + client: &'cc SpiClient<'c>, args: Self::Arguments, ) -> Result, Error> { let src = std::ffi::CString::new(self).expect("query contained a null byte"); @@ -283,12 +297,12 @@ impl<'a> Query for &'a str { argtypes.as_mut_ptr(), datums.as_mut_ptr(), nulls.as_ptr(), - false, + client.readonly, 0, ) }) .ok_or(Error::PortalIsNull)?; - Ok(SpiCursor { ptr, _phantom: PhantomData }) + Ok(SpiCursor { ptr, __marker: PhantomData }) } } @@ -316,13 +330,13 @@ pub struct SpiHeapTupleData { impl Spi { pub fn get_one(query: &str) -> Result, Error> { - Spi::connect(|client| client.select(query, Some(1), None).first().get_one()) + Spi::connect(|mut client| client.update(query, Some(1), None).first().get_one()) } pub fn get_two( query: &str, ) -> Result<(Option, Option), Error> { - Spi::connect(|client| client.select(query, Some(1), None).first().get_two::()) + Spi::connect(|mut client| client.update(query, Some(1), None).first().get_two::()) } pub fn get_three< @@ -332,21 +346,25 @@ impl Spi { >( query: &str, ) -> Result<(Option, Option, Option), Error> { - Spi::connect(|client| client.select(query, Some(1), None).first().get_three::()) + Spi::connect(|mut client| { + client.update(query, Some(1), None).first().get_three::() + }) } pub fn get_one_with_args( query: &str, args: Vec<(PgOid, Option)>, ) -> Result, Error> { - Spi::connect(|client| client.select(query, Some(1), Some(args)).first().get_one()) + Spi::connect(|mut client| client.update(query, Some(1), Some(args)).first().get_one()) } pub fn get_two_with_args( query: &str, args: Vec<(PgOid, Option)>, ) -> Result<(Option, Option), Error> { - Spi::connect(|client| client.select(query, Some(1), Some(args)).first().get_two::()) + Spi::connect(|mut client| { + client.update(query, Some(1), Some(args)).first().get_two::() + }) } pub fn get_three_with_args< @@ -357,8 +375,8 @@ impl Spi { query: &str, args: Vec<(PgOid, Option)>, ) -> Result<(Option, Option, Option), Error> { - Spi::connect(|client| { - client.select(query, Some(1), Some(args)).first().get_three::() + Spi::connect(|mut client| { + client.update(query, Some(1), Some(args)).first().get_three::() }) } @@ -377,7 +395,7 @@ impl Spi { /// /// The statement runs in read/write mode pub fn run_with_args(query: &str, args: Option)>>) { - Spi::connect(|client| { + Spi::connect(|mut client| { client.update(query, None, args); }) } @@ -392,7 +410,7 @@ impl Spi { query: &str, args: Option)>>, ) -> Result { - Spi::connect(|client| { + Spi::connect(|mut client| { let table = client.update(&format!("EXPLAIN (format json) {}", query), None, args).first(); Ok(table.get_one::()?.unwrap()) @@ -452,33 +470,22 @@ impl Spi { impl<'a> SpiClient<'a> { /// perform a SELECT statement pub fn select(&self, query: Q, limit: Option, args: Q::Arguments) -> Q::Result { - // Postgres docs say: - // - // It is generally unwise to mix read-only and read-write commands within a single function - // using SPI; that could result in very confusing behavior, since the read-only queries - // would not see the results of any database updates done by the read-write queries. - // - // As such, we don't actually set read-only to true here - - // TODO: can we detect if the command counter (or something?) has incremented and if yes - // then we set read_only=false, else we can set it to true? - // Is this even a good idea? - self.execute(query, false, limit, args) + self.execute(query, limit, args) } /// perform any query (including utility statements) that modify the database in some way - pub fn update(&self, query: Q, limit: Option, args: Q::Arguments) -> Q::Result { - self.execute(query, false, limit, args) - } - - fn execute( - &self, + pub fn update( + &mut self, query: Q, - read_only: bool, limit: Option, args: Q::Arguments, ) -> Q::Result { - query.execute(&self, read_only, limit, args) + self.readonly = false; + self.execute(query, limit, args) + } + + fn execute(&self, query: Q, limit: Option, args: Q::Arguments) -> Q::Result { + query.execute(&self, limit, args) } fn prepare_tuple_table(status_code: i32) -> SpiTupleTable { @@ -502,11 +509,21 @@ impl<'a> SpiClient<'a> { /// Rows may be then fetched using [`SpiCursor::fetch`]. /// /// See [`SpiCursor`] docs for usage details. - pub fn open_cursor( - &self, + pub fn open_cursor(&self, query: Q, args: Q::Arguments) -> Result { + query.open_cursor(&self, args) + } + + /// Set up a cursor that will execute the specified update (mutating) query + /// + /// Rows may be then fetched using [`SpiCursor::fetch`]. + /// + /// See [`SpiCursor`] docs for usage details. + pub fn open_cursor_mut( + &mut self, query: Q, args: Q::Arguments, - ) -> Result, Error> { + ) -> Result { + self.readonly = false; query.open_cursor(&self, args) } @@ -522,7 +539,7 @@ impl<'a> SpiClient<'a> { let ptr = NonNull::new(unsafe { pg_sys::SPI_cursor_find(name.as_pg_cstr()) }) .ok_or(Error::CursorNotFound(name.to_string()))?; - Ok(SpiCursor { ptr, _phantom: PhantomData }) + Ok(SpiCursor { ptr, __marker: PhantomData }) } } @@ -581,7 +598,7 @@ type CursorName = String; /// ``` pub struct SpiCursor<'client> { ptr: NonNull, - _phantom: PhantomData<&'client SpiClient<'client>>, + __marker: PhantomData<&'client SpiClient<'client>>, } impl SpiCursor<'_> { @@ -665,11 +682,10 @@ impl<'a> Query for &'a OwnedPreparedStatement { fn execute( self, client: &SpiClient, - read_only: bool, limit: Option, arguments: Self::Arguments, ) -> Self::Result { - (&self.0).execute(client, read_only, limit, arguments) + (&self.0).execute(client, limit, arguments) } fn open_cursor<'c: 'cc, 'cc>( @@ -688,11 +704,10 @@ impl Query for OwnedPreparedStatement { fn execute( self, client: &SpiClient, - read_only: bool, limit: Option, arguments: Self::Arguments, ) -> Self::Result { - (&self.0).execute(client, read_only, limit, arguments) + (&self.0).execute(client, limit, arguments) } fn open_cursor<'c: 'cc, 'cc>( @@ -725,8 +740,7 @@ impl<'a: 'b, 'b> Query for &'b PreparedStatement<'a> { fn execute( self, - _client: &SpiClient, - read_only: bool, + client: &SpiClient, limit: Option, arguments: Self::Arguments, ) -> Self::Result { @@ -751,7 +765,7 @@ impl<'a: 'b, 'b> Query for &'b PreparedStatement<'a> { self.plan, datums.as_mut_ptr(), nulls.as_mut_ptr(), - read_only, + client.readonly, limit.unwrap_or(0), ) }; @@ -761,7 +775,7 @@ impl<'a: 'b, 'b> Query for &'b PreparedStatement<'a> { fn open_cursor<'c: 'cc, 'cc>( self, - _client: &'cc SpiClient<'c>, + client: &'cc SpiClient<'c>, args: Self::Arguments, ) -> Result, Error> { let args = args.unwrap_or_default(); @@ -775,11 +789,11 @@ impl<'a: 'b, 'b> Query for &'b PreparedStatement<'a> { self.plan, datums.as_mut_ptr(), nulls.as_ptr(), - false, + client.readonly, ) }) .ok_or(Error::PortalIsNull)?; - Ok(SpiCursor { ptr, _phantom: PhantomData }) + Ok(SpiCursor { ptr, __marker: PhantomData }) } } @@ -790,11 +804,10 @@ impl<'a> Query for PreparedStatement<'a> { fn execute( self, client: &SpiClient, - read_only: bool, limit: Option, arguments: Self::Arguments, ) -> Self::Result { - (&self).execute(client, read_only, limit, arguments) + (&self).execute(client, limit, arguments) } fn open_cursor<'c: 'cc, 'cc>(