Skip to content

feat: Adding with_schema_unchecked method for RecordBatch #7402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 10, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 92 additions & 2 deletions arrow-array/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,18 @@ impl RecordBatch {
})
}

/// Overrides the schema of this [`RecordBatch`]
/// without additional schema checks. Note, however, that this pushes all the schema compatibility responsibilities
/// to the caller site. In particular, the caller guarantees that `schema` is a superset
/// of the current schema as determined by [`Schema::contains`].
pub fn with_schema_unchecked(self, schema: SchemaRef) -> Result<Self, ArrowError> {
Ok(Self {
schema,
columns: self.columns,
row_count: self.row_count,
})
}

/// Returns the [`Schema`] of the record batch.
pub fn schema(&self) -> SchemaRef {
self.schema.clone()
Expand Down Expand Up @@ -744,12 +756,14 @@ impl RecordBatchOptions {
row_count: None,
}
}
/// Sets the row_count of RecordBatchOptions and returns self

/// Sets the `row_count` of `RecordBatchOptions` and returns this [`RecordBatch`]
pub fn with_row_count(mut self, row_count: Option<usize>) -> Self {
self.row_count = row_count;
self
}
/// Sets the match_field_names of RecordBatchOptions and returns self

/// Sets the `match_field_names` of `RecordBatchOptions` and returns this [`RecordBatch`]
pub fn with_match_field_names(mut self, match_field_names: bool) -> Self {
self.match_field_names = match_field_names;
self
Expand Down Expand Up @@ -1637,4 +1651,80 @@ mod tests {
"bar"
);
}

#[test]
fn test_batch_with_unchecked_schema() {
fn apply_schema_unchecked(
record_batch: &RecordBatch,
schema_ref: SchemaRef,
idx: usize,
) -> Option<ArrowError> {
record_batch
.clone()
.with_schema_unchecked(schema_ref)
.unwrap()
.project(&[idx])
.err()
}

let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"]));

let record_batch =
RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion");

// Test empty schema for non-empty schema batch
let invalid_schema_empty = Schema::empty();
assert_eq!(
apply_schema_unchecked(&record_batch, invalid_schema_empty.into(), 0)
.unwrap()
.to_string(),
"Schema error: project index 0 out of bounds, max field 0"
);

// Wrong number of columns
let invalid_schema_more_cols = Schema::new(vec![
Field::new("a", DataType::Utf8, false),
Field::new("b", DataType::Int32, false),
]);

assert!(
apply_schema_unchecked(&record_batch, invalid_schema_more_cols.clone().into(), 0)
.is_none()
);

assert_eq!(
apply_schema_unchecked(&record_batch, invalid_schema_more_cols.into(), 1)
.unwrap()
.to_string(),
"Schema error: project index 1 out of bounds, max field 1"
);

// Wrong datatype
let invalid_schema_wrong_datatype =
Schema::new(vec![Field::new("a", DataType::Int32, false)]);
assert_eq!(apply_schema_unchecked(&record_batch, invalid_schema_wrong_datatype.into(), 0).unwrap().to_string(), "Invalid argument error: column types must match schema types, expected Int32 but found Utf8 at column index 0");

// Wrong column name. A instead C
let invalid_schema_wrong_col_name =
Schema::new(vec![Field::new("a", DataType::Utf8, false)]);

assert!(record_batch
.clone()
.with_schema_unchecked(invalid_schema_wrong_col_name.into())
.unwrap()
.column_by_name("c")
.is_none());

// Valid schema
let valid_schema = Schema::new(vec![Field::new("c", DataType::Utf8, false)]);

assert_eq!(
record_batch
.clone()
.with_schema_unchecked(valid_schema.into())
.unwrap()
.column_by_name("c"),
record_batch.column_by_name("c")
);
}
}
Loading