Skip to content

Commit

Permalink
dialer: support client-side tcp fastopen
Browse files Browse the repository at this point in the history
Signed-off-by: hexian000 <hexian000@outlook.com>
  • Loading branch information
hexian000 committed Sep 23, 2023
1 parent b7ab27e commit 9be599d
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 77 deletions.
158 changes: 83 additions & 75 deletions src/dialer.c
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,9 @@ dialer_stop(struct dialer *restrict d, struct ev_loop *loop, const bool ok)
case STATE_DONE:
break;
}
if (!ok && d->fd != -1) {
CLOSE_FD(d->fd);
d->fd = -1;
if (!ok && d->w_socket.fd != -1) {
CLOSE_FD(d->w_socket.fd);
ev_io_set(&d->w_socket, -1, EV_NONE);
}
assert(!ev_is_active(&d->w_socket));
d->state = STATE_DONE;
Expand All @@ -269,7 +269,7 @@ static bool
send_req(struct dialer *restrict d, const unsigned char *buf, const size_t len)
{
LOG_BIN_F(LOG_LEVEL_VERBOSE, buf, len, "send: %zu bytes", len);
const ssize_t nsend = send(d->fd, buf, len, 0);
const ssize_t nsend = send(d->w_socket.fd, buf, len, 0);
if (nsend < 0) {
const int err = errno;
LOGE_F("send: %s", strerror(err));
Expand Down Expand Up @@ -322,11 +322,14 @@ send_http_req(struct dialer *restrict d, const struct dialaddr *restrict addr)
b += n;
APPEND(b, " HTTP/1.1\r\n\r\n");

socket_rcvlowat(d->fd, STRLENB("HTTP/2 200 \r\n\r\n"));
if (!send_req(d, (unsigned char *)buf, (size_t)(b - buf))) {
return false;
}
socket_rcvlowat(d->w_socket.fd, STRLENB("HTTP/2 200 \r\n\r\n"));
#undef APPEND
#undef STRLENB
#undef STRLEN
return send_req(d, (unsigned char *)buf, (size_t)(b - buf));
return true;
}

static bool send_socks4a_req(
Expand Down Expand Up @@ -378,8 +381,11 @@ static bool send_socks4a_req(
len += n + 1;
} break;
}
socket_rcvlowat(d->fd, SOCKS4_RSP_MINLEN);
return send_req(d, buf, len);
if (!send_req(d, buf, len)) {
return false;
}
socket_rcvlowat(d->w_socket.fd, SOCKS4_RSP_MINLEN);
return true;
}

static bool send_socks5_auth(
Expand Down Expand Up @@ -447,8 +453,11 @@ send_socks5_req(struct dialer *restrict d, const struct dialaddr *restrict addr)
len += sizeof(uint16_t);
} break;
}
socket_rcvlowat(d->fd, SOCKS5_RSP_MINLEN);
return send_req(d, buf, len);
if (!send_req(d, buf, len)) {
return false;
}
socket_rcvlowat(d->w_socket.fd, SOCKS5_RSP_MINLEN);
return true;
}

