From 6f101b1000fbcd235462029aeddef79c19c2213a Mon Sep 17 00:00:00 2001 From: wangbo Date: Fri, 1 Sep 2023 11:25:33 +0800 Subject: [PATCH] dev --- gateway/internal/eventhandler.go | 20 ++++++++++++++++++-- gateway/server.go | 12 +++++++++++- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/gateway/internal/eventhandler.go b/gateway/internal/eventhandler.go index e3a0f0c2df8f..76d02e48df5d 100644 --- a/gateway/internal/eventhandler.go +++ b/gateway/internal/eventhandler.go @@ -15,19 +15,32 @@ type EventHandler struct { Status *status.Status writer io.Writer marshaler jsonpb.Marshaler + + Message proto.Message + RespHandler func(writer io.Writer, status *status.Status, message proto.Message) } -func NewEventHandler(writer io.Writer, resolver jsonpb.AnyResolver) *EventHandler { - return &EventHandler{ +type HandlerOption func(handler *EventHandler) + +func NewEventHandler(writer io.Writer, resolver jsonpb.AnyResolver, opts ...HandlerOption) *EventHandler { + handler := &EventHandler{ writer: writer, marshaler: jsonpb.Marshaler{ EmitDefaults: true, AnyResolver: resolver, }, } + for _, opt := range opts { + opt(handler) + } + return handler } func (h *EventHandler) OnReceiveResponse(message proto.Message) { + if h.RespHandler != nil { + h.Message = message + return + } if err := h.marshaler.Marshal(h.writer, message); err != nil { logx.Error(err) } @@ -35,6 +48,9 @@ func (h *EventHandler) OnReceiveResponse(message proto.Message) { func (h *EventHandler) OnReceiveTrailers(status *status.Status, _ metadata.MD) { h.Status = status + if h.RespHandler != nil { + h.RespHandler(h.writer, h.Status, h.Message) + } } func (h *EventHandler) OnResolveMethod(_ *desc.MethodDescriptor) { diff --git a/gateway/server.go b/gateway/server.go index 71d1e554eff0..d944d336fb47 100644 --- a/gateway/server.go +++ b/gateway/server.go @@ -3,11 +3,13 @@ package gateway import ( "context" "fmt" + "io" "net/http" "strings" "github.com/fullstorydev/grpcurl" "github.com/golang/protobuf/jsonpb" + "github.com/golang/protobuf/proto" "github.com/jhump/protoreflect/grpcreflect" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/mr" @@ -16,6 +18,7 @@ import ( "github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/zrpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) type ( @@ -25,6 +28,7 @@ type ( upstreams []Upstream processHeader func(http.Header) []string dialer func(conf zrpc.RpcClientConf) zrpc.Client + respHandler func(writer io.Writer, status *status.Status, message proto.Message) } // Option defines the method to customize Server. @@ -55,6 +59,10 @@ func (s *Server) Stop() { s.Server.Stop() } +func (s *Server) SetRespHandler(handler func(writer io.Writer, status *status.Status, message proto.Message)) { + s.respHandler = handler +} + func (s *Server) build() error { if err := s.ensureUpstreamNames(); err != nil { return err @@ -128,7 +136,9 @@ func (s *Server) buildHandler(source grpcurl.DescriptorSource, resolver jsonpb.A } w.Header().Set(httpx.ContentType, httpx.JsonContentType) - handler := internal.NewEventHandler(w, resolver) + handler := internal.NewEventHandler(w, resolver, func(eventHandler *internal.EventHandler) { + eventHandler.RespHandler = s.respHandler + }) if err := grpcurl.InvokeRPC(r.Context(), source, cli.Conn(), rpcPath, s.prepareMetadata(r.Header), handler, parser.Next); err != nil { httpx.ErrorCtx(r.Context(), w, err)