Skip to content

Commit 0f7055a

Browse files
committedMay 29, 2024
feat: pbv
1 parent 976f0b1 commit 0f7055a

14 files changed

+567
-0
lines changed
 

‎.github/workflows/publish_to_pypi.yml

+157
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
name: CI
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
- master
8+
tags:
9+
- '*'
10+
pull_request:
11+
workflow_dispatch:
12+
13+
concurrency:
14+
15+
group: ${{ github.workflow }}-${{ github.ref }}
16+
17+
cancel-in-progress: true
18+
19+
permissions:
20+
contents: read
21+
22+
# Make sure CI fails on all warnings, including Clippy lints
23+
env:
24+
RUSTFLAGS: "-Dwarnings"
25+
26+
jobs:
27+
linux_tests:
28+
runs-on: ubuntu-latest
29+
strategy:
30+
matrix:
31+
target: [x86_64]
32+
python-version: ["3.8", "3.9", "3.10", "3.11"]
33+
steps:
34+
- uses: actions/checkout@v3
35+
- uses: actions/setup-python@v4
36+
with:
37+
38+
python-version: ${{ matrix.python-version }}
39+
40+
41+
- name: Set up Rust
42+
run: rustup show
43+
- uses: mozilla-actions/sccache-action@v0.0.3
44+
- run: make .venv
45+
- run: make pre-commit
46+
- run: make install
47+
- run: make test
48+
49+
linux:
50+
runs-on: ubuntu-latest
51+
strategy:
52+
matrix:
53+
target: [x86_64, x86]
54+
steps:
55+
- uses: actions/checkout@v3
56+
- uses: actions/setup-python@v4
57+
with:
58+
python-version: '3.10'
59+
- name: Build wheels
60+
uses: PyO3/maturin-action@v1
61+
with:
62+
63+
target: ${{ matrix.target }}
64+
65+
args: --release --out dist --find-interpreter
66+
sccache: 'true'
67+
manylinux: auto
68+
- name: Upload wheels
69+
uses: actions/upload-artifact@v3
70+
with:
71+
name: wheels
72+
path: dist
73+
74+
windows:
75+
runs-on: windows-latest
76+
strategy:
77+
matrix:
78+
target: [x64]
79+
steps:
80+
- uses: actions/checkout@v3
81+
- uses: actions/setup-python@v4
82+
with:
83+
python-version: '3.10'
84+
85+
architecture: ${{ matrix.target }}
86+
87+
- name: Build wheels
88+
uses: PyO3/maturin-action@v1
89+
with:
90+
91+
target: ${{ matrix.target }}
92+
93+
args: --release --out dist --find-interpreter
94+
sccache: 'true'
95+
- name: Upload wheels
96+
uses: actions/upload-artifact@v3
97+
with:
98+
name: wheels
99+
path: dist
100+
101+
macos:
102+
runs-on: macos-latest
103+
strategy:
104+
matrix:
105+
target: [x86_64, aarch64]
106+
steps:
107+
- uses: actions/checkout@v3
108+
- uses: actions/setup-python@v4
109+
with:
110+
python-version: '3.10'
111+
- name: Build wheels
112+
uses: PyO3/maturin-action@v1
113+
with:
114+
115+
target: ${{ matrix.target }}
116+
117+
args: --release --out dist --find-interpreter
118+
sccache: 'true'
119+
- name: Upload wheels
120+
uses: actions/upload-artifact@v3
121+
with:
122+
name: wheels
123+
path: dist
124+
125+
sdist:
126+
runs-on: ubuntu-latest
127+
steps:
128+
- uses: actions/checkout@v3
129+
- name: Build sdist
130+
uses: PyO3/maturin-action@v1
131+
with:
132+
command: sdist
133+
args: --out dist
134+
- name: Upload sdist
135+
uses: actions/upload-artifact@v3
136+
with:
137+
name: wheels
138+
path: dist
139+
140+
release:
141+
name: Release
142+
if: "startsWith(github.ref, 'refs/tags/')"
143+
needs: [linux, windows, macos, sdist]
144+
runs-on: ubuntu-latest
145+
environment: pypi
146+
permissions:
147+
id-token: write # IMPORTANT: mandatory for trusted publishing
148+
steps:
149+
- uses: actions/download-artifact@v3
150+
with:
151+
name: wheels
152+
- name: Publish to PyPI
153+
uses: PyO3/maturin-action@v1
154+
with:
155+
command: upload
156+
args: --non-interactive --skip-existing *
157+

