diff --git a/core/network/connection.go b/core/network/connection.go index a210f72..502f1f8 100644 --- a/core/network/connection.go +++ b/core/network/connection.go @@ -7,15 +7,36 @@ import ( "net" ) +// 对连接进行包装 type Connection struct { - Conn net.Conn + conn net.Conn + closed bool +} + +// 使用已有连接创建自定义连接 +func New(conn net.Conn) Connection { + log.Printf("New connection %s -> %s\n", conn.LocalAddr(), conn.RemoteAddr()) + return Connection{conn: conn, closed: false} +} + +// 创建自定义连接并根据address进行远程连接 +func Connect(address string) (Connection, error) { + conn, err := net.Dial("tcp", address) + if err != nil { + log.Println(err) + return Connection{}, err + } + log.Printf("New connection %s -> %s\n", conn.LocalAddr(), conn.RemoteAddr()) + return Connection{conn: conn, closed: false}, nil } func (c Connection) Read(length uint32) ([]byte, error) { - conn := c.Conn + if c.closed { + return nil, errors.New("Connection has been closed\n") + } var buf = make([]byte, length) - var bufSize, err = conn.Read(buf) + var bufSize, err = c.conn.Read(buf) if err != nil { c.Close() return nil, err @@ -24,8 +45,12 @@ func (c Connection) Read(length uint32) ([]byte, error) { } func (c Connection) Write(data []byte) error { + if c.closed { + return errors.New("Connection has been closed\n") + } + // 读操作默认不会超时 - _, err := c.Conn.Write(data) + _, err := c.conn.Write(data) if err != nil { log.Println(err) return err @@ -74,13 +99,15 @@ func (c Connection) WriteWithLength(source []byte) error { } func (c Connection) Close() { - c.Conn.Close() + log.Printf("Close connection %s -> %s\n", c.LocalAddress(), c.RemoteAddress()) + c.closed = true + c.conn.Close() } func (c Connection) RemoteAddress() string { - return c.Conn.RemoteAddr().String() + return c.conn.RemoteAddr().String() } func (c Connection) LocalAddress() string { - return c.Conn.LocalAddr().String() + return c.conn.LocalAddr().String() } diff --git a/core/network/connections.go b/core/network/connections.go new file mode 100644 index 0000000..249c987 --- /dev/null +++ b/core/network/connections.go @@ -0,0 +1,24 @@ +package network + +import "log" + +// 当前所有的连接 +var connections = make([]Connection, 50) + +func remove(conn Connection) { + for i, c := range connections { + if c == conn { + connections = append(connections[:i], connections[i+1:]...) + return + } + } + log.Printf("%s not in connections\n", conn) +} + +func add(conn Connection) { + connections = append(connections, conn) +} + +func GetConnections() []Connection { + return connections +} diff --git a/local/assets/html/index.html b/local/assets/html/index.html index ba0a5b6..9e1839f 100644 --- a/local/assets/html/index.html +++ b/local/assets/html/index.html @@ -26,7 +26,7 @@ var socket; if (window.WebSocket) { - socket = new WebSocket('ws://' + window.location.host + '/traffic'); + socket = new WebSocket('ws://' + window.location.host + '/ws'); socket.onmessage = function (event) { console.log('收到信息:' + event.data); var data = event.data.split(separator); diff --git a/local/http/server.go b/local/http/server.go index ae27de8..f8745d6 100644 --- a/local/http/server.go +++ b/local/http/server.go @@ -1,94 +1,24 @@ package http import ( - "fmt" - "github.com/gorilla/websocket" - "github.com/ritterhou/stinger/core/common" "github.com/ritterhou/stinger/local/resource" - "github.com/ritterhou/stinger/local/socks" "io" "log" "net/http" "strconv" - "time" ) -var download, upload uint64 - -// 计算带宽以及流量 -func bandwidthTraffic() { - log.Printf("Moniting bandwidth traffic.") - - ticker := time.NewTicker(1 * time.Second) - lastDownload := socks.TotalDownload - lastUpload := socks.TotalUpload - for range ticker.C { - t := time.Now() - now := t.Format("2006-01-02 15:04:05") - - download = socks.TotalDownload - lastDownload - upload = socks.TotalUpload - lastUpload - if upload != 0 && download != 0 { - fmt.Printf("%s %s ↓ %s ↑", now, common.ByteFormat(download), common.ByteFormat(upload)) - fmt.Printf(" (%s ↓ %s ↑)\n", common.ByteFormat(socks.TotalDownload), common.ByteFormat(socks.TotalUpload)) - } - lastDownload = socks.TotalDownload - lastUpload = socks.TotalUpload - } -} - -var upgrade = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, -} - -// 流量跟踪数据 -func traffic(w http.ResponseWriter, req *http.Request) { - conn, err := upgrade.Upgrade(w, req, nil) - if err != nil { - log.Println(err) - return - } - - messageType, p, err := conn.ReadMessage() - if err != nil { - log.Println(err) - return - } - separator := string(p) - log.Println("The separator is", separator) - - ticker := time.NewTicker(1 * time.Second) - lastDownload := download - lastUpload := upload - for range ticker.C { - if lastDownload != download || lastUpload != upload { - lastDownload = download - lastUpload = upload - message := fmt.Sprintf("%s%s%s", common.ByteFormat(download), separator, common.ByteFormat(upload)) - if err := conn.WriteMessage(messageType, []byte(message)); err != nil { - log.Println(err) - conn.Close() - break - } - } - } - log.Println("Stop sending traffic to", conn.RemoteAddr()) -} - var indexHtml string -func index(w http.ResponseWriter, req *http.Request) { - io.WriteString(w, indexHtml) -} - func StartServer(port int) { go bandwidthTraffic() indexHtml = resource.GetContent("/html/index.html") // 首页 - http.HandleFunc("/", index) + http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { + io.WriteString(w, indexHtml) + }) // PAC文件获取 pacConf := getPac() @@ -98,8 +28,8 @@ func StartServer(port int) { io.WriteString(w, pacConf) }) - // 获取流量以及网速信息 - http.HandleFunc("/traffic", traffic) + // WebSocket + http.HandleFunc("/ws", ws) log.Printf("HTTP Server working on http://0.0.0.0:%d\n", port) err := http.ListenAndServe("0.0.0.0:"+strconv.Itoa(port), nil) diff --git a/local/http/ws.go b/local/http/ws.go new file mode 100644 index 0000000..096eb20 --- /dev/null +++ b/local/http/ws.go @@ -0,0 +1,74 @@ +package http + +import ( + "fmt" + "github.com/gorilla/websocket" + "github.com/ritterhou/stinger/core/common" + "github.com/ritterhou/stinger/local/socks" + "log" + "net/http" + "time" +) + +var download, upload uint64 + +// 计算带宽以及流量 +func bandwidthTraffic() { + log.Printf("Moniting bandwidth traffic.") + + ticker := time.NewTicker(1 * time.Second) + lastDownload := socks.TotalDownload + lastUpload := socks.TotalUpload + for range ticker.C { + t := time.Now() + now := t.Format("2006-01-02 15:04:05") + + download = socks.TotalDownload - lastDownload + upload = socks.TotalUpload - lastUpload + if upload != 0 && download != 0 { + fmt.Printf("%s %s ↓ %s ↑", now, common.ByteFormat(download), common.ByteFormat(upload)) + fmt.Printf(" (%s ↓ %s ↑)\n", common.ByteFormat(socks.TotalDownload), common.ByteFormat(socks.TotalUpload)) + } + lastDownload = socks.TotalDownload + lastUpload = socks.TotalUpload + } +} + +var upgrade = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, +} + +// 与网页端进行WebSocket连接 +func ws(w http.ResponseWriter, req *http.Request) { + conn, err := upgrade.Upgrade(w, req, nil) + if err != nil { + log.Println(err) + return + } + + messageType, p, err := conn.ReadMessage() + if err != nil { + log.Println(err) + return + } + separator := string(p) + log.Println("The separator is", separator) + + ticker := time.NewTicker(1 * time.Second) + lastDownload := download + lastUpload := upload + for range ticker.C { + if lastDownload != download || lastUpload != upload { + lastDownload = download + lastUpload = upload + message := fmt.Sprintf("%s%s%s", common.ByteFormat(download), separator, common.ByteFormat(upload)) + if err := conn.WriteMessage(messageType, []byte(message)); err != nil { + log.Println(err) + conn.Close() + break + } + } + } + log.Println("Stop sending traffic to", conn.RemoteAddr()) +} diff --git a/local/main.go b/local/main.go index 0a0b49c..e3292a3 100644 --- a/local/main.go +++ b/local/main.go @@ -75,7 +75,7 @@ func startProxyServer(proxyPort int) { log.Println("Error accepting:", err) continue } - go handlerSocks5(network.Connection{Conn: conn}) + go handlerSocks5(network.New(conn)) } } diff --git a/local/socks/socks5.go b/local/socks/socks5.go index 24fa4b5..480947f 100644 --- a/local/socks/socks5.go +++ b/local/socks/socks5.go @@ -8,7 +8,6 @@ import ( "github.com/ritterhou/stinger/core/common" "github.com/ritterhou/stinger/core/network" "log" - "net" "strconv" "sync/atomic" ) @@ -110,8 +109,9 @@ func ConnectRemote(conn network.Connection, remoteServer string, password string port := binary.BigEndian.Uint16(portBytes) // 构建最终目标的地址 targetAddr := host + ":" + strconv.Itoa(int(port)) + // 尝试连接到远程主机 - c, err := net.Dial("tcp", remoteServer) + serverConn, err := network.Connect(remoteServer) if err != nil { err = conn.Write([]byte{5, 3, 0, 1, 0, 0, 0, 0, 0, 0}) if err != nil { @@ -120,7 +120,6 @@ func ConnectRemote(conn network.Connection, remoteServer string, password string } return network.Connection{}, errors.New("can't connect to remote server " + remoteServer) } - serverConn := network.Connection{Conn: c} // 发送密码进行验证 err = serverConn.WriteWithLength([]byte(password)) if err != nil { @@ -178,7 +177,7 @@ func HandlerSocks5Data(localConn network.Connection, remoteConn network.Connecti // 浏览器 -> local buf, err := localConn.Read(1024) if err != nil { - log.Println(localConn.RemoteAddress() + " -> " + err.Error()) + //log.Println(localConn.RemoteAddress() + " -> " + err.Error()) remoteConn.Close() break } @@ -189,7 +188,7 @@ func HandlerSocks5Data(localConn network.Connection, remoteConn network.Connecti // local -> server err = remoteConn.WriteWithLength(buf) if err != nil { - log.Println(remoteConn.RemoteAddress() + " -> " + err.Error()) + //log.Println(remoteConn.RemoteAddress() + " -> " + err.Error()) localConn.Close() break } @@ -201,7 +200,7 @@ func HandlerSocks5Data(localConn network.Connection, remoteConn network.Connecti // server -> local buf, err := remoteConn.ReadWithLength() if err != nil { - log.Println(remoteConn.RemoteAddress() + " -> " + err.Error()) + //log.Println(remoteConn.RemoteAddress() + " -> " + err.Error()) localConn.Close() break } @@ -212,7 +211,7 @@ func HandlerSocks5Data(localConn network.Connection, remoteConn network.Connecti // local -> 浏览器 err = localConn.Write(buf) if err != nil { - log.Println(localConn.RemoteAddress() + " -> " + err.Error()) + //log.Println(localConn.RemoteAddress() + " -> " + err.Error()) remoteConn.Close() break } diff --git a/server/main.go b/server/main.go index f2d3bb9..aeb779b 100644 --- a/server/main.go +++ b/server/main.go @@ -63,8 +63,7 @@ func startProxyServer(proxyPort int) { } //log.Printf("Connection established %s -> %s \n", conn.RemoteAddr(), conn.LocalAddr()) - c := network.Connection{Conn: conn} - go handlerClient(c) + go handlerClient(network.New(conn)) } } @@ -99,7 +98,7 @@ func handlerClient(localConn network.Connection) { targetAddr := string(targetAddrBytes) //log.Println(targetAddr) - c, err := net.Dial("tcp", targetAddr) + remoteConn, err := network.Connect(targetAddr) if err != nil { log.Println("can't connect to target address", targetAddr) err = localConn.Write([]byte{1}) // 远程主机连接失败 @@ -116,14 +115,13 @@ func handlerClient(localConn network.Connection) { log.Println(err) return } - remoteConn := network.Connection{Conn: c} go func() { for { // local -> server buf, err := localConn.ReadWithLength() if err != nil { - log.Println(localConn.RemoteAddress() + " -> " + err.Error()) + //log.Println(localConn.RemoteAddress() + " -> " + err.Error()) remoteConn.Close() break } @@ -131,7 +129,7 @@ func handlerClient(localConn network.Connection) { // server -> remote err = remoteConn.Write(buf) if err != nil { - log.Println(remoteConn.RemoteAddress() + " -> " + err.Error()) + //log.Println(remoteConn.RemoteAddress() + " -> " + err.Error()) localConn.Close() break } @@ -143,7 +141,7 @@ func handlerClient(localConn network.Connection) { // remote -> server buf, err := remoteConn.Read(1024) if err != nil { - log.Println(remoteConn.RemoteAddress() + " -> " + err.Error()) + //log.Println(remoteConn.RemoteAddress() + " -> " + err.Error()) localConn.Close() break } @@ -151,7 +149,7 @@ func handlerClient(localConn network.Connection) { // server -> local err = localConn.WriteWithLength(buf) if err != nil { - log.Println(localConn.RemoteAddress() + " -> " + err.Error()) + //log.Println(localConn.RemoteAddress() + " -> " + err.Error()) remoteConn.Close() break }