From c4b3ac35658f641202a6dd960c8a3382d23037da Mon Sep 17 00:00:00 2001 From: "Meir Shpilraien (Spielrein)" Date: Wed, 1 Feb 2023 12:16:39 +0200 Subject: [PATCH] Make call function generic. (#264) * Make call function generic. * Avoid RedisModuleString retain if there is not need to. * Code reuse * Support u8 slice and u8 vec. * Format fixes * Added tests for the new call functionality. * Improve tests * Format fixes * No need for AsBytes trait * Use try_into instead of into. * Apply suggestions from code review Co-authored-by: Guy Korland --------- Co-authored-by: Guy Korland --- Cargo.toml | 4 +++ examples/call.rs | 77 ++++++++++++++++++++++++++++++++++++++++++++ src/context/mod.rs | 70 ++++++++++++++++++++++++++++++++++------ src/redismodule.rs | 21 ++++++++++-- src/redisvalue.rs | 16 ++++++++- tests/integration.rs | 17 ++++++++++ 6 files changed, 192 insertions(+), 13 deletions(-) create mode 100644 examples/call.rs diff --git a/Cargo.toml b/Cargo.toml index 03bac1fd..10bda721 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,10 @@ crate-type = ["cdylib"] name = "string" crate-type = ["cdylib"] +[[example]] +name = "call" +crate-type = ["cdylib"] + [[example]] name = "keys_pos" crate-type = ["cdylib"] diff --git a/examples/call.rs b/examples/call.rs new file mode 100644 index 00000000..0e3da189 --- /dev/null +++ b/examples/call.rs @@ -0,0 +1,77 @@ +#[macro_use] +extern crate redis_module; + +use redis_module::{Context, RedisError, RedisResult, RedisString}; + +fn call_test(ctx: &Context, _: Vec) -> RedisResult { + let res: String = ctx.call("ECHO", &["TEST"])?.try_into()?; + if "TEST" != &res { + return Err(RedisError::Str("Failed calling 'ECHO TEST'")); + } + + let res: String = ctx.call("ECHO", vec!["TEST"].as_slice())?.try_into()?; + if "TEST" != &res { + return Err(RedisError::Str( + "Failed calling 'ECHO TEST' dynamic str vec", + )); + } + + let res: String = ctx.call("ECHO", &[b"TEST"])?.try_into()?; + if "TEST" != &res { + return Err(RedisError::Str( + "Failed calling 'ECHO TEST' with static [u8]", + )); + } + + let res: String = ctx.call("ECHO", vec![b"TEST"].as_slice())?.try_into()?; + if "TEST" != &res { + return Err(RedisError::Str( + "Failed calling 'ECHO TEST' dynamic &[u8] vec", + )); + } + + let res: String = ctx.call("ECHO", &[&"TEST".to_string()])?.try_into()?; + if "TEST" != &res { + return Err(RedisError::Str("Failed calling 'ECHO TEST' with String")); + } + + let res: String = ctx + .call("ECHO", vec![&"TEST".to_string()].as_slice())? + .try_into()?; + if "TEST" != &res { + return Err(RedisError::Str( + "Failed calling 'ECHO TEST' dynamic &[u8] vec", + )); + } + + let res: String = ctx + .call("ECHO", &[&ctx.create_string("TEST")])? + .try_into()?; + if "TEST" != &res { + return Err(RedisError::Str( + "Failed calling 'ECHO TEST' with RedisString", + )); + } + + let res: String = ctx + .call("ECHO", vec![&ctx.create_string("TEST")].as_slice())? + .try_into()?; + if "TEST" != &res { + return Err(RedisError::Str( + "Failed calling 'ECHO TEST' with dynamic array of RedisString", + )); + } + + Ok("pass".into()) +} + +////////////////////////////////////////////////////// + +redis_module! { + name: "call", + version: 1, + data_types: [], + commands: [ + ["call.test", call_test, "", 0, 0, 0], + ], +} diff --git a/src/context/mod.rs b/src/context/mod.rs index 1c86709f..560ef239 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -30,6 +30,61 @@ pub struct Context { pub ctx: *mut raw::RedisModuleCtx, } +pub struct StrCallArgs<'a> { + is_owner: bool, + args: Vec<*mut raw::RedisModuleString>, + // Phantom is used to make sure the object will not live longer than actual arguments slice + phantom: std::marker::PhantomData<&'a raw::RedisModuleString>, +} + +impl<'a> Drop for StrCallArgs<'a> { + fn drop(&mut self) { + if self.is_owner { + self.args.iter_mut().for_each(|v| unsafe { + raw::RedisModule_FreeString.unwrap()(std::ptr::null_mut(), *v) + }); + } + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> From<&'a [&T]> for StrCallArgs<'a> { + fn from(vals: &'a [&T]) -> Self { + StrCallArgs { + is_owner: true, + args: vals + .iter() + .map(|v| RedisString::create_from_slice(std::ptr::null_mut(), v.as_ref()).take()) + .collect(), + phantom: std::marker::PhantomData, + } + } +} + +impl<'a> From<&'a [&RedisString]> for StrCallArgs<'a> { + fn from(vals: &'a [&RedisString]) -> Self { + StrCallArgs { + is_owner: false, + args: vals.iter().map(|v| v.inner).collect(), + phantom: std::marker::PhantomData, + } + } +} + +impl<'a, const SIZE: usize, T: ?Sized> From<&'a [&T; SIZE]> for StrCallArgs<'a> +where + for<'b> &'a [&'b T]: Into>, +{ + fn from(vals: &'a [&T; SIZE]) -> Self { + vals.as_ref().into() + } +} + +impl<'a> StrCallArgs<'a> { + fn args_mut(&mut self) -> &mut [*mut raw::RedisModuleString] { + &mut self.args + } +} + impl Context { pub const fn new(ctx: *mut raw::RedisModuleCtx) -> Self { Self { ctx } @@ -97,14 +152,9 @@ impl Context { } } - pub fn call(&self, command: &str, args: &[&str]) -> RedisResult { - let terminated_args: Vec = args - .iter() - .map(|s| RedisString::create(self.ctx, s)) - .collect(); - - let mut inner_args: Vec<*mut raw::RedisModuleString> = - terminated_args.iter().map(|s| s.inner).collect(); + pub fn call<'a, T: Into>>(&self, command: &str, args: T) -> RedisResult { + let mut call_args: StrCallArgs = args.into(); + let final_args = call_args.args_mut(); let cmd = CString::new(command).unwrap(); let reply: *mut raw::RedisModuleCallReply = unsafe { @@ -113,8 +163,8 @@ impl Context { self.ctx, cmd.as_ptr(), raw::FMT, - inner_args.as_mut_ptr(), - terminated_args.len(), + final_args.as_mut_ptr(), + final_args.len(), ) }; let result = Self::parse_call_reply(reply); diff --git a/src/redismodule.rs b/src/redismodule.rs index 1b87d494..e83e62e6 100644 --- a/src/redismodule.rs +++ b/src/redismodule.rs @@ -95,6 +95,12 @@ pub struct RedisString { } impl RedisString { + pub(crate) fn take(mut self) -> *mut raw::RedisModuleString { + let inner = self.inner; + self.inner = std::ptr::null_mut(); + inner + } + pub fn new(ctx: *mut raw::RedisModuleCtx, inner: *mut raw::RedisModuleString) -> Self { raw::string_retain_string(ctx, inner); Self { ctx, inner } @@ -108,6 +114,15 @@ impl RedisString { Self { ctx, inner } } + #[allow(clippy::not_unsafe_ptr_arg_deref)] + pub fn create_from_slice(ctx: *mut raw::RedisModuleCtx, s: &[u8]) -> Self { + let inner = unsafe { + raw::RedisModule_CreateString.unwrap()(ctx, s.as_ptr().cast::(), s.len()) + }; + + Self { ctx, inner } + } + pub fn from_redis_module_string( ctx: *mut raw::RedisModuleCtx, inner: *mut raw::RedisModuleString, @@ -198,8 +213,10 @@ impl RedisString { impl Drop for RedisString { fn drop(&mut self) { - unsafe { - raw::RedisModule_FreeString.unwrap()(self.ctx, self.inner); + if !self.inner.is_null() { + unsafe { + raw::RedisModule_FreeString.unwrap()(self.ctx, self.inner); + } } } } diff --git a/src/redisvalue.rs b/src/redisvalue.rs index 81b320fd..5bb002fd 100644 --- a/src/redisvalue.rs +++ b/src/redisvalue.rs @@ -1,4 +1,4 @@ -use crate::RedisString; +use crate::{RedisError, RedisString}; #[derive(Debug, PartialEq)] pub enum RedisValue { @@ -14,6 +14,20 @@ pub enum RedisValue { NoReply, // No reply at all (as opposed to a Null reply) } +impl TryFrom for String { + type Error = RedisError; + fn try_from(val: RedisValue) -> Result { + match val { + RedisValue::SimpleStringStatic(s) => Ok(s.to_string()), + RedisValue::SimpleString(s) => Ok(s), + RedisValue::BulkString(s) => Ok(s), + RedisValue::BulkRedisString(s) => Ok(s.try_as_str()?.to_string()), + RedisValue::StringBuffer(s) => Ok(std::str::from_utf8(&s)?.to_string()), + _ => Err(RedisError::Str("Can not convert result to String")), + } + } +} + impl From<()> for RedisValue { fn from(_: ()) -> Self { Self::Null diff --git a/tests/integration.rs b/tests/integration.rs index aaba0f35..3b96ac27 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -234,3 +234,20 @@ fn test_stream_reader() -> Result<()> { Ok(()) } + +#[test] +fn test_call() -> Result<()> { + let port: u16 = 6488; + let _guards = vec![start_redis_server_with_module("call", port) + .with_context(|| "failed to start redis server")?]; + let mut con = + get_redis_connection(port).with_context(|| "failed to connect to redis server")?; + + let res: String = redis::cmd("call.test") + .query(&mut con) + .with_context(|| "failed to run string.set")?; + + assert_eq!(&res, "pass"); + + Ok(()) +}