Skip to content

Commit

Permalink
feat(vector): add sub function
Browse files Browse the repository at this point in the history
  • Loading branch information
KKould committed Dec 24, 2024
1 parent 2082c4b commit 3c50b24
Show file tree
Hide file tree
Showing 4 changed files with 244 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/common/function/src/scalars/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ mod distance;
pub(crate) mod impl_conv;
mod scalar_add;
mod scalar_mul;
mod sub;

use std::sync::Arc;

Expand All @@ -38,5 +39,6 @@ impl VectorFunction {
// scalar calculation
registry.register(Arc::new(scalar_add::ScalarAddFunction));
registry.register(Arc::new(scalar_mul::ScalarMulFunction));
registry.register(Arc::new(sub::SubFunction));
}
}
182 changes: 182 additions & 0 deletions src/common/function/src/scalars/vector/sub.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::borrow::Cow;
use std::fmt::Display;

use common_query::error::InvalidFuncArgsSnafu;
use common_query::prelude::Signature;
use datatypes::prelude::ConcreteDataType;
use datatypes::scalars::ScalarVectorBuilder;
use datatypes::vectors::{BinaryVectorBuilder, MutableVector, VectorRef};
use nalgebra::DVectorView;
use snafu::ensure;

use crate::function::{Function, FunctionContext};
use crate::helper;
use crate::scalars::vector::impl_conv::{as_veclit, as_veclit_if_const, veclit_to_binlit};

const NAME: &str = "vec_sub";

/// Subtracts corresponding elements of two vectors, returns a vector.
///
/// # Example
///
/// ```sql
/// SELECT vec_to_string(vec_sub("[1.0, 1.0]", "[1.0, 2.0]")) as result;
///
/// +---------------------------------------------------------------+
/// | vec_to_string(vec_sub(Utf8("[1.0, 1.0]"),Utf8("[1.0, 2.0]"))) |
/// +---------------------------------------------------------------+
/// | [0,-1] |
/// +---------------------------------------------------------------+
///
/// -- Negative scalar to simulate subtraction
/// SELECT vec_to_string(vec_sub('[-1.0, -1.0]', '[1.0, 2.0]'));
///
/// +-----------------------------------------------------------------+
/// | vec_to_string(vec_sub(Utf8("[-1.0, -1.0]"),Utf8("[1.0, 2.0]"))) |
/// +-----------------------------------------------------------------+
/// | [-2,-3] |
/// +-----------------------------------------------------------------+
///
#[derive(Debug, Clone, Default)]
pub struct SubFunction;

impl Function for SubFunction {
fn name(&self) -> &str {
NAME
}

fn return_type(
&self,
_input_types: &[ConcreteDataType],
) -> common_query::error::Result<ConcreteDataType> {
Ok(ConcreteDataType::binary_datatype())
}

fn signature(&self) -> Signature {
helper::one_of_sigs2(
vec![
ConcreteDataType::string_datatype(),
ConcreteDataType::binary_datatype(),
],
vec![
ConcreteDataType::string_datatype(),
ConcreteDataType::binary_datatype(),
],
)
}

fn eval(
&self,
_func_ctx: FunctionContext,
columns: &[VectorRef],
) -> common_query::error::Result<VectorRef> {
ensure!(
columns.len() == 2,
InvalidFuncArgsSnafu {
err_msg: format!(
"The length of the args is not correct, expect exactly two, have: {}",
columns.len()
)
}
);
let arg0 = &columns[0];
let arg1 = &columns[1];

let len = arg0.len();
let mut result = BinaryVectorBuilder::with_capacity(len);
if len == 0 {
return Ok(result.to_vector());
}

let arg0_const = as_veclit_if_const(arg0)?;
let arg1_const = as_veclit_if_const(arg1)?;

for i in 0..len {
let arg0 = match arg0_const.as_ref() {
Some(arg0) => Some(Cow::Borrowed(arg0.as_ref())),
None => as_veclit(arg0.get_ref(i))?,
};
let arg1 = match arg1_const.as_ref() {
Some(arg1) => Some(Cow::Borrowed(arg1.as_ref())),
None => as_veclit(arg1.get_ref(i))?,
};
let (Some(arg0), Some(arg1)) = (arg0, arg1) else {
result.push_null();
continue;
};
let vec0 = DVectorView::from_slice(&arg0, arg0.len());
let vec1 = DVectorView::from_slice(&arg1, arg1.len());

let vec_res = vec0 - vec1;
let veclit = vec_res.as_slice();
let binlit = veclit_to_binlit(veclit);
result.push(Some(&binlit));
}

Ok(result.to_vector())
}
}

impl Display for SubFunction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", NAME.to_ascii_uppercase())
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use datatypes::vectors::StringVector;

use super::*;

