Skip to content

Commit

Permalink
Pass target arrow type to array_decoder_factory
Browse files Browse the repository at this point in the history
  • Loading branch information
progval committed Jun 4, 2024
1 parent ac5a8ab commit b131626
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 72 deletions.
6 changes: 2 additions & 4 deletions src/array_decoder/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,16 @@ pub struct ListArrayDecoder {
}

impl ListArrayDecoder {
pub fn new(column: &Column, stripe: &Stripe) -> Result<Self> {
pub fn new(column: &Column, field: Arc<Field>, stripe: &Stripe) -> Result<Self> {
let present = get_present_vec(column, stripe)?
.map(|iter| Box::new(iter.into_iter()) as Box<dyn Iterator<Item = bool> + Send>);

let child = &column.children()[0];
let inner = array_decoder_factory(child, stripe)?;
let inner = array_decoder_factory(child, field.clone(), stripe)?;

let reader = stripe.stream_map().get(column, Kind::Length);
let lengths = get_rle_reader(column, reader)?;

let field = Arc::new(Field::from(child));

Ok(Self {
inner,
present,
Expand Down
22 changes: 8 additions & 14 deletions src/array_decoder/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,24 @@ pub struct MapArrayDecoder {
}

impl MapArrayDecoder {
pub fn new(column: &Column, stripe: &Stripe) -> Result<Self> {
pub fn new(
column: &Column,
keys_field: Arc<Field>,
values_field: Arc<Field>,
stripe: &Stripe,
) -> Result<Self> {
let present = get_present_vec(column, stripe)?
.map(|iter| Box::new(iter.into_iter()) as Box<dyn Iterator<Item = bool> + Send>);

let keys_column = &column.children()[0];
let keys = array_decoder_factory(keys_column, stripe)?;
let keys = array_decoder_factory(keys_column, keys_field.clone(), stripe)?;

let values_column = &column.children()[1];
let values = array_decoder_factory(values_column, stripe)?;
let values = array_decoder_factory(values_column, values_field.clone(), stripe)?;

let reader = stripe.stream_map().get(column, Kind::Length);
let lengths = get_rle_reader(column, reader)?;

// TODO: should it be "keys" and "values" (like arrow-rs)
// or "key" and "value" like PyArrow and in Schema.fbs?
let keys_field = Field::new("keys", keys_column.data_type().to_arrow_data_type(), false);
let keys_field = Arc::new(keys_field);
let values_field = Field::new(
"values",
values_column.data_type().to_arrow_data_type(),
true,
);
let values_field = Arc::new(values_field);

let fields = Fields::from(vec![keys_field, values_field]);

Ok(Self {
Expand Down
155 changes: 146 additions & 9 deletions src/array_decoder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ use std::sync::Arc;
use arrow::array::{ArrayRef, BooleanArray, BooleanBuilder, PrimitiveArray, PrimitiveBuilder};
use arrow::buffer::NullBuffer;
use arrow::datatypes::{ArrowPrimitiveType, Decimal128Type, UInt64Type};
use arrow::datatypes::{DataType as ArrowDataType, Field};
use arrow::datatypes::{
Date32Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, SchemaRef,
};
use arrow::record_batch::RecordBatch;
use snafu::ResultExt;
use snafu::{ensure, ResultExt};

use crate::column::{get_present_vec, Column};
use crate::error::{self, ArrowSnafu, Result};
use crate::error::{self, ArrowSnafu, MismatchedSchemaSnafu, UnsupportedTypeVariantSnafu, UnexpectedSnafu, Result};
use crate::proto::stream::Kind;
use crate::reader::decode::boolean_rle::BooleanIter;
use crate::reader::decode::byte_rle::ByteRleIter;
Expand Down Expand Up @@ -316,63 +317,130 @@ pub trait ArrayBatchDecoder: Send {

pub fn array_decoder_factory(
column: &Column,
field: Arc<Field>,
stripe: &Stripe,
) -> Result<Box<dyn ArrayBatchDecoder>> {
let field_type = field.data_type().clone();
let decoder: Box<dyn ArrayBatchDecoder> = match column.data_type() {
// TODO: try make branches more generic, reduce duplication
DataType::Boolean { .. } => {
ensure!(
field_type == ArrowDataType::Boolean,
MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type
}
);
let iter = stripe.stream_map().get(column, Kind::Data);
let iter = Box::new(BooleanIter::new(iter));
let present = get_present_vec(column, stripe)?
.map(|iter| Box::new(iter.into_iter()) as Box<dyn Iterator<Item = bool> + Send>);
Box::new(BooleanArrayDecoder::new(iter, present))
}
DataType::Byte { .. } => {
ensure!(
field_type == ArrowDataType::Int8,
MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type
}
);
let iter = stripe.stream_map().get(column, Kind::Data);
let iter = Box::new(ByteRleIter::new(iter).map(|value| value.map(|value| value as i8)));
let present = get_present_vec(column, stripe)?
.map(|iter| Box::new(iter.into_iter()) as Box<dyn Iterator<Item = bool> + Send>);
Box::new(Int8ArrayDecoder::new(iter, present))
}
DataType::Short { .. } => {
ensure!(
field_type == ArrowDataType::Int16,
MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type
}
);
let iter = stripe.stream_map().get(column, Kind::Data);
let iter = get_rle_reader(column, iter)?;
let present = get_present_vec(column, stripe)?
.map(|iter| Box::new(iter.into_iter()) as Box<dyn Iterator<Item = bool> + Send>);
Box::new(Int16ArrayDecoder::new(iter, present))
}
DataType::Int { .. } => {
ensure!(
field_type == ArrowDataType::Int32,
MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type
}
);
let iter = stripe.stream_map().get(column, Kind::Data);
let iter = get_rle_reader(column, iter)?;
let present = get_present_vec(column, stripe)?
.map(|iter| Box::new(iter.into_iter()) as Box<dyn Iterator<Item = bool> + Send>);
Box::new(Int32ArrayDecoder::new(iter, present))
}
DataType::Long { .. } => {
ensure!(
field_type == ArrowDataType::Int64,
MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type
}
);
let iter = stripe.stream_map().get(column, Kind::Data);
let iter = get_rle_reader(column, iter)?;
let present = get_present_vec(column, stripe)?
.map(|iter| Box::new(iter.into_iter()) as Box<dyn Iterator<Item = bool> + Send>);
Box::new(Int64ArrayDecoder::new(iter, present))
}
DataType::Float { .. } => {
ensure!(
field_type == ArrowDataType::Float32,
MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type
}
);
let iter = stripe.stream_map().get(column, Kind::Data);
let iter = Box::new(FloatIter::new(iter, stripe.number_of_rows()));
let present = get_present_vec(column, stripe)?
.map(|iter| Box::new(iter.into_iter()) as Box<dyn Iterator<Item = bool> + Send>);
Box::new(Float32ArrayDecoder::new(iter, present))
}
DataType::Double { .. } => {
ensure!(
field_type == ArrowDataType::Float64,
MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type
}
);
let iter = stripe.stream_map().get(column, Kind::Data);
let iter = Box::new(FloatIter::new(iter, stripe.number_of_rows()));
let present = get_present_vec(column, stripe)?
.map(|iter| Box::new(iter.into_iter()) as Box<dyn Iterator<Item = bool> + Send>);
Box::new(Float64ArrayDecoder::new(iter, present))
}
DataType::String { .. } | DataType::Varchar { .. } | DataType::Char { .. } => {
ensure!(
field_type == ArrowDataType::Utf8,
MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type
}
);
new_string_decoder(column, stripe)?
}
DataType::Binary { .. } => new_binary_decoder(column, stripe)?,
DataType::Binary { .. } => {
ensure!(
field_type == ArrowDataType::Binary,
MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type
}
);
new_binary_decoder(column, stripe)?
}
DataType::Decimal {
precision, scale, ..
} => new_decimal_decoder(column, stripe, *precision, *scale)?,
Expand All @@ -388,10 +456,75 @@ pub fn array_decoder_factory(
.map(|iter| Box::new(iter.into_iter()) as Box<dyn Iterator<Item = bool> + Send>);
Box::new(DateArrayDecoder::new(iter, present))
}
DataType::Struct { .. } => Box::new(StructArrayDecoder::new(column, stripe)?),
DataType::List { .. } => Box::new(ListArrayDecoder::new(column, stripe)?),
DataType::Map { .. } => Box::new(MapArrayDecoder::new(column, stripe)?),
DataType::Union { .. } => Box::new(UnionArrayDecoder::new(column, stripe)?),
DataType::Struct { .. } => match field_type {
ArrowDataType::Struct(fields) => {
Box::new(StructArrayDecoder::new(column, fields, stripe)?)
}
_ => MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type,
}
.fail()?,
},
DataType::List { .. } => {
match field_type {
ArrowDataType::List(field) => {
Box::new(ListArrayDecoder::new(column, field, stripe)?)
}
// TODO: add support for ArrowDataType::LargeList
_ => MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type,
}
.fail()?,
}
}
DataType::Map { .. } => match field_type {
ArrowDataType::Map(entries, sorted) => {
ensure!(!sorted, UnsupportedTypeVariantSnafu { msg: "Sorted map" });
match entries.data_type() {
ArrowDataType::Struct(entries) => {
ensure!(
entries.len() == 2,
UnexpectedSnafu {
msg: format!(
"arrow Map with {} columns per entry (expected 2)",
entries.len()
)
}
);
let keys_field = entries[0].clone();
let values_field = entries[1].clone();

Box::new(MapArrayDecoder::new(
column,
keys_field,
values_field,
stripe,
)?)
}
_ => UnexpectedSnafu {
msg: format!("arrow Map with non-Struct entry type"),
}
.fail()?,
}
}
_ => MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type,
}
.fail()?,
},
DataType::Union { .. } => match field_type {
ArrowDataType::Union(fields, _) => {
Box::new(UnionArrayDecoder::new(column, fields, stripe)?)
}
_ => MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type,
}
.fail()?,
},
};

