Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Ensure Python version matches version used to serialize credential provider #19375

Merged
merged 4 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion crates/polars-io/src/cloud/credential_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,14 @@ mod python_impl {
use super::IntoCredentialProvider;

#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct PythonCredentialProvider(pub(super) Arc<PythonFunction>);

impl From<PythonFunction> for PythonCredentialProvider {
fn from(value: PythonFunction) -> Self {
Self(Arc::new(value))
}
}

impl IntoCredentialProvider for PythonCredentialProvider {
#[cfg(feature = "aws")]
fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider {
Expand Down Expand Up @@ -665,6 +670,32 @@ mod python_impl {
state.write_usize(Arc::as_ptr(&self.0) as *const () as usize)
}
}

#[cfg(feature = "serde")]
mod _serde_impl {
use polars_utils::python_function::PySerializeWrap;

use super::PythonCredentialProvider;

impl serde::Serialize for PythonCredentialProvider {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
PySerializeWrap(self.0.as_ref()).serialize(serializer)
}
}

impl<'a> serde::Deserialize<'a> for PythonCredentialProvider {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'a>,
{
PySerializeWrap::<super::PythonFunction>::deserialize(deserializer)
.map(|x| x.0.into())
}
}
}
}

#[cfg(test)]
Expand Down
23 changes: 12 additions & 11 deletions crates/polars-plan/src/dsl/python_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub static mut CALL_DF_UDF_PYTHON: Option<
> = None;

pub use polars_utils::python_function::{
PythonFunction, PYTHON_SERDE_MAGIC_BYTE_MARK, PYTHON_VERSION_MINOR,
PythonFunction, PYTHON3_VERSION, PYTHON_SERDE_MAGIC_BYTE_MARK,
};

