diff --git a/src/array_decoder/list.rs b/src/array_decoder/list.rs index 25dd266..275e224 100644 --- a/src/array_decoder/list.rs +++ b/src/array_decoder/list.rs @@ -23,18 +23,16 @@ pub struct ListArrayDecoder { } impl ListArrayDecoder { - pub fn new(column: &Column, stripe: &Stripe) -> Result { + pub fn new(column: &Column, field: Arc, stripe: &Stripe) -> Result { let present = get_present_vec(column, stripe)? .map(|iter| Box::new(iter.into_iter()) as Box + Send>); let child = &column.children()[0]; - let inner = array_decoder_factory(child, stripe)?; + let inner = array_decoder_factory(child, field.clone(), stripe)?; let reader = stripe.stream_map().get(column, Kind::Length); let lengths = get_rle_reader(column, reader)?; - let field = Arc::new(Field::from(child)); - Ok(Self { inner, present, diff --git a/src/array_decoder/map.rs b/src/array_decoder/map.rs index 9099615..9a0d182 100644 --- a/src/array_decoder/map.rs +++ b/src/array_decoder/map.rs @@ -23,30 +23,24 @@ pub struct MapArrayDecoder { } impl MapArrayDecoder { - pub fn new(column: &Column, stripe: &Stripe) -> Result { + pub fn new( + column: &Column, + keys_field: Arc, + values_field: Arc, + stripe: &Stripe, + ) -> Result { let present = get_present_vec(column, stripe)? .map(|iter| Box::new(iter.into_iter()) as Box + Send>); let keys_column = &column.children()[0]; - let keys = array_decoder_factory(keys_column, stripe)?; + let keys = array_decoder_factory(keys_column, keys_field.clone(), stripe)?; let values_column = &column.children()[1]; - let values = array_decoder_factory(values_column, stripe)?; + let values = array_decoder_factory(values_column, values_field.clone(), stripe)?; let reader = stripe.stream_map().get(column, Kind::Length); let lengths = get_rle_reader(column, reader)?; - // TODO: should it be "keys" and "values" (like arrow-rs) - // or "key" and "value" like PyArrow and in Schema.fbs? - let keys_field = Field::new("keys", keys_column.data_type().to_arrow_data_type(), false); - let keys_field = Arc::new(keys_field); - let values_field = Field::new( - "values", - values_column.data_type().to_arrow_data_type(), - true, - ); - let values_field = Arc::new(values_field); - let fields = Fields::from(vec![keys_field, values_field]); Ok(Self { diff --git a/src/array_decoder/mod.rs b/src/array_decoder/mod.rs index 77a92c6..a556306 100644 --- a/src/array_decoder/mod.rs +++ b/src/array_decoder/mod.rs @@ -3,14 +3,15 @@ use std::sync::Arc; use arrow::array::{ArrayRef, BooleanArray, BooleanBuilder, PrimitiveArray, PrimitiveBuilder}; use arrow::buffer::NullBuffer; use arrow::datatypes::{ArrowPrimitiveType, Decimal128Type, UInt64Type}; +use arrow::datatypes::{DataType as ArrowDataType, Field}; use arrow::datatypes::{ Date32Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, SchemaRef, }; use arrow::record_batch::RecordBatch; -use snafu::ResultExt; +use snafu::{ensure, ResultExt}; use crate::column::{get_present_vec, Column}; -use crate::error::{self, ArrowSnafu, Result}; +use crate::error::{self, ArrowSnafu, MismatchedSchemaSnafu, UnsupportedTypeVariantSnafu, UnexpectedSnafu, Result}; use crate::proto::stream::Kind; use crate::reader::decode::boolean_rle::BooleanIter; use crate::reader::decode::byte_rle::ByteRleIter; @@ -316,11 +317,20 @@ pub trait ArrayBatchDecoder: Send { pub fn array_decoder_factory( column: &Column, + field: Arc, stripe: &Stripe, ) -> Result> { + let field_type = field.data_type().clone(); let decoder: Box = match column.data_type() { // TODO: try make branches more generic, reduce duplication DataType::Boolean { .. } => { + ensure!( + field_type == ArrowDataType::Boolean, + MismatchedSchemaSnafu { + orc_type: column.data_type().clone(), + arrow_type: field_type + } + ); let iter = stripe.stream_map().get(column, Kind::Data); let iter = Box::new(BooleanIter::new(iter)); let present = get_present_vec(column, stripe)? @@ -328,6 +338,13 @@ pub fn array_decoder_factory( Box::new(BooleanArrayDecoder::new(iter, present)) } DataType::Byte { .. } => { + ensure!( + field_type == ArrowDataType::Int8, + MismatchedSchemaSnafu { + orc_type: column.data_type().clone(), + arrow_type: field_type + } + ); let iter = stripe.stream_map().get(column, Kind::Data); let iter = Box::new(ByteRleIter::new(iter).map(|value| value.map(|value| value as i8))); let present = get_present_vec(column, stripe)? @@ -335,6 +352,13 @@ pub fn array_decoder_factory( Box::new(Int8ArrayDecoder::new(iter, present)) } DataType::Short { .. } => { + ensure!( + field_type == ArrowDataType::Int16, + MismatchedSchemaSnafu { + orc_type: column.data_type().clone(), + arrow_type: field_type + } + ); let iter = stripe.stream_map().get(column, Kind::Data); let iter = get_rle_reader(column, iter)?; let present = get_present_vec(column, stripe)? @@ -342,6 +366,13 @@ pub fn array_decoder_factory( Box::new(Int16ArrayDecoder::new(iter, present)) } DataType::Int { .. } => { + ensure!( + field_type == ArrowDataType::Int32, + MismatchedSchemaSnafu { + orc_type: column.data_type().clone(), + arrow_type: field_type + } + ); let iter = stripe.stream_map().get(column, Kind::Data); let iter = get_rle_reader(column, iter)?; let present = get_present_vec(column, stripe)? @@ -349,6 +380,13 @@ pub fn array_decoder_factory( Box::new(Int32ArrayDecoder::new(iter, present)) } DataType::Long { .. } => { + ensure!( + field_type == ArrowDataType::Int64, + MismatchedSchemaSnafu { + orc_type: column.data_type().clone(), + arrow_type: field_type + } + ); let iter = stripe.stream_map().get(column, Kind::Data); let iter = get_rle_reader(column, iter)?; let present = get_present_vec(column, stripe)? @@ -356,6 +394,13 @@ pub fn array_decoder_factory( Box::new(Int64ArrayDecoder::new(iter, present)) } DataType::Float { .. } => { + ensure!( + field_type == ArrowDataType::Float32, + MismatchedSchemaSnafu { + orc_type: column.data_type().clone(), + arrow_type: field_type + } + ); let iter = stripe.stream_map().get(column, Kind::Data); let iter = Box::new(FloatIter::new(iter, stripe.number_of_rows())); let present = get_present_vec(column, stripe)? @@ -363,6 +408,13 @@ pub fn array_decoder_factory( Box::new(Float32ArrayDecoder::new(iter, present)) } DataType::Double { .. } => { + ensure!( + field_type == ArrowDataType::Float64, + MismatchedSchemaSnafu { + orc_type: column.data_type().clone(), + arrow_type: field_type + } + ); let iter = stripe.stream_map().get(column, Kind::Data); let iter = Box::new(FloatIter::new(iter, stripe.number_of_rows())); let present = get_present_vec(column, stripe)? @@ -370,9 +422,25 @@ pub fn array_decoder_factory( Box::new(Float64ArrayDecoder::new(iter, present)) } DataType::String { .. } | DataType::Varchar { .. } | DataType::Char { .. } => { + ensure!( + field_type == ArrowDataType::Utf8, + MismatchedSchemaSnafu { + orc_type: column.data_type().clone(), + arrow_type: field_type + } + ); new_string_decoder(column, stripe)? } - DataType::Binary { .. } => new_binary_decoder(column, stripe)?, + DataType::Binary { .. } => { + ensure!( + field_type == ArrowDataType::Binary, + MismatchedSchemaSnafu { + orc_type: column.data_type().clone(), + arrow_type: field_type + } + ); + new_binary_decoder(column, stripe)? + } DataType::Decimal { precision, scale, .. } => new_decimal_decoder(column, stripe, *precision, *scale)?, @@ -388,10 +456,75 @@ pub fn array_decoder_factory( .map(|iter| Box::new(iter.into_iter()) as Box + Send>); Box::new(DateArrayDecoder::new(iter, present)) } - DataType::Struct { .. } => Box::new(StructArrayDecoder::new(column, stripe)?), - DataType::List { .. } => Box::new(ListArrayDecoder::new(column, stripe)?), - DataType::Map { .. } => Box::new(MapArrayDecoder::new(column, stripe)?), - DataType::Union { .. } => Box::new(UnionArrayDecoder::new(column, stripe)?), + DataType::Struct { .. } => match field_type { + ArrowDataType::Struct(fields) => { + Box::new(StructArrayDecoder::new(column, fields, stripe)?) + } + _ => MismatchedSchemaSnafu { + orc_type: column.data_type().clone(), + arrow_type: field_type, + } + .fail()?, + }, + DataType::List { .. } => { + match field_type { + ArrowDataType::List(field) => { + Box::new(ListArrayDecoder::new(column, field, stripe)?) + } + // TODO: add support for ArrowDataType::LargeList + _ => MismatchedSchemaSnafu { + orc_type: column.data_type().clone(), + arrow_type: field_type, + } + .fail()?, + } + } + DataType::Map { .. } => match field_type { + ArrowDataType::Map(entries, sorted) => { + ensure!(!sorted, UnsupportedTypeVariantSnafu { msg: "Sorted map" }); + match entries.data_type() { + ArrowDataType::Struct(entries) => { + ensure!( + entries.len() == 2, + UnexpectedSnafu { + msg: format!( + "arrow Map with {} columns per entry (expected 2)", + entries.len() + ) + } + ); + let keys_field = entries[0].clone(); + let values_field = entries[1].clone(); + + Box::new(MapArrayDecoder::new( + column, + keys_field, + values_field, + stripe, + )?) + } + _ => UnexpectedSnafu { + msg: format!("arrow Map with non-Struct entry type"), + } + .fail()?, + } + } + _ => MismatchedSchemaSnafu { + orc_type: column.data_type().clone(), + arrow_type: field_type, + } + .fail()?, + }, + DataType::Union { .. } => match field_type { + ArrowDataType::Union(fields, _) => { + Box::new(UnionArrayDecoder::new(column, fields, stripe)?) + } + _ => MismatchedSchemaSnafu { + orc_type: column.data_type().clone(), + arrow_type: field_type, + } + .fail()?, + }, }; Ok(decoder) @@ -440,8 +573,12 @@ impl NaiveStripeDecoder { let mut decoders = Vec::with_capacity(stripe.columns().len()); let number_of_rows = stripe.number_of_rows(); - for col in stripe.columns() { - let decoder = array_decoder_factory(col, &stripe)?; + for (col, field) in stripe + .columns() + .iter() + .zip(schema_ref.fields.iter().cloned()) + { + let decoder = array_decoder_factory(col, field, &stripe)?; decoders.push(decoder); } diff --git a/src/array_decoder/struct_decoder.rs b/src/array_decoder/struct_decoder.rs index 1c08196..9905707 100644 --- a/src/array_decoder/struct_decoder.rs +++ b/src/array_decoder/struct_decoder.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use arrow::{ array::{ArrayRef, StructArray}, buffer::NullBuffer, - datatypes::{Field, Fields}, + datatypes::Fields, }; use snafu::ResultExt; @@ -20,24 +20,19 @@ pub struct StructArrayDecoder { } impl StructArrayDecoder { - pub fn new(column: &Column, stripe: &Stripe) -> Result { + pub fn new(column: &Column, fields: Fields, stripe: &Stripe) -> Result { + println!("StructArrayDecoder column: {:#?}", column); + println!("StructArrayDecoder fields: {:#?}", fields); let present = get_present_vec(column, stripe)? .map(|iter| Box::new(iter.into_iter()) as Box + Send>); let decoders = column .children() .iter() - .map(|child| array_decoder_factory(child, stripe)) + .zip(fields.iter().cloned()) + .map(|(child, field)| array_decoder_factory(child, field, stripe)) .collect::>>()?; - let fields = column - .children() - .into_iter() - .map(Field::from) - .map(Arc::new) - .collect::>(); - let fields = Fields::from(fields); - Ok(Self { decoders, present, diff --git a/src/array_decoder/union.rs b/src/array_decoder/union.rs index 009a7fd..6b8ced5 100644 --- a/src/array_decoder/union.rs +++ b/src/array_decoder/union.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use arrow::array::{ArrayRef, UnionArray}; use arrow::buffer::Buffer; -use arrow::datatypes::Field; +use arrow::datatypes::UnionFields; use snafu::ResultExt; use crate::column::{get_present_vec, Column}; @@ -18,14 +18,14 @@ use super::{array_decoder_factory, derive_present_vec, ArrayBatchDecoder}; pub struct UnionArrayDecoder { // fields and variants should have same length // TODO: encode this assumption into types - fields: Vec, + fields: UnionFields, variants: Vec>, tags: Box> + Send>, present: Option + Send>>, } impl UnionArrayDecoder { - pub fn new(column: &Column, stripe: &Stripe) -> Result { + pub fn new(column: &Column, fields: UnionFields, stripe: &Stripe) -> Result { let present = get_present_vec(column, stripe)? .map(|iter| Box::new(iter.into_iter()) as Box + Send>); @@ -35,21 +35,10 @@ impl UnionArrayDecoder { let variants = column .children() .iter() - .map(|child| array_decoder_factory(child, stripe)) + .zip(fields.iter()) + .map(|(child, (_id, field))| array_decoder_factory(child, field.clone(), stripe)) .collect::>>()?; - let fields = column - .children() - .into_iter() - .enumerate() - .map(|(idx, col)| { - let dt = col.data_type().to_arrow_data_type(); - // Naming matching what's set in schema.rs - // TODO: unify this across the files - Field::new(format!("_union_{idx}"), dt, true) - }) - .collect::>(); - Ok(Self { fields, variants, @@ -134,8 +123,8 @@ impl ArrayBatchDecoder for UnionArrayDecoder { let type_ids = Buffer::from_vec(tags); let child_arrays = self .fields - .clone() - .into_iter() + .iter() + .map(|(_id, field)| field.as_ref().clone()) .zip(child_arrays) .collect::>(); let array = UnionArray::try_new(&field_type_ids, type_ids, None, child_arrays) diff --git a/src/column.rs b/src/column.rs index 6deb9a8..393ab4f 100644 --- a/src/column.rs +++ b/src/column.rs @@ -1,6 +1,5 @@ use std::sync::Arc; -use arrow::datatypes::Field; use bytes::Bytes; use snafu::ResultExt; @@ -12,7 +11,7 @@ use crate::reader::ChunkReader; use crate::schema::DataType; use crate::stripe::Stripe; -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Column { number_of_rows: u64, footer: Arc, @@ -20,20 +19,6 @@ pub struct Column { data_type: DataType, } -impl From for Field { - fn from(value: Column) -> Self { - let dt = value.data_type.to_arrow_data_type(); - Field::new(value.name, dt, true) - } -} - -impl From<&Column> for Field { - fn from(value: &Column) -> Self { - let dt = value.data_type.to_arrow_data_type(); - Field::new(value.name.clone(), dt, true) - } -} - impl Column { pub fn new( name: &str, diff --git a/src/error.rs b/src/error.rs index b7ff01c..ac7687b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,12 +1,14 @@ use std::io; use std::string::FromUtf8Error; +use arrow::datatypes::DataType as ArrowDataType; use arrow::error::ArrowError; use snafu::prelude::*; use snafu::Location; use crate::proto; use crate::proto::r#type::Kind; +use crate::schema::DataType; #[derive(Debug, Snafu)] #[snafu(visibility(pub))] @@ -73,12 +75,29 @@ pub enum OrcError { #[snafu(display("unsupported type: {:?}", kind))] UnsupportedType { location: Location, kind: Kind }, + #[snafu(display("unsupported type variant: {}", msg))] + UnsupportedTypeVariant { + location: Location, + msg: &'static str, + }, + #[snafu(display("Field not found: {:?}", name))] FieldNotFound { location: Location, name: String }, #[snafu(display("Invalid column : {:?}", name))] InvalidColumn { location: Location, name: String }, + #[snafu(display( + "Cannot decode ORC type {:?} into Arrow type {:?}", + orc_type, + arrow_type, + ))] + MismatchedSchema { + location: Location, + orc_type: DataType, + arrow_type: ArrowDataType, + }, + #[snafu(display("Invalid encoding for column '{}': {:?}", name, encoding))] InvalidColumnEncoding { location: Location,