diff --git a/src/dialer.c b/src/dialer.c index f745440..c27c66c 100644 --- a/src/dialer.c +++ b/src/dialer.c @@ -120,28 +120,29 @@ int dialaddr_format( FAIL(); } -struct dialreq *dialreq_new(const size_t max_proxy) +struct dialreq *dialreq_new(const size_t num_proxy) { struct dialreq *restrict req = malloc( - sizeof(struct dialreq) + sizeof(struct proxy_req) * max_proxy); + sizeof(struct dialreq) + sizeof(struct proxy_req) * num_proxy); if (req == NULL) { return NULL; } - req->num_proxy = 0; + req->num_proxy = num_proxy; return req; } -bool dialreq_proxy( - struct dialreq *restrict req, const char *proxy_uri, size_t len) +bool dialreq_setproxy( + struct dialreq *restrict req, const size_t i, const char *proxy_uri, + const size_t urilen) { /* should be more than enough */ - if (len >= 1024) { + if (urilen >= 1024) { LOGE_F("proxy uri is too long: \"%s\"", proxy_uri); return false; } - char buf[len + 1]; - memcpy(buf, proxy_uri, len); - buf[len] = '\0'; + char buf[urilen + 1]; + memcpy(buf, proxy_uri, urilen); + buf[urilen] = '\0'; struct url uri; if (!url_parse(buf, &uri)) { LOGE_F("unable to parse uri: \"%s\"", proxy_uri); @@ -163,14 +164,12 @@ bool dialreq_proxy( LOGE_F("dialer: unknown scheme \"%s\"", uri.scheme); return false; } - const size_t n = req->num_proxy; - struct proxy_req *restrict proxy = &req->proxy[n]; + struct proxy_req *restrict proxy = &req->proxy[i]; proxy->proto = protocol; const size_t hostlen = strlen(uri.host); if (!dialaddr_set(&proxy->addr, uri.host, hostlen)) { return false; } - req->num_proxy = n + 1; return true; } @@ -179,10 +178,10 @@ struct dialreq *dialreq_parse(const char *csv) const size_t len = strlen(csv); char buf[len + 1]; memcpy(buf, csv, len + 1); - size_t n = 1; + size_t n = 0; for (size_t i = 0; i < len; i++) { if (buf[i] == ',') { - n++; + ++n; } } struct dialreq *req = dialreq_new(n); @@ -193,6 +192,7 @@ struct dialreq *dialreq_parse(const char *csv) bool direct = true; for (char *tok = strtok(buf, ","); tok != NULL; tok = strtok(NULL, ",")) { + LOGE_F("tok: %s", tok); if (direct) { if (!dialaddr_set(&req->addr, tok, strlen(tok))) { dialreq_free(req); @@ -201,7 +201,7 @@ struct dialreq *dialreq_parse(const char *csv) direct = false; continue; } - if (!dialreq_proxy(req, tok, strlen(tok))) { + if (!dialreq_setproxy(req, --n, tok, strlen(tok))) { dialreq_free(req); return NULL; } @@ -560,12 +560,13 @@ static int recv_socks4a_rsp(struct dialer *restrict d) static int recv_socks5_rsp(struct dialer *restrict d) { assert(d->state == STATE_HANDSHAKE2); - const size_t rsplen = sizeof(struct socks5_hdr) + - sizeof(struct in6_addr) + sizeof(in_port_t); - if (d->buf.len < rsplen) { - return (int)(rsplen - d->buf.len); - } const unsigned char *hdr = d->buf.data; + const size_t len = d->buf.len; + size_t expected = sizeof(struct socks5_hdr); + if (len < expected) { + return (int)(expected - len) + 1; + } + const uint8_t version = read_uint8(hdr + offsetof(struct socks5_hdr, version)); if (version != SOCKS5) { @@ -578,8 +579,24 @@ static int recv_socks5_rsp(struct dialer *restrict d) LOGE_F("SOCKS5 failure: %" PRIu8, command); return -1; } + const uint8_t addrtype = + read_uint8(hdr + offsetof(struct socks5_hdr, addrtype)); + switch (addrtype) { + case SOCKS5ADDR_IPV4: + expected += sizeof(struct in_addr) + sizeof(in_port_t); + break; + case SOCKS5ADDR_IPV6: + expected += sizeof(struct in6_addr) + sizeof(in_port_t); + break; + default: + LOGE_F("SOCKS5 invalid addrtype: %" PRIu8, addrtype); + return -1; + } + if (len < expected) { + return (int)(expected - len); + } /* protocol finished, remove header */ - if (!consume_rcvbuf(d, rsplen)) { + if (!consume_rcvbuf(d, expected)) { return -1; } return 0; @@ -609,7 +626,11 @@ static int recv_socks5_auth(struct dialer *restrict d) if (!consume_rcvbuf(d, rsplen)) { return -1; } + BUF_CONSUME(d->buf, rsplen); d->state = STATE_HANDSHAKE2; + if (!send_proxy_req(d)) { + return -1; + } return recv_socks5_rsp(d); } @@ -666,18 +687,19 @@ static int dialer_recv( LOG_BIN_F( LOG_LEVEL_VERBOSE, d->buf.data, d->buf.len, "recv: %zu bytes", d->buf.len); - const int want = recv_dispatch(d, req); - if (want < 0) { - return want; - } else if (want == 0) { + const int ret = recv_dispatch(d, req); + if (ret < 0) { + return ret; + } else if (ret == 0) { socket_rcvlowat(d->fd, 1); return 0; } - if (d->buf.len + (size_t)want > d->buf.cap) { + const size_t want = d->buf.len + (size_t)ret; + if (want > d->buf.cap) { LOGE("recv: header too long"); return -1; } - socket_rcvlowat(fd, (size_t)nrecv + (size_t)want); + socket_rcvlowat(fd, want); return 1; } diff --git a/src/dialer.h b/src/dialer.h index 9eb5484..b28bde3 100644 --- a/src/dialer.h +++ b/src/dialer.h @@ -52,8 +52,9 @@ struct dialreq { struct proxy_req proxy[]; }; -struct dialreq *dialreq_new(size_t max_proxy); -bool dialreq_proxy(struct dialreq *r, const char *proxy_uri, size_t len); +struct dialreq *dialreq_new(size_t num_proxy); +bool dialreq_setproxy( + struct dialreq *r, size_t i, const char *proxy_uri, size_t urilen); struct dialreq *dialreq_parse(const char *csv); void dialreq_free(struct dialreq *r); diff --git a/src/ruleset.c b/src/ruleset.c index 1a14e88..c02709d 100644 --- a/src/ruleset.c +++ b/src/ruleset.c @@ -72,7 +72,7 @@ static struct dialreq *pop_dialreq(lua_State *restrict L, int n) LOGE_F("ruleset: returned address #%d is not a string", 1); return NULL; } - struct dialreq *req = dialreq_new((size_t)n); + struct dialreq *req = dialreq_new((size_t)(n - 1)); if (req == NULL) { LOGOOM(); return NULL; @@ -90,7 +90,7 @@ static struct dialreq *pop_dialreq(lua_State *restrict L, int n) dialreq_free(req); return NULL; } - if (!dialreq_proxy(req, s, len)) { + if (!dialreq_setproxy(req, i - 1, s, len)) { LOGE_F("ruleset: returned address #%d is not valid", n - i); dialreq_free(req);