#[test]
fn test_sub() {
let func = SubFunction;

let input0 = Arc::new(StringVector::from(vec![
Some("[1.0,2.0,3.0]".to_string()),
Some("[4.0,5.0,6.0]".to_string()),
None,
Some("[2.0,3.0,3.0]".to_string()),
]));
let input1 = Arc::new(StringVector::from(vec![
Some("[1.0,1.0,1.0]".to_string()),
Some("[6.0,5.0,4.0]".to_string()),
Some("[3.0,2.0,2.0]".to_string()),
None,
]));

let result = func
.eval(FunctionContext::default(), &[input0, input1])
.unwrap();

let result = result.as_ref();
assert_eq!(result.len(), 4);
assert_eq!(
result.get_ref(0).as_binary().unwrap(),
Some(veclit_to_binlit(&[0.0, 1.0, 2.0]).as_slice())
);
assert_eq!(
result.get_ref(1).as_binary().unwrap(),
Some(veclit_to_binlit(&[-2.0, 0.0, 2.0]).as_slice())
);
assert!(result.get_ref(2).is_null());
assert!(result.get_ref(3).is_null());
}
}
48 changes: 48 additions & 0 deletions tests/cases/standalone/common/function/vector/vector.result
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,51 @@ SELECT vec_to_string(parse_vec('[]'));
| [] |
+--------------------------------------+

SELECT vec_to_string(vec_sub('[1.0, 1.0]', '[1.0, 2.0]'));

+---------------------------------------------------------------+
| vec_to_string(vec_sub(Utf8("[1.0, 1.0]"),Utf8("[1.0, 2.0]"))) |
+---------------------------------------------------------------+
| [0,-1] |
+---------------------------------------------------------------+

SELECT vec_to_string(vec_sub('[-1.0, -1.0]', '[1.0, 2.0]'));

+-----------------------------------------------------------------+
| vec_to_string(vec_sub(Utf8("[-1.0, -1.0]"),Utf8("[1.0, 2.0]"))) |
+-----------------------------------------------------------------+
| [-2,-3] |
+-----------------------------------------------------------------+

SELECT vec_to_string(vec_sub('[1.0, 1.0]', parse_vec('[1.0, 2.0]')));

+--------------------------------------------------------------------------+
| vec_to_string(vec_sub(Utf8("[1.0, 1.0]"),parse_vec(Utf8("[1.0, 2.0]")))) |
+--------------------------------------------------------------------------+
| [0,-1] |
+--------------------------------------------------------------------------+

SELECT vec_to_string(vec_sub('[-1.0, -1.0]', parse_vec('[1.0, 2.0]')));

+----------------------------------------------------------------------------+
| vec_to_string(vec_sub(Utf8("[-1.0, -1.0]"),parse_vec(Utf8("[1.0, 2.0]")))) |
+----------------------------------------------------------------------------+
| [-2,-3] |
+----------------------------------------------------------------------------+

SELECT vec_to_string(vec_sub(parse_vec('[1.0, 1.0]'), '[1.0, 2.0]'));

+--------------------------------------------------------------------------+
| vec_to_string(vec_sub(parse_vec(Utf8("[1.0, 1.0]")),Utf8("[1.0, 2.0]"))) |
+--------------------------------------------------------------------------+
| [0,-1] |
+--------------------------------------------------------------------------+

SELECT vec_to_string(vec_sub(parse_vec('[-1.0, -1.0]'), '[1.0, 2.0]'));

+----------------------------------------------------------------------------+
| vec_to_string(vec_sub(parse_vec(Utf8("[-1.0, -1.0]")),Utf8("[1.0, 2.0]"))) |
+----------------------------------------------------------------------------+
| [-2,-3] |
+----------------------------------------------------------------------------+

12 changes: 12 additions & 0 deletions tests/cases/standalone/common/function/vector/vector.sql
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,15 @@ SELECT vec_to_string(parse_vec('[1.0, 2.0]'));
SELECT vec_to_string(parse_vec('[1.0, 2.0, 3.0]'));

SELECT vec_to_string(parse_vec('[]'));

SELECT vec_to_string(vec_sub('[1.0, 1.0]', '[1.0, 2.0]'));

SELECT vec_to_string(vec_sub('[-1.0, -1.0]', '[1.0, 2.0]'));

SELECT vec_to_string(vec_sub('[1.0, 1.0]', parse_vec('[1.0, 2.0]')));

SELECT vec_to_string(vec_sub('[-1.0, -1.0]', parse_vec('[1.0, 2.0]')));

SELECT vec_to_string(vec_sub(parse_vec('[1.0, 1.0]'), '[1.0, 2.0]'));

SELECT vec_to_string(vec_sub(parse_vec('[-1.0, -1.0]'), '[1.0, 2.0]'));

0 comments on commit 3c50b24

Please sign in to comment.