diff --git a/arrow-array/src/record_batch.rs b/arrow-array/src/record_batch.rs index a6c2aee7cbc..cf20f9a059c 100644 --- a/arrow-array/src/record_batch.rs +++ b/arrow-array/src/record_batch.rs @@ -359,6 +359,18 @@ impl RecordBatch { }) } + /// Overrides the schema of this [`RecordBatch`] + /// without additional schema checks. Note, however, that this pushes all the schema compatibility responsibilities + /// to the caller site. In particular, the caller guarantees that `schema` is a superset + /// of the current schema as determined by [`Schema::contains`]. + pub fn with_schema_unchecked(self, schema: SchemaRef) -> Result { + Ok(Self { + schema, + columns: self.columns, + row_count: self.row_count, + }) + } + /// Returns the [`Schema`] of the record batch. pub fn schema(&self) -> SchemaRef { self.schema.clone() @@ -744,12 +756,14 @@ impl RecordBatchOptions { row_count: None, } } - /// Sets the row_count of RecordBatchOptions and returns self + + /// Sets the `row_count` of `RecordBatchOptions` and returns this [`RecordBatch`] pub fn with_row_count(mut self, row_count: Option) -> Self { self.row_count = row_count; self } - /// Sets the match_field_names of RecordBatchOptions and returns self + + /// Sets the `match_field_names` of `RecordBatchOptions` and returns this [`RecordBatch`] pub fn with_match_field_names(mut self, match_field_names: bool) -> Self { self.match_field_names = match_field_names; self @@ -1637,4 +1651,80 @@ mod tests { "bar" ); } + + #[test] + fn test_batch_with_unchecked_schema() { + fn apply_schema_unchecked( + record_batch: &RecordBatch, + schema_ref: SchemaRef, + idx: usize, + ) -> Option { + record_batch + .clone() + .with_schema_unchecked(schema_ref) + .unwrap() + .project(&[idx]) + .err() + } + + let c: ArrayRef = Arc::new(StringArray::from(vec!["d", "e", "f"])); + + let record_batch = + RecordBatch::try_from_iter(vec![("c", c.clone())]).expect("valid conversion"); + + // Test empty schema for non-empty schema batch + let invalid_schema_empty = Schema::empty(); + assert_eq!( + apply_schema_unchecked(&record_batch, invalid_schema_empty.into(), 0) + .unwrap() + .to_string(), + "Schema error: project index 0 out of bounds, max field 0" + ); + + // Wrong number of columns + let invalid_schema_more_cols = Schema::new(vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Int32, false), + ]); + + assert!( + apply_schema_unchecked(&record_batch, invalid_schema_more_cols.clone().into(), 0) + .is_none() + ); + + assert_eq!( + apply_schema_unchecked(&record_batch, invalid_schema_more_cols.into(), 1) + .unwrap() + .to_string(), + "Schema error: project index 1 out of bounds, max field 1" + ); + + // Wrong datatype + let invalid_schema_wrong_datatype = + Schema::new(vec![Field::new("a", DataType::Int32, false)]); + assert_eq!(apply_schema_unchecked(&record_batch, invalid_schema_wrong_datatype.into(), 0).unwrap().to_string(), "Invalid argument error: column types must match schema types, expected Int32 but found Utf8 at column index 0"); + + // Wrong column name. A instead C + let invalid_schema_wrong_col_name = + Schema::new(vec![Field::new("a", DataType::Utf8, false)]); + + assert!(record_batch + .clone() + .with_schema_unchecked(invalid_schema_wrong_col_name.into()) + .unwrap() + .column_by_name("c") + .is_none()); + + // Valid schema + let valid_schema = Schema::new(vec![Field::new("c", DataType::Utf8, false)]); + + assert_eq!( + record_batch + .clone() + .with_schema_unchecked(valid_schema.into()) + .unwrap() + .column_by_name("c"), + record_batch.column_by_name("c") + ); + } }