‎.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@ Cargo.lock
1212

1313
# MSVC Windows builds of rustc generate these, which store debugging information
1414
*.pdb
15+
*.so
16+
*.pyc

‎.python-version

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
miniconda3-3.10-24.1.2-0

‎Cargo.toml

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
[package]
2+
name = "polars-pbv"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[lib]
7+
name = "polars_pbv"
8+
crate-type= ["cdylib"]
9+
10+
[dependencies]
11+
pyo3 = { version = "0.21.2", features = ["extension-module", "abi3-py38"] }
12+
pyo3-polars = { version = "0.13.0", features = ["derive"] }
13+
serde = { version = "1", features = ["derive"] }
14+
polars = { version = "0.39.2", default-features = false, features=["dtype-struct"]}
15+
16+
[target.'cfg(target_os = "linux")'.dependencies]
17+
jemallocator = { version = "0.5", features = ["disable_initial_exec_tls"] }
18+

‎Makefile

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
SHELL=/bin/bash
2+
3+
.venv: ## Set up virtual environment
4+
python3 -m venv .venv
5+
.venv/bin/pip install -r requirements.txt
6+
7+
install: .venv
8+
unset CONDA_PREFIX && \
9+
source .venv/bin/activate && maturin develop
10+
11+
install-release: .venv
12+
unset CONDA_PREFIX && \
13+
source .venv/bin/activate && maturin develop --release
14+
15+
pre-commit: .venv
16+
cargo fmt --all && cargo clippy --all-features
17+
.venv/bin/python -m ruff check . --fix --exit-non-zero-on-fix
18+
.venv/bin/python -m ruff format polars_pbv tests
19+
.venv/bin/mypy polars_pbv tests
20+
21+
test: .venv
22+
.venv/bin/python -m pytest tests
23+
24+
run: install
25+
source .venv/bin/activate && python run.py
26+
27+
run-release: install-release
28+
source .venv/bin/activate && python run.py
29+

‎polars_pbv/__init__.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from __future__ import annotations
2+
3+
from pathlib import Path
4+
from typing import TYPE_CHECKING
5+
6+
import polars as pl
7+
8+
from polars_pbv.utils import parse_into_expr, register_plugin, parse_version
9+
10+
if TYPE_CHECKING:
11+
from polars.type_aliases import IntoExpr
12+
13+
if parse_version(pl.__version__) < parse_version("0.20.16"):
14+
from polars.utils.udfs import _get_shared_lib_location
15+
16+
lib: str | Path = _get_shared_lib_location(__file__)
17+
else:
18+
lib = Path(__file__).parent
19+
20+
def pbv(
21+
price: IntoExpr,
22+
volume: IntoExpr,
23+
window_size: int,
24+
bins: int,
25+
center: bool = True,
26+
) -> pl.Expr:
27+
price = parse_into_expr(price)
28+
volume = parse_into_expr(volume)
29+
return register_plugin(
30+
args=[price, volume],
31+
symbol="price_by_volume",
32+
is_elementwise=False,
33+
lib=lib,
34+
kwargs={"window_size": window_size, "bins": bins, "center_label": center},
35+
)

