Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass target arrow type to array_decoder_factory #92

Merged
merged 2 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/array_decoder/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,16 @@ pub struct ListArrayDecoder {
}

impl ListArrayDecoder {
pub fn new(column: &Column, stripe: &Stripe) -> Result<Self> {
pub fn new(column: &Column, field: Arc<Field>, stripe: &Stripe) -> Result<Self> {
let present = get_present_vec(column, stripe)?
.map(|iter| Box::new(iter.into_iter()) as Box<dyn Iterator<Item = bool> + Send>);

let child = &column.children()[0];
let inner = array_decoder_factory(child, stripe)?;
let inner = array_decoder_factory(child, field.clone(), stripe)?;

let reader = stripe.stream_map().get(column, Kind::Length);
let lengths = get_rle_reader(column, reader)?;

let field = Arc::new(Field::from(child));

Ok(Self {
inner,
present,
Expand Down
22 changes: 8 additions & 14 deletions src/array_decoder/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,24 @@ pub struct MapArrayDecoder {
}

impl MapArrayDecoder {
pub fn new(column: &Column, stripe: &Stripe) -> Result<Self> {
pub fn new(
column: &Column,
keys_field: Arc<Field>,
values_field: Arc<Field>,
stripe: &Stripe,
) -> Result<Self> {
let present = get_present_vec(column, stripe)?
.map(|iter| Box::new(iter.into_iter()) as Box<dyn Iterator<Item = bool> + Send>);

let keys_column = &column.children()[0];
let keys = array_decoder_factory(keys_column, stripe)?;
let keys = array_decoder_factory(keys_column, keys_field.clone(), stripe)?;

let values_column = &column.children()[1];
let values = array_decoder_factory(values_column, stripe)?;
let values = array_decoder_factory(values_column, values_field.clone(), stripe)?;

let reader = stripe.stream_map().get(column, Kind::Length);
let lengths = get_rle_reader(column, reader)?;

// TODO: should it be "keys" and "values" (like arrow-rs)
// or "key" and "value" like PyArrow and in Schema.fbs?
let keys_field = Field::new("keys", keys_column.data_type().to_arrow_data_type(), false);
let keys_field = Arc::new(keys_field);
let values_field = Field::new(
"values",
values_column.data_type().to_arrow_data_type(),
true,
);
let values_field = Arc::new(values_field);

let fields = Fields::from(vec![keys_field, values_field]);

Ok(Self {
Expand Down
195 changes: 184 additions & 11 deletions src/array_decoder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@ use std::sync::Arc;
use arrow::array::{ArrayRef, BooleanArray, BooleanBuilder, PrimitiveArray, PrimitiveBuilder};
use arrow::buffer::NullBuffer;
use arrow::datatypes::{ArrowPrimitiveType, Decimal128Type, UInt64Type};
use arrow::datatypes::{DataType as ArrowDataType, Field};
use arrow::datatypes::{
Date32Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, SchemaRef,
TimeUnit,
};
use arrow::record_batch::RecordBatch;
use snafu::ResultExt;
use snafu::{ensure, ResultExt};

use crate::column::{get_present_vec, Column};
use crate::error::{self, ArrowSnafu, Result};
use crate::error::{
self, ArrowSnafu, MismatchedSchemaSnafu, Result, UnexpectedSnafu, UnsupportedTypeVariantSnafu,
};
use crate::proto::stream::Kind;
use crate::reader::decode::boolean_rle::BooleanIter;
use crate::reader::decode::byte_rle::ByteRleIter;
Expand Down Expand Up @@ -316,82 +320,247 @@ pub trait ArrayBatchDecoder: Send {

pub fn array_decoder_factory(
column: &Column,
field: Arc<Field>,
stripe: &Stripe,
) -> Result<Box<dyn ArrayBatchDecoder>> {
Comment on lines 321 to 325
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The key API change 👍

let field_type = field.data_type().clone();
let decoder: Box<dyn ArrayBatchDecoder> = match column.data_type() {
// TODO: try make branches more generic, reduce duplication
DataType::Boolean { .. } => {
ensure!(
field_type == ArrowDataType::Boolean,
MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type
}
);
let iter = stripe.stream_map().get(column, Kind::Data);
let iter = Box::new(BooleanIter::new(iter));
let present = get_present_vec(column, stripe)?
.map(|iter| Box::new(iter.into_iter()) as Box<dyn Iterator<Item = bool> + Send>);
Box::new(BooleanArrayDecoder::new(iter, present))
}
DataType::Byte { .. } => {
ensure!(
field_type == ArrowDataType::Int8,
MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type
}
);
let iter = stripe.stream_map().get(column, Kind::Data);
let iter = Box::new(ByteRleIter::new(iter).map(|value| value.map(|value| value as i8)));
let present = get_present_vec(column, stripe)?
.map(|iter| Box::new(iter.into_iter()) as Box<dyn Iterator<Item = bool> + Send>);
Box::new(Int8ArrayDecoder::new(iter, present))
}
DataType::Short { .. } => {
ensure!(
field_type == ArrowDataType::Int16,
MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type
}
);
let iter = stripe.stream_map().get(column, Kind::Data);
let iter = get_rle_reader(column, iter)?;
let present = get_present_vec(column, stripe)?
.map(|iter| Box::new(iter.into_iter()) as Box<dyn Iterator<Item = bool> + Send>);
Box::new(Int16ArrayDecoder::new(iter, present))
}
DataType::Int { .. } => {
ensure!(
field_type == ArrowDataType::Int32,
MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type
}
);
let iter = stripe.stream_map().get(column, Kind::Data);
let iter = get_rle_reader(column, iter)?;
let present = get_present_vec(column, stripe)?
.map(|iter| Box::new(iter.into_iter()) as Box<dyn Iterator<Item = bool> + Send>);
Box::new(Int32ArrayDecoder::new(iter, present))
}
DataType::Long { .. } => {
ensure!(
field_type == ArrowDataType::Int64,
MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type
}
);
let iter = stripe.stream_map().get(column, Kind::Data);
let iter = get_rle_reader(column, iter)?;
let present = get_present_vec(column, stripe)?
.map(|iter| Box::new(iter.into_iter()) as Box<dyn Iterator<Item = bool> + Send>);
Box::new(Int64ArrayDecoder::new(iter, present))
}
DataType::Float { .. } => {
ensure!(
field_type == ArrowDataType::Float32,
MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type
}
);
let iter = stripe.stream_map().get(column, Kind::Data);
let iter = Box::new(FloatIter::new(iter, stripe.number_of_rows()));
let present = get_present_vec(column, stripe)?
.map(|iter| Box::new(iter.into_iter()) as Box<dyn Iterator<Item = bool> + Send>);
Box::new(Float32ArrayDecoder::new(iter, present))
}
DataType::Double { .. } => {
ensure!(
field_type == ArrowDataType::Float64,
MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type
}
);
let iter = stripe.stream_map().get(column, Kind::Data);
let iter = Box::new(FloatIter::new(iter, stripe.number_of_rows()));
let present = get_present_vec(column, stripe)?
.map(|iter| Box::new(iter.into_iter()) as Box<dyn Iterator<Item = bool> + Send>);
Box::new(Float64ArrayDecoder::new(iter, present))
}
DataType::String { .. } | DataType::Varchar { .. } | DataType::Char { .. } => {
ensure!(
field_type == ArrowDataType::Utf8,
MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type
}
);
new_string_decoder(column, stripe)?
}
DataType::Binary { .. } => new_binary_decoder(column, stripe)?,
DataType::Binary { .. } => {
ensure!(
field_type == ArrowDataType::Binary,
MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type
}
);
new_binary_decoder(column, stripe)?
}
DataType::Decimal {
precision, scale, ..
} => new_decimal_decoder(column, stripe, *precision, *scale)?,
DataType::Timestamp { .. } => new_timestamp_decoder(column, stripe)?,
} => {
ensure!(
field_type == ArrowDataType::Decimal128(*precision as u8, *scale as i8),
MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type
}
);
new_decimal_decoder(column, stripe, *precision, *scale)?
}
DataType::Timestamp { .. } => {
// TODO: add support for any precision
ensure!(
field_type == ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type
}
);
new_timestamp_decoder(column, stripe)?
}
DataType::TimestampWithLocalTimezone { .. } => {
// TODO: add support for any precision and for arbitrary timezones
ensure!(
field_type == ArrowDataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type
}
);
new_timestamp_instant_decoder(column, stripe)?
}

