-
Notifications
You must be signed in to change notification settings - Fork 0
/
server.go
278 lines (248 loc) · 7.28 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
// server
package zrpc
import (
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"reflect"
"strings"
"sync"
"time"
"github.com/ilaziness/zrpc/codec"
)
// Option 连接服务器的协商字段
// 连接服务器后,服务器读取出option的内容,根据option的内容处理后续的数据
// | Option | Header1 | Body1 | Header2 | Body2 | ...
// 如上,一次tcp连接,Option协商数据内容信息,一对heade、body代表一次tpc方法调用,可以不断的发送rpc方法调用
type Option struct {
CodecType codec.Type // 数据编码类型
ConnectTimeout time.Duration
HandleTimeout time.Duration
}
// DefaultOption 默认option
var DefaultOption = &Option{
CodecType: codec.GobType,
ConnectTimeout: time.Second * 10,
}
// DefaultServer 默认服务器对象
var DefaultServer = NewServer()
func Accept(lis net.Listener) { DefaultServer.Accept(lis) }
// Server 逻辑实现
type Server struct {
serviceMap sync.Map //服务器提供的服务列表,key是服务器名称
}
func NewServer() *Server {
return &Server{}
}
// Register 注册服务
func (server *Server) Register(rcvr interface{}) error {
s := newService(rcvr)
if _, dup := server.serviceMap.LoadOrStore(s.name, s); dup {
return errors.New("rpc: service already defined: " + s.name)
}
return nil
}
// Register 快捷注册方法,创建默认服务器注册服务
func Register(rcvr interface{}) error { return DefaultServer.Register(rcvr) }
// findService 查找服务
func (server *Server) findService(serviceMethod string) (svc *service, mtype *methodType, err error) {
dot := strings.LastIndex(serviceMethod, ".")
if dot < 0 {
err = errors.New("rpc server: service/method request ill-formed: " + serviceMethod)
return
}
// 服务器列表里面找到对应的服务
serviceName, methodName := serviceMethod[:dot], serviceMethod[dot+1:]
svci, ok := server.serviceMap.Load(serviceName)
if !ok {
err = errors.New("rpc server: can't find service " + serviceName)
return
}
// 拿到服务本身和方法
svc = svci.(*service)
mtype = svc.method[methodName]
if mtype == nil {
err = errors.New("rpc server: can't find method " + methodName)
}
return
}
// Accept 接受连接,然后交给ServeConn处理
func (server *Server) Accept(lis net.Listener) {
for {
conn, err := lis.Accept()
if err != nil {
log.Println("rpc server: accept error:", err)
return
}
go server.ServeConn(conn)
}
}
// ServeConn 处理连接
func (server *Server) ServeConn(conn io.ReadWriteCloser) {
defer func() {
_ = conn.Close()
}()
// 反序列化option,option json格式
var opt Option
if err := json.NewDecoder(conn).Decode(&opt); err != nil {
log.Println("rpc server: options error: ", err)
return
}
log.Println("option:", opt)
// 根据option的CodecType得到对应的编解码器
f := codec.NewCodecFuncMap[opt.CodecType]
if f == nil {
log.Printf("rpc server: invalid codec type %s", opt.CodecType)
return
}
// 实例化编码器后,交给serveCodec处理数据
server.serveCodec(f(conn))
}
var invalidRequest = struct{}{}
// serveCodec 处理数据
func (server *Server) serveCodec(cc codec.Codec) {
sending := new(sync.Mutex)
wg := new(sync.WaitGroup)
// 一次连接可以处理器多个请求
for {
// 读取rpc请求
req, err := server.readRequest(cc)
if err != nil {
// 出错后结束请求
if req == nil {
break // it's not possible to recover, so close the connection
}
req.h.Error = err.Error()
//发送无效请求响应
server.sendResponse(cc, req.h, invalidRequest, sending)
continue
}
wg.Add(1)
// 处理请求
go server.handleRequest(cc, req, sending, wg, DefaultOption.HandleTimeout)
}
wg.Wait()
_ = cc.Close()
}
// request 代表一次rpc 方法请求
type request struct {
h *codec.Header
argv, replyv reflect.Value //参数和响应
mtype *methodType
svc *service
}
// readRequestHeader 读取请求header
func (server *Server) readRequestHeader(cc codec.Codec) (*codec.Header, error) {
var h codec.Header
if err := cc.ReadHeader(&h); err != nil {
if err != io.EOF && err != io.ErrUnexpectedEOF {
log.Println("rpc server: read header error:", err)
}
return nil, err
}
return &h, nil
}
// readRequest 读取rpc请求
func (server *Server) readRequest(cc codec.Codec) (*request, error) {
h, err := server.readRequestHeader(cc)
if err != nil {
return nil, err
}
log.Println("rpc server header:", h)
req := &request{h: h}
//找到请求对应的服务
req.svc, req.mtype, err = server.findService(h.ServiceMethod)
if err != nil {
return req, err
}
req.argv = req.mtype.newArgv()
req.replyv = req.mtype.newReplyv()
// 确保 argvi 是指针类型, ReadBody需要指针类型参数
argvi := req.argv.Interface()
if req.argv.Type().Kind() != reflect.Pointer {
argvi = req.argv.Addr().Interface()
}
// 读取参数
if err = cc.ReadBody(argvi); err != nil {
log.Println("rpc server: read argv err:", err)
}
return req, nil
}
// sendResponse 发送结果
func (server *Server) sendResponse(cc codec.Codec, h *codec.Header, body interface{}, sending *sync.Mutex) {
sending.Lock()
defer sending.Unlock()
if err := cc.Write(h, body); err != nil {
log.Println("rpc server: write response error:", err)
}
}
// handleRequest 处理请求
func (server *Server) handleRequest(cc codec.Codec, req *request, sending *sync.Mutex, wg *sync.WaitGroup, timeout time.Duration) {
defer wg.Done()
called := make(chan int)
sent := make(chan int)
// 调用方法
go func() {
err := req.svc.call(req.mtype, req.argv, req.replyv)
called <- 1
if err != nil {
req.h.Error = err.Error()
server.sendResponse(cc, req.h, invalidRequest, sending)
sent <- 1
return
}
// 返回响应结果
server.sendResponse(cc, req.h, req.replyv.Interface(), sending)
sent <- 1
}()
// 超时控制
if timeout == 0 {
<-called
<-sent
return
}
select {
case <-time.After(timeout):
req.h.Error = fmt.Sprintf("rpc server: request handle timeout: expect within %s", timeout)
server.sendResponse(cc, req.h, invalidRequest, sending)
case <-called:
<-sent
}
}
const (
connected = "200 Connected to Gee RPC"
defaultRPCPath = "/_zprc_"
defaultDebugPath = "/debug/zrpc"
)
// ServeHTTP 实现 http.Handler 接口, 处理RPC请求.
func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Method != "CONNECT" {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusMethodNotAllowed)
_, _ = io.WriteString(w, "405 must CONNECT\n")
return
}
// w断言成Hijacker类型,调用Hijack接管tcp连接,之后http所在的tcp连接将由调用者来控制
conn, _, err := w.(http.Hijacker).Hijack()
if err != nil {
log.Print("rpc hijacking ", req.RemoteAddr, ": ", err.Error())
return
}
_, _ = io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
server.ServeConn(conn)
}
// HandleHTTP registers an HTTP handler for RPC messages on rpcPath.
// It is still necessary to invoke http.Serve(), typically in a go statement.
func (server *Server) HandleHTTP() {
http.Handle(defaultRPCPath, server)
http.Handle(defaultDebugPath, debugHTTP{server})
log.Println("rpc server debug path:", defaultDebugPath)
}
// HandleHTTP is a convenient approach for default server to register HTTP handlers
func HandleHTTP() {
DefaultServer.HandleHTTP()
}