‎polars_pbv/utils.py

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from __future__ import annotations
2+
3+
import re
4+
from typing import TYPE_CHECKING, Sequence, Any
5+
6+
import polars as pl
7+
8+
if TYPE_CHECKING:
9+
from polars.type_aliases import IntoExpr, PolarsDataType
10+
from pathlib import Path
11+
12+
13+
def parse_into_expr(
14+
expr: IntoExpr,
15+
*,
16+
str_as_lit: bool = False,
17+
list_as_lit: bool = True,
18+
dtype: PolarsDataType | None = None,
19+
) -> pl.Expr:
20+
"""
21+
Parse a single input into an expression.
22+
23+
Parameters
24+
----------
25+
expr
26+
The input to be parsed as an expression.
27+
str_as_lit
28+
Interpret string input as a string literal. If set to `False` (default),
29+
strings are parsed as column names.
30+
list_as_lit
31+
Interpret list input as a lit literal, If set to `False`,
32+
lists are parsed as `Series` literals.
33+
dtype
34+
If the input is expected to resolve to a literal with a known dtype, pass
35+
this to the `lit` constructor.
36+
37+
Returns
38+
-------
39+
polars.Expr
40+
"""
41+
if isinstance(expr, pl.Expr):
42+
pass
43+
elif isinstance(expr, str) and not str_as_lit:
44+
expr = pl.col(expr)
45+
elif isinstance(expr, list) and not list_as_lit:
46+
expr = pl.lit(pl.Series(expr), dtype=dtype)
47+
else:
48+
expr = pl.lit(expr, dtype=dtype)
49+
50+
return expr
51+
52+
53+
def register_plugin(
54+
*,
55+
symbol: str,
56+
is_elementwise: bool,
57+
kwargs: dict[str, Any] | None = None,
58+
args: list[IntoExpr],
59+
lib: str | Path,
60+
) -> pl.Expr:
61+
if parse_version(pl.__version__) < parse_version("0.20.16"):
62+
assert isinstance(args[0], pl.Expr)
63+
assert isinstance(lib, str)
64+
return args[0].register_plugin(
65+
lib=lib,
66+
symbol=symbol,
67+
args=args[1:],
68+
kwargs=kwargs,
69+
is_elementwise=is_elementwise,
70+
)
71+
from polars.plugins import register_plugin_function
72+
73+
return register_plugin_function(
74+
args=args,
75+
plugin_path=lib,
76+
function_name=symbol,
77+
kwargs=kwargs,
78+
is_elementwise=is_elementwise,
79+
)
80+
81+
def parse_version(version: Sequence[str | int]) -> tuple[int, ...]:
82+
# Simple version parser; split into a tuple of ints for comparison.
83+
# vendored from Polars
84+
if isinstance(version, str):
85+
version = version.split(".")
86+
return tuple(int(re.sub(r"\D", "", str(v))) for v in version)
87+

‎pyproject.toml

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
[build-system]
2+
requires = ["maturin>=1.0,<2.0", "polars>=0.20.6"]
3+
build-backend = "maturin"
4+
5+
[project]
6+
name = "polars-pbv"
7+
requires-python = ">=3.8"
8+
classifiers = [
9+
"Programming Language :: Rust",
10+
"Programming Language :: Python :: Implementation :: CPython",
11+
"Programming Language :: Python :: Implementation :: PyPy",
12+
]

‎requirements.txt

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
polars
2+
maturin
3+
ruff
4+
pytest
5+
mypy

‎run.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import polars as pl
2+
from polars_pbv import pig_latinnify
3+
4+
df = pl.DataFrame({
5+
'english': ['this', 'is', 'not', 'pig', 'latin'],
6+
})
7+
result = df.with_columns(pig_latin = pig_latinnify('english'))
8+
print(result)
9+