pub struct PythonUdfExpression {
Expand Down Expand Up @@ -57,17 +57,17 @@ impl PythonUdfExpression {
// Handle pickle metadata
let use_cloudpickle = buf[0];
if use_cloudpickle != 0 {
Copy link
Collaborator Author

@nameexhaustion nameexhaustion Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this, but the existing code for PythonUDF deserialization doesn't enforce the Python version equality if it was serialized with cloudpickle. Maybe we should?

let ser_py_version = buf[1];
let cur_py_version = *PYTHON_VERSION_MINOR;
let ser_py_version = &buf[1..3];
let cur_py_version = *PYTHON3_VERSION;
polars_ensure!(
ser_py_version == cur_py_version,
InvalidOperation:
"current Python version (3.{}) does not match the Python version used to serialize the UDF (3.{})",
cur_py_version,
ser_py_version
"current Python version {:?} does not match the Python version used to serialize the UDF {:?}",
(3, cur_py_version[0], cur_py_version[1]),
(3, ser_py_version[0], ser_py_version[1] )
Copy link
Collaborator Author

@nameexhaustion nameexhaustion Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed this to also check the micro version - from research pickle requires the exact same Python version.

was added from #19175

);
}
let buf = &buf[2..];
let buf = &buf[3..];

// Load UDF metadata
let mut reader = Cursor::new(buf);
Expand Down Expand Up @@ -141,8 +141,8 @@ impl ColumnsUdf for PythonUdfExpression {
.getattr("dumps")
.unwrap();
let pickle_result = pickle.call1((self.python_function.clone_ref(py),));
let (dumped, use_cloudpickle, py_version) = match pickle_result {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still serialize the correct Python version even if we are using cloudpickle

Ok(dumped) => (dumped, false, 0),
let (dumped, use_cloudpickle) = match pickle_result {
Ok(dumped) => (dumped, false),
Err(_) => {
let cloudpickle = PyModule::import_bound(py, "cloudpickle")
.map_err(from_pyerr)?
Expand All @@ -151,12 +151,13 @@ impl ColumnsUdf for PythonUdfExpression {
let dumped = cloudpickle
.call1((self.python_function.clone_ref(py),))
.map_err(from_pyerr)?;
(dumped, true, *PYTHON_VERSION_MINOR)
(dumped, true)
},
};

// Write pickle metadata
buf.extend_from_slice(&[use_cloudpickle as u8, py_version]);
buf.push(use_cloudpickle as u8);
buf.extend_from_slice(&*PYTHON3_VERSION);

// Write UDF metadata
ciborium::ser::into_writer(
Expand Down
226 changes: 181 additions & 45 deletions crates/polars-utils/src/python_function.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
use once_cell::sync::Lazy;
use polars_error::{polars_bail, PolarsError, PolarsResult};
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedBytes;
use pyo3::types::PyBytes;
#[cfg(feature = "serde")]
use serde::ser::Error;
#[cfg(feature = "serde")]
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub use serde_wrap::{
PySerializeWrap, TrySerializeToBytes, PYTHON3_VERSION,
SERDE_MAGIC_BYTE_MARK as PYTHON_SERDE_MAGIC_BYTE_MARK,
};

#[cfg(feature = "serde")]
pub const PYTHON_SERDE_MAGIC_BYTE_MARK: &[u8] = "PLPYUDF".as_bytes();
pub static PYTHON_VERSION_MINOR: Lazy<u8> = Lazy::new(get_python_minor_version);
use crate::flatten;

#[derive(Debug)]
pub struct PythonFunction(pub PyObject);
Expand Down Expand Up @@ -42,64 +41,201 @@ impl PartialEq for PythonFunction {
}

#[cfg(feature = "serde")]
impl Serialize for PythonFunction {
impl serde::Serialize for PythonFunction {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: Serializer,
S: serde::Serializer,
{
Python::with_gil(|py| {
let pickle = PyModule::import_bound(py, "cloudpickle")
.or_else(|_| PyModule::import_bound(py, "pickle"))
.expect("unable to import 'cloudpickle' or 'pickle'")
.getattr("dumps")
.unwrap();

let python_function = self.0.clone_ref(py);

let dumped = pickle
.call1((python_function,))
.map_err(|s| S::Error::custom(format!("cannot pickle {s}")))?;
let dumped = dumped.extract::<PyBackedBytes>().unwrap();

serializer.serialize_bytes(&dumped)
})
use serde::ser::Error;
serializer.serialize_bytes(
self.try_serialize_to_bytes()
.map_err(|e| S::Error::custom(e.to_string()))?
.as_slice(),
)
}
}

#[cfg(feature = "serde")]
impl<'a> Deserialize<'a> for PythonFunction {
impl<'a> serde::Deserialize<'a> for PythonFunction {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'a>,
D: serde::Deserializer<'a>,
{
use serde::de::Error;
let bytes = Vec::<u8>::deserialize(deserializer)?;
Self::try_deserialize_bytes(bytes.as_slice()).map_err(|e| D::Error::custom(e.to_string()))
}
}

Python::with_gil(|py| {
let pickle = PyModule::import_bound(py, "pickle")
.expect("unable to import 'pickle'")
.getattr("loads")
#[cfg(feature = "serde")]
impl TrySerializeToBytes for PythonFunction {
fn try_serialize_to_bytes(&self) -> polars_error::PolarsResult<Vec<u8>> {
serialize_pyobject_with_cloudpickle_fallback(&self.0)
}

fn try_deserialize_bytes(bytes: &[u8]) -> polars_error::PolarsResult<Self> {
deserialize_pyobject_bytes_maybe_cloudpickle(bytes)
}
}

pub fn serialize_pyobject_with_cloudpickle_fallback(py_object: &PyObject) -> PolarsResult<Vec<u8>> {
Python::with_gil(|py| {
let pickle = PyModule::import_bound(py, "pickle")
.expect("unable to import 'pickle'")
.getattr("dumps")
.unwrap();

let dumped = pickle.call1((py_object.clone_ref(py),));

let (dumped, used_cloudpickle) = if let Ok(v) = dumped {
(v, false)
} else {
let cloudpickle = PyModule::import_bound(py, "cloudpickle")
.map_err(from_pyerr)?
.getattr("dumps")
.unwrap();
let arg = (PyBytes::new_bound(py, &bytes),);
let python_function = pickle
.call1(arg)
.map_err(|s| D::Error::custom(format!("cannot pickle {s}")))?;
let dumped = cloudpickle
.call1((py_object.clone_ref(py),))
.map_err(from_pyerr)?;
(dumped, true)
};

Ok(Self(python_function.into()))
})
let py_bytes = dumped.extract::<PyBackedBytes>().map_err(from_pyerr)?;

Ok(flatten(
&[&[used_cloudpickle as u8, b'C'][..], py_bytes.as_ref()],
None,
))
})
}

pub fn deserialize_pyobject_bytes_maybe_cloudpickle<T: for<'a> From<PyObject>>(
bytes: &[u8],
) -> PolarsResult<T> {
// TODO: Actually deserialize with cloudpickle if it's set.
let [_used_cloudpickle @ 0 | _used_cloudpickle @ 1, b'C', rem @ ..] = bytes else {
polars_bail!(ComputeError: "deserialize_pyobject_bytes_maybe_cloudpickle: invalid start bytes")
};

let bytes = rem;

Python::with_gil(|py| {
let pickle = PyModule::import_bound(py, "pickle")
.expect("unable to import 'pickle'")
.getattr("loads")
.unwrap();
let arg = (PyBytes::new_bound(py, bytes),);
let pyany_bound = pickle.call1(arg).map_err(from_pyerr)?;
Ok(PyObject::from(pyany_bound).into())
})
}

#[cfg(feature = "serde")]
mod serde_wrap {
use once_cell::sync::Lazy;
use polars_error::PolarsResult;

use crate::flatten;

pub const SERDE_MAGIC_BYTE_MARK: &[u8] = "PLPYFN".as_bytes();
/// [minor, micro]
pub static PYTHON3_VERSION: Lazy<[u8; 2]> = Lazy::new(super::get_python3_version);

/// Serializes a Python object without additional system metadata. This is intended to be used
/// together with `PySerializeWrap`, which attaches e.g. Python version metadata.
pub trait TrySerializeToBytes: Sized {
fn try_serialize_to_bytes(&self) -> PolarsResult<Vec<u8>>;
fn try_deserialize_bytes(bytes: &[u8]) -> PolarsResult<Self>;
}

/// Serialization wrapper for T: TrySerializeToBytes that attaches Python
/// version metadata.
pub struct PySerializeWrap<T>(pub T);

impl<T: TrySerializeToBytes> serde::Serialize for PySerializeWrap<&T> {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::Error;
let dumped = self
.0
.try_serialize_to_bytes()
.map_err(|e| S::Error::custom(e.to_string()))?;

serializer.serialize_bytes(
flatten(
&[SERDE_MAGIC_BYTE_MARK, &*PYTHON3_VERSION, dumped.as_slice()],
None,
)
.as_slice(),
)
}
}

impl<'a, T: TrySerializeToBytes> serde::Deserialize<'a> for PySerializeWrap<T> {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'a>,
{
use serde::de::Error;
let bytes = Vec::<u8>::deserialize(deserializer)?;

let Some((magic, rem)) = bytes.split_at_checked(SERDE_MAGIC_BYTE_MARK.len()) else {
return Err(D::Error::custom(
"unexpected EOF when reading serialized pyobject version",
));
};

if magic != SERDE_MAGIC_BYTE_MARK {
return Err(D::Error::custom(
"serialized pyobject did not begin with magic byte mark",
));
}

let bytes = rem;

let [a, b, rem @ ..] = bytes else {
return Err(D::Error::custom(
"unexpected EOF when reading serialized pyobject metadata",
));
};

let py3_version = [*a, *b];

if py3_version != *PYTHON3_VERSION {
return Err(D::Error::custom(format!(
"python version that pyobject was serialized with {:?} \
differs from system python version {:?}",
(3, py3_version[0], py3_version[1]),
(3, PYTHON3_VERSION[0], PYTHON3_VERSION[1]),
)));
}

let bytes = rem;

T::try_deserialize_bytes(bytes)
.map(Self)
.map_err(|e| D::Error::custom(e.to_string()))
}
}
}

/// Get the minor Python version from the `sys` module.
fn get_python_minor_version() -> u8 {
/// Get the [minor, micro] Python3 version from the `sys` module.
fn get_python3_version() -> [u8; 2] {
Python::with_gil(|py| {
PyModule::import_bound(py, "sys")
let version_info = PyModule::import_bound(py, "sys")
.unwrap()
.getattr("version_info")
.unwrap()
.getattr("minor")
.unwrap()
.extract()
.unwrap()
.unwrap();

[
version_info.getattr("minor").unwrap().extract().unwrap(),
version_info.getattr("micro").unwrap().extract().unwrap(),
]
})
}

fn from_pyerr(e: PyErr) -> PolarsError {
PolarsError::ComputeError(format!("error raised in python: {e}").into())
}
Loading