Skip to content

Commit

Permalink
update function signatures to use _bound versions
Browse files Browse the repository at this point in the history
  • Loading branch information
emgeee committed Sep 10, 2024
1 parent 04e65e8 commit c368bc1
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
5 changes: 3 additions & 2 deletions src/expr/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use std::fmt::{self, Display, Formatter};
use crate::common::df_schema::PyDFSchema;
use crate::errors::py_type_err;
use crate::expr::logical_node::LogicalNode;
use crate::expr::sort_expr::{py_sort_expr_list, PySortExpr};
use crate::expr::PyExpr;
use crate::sql::logical::PyLogicalPlan;

Expand Down Expand Up @@ -114,9 +115,9 @@ impl PyWindow {
}

/// Returns order by columns in a window function expression
pub fn get_sort_exprs(&self, expr: PyExpr) -> PyResult<Vec<PyExpr>> {
pub fn get_sort_exprs(&self, expr: PyExpr) -> PyResult<Vec<PySortExpr>> {
match expr.expr.unalias() {
Expr::WindowFunction(WindowFunction { order_by, .. }) => py_expr_list(&order_by),
Expr::WindowFunction(WindowFunction { order_by, .. }) => py_sort_expr_list(&order_by),
other => Err(not_window_function_err(other)),
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::sync::Arc;
use pyo3::{prelude::*, types::PyTuple};

use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef};
use datafusion::arrow::pyarrow::FromPyArrow;
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
use datafusion::error::DataFusionError;
Expand All @@ -43,16 +44,15 @@ fn to_rust_function(func: PyObject) -> ScalarFunctionImplementation {
.iter()
.map(|arg| arg.into_data().to_pyarrow(py).unwrap())
.collect::<Vec<_>>();
let py_args = PyTuple::new(py, py_args);
let py_args = PyTuple::new_bound(py, py_args);

// 2. call function
let value = func
.as_ref(py)
.call(py_args, None)
.call_bound(py, py_args, None)
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;

// 3. cast to arrow::array::Array
let array_data = ArrayData::from_pyarrow(value).unwrap();
let array_data = ArrayData::from_pyarrow_bound(value.bind(py)).unwrap();
Ok(make_array(array_data))
})
},
Expand Down

0 comments on commit c368bc1

Please sign in to comment.