‎src/expressions.rs

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#![allow(clippy::unused_unit)]
2+
use polars::prelude::*;
3+
use pyo3_polars::derive::polars_expr;
4+
use serde::Deserialize;
5+
6+
#[derive(Deserialize)]
7+
pub struct PriceByVolumeKwargs {
8+
window_size: i32,
9+
bins: i32,
10+
center_label: bool,
11+
}
12+
13+
// fn price_by_volume_dtype(input_fields: &[Field]) -> PolarsResult<Field> {
14+
// let field = Field::new(
15+
// "pbv",
16+
// DataType::List(Box::new(input_fields[1].dtype.clone())),
17+
// );
18+
// Ok(field)
19+
// }
20+
21+
fn price_by_volume_dtype(input_fields: &[Field]) -> PolarsResult<Field> {
22+
let field_price = Field::new(
23+
"price",
24+
DataType::List(Box::new(input_fields[0].dtype.clone())),
25+
);
26+
let field_volume = Field::new(
27+
"volume",
28+
DataType::List(Box::new(input_fields[1].dtype.clone())),
29+
);
30+
let v: Vec<Field> = vec![field_price, field_volume];
31+
Ok(Field::new("pbv", DataType::Struct(v)))
32+
}
33+
34+
#[polars_expr(output_type_func=price_by_volume_dtype)]
35+
fn price_by_volume(inputs: &[Series], kwargs: PriceByVolumeKwargs) -> PolarsResult<Series> {
36+
let price = &inputs[0].to_float()?;
37+
let volume = &inputs[1].to_float()?;
38+
let window_size = kwargs.window_size as usize;
39+
let bins = kwargs.bins;
40+
let mut pbv = vec![];
41+
let mut label = vec![];
42+
for i in 1..(price.len() + 1) {
43+
println!("i: {}", i);
44+
if i < (window_size) {
45+
pbv.push(None);
46+
label.push(None);
47+
} else {
48+
let mut volume_at_price = vec![];
49+
let mut price_label = vec![];
50+
let start = (i - window_size) as i64;
51+
let window_price = price.slice(start, window_size);
52+
let window_volume = volume.slice(start, window_size);
53+
let max_price: f64 = window_price.max()?.unwrap();
54+
let min_price: f64 = window_price.min()?.unwrap();
55+
let range = max_price - min_price;
56+
let interval = range / bins as f64;
57+
for n in 0..bins {
58+
let lower_bound = min_price + n as f64 * interval;
59+
let upper_bound = min_price + (n + 1) as f64 * interval;
60+
let center = (lower_bound + upper_bound) / 2.0;
61+
if n == bins - 1 {
62+
println!("start: {}, lower: {}", start, lower_bound);
63+
let v: f64 = window_volume
64+
.filter(&window_price.gt_eq(lower_bound)?)?
65+
.sum()?;
66+
volume_at_price.push(v);
67+
} else {
68+
println!(
69+
"start: {}, lower: {}, upper: {}",
70+
start, lower_bound, upper_bound
71+
);
72+
let mask = window_price.gt_eq(lower_bound)? & window_price.lt(upper_bound)?;
73+
let v = window_volume.filter(&mask)?.sum()?;
74+
volume_at_price.push(v);
75+
}
76+
if kwargs.center_label {
77+
price_label.push(center);
78+
} else {
79+
price_label.push(lower_bound);
80+
}
81+
}
82+
println!("{:?}", volume_at_price);
83+
pbv.push(Some(Series::new("volume", &volume_at_price)));
84+
label.push(Some(Series::new("price", &price_label)));
85+
}
86+
}
87+
let label_series = Series::new("price", &label);
88+
let pbv_series = Series::new("volume", &pbv);
89+
let out = StructChunked::new("pbv", &vec![label_series, pbv_series])?;
90+
Ok(out.into_series())
91+
}

‎src/lib.rs

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
mod expressions;
2+
mod utils;
3+
4+
#[cfg(target_os = "linux")]
5+
use jemallocator::Jemalloc;
6+
7+
#[global_allocator]
8+
#[cfg(target_os = "linux")]
9+
static ALLOC: Jemalloc = Jemalloc;
10+
11+
use pyo3::types::PyModule;
12+
use pyo3::{pymodule, PyResult, Python};
13+
14+
#[pymodule]
15+
fn polars_pbv(_py: Python, m: &PyModule) -> PyResult<()> {
16+
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
17+
Ok(())
18+
}