DataType::Date { .. } => {
// TODO: allow Date64
ensure!(
field_type == ArrowDataType::Date32,
MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type
}
);
let iter = stripe.stream_map().get(column, Kind::Data);
let iter = get_rle_reader(column, iter)?;
let present = get_present_vec(column, stripe)?
.map(|iter| Box::new(iter.into_iter()) as Box<dyn Iterator<Item = bool> + Send>);
Box::new(DateArrayDecoder::new(iter, present))
}
DataType::Struct { .. } => Box::new(StructArrayDecoder::new(column, stripe)?),
DataType::List { .. } => Box::new(ListArrayDecoder::new(column, stripe)?),
DataType::Map { .. } => Box::new(MapArrayDecoder::new(column, stripe)?),
DataType::Union { .. } => Box::new(UnionArrayDecoder::new(column, stripe)?),
DataType::Struct { .. } => match field_type {
ArrowDataType::Struct(fields) => {
Box::new(StructArrayDecoder::new(column, fields, stripe)?)
}
_ => MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type,
}
.fail()?,
},
DataType::List { .. } => {
match field_type {
ArrowDataType::List(field) => {
Box::new(ListArrayDecoder::new(column, field, stripe)?)
}
// TODO: add support for ArrowDataType::LargeList
_ => MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type,
}
.fail()?,
}
}
DataType::Map { .. } => {
let ArrowDataType::Map(entries, sorted) = field_type else {
MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type,
}
.fail()?
};
ensure!(!sorted, UnsupportedTypeVariantSnafu { msg: "Sorted map" });
let ArrowDataType::Struct(entries) = entries.data_type() else {
UnexpectedSnafu {
msg: "arrow Map with non-Struct entry type".to_owned(),
}
.fail()?
};
ensure!(
entries.len() == 2,
UnexpectedSnafu {
msg: format!(
"arrow Map with {} columns per entry (expected 2)",
entries.len()
)
}
);
let keys_field = entries[0].clone();
let values_field = entries[1].clone();

