Skip to content

Commit

Permalink
proc_macro: support fn inputs which are refs: &T and/or Option<&T>
Browse files Browse the repository at this point in the history
- make_cache_key_type converts keys of Option<&T> to Option<T>
- Use the __private ToFullyOwned trait in generated code impl
  • Loading branch information
BaxHugh committed Apr 9, 2024
1 parent 7662b49 commit 61bce17
Show file tree
Hide file tree
Showing 9 changed files with 370 additions and 7 deletions.
23 changes: 20 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,22 @@ async = ["futures", "tokio", "async-trait"]
async_tokio_rt_multi_thread = ["async", "tokio/rt-multi-thread"]
redis_store = ["redis", "r2d2", "serde", "serde_json"]
redis_connection_manager = ["redis_store", "redis/connection-manager"]
redis_async_std = ["redis_store", "async", "redis/aio", "redis/async-std-comp", "redis/tls", "redis/async-std-tls-comp"]
redis_tokio = ["redis_store", "async", "redis/aio", "redis/tokio-comp", "redis/tls", "redis/tokio-native-tls-comp"]
redis_async_std = [
"redis_store",
"async",
"redis/aio",
"redis/async-std-comp",
"redis/tls",
"redis/async-std-tls-comp",
]
redis_tokio = [
"redis_store",
"async",
"redis/aio",
"redis/tokio-comp",
"redis/tls",
"redis/tokio-native-tls-comp",
]
redis_ahash = ["redis_store", "redis/ahash"]
disk_store = ["sled", "serde", "rmp-serde", "directories"]
wasm = ["instant/wasm-bindgen"]
Expand Down Expand Up @@ -103,7 +117,7 @@ optional = true
version = "0.1"

[dev-dependencies]
googletest = "0.11.0"
googletest.workspace = true
tempfile = "3.10.1"

[dev-dependencies.async-std]
Expand All @@ -116,6 +130,9 @@ version = "1"
[dev-dependencies.serial_test]
version = "3"

[workspace.dependencies]
googletest = "0.11.0"

[workspace]
members = ["cached_proc_macro", "examples/wasm"]

Expand Down
4 changes: 4 additions & 0 deletions cached_proc_macro/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,7 @@ quote = "1.0.6"
darling = "0.20.8"
proc-macro2 = "1.0.49"
syn = "2.0.52"
derive-syn-parse = "0.2.0"

