Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix overwrite of dtypes for DF.load_csv/2 #955

Merged
merged 2 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions lib/explorer/polars_backend/data_frame.ex
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,6 @@ defmodule Explorer.PolarsBackend.DataFrame do

{columns, with_projection} = column_names_or_projection(columns)

dtypes_list =
if not Enum.empty?(dtypes) do
Map.to_list(dtypes)
end

df =
Native.df_load_csv(
contents,
Expand All @@ -212,7 +207,7 @@ defmodule Explorer.PolarsBackend.DataFrame do
delimiter,
true,
columns,
dtypes_list,
Map.to_list(dtypes),
encoding,
nil_values,
parse_dates,
Expand Down
20 changes: 10 additions & 10 deletions native/explorer/src/dataframe/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pub fn df_from_csv(
};

let dataframe = CsvReadOptions::default()
.with_schema_overwrite(schema_from_dtypes_pairs(dtypes)?)
.with_infer_schema_length(infer_schema_length)
.with_has_header(has_header)
.with_n_rows(stop_after_n_rows)
Expand All @@ -56,7 +57,6 @@ pub fn df_from_csv(
.with_projection(projection.map(Arc::new))
.with_rechunk(do_rechunk)
.with_columns(column_names.map(Arc::from))
.with_schema_overwrite(Some(schema_from_dtypes_pairs(dtypes)?))
.with_parse_options(
CsvParseOptions::default()
.with_encoding(encoding)
Expand All @@ -74,13 +74,17 @@ pub fn df_from_csv(

pub fn schema_from_dtypes_pairs(
dtypes: Vec<(&str, ExSeriesDtype)>,
) -> Result<Arc<Schema>, ExplorerError> {
) -> Result<Option<Arc<Schema>>, ExplorerError> {
if dtypes.is_empty() {
return Ok(None);
}

let mut schema = Schema::new();
for (name, ex_dtype) in dtypes {
let dtype = DataType::try_from(&ex_dtype)?;
schema.with_column(name.into(), dtype);
}
Ok(Arc::new(schema))
Ok(Some(Arc::new(schema)))
}

#[rustler::nif(schedule = "DirtyIo")]
Expand Down Expand Up @@ -152,7 +156,7 @@ pub fn df_load_csv(
delimiter_as_byte: u8,
do_rechunk: bool,
column_names: Option<Vec<String>>,
dtypes: Option<Vec<(&str, ExSeriesDtype)>>,
dtypes: Vec<(&str, ExSeriesDtype)>,
encoding: &str,
null_vals: Vec<String>,
parse_dates: bool,
Expand All @@ -165,12 +169,8 @@ pub fn df_load_csv(

let cursor = Cursor::new(binary.as_slice());

let read_options = match dtypes {
Some(val) => CsvReadOptions::default().with_schema(Some(schema_from_dtypes_pairs(val)?)),
None => CsvReadOptions::default(),
};

let dataframe = read_options
let dataframe = CsvReadOptions::default()
.with_schema_overwrite(schema_from_dtypes_pairs(dtypes)?)
.with_has_header(has_header)
.with_infer_schema_length(infer_schema_length)
.with_n_rows(stop_after_n_rows)
Expand Down
2 changes: 1 addition & 1 deletion native/explorer/src/lazyframe/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ pub fn lf_from_csv(
.with_skip_rows_after_header(skip_rows_after_header)
.with_rechunk(do_rechunk)
.with_encoding(encoding)
.with_dtype_overwrite(Some(schema_from_dtypes_pairs(dtypes)?))
.with_dtype_overwrite(schema_from_dtypes_pairs(dtypes)?)
.with_null_values(Some(NullValues::AllColumns(null_vals)))
.with_eol_char(eol_delimiter.unwrap_or(b'\n'))
.finish()?;
Expand Down
249 changes: 249 additions & 0 deletions test/explorer/data_frame/csv_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,130 @@ defmodule Explorer.DataFrame.CSVTest do
assert city[13] == "Aberdeen, Aberdeen City, UK"
end

test "load_csv/2 dtypes - all as strings" do
csv =
"""
id,first_name,last_name,email,gender,ip_address,salary,latitude,longitude
1,Torey,Geraghty,[email protected],Male,119.110.38.172,14036.68,38.9187037,-76.9611991
2,Nevin,Mandrake,[email protected],Male,161.2.124.233,32530.27,41.4176872,-8.7653155
3,Melisenda,Guiso,[email protected],Female,192.152.64.134,9177.8,21.3772424,110.2485736
4,Noble,Doggett,[email protected],Male,252.234.29.244,20328.76,37.268428,55.1487513
5,Janaya,Claypoole,[email protected],Female,150.191.214.252,21442.93,15.3553417,120.5293228
6,Sarah,Hugk,[email protected],Female,211.158.246.13,79709.16,28.168408,120.482198
7,Ulberto,Simenon,[email protected],Male,206.56.108.90,16248.98,48.4046776,-0.9746208
8,Kevon,Lingner,[email protected],Male,181.71.212.116,7497.64,-23.351784,-47.6931718
9,Sada,Garbert,[email protected],Female,170.42.190.231,15969.95,30.3414125,114.1543243
10,Salmon,Shoulders,[email protected],Male,68.138.106.143,19996.71,49.2152833,17.7687416
"""

headers = ~w(id first_name last_name email gender ip_address salary latitude longitude)

# Out of order on purpose.
df = DF.load_csv!(csv, dtypes: for(l <- Enum.shuffle(headers), do: {l, :string}))

assert DF.names(df) == headers

assert DF.to_columns(df, atom_keys: true) == %{
email: [
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]"
],
first_name: [
"Torey",
"Nevin",
"Melisenda",
"Noble",
"Janaya",
"Sarah",
"Ulberto",
"Kevon",
"Sada",
"Salmon"
],
gender: [
"Male",
"Male",
"Female",
"Male",
"Female",
"Female",
"Male",
"Male",
"Female",
"Male"
],
id: ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
ip_address: [
"119.110.38.172",
"161.2.124.233",
"192.152.64.134",
"252.234.29.244",
"150.191.214.252",
"211.158.246.13",
"206.56.108.90",
"181.71.212.116",
"170.42.190.231",
"68.138.106.143"
],
last_name: [
"Geraghty",
"Mandrake",
"Guiso",
"Doggett",
"Claypoole",
"Hugk",
"Simenon",
"Lingner",
"Garbert",
"Shoulders"
],
latitude: [
"38.9187037",
"41.4176872",
"21.3772424",
"37.268428",
"15.3553417",
"28.168408",
"48.4046776",
"-23.351784",
"30.3414125",
"49.2152833"
],
longitude: [
"-76.9611991",
"-8.7653155",
"110.2485736",
"55.1487513",
"120.5293228",
"120.482198",
"-0.9746208",
"-47.6931718",
"114.1543243",
"17.7687416"
],
salary: [
"14036.68",
"32530.27",
"9177.8",
"20328.76",
"21442.93",
"79709.16",
"16248.98",
"7497.64",
"15969.95",
"19996.71"
]
}
end

def assert_csv(type, csv_value, parsed_value, from_csv_options) do
data = "column\n#{csv_value}\n"
# parsing should work as expected
Expand Down Expand Up @@ -182,6 +306,131 @@ defmodule Explorer.DataFrame.CSVTest do
}
end

@tag :tmp_dir
test "dtypes - all as strings", config do
csv =
tmp_csv(config.tmp_dir, """
id,first_name,last_name,email,gender,ip_address,salary,latitude,longitude
1,Torey,Geraghty,[email protected],Male,119.110.38.172,14036.68,38.9187037,-76.9611991
2,Nevin,Mandrake,[email protected],Male,161.2.124.233,32530.27,41.4176872,-8.7653155
3,Melisenda,Guiso,[email protected],Female,192.152.64.134,9177.8,21.3772424,110.2485736
4,Noble,Doggett,[email protected],Male,252.234.29.244,20328.76,37.268428,55.1487513
5,Janaya,Claypoole,[email protected],Female,150.191.214.252,21442.93,15.3553417,120.5293228
6,Sarah,Hugk,[email protected],Female,211.158.246.13,79709.16,28.168408,120.482198
7,Ulberto,Simenon,[email protected],Male,206.56.108.90,16248.98,48.4046776,-0.9746208
8,Kevon,Lingner,[email protected],Male,181.71.212.116,7497.64,-23.351784,-47.6931718
9,Sada,Garbert,[email protected],Female,170.42.190.231,15969.95,30.3414125,114.1543243
10,Salmon,Shoulders,[email protected],Male,68.138.106.143,19996.71,49.2152833,17.7687416
""")

headers = ~w(id first_name last_name email gender ip_address salary latitude longitude)

# Out of order on purpose.
df = DF.from_csv!(csv, dtypes: for(l <- Enum.shuffle(headers), do: {l, :string}))

assert DF.names(df) == headers

assert DF.to_columns(df, atom_keys: true) == %{
email: [
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]",
"[email protected]"
],
first_name: [
"Torey",
"Nevin",
"Melisenda",
"Noble",
"Janaya",
"Sarah",
"Ulberto",
"Kevon",
"Sada",
"Salmon"
],
gender: [
"Male",
"Male",
"Female",
"Male",
"Female",
"Female",
"Male",
"Male",
"Female",
"Male"
],
id: ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
ip_address: [
"119.110.38.172",
"161.2.124.233",
"192.152.64.134",
"252.234.29.244",
"150.191.214.252",
"211.158.246.13",
"206.56.108.90",
"181.71.212.116",
"170.42.190.231",
"68.138.106.143"
],
last_name: [
"Geraghty",
"Mandrake",
"Guiso",
"Doggett",
"Claypoole",
"Hugk",
"Simenon",
"Lingner",
"Garbert",
"Shoulders"
],
latitude: [
"38.9187037",
"41.4176872",
"21.3772424",
"37.268428",
"15.3553417",
"28.168408",
"48.4046776",
"-23.351784",
"30.3414125",
"49.2152833"
],
longitude: [
"-76.9611991",
"-8.7653155",
"110.2485736",
"55.1487513",
"120.5293228",
"120.482198",
"-0.9746208",
"-47.6931718",
"114.1543243",
"17.7687416"
],
salary: [
"14036.68",
"32530.27",
"9177.8",
"20328.76",
"21442.93",
"79709.16",
"16248.98",
"7497.64",
"15969.95",
"19996.71"
]
}
end

@tag :tmp_dir
test "dtypes - parse datetime", config do
csv =
Expand Down
Loading