Skip to content

Commit

Permalink
Merge pull request #2 from mhils/pr-5435
Browse files Browse the repository at this point in the history
Pr 5435
  • Loading branch information
meitinger authored Oct 24, 2022
2 parents 995ff02 + 7b1d188 commit 98ce1c2
Show file tree
Hide file tree
Showing 16 changed files with 129 additions and 83 deletions.
2 changes: 2 additions & 0 deletions mitmproxy/proxy/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ class CloseConnection(ConnectionCommand):
all other connections will ultimately be closed during cleanup.
"""


class CloseTcpConnection(CloseConnection):
half_close: bool
"""
If True, only close our half of the connection by sending a FIN packet.
Expand Down
8 changes: 6 additions & 2 deletions mitmproxy/proxy/layers/http/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from ._http1 import Http1Client, Http1Connection, Http1Server
from ._http2 import Http2Client, Http2Server
from ._http3 import Http3Client, Http3Server
from ..quic import QuicStreamEvent
from ...context import Context
from ...mode_specs import ReverseMode, UpstreamMode

Expand Down Expand Up @@ -790,7 +791,8 @@ def passthrough(self, event: events.Event) -> layer.CommandGenerator[None]:
# The easiest approach for this is to just always full close for now.
# Alternatively, we could signal that we want a half close only through ResponseProtocolError,
# but that is more complex to implement.
command.half_close = False
if isinstance(command, commands.CloseTcpConnection):
command = commands.CloseConnection(command.connection)
yield command
else:
yield command
Expand Down Expand Up @@ -886,7 +888,7 @@ def _handle_event(self, event: events.Event):
if isinstance(event, events.ConnectionClosed):
# The peer has closed it - let's close it too!
yield commands.CloseConnection(event.connection)
else:
elif isinstance(event, (events.DataReceived, QuicStreamEvent)):
# The peer has sent data or another connection activity occurred.
# This can happen with HTTP/2 servers that already send a settings frame.
child_layer: HttpConnection
Expand All @@ -899,6 +901,8 @@ def _handle_event(self, event: events.Event):
self.connections[self.context.server] = child_layer
yield from self.event_to_child(child_layer, events.Start())
yield from self.event_to_child(child_layer, event)
else:
raise AssertionError(f"Unexpected event: {event}")
else:
handler = self.connections[event.connection]
yield from self.event_to_child(handler, event)
Expand Down
2 changes: 1 addition & 1 deletion mitmproxy/proxy/layers/http/_http1.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def send(self, event: HttpEvent) -> layer.CommandGenerator[None]:
if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
yield commands.SendData(self.conn, b"0\r\n\r\n")
elif http1.expected_http_body_size(self.request, self.response) == -1:
yield commands.CloseConnection(self.conn, half_close=True)
yield commands.CloseTcpConnection(self.conn, half_close=True)
yield from self.mark_done(request=True)
else:
raise AssertionError(f"Unexpected event: {event}")
Expand Down
12 changes: 5 additions & 7 deletions mitmproxy/proxy/layers/http/_http3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from mitmproxy.net.http import status_codes
from mitmproxy.proxy import commands, context, events, layer
from mitmproxy.proxy.layers.quic import (
QuicConnectionClosed,
QuicStreamEvent,
error_code_to_str,
get_connection_error,
)
from mitmproxy.proxy.utils import expect

Expand Down Expand Up @@ -164,12 +164,10 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
# report a protocol error for all remaining open streams when a connection is closed
elif isinstance(event, events.ConnectionClosed):
self._handle_event = self.done # type: ignore
close_event = get_connection_error(self.conn)
msg = (
"peer closed connection"
if close_event is None else
close_event.reason_phrase or error_code_to_str(close_event.error_code)
)
if isinstance(event, QuicConnectionClosed):
msg = event.reason_phrase or error_code_to_str(event.error_code)
else:
msg = "peer closed connection"
for stream_id in self.h3_conn.get_reserved_stream_ids():
yield ReceiveHttp(self.ReceiveProtocolError(stream_id, msg))