[dev-dependencies]
googletest.workspace = true
2 changes: 2 additions & 0 deletions cached_proc_macro/src/cached.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream {
#visibility #signature_no_muts {
use cached::Cached;
use cached::CloneCached;
use cached::proc_macro::__private::ToFullyOwned as _;
let key = #key_convert_block;
#do_set_return_block
}
Expand All @@ -328,6 +329,7 @@ pub fn cached(args: TokenStream, input: TokenStream) -> TokenStream {
#(#attributes)*
#visibility #prime_sig {
use cached::Cached;
use cached::proc_macro::__private::ToFullyOwned as _;
let key = #key_convert_block;
#prime_do_set_return_block
}
Expand Down
226 changes: 222 additions & 4 deletions cached_proc_macro/src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,17 @@ pub(super) fn make_cache_key_type(

(quote! {}, quote! {#key_convert_block})
}
(None, None, _) => (
quote! {(#(#input_tys),*)},
quote! {(#(#input_names.clone()),*)},
),
(None, None, _) => {
let key_tys = input_tys
.into_iter()
.map(convert_option_of_ref_to_option_of_owned_type)
.map(convert_ref_to_owned_type)
.collect::<Vec<Type>>();
(
quote! {(#(#key_tys),*)},
quote! {(#(#input_names.to_fully_owned()),*)},
)
}
(Some(_), None, _) => panic!("key requires convert to be set"),
(None, Some(_), None) => panic!("convert requires key or type to be set"),
}
Expand Down Expand Up @@ -218,3 +225,214 @@ pub(super) fn check_with_cache_flag(with_cached_flag: bool, output_string: Strin
&& !output_string.contains("Return")
&& !output_string.contains("cached::Return")
}

use ref_inputs::*;
mod ref_inputs {
use super::*;

pub(super) fn is_option(ty: &Type) -> bool {
if let Type::Path(typepath) = ty {
let segments = &typepath.path.segments;
if segments.len() == 1 {
let segment = segments.first().unwrap();
if segment.ident == "Option" {
return true;
}
} else if segments.len() == 3 {
let segment_idents = segments
.iter()
.map(|s| s.ident.to_string())
.collect::<Vec<_>>();
if segment_idents == ["std", "option", "Option"] {
return true;
}
}
}
false
}

fn option_generic_arg_unchecked(ty: &Type) -> Type {
if let Type::Path(typepath) = ty {
let segment = &typepath
.path
.segments
.last()
.expect("option_generic_arg_unchecked: empty path");
if let PathArguments::AngleBracketed(brackets) = &segment.arguments {
if let Some(syn::GenericArgument::Type(inner_ty)) = brackets.args.first() {
return inner_ty.clone();
}
}
}
panic!("option_generic_arg_unchecked: could not extract inner type");
}

pub(super) fn is_option_of_ref(ty: &Type) -> bool {
if is_option(ty) {
let inner_ty = option_generic_arg_unchecked(ty);
if let Type::Reference(_) = inner_ty {
return true;
}
}

false
}

pub(super) fn convert_ref_to_owned_type(ty: Type) -> Type {
match ty {
Type::Reference(reftype) => *reftype.elem,
_ => ty,
}
}

pub(super) fn convert_option_of_ref_to_option_of_owned_type(ty: Type) -> Type {
if is_option_of_ref(&ty) {
let inner_ty = option_generic_arg_unchecked(&ty);
if let Type::Reference(reftype) = inner_ty {
let elem = *reftype.elem;
return parse_quote! { Option< #elem > };
}
}
ty
}
}

#[cfg(test)]
mod test {
use super::*;
use googletest::{assert_that, matchers::eq};
use syn::parse_quote;

macro_rules! type_test {
($test_name:ident, $target_fn:ident syn_ref, $input_type:ty, $expected:expr) => {
#[googletest::test]
fn $test_name() {
let ty = &parse_quote! { $input_type };
assert_that!($target_fn(ty), eq($expected));
}
};
($test_name:ident, $target_fn:ident syn_owned, $input_type:ty, $expected:expr) => {
#[googletest::test]
fn $test_name() {
let ty = parse_quote! { $input_type };
assert_that!($target_fn(ty), eq($expected));
}
};
}

mod convert_ref_to_owned_type {
use super::*;

type_test! {
returns_the_owned_type_when_given_a_ref_type,
convert_ref_to_owned_type syn_owned,
&T,
parse_quote!{ T }
}

type_test! {
returns_the_same_type_when_given_a_non_ref_type,
convert_ref_to_owned_type syn_owned,
T,
parse_quote!{ T }
}
}

mod convert_option_of_ref_to_option_of_owned_type {
use super::*;

type_test! {
returns_the_owned_option_type_when_given_option_of_ref,
convert_option_of_ref_to_option_of_owned_type syn_owned,
Option<&T>,
parse_quote!{ Option<T> }
}

type_test! {
returns_the_same_type_when_given_a_non_option_type,
convert_option_of_ref_to_option_of_owned_type syn_owned,
T,
parse_quote!{ T }
}

type_test! {
returns_the_same_type_when_given_an_option_of_non_ref_type,
convert_option_of_ref_to_option_of_owned_type syn_owned,
Option<T>,
parse_quote!{ Option<T> }
}
}

mod is_option {

mod when_arg_is_ref {
use super::super::*;
type_test!(returns_true_for_option, is_option syn_ref, Option<&T>, true);
type_test!(
returns_true_for_option_with_fully_qualified_core_path,
is_option syn_ref,
std::option::Option<&T>,
true
);
type_test!(
returns_false_for_custom_type_named_option,
is_option syn_ref,
my_module::Option<&T>,
false
);
}

mod when_arg_is_not_ref {
use super::super::*;
type_test!(returns_true_for_option, is_option syn_ref, Option<T>, true);
type_test!(
returns_true_for_option_with_fully_qualified_core_path,
is_option syn_ref,
std::option::Option<T>,
true
);
type_test!(
returns_false_for_custom_type_named_option,
is_option syn_ref,
my_module::Option<T>,
false
);
type_test!(returns_false_for_simple_type, is_option syn_ref, T, false);
type_test!(returns_false_for_a_generic_type, is_option syn_ref, Vec<T>, false);
}
}

mod is_option_of_ref {
use super::*;
type_test!(
returns_true_for_option_of_ref,
is_option_of_ref syn_ref,
Option<&T>,
true
);
type_test!(
returns_true_for_option_of_ref_with_fully_qualified_core_path,
is_option_of_ref syn_ref,
std::option::Option<&T>,
true
);
type_test!(
returns_false_for_custom_type_named_option_with_ref_generic_arg,
is_option_of_ref syn_ref,
my_module::Option<&T>,
false
);
type_test!(
returns_false_for_option_of_non_ref,
is_option_of_ref syn_ref,
Option<T>,
false
);
type_test!(
returns_false_for_option_of_non_ref_with_fully_qualified_core_path,
is_option_of_ref syn_ref,
std::option::Option<T>,
false
);
}
}
5 changes: 5 additions & 0 deletions cached_proc_macro/src/io_cached.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,10 +406,12 @@ pub fn io_cached(args: TokenStream, input: TokenStream) -> TokenStream {
let async_trait = if asyncness.is_some() && !args.disk {
quote! {
use cached::IOCachedAsync;
use cached::proc_macro::__private::ToFullyOwned as _;
}
} else {
quote! {
use cached::IOCached;
use cached::proc_macro::__private::ToFullyOwned as _;
}
};

Expand All @@ -435,6 +437,7 @@ pub fn io_cached(args: TokenStream, input: TokenStream) -> TokenStream {
// Cached function
#(#attributes)*
#visibility #signature_no_muts {
use cached::proc_macro::__private::ToFullyOwned as _;
let init = || async { #cache_create };
#async_trait
let key = #key_convert_block;
Expand Down Expand Up @@ -464,6 +467,7 @@ pub fn io_cached(args: TokenStream, input: TokenStream) -> TokenStream {
#(#attributes)*
#visibility #signature_no_muts {
use cached::IOCached;
use cached::proc_macro::__private::ToFullyOwned as _;
let key = #key_convert_block;
{
// check if the result is cached
Expand All @@ -479,6 +483,7 @@ pub fn io_cached(args: TokenStream, input: TokenStream) -> TokenStream {
#[allow(dead_code)]
#visibility #prime_sig {
use cached::IOCached;
use cached::proc_macro::__private::ToFullyOwned as _;
let key = #key_convert_block;
#do_set_return_block
}
Expand Down
8 changes: 8 additions & 0 deletions cached_proc_macro_types/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
mod to_fully_owned;

// Not public API.
#[doc(hidden)]
pub mod __private {
pub use super::to_fully_owned::ToFullyOwned;
}

/// Used to wrap a function result so callers can see whether the result was cached.
#[derive(Clone)]
pub struct Return<T> {
Expand Down
Loading

0 comments on commit 61bce17

Please sign in to comment.