Skip to content
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: pgcentralfoundation/pgrx
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: d1b78d2f7c47ec0654185f01992c6987c1c1e823
Choose a base ref
..
head repository: pgcentralfoundation/pgrx
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: 652dfa8e3b25fd71cb2b50cd7c2a89901e849715
Choose a head ref
Showing with 166 additions and 0 deletions.
  1. +166 −0 pgx-tests/src/tests/spi_tests.rs
166 changes: 166 additions & 0 deletions pgx-tests/src/tests/spi_tests.rs
Original file line number Diff line number Diff line change
@@ -183,6 +183,172 @@ mod tests {
Spi::run("SELECT tests.do_panic();");
}

#[pg_test]
fn test_inserting_null() -> Result<(), pgx::spi::Error> {
Spi::connect(|mut client| {
client.update("CREATE TABLE tests.null_test (id uuid)", None, None);
});
assert_eq!(
Spi::get_one_with_args::<i32>(
"INSERT INTO tests.null_test VALUES ($1) RETURNING 1",
vec![(PgBuiltInOids::UUIDOID.oid(), None)],
)?
.unwrap(),
1
);
Ok(())
}

#[pg_test]
fn test_cursor() {
Spi::connect(|mut client| {
client.update("CREATE TABLE tests.cursor_table (id int)", None, None);
client.update(
"INSERT INTO tests.cursor_table (id) \
SELECT i FROM generate_series(1, 10) AS t(i)",
None,
None,
);
let mut portal = client.open_cursor("SELECT * FROM tests.cursor_table", None).unwrap();

fn sum_all(table: pgx::SpiTupleTable) -> i32 {
table.map(|r| r.by_ordinal(1).unwrap().value::<i32>().unwrap()).sum()
}
assert_eq!(sum_all(portal.fetch(3)), 1 + 2 + 3);
assert_eq!(sum_all(portal.fetch(3)), 4 + 5 + 6);
assert_eq!(sum_all(portal.fetch(3)), 7 + 8 + 9);
assert_eq!(sum_all(portal.fetch(3)), 10);
});
}

#[pg_test]
fn test_cursor_by_name() -> Result<(), pgx::spi::Error> {
let cursor_name = Spi::connect(|mut client| {
client.update("CREATE TABLE tests.cursor_table (id int)", None, None);
client.update(
"INSERT INTO tests.cursor_table (id) \
SELECT i FROM generate_series(1, 10) AS t(i)",
None,
None,
);
client.open_cursor("SELECT * FROM tests.cursor_table", None).map(|mut cursor| {
assert_eq!(sum_all(cursor.fetch(3)), 1 + 2 + 3);
cursor.detach_into_name()
})
})?;

fn sum_all(table: pgx::SpiTupleTable) -> i32 {
table.map(|r| r.by_ordinal(1).unwrap().value::<i32>().unwrap()).sum()
}
Spi::connect(|client| {
client.find_cursor(&cursor_name).map(|mut cursor| {
assert_eq!(sum_all(cursor.fetch(3)), 4 + 5 + 6);
assert_eq!(sum_all(cursor.fetch(3)), 7 + 8 + 9);
cursor.detach_into_name();
})
})?;

Spi::connect(|client| {
client.find_cursor(&cursor_name).map(|mut cursor| {
assert_eq!(sum_all(cursor.fetch(3)), 10);
})
})?;
Ok(())
}

#[pg_test(error = "syntax error at or near \"THIS\"")]
fn test_cursor_failure() {
Spi::connect(|client| client.open_cursor("THIS IS NOT SQL", None).map(|_| ())).unwrap();
}

#[pg_test(error = "cursor: CursorNotFound(\"NOT A CURSOR\")")]
fn test_cursor_not_found() {
Spi::connect(|client| client.find_cursor("NOT A CURSOR").map(|_| ())).expect("cursor");
}

#[pg_test]
fn test_columns() {
use pgx::{PgBuiltInOids, PgOid};
Spi::connect(|client| {
let res = client.select("SELECT 42 AS a, 'test' AS b", None, None);

assert_eq!(2, res.columns());

assert_eq!(res.column_type_oid(1).unwrap(), PgOid::BuiltIn(PgBuiltInOids::INT4OID));

assert_eq!(res.column_type_oid(2).unwrap(), PgOid::BuiltIn(PgBuiltInOids::TEXTOID));

assert_eq!(res.column_name(1).unwrap(), "a");

assert_eq!(res.column_name(2).unwrap(), "b");
});

Spi::connect(|mut client| {
let res = client.update("SET TIME ZONE 'PST8PDT'", None, None);

assert_eq!(0, res.columns());
});
}

#[pg_test]
fn test_connect_return_anything() {
struct T;
assert!(matches!(Spi::connect(|_| Ok::<_, ()>(Some(T))).unwrap().unwrap(), T));
}

#[pg_test]
fn test_spi_non_mut() -> Result<(), pgx::spi::Error> {
// Ensures update and cursor APIs do not need mutable reference to SpiClient
Spi::connect(|mut client| {
client.update("SELECT 1", None, None);
let cursor = client.open_cursor("SELECT 1", None)?.detach_into_name();
client.find_cursor(&cursor).map(|_| ())
})
}

#[pg_test]
fn test_open_multiple_tuptables() {
// Regression test to ensure a new `SpiTupTable` instance does not override the
// effective length of an already open one due to misuse of Spi statics
Spi::connect(|client| {
let a = client.select("SELECT 1", None, None).first();
let _b = client.select("SELECT 1 WHERE 'f'", None, None);
assert!(!a.is_empty());
assert_eq!(1, a.len());
assert!(a.get_heap_tuple().is_some());
assert_eq!(Ok(Some(1)), a.get_datum::<i32>(1));
})
}

#[pg_test]
fn test_open_multiple_tuptables_rev() {
// Regression test to ensure a new `SpiTupTable` instance does not override the
// effective length of an already open one.
// Same as `test_open_multiple_tuptables`, but with the second tuptable being empty
Spi::connect(|client| {
let a = client.select("SELECT 1 WHERE 'f'", None, None).first();
let _b = client.select("SELECT 1", None, None);
assert!(a.is_empty());
assert_eq!(0, a.len());
assert!(a.get_heap_tuple().is_none());
assert_eq!(Err(pgx::spi::Error::InvalidPosition), a.get_datum::<i32>(1));
});
}

#[pg_test]
fn test_spi_unwind_safe() {
struct T;
assert!(matches!(Spi::connect(|_| Ok::<_, ()>(Some(T))).unwrap().unwrap(), T));
}

#[pg_test]
fn test_error_propagation() {
#[derive(Debug)]
struct Error;
let result = Spi::connect(|_| Err::<(), _>(Error));
assert!(matches!(result, Err(Error)))
}

#[pg_test]
fn test_option() {
assert!(Spi::get_one::<i32>("SELECT NULL::integer").unwrap().is_none());