diff --git a/src/operation_dispatcher.rs b/src/operation_dispatcher.rs index 61e12b8a..1d855a8f 100644 --- a/src/operation_dispatcher.rs +++ b/src/operation_dispatcher.rs @@ -1,10 +1,10 @@ use crate::configuration::{Extension, ExtensionType, FailureMode}; use crate::envoy::RateLimitDescriptor; use crate::policy::Policy; -use crate::service::{GrpcMessage, GrpcServiceHandler}; +use crate::service::{GetMapValuesBytes, GrpcCall, GrpcMessage, GrpcServiceHandler}; use protobuf::RepeatedField; use proxy_wasm::hostcalls; -use proxy_wasm::types::Status; +use proxy_wasm::types::{Bytes, MapType, Status}; use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; @@ -29,24 +29,6 @@ impl State { } } -fn grpc_call( - upstream_name: &str, - service_name: &str, - method_name: &str, - initial_metadata: Vec<(&str, &[u8])>, - message: Option<&[u8]>, - timeout: Duration, -) -> Result { - hostcalls::dispatch_grpc_call( - upstream_name, - service_name, - method_name, - initial_metadata, - message, - timeout, - ) -} - type Procedure = (Rc, GrpcMessage); #[allow(dead_code)] @@ -56,6 +38,8 @@ pub(crate) struct Operation { result: Result, extension: Rc, procedure: Procedure, + grpc_call: GrpcCall, + get_map_values_bytes: GetMapValuesBytes, } #[allow(dead_code)] @@ -66,17 +50,19 @@ impl Operation { result: Err(Status::Empty), extension, procedure, + grpc_call, + get_map_values_bytes, } } - pub fn set_action(&mut self, procedure: Procedure) { - self.procedure = procedure; - } - - pub fn trigger(&mut self) { + fn trigger(&mut self) { if let State::Done = self.state { } else { - self.result = self.procedure.0.send(grpc_call, self.procedure.1.clone()); + self.result = self.procedure.0.send( + self.get_map_values_bytes, + self.grpc_call, + self.procedure.1.clone(), + ); self.state.next(); } } @@ -172,6 +158,28 @@ impl OperationDispatcher { } } +fn grpc_call( + upstream_name: &str, + service_name: &str, + method_name: &str, + initial_metadata: Vec<(&str, &[u8])>, + message: Option<&[u8]>, + timeout: Duration, +) -> Result { + hostcalls::dispatch_grpc_call( + upstream_name, + service_name, + method_name, + initial_metadata, + message, + timeout, + ) +} + +fn get_map_values_bytes(map_type: MapType, key: &str) -> Result, Status> { + hostcalls::get_map_value_bytes(map_type, key) +} + #[cfg(test)] mod tests { use super::*; @@ -189,6 +197,10 @@ mod tests { Ok(200) } + fn get_map_values_bytes(_map_type: MapType, _key: &str) -> Result, Status> { + Ok(Some(Vec::new())) + } + fn build_grpc_service_handler() -> GrpcServiceHandler { GrpcServiceHandler::new(Rc::new(Default::default()), Rc::new(Default::default())) } @@ -203,33 +215,33 @@ mod tests { } } - #[test] - fn operation_getters() { - let extension = Rc::new(Extension::default()); - let operation = Operation::new( - extension, - ( + fn build_operation() -> Operation { + Operation { + state: State::Pending, + result: Ok(200), + extension: Rc::new(Extension::default()), + procedure: ( Rc::new(build_grpc_service_handler()), GrpcMessage::RateLimit(build_message()), ), - ); + grpc_call, + get_map_values_bytes, + } + } + + #[test] + fn operation_getters() { + let operation = build_operation(); assert_eq!(operation.get_state(), State::Pending); assert_eq!(operation.get_extension_type(), ExtensionType::RateLimit); assert_eq!(operation.get_failure_mode(), FailureMode::Deny); - assert_eq!(operation.get_result(), Result::Ok(1)); + assert_eq!(operation.get_result(), Ok(200)); } #[test] fn operation_transition() { - let extension = Rc::new(Extension::default()); - let mut operation = Operation::new( - extension, - ( - Rc::new(build_grpc_service_handler()), - GrpcMessage::RateLimit(build_message()), - ), - ); + let mut operation = build_operation(); assert_eq!(operation.get_state(), State::Pending); operation.trigger(); assert_eq!(operation.get_state(), State::Waiting); @@ -242,23 +254,16 @@ mod tests { fn operation_dispatcher_push_actions() { let operation_dispatcher = OperationDispatcher::default(); - assert_eq!(operation_dispatcher.operations.borrow().len(), 1); - let extension = Rc::new(Extension::default()); - operation_dispatcher.push_operations(vec![Operation::new( - extension, - ( - Rc::new(build_grpc_service_handler()), - GrpcMessage::RateLimit(build_message()), - ), - )]); + assert_eq!(operation_dispatcher.operations.borrow().len(), 0); + operation_dispatcher.push_operations(vec![build_operation()]); - assert_eq!(operation_dispatcher.operations.borrow().len(), 2); + assert_eq!(operation_dispatcher.operations.borrow().len(), 1); } #[test] fn operation_dispatcher_get_current_action_state() { let operation_dispatcher = OperationDispatcher::default(); - + operation_dispatcher.push_operations(vec![build_operation()]); assert_eq!( operation_dispatcher.get_current_operation_state(), Some(State::Pending) @@ -267,14 +272,7 @@ mod tests { #[test] fn operation_dispatcher_next() { - let extension = Rc::new(Extension::default()); - let operation = Operation::new( - extension, - ( - Rc::new(build_grpc_service_handler()), - GrpcMessage::RateLimit(build_message()), - ), - ); + let operation = build_operation(); let operation_dispatcher = OperationDispatcher::default(); operation_dispatcher.push_operations(vec![operation]); diff --git a/src/service.rs b/src/service.rs index 38fe2fce..aec2aff0 100644 --- a/src/service.rs +++ b/src/service.rs @@ -10,7 +10,6 @@ use protobuf::reflect::MessageDescriptor; use protobuf::{ Clear, CodedInputStream, CodedOutputStream, Message, ProtobufResult, UnknownFields, }; -use proxy_wasm::hostcalls; use proxy_wasm::types::{Bytes, MapType, Status}; use std::any::Any; use std::cell::OnceCell; @@ -182,7 +181,7 @@ impl GrpcService { } } -type GrpcCall = fn( +pub type GrpcCall = fn( upstream_name: &str, service_name: &str, method_name: &str, @@ -191,6 +190,8 @@ type GrpcCall = fn( timeout: Duration, ) -> Result; +pub type GetMapValuesBytes = fn(map_type: MapType, key: &str) -> Result, Status>; + pub struct GrpcServiceHandler { service: Rc, header_resolver: Rc, @@ -204,11 +205,16 @@ impl GrpcServiceHandler { } } - pub fn send(&self, grpc_call: GrpcCall, message: GrpcMessage) -> Result { + pub fn send( + &self, + get_map_values_bytes: GetMapValuesBytes, + grpc_call: GrpcCall, + message: GrpcMessage, + ) -> Result { let msg = Message::write_to_bytes(&message).unwrap(); let metadata = self .header_resolver - .get() + .get(get_map_values_bytes) .iter() .map(|(header, value)| (*header, value.as_slice())) .collect(); @@ -249,12 +255,12 @@ impl HeaderResolver { } } - pub fn get(&self) -> &Vec<(&'static str, Bytes)> { + pub fn get(&self, get_map_values_bytes: GetMapValuesBytes) -> &Vec<(&'static str, Bytes)> { self.headers.get_or_init(|| { let mut headers = Vec::new(); for header in TracingHeader::all() { if let Ok(Some(value)) = - hostcalls::get_map_value_bytes(MapType::HttpRequestHeaders, (*header).as_str()) + get_map_values_bytes(MapType::HttpRequestHeaders, (*header).as_str()) { headers.push(((*header).as_str(), value)); }