Ok(decoder)
Expand Down Expand Up @@ -440,8 +573,12 @@ impl NaiveStripeDecoder {
let mut decoders = Vec::with_capacity(stripe.columns().len());
let number_of_rows = stripe.number_of_rows();

for col in stripe.columns() {
let decoder = array_decoder_factory(col, &stripe)?;
for (col, field) in stripe
.columns()
.iter()
.zip(schema_ref.fields.iter().cloned())
{
let decoder = array_decoder_factory(col, field, &stripe)?;
decoders.push(decoder);
}

Expand Down
17 changes: 6 additions & 11 deletions src/array_decoder/struct_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;
use arrow::{
array::{ArrayRef, StructArray},
buffer::NullBuffer,
datatypes::{Field, Fields},
datatypes::Fields,
};
use snafu::ResultExt;

Expand All @@ -20,24 +20,19 @@ pub struct StructArrayDecoder {
}

impl StructArrayDecoder {
pub fn new(column: &Column, stripe: &Stripe) -> Result<Self> {
pub fn new(column: &Column, fields: Fields, stripe: &Stripe) -> Result<Self> {
println!("StructArrayDecoder column: {:#?}", column);
println!("StructArrayDecoder fields: {:#?}", fields);
let present = get_present_vec(column, stripe)?
.map(|iter| Box::new(iter.into_iter()) as Box<dyn Iterator<Item = bool> + Send>);

let decoders = column
.children()
.iter()
.map(|child| array_decoder_factory(child, stripe))
.zip(fields.iter().cloned())
.map(|(child, field)| array_decoder_factory(child, field, stripe))
.collect::<Result<Vec<_>>>()?;

let fields = column
.children()
.into_iter()
.map(Field::from)
.map(Arc::new)
.collect::<Vec<_>>();
let fields = Fields::from(fields);

Ok(Self {
decoders,
present,
Expand Down
Loading

0 comments on commit b131626

Please sign in to comment.