diff --git a/src/operation_dispatcher.rs b/src/operation_dispatcher.rs index 91efc538..2f038eec 100644 --- a/src/operation_dispatcher.rs +++ b/src/operation_dispatcher.rs @@ -96,7 +96,7 @@ impl OperationDispatcher { policy.actions.iter().for_each(|action| { // TODO(didierofrivia): Error handling if let Some(service) = self.service_handlers.get(&action.extension) { - let message = service.build_message(policy.domain.clone(), descriptors.clone()); + let message = GrpcMessage::new(service.get_extension_type(), policy.domain.clone(), descriptors.clone()); operations.push(Operation::new((service.clone(), message))) } }); diff --git a/src/service.rs b/src/service.rs index 4d553b77..b28a162f 100644 --- a/src/service.rs +++ b/src/service.rs @@ -2,33 +2,144 @@ pub(crate) mod auth; pub(crate) mod rate_limit; use crate::configuration::{ExtensionType, FailureMode}; -use crate::envoy::{RateLimitDescriptor, RateLimitRequest}; -use crate::service::auth::{AUTH_METHOD_NAME, AUTH_SERVICE_NAME}; +use crate::envoy::{CheckRequest, RateLimitDescriptor, RateLimitRequest}; +use crate::service::auth::{AuthService, AUTH_METHOD_NAME, AUTH_SERVICE_NAME}; use crate::service::rate_limit::{RateLimitService, RATELIMIT_METHOD_NAME, RATELIMIT_SERVICE_NAME}; use crate::service::TracingHeader::{Baggage, Traceparent, Tracestate}; -use protobuf::Message; +use protobuf::reflect::MessageDescriptor; +use protobuf::{ + Clear, CodedInputStream, CodedOutputStream, Message, ProtobufResult, UnknownFields, +}; use proxy_wasm::hostcalls; use proxy_wasm::hostcalls::dispatch_grpc_call; use proxy_wasm::types::{Bytes, MapType, Status}; +use std::any::Any; use std::cell::OnceCell; +use std::fmt::{Debug}; use std::rc::Rc; use std::time::Duration; -#[derive(Clone)] +#[derive(Clone, Debug)] pub enum GrpcMessage { - //Auth(CheckRequest), + Auth(CheckRequest), RateLimit(RateLimitRequest), } -impl GrpcMessage { - pub fn get_message(&self) -> &RateLimitRequest { - //TODO(didierofrivia): Should return Message +impl Default for GrpcMessage { + fn default() -> Self { + GrpcMessage::RateLimit(RateLimitRequest::new()) + } +} + +impl Clear for GrpcMessage { + fn clear(&mut self) { match self { - GrpcMessage::RateLimit(message) => message, + GrpcMessage::Auth(msg) => msg.clear(), + GrpcMessage::RateLimit(msg) => msg.clear(), } } } +impl Message for GrpcMessage { + fn descriptor(&self) -> &'static MessageDescriptor { + match self { + GrpcMessage::Auth(msg) => msg.descriptor(), + GrpcMessage::RateLimit(msg) => msg.descriptor(), + } + } + + fn is_initialized(&self) -> bool { + match self { + GrpcMessage::Auth(msg) => msg.is_initialized(), + GrpcMessage::RateLimit(msg) => msg.is_initialized(), + } + } + + fn merge_from(&mut self, is: &mut CodedInputStream) -> ProtobufResult<()> { + match self { + GrpcMessage::Auth(msg) => msg.merge_from(is), + GrpcMessage::RateLimit(msg) => msg.merge_from(is), + } + } + + fn write_to_with_cached_sizes(&self, os: &mut CodedOutputStream) -> ProtobufResult<()> { + match self { + GrpcMessage::Auth(msg) => msg.write_to_with_cached_sizes(os), + GrpcMessage::RateLimit(msg) => msg.write_to_with_cached_sizes(os), + } + } + + fn write_to_bytes(&self) -> ProtobufResult> { + match self { + GrpcMessage::Auth(msg) => msg.write_to_bytes(), + GrpcMessage::RateLimit(msg) => msg.write_to_bytes(), + } + } + + fn compute_size(&self) -> u32 { + match self { + GrpcMessage::Auth(msg) => msg.compute_size(), + GrpcMessage::RateLimit(msg) => msg.compute_size(), + } + } + + fn get_cached_size(&self) -> u32 { + match self { + GrpcMessage::Auth(msg) => msg.get_cached_size(), + GrpcMessage::RateLimit(msg) => msg.get_cached_size(), + } + } + + fn get_unknown_fields(&self) -> &UnknownFields { + match self { + GrpcMessage::Auth(msg) => msg.get_unknown_fields(), + GrpcMessage::RateLimit(msg) => msg.get_unknown_fields(), + } + } + + fn mut_unknown_fields(&mut self) -> &mut UnknownFields { + match self { + GrpcMessage::Auth(msg) => msg.mut_unknown_fields(), + GrpcMessage::RateLimit(msg) => msg.mut_unknown_fields(), + } + } + + fn as_any(&self) -> &dyn Any { + match self { + GrpcMessage::Auth(msg) => msg.as_any(), + GrpcMessage::RateLimit(msg) => msg.as_any(), + } + } + + fn new() -> Self + where + Self: Sized, + { + // Returning default value + GrpcMessage::default() + } + + fn default_instance() -> &'static Self + where + Self: Sized, + { + #[allow(non_upper_case_globals)] + static instance: ::protobuf::rt::LazyV2 = ::protobuf::rt::LazyV2::INIT; + instance.get(|| GrpcMessage::RateLimit(RateLimitRequest::new())) + } +} + +impl GrpcMessage { + // Using domain as ce_host for the time being, we might pass a DataType in the future. + pub fn new(extension_type: ExtensionType, domain: String, descriptors: protobuf::RepeatedField) -> Self { + match extension_type { + ExtensionType::RateLimit => GrpcMessage::RateLimit(RateLimitService::message(domain.clone(), descriptors)), + ExtensionType::Auth => GrpcMessage::Auth(AuthService::message(domain.clone())) + } + } + +} + #[derive(Default)] pub struct GrpcService { endpoint: String, @@ -102,7 +213,7 @@ impl GrpcServiceHandler { } pub fn send(&self, message: GrpcMessage) -> Result { - let msg = Message::write_to_bytes(message.get_message()).unwrap(); + let msg = Message::write_to_bytes(&message).unwrap(); let metadata = self .header_resolver .get() @@ -120,18 +231,8 @@ impl GrpcServiceHandler { ) } - // Using domain as ce_host for the time being, we might pass a DataType in the future. - //TODO(didierofrivia): Make it work with Message. for both Auth and RL - pub fn build_message( - &self, - domain: String, - descriptors: protobuf::RepeatedField, - ) -> GrpcMessage { - /*match self.service.extension_type { - //ExtensionType::Auth => GrpcMessage::Auth(AuthService::message(domain.clone())), - //ExtensionType::RateLimit => GrpcMessage::RateLimit(RateLimitService::message(domain.clone(), descriptors)), - }*/ - GrpcMessage::RateLimit(RateLimitService::message(domain.clone(), descriptors)) + pub fn get_extension_type(&self) -> ExtensionType { + self.service.extension_type.clone() } }