From c559cc1162a1f64b3257eb977ff53471743d4e5e Mon Sep 17 00:00:00 2001 From: kamille Date: Mon, 9 Sep 2024 21:49:57 +0800 Subject: [PATCH] draft. --- datafusion/functions/src/string/common.rs | 122 +++++++++++++++++---- datafusion/functions/src/string/rtrim.rs | 40 +++++++ datafusion/functions/src/unicode/substr.rs | 2 +- 3 files changed, 140 insertions(+), 24 deletions(-) diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 9365a6d833319..e2b69b58ff015 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -27,11 +27,14 @@ use arrow::array::{ }; use arrow::buffer::{Buffer, MutableBuffer, NullBuffer}; use arrow::datatypes::DataType; +use arrow_buffer::{NullBufferBuilder, ScalarBuffer}; use datafusion_common::cast::{as_generic_string_array, as_string_view_array}; use datafusion_common::Result; use datafusion_common::{exec_err, ScalarValue}; use datafusion_expr::ColumnarValue; +use crate::unicode::substr::make_and_append_view; + pub(crate) enum TrimType { Left, Right, @@ -83,16 +86,33 @@ fn string_view_trim<'a, T: OffsetSizeTrait>( func: fn(&'a str, &'a str) -> &'a str, args: &'a [ArrayRef], ) -> Result { - let string_array = as_string_view_array(&args[0])?; + let string_view_array = as_string_view_array(&args[0])?; + let mut views_buf = Vec::with_capacity(string_view_array.len()); + let mut null_builder = NullBufferBuilder::new(string_view_array.len()); match args.len() { 1 => { - let result = string_array - .iter() - .map(|string| string.map(|string: &str| func(string, " "))) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) + for (idx, raw) in string_view_array.views().iter().enumerate() { + unsafe { + // Safety: + // idx is always smaller or equal to string_view_array.views.len() + let origin_str = string_view_array.value_unchecked(idx); + let trim_str = func(origin_str, " "); + + // Safety: + // `trim_str` is computed from `str::trim_xxx_matches`, + // and its addr is ensured to be >= `origin_str`'s + let start = trim_str.as_ptr().offset_from(origin_str.as_ptr()) as u32; + + make_and_append_view( + &mut views_buf, + &mut null_builder, + raw, + trim_str, + start, + ); + } + } } 2 => { let characters_array = as_string_view_array(&args[1])?; @@ -102,35 +122,91 @@ fn string_view_trim<'a, T: OffsetSizeTrait>( return Ok(new_null_array( // The schema is expecting utf8 as null &DataType::Utf8, - string_array.len(), + string_view_array.len(), )); } let characters = characters_array.value(0); - let result = string_array - .iter() - .map(|item| item.map(|string| func(string, characters))) - .collect::>(); - return Ok(Arc::new(result) as ArrayRef); + + for (idx, raw) in string_view_array.views().iter().enumerate() { + unsafe { + // Safety: + // idx is always smaller or equal to string_view_array.views.len() + let origin_str = string_view_array.value_unchecked(idx); + let trim_str = func(origin_str, characters); + + // Safety: + // `trim_str` is computed from `str::trim_xxx_matches`, + // and its addr is ensured to be >= `origin_str`'s + let start = + trim_str.as_ptr().offset_from(origin_str.as_ptr()) as u32; + + make_and_append_view( + &mut views_buf, + &mut null_builder, + raw, + trim_str, + start, + ); + } + } } - let result = string_array + for (idx, (raw, characters_opt)) in string_view_array + .views() .iter() .zip(characters_array.iter()) - .map(|(string, characters)| match (string, characters) { - (Some(string), Some(characters)) => Some(func(string, characters)), - _ => None, - }) - .collect::>(); - - Ok(Arc::new(result) as ArrayRef) + .enumerate() + { + if let Some(characters) = characters_opt { + unsafe { + // Safety: + // idx is always smaller or equal to string_view_array.views.len() + let origin_str = string_view_array.value_unchecked(idx); + let trim_str = func(origin_str, characters); + + // Safety: + // `trim_str` is computed from `str::trim_xxx_matches`, + // and its addr is ensured to be >= `origin_str`'s + let start = + trim_str.as_ptr().offset_from(origin_str.as_ptr()) as u32; + + make_and_append_view( + &mut views_buf, + &mut null_builder, + raw, + trim_str, + start, + ); + } + } else { + null_builder.append_null(); + views_buf.push(0); + } + } } other => { - exec_err!( + return exec_err!( "Function TRIM was called with {other} arguments. It requires at least 1 and at most 2." - ) + ); } } + + let views_buf = ScalarBuffer::from(views_buf); + let nulls_buf = null_builder.finish(); + + // Safety: + // (1) The blocks of the given views are all provided + // (2) Each of the range `view.offset+start..end` of view in views_buf is within + // the bounds of each of the blocks + unsafe { + let array = StringViewArray::new_unchecked( + views_buf, + string_view_array.data_buffers().to_vec(), + nulls_buf, + ); + Ok(Arc::new(array) as ArrayRef) + } } fn string_trim<'a, T: OffsetSizeTrait>( diff --git a/datafusion/functions/src/string/rtrim.rs b/datafusion/functions/src/string/rtrim.rs index ec53f3ed74307..52d0826137fa0 100644 --- a/datafusion/functions/src/string/rtrim.rs +++ b/datafusion/functions/src/string/rtrim.rs @@ -101,3 +101,43 @@ impl ScalarUDFImpl for RtrimFunc { } } } + +#[cfg(test)] +mod tests { + use arrow::array::{Array, StringArray, StringViewArray}; + use arrow::datatypes::DataType::{Utf8, Utf8View}; + + use datafusion_common::{exec_err, Result, ScalarValue}; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + + use crate::unicode::substr::SubstrFunc; + use crate::utils::test::test_function; + + #[test] + fn test_functions() { + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(None)), + ColumnarValue::Scalar(ScalarValue::from(1i64)), + ], + Ok(None), + &str, + Utf8View, + StringViewArray + ); + test_function!( + SubstrFunc::new(), + &[ + ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from( + "alphabet" + )))), + ColumnarValue::Scalar(ScalarValue::from(0i64)), + ], + Ok(Some("alphabet")), + &str, + Utf8View, + StringViewArray + ); + } +} diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 8a70b380669cc..88e5f04163f61 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -145,7 +145,7 @@ fn get_true_start_end(input: &str, start: usize, count: i64) -> (usize, usize) { /// Make a `u128` based on the given substr, start(offset to view.offset), and /// push into to the given buffers -fn make_and_append_view( +pub fn make_and_append_view( views_buffer: &mut Vec, null_builder: &mut NullBufferBuilder, raw: &u128,