Skip to content

Commit

Permalink
进行了部分优化
Browse files Browse the repository at this point in the history
  • Loading branch information
侯锐 committed Jan 9, 2019
1 parent 0f66d51 commit bdd59d3
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 99 deletions.
41 changes: 34 additions & 7 deletions core/network/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,36 @@ import (
"net"
)

// 对连接进行包装
type Connection struct {
Conn net.Conn
conn net.Conn
closed bool
}

// 使用已有连接创建自定义连接
func New(conn net.Conn) Connection {
log.Printf("New connection %s -> %s\n", conn.LocalAddr(), conn.RemoteAddr())
return Connection{conn: conn, closed: false}
}

// 创建自定义连接并根据address进行远程连接
func Connect(address string) (Connection, error) {
conn, err := net.Dial("tcp", address)
if err != nil {
log.Println(err)
return Connection{}, err
}
log.Printf("New connection %s -> %s\n", conn.LocalAddr(), conn.RemoteAddr())
return Connection{conn: conn, closed: false}, nil
}

func (c Connection) Read(length uint32) ([]byte, error) {
conn := c.Conn
if c.closed {
return nil, errors.New("Connection has been closed\n")
}

var buf = make([]byte, length)
var bufSize, err = conn.Read(buf)
var bufSize, err = c.conn.Read(buf)
if err != nil {
c.Close()
return nil, err
Expand All @@ -24,8 +45,12 @@ func (c Connection) Read(length uint32) ([]byte, error) {
}

func (c Connection) Write(data []byte) error {
if c.closed {
return errors.New("Connection has been closed\n")
}

// 读操作默认不会超时
_, err := c.Conn.Write(data)
_, err := c.conn.Write(data)
if err != nil {
log.Println(err)
return err
Expand Down Expand Up @@ -74,13 +99,15 @@ func (c Connection) WriteWithLength(source []byte) error {
}

func (c Connection) Close() {
c.Conn.Close()
log.Printf("Close connection %s -> %s\n", c.LocalAddress(), c.RemoteAddress())
c.closed = true
c.conn.Close()
}

func (c Connection) RemoteAddress() string {
return c.Conn.RemoteAddr().String()
return c.conn.RemoteAddr().String()
}

func (c Connection) LocalAddress() string {
return c.Conn.LocalAddr().String()
return c.conn.LocalAddr().String()
}
24 changes: 24 additions & 0 deletions core/network/connections.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package network

import "log"

// 当前所有的连接
var connections = make([]Connection, 50)

func remove(conn Connection) {
for i, c := range connections {
if c == conn {
connections = append(connections[:i], connections[i+1:]...)
return
}
}
log.Printf("%s not in connections\n", conn)
}

func add(conn Connection) {
connections = append(connections, conn)
}

func GetConnections() []Connection {
return connections
}
2 changes: 1 addition & 1 deletion local/assets/html/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

var socket;
if (window.WebSocket) {
socket = new WebSocket('ws://' + window.location.host + '/traffic');
socket = new WebSocket('ws://' + window.location.host + '/ws');
socket.onmessage = function (event) {
console.log('收到信息:' + event.data);
var data = event.data.split(separator);
Expand Down
80 changes: 5 additions & 75 deletions local/http/server.go
Original file line number Diff line number Diff line change
@@ -1,94 +1,24 @@
package http

import (
"fmt"
"github.com/gorilla/websocket"
"github.com/ritterhou/stinger/core/common"
"github.com/ritterhou/stinger/local/resource"
"github.com/ritterhou/stinger/local/socks"
"io"
"log"
"net/http"
"strconv"
"time"
)

var download, upload uint64

// 计算带宽以及流量
func bandwidthTraffic() {
log.Printf("Moniting bandwidth traffic.")

ticker := time.NewTicker(1 * time.Second)
lastDownload := socks.TotalDownload
lastUpload := socks.TotalUpload
for range ticker.C {
t := time.Now()
now := t.Format("2006-01-02 15:04:05")

download = socks.TotalDownload - lastDownload
upload = socks.TotalUpload - lastUpload
if upload != 0 && download != 0 {
fmt.Printf("%s %s ↓ %s ↑", now, common.ByteFormat(download), common.ByteFormat(upload))
fmt.Printf(" (%s ↓ %s ↑)\n", common.ByteFormat(socks.TotalDownload), common.ByteFormat(socks.TotalUpload))
}
lastDownload = socks.TotalDownload
lastUpload = socks.TotalUpload
}
}

var upgrade = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}

// 流量跟踪数据
func traffic(w http.ResponseWriter, req *http.Request) {
conn, err := upgrade.Upgrade(w, req, nil)
if err != nil {
log.Println(err)
return
}

messageType, p, err := conn.ReadMessage()
if err != nil {
log.Println(err)
return
}
separator := string(p)
log.Println("The separator is", separator)

ticker := time.NewTicker(1 * time.Second)
lastDownload := download
lastUpload := upload
for range ticker.C {
if lastDownload != download || lastUpload != upload {
lastDownload = download
lastUpload = upload
message := fmt.Sprintf("%s%s%s", common.ByteFormat(download), separator, common.ByteFormat(upload))
if err := conn.WriteMessage(messageType, []byte(message)); err != nil {
log.Println(err)
conn.Close()
break
}
}
}
log.Println("Stop sending traffic to", conn.RemoteAddr())
}

var indexHtml string

func index(w http.ResponseWriter, req *http.Request) {
io.WriteString(w, indexHtml)
}

func StartServer(port int) {
go bandwidthTraffic()

indexHtml = resource.GetContent("/html/index.html")

// 首页
http.HandleFunc("/", index)
http.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
io.WriteString(w, indexHtml)
})

// PAC文件获取
pacConf := getPac()
Expand All @@ -98,8 +28,8 @@ func StartServer(port int) {
io.WriteString(w, pacConf)
})

// 获取流量以及网速信息
http.HandleFunc("/traffic", traffic)
// WebSocket
http.HandleFunc("/ws", ws)

log.Printf("HTTP Server working on http://0.0.0.0:%d\n", port)
err := http.ListenAndServe("0.0.0.0:"+strconv.Itoa(port), nil)
Expand Down
74 changes: 74 additions & 0 deletions local/http/ws.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package http

import (
"fmt"
"github.com/gorilla/websocket"
"github.com/ritterhou/stinger/core/common"
"github.com/ritterhou/stinger/local/socks"
"log"
"net/http"
"time"
)

var download, upload uint64

// 计算带宽以及流量
func bandwidthTraffic() {
log.Printf("Moniting bandwidth traffic.")

ticker := time.NewTicker(1 * time.Second)
lastDownload := socks.TotalDownload
lastUpload := socks.TotalUpload
for range ticker.C {
t := time.Now()
now := t.Format("2006-01-02 15:04:05")

download = socks.TotalDownload - lastDownload
upload = socks.TotalUpload - lastUpload
if upload != 0 && download != 0 {
fmt.Printf("%s %s ↓ %s ↑", now, common.ByteFormat(download), common.ByteFormat(upload))
fmt.Printf(" (%s ↓ %s ↑)\n", common.ByteFormat(socks.TotalDownload), common.ByteFormat(socks.TotalUpload))
}
lastDownload = socks.TotalDownload
lastUpload = socks.TotalUpload
}
}

var upgrade = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}

// 与网页端进行WebSocket连接
func ws(w http.ResponseWriter, req *http.Request) {
conn, err := upgrade.Upgrade(w, req, nil)
if err != nil {
log.Println(err)
return
}

messageType, p, err := conn.ReadMessage()
if err != nil {
log.Println(err)
return
}
separator := string(p)
log.Println("The separator is", separator)

ticker := time.NewTicker(1 * time.Second)
lastDownload := download
lastUpload := upload
for range ticker.C {
if lastDownload != download || lastUpload != upload {
lastDownload = download
lastUpload = upload
message := fmt.Sprintf("%s%s%s", common.ByteFormat(download), separator, common.ByteFormat(upload))
if err := conn.WriteMessage(messageType, []byte(message)); err != nil {
log.Println(err)
conn.Close()
break
}
}
}
log.Println("Stop sending traffic to", conn.RemoteAddr())
}
2 changes: 1 addition & 1 deletion local/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func startProxyServer(proxyPort int) {
log.Println("Error accepting:", err)
continue
}
go handlerSocks5(network.Connection{Conn: conn})
go handlerSocks5(network.New(conn))
}
}