Box::new(MapArrayDecoder::new(
column,
keys_field,
values_field,
stripe,
)?)
}
DataType::Union { .. } => match field_type {
ArrowDataType::Union(fields, _) => {
Box::new(UnionArrayDecoder::new(column, fields, stripe)?)
}
_ => MismatchedSchemaSnafu {
orc_type: column.data_type().clone(),
arrow_type: field_type,
}
.fail()?,
},
};

Ok(decoder)
Expand Down Expand Up @@ -440,8 +609,12 @@ impl NaiveStripeDecoder {
let mut decoders = Vec::with_capacity(stripe.columns().len());
let number_of_rows = stripe.number_of_rows();

for col in stripe.columns() {
let decoder = array_decoder_factory(col, &stripe)?;
for (col, field) in stripe
.columns()
.iter()
.zip(schema_ref.fields.iter().cloned())
{
let decoder = array_decoder_factory(col, field, &stripe)?;
decoders.push(decoder);
}

Expand Down
17 changes: 6 additions & 11 deletions src/array_decoder/struct_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;
use arrow::{
array::{ArrayRef, StructArray},
buffer::NullBuffer,
datatypes::{Field, Fields},
datatypes::Fields,
};
use snafu::ResultExt;

Expand All @@ -20,24 +20,19 @@ pub struct StructArrayDecoder {
}

impl StructArrayDecoder {
pub fn new(column: &Column, stripe: &Stripe) -> Result<Self> {
pub fn new(column: &Column, fields: Fields, stripe: &Stripe) -> Result<Self> {
println!("StructArrayDecoder column: {:#?}", column);
println!("StructArrayDecoder fields: {:#?}", fields);
progval marked this conversation as resolved.
Show resolved Hide resolved
let present = get_present_vec(column, stripe)?
.map(|iter| Box::new(iter.into_iter()) as Box<dyn Iterator<Item = bool> + Send>);

let decoders = column
.children()
.iter()
.map(|child| array_decoder_factory(child, stripe))
.zip(fields.iter().cloned())
.map(|(child, field)| array_decoder_factory(child, field, stripe))
.collect::<Result<Vec<_>>>()?;

let fields = column
.children()
.into_iter()
.map(Field::from)
.map(Arc::new)
.collect::<Vec<_>>();
let fields = Fields::from(fields);

Ok(Self {
decoders,
present,
Expand Down
Loading
Loading