diff --git a/capnp-rpc-net/capTP_capnp.ml b/capnp-rpc-net/capTP_capnp.ml index 74488522..994972ad 100644 --- a/capnp-rpc-net/capTP_capnp.ml +++ b/capnp-rpc-net/capTP_capnp.ml @@ -1,33 +1,5 @@ open Eio.Std -module Metrics = struct - open Prometheus - - let namespace = "capnp" - - let subsystem = "net" - - let connections = - let help = "Number of live capnp-rpc connections" in - Gauge.v ~help ~namespace ~subsystem "connections" - - let messages_inbound_received_total = - let help = "Total number of messages received" in - Counter.v ~help ~namespace ~subsystem "messages_inbound_received_total" - - let messages_outbound_enqueued_total = - let help = "Total number of messages enqueued to be transmitted" in - Counter.v ~help ~namespace ~subsystem "messages_outbound_enqueued_total" - - let messages_outbound_sent_total = - let help = "Total number of messages transmitted" in - Counter.v ~help ~namespace ~subsystem "messages_outbound_sent_total" - - let messages_outbound_dropped_total = - let help = "Total number of messages lost due to disconnections" in - Counter.v ~help ~namespace ~subsystem "messages_outbound_dropped_total" -end - module Log = Capnp_rpc.Debug.Log module Builder = Capnp_rpc.Private.Schema.Builder @@ -45,7 +17,6 @@ module Make (Network : S.NETWORK) = struct sw : Switch.t; endpoint : Endpoint.t; conn : Conn.t; - xmit_queue : Capnp.Message.rw Capnp.BytesMessage.Message.t Eio.Stream.t; mutable disconnecting : bool; } @@ -60,48 +31,12 @@ module Make (Network : S.NETWORK) = struct let tags t = Conn.tags t.conn - let drop_queue q = - let len = Eio.Stream.length q in - Prometheus.Counter.inc Metrics.messages_outbound_dropped_total (float_of_int len) - (* Queue.clear q -- could close stream here instead *) - - (* [flush ~xmit_queue endpoint] writes each message in [xmit_queue] to [endpoint]. *) - let rec flush ~xmit_queue endpoint = - let next = Eio.Stream.take xmit_queue in - match Endpoint.send endpoint next with - | Error `Closed -> - Endpoint.disconnect endpoint; (* We'll read a close soon *) - drop_queue xmit_queue; - `Stop_daemon - | Error (`Msg msg) -> - Log.warn (fun f -> f "Error sending messages: %s (will shutdown connection)" msg); - Endpoint.disconnect endpoint; - drop_queue xmit_queue; - `Stop_daemon - | Ok () -> - Prometheus.Counter.inc_one Metrics.messages_outbound_sent_total; - flush ~xmit_queue endpoint - | exception ex -> - drop_queue xmit_queue; - raise ex - - (* Enqueue [message] in [xmit_queue] and ensure the flush thread is running. *) - let queue_send ~xmit_queue message = - Log.debug (fun f -> - let module M = Capnp_rpc.Private.Schema.MessageWrapper.Message in - f "queue_send: %d/%d allocated bytes in %d segs" - (M.total_size message) - (M.total_alloc_size message) - (M.num_segments message)); - Eio.Stream.add xmit_queue message; - Prometheus.Counter.inc_one Metrics.messages_outbound_enqueued_total - let return_not_implemented t x = Log.debug (fun f -> f ~tags:(tags t) "Returning Unimplemented"); let open Builder in let m = Message.init_root () in let _ : Builder.Message.t = Message.unimplemented_set_reader m x in - queue_send ~xmit_queue:t.xmit_queue (Message.to_message m) + Endpoint.send t.endpoint (Message.to_message m) let listen t = let rec loop () = @@ -110,7 +45,6 @@ module Make (Network : S.NETWORK) = struct | Ok msg -> let open Reader.Message in let msg = of_message msg in - Prometheus.Counter.inc_one Metrics.messages_inbound_received_total; match Parse.message msg with | #Endpoint_types.In.t as msg -> Log.debug (fun f -> @@ -140,7 +74,7 @@ module Make (Network : S.NETWORK) = struct loop () let send_abort t ex = - queue_send ~xmit_queue:t.xmit_queue (Serialise.message (`Abort ex)) + Endpoint.send t.endpoint (Serialise.message (`Abort ex)) let disconnect t ex = if not t.disconnecting then ( @@ -153,9 +87,7 @@ module Make (Network : S.NETWORK) = struct let disconnecting t = t.disconnecting let connect ~sw ~restore ?(tags=Logs.Tag.empty) endpoint = - let xmit_queue = Eio.Stream.create 100 in (* todo: tune this? make it configurable? *) - Fiber.fork_daemon ~sw (fun () -> flush ~xmit_queue endpoint); - let queue_send msg = Eio.Stream.add xmit_queue (Serialise.message msg) in + let queue_send msg = Endpoint.send endpoint (Serialise.message msg) in let restore = Restorer.fn restore in let fork = Fiber.fork ~sw in let conn = Conn.create ~restore ~tags ~fork ~queue_send in @@ -163,12 +95,10 @@ module Make (Network : S.NETWORK) = struct sw; conn; endpoint; - xmit_queue; disconnecting = false; } let listen t = - Prometheus.Gauge.inc_one Metrics.connections; let tags = Conn.tags t.conn in begin match listen t with @@ -182,7 +112,6 @@ module Make (Network : S.NETWORK) = struct send_abort t (Capnp_rpc.Exception.v ~ty:`Failed (Printexc.to_string ex)) end; Log.info (fun f -> f ~tags "Connection closed"); - Prometheus.Gauge.dec_one Metrics.connections; Eio.Cancel.protect (fun () -> disconnect t (Capnp_rpc.Exception.v ~ty:`Disconnected "Connection closed") ) diff --git a/capnp-rpc-net/endpoint.ml b/capnp-rpc-net/endpoint.ml index 845644ff..237ebabe 100644 --- a/capnp-rpc-net/endpoint.ml +++ b/capnp-rpc-net/endpoint.ml @@ -1,5 +1,27 @@ open Eio.Std +module Metrics = struct + open Prometheus + + let namespace = "capnp" + + let subsystem = "net" + + let connections = + let help = "Number of live capnp-rpc connections" in + Gauge.v ~help ~namespace ~subsystem "connections" + + let messages_inbound_received_total = + let help = "Total number of messages received" in + Counter.v ~help ~namespace ~subsystem "messages_inbound_received_total" + + let messages_outbound_enqueued_total = + let help = "Total number of messages enqueued to be transmitted" in + Counter.v ~help ~namespace ~subsystem "messages_outbound_enqueued_total" +end + +module Write = Eio.Buf_write + let src = Logs.Src.create "endpoint" ~doc:"Send and receive Cap'n'Proto messages" module Log = (val Logs.src_log src: Logs.LOG) @@ -11,17 +33,13 @@ type flow = Eio.Flow.two_way_ty r type t = { flow : flow; + writer : Write.t; decoder : Capnp.Codecs.FramedStream.t; peer_id : Auth.Digest.t; } let peer_id t = t.peer_id -let of_flow ~peer_id flow = - let decoder = Capnp.Codecs.FramedStream.empty compression in - let flow = (flow :> flow) in - { flow; decoder; peer_id } - let dump_msg = let next = ref 0 in fun data -> @@ -32,25 +50,62 @@ let dump_msg = output_string ch data; close_out ch +let disconnect t = + try + Eio.Flow.shutdown t.flow `All + with Eio.Io (Eio.Net.E Connection_reset _, _) -> + (* TCP connection already shut down, so TLS shutdown failed. Ignore. *) + () + let send t msg = - let data = Capnp.Codecs.serialize ~compression msg in - if record_sent_messages then dump_msg data; - match Eio.Flow.copy_string data t.flow with - | () - | exception End_of_file -> Ok () + Log.debug (fun f -> + let module M = Capnp_rpc.Private.Schema.MessageWrapper.Message in + f "queue_send: %d/%d allocated bytes in %d segs" + (M.total_size msg) + (M.total_alloc_size msg) + (M.num_segments msg)); + Capnp.Codecs.serialize_iter_copyless ~compression msg ~f:(fun x len -> Write.string t.writer x ~len); + Prometheus.Counter.inc_one Metrics.messages_outbound_enqueued_total; + if record_sent_messages then dump_msg (Capnp.Codecs.serialize ~compression msg) + +let rec run_writer t = + let bufs = Write.await_batch t.writer in + match Eio.Flow.single_write t.flow bufs with + | n -> Write.shift t.writer n; run_writer t | exception (Eio.Io (Eio.Net.E Connection_reset _, _) as ex) -> Log.info (fun f -> f "%a" Eio.Exn.pp ex); - Error `Closed + disconnect t (* We'll read a close soon *) | exception ex -> Eio.Fiber.check (); - Error (`Msg (Printexc.to_string ex)) + Log.warn (fun f -> f "Error sending messages: %a (will shutdown connection)" Fmt.exn ex); + disconnect t + +let of_flow ~sw ~peer_id flow = + let decoder = Capnp.Codecs.FramedStream.empty compression in + let flow = (flow :> flow) in + let writer = Write.create 4096 in + let t = { flow; writer; decoder; peer_id } in + Prometheus.Gauge.inc_one Metrics.connections; + Switch.on_release sw (fun () -> Prometheus.Gauge.dec_one Metrics.connections); + Fiber.fork_daemon ~sw (fun () -> run_writer t; `Stop_daemon); + t let rec recv t = match Capnp.Codecs.FramedStream.get_next_frame t.decoder with - | Ok msg -> Ok (Capnp.BytesMessage.Message.readonly msg) + | Ok msg -> + Prometheus.Counter.inc_one Metrics.messages_inbound_received_total; + (* We often want to send multiple response messages while processing a batch of requests, + so pause the writer to collect them. We'll unpause on the next read. *) + Write.pause t.writer; + Ok (Capnp.BytesMessage.Message.readonly msg) | Error Capnp.Codecs.FramingError.Unsupported -> failwith "Unsupported Cap'n'Proto frame received" | Error Capnp.Codecs.FramingError.Incomplete -> Log.debug (fun f -> f "Incomplete; waiting for more data..."); + (* We probably scheduled one or more application fibers to run while handling the last + batch of messages. Given them a chance to run now while the writer is paused, because + they might want to send more messages immediately. *) + Fiber.yield (); + Write.unpause t.writer; let buf = Cstruct.create 4096 in (* TODO: make this efficient *) match Eio.Flow.single_read t.flow buf with | got -> @@ -62,10 +117,3 @@ let rec recv t = | exception (Eio.Io (Eio.Net.E Connection_reset _, _) as ex) -> Log.info (fun f -> f "%a" Eio.Exn.pp ex); Error `Closed - -let disconnect t = - try - Eio.Flow.shutdown t.flow `All - with Eio.Io (Eio.Net.E Connection_reset _, _) -> - (* TCP connection already shut down, so TLS shutdown failed. Ignore. *) - () diff --git a/capnp-rpc-net/endpoint.mli b/capnp-rpc-net/endpoint.mli index 674a6a29..54db0917 100644 --- a/capnp-rpc-net/endpoint.mli +++ b/capnp-rpc-net/endpoint.mli @@ -1,20 +1,24 @@ (** Send and receive capnp messages over a byte-stream. *) +open Eio.Std + val src : Logs.src (** Control the log level. *) type t (** A wrapper for a byte-stream (flow). *) -val send : t -> 'a Capnp.BytesMessage.Message.t -> (unit, [`Closed | `Msg of string]) result -(** [send t msg] transmits [msg]. *) +val send : t -> 'a Capnp.BytesMessage.Message.t -> unit +(** [send t msg] enqueues [msg]. *) val recv : t -> (Capnp.Message.ro Capnp.BytesMessage.Message.t, [> `Closed]) result (** [recv t] reads the next message from the remote peer. It returns [Error `Closed] if the connection to the peer is lost. *) -val of_flow : peer_id:Auth.Digest.t -> _ Eio.Flow.two_way -> t -(** [of_flow ~peer_id flow] sends and receives on [flow]. *) +val of_flow : sw:Switch.t -> peer_id:Auth.Digest.t -> _ Eio.Flow.two_way -> t +(** [of_flow ~sw ~peer_id flow] sends and receives on [flow]. + + [sw] is used to run a fiber writing messages in batches. *) val peer_id : t -> Auth.Digest.t (** [peer_id t] is the fingerprint of the peer's public key, diff --git a/capnp-rpc-net/tls_wrapper.ml b/capnp-rpc-net/tls_wrapper.ml index f2cc73fa..6a7a9100 100644 --- a/capnp-rpc-net/tls_wrapper.ml +++ b/capnp-rpc-net/tls_wrapper.ml @@ -6,12 +6,12 @@ let error fmt = fmt |> Fmt.kstr @@ fun msg -> Error (`Msg msg) -let plain_endpoint flow = - Endpoint.of_flow ~peer_id:Auth.Digest.insecure flow +let plain_endpoint ~sw flow = + Endpoint.of_flow ~sw ~peer_id:Auth.Digest.insecure flow -let connect_as_server flow secret_key = +let connect_as_server ~sw flow secret_key = match secret_key with - | None -> Ok (plain_endpoint flow) + | None -> Ok (plain_endpoint ~sw flow) | Some key -> Log.info (fun f -> f "Doing TLS server-side handshake..."); let tls_config = Secret_key.tls_server_config key in @@ -26,15 +26,15 @@ let connect_as_server flow secret_key = | None -> error "No client certificate found" | Some client_cert -> let peer_id = Digest.of_certificate client_cert in - Ok (Endpoint.of_flow ~peer_id flow) + Ok (Endpoint.of_flow ~sw ~peer_id flow) -let connect_as_client flow secret_key auth = +let connect_as_client ~sw flow secret_key auth = match Digest.authenticator auth with - | None -> Ok (plain_endpoint flow) + | None -> Ok (plain_endpoint ~sw flow) | Some authenticator -> let tls_config = Secret_key.tls_client_config ~authenticator (Lazy.force secret_key) in Log.info (fun f -> f "Doing TLS client-side handshake..."); match Tls_eio.client_of_flow tls_config flow with | exception (Failure msg) -> error "TLS connection failed: %s" msg | exception ex -> Eio.Fiber.check (); error "TLS connection failed: %a" Fmt.exn ex - | flow -> Ok (Endpoint.of_flow ~peer_id:auth flow) + | flow -> Ok (Endpoint.of_flow ~sw ~peer_id:auth flow) diff --git a/capnp-rpc-net/tls_wrapper.mli b/capnp-rpc-net/tls_wrapper.mli index 81c214f3..4b6fa78e 100644 --- a/capnp-rpc-net/tls_wrapper.mli +++ b/capnp-rpc-net/tls_wrapper.mli @@ -2,12 +2,14 @@ open Auth open Eio.Std val connect_as_server : + sw:Switch.t -> [> Eio.Flow.two_way_ty | Eio.Resource.close_ty] r -> Auth.Secret_key.t option -> (Endpoint.t, [> `Msg of string]) result val connect_as_client : + sw:Switch.t -> [> Eio.Flow.two_way_ty | Eio.Resource.close_ty] r -> Auth.Secret_key.t Lazy.t -> Digest.t -> (Endpoint.t, [> `Msg of string]) result -(** [connect_as_client underlying key digest] is an endpoint using flow [underlying]. +(** [connect_as_client ~sw underlying key digest] is an endpoint using flow [underlying]. If [digest] requires TLS, it performs a TLS handshake. It uses [key] as its private key and checks that the server is the one required by [auth]. *) diff --git a/test-bin/calc_direct.ml b/test-bin/calc_direct.ml index 7ac6dde6..6e7ef293 100644 --- a/test-bin/calc_direct.ml +++ b/test-bin/calc_direct.ml @@ -35,7 +35,7 @@ module Parent = struct Switch.run @@ fun sw -> (* Run Cap'n Proto RPC protocol on [socket]: *) let p = Eio_unix.Net.import_socket_stream ~sw ~close_unix:true socket - |> Capnp_rpc_net.Endpoint.of_flow + |> Capnp_rpc_net.Endpoint.of_flow ~sw ~peer_id:Capnp_rpc_net.Auth.Digest.insecure in Logs.info (fun f -> f "Connecting to child process..."); @@ -60,7 +60,7 @@ module Child = struct let service = Calc.local ~sw in let restore = Capnp_rpc_net.Restorer.single service_name service in (* Run Cap'n Proto RPC protocol on [socket]: *) - let endpoint = Capnp_rpc_net.Endpoint.of_flow socket + let endpoint = Capnp_rpc_net.Endpoint.of_flow socket ~sw ~peer_id:Capnp_rpc_net.Auth.Digest.insecure in let conn = Capnp_rpc_unix.CapTP.connect ~sw ~restore endpoint in diff --git a/test-bin/echo/echo_bench.ml b/test-bin/echo/echo_bench.ml index aad9e785..30a222d8 100755 --- a/test-bin/echo/echo_bench.ml +++ b/test-bin/echo/echo_bench.ml @@ -7,8 +7,7 @@ let () = Logs.set_reporter (Logs_fmt.reporter ()) let run_client service = - (* let n = 100000 in *) (* XXX: improve speed *) - let n = 1000 in + let n = 100000 in let ops = List.init n (fun i -> let payload = Int.to_string i in let desired_result = "echo:" ^ payload in diff --git a/unix/capnp_rpc_unix.ml b/unix/capnp_rpc_unix.ml index ae900458..c691523a 100644 --- a/unix/capnp_rpc_unix.ml +++ b/unix/capnp_rpc_unix.ml @@ -150,8 +150,8 @@ let with_cap_exn ?progress sr f = | Error ex -> Fmt.failwith "%a" Capnp_rpc.Exception.pp ex | Ok x -> Capnp_rpc.Capability.with_ref x f -let handle_connection ?tags ~secret_key vat client = - match Network.accept_connection ~secret_key client with +let handle_connection ?tags ~sw ~secret_key vat client = + match Network.accept_connection ~sw ~secret_key client with | Error (`Msg msg) -> Log.warn (fun f -> f ?tags "Rejecting new connection: %s" msg) | Ok ep -> @@ -189,7 +189,7 @@ let listen ?tags ~sw (config, vat, socket) = let secret_key = if config.Vat_config.serve_tls then Some (Vat_config.secret_key config) else None in Fiber.fork ~sw (fun () -> (* We don't use [Net.accept_fork] here because this returns immediately after connecting. *) - handle_connection ?tags ~secret_key vat client + handle_connection ?tags ~sw ~secret_key vat client ) done diff --git a/unix/network.ml b/unix/network.ml index 8eb90110..db9c1462 100644 --- a/unix/network.ml +++ b/unix/network.ml @@ -96,14 +96,14 @@ let connect net ~sw ~secret_key (addr, auth) = try_set_nodelay socket; Keepalive.try_set_idle socket 60 end; - Tls_wrapper.connect_as_client socket secret_key auth + Tls_wrapper.connect_as_client ~sw socket secret_key auth | exception ex -> Fiber.check (); error "@[Network connection for %a failed:@,%a@]" Location.pp addr Fmt.exn ex -let accept_connection ~secret_key flow = +let accept_connection ~sw ~secret_key flow = Eio_unix.Resource.fd_opt flow |> Option.iter (fun fd -> Eio_unix.Fd.use_exn "TCP_NODELAY" fd try_set_nodelay); - Tls_wrapper.connect_as_server flow secret_key + Tls_wrapper.connect_as_server ~sw flow secret_key let v t = (t :> [`Generic] Eio.Net.ty r) diff --git a/unix/network.mli b/unix/network.mli index ef68b701..c9ca61a5 100644 --- a/unix/network.mli +++ b/unix/network.mli @@ -33,10 +33,11 @@ include Capnp_rpc_net.S.NETWORK with val v : _ Eio.Net.t -> t val accept_connection : + sw:Switch.t -> secret_key:Capnp_rpc_net.Auth.Secret_key.t option -> [> Eio.Flow.two_way_ty | Eio.Resource.close_ty] r -> (Capnp_rpc_net.Endpoint.t, [> `Msg of string]) result -(** [accept_connection ~switch ~secret_key flow] is a new endpoint for [flow]. +(** [accept_connection ~sw ~secret_key flow] is a new endpoint for [flow]. If [secret_key] is not [None], it is used to perform a TLS server-side handshake. Otherwise, the connection is not encrypted. *)