Expand Down
13 changes: 5 additions & 8 deletions mitmproxy/proxy/layers/http/_http_h3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
from mitmproxy import connection
from mitmproxy.proxy import commands, layer
from mitmproxy.proxy.layers.quic import (
CloseQuicConnection,
QuicStreamDataReceived,
QuicStreamEvent,
QuicStreamReset,
ResetQuicStream,
SendQuicStreamData,
set_connection_error,
)


Expand Down Expand Up @@ -87,13 +87,10 @@ def close(
# we'll get closed if a protocol error occurs in `H3Connection.handle_event`
# we note the error on the connection and yield a CloseConnection
# this will then call `QuicConnection.close` with the proper values
# once the `Http3Connection` receives `ConnectionClosed`, it will send out `*ProtocolError`
set_connection_error(self.conn, ConnectionTerminated(
error_code=error_code,
frame_type=frame_type,
reason_phrase=reason_phrase,
))
self.pending_commands.append(commands.CloseConnection(self.conn))
# once the `Http3Connection` receives `ConnectionClosed`, it will send out `ProtocolError`
self.pending_commands.append(
CloseQuicConnection(self.conn, error_code, frame_type, reason_phrase)
)

def get_next_available_stream_id(self, is_unidirectional: bool = False) -> int:
# since we always reserve the ID, we have to "find" the next ID like `QuicConnection` does
Expand Down
11 changes: 6 additions & 5 deletions mitmproxy/proxy/layers/modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,14 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
assert isinstance(spec, ReverseMode)
self.context.server.address = spec.address

if spec.scheme in ("https", "http3", "quic", "tls", "dtls"):
if spec.scheme in ("http3", "quic"):
if not self.context.options.keep_host_header:
self.context.server.sni = spec.address[0]
if spec.scheme == "http3" or spec.scheme == "quic":
self.child_layer = quic.ServerQuicLayer(self.context)
else:
self.child_layer = tls.ServerTLSLayer(self.context)
self.child_layer = quic.ServerQuicLayer(self.context)
elif spec.scheme in ("https", "tls", "dtls"):
if not self.context.options.keep_host_header:
self.context.server.sni = spec.address[0]
self.child_layer = tls.ServerTLSLayer(self.context)
elif spec.scheme == "udp":
self.child_layer = udp.UDPLayer(self.context)
elif spec.scheme == "http" or spec.scheme == "tcp":
Expand Down
118 changes: 82 additions & 36 deletions mitmproxy/proxy/layers/quic.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,56 @@ def __init__(self, connection: connection.Connection, stream_id: int, error_code
self.error_code = error_code


class CloseQuicConnection(commands.CloseConnection):
"""Close a QUIC connection."""

error_code: int
"The error code which was specified when closing the connection."

frame_type: int | None
"The frame type which caused the connection to be closed, or `None`."

reason_phrase: str
"The human-readable reason for which the connection was closed."

# XXX: A bit much boilerplate right now. Should switch to dataclasses.
def __init__(
self,
conn: connection.Connection,
error_code: int,
frame_type: int | None,
reason_phrase: str,
):
super().__init__(conn)
self.error_code = error_code
self.frame_type = frame_type
self.reason_phrase = reason_phrase


class QuicConnectionClosed(events.ConnectionClosed):
"""QUIC connection has been closed."""
error_code: int
"The error code which was specified when closing the connection."

frame_type: int | None
"The frame type which caused the connection to be closed, or `None`."

reason_phrase: str
"The human-readable reason for which the connection was closed."

def __init__(
self,
conn: connection.Connection,
error_code: int,
frame_type: int | None,
reason_phrase: str,
):
super().__init__(conn)
self.error_code = error_code
self.frame_type = frame_type
self.reason_phrase = reason_phrase


class QuicSecretsLogger:
logger: tls.MasterSecretLogger

Expand Down Expand Up @@ -208,28 +258,12 @@ def error_code_to_str(error_code: int) -> str:
return f"unknown error (0x{error_code:x})"


def get_connection_error(conn: connection.Connection) -> quic_events.ConnectionTerminated | None:
"""Returns the QUIC close event that is associated with the given connection."""

close_event = getattr(conn, "quic_error", None)
if close_event is None:
return None
assert isinstance(close_event, quic_events.ConnectionTerminated)
return close_event


def is_success_error_code(error_code: int) -> bool:
"""Returns whether the given error code actually indicates no error."""

return error_code in (QuicErrorCode.NO_ERROR, H3ErrorCode.H3_NO_ERROR)


def set_connection_error(conn: connection.Connection, close_event: quic_events.ConnectionTerminated) -> None:
"""Stores the given close event for the given connection."""

setattr(conn, "quic_error", close_event)


@dataclass
class QuicClientHello(Exception):
"""Helper error only used in `quic_parse_client_hello`."""
Expand Down Expand Up @@ -299,7 +333,7 @@ class QuicStreamLayer(layer.Layer):
"""Virtual client connection for this stream. Use this in QuicRawLayer instead of `context.client`."""
server: connection.Server
"""Virtual server connection for this stream. Use this in QuicRawLayer instead of `context.server`."""
child_layer: layer.Layer
child_layer: TCPLayer
"""The stream's child layer."""

def __init__(self, context: context.Context, ignore: bool, stream_id: int) -> None:
Expand Down Expand Up @@ -335,11 +369,21 @@ def __init__(self, context: context.Context, ignore: bool, stream_id: int) -> No
if ignore else
layer.NextLayer(context)
)
if ignore:
self.child_layer = TCPLayer(context, ignore=True)
else:
tcp_layer = TCPLayer(context)
# This can potentially move to a smarter place later on,
# but it's useful debugging info in mitmproxy for now.
tcp_layer.flow.metadata["quic_is_unidirectional"] = stream_is_unidirectional(stream_id)
tcp_layer.flow.metadata["quic_initiator"] = "client" if stream_is_client_initiated(stream_id) else "server"
tcp_layer.flow.metadata["quic_stream_id_client"] = stream_id
self.child_layer = tcp_layer
self.handle_event = self.child_layer.handle_event # type: ignore
self._handle_event = self.child_layer._handle_event # type: ignore

def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
pass
raise AssertionError

def open_server_stream(self, server_stream_id) -> None:
assert self._server_stream_id is None
Expand All @@ -354,6 +398,8 @@ def open_server_stream(self, server_stream_id) -> None:
if stream_is_unidirectional(server_stream_id) else
connection.ConnectionState.OPEN
)
if self.child_layer.flow:
self.child_layer.flow.metadata["quic_stream_id_server"] = server_stream_id

