Skip to content

Replace RecordBatch::with_schema_unchecked with RecordBatch::new_unchecked #7405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 13, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 39 additions & 96 deletions arrow-array/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,11 @@ impl RecordBatch {
/// Creates a `RecordBatch` from a schema and columns.
///
/// Expects the following:
/// * the vec of columns to not be empty
/// * the schema and column data types to have equal lengths
/// and match
/// * each array in columns to have the same length
///
/// * `!columns.is_empty()`
/// * `schema.fields.len() == columns.len()`
/// * `schema.fields[i].data_type() == columns[i].data_type()`
/// * `columns[i].len() == columns[j].len()`
///
/// If the conditions are not met, an error is returned.
///
Expand All @@ -240,6 +241,33 @@ impl RecordBatch {
Self::try_new_impl(schema, columns, &options)
}

/// Creates a `RecordBatch` from a schema and columns, without validation.
///
/// See [`Self::try_new`] for the checked version.
///
/// # Safety
///
/// Expects the following:
///
/// * `schema.fields.len() == columns.len()`
/// * `schema.fields[i].data_type() == columns[i].data_type()`
/// * `columns[i].len() == row_count`
///
/// Note: if the schema does not match the underlying data exactly, it can lead to undefined
/// behavior, for example, via conversion to a `StructArray`, which in turn could lead
/// to incorrect access.
pub unsafe fn new_unchecked(
schema: SchemaRef,
columns: Vec<Arc<dyn Array>>,
row_count: usize,
) -> Self {
Self {
schema,
columns,
row_count,
}
}

/// Creates a `RecordBatch` from a schema and columns, with additional options,
/// such as whether to strictly validate field names.
///
Expand Down Expand Up @@ -340,6 +368,11 @@ impl RecordBatch {
})
}

/// Return the schema, columns and row count of this [`RecordBatch`]
pub fn into_parts(self) -> (SchemaRef, Vec<ArrayRef>, usize) {
(self.schema, self.columns, self.row_count)
}

/// Override the schema of this [`RecordBatch`]
///
/// Returns an error if `schema` is not a superset of the current schema
Expand All @@ -359,18 +392,6 @@ impl RecordBatch {
})
}

/// Overrides the schema of this [`RecordBatch`]
/// without additional schema checks. Note, however, that this pushes all the schema compatibility responsibilities
/// to the caller site. In particular, the caller guarantees that `schema` is a superset
/// of the current schema as determined by [`Schema::contains`].
pub fn with_schema_unchecked(self, schema: SchemaRef) -> Result<Self, ArrowError> {
Ok(Self {
schema,
columns: self.columns,
row_count: self.row_count,
})
}

/// Returns the [`Schema`] of the record batch.
pub fn schema(&self) -> SchemaRef {
self.schema.clone()
Expand Down Expand Up @@ -756,14 +777,12 @@ impl RecordBatchOptions {
row_count: None,
}
}

/// Sets the `row_count` of `RecordBatchOptions` and returns this [`RecordBatch`]
/// Sets the row_count of RecordBatchOptions and returns self
pub fn with_row_count(mut self, row_count: Option<usize>) -> Self {
self.row_count = row_count;
self
}

/// Sets the `match_field_names` of `RecordBatchOptions` and returns this [`RecordBatch`]
/// Sets the match_field_names of RecordBatchOptions and returns self
pub fn with_match_field_names(mut self, match_field_names: bool) -> Self {
self.match_field_names = match_field_names;
self
Expand Down Expand Up @@ -1651,80 +1670,4 @@ mod tests {
"bar"
);
}

#[test]
fn test_batch_with_unchecked_schema() {
fn apply_schema_unchecked(
record_batch: &RecordBatch,
schema_ref: SchemaRef,
idx: usize,
) -> Option<ArrowError> {
record_batch
.clone()
.with_schema_unchecked(schema_ref)
.unwrap()
.project(&[idx])
.err()
}

let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));

let record_batch =
RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion");

// Test empty schema for non-empty schema batch
let invalid_schema_empty = Schema::empty();
assert_eq!(
apply_schema_unchecked(&record_batch, invalid_schema_empty.into(), 0)
.unwrap()
.to_string(),
"Schema error: project index 0 out of bounds, max field 0"
);

// Wrong number of columns
let invalid_schema_more_cols = Schema::new(vec![
Field::new("a", DataType::Utf8, false),
Field::new("b", DataType::Int32, false),
]);

assert!(
apply_schema_unchecked(&record_batch, invalid_schema_more_cols.clone().into(), 0)
.is_none()
);

assert_eq!(
apply_schema_unchecked(&record_batch, invalid_schema_more_cols.into(), 1)
.unwrap()
.to_string(),
"Schema error: project index 1 out of bounds, max field 1"
);

// Wrong datatype
let invalid_schema_wrong_datatype =
Schema::new(vec![Field::new("a", DataType::Int32, false)]);
assert_eq!(apply_schema_unchecked(&record_batch, invalid_schema_wrong_datatype.into(), 0).unwrap().to_string(), "Invalid argument error: column types must match schema types, expected Int32 but found Utf8 at column index 0");

// Wrong column name. A instead C
let invalid_schema_wrong_col_name =
Schema::new(vec![Field::new("a", DataType::Utf8, false)]);

assert!(record_batch
.clone()
.with_schema_unchecked(invalid_schema_wrong_col_name.into())
.unwrap()
.column_by_name("c")
.is_none());

// Valid schema
let valid_schema = Schema::new(vec![Field::new("c", DataType::Utf8, false)]);

assert_eq!(
record_batch
.clone()
.with_schema_unchecked(valid_schema.into())
.unwrap()
.column_by_name("c"),
record_batch.column_by_name("c")
);
}
}
Loading