-
Notifications
You must be signed in to change notification settings - Fork 25
/
websocket.go
133 lines (106 loc) · 2.85 KB
/
websocket.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
package azuretls
import (
"context"
"errors"
"fmt"
http "github.com/Noooste/fhttp"
"github.com/Noooste/websocket"
"net"
url2 "net/url"
)
var (
ErrNilRequest = errors.New("request is nil")
)
type Websocket struct {
Url string
Headers http.Header
Request *Request
Response *http.Response
dialer *websocket.Dialer
*websocket.Conn
}
// NewWebsocket returns a new websocket connection.
func (s *Session) NewWebsocket(url string, readBufferSize, writeBufferSize int, args ...any) (*Websocket, error) {
return s.NewWebsocketWithContext(s.ctx, url, readBufferSize, writeBufferSize, args...)
}
// NewWebsocketWithContext returns a new websocket connection with a context.
func (s *Session) NewWebsocketWithContext(ctx context.Context, url string, readBufferSize, writeBufferSize int, args ...any) (*Websocket, error) {
if url == "" {
return nil, errors.New("url is empty")
}
if readBufferSize <= 0 {
readBufferSize = 1024
}
if writeBufferSize <= 0 {
writeBufferSize = 1024
}
req := new(Request)
req.Url = url
if req == nil {
return nil, ErrNilRequest
}
if err := s.prepareRequest(req, args...); err != nil {
return nil, err
}
var (
ws = new(Websocket)
h = make(http.Header)
err error
)
req.HttpRequest = &http.Request{}
req.parsedUrl, err = url2.Parse(req.Url)
if err != nil {
return nil, err
}
if err = s.buildRequest(ctx, req); err != nil {
return nil, err
}
if !req.NoCookie {
cookies := s.CookieJar.Cookies(req.parsedUrl)
if cookies != nil && len(cookies) > 0 {
if c := req.HttpRequest.Header.Get("Cookie"); c != "" {
req.HttpRequest.Header.Set("Cookie", c+"; "+CookiesToString(cookies))
} else {
req.HttpRequest.Header.Set("Cookie", CookiesToString(cookies))
}
}
}
req.ForceHTTP1 = true
if _, err = s.initConn(req); err != nil {
return nil, err
}
ws.dialer = &websocket.Dialer{
HandshakeTimeout: s.TimeOut,
ReadBufferSize: readBufferSize,
WriteBufferSize: writeBufferSize,
EnableCompression: true,
}
ws.dialer.Jar = s.CookieJar
ws.dialer.NetDialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
s.Connections.mu.RLock()
defer s.Connections.mu.RUnlock()
if rc, ok := s.Connections.hosts[addr]; ok {
return rc.TLS, nil
}
return nil, fmt.Errorf("no connection for %s", addr)
}
ws.dialer.NetDialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
s.Connections.mu.RLock()
defer s.Connections.mu.RUnlock()
if rc, ok := s.Connections.hosts[addr]; ok {
return rc.Conn, nil
}
return nil, fmt.Errorf("no connection for %s", addr)
}
c, resp, err := ws.dialer.DialContext(ctx, req.Url, req.HttpRequest.Header, req.HttpRequest.Header[http.HeaderOrderKey])
if err != nil {
return nil, err
}
return &Websocket{
Url: req.Url,
Headers: h,
Conn: c,
Request: req,
Response: resp,
}, nil
}