def stream_id(self, client: bool) -> int | None:
return self._client_stream_id if client else self._server_stream_id
Expand Down Expand Up @@ -481,17 +527,13 @@ def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:

# handle close events that target this context
elif (
isinstance(event, events.ConnectionClosed)
isinstance(event, QuicConnectionClosed)
and (
event.connection is self.context.client
or event.connection is self.context.server
)
):
# copy the connection error
from_client = event.connection is self.context.client
close_event = get_connection_error(event.connection)
if close_event is not None:
set_connection_error(self.context.server if from_client else self.context.client, close_event)

# always forward to the datagram layer
yield from self.event_to_child(self.datagram_layer, event)
Expand Down Expand Up @@ -552,7 +594,9 @@ def event_to_child(self, child_layer: layer.Layer, event: events.Event) -> layer
if command.connection.state & connection.ConnectionState.CAN_WRITE:
command.connection.state &= ~connection.ConnectionState.CAN_WRITE
yield SendQuicStreamData(quic_conn, stream_id, b"", end_stream=True)
if not command.half_close:
# XXX: Use `command.connection.state & connection.ConnectionState.CAN_READ` instead?
only_close_our_half = isinstance(command, commands.CloseTcpConnection) and command.half_close
if not only_close_our_half:
if (
stream_is_client_initiated(stream_id) == to_client
or not stream_is_unidirectional(stream_id)
Expand Down Expand Up @@ -605,21 +649,24 @@ def __init__(self, context: context.Context, conn: connection.Connection) -> Non
conn.tls = True

def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
# turn Wakeup events into empty DataReceived events
if (
isinstance(event, events.Wakeup)
and event.command in self._wakeup_commands
):
# TunnelLayer has no understanding of wakeups, so we turn this into an empty DataReceived event
# which TunnelLayer recognizes as belonging to our connection.
assert self.quic
timer = self._wakeup_commands.pop(event.command)
if self.quic._state is not QuicConnectionState.TERMINATED:
self.quic.handle_timer(now=max(timer, self._loop.time()))
event = events.DataReceived(self.tunnel_connection, b"")
yield from super()._handle_event(event)
yield from super()._handle_event(
events.DataReceived(self.tunnel_connection, b"")
)
else:
yield from super()._handle_event(event)

def _handle_command(self, command: commands.Command) -> layer.CommandGenerator[None]:
"""Turns stream commands into aioquic connection invocations."""

if (
isinstance(command, QuicStreamCommand)
and command.connection is self.conn
Expand Down Expand Up @@ -783,13 +830,12 @@ def receive_data(self, data: bytes) -> layer.CommandGenerator[None]:
# handle post-handshake events
while event := self.quic.next_event():
if isinstance(event, quic_events.ConnectionTerminated):
set_connection_error(self.conn, event)
if self.debug:
reason = event.reason_phrase or error_code_to_str(event.error_code)
yield commands.Log(
f"{self.debug}[quic] close_notify {self.conn} (reason={reason})", DEBUG
)
yield commands.CloseConnection(self.conn)
yield CloseQuicConnection(self.conn, event.error_code, event.frame_type, event.reason_phrase)
return # we don't handle any further events, nor do/can we transmit data, so exit
elif isinstance(event, quic_events.DatagramFrameReceived):
yield from self.event_to_child(events.DataReceived(self.conn, event.data))
Expand All @@ -801,6 +847,7 @@ def receive_data(self, data: bytes) -> layer.CommandGenerator[None]:
quic_events.ConnectionIdIssued,
quic_events.ConnectionIdRetired,
quic_events.PingAcknowledged,
quic_events.ProtocolNegotiated,
)):
pass
else:
Expand All @@ -820,16 +867,15 @@ def send_data(self, data: bytes) -> layer.CommandGenerator[None]:
self.quic.send_datagram_frame(data)
yield from self.tls_interact()

