diff --git a/.gitignore b/.gitignore index 009d65bc..d88fc2d0 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,6 @@ private/ /perf.* /flamegraph.svg + +# IDEA +.idea/ diff --git a/src/arrow_reader.rs b/src/arrow_reader.rs index 550ab971..db6827a5 100644 --- a/src/arrow_reader.rs +++ b/src/arrow_reader.rs @@ -16,6 +16,7 @@ // under the License. use std::collections::HashMap; +use std::ops::Range; use std::sync::Arc; use arrow::datatypes::SchemaRef; @@ -28,7 +29,7 @@ use crate::projection::ProjectionMask; use crate::reader::metadata::{read_metadata, FileMetadata}; use crate::reader::ChunkReader; use crate::schema::RootDataType; -use crate::stripe::Stripe; +use crate::stripe::{Stripe, StripeMetadata}; const DEFAULT_BATCH_SIZE: usize = 8192; @@ -38,6 +39,7 @@ pub struct ArrowReaderBuilder { pub(crate) batch_size: usize, pub(crate) projection: ProjectionMask, pub(crate) schema_ref: Option, + pub(crate) range: Option>, } impl ArrowReaderBuilder { @@ -48,6 +50,7 @@ impl ArrowReaderBuilder { batch_size: DEFAULT_BATCH_SIZE, projection: ProjectionMask::all(), schema_ref: None, + range: None, } } @@ -70,6 +73,11 @@ impl ArrowReaderBuilder { self } + pub fn with_range(mut self, range: Range) -> Self { + self.range = Some(range); + self + } + /// Returns the currently computed schema /// /// Unless [`with_schema`](Self::with_schema) was called, this is computed dynamically @@ -108,6 +116,7 @@ impl ArrowReaderBuilder { file_metadata: self.file_metadata, projected_data_type, stripe_index: 0, + range: self.range, }; ArrowReader { cursor, @@ -176,14 +185,32 @@ pub(crate) struct Cursor { pub file_metadata: Arc, pub projected_data_type: RootDataType, pub stripe_index: usize, + pub range: Option>, +} + +impl Cursor { + fn get_stripe_metadatas(&self) -> Vec { + if let Some(range) = self.range.clone() { + self.file_metadata + .stripe_metadatas() + .iter() + .filter(|info| { + let offset = info.offset() as usize; + !(offset < range.start || offset >= range.end) + }) + .map(|info| info.to_owned()) + .collect::>() + } else { + self.file_metadata.stripe_metadatas().to_vec() + } + } } impl Iterator for Cursor { type Item = Result; fn next(&mut self) -> Option { - self.file_metadata - .stripe_metadatas() + self.get_stripe_metadatas() .get(self.stripe_index) .map(|info| { let stripe = Stripe::new( diff --git a/src/async_arrow_reader.rs b/src/async_arrow_reader.rs index 6fe123a5..eda2e35e 100644 --- a/src/async_arrow_reader.rs +++ b/src/async_arrow_reader.rs @@ -104,6 +104,13 @@ impl StripeFactory { .cloned(); if let Some(info) = info { + if let Some(range) = self.inner.range.clone() { + let offset = info.offset() as usize; + if offset < range.start || offset >= range.end { + self.inner.stripe_index += 1; + return Ok((self, None)); + } + } match self.read_next_stripe_inner(&info).await { Ok(stripe) => Ok((self, Some(stripe))), Err(err) => Err(err), @@ -214,6 +221,7 @@ impl ArrowReaderBuilder { file_metadata: self.file_metadata, projected_data_type, stripe_index: 0, + range: self.range, }; ArrowStreamReader::new(cursor, self.batch_size, schema_ref) } diff --git a/src/datafusion/physical_exec.rs b/src/datafusion/physical_exec.rs index fef3f3fd..56fcbb86 100644 --- a/src/datafusion/physical_exec.rs +++ b/src/datafusion/physical_exec.rs @@ -151,11 +151,15 @@ impl FileOpener for OrcOpener { // Offset by 1 since index 0 is the root let projection = self.projection.iter().map(|i| i + 1).collect::>(); Ok(Box::pin(async move { - let builder = ArrowReaderBuilder::try_new_async(reader) + let mut builder = ArrowReaderBuilder::try_new_async(reader) .await .map_err(ArrowError::from)?; let projection_mask = ProjectionMask::roots(builder.file_metadata().root_data_type(), projection); + if let Some(range) = file_meta.range.clone() { + let range = range.start as usize..range.end as usize; + builder = builder.with_range(range); + } let reader = builder .with_batch_size(batch_size) .with_projection(projection_mask) diff --git a/tests/basic/main.rs b/tests/basic/main.rs index 6d9bb718..df315ba6 100644 --- a/tests/basic/main.rs +++ b/tests/basic/main.rs @@ -16,6 +16,7 @@ // under the License. use std::fs::File; +use std::ops::Range; use std::sync::Arc; use arrow::datatypes::{DataType, Decimal128Type, DecimalType, Field, Schema, TimeUnit}; @@ -48,11 +49,32 @@ async fn new_arrow_stream_reader_root(path: &str) -> ArrowStreamReader, +) -> ArrowStreamReader { + let f = tokio::fs::File::open(path).await.unwrap(); + ArrowReaderBuilder::try_new_async(f) + .await + .unwrap() + .with_range(range) + .build_async() +} + fn new_arrow_reader_root(path: &str) -> ArrowReader { let f = File::open(path).expect("no file found"); ArrowReaderBuilder::try_new(f).unwrap().build() } +fn new_arrow_reader_range(path: &str, range: Range) -> ArrowReader { + let f = File::open(path).expect("no file found"); + ArrowReaderBuilder::try_new(f) + .unwrap() + .with_range(range) + .build() +} + fn basic_path(path: &str) -> String { let dir = env!("CARGO_MANIFEST_DIR"); format!("{}/tests/basic/data/{}", dir, path) @@ -360,6 +382,44 @@ pub fn basic_test_0() { assert_batches_eq(&batch, &expected); } +#[test] +pub fn basic_test_with_range() { + let path = basic_path("test.orc"); + let reader = new_arrow_reader_range(&path, 0..2000); + let batch = reader.collect::, _>>().unwrap(); + + assert_eq!(5, batch[0].column(0).len()); +} + +#[test] +pub fn basic_test_with_range_without_data() { + let path = basic_path("test.orc"); + let reader = new_arrow_reader_range(&path, 100..2000); + let batch = reader.collect::, _>>().unwrap(); + + assert_eq!(0, batch.len()); +} + +#[cfg(feature = "async")] +#[tokio::test] +pub async fn async_basic_test_with_range() { + let path = basic_path("test.orc"); + let reader = new_arrow_stream_reader_range(&path, 0..2000).await; + let batch = reader.try_collect::>().await.unwrap(); + + assert_eq!(5, batch[0].column(0).len()); +} + +#[cfg(feature = "async")] +#[tokio::test] +pub async fn async_basic_test_with_range_without_data() { + let path = basic_path("test.orc"); + let reader = new_arrow_stream_reader_range(&path, 100..2000).await; + let batch = reader.try_collect::>().await.unwrap(); + + assert_eq!(0, batch.len()); +} + #[cfg(feature = "async")] #[tokio::test] pub async fn async_basic_test_0() {