static bool send_proxy_req(struct dialer *restrict d)
Expand Down Expand Up @@ -483,7 +492,7 @@ static bool send_proxy_req(struct dialer *restrict d)
static bool consume_rcvbuf(struct dialer *restrict d, const size_t n)
{
LOGV_F("consume_rcvbuf: %zu bytes", n);
const ssize_t nrecv = recv(d->fd, d->buf.data, n, 0);
const ssize_t nrecv = recv(d->w_socket.fd, d->buf.data, n, 0);
if (nrecv < 0) {
const int err = errno;
LOGE_F("recv: %s", strerror(err));
Expand Down Expand Up @@ -669,10 +678,9 @@ recv_dispatch(struct dialer *restrict d, const struct proxy_req *restrict req)
FAIL();
}

static int dialer_recv(
struct dialer *restrict d, const int fd,
const struct proxy_req *restrict req)
static int dialer_recv(struct dialer *restrict d)
{
const int fd = d->w_socket.fd;
const ssize_t nrecv = recv(fd, d->buf.data, d->buf.cap, MSG_PEEK);
if (nrecv < 0) {
const int err = errno;
Expand All @@ -698,11 +706,12 @@ static int dialer_recv(
LOG_BIN_F(
LOG_LEVEL_VERBOSE, d->buf.data, d->buf.len, "recv: %zu bytes",
d->buf.len);
const int ret = recv_dispatch(d, req);

const int ret = recv_dispatch(d, &d->req->proxy[d->jump]);
if (ret < 0) {
return ret;
} else if (ret == 0) {
socket_rcvlowat(d->fd, 1);
socket_rcvlowat(d->w_socket.fd, 1);
return 0;
}
const size_t want = d->buf.len + (size_t)ret;
Expand All @@ -714,64 +723,50 @@ static int dialer_recv(
return 1;
}

static int on_connected(struct dialer *restrict d, const int fd)
{
assert(d->state == STATE_CONNECT);
const struct dialreq *restrict req = d->req;
const int sockerr = socket_get_error(fd);
if (sockerr != 0) {
LOGE_F("connect: %s", strerror(sockerr));
d->syserr = sockerr;
return -1;
}

if (d->jump >= req->num_proxy) {
return 0;
}

d->state = STATE_HANDSHAKE1;
if (!send_proxy_req(d)) {
return -1;
}
return 1;
}

static void socket_cb(struct ev_loop *loop, struct ev_io *watcher, int revents)
{
CHECK_EV_ERROR(revents);
struct dialer *restrict d = watcher->data;
const int fd = watcher->fd;

if (revents & EV_WRITE) {
ev_io_stop(loop, watcher);
const int ret = on_connected(d, fd);
if (ret < 0) {
assert(d->state == STATE_CONNECT);
const int sockerr = socket_get_error(d->w_socket.fd);
if (sockerr != 0) {
LOGE_F("connect: %s", strerror(sockerr));
d->syserr = sockerr;
DIALER_RETURN(d, loop, false);
} else if (ret == 0) {
}
if (d->req->num_proxy == 0) {
DIALER_RETURN(d, loop, true);
}
ev_io_set(watcher, fd, EV_READ);
ev_io_start(loop, watcher);
return;
d->state = STATE_HANDSHAKE1;
if (!send_proxy_req(d)) {
DIALER_RETURN(d, loop, false);
}
modify_io_events(loop, watcher, EV_READ);
}

assert(d->state == STATE_HANDSHAKE1 || d->state == STATE_HANDSHAKE2);
const int ret = dialer_recv(d, fd, &d->req->proxy[d->jump]);
if (ret < 0) {
DIALER_RETURN(d, loop, false);
} else if (ret > 0) {
/* want more data */
return;
}
if (revents & EV_READ) {
assert(d->state == STATE_HANDSHAKE1 ||
d->state == STATE_HANDSHAKE2);
const int ret = dialer_recv(d);
if (ret < 0) {
DIALER_RETURN(d, loop, false);
} else if (ret > 0) {
/* want more data */
return;
}

d->buf.len = 0;
d->jump++;
if (d->jump >= d->req->num_proxy) {
DIALER_RETURN(d, loop, true);
}
d->buf.len = 0;
d->jump++;
if (d->jump >= d->req->num_proxy) {
DIALER_RETURN(d, loop, true);
}

d->state = STATE_HANDSHAKE1;
if (!send_proxy_req(d)) {
DIALER_RETURN(d, loop, false);
d->state = STATE_HANDSHAKE1;
if (!send_proxy_req(d)) {
DIALER_RETURN(d, loop, false);
}
}
}

Expand Down Expand Up @@ -801,27 +796,41 @@ static bool connect_sa(
#endif
socket_set_tcp(fd, conf->tcp_nodelay, conf->tcp_keepalive);
socket_set_buffer(fd, conf->tcp_sndbuf, conf->tcp_rcvbuf);
#if WITH_TCP_FASTOPEN
if (conf->tcp_fastopen) {
socket_set_fastopen_connect(fd, true);
}
#endif
ev_io_set(&d->w_socket, fd, EV_NONE);
if (LOGLEVEL(LOG_LEVEL_VERBOSE)) {
char addr_str[64];
format_sa(sa, addr_str, sizeof(addr_str));
LOG_F(LOG_LEVEL_VERBOSE, "dialer: connect \"%s\"", addr_str);
}
d->state = STATE_CONNECT;
if (connect(fd, sa, getsocklen(sa)) != 0) {
const int err = errno;
if (err != EINTR && err != EINPROGRESS) {
LOGE_F("connect: %s", strerror(err));
CLOSE_FD(fd);
d->syserr = err;
CLOSE_FD(fd);
return false;
}
modify_io_events(loop, &d->w_socket, EV_WRITE);
return true;
}
if (LOGLEVEL(LOG_LEVEL_VERBOSE)) {
char addr_str[64];
format_sa(sa, addr_str, sizeof(addr_str));
LOG_F(LOG_LEVEL_VERBOSE, "dialer: CONNECT \"%s\"", addr_str);
}
d->fd = fd;

struct ev_io *restrict w_socket = &d->w_socket;
ev_io_set(w_socket, fd, EV_WRITE);
ev_io_start(loop, w_socket);
if (d->req->num_proxy == 0) {
modify_io_events(loop, &d->w_socket, EV_WRITE);
return true;
}

d->state = STATE_CONNECT;
d->state = STATE_HANDSHAKE1;
if (!send_proxy_req(d)) {
CLOSE_FD(fd);
return false;
}
modify_io_events(loop, &d->w_socket, EV_READ);
return true;
}

Expand Down Expand Up @@ -876,7 +885,6 @@ void dialer_init(struct dialer *restrict d, const struct event_cb cb)
d->resolve_handle = INVALID_HANDLE;
d->jump = 0;
d->state = STATE_INIT;
d->fd = -1;
d->syserr = 0;
{
struct ev_io *restrict w_socket = &d->w_socket;
Expand Down Expand Up @@ -940,7 +948,7 @@ void dialer_start(
int dialer_get(struct dialer *d)
{
assert(d->state == STATE_DONE);
return d->fd;
return d->w_socket.fd;
}

void dialer_cancel(struct dialer *restrict d, struct ev_loop *loop)
Expand Down
2 changes: 1 addition & 1 deletion src/dialer.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ struct dialer {
handle_t resolve_handle;
size_t jump;
int state;
int fd, syserr;
int syserr;
struct ev_io w_socket;
struct {
BUFFER_HDR;
Expand Down
2 changes: 1 addition & 1 deletion src/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ static void print_usage(const char *argv0)
" --reuseport allow multiple instances to listen on the same address\n"
#endif
#if WITH_TCP_FASTOPEN
" --no-fastopen disable server-side TCP fast open (RFC 7413)\n"
" --no-fastopen disable TCP fast open (RFC 7413)\n"
#endif
#if WITH_TPROXY
" --tproxy operate as a transparent proxy\n"
Expand Down
2 changes: 2 additions & 0 deletions src/ruleset.c
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,8 @@ static int api_setidle_(lua_State *restrict L)

static int luaopen_neosocksd(lua_State *restrict L)
{
lua_newtable(L);
lua_seti(L, LUA_REGISTRYINDEX, RIDX_ASYNC_CALLBACKS);
const luaL_Reg apilib[] = {
{ "invoke", api_invoke_ },
{ "resolve", api_resolve_ },
Expand Down
18 changes: 18 additions & 0 deletions src/sockutil.c
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,24 @@ void socket_set_fastopen(const int fd, const int backlog)
#endif
}

void socket_set_fastopen_connect(const int fd, const bool enabled)
{
#ifdef TCP_FASTOPEN_CONNECT
int val = enabled ? 1 : 0;
if (setsockopt(
fd, IPPROTO_TCP, TCP_FASTOPEN_CONNECT, &val, sizeof(val))) {
const int err = errno;
LOGW_F("TCP_FASTOPEN_CONNECT: %s", strerror(err));
}
#else
UNUSED(fd);
if (enabled) {
LOGW_F("TCP_FASTOPEN_CONNECT: %s",
"not supported in current build");
}
#endif
}

void socket_set_buffer(const int fd, const size_t send, const size_t recv)
{
int val;
Expand Down
1 change: 1 addition & 0 deletions src/sockutil.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ bool socket_set_nonblock(int fd);
void socket_set_reuseport(int fd, bool reuseport);
void socket_set_tcp(int fd, bool nodelay, bool keepalive);
void socket_set_fastopen(int fd, int backlog);
void socket_set_fastopen_connect(int fd, bool enabled);
void socket_set_buffer(int fd, size_t send, size_t recv);
void socket_bind_netdev(int fd, const char *netdev);
void socket_set_transparent(int fd, bool tproxy);
Expand Down
27 changes: 27 additions & 0 deletions src/util.c
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,33 @@ void uninit(void)
resolver_atexit_cb();
}

void modify_io_events(
struct ev_loop *loop, struct ev_io *restrict watcher, const int events)
{
const int fd = watcher->fd;
assert(fd != -1);
const int ioevents = events & (EV_READ | EV_WRITE);
if (ioevents == EV_NONE) {
if (ev_is_active(watcher)) {
LOGD_F("io fd=%d stop", fd);
ev_io_stop(loop, watcher);
}
return;
}
if (ioevents != (watcher->events & (EV_READ | EV_WRITE))) {
ev_io_stop(loop, watcher);
#ifdef ev_io_modify
ev_io_modify(watcher, ioevents);
#else
ev_io_set(watcher, fd, ioevents);
#endif
}
if (!ev_is_active(watcher)) {
LOGD_F("io fd=%d events=0x%x", fd, ioevents);
ev_io_start(loop, watcher);
}
}

void drop_privileges(const char *user)
{
if (getuid() != 0) {
Expand Down
3 changes: 3 additions & 0 deletions src/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ typedef uintptr_t handle_t;
} while (0)

struct ev_loop;
struct ev_io;

void modify_io_events(struct ev_loop *loop, struct ev_io *watcher, int events);

struct event_cb {
void (*cb)(struct ev_loop *loop, void *ctx);
Expand Down

0 comments on commit 9be599d

Please sign in to comment.