def send_close(self, half_close: bool) -> layer.CommandGenerator[None]:
def send_close(self, command: commands.CloseConnection) -> layer.CommandGenerator[None]:
# properly close the QUIC connection
if self.quic is not None:
close_event = get_connection_error(self.conn)
if close_event is None:
self.quic.close()
if isinstance(command, CloseQuicConnection):
self.quic.close(command.error_code, command.frame_type, command.reason_phrase)
else:
self.quic.close(close_event.error_code, close_event.frame_type, close_event.reason_phrase)
self.quic.close()
yield from self.tls_interact()
yield from super().send_close(half_close)
yield from super().send_close(command)


class ServerQuicLayer(QuicLayer):
Expand Down
2 changes: 1 addition & 1 deletion mitmproxy/proxy/layers/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def relay_messages(self, event: events.Event) -> layer.CommandGenerator[None]:
yield TcpEndHook(self.flow)
self.flow.live = False
else:
yield commands.CloseConnection(send_to, half_close=True)
yield commands.CloseTcpConnection(send_to, half_close=True)
else:
raise AssertionError(f"Unexpected event: {event}")

Expand Down
4 changes: 2 additions & 2 deletions mitmproxy/proxy/layers/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,9 +440,9 @@ def send_data(self, data: bytes) -> layer.CommandGenerator[None]:
pass
yield from self.tls_interact()

def send_close(self, half_close: bool) -> layer.CommandGenerator[None]:
def send_close(self, command: commands.CloseConnection) -> layer.CommandGenerator[None]:
# We should probably shutdown the TLS connection properly here.
yield from super().send_close(half_close)
yield from super().send_close(command)


class ServerTLSLayer(TLSLayer):
Expand Down
4 changes: 3 additions & 1 deletion mitmproxy/proxy/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,10 @@ def server_event(self, event: events.Event) -> None:
assert writer
if not writer.is_closing():
writer.write(command.data)
elif isinstance(command, commands.CloseConnection):
elif isinstance(command, commands.CloseTcpConnection):
self.close_connection(command.connection, command.half_close)
elif isinstance(command, commands.CloseConnection):
self.close_connection(command.connection, False)
elif isinstance(command, commands.StartHook):
asyncio_utils.create_task(
self.hook_task(command),
Expand Down
Loading

0 comments on commit 98ce1c2

Please sign in to comment.