Expand Down
13 changes: 6 additions & 7 deletions local/socks/socks5.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"github.com/ritterhou/stinger/core/common"
"github.com/ritterhou/stinger/core/network"
"log"
"net"
"strconv"
"sync/atomic"
)
Expand Down Expand Up @@ -110,8 +109,9 @@ func ConnectRemote(conn network.Connection, remoteServer string, password string
port := binary.BigEndian.Uint16(portBytes)
// 构建最终目标的地址
targetAddr := host + ":" + strconv.Itoa(int(port))

// 尝试连接到远程主机
c, err := net.Dial("tcp", remoteServer)
serverConn, err := network.Connect(remoteServer)
if err != nil {
err = conn.Write([]byte{5, 3, 0, 1, 0, 0, 0, 0, 0, 0})
if err != nil {
Expand All @@ -120,7 +120,6 @@ func ConnectRemote(conn network.Connection, remoteServer string, password string
}
return network.Connection{}, errors.New("can't connect to remote server " + remoteServer)
}
serverConn := network.Connection{Conn: c}
// 发送密码进行验证
err = serverConn.WriteWithLength([]byte(password))
if err != nil {
Expand Down Expand Up @@ -178,7 +177,7 @@ func HandlerSocks5Data(localConn network.Connection, remoteConn network.Connecti
// 浏览器 -> local
buf, err := localConn.Read(1024)
if err != nil {
log.Println(localConn.RemoteAddress() + " -> " + err.Error())
//log.Println(localConn.RemoteAddress() + " -> " + err.Error())
remoteConn.Close()
break
}
Expand All @@ -189,7 +188,7 @@ func HandlerSocks5Data(localConn network.Connection, remoteConn network.Connecti
// local -> server
err = remoteConn.WriteWithLength(buf)
if err != nil {
log.Println(remoteConn.RemoteAddress() + " -> " + err.Error())
//log.Println(remoteConn.RemoteAddress() + " -> " + err.Error())
localConn.Close()
break
}
Expand All @@ -201,7 +200,7 @@ func HandlerSocks5Data(localConn network.Connection, remoteConn network.Connecti
// server -> local
buf, err := remoteConn.ReadWithLength()
if err != nil {
log.Println(remoteConn.RemoteAddress() + " -> " + err.Error())
//log.Println(remoteConn.RemoteAddress() + " -> " + err.Error())
localConn.Close()
break
}
Expand All @@ -212,7 +211,7 @@ func HandlerSocks5Data(localConn network.Connection, remoteConn network.Connecti
// local -> 浏览器
err = localConn.Write(buf)
if err != nil {
log.Println(localConn.RemoteAddress() + " -> " + err.Error())
//log.Println(localConn.RemoteAddress() + " -> " + err.Error())
remoteConn.Close()
break
}
Expand Down
Loading

0 comments on commit bdd59d3

Please sign in to comment.