‎src/utils.rs

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
use polars::prelude::*;
2+
3+
// This function is useful for writing functions which
4+
// accept pairs of List columns. Delete if unneded.
5+
#[allow(dead_code)]
6+
pub(crate) fn binary_amortized_elementwise<'a, T, K, F>(
7+
ca: &'a ListChunked,
8+
weights: &'a ListChunked,
9+
mut f: F,
10+
) -> ChunkedArray<T>
11+
where
12+
T: PolarsDataType,
13+
T::Array: ArrayFromIter<Option<K>>,
14+
F: FnMut(&Series, &Series) -> Option<K> + Copy,
15+
{
16+
unsafe {
17+
ca.amortized_iter()
18+
.zip(weights.amortized_iter())
19+
.map(|(lhs, rhs)| match (lhs, rhs) {
20+
(Some(lhs), Some(rhs)) => f(lhs.as_ref(), rhs.as_ref()),
21+
_ => None,
22+
})
23+
.collect_ca(ca.name())
24+
}
25+
}
26+
27+
// This function is useful for writing functions which
28+
// accept pairs of columns and produce String output. Delete if unneded.
29+
//
30+
// To use it, you will also need to import the following:
31+
//
32+
// use polars_arrow::array::Array;
33+
// use polars_arrow::array::MutablePlString;
34+
// use polars_core::utils::align_chunks_binary;
35+
// use std::fmt::Write;
36+
//
37+
// and make sure you have
38+
//
39+
// polars-arrow = { version = "0.37.0", default-features = false }
40+
// polars-core = { version = "0.37.0", default-features = false }
41+
//
42+
// in your `Cargo.toml` file.
43+
// Only uncomment if needed
44+
// pub(crate) fn binary_apply_to_buffer_generic<T, K, F>(
45+
// lhs: &ChunkedArray<T>,
46+
// rhs: &ChunkedArray<K>,
47+
// mut f: F,
48+
// ) -> StringChunked
49+
// where
50+
// T: PolarsDataType,
51+
// K: PolarsDataType,
52+
// F: for<'a> FnMut(T::Physical<'a>, K::Physical<'a>) -> String,
53+
// {
54+
// let (lhs, rhs) = align_chunks_binary(lhs, rhs);
55+
// let chunks = lhs
56+
// .downcast_iter()
57+
// .zip(rhs.downcast_iter())
58+
// .map(|(lhs_arr, rhs_arr)| {
59+
// let mut buf = String::new();
60+
// let mut mutarr = MutablePlString::with_capacity(lhs_arr.len());
61+
62+
// for (lhs_opt_val, rhs_opt_val) in lhs_arr.iter().zip(rhs_arr.iter()) {
63+
// match (lhs_opt_val, rhs_opt_val) {
64+
// (Some(lhs_val), Some(rhs_val)) => {
65+
// let res = f(lhs_val, rhs_val);
66+
// buf.clear();
67+
// write!(buf, "{res}").unwrap();
68+
// mutarr.push(Some(&buf))
69+
// }
70+
// _ => mutarr.push_null(),
71+
// }
72+
// }
73+
74+
// mutarr.freeze().boxed()
75+
// })
76+
// .collect();
77+
// unsafe { ChunkedArray::from_chunks("placeholder", chunks) }
78+
// }
79+

‎tests/test_pbv.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import polars as pl
2+
from polars_pbv import pbv
3+
4+
def test_pbv():
5+
price_col = [100, 101, 102, 103, 104, 105, 106, ]#107, 108, 109]
6+
volume_col = [200, 220, 250, 240, 260, 300, 280, ]# 270, 310, 330]
7+
window_size = 6#5.0
8+
bins = 3
9+
df = pl.DataFrame({
10+
'price': price_col,
11+
'volume': volume_col
12+
})
13+
expected_df = pl.DataFrame({
14+
"price": [*[None,]*5, [100., 101.666667, 103.333333], [101., 102.666667, 104.333333]],
15+
"volume": [*[None,]*5, [200+220, 250+240, 260+300], [220+250, 240+260, 280+300]]
16+
}).select(
17+
pl.struct("price", "volume").alias("pbv")
18+
)
19+
result = df.select(
20+
pbv("price", "volume", window_size=window_size, bins=bins, center=False).alias("pbv"),
21+
# pl.col("price").pbv.pbv(volume=pl.col("volume"), window_size=window_size, bins=bins).alias("pbv")
22+
)
23+
print(df)
24+
assert result.equals(expected_df)

0 commit comments

Comments
 (0)
Please sign in to comment.