From 9c74a31467fd1674ac62869955ed74cb147fdf2e Mon Sep 17 00:00:00 2001 From: Thomas Leonard Date: Tue, 12 Jul 2022 15:24:52 +0100 Subject: [PATCH] Initial Eio port This switches capnp-rpc from Lwt to Eio. One particularly nice side effect of this is that `Service.return_lwt` has gone, as there is no distinction now between concurrent and non-concurrent service methods. Notes: - In this commit, everything is still using the "lwt" names to make the diff useful. In a future commit, this should be renamed. Also, some of the "unix" functions can be moved into the core library with Eio. This would likely be a good time to rename `capnp_rpc` to `capnp_rpc_protocol` or something, leaving the short name free for the main library. - Mirage support is gone. Ideally, the regular library should work with Mirage 5. --- Makefile | 4 +- README.md | 340 +++++------- capnp-rpc-lwt.opam | 12 +- capnp-rpc-lwt/capability.ml | 38 +- capnp-rpc-lwt/capnp_core.ml | 11 +- capnp-rpc-lwt/capnp_rpc_lwt.mli | 60 +- capnp-rpc-lwt/dune | 2 +- capnp-rpc-lwt/persistence.ml | 21 +- capnp-rpc-lwt/service.ml | 24 +- capnp-rpc-lwt/sturdy_ref.ml | 14 +- capnp-rpc-mirage.opam | 36 -- capnp-rpc-net.opam | 1 - capnp-rpc-net/capTP_capnp.ml | 101 ++-- capnp-rpc-net/capTP_capnp.mli | 13 +- capnp-rpc-net/capnp_rpc_net.ml | 6 +- capnp-rpc-net/capnp_rpc_net.mli | 16 +- capnp-rpc-net/endpoint.ml | 52 +- capnp-rpc-net/endpoint.mli | 24 +- capnp-rpc-net/restorer.ml | 88 ++- capnp-rpc-net/s.ml | 35 +- capnp-rpc-net/tls_eio.ml | 228 ++++++++ capnp-rpc-net/tls_wrapper.ml | 83 ++- capnp-rpc-net/tls_wrapper.mli | 25 +- capnp-rpc-net/two_party_network.ml | 2 +- capnp-rpc-net/vat.ml | 117 ++-- capnp-rpc-unix.opam | 4 +- capnp-rpc/capTP.ml | 11 +- capnp-rpc/capTP.mli | 8 +- examples/pipelining/dune | 2 +- examples/pipelining/echo.ml | 25 +- examples/pipelining/main.ml | 28 +- examples/sturdy-refs-2/dune | 2 +- examples/sturdy-refs-2/main.ml | 30 +- examples/sturdy-refs-3/dune | 2 +- examples/sturdy-refs-3/main.ml | 42 +- examples/sturdy-refs-4/db.ml | 10 +- examples/sturdy-refs-4/db.mli | 2 +- examples/sturdy-refs-4/dune | 2 +- examples/sturdy-refs-4/logger.ml | 9 +- examples/sturdy-refs-4/main.ml | 91 ++-- examples/sturdy-refs/dune | 2 +- examples/sturdy-refs/main.ml | 22 +- examples/testlib/calc.ml | 68 +-- examples/testlib/calc.mli | 11 +- examples/testlib/echo.ml | 17 +- examples/testlib/echo.mli | 6 +- examples/testlib/registry.ml | 21 +- examples/testlib/registry.mli | 11 +- examples/testlib/store.ml | 18 +- examples/testlib/store.mli | 6 +- examples/v1/dune | 2 +- examples/v1/echo.ml | 3 +- examples/v1/main.ml | 12 +- examples/v2/dune | 2 +- examples/v2/echo.ml | 22 +- examples/v2/fake_clock.ml | 9 + examples/v2/main.ml | 11 +- examples/v3/dune | 2 +- examples/v3/echo.ml | 22 +- examples/v3/fake_clock.ml | 9 + examples/v3/main.ml | 26 +- examples/v4/client.ml | 13 +- examples/v4/dune | 2 +- examples/v4/echo.ml | 22 +- examples/v4/fake_clock.ml | 9 + examples/v4/server.ml | 23 +- fuzz/fuzz.ml | 4 +- mirage/capnp_rpc_mirage.ml | 54 -- mirage/capnp_rpc_mirage.mli | 72 --- mirage/dune | 4 - mirage/network.ml | 93 ---- mirage/network.mli | 37 -- mirage/vat_config.ml | 57 -- test-bin/calc.ml | 40 +- test-bin/calc_direct.ml | 94 ++-- test-bin/dune | 2 +- test-bin/echo/dune | 2 +- test-bin/echo/echo.ml | 3 +- test-bin/echo/echo_bench.ml | 32 +- test-lwt/dune | 4 +- test-lwt/test_lwt.ml | 848 +++++++++++++++-------------- test-mirage/dune | 6 - test-mirage/test_mirage.ml | 142 ----- test-mirage/test_mirage.mli | 1 - test/testbed/connection.ml | 4 +- unix/capnp_rpc_unix.ml | 138 ++--- unix/capnp_rpc_unix.mli | 24 +- unix/dune | 4 +- unix/file_store.ml | 48 +- unix/network.ml | 72 ++- unix/network.mli | 9 +- unix/unix_flow.ml | 95 ---- unix/unix_flow.mli | 7 - unix/vat_network.ml | 2 +- 94 files changed, 1719 insertions(+), 2171 deletions(-) delete mode 100644 capnp-rpc-mirage.opam create mode 100644 capnp-rpc-net/tls_eio.ml create mode 100644 examples/v2/fake_clock.ml create mode 100644 examples/v3/fake_clock.ml create mode 100644 examples/v4/fake_clock.ml delete mode 100644 mirage/capnp_rpc_mirage.ml delete mode 100644 mirage/capnp_rpc_mirage.mli delete mode 100644 mirage/dune delete mode 100644 mirage/network.ml delete mode 100644 mirage/network.mli delete mode 100644 mirage/vat_config.ml delete mode 100644 test-mirage/dune delete mode 100644 test-mirage/test_mirage.ml delete mode 100644 test-mirage/test_mirage.mli delete mode 100644 unix/unix_flow.ml delete mode 100644 unix/unix_flow.mli diff --git a/Makefile b/Makefile index 322e22450..1f7fe3eda 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ default: test build-fuzz all: - dune build @install test/test.exe test-lwt/test_lwt.exe test-bin/calc.exe test-mirage/test_mirage.exe + dune build @install test/test.exe test-lwt/test_lwt.exe test-bin/calc.exe rm -rf _build/_tests dune runtest --no-buffer -j 1 @@ -19,7 +19,7 @@ clean: test: rm -rf _build/_tests - dune build test/test.exe test-lwt/test_lwt.exe test-bin/calc.exe test-mirage/test_mirage.exe test-bin/echo/echo_bench.exe @install + dune build test/test.exe test-lwt/test_lwt.exe test-bin/calc.exe test-bin/echo/echo_bench.exe @install #./_build/default/test/test.bc test core -ev 36 #./_build/default/test-lwt/test.bc test lwt -ev 3 dune build @runtest --no-buffer -j 1 diff --git a/README.md b/README.md index afe544712..b885da8c6 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # OCaml Cap'n Proto RPC library Copyright 2017 Docker, Inc. -Copyright 2019 Thomas Leonard. +Copyright 2022 Thomas Leonard. See [LICENSE.md](LICENSE.md) for details. [API documentation][api] @@ -38,7 +38,6 @@ See [LICENSE.md](LICENSE.md) for details. * [How can I release other resources when my service is released?](#how-can-i-release-other-resources-when-my-service-is-released) * [Is there an interactive version I can use for debugging?](#is-there-an-interactive-version-i-can-use-for-debugging) * [Can I set up a direct 2-party connection over a pre-existing channel?](#can-i-set-up-a-direct-2-party-connection-over-a-pre-existing-channel) - * [How can I use this with Mirage?](#how-can-i-use-this-with-mirage) * [Contributing](#contributing) * [Conceptual model](#conceptual-model) * [Building](#building) @@ -173,7 +172,6 @@ For the server, you should inherit from the generated `Api.Service.Echo.service` ```ocaml module Api = Echo_api.MakeRPC(Capnp_rpc_lwt) -open Lwt.Infix open Capnp_rpc_lwt let local = @@ -207,7 +205,7 @@ There's a bit of ugly boilerplate here, but it's quite simple: should always free them anyway. - `Service.Response.create Results.init_pointer` creates a new response message, using `Ping.Results.init_pointer` to initialise the payload contents. - `response` is the complete message to be sent back, and `results` is the data part of it. -- `Service.return` returns the results immediately (like `Lwt.return`). +- `Service.return` returns the results immediately (rather than returning a promise). The client implementation is similar, but uses `Api.Client` instead of `Api.Service`. Here, we have a *builder* for the parameters and a *reader* for the results. @@ -221,7 +219,7 @@ let ping t msg = let open Echo.Ping in let request, params = Capability.Request.create Params.init_pointer in Params.msg_set params msg; - Capability.call_for_value_exn t method_id request >|= Results.reply_get + Capability.call_for_value_exn t method_id request |> Results.reply_get ``` `Capability.call_for_value_exn` sends the request message to the service and waits for the response to arrive. @@ -235,19 +233,17 @@ With the boilerplate out of the way, we can now write a `main.ml` to test it: ```ocaml -open Lwt.Infix +open Eio.Std let () = Logs.set_level (Some Logs.Warning); Logs.set_reporter (Logs_fmt.reporter ()) let () = - Lwt_main.run begin - let service = Echo.local in - Echo.ping service "foo" >>= fun reply -> - Fmt.pr "Got reply %S@." reply; - Lwt.return_unit - end + Eio_main.run @@ fun _ -> + let service = Echo.local in + let reply = Echo.ping service "foo" in + traceln "Got reply %S" reply ```

@@ -260,7 +256,7 @@ Here's a suitable `dune` file to compile the schema file and then the generated ``` (executable (name main) - (libraries lwt.unix capnp-rpc-lwt logs.fmt) + (libraries eio_main capnp-rpc-lwt logs.fmt) (flags (:standard -w -53-55))) (rule @@ -286,7 +282,7 @@ $ opam depext -i capnp-rpc-lwt ```bash $ dune exec ./main.exe -Got reply "echo:foo" ++Got reply "echo:foo" ``` This isn't very exciting, so let's add some capabilities to the protocol... @@ -324,33 +320,29 @@ The new `heartbeat_impl` method looks like this: match callback with | None -> Service.fail "No callback parameter!" | Some callback -> - Service.return_lwt @@ fun () -> - Capability.with_ref callback (notify ~msg) + Capability.with_ref callback (notify ~clock msg) ``` Note that all parameters in Cap'n Proto are optional, so we have to check for `callback` not being set (data parameters such as `msg` get a default value from the schema, which is `""` for strings if not set explicitly). -`Service.return_lwt fn` runs `fn ()` and replies to the `heartbeat` call when it finishes. -Here, the whole of the rest of the method is the argument to `return_lwt`, which is a common pattern. - `Capability.with_ref x f` calls `f x` and then releases `x` (capabilities are ref-counted). -`notify callback msg` just sends a few messages to `callback` in a loop: +`notify ~clock msg callback` just sends a few messages to `callback` in a loop: ```ocaml -let (>>!=) = Lwt_result.bind (* Return errors *) - -let notify callback ~msg = +let notify ~clock msg callback = let rec loop = function | 0 -> - Lwt.return @@ Ok (Service.Response.create_empty ()) + Service.return_empty () | i -> - Callback.log callback msg >>!= fun () -> - Lwt_unix.sleep 1.0 >>= fun () -> - loop (i - 1) + match Callback.log callback msg with + | Error (`Capnp e) -> Service.error e + | Ok () -> + Eio.Time.sleep clock 1.0; + loop (i - 1) in loop 3 ``` @@ -377,6 +369,7 @@ In `main.ml`, we can now wrap a regular OCaml function as the callback: ```ocaml +open Eio.Std open Capnp_rpc_lwt let () = @@ -384,17 +377,17 @@ let () = Logs.set_reporter (Logs_fmt.reporter ()) let callback_fn msg = - Fmt.pr "Callback got %S@." msg + traceln "Callback got %S" msg let run_client service = Capability.with_ref (Echo.Callback.local callback_fn) @@ fun callback -> Echo.heartbeat service "foo" callback let () = - Lwt_main.run begin - let service = Echo.local in - run_client service - end + Eio_main.run @@ fun env -> + let clock = if Sys.getenv_opt "CI" = None then env#clock else Fake_clock.v in + let service = Echo.local ~clock in + run_client service ``` Step 1: The client creates the callback: @@ -420,12 +413,12 @@ Exercise: implement `Callback.local fn` (hint: it's similar to the original `pin And testing it should give (three times, at one second intervals): - + ```sh $ dune exec -- ./main.exe -Callback got "foo" -Callback got "foo" -Callback got "foo" ++Callback got "foo" ++Callback got "foo" ++Callback got "foo" ``` Note that the client gives the echo service permission to call its callback service by sending a message containing the callback to the service. @@ -444,7 +437,7 @@ Here's the new `main.ml` (the top half is the same as before): ```ocaml -open Lwt.Infix +open Eio.Std open Capnp_rpc_lwt let () = @@ -452,7 +445,7 @@ let () = Logs.set_reporter (Logs_fmt.reporter ()) let callback_fn msg = - Fmt.pr "Callback got %S@." msg + traceln "Callback got %S" msg let run_client service = Capability.with_ref (Echo.Callback.local callback_fn) @@ fun callback -> @@ -461,21 +454,23 @@ let run_client service = let secret_key = `Ephemeral let listen_address = `TCP ("127.0.0.1", 7000) -let start_server () = +let start_server ~sw ~clock net = let config = Capnp_rpc_unix.Vat_config.create ~secret_key listen_address in let service_id = Capnp_rpc_unix.Vat_config.derived_id config "main" in - let restore = Capnp_rpc_net.Restorer.single service_id Echo.local in - Capnp_rpc_unix.serve config ~restore >|= fun vat -> + let restore = Capnp_rpc_net.Restorer.single service_id (Echo.local ~clock) in + let vat = Capnp_rpc_unix.serve ~sw ~net ~restore config in Capnp_rpc_unix.Vat.sturdy_uri vat service_id let () = - Lwt_main.run begin - start_server () >>= fun uri -> - Fmt.pr "Connecting to echo service at: %a@." Uri.pp_hum uri; - let client_vat = Capnp_rpc_unix.client_only_vat () in - let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in - Sturdy_ref.with_cap_exn sr run_client - end + Eio_main.run @@ fun env -> + Switch.run @@ fun sw -> + let clock = if Sys.getenv_opt "CI" = None then env#clock else Fake_clock.v in + let uri = start_server ~sw ~clock env#net in + traceln "Connecting to echo service at: %a" Uri.pp_hum uri; + let client_vat = Capnp_rpc_unix.client_only_vat ~sw env#net in + let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in + Sturdy_ref.with_cap_exn sr run_client; + raise Exit ```

@@ -491,7 +486,7 @@ $ opam depext -i capnp-rpc-unix Running this will give something like: - + ```sh $ dune exec ./main.exe Connecting to echo service at: capnp://sha-256:3Tj5y5Q2qpqN3Sbh0GRPxgORZw98_NtrU2nLI0-Tn6g@127.0.0.1:7000/eBIndzZyoVDxaJdZ8uh_xBx5V1lfXWTJCDX-qEkgNZ4 @@ -546,8 +541,9 @@ In `start_server`: and the name. This means that the ID will be stable as long as the server's key doesn't change. The name used ("main" here) isn't important - it just needs to be unique. -- `let restore = Restorer.single service_id Echo.local` configures a simple "restorer" that - answers requests for `service_id` with our `Echo.local` service. +- `let restore = Capnp_rpc_net.Restorer.single service_id (Echo.local ~clock)` + configures a simple "restorer" that answers requests for `service_id` with + our `Echo.local` service. - `Capnp_rpc_unix.serve config ~restore` creates the service vat using the previous configuration items and starts it listening for incoming connections. @@ -573,7 +569,7 @@ Edit the `dune` file to build a client and server: ``` (executables (names client server) - (libraries lwt.unix capnp-rpc-lwt logs.fmt capnp-rpc-unix) + (libraries eio_main capnp-rpc-lwt logs.fmt capnp-rpc-unix) (flags (:standard -w -53-55))) (rule @@ -586,7 +582,7 @@ Here's a suitable `server.ml`: ```ocaml -open Lwt.Infix +open Eio.Std open Capnp_rpc_net let () = @@ -596,16 +592,17 @@ let () = let cap_file = "echo.cap" let serve config = - Lwt_main.run begin - let service_id = Capnp_rpc_unix.Vat_config.derived_id config "main" in - let restore = Restorer.single service_id Echo.local in - Capnp_rpc_unix.serve config ~restore >>= fun vat -> - match Capnp_rpc_unix.Cap_file.save_service vat service_id cap_file with - | Error `Msg m -> failwith m - | Ok () -> - Fmt.pr "Server running. Connect using %S.@." cap_file; - fst @@ Lwt.wait () (* Wait forever *) - end + Eio_main.run @@ fun env -> + let clock = if Sys.getenv_opt "CI" = None then env#clock else Fake_clock.v in + Switch.run @@ fun sw -> + let service_id = Capnp_rpc_unix.Vat_config.derived_id config "main" in + let restore = Restorer.single service_id (Echo.local ~clock) in + let vat = Capnp_rpc_unix.serve ~sw ~net:env#net ~restore config in + match Capnp_rpc_unix.Cap_file.save_service vat service_id cap_file with + | Error `Msg m -> failwith m + | Ok () -> + traceln "Server running. Connect using %S." cap_file; + Fiber.await_cancel () open Cmdliner @@ -625,6 +622,7 @@ And here's the corresponding `client.ml`: ```ocaml +open Eio.Std open Capnp_rpc_lwt let () = @@ -632,18 +630,18 @@ let () = Logs.set_reporter (Logs_fmt.reporter ()) let callback_fn msg = - Fmt.pr "Callback got %S@." msg + traceln "Callback got %S" msg let run_client service = Capability.with_ref (Echo.Callback.local callback_fn) @@ fun callback -> Echo.heartbeat service "foo" callback let connect uri = - Lwt_main.run begin - let client_vat = Capnp_rpc_unix.client_only_vat () in - let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in - Capnp_rpc_unix.with_cap_exn sr run_client - end + Eio_main.run @@ fun env -> + Switch.run @@ fun sw -> + let client_vat = Capnp_rpc_unix.client_only_vat ~sw env#net in + let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in + Capnp_rpc_unix.with_cap_exn sr run_client open Cmdliner @@ -754,7 +752,7 @@ We can test it as follows: ```ocaml let run_client service = let logger = Echo.get_logger service in - Echo.Callback.log logger "Message from client" >|= function + match Echo.Callback.log logger "Message from client" with | Ok () -> () | Error (`Capnp err) -> Fmt.epr "Server's logger failed: %a" Capnp_rpc.Error.pp err @@ -769,8 +767,10 @@ This should print (in the server's output) something like: ```sh $ dune exec ./main.exe -[client] Connecting to echo service... -[server] Received "Message from client" ++[client] Connecting to echo service... ++[server] Received "Message from client" +Fatal error: exception Exit +[2] ``` In this case, we didn't wait for the `getLogger` call to return before using the logger. @@ -838,32 +838,34 @@ let make_service ~config ~services name = Restorer.Table.add services id service; name, id -let start_server () = +let start_server ~sw net = let config = Capnp_rpc_unix.Vat_config.create ~secret_key listen_address in let make_sturdy = Capnp_rpc_unix.Vat_config.sturdy_uri config in let services = Restorer.Table.create make_sturdy in let restore = Restorer.of_table services in let services = List.map (make_service ~config ~services) ["alice"; "bob"] in - Capnp_rpc_unix.serve config ~restore >|= fun vat -> + let vat = Capnp_rpc_unix.serve ~sw ~net ~restore config in services |> List.iter (fun (name, id) -> let cap_file = name ^ ".cap" in Capnp_rpc_unix.Cap_file.save_service vat id cap_file |> or_fail; Printf.printf "[server] saved %S\n%!" cap_file ) -let run_client cap_file msg = - let vat = Capnp_rpc_unix.client_only_vat () in +let run_client ~sw ~net cap_file msg = + let vat = Capnp_rpc_unix.client_only_vat ~sw net in let sr = Capnp_rpc_unix.Cap_file.load vat cap_file |> or_fail in Printf.printf "[client] loaded %S\n%!" cap_file; Sturdy_ref.with_cap_exn sr @@ fun cap -> Logger.log cap msg let () = - Lwt_main.run begin - start_server () >>= fun () -> - run_client "./alice.cap" "Message from Alice" >>= fun () -> - run_client "./bob.cap" "Message from Bob" - end + Eio_main.run @@ fun env -> + Switch.run @@ fun sw -> + let net = env#net in + start_server ~sw net; + run_client ~sw ~net "./alice.cap" "Message from Alice"; + run_client ~sw ~net "./bob.cap" "Message from Bob"; + raise Exit ``` @@ -875,6 +877,8 @@ $ dune exec ./main.exe [server] "alice" says "Message from Alice" [client] loaded "./bob.cap" [server] "bob" says "Message from Bob" +Fatal error: exception Exit +[2] ``` ### Creating services dynamically @@ -901,17 +905,19 @@ We can use the new API like this: ```ocaml let () = - Lwt_main.run begin - start_server () >>= fun root_uri -> - let vat = Capnp_rpc_unix.client_only_vat () in - let root_sr = Capnp_rpc_unix.Vat.import vat root_uri |> or_fail in - Sturdy_ref.with_cap_exn root_sr @@ fun root -> - Logger.log root "Message from Admin" >>= fun () -> - let for_alice = Logger.sub root "alice" in - let for_bob = Logger.sub root "bob" in - Logger.log for_alice "Message from Alice" >>= fun () -> - Logger.log for_bob "Message from Bob" - end + Eio_main.run @@ fun env -> + Switch.run @@ fun sw -> + let net = env#net in + let root_uri = start_server ~sw net in + let vat = Capnp_rpc_unix.client_only_vat ~sw net in + let root_sr = Capnp_rpc_unix.Vat.import vat root_uri |> or_fail in + Sturdy_ref.with_cap_exn root_sr @@ fun root -> + Logger.log root "Message from Admin"; + let for_alice = Logger.sub root "alice" in + let for_bob = Logger.sub root "bob" in + Logger.log for_alice "Message from Alice"; + Logger.log for_bob "Message from Bob"; + raise Exit ``` @@ -920,6 +926,8 @@ $ dune exec ./main.exe [server] "root" says "Message from Admin" [server] "root/alice" says "Message from Alice" [server] "root/bob" says "Message from Bob" +Fatal error: exception Exit +[2] ``` ### The Persistence API @@ -936,12 +944,13 @@ the admin can request the sturdy ref like this: ```ocaml - (* The admin creates a logger for Alice and saves it: *) - let for_alice = Logger.sub root "alice" in - Persistence.save_exn for_alice >>= fun uri -> - Capnp_rpc_unix.Cap_file.save_uri uri "alice.cap" |> or_fail; - (* Alice uses it: *) - run_client "alice.cap" + (* The admin creates a logger for Alice and saves it: *) + let for_alice = Logger.sub root "alice" in + let uri = Persistence.save_exn for_alice in + Capnp_rpc_unix.Cap_file.save_uri uri "alice.cap" |> or_fail; + (* Alice uses it: *) + run_client ~sw ~net "alice.cap"; + raise Exit ``` If successful, the client can use this sturdy ref to connect directly to the logger in future: @@ -951,6 +960,8 @@ If successful, the client can use this sturdy ref to connect directly to the log $ dune exec ./main.exe [server] "root" says "Message from Admin" [server] "root/alice" says "Message from Alice" +Fatal error: exception Exit +[2] ``` If you try the above, it will fail with `Unimplemented: Unknown interface 14468694717054801553UL`. @@ -1007,7 +1018,7 @@ include Restorer.LOADER type loader = [`Logger_beacebd78653e9af] Sturdy_ref.t -> label:string -> Restorer.resolution (** A function to create a new in-memory logger with the given label and sturdy-ref. *) -val create : make_sturdy:(Restorer.Id.t -> Uri.t) -> string -> t * loader Lwt.u +val create : make_sturdy:(Restorer.Id.t -> Uri.t) -> string -> t * loader Eio.Promise.u (** [create ~make_sturdy dir] is a database that persists services in [dir] and a resolver to let you set the loader (we're not ready to set the loader when we create the database). *) @@ -1040,7 +1051,7 @@ We can use this with `File_store` to implement `Db`: ```ocaml -open Lwt.Infix +open Eio.Std open Capnp_rpc_lwt open Capnp_rpc_net @@ -1051,7 +1062,7 @@ type loader = [`Logger_beacebd78653e9af] Sturdy_ref.t -> label:string -> Restore type t = { store : Store.Reader.SavedService.struct_t File_store.t; - loader : loader Lwt.t; + loader : loader Promise.t; make_sturdy : Restorer.Id.t -> Uri.t; } @@ -1074,16 +1085,16 @@ let save_new t ~label = let load t sr digest = match File_store.load t.store ~digest with - | None -> Lwt.return Restorer.unknown_service_id + | None -> Restorer.unknown_service_id | Some saved_service -> let logger = Store.Reader.SavedService.logger_get saved_service in let label = Store.Reader.SavedLogger.label_get logger in let sr = Capnp_rpc_lwt.Sturdy_ref.cast sr in - t.loader >|= fun loader -> + let loader = Promise.await t.loader in loader sr ~label let create ~make_sturdy dir = - let loader, set_loader = Lwt.wait () in + let loader, set_loader = Promise.create () in if not (Sys.file_exists dir) then Unix.mkdir dir 0o755; let store = File_store.create dir in {store; loader; make_sturdy}, set_loader @@ -1097,33 +1108,33 @@ The main `start_server` function then uses `Db` to create the table: ```ocaml let serve config = - Lwt_main.run begin - (* Create the on-disk store *) - let make_sturdy = Capnp_rpc_unix.Vat_config.sturdy_uri config in - let db, set_loader = Db.create ~make_sturdy "./store" in - (* Create the restorer *) - let services = Restorer.Table.of_loader (module Db) db in - let restore = Restorer.of_table services in - (* Add the root service *) - let persist_new ~label = - let id = Db.save_new db ~label in - Capnp_rpc_net.Restorer.restore restore id - in - let root_id = Capnp_rpc_unix.Vat_config.derived_id config "root" in - let root = - let sr = Capnp_rpc_net.Restorer.Table.sturdy_ref services root_id in - Logger.local ~persist_new sr "root" - in - Restorer.Table.add services root_id root; - (* Tell the database how to restore saved loggers *) - Lwt.wakeup set_loader (fun sr ~label -> Restorer.grant @@ Logger.local ~persist_new sr label); - (* Run the server *) - Capnp_rpc_unix.serve config ~restore >>= fun _vat -> - let uri = Capnp_rpc_unix.Vat_config.sturdy_uri config root_id in - Capnp_rpc_unix.Cap_file.save_uri uri "admin.cap" |> or_fail; - print_endline "Wrote admin.cap"; - fst @@ Lwt.wait () (* Wait forever *) - end + Eio_main.run @@ fun env -> + Switch.run @@ fun sw -> + (* Create the on-disk store *) + let make_sturdy = Capnp_rpc_unix.Vat_config.sturdy_uri config in + let db, set_loader = Db.create ~make_sturdy "./store" in + (* Create the restorer *) + let services = Restorer.Table.of_loader ~sw (module Db) db in + let restore = Restorer.of_table services in + (* Add the root service *) + let persist_new ~label = + let id = Db.save_new db ~label in + Capnp_rpc_net.Restorer.restore restore id + in + let root_id = Capnp_rpc_unix.Vat_config.derived_id config "root" in + let root = + let sr = Capnp_rpc_net.Restorer.Table.sturdy_ref services root_id in + Logger.local ~persist_new sr "root" + in + Restorer.Table.add services root_id root; + (* Tell the database how to restore saved loggers *) + Promise.resolve set_loader (fun sr ~label -> Restorer.grant @@ Logger.local ~persist_new sr label); + (* Run the server *) + let _vat = Capnp_rpc_unix.serve ~sw ~net:env#net ~restore config in + let uri = Capnp_rpc_unix.Vat_config.sturdy_uri config root_id in + Capnp_rpc_unix.Cap_file.save_uri uri "admin.cap" |> or_fail; + print_endline "Wrote admin.cap"; + Fiber.await_cancel () ``` The server implementation of the `sub` method gets the label from the parameters @@ -1136,14 +1147,13 @@ and uses `persist_new` to save the new logger to the database: let sub_label = Params.label_get params in release_param_caps (); let label = Printf.sprintf "%s/%s" label sub_label in - Service.return_lwt @@ fun () -> - persist_new ~label >|= function - | Error e -> Error (`Capnp (`Exception e)) + match persist_new ~label with + | Error e -> Service.error (`Exception e) | Ok logger -> let response, results = Service.Response.create Results.init_pointer in Results.logger_set results (Some logger); Capability.dec_ref logger; - Ok response + Service.return response ``` @@ -1257,12 +1267,12 @@ The solution here is to construct `Frontend` with a *promise* for the sturdy ref ```ocaml -let run_frontend backend_uri = - let backend_promise, resolver = Lwt.wait () in +let run_frontend ~sw ~net backend_uri = + let backend_promise, resolver = Promise.create () in let frontend = Frontend.make backend_promise in let restore = Restorer.single id frontend in - Capnp_rpc_unix.serve config ~restore >|= fun vat -> - Lwt.wakeup resolver (Capnp_rpc_unix.Vat.import_exn vat backend_uri) + let vat = Capnp_rpc_unix.serve ~sw ~net ~restore config in + Promise.resolve resolver (Capnp_rpc_unix.Vat.import_exn vat backend_uri) ``` ### How can I release other resources when my service is released? @@ -1351,58 +1361,6 @@ parent: application: Waiting for child to exit... parent: application: Done ``` -### How can I use this with Mirage? - -Note: `capnp` uses the `stdint` library, which has C stubs and -[might need patching](https://github.com/mirage/mirage/issues/885) to work with the Xen backend. - explains why OCaml doesn't have unsigned integer support. - -Here is a suitable `config.ml`: - - -```ocaml -open Mirage - -let main = - foreign - ~packages:[package "capnp-rpc-mirage"; package "mirage-dns"] - "Unikernel.Make" (random @-> mclock @-> stackv4 @-> job) - -let stack = generic_stackv4 default_network - -let () = - register "test" [main $ default_random $ default_monotonic_clock $ stack] -``` - -This should work as the `unikernel.ml`: - - -```ocaml -open Lwt.Infix - -open Capnp_rpc_lwt - -module Make (R : Mirage_random.S) (C : Mirage_clock.MCLOCK) (Stack : Mirage_stack.V4) = struct - module Mirage_capnp = Capnp_rpc_mirage.Make (R) (C) (Stack) - - let secret_key = `Ephemeral - - let listen_address = `TCP 7000 - let public_address = `TCP ("localhost", 7000) - - let start () () stack = - let dns = Mirage.Network.Dns.create stack in - let net = Mirage_capnp.network ~dns stack in - let config = Mirage_capnp.Vat_config.create ~secret_key ~public_address listen_address in - let service_id = Mirage_capnp.Vat_config.derived_id config "main" in - let restore = Restorer.single service_id Echo.local in - Mirage_capnp.serve net config ~restore >>= fun vat -> - let uri = Mirage_capnp.Vat.sturdy_uri vat service_id in - Logs.app (fun f -> f "Main service: %a" Uri.pp_hum uri); - Lwt.wait () |> fst -end -``` - ## Contributing ### Conceptual model diff --git a/capnp-rpc-lwt.opam b/capnp-rpc-lwt.opam index bd843ee98..0aed78591 100644 --- a/capnp-rpc-lwt.opam +++ b/capnp-rpc-lwt.opam @@ -15,16 +15,26 @@ depends: [ "conf-capnproto" {build} "capnp" {>= "3.4.0"} "capnp-rpc" {= version} - "lwt" "astring" "fmt" {>= "0.8.7"} "logs" "asetmap" "uri" {>= "1.6.0"} "dune" {>= "3.0"} + "eio" ] build: [ ["dune" "build" "-p" name "-j" jobs] ["dune" "runtest" "-p" name "-j" jobs] {with-test} ] dev-repo: "git+https://github.com/mirage/capnp-rpc.git" +pin-depends: [ + ["eio.dev" "git+https://github.com/ocaml-multicore/eio.git#2aadabf4249a0fb14f0f70a06542103fc07ff08f"] + ["eio_linux.dev" "git+https://github.com/ocaml-multicore/eio.git#2aadabf4249a0fb14f0f70a06542103fc07ff08f"] + ["eio_luv.dev" "git+https://github.com/ocaml-multicore/eio.git#2aadabf4249a0fb14f0f70a06542103fc07ff08f"] + ["eio_main.dev" "git+https://github.com/ocaml-multicore/eio.git#2aadabf4249a0fb14f0f70a06542103fc07ff08f"] + ["base.v0.14" "git+https://github.com/kit-ty-kate/base.git#a3cf0042e943c9c979ff7424912c71c6236f68f3"] + ["stdint.0.7.0.1~alpha-repo" "git+https://github.com/kit-ty-kate/ocaml-stdint.git#cb95ca6dff6bd58aa555b872a5db6558837d52db"] + ["sexplib.v0.14.1~alpha-repo" "git+https://github.com/kit-ty-kate/sexplib.git#864ef15e927761b4d98d194b638293a7ca2fbff3"] + ["sexplib0.v0.14.1~alpha-repo" "git+https://github.com/kit-ty-kate/sexplib0.git#f13a9b23e7f0a68d9e5d81af30c35cbd419f8b25"] +] diff --git a/capnp-rpc-lwt/capability.ml b/capnp-rpc-lwt/capability.ml index 1492cc0a8..07702829a 100644 --- a/capnp-rpc-lwt/capability.ml +++ b/capnp-rpc-lwt/capability.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std open Capnp_core module Log = Capnp_rpc.Debug.Log @@ -14,9 +14,9 @@ let inc_ref = Core_types.inc_ref let dec_ref = Core_types.dec_ref let with_ref t fn = - Lwt.finalize + Fun.protect (fun () -> fn t) - (fun () -> dec_ref t; Lwt.return_unit) + ~finally:(fun () -> dec_ref t) let pp f x = x#pp f @@ -26,10 +26,10 @@ let when_released (x:Core_types.cap) f = x#when_released f let problem x = x#problem let wait_until_settled (x : _ t) = - let result, set_result = Lwt.wait () in + let result, set_result = Promise.create () in let rec aux x = if x#blocker = None then ( - Lwt.wakeup set_result () + Promise.resolve set_result () ) else ( x#when_more_resolved (fun x -> Core_types.dec_ref x; @@ -38,16 +38,16 @@ let wait_until_settled (x : _ t) = ) in aux x; - result + Promise.await result let await_settled t = - wait_until_settled t >|= fun () -> + wait_until_settled t; match problem t with | None -> Ok () | Some ex -> Error ex let await_settled_exn t = - wait_until_settled t >|= fun () -> + wait_until_settled t; match problem t with | None -> () | Some e -> Fmt.failwith "%a" Capnp_rpc.Exception.pp e @@ -72,31 +72,33 @@ let call (target : 't capability_t) (m : ('t, 'a, 'b) method_t) (req : 'a Reques results let call_and_wait cap (m : ('t, 'a, 'b StructStorage.reader_t) method_t) req = - let p, r = Lwt.task () in + let p, r = Promise.create () in let result = call cap m req in let finish = lazy (Core_types.dec_ref result) in - Lwt.on_cancel p (fun () -> Lazy.force finish); result#when_resolved (function - | Error e -> Lwt.wakeup r (Error (`Capnp e)) + | Error e -> Promise.resolve_error r (`Capnp e) | Ok resp -> Lazy.force finish; let payload = Msg.Response.readable resp in let release_response_caps () = Core_types.Response_payload.release resp in let contents = Schema.Reader.Payload.content_get payload |> Schema.Reader.of_pointer in - Lwt.wakeup r @@ Ok (contents, release_response_caps) + Promise.resolve_ok r (contents, release_response_caps) ); - p + try Promise.await p + with ex -> + Lazy.force finish; + raise ex let call_for_value cap m req = - call_and_wait cap m req >|= function + match call_and_wait cap m req with | Error _ as response -> response | Ok (response, release_response_caps) -> release_response_caps (); Ok response let call_for_value_exn cap m req = - call_for_value cap m req >>= function - | Ok x -> Lwt.return x + match call_for_value cap m req with + | Ok x -> x | Error (`Capnp e) -> Log.debug (fun f -> f "Error calling %t(%a): %a" cap#pp @@ -105,11 +107,11 @@ let call_for_value_exn cap m req = Fmt.failwith "%a: %a" Capnp.RPC.MethodID.pp m Capnp_rpc.Error.pp e let call_for_unit cap m req = - call_for_value cap m req >|= function + match call_for_value cap m req with | Ok _ -> Ok () | Error _ as e -> e -let call_for_unit_exn cap m req = call_for_value_exn cap m req >|= ignore +let call_for_unit_exn cap m req = call_for_value_exn cap m req |> ignore let call_for_caps cap m req fn = let q = call cap m req in diff --git a/capnp-rpc-lwt/capnp_core.ml b/capnp-rpc-lwt/capnp_core.ml index e859526e1..b83490c62 100644 --- a/capnp-rpc-lwt/capnp_core.ml +++ b/capnp-rpc-lwt/capnp_core.ml @@ -1,13 +1,10 @@ -open Lwt.Infix - module Capnp_content = struct include Msg let ref_leak_detected fn = - Lwt.async (fun () -> - Lwt.pause () >|= fun () -> - fn () - ) + (* XXX: need to call [fn] in the leaked cap's domain, from + the event loop. *) + fn () end module Core_types = Capnp_rpc.Core_types(Capnp_content) @@ -19,6 +16,6 @@ module type ENDPOINT = Capnp_rpc.Message_types.ENDPOINT with module Core_types = Core_types class type sturdy_ref = object - method connect : (Core_types.cap, Capnp_rpc.Exception.t) result Lwt.t + method connect : (Core_types.cap, Capnp_rpc.Exception.t) result method to_uri_with_secrets : Uri.t end diff --git a/capnp-rpc-lwt/capnp_rpc_lwt.mli b/capnp-rpc-lwt/capnp_rpc_lwt.mli index cce83768b..ffd58427a 100644 --- a/capnp-rpc-lwt/capnp_rpc_lwt.mli +++ b/capnp-rpc-lwt/capnp_rpc_lwt.mli @@ -1,4 +1,4 @@ -(** Cap'n Proto RPC using the Cap'n Proto serialisation and Lwt for concurrency. *) +(** Cap'n Proto RPC using the Cap'n Proto serialisation and Eio for concurrency. *) open Capnp.RPC @@ -55,7 +55,7 @@ module Capability : sig believed to be healthy. Once a capability is broken, it will never work again and any calls made on it will fail with exception [ex]. *) - val await_settled : 'a t -> (unit, Capnp_rpc.Exception.t) Lwt_result.t + val await_settled : 'a t -> (unit, Capnp_rpc.Exception.t) result (** [await_settled t] resolves once [t] is a "settled" (non-promise) reference. If [t] is a near, far or broken reference, this returns immediately. If it is currently a local or remote promise, it waits until it isn't. @@ -64,13 +64,10 @@ module Capability : sig @return [Ok ()] on success, or [Error _] if [t] failed. @since 1.2 *) - val await_settled_exn : 'a t -> unit Lwt.t + val await_settled_exn : 'a t -> unit (** Like [await_settled], but raises an exception on error. @since 1.2 *) - val wait_until_settled : 'a t -> unit Lwt.t - [@@deprecated "Use await_settled instead."] - val equal : 'a t -> 'a t -> (bool, [`Unsettled]) result (** [equal a b] indicates whether [a] and [b] designate the same settled service. Returns [Error `Unsettled] if [a] or [b] is still a promise (and they therefore @@ -104,7 +101,7 @@ module Capability : sig instead for a simpler interface). *) val call_and_wait : 't t -> ('t, 'a, 'b StructStorage.reader_t) MethodID.t -> - 'a Request.t -> (('b StructStorage.reader_t * (unit -> unit)), [> `Capnp of Capnp_rpc.Error.t]) Lwt_result.t + 'a Request.t -> (('b StructStorage.reader_t * (unit -> unit)), [> `Capnp of Capnp_rpc.Error.t]) result (** [call_and_wait t m req] does [call t m req] and waits for the response. This is simpler than using [call], but doesn't support pipelining (you can't use any capabilities in the response in another message until the @@ -114,26 +111,25 @@ module Capability : sig contain that you didn't use (remembering that future versions of the protocol might add new optional capabilities you don't know about yet). If you don't need any capabilities from the result, consider using [call_for_value] instead. - Doing [Lwt.cancel] on the result will send a cancel message to the target - for remote calls. *) + Cancelling the fiber will send a cancel message to the target for remote calls. *) val call_for_value : 't t -> ('t, 'a, 'b StructStorage.reader_t) MethodID.t -> - 'a Request.t -> ('b StructStorage.reader_t, [> `Capnp of Capnp_rpc.Error.t]) Lwt_result.t + 'a Request.t -> ('b StructStorage.reader_t, [> `Capnp of Capnp_rpc.Error.t]) result (** [call_for_value t m req] is similar to [call_and_wait], but automatically releases any capabilities in the response before returning. Use this if you aren't expecting any capabilities in the response. *) val call_for_value_exn : 't t -> ('t, 'a, 'b StructStorage.reader_t) MethodID.t -> - 'a Request.t -> 'b StructStorage.reader_t Lwt.t - (** Wrapper for [call_for_value] that turns errors into Lwt failures. *) + 'a Request.t -> 'b StructStorage.reader_t + (** Wrapper for [call_for_value] that turns errors into exceptions. *) val call_for_unit : 't t -> ('t, 'a, 'b StructStorage.reader_t) MethodID.t -> - 'a Request.t -> (unit, [> `Capnp of Capnp_rpc.Error.t]) Lwt_result.t + 'a Request.t -> (unit, [> `Capnp of Capnp_rpc.Error.t]) result (** Wrapper for [call_for_value] that ignores the result. *) val call_for_unit_exn : 't t -> ('t, 'a, 'b StructStorage.reader_t) MethodID.t -> - 'a Request.t -> unit Lwt.t - (** Wrapper for [call_for_unit] that turns errors into Lwt failures. *) + 'a Request.t -> unit + (** Wrapper for [call_for_unit] that turns errors into exceptions. *) val call_for_caps : 't t -> ('t, 'a, 'b StructStorage.reader_t) MethodID.t -> 'a Request.t -> ('b StructRef.t -> 'c) -> 'c @@ -166,7 +162,7 @@ module Capability : sig peer. Any time you extract a capability from a struct or struct promise, it must eventually be freed by calling [dec_ref] on it. *) - val with_ref : 'a t -> ('a t -> 'b Lwt.t) -> 'b Lwt.t + val with_ref : 'a t -> ('a t -> 'b) -> 'b (** [with_ref t fn] runs [fn t] and then calls [dec_ref t] (whether [fn] succeeds or not). *) @@ -185,20 +181,20 @@ module Sturdy_ref : sig (e.g. a "Swiss number") *) - val connect : 'a t -> ('a Capability.t, Capnp_rpc.Exception.t) result Lwt.t + val connect : 'a t -> ('a Capability.t, Capnp_rpc.Exception.t) result (** [connect t] returns a live reference to [t]'s service. *) - val connect_exn : 'a t -> 'a Capability.t Lwt.t - (** [connect_exn] is a wrapper for [connect] that returns a failed Lwt thread on error. *) + val connect_exn : 'a t -> 'a Capability.t + (** [connect_exn] is a wrapper for [connect] that raises an exception on error. *) val with_cap : 'a t -> - ('a Capability.t -> ('b, [> `Capnp of Capnp_rpc.Exception.t] as 'e) Lwt_result.t) -> - ('b, 'e) Lwt_result.t + ('a Capability.t -> ('b, [> `Capnp of Capnp_rpc.Exception.t] as 'e) result) -> + ('b, 'e) result (** [with_cap t f] uses [connect t] to get a live-ref [x], then does [Capability.with_ref x f]. *) - val with_cap_exn : 'a t -> ('a Capability.t -> 'b Lwt.t) -> 'b Lwt.t + val with_cap_exn : 'a t -> ('a Capability.t -> 'b) -> 'b (** [with_cap_exn t f] uses [connect_exn t] to get a live-ref [x], then does [Capability.with_ref x f]. *) @@ -251,22 +247,10 @@ module Service : sig val return_empty : unit -> 'a StructRef.t (** [return_empty ()] is a promise for a response with no payload. *) - val return_lwt : (unit -> ('a Response.t, [< `Capnp of Capnp_rpc.Error.t]) Lwt_result.t) -> 'a StructRef.t - (** [return_lwt fn] is a local promise for the result of Lwt thread [fn ()]. - If [fn ()] fails, the error is logged and an "Internal error" returned to the caller. - If it returns an [Error] value then that error is returned to the caller. - Note that this does not support pipelining (any messages sent to the response - will be queued locally until it [fn] has produced a result), so it may be better - to return immediately a result containing a promise in some cases. *) - val fail : ?ty:Capnp_rpc.Exception.ty -> ('a, Format.formatter, unit, 'b StructRef.t) format4 -> 'a (** [fail msg] is an exception with reason [msg]. *) - val fail_lwt : - ?ty:Capnp_rpc.Exception.ty -> - ('a, Format.formatter, unit, (_, [> `Capnp of Capnp_rpc.Error.t]) Lwt_result.t) format4 -> - 'a - (** [fail_lwt msg] is like [fail msg], but can be used with [return_lwt]. *) + val error : Capnp_rpc.Error.t -> 'a StructRef.t end (**/**) @@ -314,7 +298,7 @@ end module Persistence : sig class type ['a] persistent = object - method save : ('a Sturdy_ref.t, Capnp_rpc.Exception.t) result Lwt.t + method save : ('a Sturdy_ref.t, Capnp_rpc.Exception.t) result end val with_persistence : @@ -333,11 +317,11 @@ module Persistence : sig (** [with_sturdy_ref sr Service.Foo.local obj] is like [Service.Foo.local obj], but responds to [save] calls by returning [sr]. *) - val save : 'a Capability.t -> (Uri.t, [> `Capnp of Capnp_rpc.Error.t]) Lwt_result.t + val save : 'a Capability.t -> (Uri.t, [> `Capnp of Capnp_rpc.Error.t]) result (** [save cap] calls the persistent [save] method on [cap]. Note that not all capabilities can be saved. todo: this should return an ['a Sturdy_ref.t]; see {!Sturdy_ref.reader}. *) - val save_exn : 'a Capability.t -> Uri.t Lwt.t + val save_exn : 'a Capability.t -> Uri.t (** [save_exn] is a wrapper for [save] that returns a failed thread on error. *) end diff --git a/capnp-rpc-lwt/dune b/capnp-rpc-lwt/dune index bf31965e4..bfd19e674 100644 --- a/capnp-rpc-lwt/dune +++ b/capnp-rpc-lwt/dune @@ -3,7 +3,7 @@ (public_name capnp-rpc-lwt) (ocamlc_flags :standard -w -55-53) (ocamlopt_flags :standard -w -55-53) - (libraries astring capnp capnp-rpc fmt logs lwt uri)) + (libraries astring capnp capnp-rpc fmt logs eio uri)) (rule (targets rpc_schema.ml rpc_schema.mli) diff --git a/capnp-rpc-lwt/persistence.ml b/capnp-rpc-lwt/persistence.ml index f01b50e60..9c90e5910 100644 --- a/capnp-rpc-lwt/persistence.ml +++ b/capnp-rpc-lwt/persistence.ml @@ -1,9 +1,7 @@ -open Lwt.Infix - module Api = Persistent.Make(Capnp.BytesMessage) class type ['a] persistent = object - method save : ('a Sturdy_ref.t, Capnp_rpc.Exception.t) result Lwt.t + method save : ('a Sturdy_ref.t, Capnp_rpc.Exception.t) result end let with_persistence @@ -16,13 +14,12 @@ let with_persistence if method_id = Capnp.RPC.MethodID.method_id Api.Client.Persistent.Save.method_id then ( let open Api.Service.Persistent.Save in release_params (); - Service.return_lwt @@ fun () -> - persistent#save >|= function - | Error e -> Error (`Capnp (`Exception e)) + match persistent#save with + | Error e -> Service.error (`Exception e) | Ok sr -> let resp, results = Service.Response.create Results.init_pointer in Sturdy_ref.builder Results.sturdy_ref_get results sr; - Ok resp + Service.return resp ) else ( release_params (); Service.fail ~ty:`Unimplemented "Unknown persistence method %d" method_id @@ -39,18 +36,18 @@ let with_persistence let with_sturdy_ref sr local impl = let persistent = object - method save = Lwt.return (Ok sr) + method save = Ok sr end in with_persistence persistent local impl let save cap = let open Api.Client.Persistent.Save in let request = Capability.Request.create_no_args () in - Capability.call_for_value cap method_id request >|= function + match Capability.call_for_value cap method_id request with | Error _ as e -> e | Ok response -> Ok (Sturdy_ref.reader Results.sturdy_ref_get response) let save_exn cap = - save cap >>= function - | Error (`Capnp e) -> Lwt.fail_with (Fmt.to_to_string Capnp_rpc.Error.pp e) - | Ok x -> Lwt.return x + match save cap with + | Error (`Capnp e) -> failwith (Fmt.to_to_string Capnp_rpc.Error.pp e) + | Ok x -> x diff --git a/capnp-rpc-lwt/service.ml b/capnp-rpc-lwt/service.ml index 1bb0ddcdf..dfbe50faf 100644 --- a/capnp-rpc-lwt/service.ml +++ b/capnp-rpc-lwt/service.ml @@ -1,5 +1,4 @@ open Capnp_core -open Lwt.Infix module Log = Capnp_rpc.Debug.Log @@ -61,27 +60,6 @@ let return resp = let return_empty () = return @@ Response.create_empty () -(* A convenient way to implement a simple blocking local function, where - pipelining is not supported (messages sent to the result promise will be - queued up at this host until it returns). *) -let return_lwt fn = - let result, resolver = Local_struct_promise.make () in - Lwt.async (fun () -> - Lwt.catch (fun () -> - fn () >|= function - | Ok resp -> Core_types.resolve_ok resolver @@ Response.finish resp; - | Error (`Capnp e) -> Core_types.resolve_payload resolver (Error e) - ) - (fun ex -> - Log.warn (fun f -> f "Uncaught exception: %a" Fmt.exn ex); - Core_types.resolve_exn resolver @@ Capnp_rpc.Exception.v "Internal error"; - Lwt.return_unit - ); - ); - result - let fail = Core_types.fail -let fail_lwt ?ty fmt = - fmt |> Fmt.kstr @@ fun msg -> - Lwt_result.fail (`Capnp (`Exception (Capnp_rpc.Exception.v ?ty msg))) +let error = Core_types.broken_struct diff --git a/capnp-rpc-lwt/sturdy_ref.ml b/capnp-rpc-lwt/sturdy_ref.ml index cc2a70b1d..fbf729b58 100644 --- a/capnp-rpc-lwt/sturdy_ref.ml +++ b/capnp-rpc-lwt/sturdy_ref.ml @@ -1,13 +1,11 @@ -open Lwt.Infix - class type [+'a] t = Capnp_core.sturdy_ref let connect t = t#connect let connect_exn t = - connect t >>= function - | Ok x -> Lwt.return x - | Error e -> Lwt.fail_with (Fmt.to_to_string Capnp_rpc.Exception.pp e) + match connect t with + | Ok x -> x + | Error e -> failwith (Fmt.to_to_string Capnp_rpc.Exception.pp e) let reader fn s = fn s |> Schema.ReaderOps.string_of_pointer |> Uri.of_string @@ -18,10 +16,10 @@ let builder fn (s : 'a Capnp.BytesMessage.StructStorage.builder_t) (sr : 'a t) = let cast t = t let with_cap t f = - connect t >>= function + match connect t with | Ok x -> Capability.with_ref x f - | Error e -> Lwt_result.fail (`Capnp e) + | Error e -> Error (`Capnp e) let with_cap_exn t f = - connect_exn t >>= fun x -> + let x = connect_exn t in Capability.with_ref x f diff --git a/capnp-rpc-mirage.opam b/capnp-rpc-mirage.opam deleted file mode 100644 index 81a95cc0a..000000000 --- a/capnp-rpc-mirage.opam +++ /dev/null @@ -1,36 +0,0 @@ -opam-version: "2.0" -synopsis: - "Cap'n Proto is a capability-based RPC system with bindings for many languages" -description: - "This package provides a version of the Cap'n Proto RPC system for use with MirageOS." -maintainer: "Thomas Leonard " -authors: "Thomas Leonard " -license: "Apache-2.0" -homepage: "https://github.com/mirage/capnp-rpc" -doc: "https://mirage.github.io/capnp-rpc/" -bug-reports: "https://github.com/mirage/capnp-rpc/issues" -depends: [ - "ocaml" {>= "4.08.0"} - "capnp" {>= "3.1.0"} - "capnp-rpc-net" {= version} - "fmt" {>= "0.8.7"} - "logs" - "dns-client" {>= "6.0.0"} - "tls-mirage" - "tcpip" {>= "7.0.0"} - "alcotest" {>= "1.0.1" & with-test} - "alcotest-lwt" {>= "1.0.1" & with-test} - "arp" {>= "3.0.0" & with-test} - "asetmap" {with-test} - "astring" {with-test} - "ethernet" {>= "3.0.0" & with-test} - "io-page-unix" {with-test} - "mirage-vnetif" {with-test} - "mirage-crypto-rng" {>= "0.7.0" & with-test} - "dune" {>= "3.0"} -] -build: [ - ["dune" "build" "-p" name "-j" jobs] - ["dune" "runtest" "-p" name "-j" jobs] {with-test} -] -dev-repo: "git+https://github.com/mirage/capnp-rpc.git" diff --git a/capnp-rpc-net.opam b/capnp-rpc-net.opam index 14e3af9d4..f13b11ccc 100644 --- a/capnp-rpc-net.opam +++ b/capnp-rpc-net.opam @@ -21,7 +21,6 @@ depends: [ "logs" "asetmap" "cstruct" {>= "6.0.0"} - "mirage-flow" {>= "2.0.0"} "tls" {>= "0.13.1"} "base64" {>= "3.0.0"} "uri" {>= "1.6.0"} diff --git a/capnp-rpc-net/capTP_capnp.ml b/capnp-rpc-net/capTP_capnp.ml index e879ea63b..cd815f973 100644 --- a/capnp-rpc-net/capTP_capnp.ml +++ b/capnp-rpc-net/capTP_capnp.ml @@ -1,5 +1,5 @@ open Capnp_rpc_lwt -open Lwt.Infix +open Eio.Std module Metrics = struct open Prometheus @@ -43,6 +43,7 @@ module Make (Network : S.NETWORK) = struct module Serialise = Serialise.Make(Endpoint_types) type t = { + sw : Switch.t; endpoint : Endpoint.t; conn : Conn.t; xmit_queue : Capnp.Message.rw Capnp.BytesMessage.Message.t Queue.t; @@ -51,16 +52,6 @@ module Make (Network : S.NETWORK) = struct let bootstrap t id = Conn.bootstrap t.conn id |> Cast.cap_of_raw - let async_tagged label fn = - Lwt.async - (fun () -> - Lwt.catch fn - (fun ex -> - Log.warn (fun f -> f "Uncaught async exception in %S: %a" label Fmt.exn ex); - Lwt.return_unit - ) - ) - let pp_msg f call = let open Reader in let call = Private.Msg.Request.readable call in @@ -76,30 +67,32 @@ module Make (Network : S.NETWORK) = struct (* [flush ~xmit_queue endpoint] writes each message in the queue until it is empty. Invariant: - Whenever Lwt blocks or switches threads, a flush thread is running iff the + Whenever Eio blocks or switches threads, a flush thread is running iff the queue is non-empty. *) let rec flush ~xmit_queue endpoint = (* We keep the item on the queue until it is transmitted, as the queue state tells us whether there is a [flush] currently running. *) let next = Queue.peek xmit_queue in - Endpoint.send endpoint next >>= function + match Endpoint.send endpoint next with | Error `Closed -> - Endpoint.disconnect endpoint >|= fun () -> (* We'll read a close soon *) + Endpoint.disconnect endpoint; (* We'll read a close soon *) drop_queue xmit_queue - | Error e -> - Log.warn (fun f -> f "Error sending messages: %a (will shutdown connection)" Endpoint.pp_error e); - Endpoint.disconnect endpoint >|= fun () -> + | Error (`Msg msg) -> + Log.warn (fun f -> f "Error sending messages: %s (will shutdown connection)" msg); + Endpoint.disconnect endpoint; drop_queue xmit_queue | Ok () -> Prometheus.Counter.inc_one Metrics.messages_outbound_sent_total; ignore (Queue.pop xmit_queue); if not (Queue.is_empty xmit_queue) then flush ~xmit_queue endpoint - else (* queue is empty and flush thread is done *) - Lwt.return_unit + (* else queue is empty and flush thread is done *) + | exception ex -> + drop_queue xmit_queue; + raise ex (* Enqueue [message] in [xmit_queue] and ensure the flush thread is running. *) - let queue_send ~xmit_queue endpoint message = + let queue_send ~sw ~xmit_queue endpoint message = Log.debug (fun f -> let module M = Capnp_rpc_lwt.Private.Schema.MessageWrapper.Message in f "queue_send: %d/%d allocated bytes in %d segs" @@ -109,19 +102,19 @@ module Make (Network : S.NETWORK) = struct let was_idle = Queue.is_empty xmit_queue in Queue.add message xmit_queue; Prometheus.Counter.inc_one Metrics.messages_outbound_enqueued_total; - if was_idle then async_tagged "Message sender thread" (fun () -> flush ~xmit_queue endpoint) + if was_idle then Eio.Fiber.fork ~sw (fun () -> flush ~xmit_queue endpoint) let return_not_implemented t x = Log.debug (fun f -> f ~tags:(tags t) "Returning Unimplemented"); let open Builder in let m = Message.init_root () in let _ : Builder.Message.t = Message.unimplemented_set_reader m x in - queue_send ~xmit_queue:t.xmit_queue t.endpoint (Message.to_message m) + queue_send ~sw:t.sw ~xmit_queue:t.xmit_queue t.endpoint (Message.to_message m) let listen t = let rec loop () = - Endpoint.recv t.endpoint >>= function - | Error e -> Lwt.return e + match Endpoint.recv t.endpoint with + | Error e -> e | Ok msg -> let open Reader.Message in let msg = of_message msg in @@ -135,8 +128,8 @@ module Make (Network : S.NETWORK) = struct | `Abort _ -> t.disconnecting <- true; Conn.handle_msg t.conn msg; - Endpoint.disconnect t.endpoint >>= fun () -> - Lwt.return `Aborted + Endpoint.disconnect t.endpoint; + `Aborted | _ -> Conn.handle_msg t.conn msg; loop () @@ -154,48 +147,52 @@ module Make (Network : S.NETWORK) = struct in loop () + let send_abort t ex = + queue_send ~sw:t.sw ~xmit_queue:t.xmit_queue t.endpoint (Serialise.message (`Abort ex)) + let disconnect t ex = if not t.disconnecting then ( t.disconnecting <- true; - queue_send ~xmit_queue:t.xmit_queue t.endpoint (Serialise.message (`Abort ex)); - Endpoint.disconnect t.endpoint >|= fun () -> + send_abort t ex; + Endpoint.disconnect t.endpoint; Conn.disconnect t.conn ex - ) else ( - Lwt.return_unit ) let disconnecting t = t.disconnecting - let connect ~restore ?(tags=Logs.Tag.empty) endpoint = + let connect ~sw ~restore ?(tags=Logs.Tag.empty) endpoint = let xmit_queue = Queue.create () in - let queue_send msg = queue_send ~xmit_queue endpoint (Serialise.message msg) in + let queue_send msg = queue_send ~sw ~xmit_queue endpoint (Serialise.message msg) in let restore = Restorer.fn restore in - let conn = Conn.create ~restore ~tags ~queue_send in - let t = { + let fork = Fiber.fork ~sw in + let conn = Conn.create ~restore ~tags ~fork ~queue_send in + { + sw; conn; endpoint; xmit_queue; disconnecting = false; - } in + } + + let listen t = Prometheus.Gauge.inc_one Metrics.connections; - Lwt.async (fun () -> - Lwt.catch - (fun () -> - listen t >|= fun (`Closed | `Aborted) -> () - ) - (fun ex -> - Log.warn (fun f -> - f ~tags "Uncaught exception handling CapTP connection: %a (dropping connection)" Fmt.exn ex - ); - queue_send @@ `Abort (Capnp_rpc.Exception.v ~ty:`Failed (Printexc.to_string ex)); - Lwt.return_unit - ) - >>= fun () -> - Log.info (fun f -> f ~tags "Connection closed"); - Prometheus.Gauge.dec_one Metrics.connections; + let tags = Conn.tags t.conn in + begin + match listen t with + | `Closed | `Aborted -> () + | exception Eio.Cancel.Cancelled ex -> + Log.debug (fun f -> f ~tags "Cancelled: %a" Fmt.exn ex) + | exception ex -> + Log.warn (fun f -> + f ~tags "Uncaught exception handling CapTP connection: %a (dropping connection)" Fmt.exn ex + ); + send_abort t (Capnp_rpc.Exception.v ~ty:`Failed (Printexc.to_string ex)) + end; + Log.info (fun f -> f ~tags "Connection closed"); + Prometheus.Gauge.dec_one Metrics.connections; + Eio.Cancel.protect (fun () -> disconnect t (Capnp_rpc.Exception.v ~ty:`Disconnected "Connection closed") - ); - t + ) let dump f t = Conn.dump f t.conn end diff --git a/capnp-rpc-net/capTP_capnp.mli b/capnp-rpc-net/capTP_capnp.mli index ec3b3914c..6160955b4 100644 --- a/capnp-rpc-net/capTP_capnp.mli +++ b/capnp-rpc-net/capTP_capnp.mli @@ -4,17 +4,22 @@ module Make (N : S.NETWORK) : sig type t (** A Cap'n Proto RPC protocol handler. *) - val connect : restore:Restorer.t -> ?tags:Logs.Tag.set -> Endpoint.t -> t - (** [connect ~restore ~switch endpoint] is fresh CapTP protocol handler that sends and + val connect : sw:Eio.Switch.t -> restore:Restorer.t -> ?tags:Logs.Tag.set -> Endpoint.t -> t + (** [connect ~sw ~restore ~switch endpoint] is fresh CapTP protocol handler that sends and receives messages using [endpoint]. [restore] is used to respond to "Bootstrap" messages. - If the connection fails then [endpoint] will be disconnected. *) + If the connection fails then [endpoint] will be disconnected. + You must call {!listen} to run the loop handling messages. + @param sw Used to run methods and to run the transmit thread. *) + + val listen : t -> unit + (** [listen t] reads and handles incoming messages until the connection is finished. *) val bootstrap : t -> string -> 'a Capnp_rpc_lwt.Capability.t (** [bootstrap t object_id] is the peer's bootstrap object [object_id], if any. Use [object_id = ""] for the main, public object. *) - val disconnect : t -> Capnp_rpc.Exception.t -> unit Lwt.t + val disconnect : t -> Capnp_rpc.Exception.t -> unit (** [disconnect t reason] releases all resources used by the connection. *) val disconnecting : t -> bool diff --git a/capnp-rpc-net/capnp_rpc_net.ml b/capnp-rpc-net/capnp_rpc_net.ml index 21a1045da..a3d8354b9 100644 --- a/capnp-rpc-net/capnp_rpc_net.ml +++ b/capnp-rpc-net/capnp_rpc_net.ml @@ -11,11 +11,9 @@ module type VAT_NETWORK = S.VAT_NETWORK with type service_id := Restorer.Id.t and type 'a sturdy_ref := 'a Sturdy_ref.t -module Networking (N : S.NETWORK) (F : Mirage_flow.S) = struct - type flow = F.flow - +module Networking (N : S.NETWORK) = struct module Network = N - module Vat = Vat.Make (N) (F) + module Vat = Vat.Make (N) module CapTP = Vat.CapTP end diff --git a/capnp-rpc-net/capnp_rpc_net.mli b/capnp-rpc-net/capnp_rpc_net.mli index 5ad7a3f3a..c3637a07c 100644 --- a/capnp-rpc-net/capnp_rpc_net.mli +++ b/capnp-rpc-net/capnp_rpc_net.mli @@ -90,7 +90,7 @@ module Restorer : sig (** [make_sturdy t id] converts an ID to a full URI, by adding the hosting vat's address and fingerprint. *) - val load : t -> 'a Sturdy_ref.t -> string -> resolution Lwt.t + val load : t -> 'a Sturdy_ref.t -> string -> resolution (** [load t sr digest] is called to restore the service with key [digest]. [sr] is a sturdy ref that refers to the service, which the service might want to hand out to clients. @@ -109,9 +109,10 @@ module Restorer : sig [make_sturdy id] converts an ID to a full URI, by adding the hosting vat's address and fingerprint. *) - val of_loader : (module LOADER with type t = 'loader) -> 'loader -> t - (** [of_loader (module Loader) l] is a new caching table that uses - [Loader.load l sr (Loader.hash id)] to restore services that aren't in the cache. *) + val of_loader : sw:Eio.Switch.t -> (module LOADER with type t = 'loader) -> 'loader -> t + (** [of_loader ~sw (module Loader) l] is a new caching table that uses + [Loader.load l sr (Loader.hash id)] to restore services that aren't in the cache. + The load function runs in a new fiber in [sw]. *) val add : t -> Id.t -> 'a Capability.t -> unit (** [add t id cap] adds a mapping to [t]. @@ -130,7 +131,7 @@ module Restorer : sig val of_table : Table.t -> t - val restore : t -> Id.t -> ('a Capability.t, Capnp_rpc.Exception.t) result Lwt.t + val restore : t -> Id.t -> ('a Capability.t, Capnp_rpc.Exception.t) result (** [restore t id] restores [id] using [t]. You don't normally need to call this directly, as the Vat will do it automatically. *) end @@ -141,8 +142,7 @@ module type VAT_NETWORK = S.VAT_NETWORK with type service_id := Restorer.Id.t and type 'a sturdy_ref := 'a Sturdy_ref.t -module Networking (N : S.NETWORK) (Flow : Mirage_flow.S) : VAT_NETWORK with - module Network = N and - type flow = Flow.flow +module Networking (N : S.NETWORK) : VAT_NETWORK with + module Network = N module Capnp_address = Capnp_address diff --git a/capnp-rpc-net/endpoint.ml b/capnp-rpc-net/endpoint.ml index 1ae0e03de..8918127c7 100644 --- a/capnp-rpc-net/endpoint.ml +++ b/capnp-rpc-net/endpoint.ml @@ -1,5 +1,3 @@ -open Lwt.Infix - let src = Logs.Src.create "endpoint" ~doc:"Send and receive Cap'n'Proto messages" module Log = (val Logs.src_log src: Logs.LOG) @@ -7,21 +5,20 @@ let compression = `None let record_sent_messages = false -type flow = Flow : (module Mirage_flow.S with type flow = 'a) * 'a -> flow +type flow = Eio.Flow.two_way type t = { flow : flow; decoder : Capnp.Codecs.FramedStream.t; - switch : Lwt_switch.t; peer_id : Auth.Digest.t; } let peer_id t = t.peer_id -let of_flow (type flow) ~switch ~peer_id (module F : Mirage_flow.S with type flow = flow) (flow:flow) = - let generic_flow = Flow ((module F), flow) in +let of_flow ~peer_id flow = let decoder = Capnp.Codecs.FramedStream.empty compression in - { flow = generic_flow; decoder; switch; peer_id } + let flow = (flow :> Eio.Flow.two_way) in + { flow; decoder; peer_id } let dump_msg = let next = ref 0 in @@ -34,37 +31,36 @@ let dump_msg = close_out ch let send t msg = - let (Flow ((module F), flow)) = t.flow in let data = Capnp.Codecs.serialize ~compression msg in if record_sent_messages then dump_msg data; - F.write flow (Cstruct.of_string data) >|= function - | Ok () - | Error `Closed as e -> e - | Error e -> Error (`Msg (Fmt.to_to_string F.pp_write_error e)) + match Eio.Flow.copy_string data t.flow with + | () + | exception End_of_file -> Ok () + | exception Eio.Net.Connection_reset ex -> + Log.info (fun f -> f "Connection reset: %a" Fmt.exn ex); + Error `Closed + | exception ex -> + Eio.Fiber.check (); + Error (`Msg (Printexc.to_string ex)) let rec recv t = - let (Flow ((module F), flow)) = t.flow in match Capnp.Codecs.FramedStream.get_next_frame t.decoder with - | _ when not (Lwt_switch.is_on t.switch) -> Lwt.return @@ Error `Closed - | Ok msg -> Lwt.return (Ok (Capnp.BytesMessage.Message.readonly msg)) + | Ok msg -> Ok (Capnp.BytesMessage.Message.readonly msg) | Error Capnp.Codecs.FramingError.Unsupported -> failwith "Unsupported Cap'n'Proto frame received" | Error Capnp.Codecs.FramingError.Incomplete -> Log.debug (fun f -> f "Incomplete; waiting for more data..."); - F.read flow >>= function - | Ok (`Data data) -> - Log.debug (fun f -> f "Read %d bytes" (Cstruct.length data)); - Capnp.Codecs.FramedStream.add_fragment t.decoder (Cstruct.to_string data); + let buf = Cstruct.create 4096 in (* TODO: make this efficient *) + match Eio.Flow.read t.flow buf with + | got -> + Log.debug (fun f -> f "Read %d bytes" got); + Capnp.Codecs.FramedStream.add_fragment t.decoder (Cstruct.to_string buf ~len:got); recv t - | Ok `Eof -> + | exception End_of_file -> Log.info (fun f -> f "Connection closed"); - Lwt_switch.turn_off t.switch >|= fun () -> Error `Closed - | Error ex when Lwt_switch.is_on t.switch -> Capnp_rpc.Debug.failf "recv: %a" F.pp_error ex - | Error _ -> Lwt.return (Error `Closed) + | exception Eio.Net.Connection_reset ex -> + Log.info (fun f -> f "Connection reset: %a" Fmt.exn ex); + Error `Closed let disconnect t = - Lwt_switch.turn_off t.switch - -let pp_error f = function - | `Closed -> Fmt.string f "Connection closed" - | `Msg m -> Fmt.string f m + Eio.Flow.shutdown t.flow `All diff --git a/capnp-rpc-net/endpoint.mli b/capnp-rpc-net/endpoint.mli index 1fe0b9e01..23f366cce 100644 --- a/capnp-rpc-net/endpoint.mli +++ b/capnp-rpc-net/endpoint.mli @@ -6,27 +6,19 @@ val src : Logs.src type t (** A wrapper for a byte-stream (flow). *) -val send : t -> 'a Capnp.BytesMessage.Message.t -> (unit, [`Closed | `Msg of string]) result Lwt.t +val send : t -> 'a Capnp.BytesMessage.Message.t -> (unit, [`Closed | `Msg of string]) result (** [send t msg] transmits [msg]. *) -val recv : t -> (Capnp.Message.ro Capnp.BytesMessage.Message.t, [> `Closed]) result Lwt.t +val recv : t -> (Capnp.Message.ro Capnp.BytesMessage.Message.t, [> `Closed]) result (** [recv t] reads the next message from the remote peer. - It returns [Error `Closed] if the connection to the peer is lost - (this will also happen if the switch is turned off). *) + It returns [Error `Closed] if the connection to the peer is lost. *) -val of_flow : switch:Lwt_switch.t -> peer_id:Auth.Digest.t -> - (module Mirage_flow.S with type flow = 'flow) -> 'flow -> t -(** [of_flow ~switch ~peer_id (module F) flow] sends and receives on [flow]. - The caller should arrange for [flow] to be closed when the switch is turned off. - If the flow is closed, the switch will be turned off. - If the flow returns an error when the switch is off, the endpoint will return [`Closed] - instead of the underlying error. *) +val of_flow : peer_id:Auth.Digest.t -> #Eio.Flow.two_way -> t +(** [of_flow ~peer_id flow] sends and receives on [flow]. *) val peer_id : t -> Auth.Digest.t (** [peer_id t] is the fingerprint of the peer's public key, - or [Auth.Digest.insecure] TLS isn't being used. *) + or [Auth.Digest.insecure] if TLS isn't being used. *) -val disconnect : t -> unit Lwt.t -(** [disconnect t] turns off [t]'s switch. *) - -val pp_error : [< `Closed | `Msg of string] Fmt.t +val disconnect : t -> unit +(** [disconnect t] shuts down the underlying flow. *) diff --git a/capnp-rpc-net/restorer.ml b/capnp-rpc-net/restorer.ml index 41198fdf4..f0a8e8fcd 100644 --- a/capnp-rpc-net/restorer.ml +++ b/capnp-rpc-net/restorer.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc_lwt module Core_types = Private.Capnp_core.Core_types @@ -33,10 +33,10 @@ module type LOADER = sig type t val hash : t -> Auth.hash val make_sturdy : t -> Id.t -> Uri.t - val load : t -> 'a Sturdy_ref.t -> string -> resolution Lwt.t + val load : t -> 'a Sturdy_ref.t -> string -> resolution end -type t = Id.t -> resolution Lwt.t +type t = Id.t -> resolution let grant x : resolution = Ok (Cast.cap_to_raw x) let reject ex = Error ex @@ -45,21 +45,16 @@ let unknown_service_id = reject (Capnp_rpc.Exception.v "Unknown persistent servi let fn (r:t) = fun k object_id -> - Lwt.async (fun () -> - Lwt.try_bind - (fun () -> r object_id) - (fun r -> k r; Lwt.return_unit) - (fun ex -> - Log.err (fun f -> f "Uncaught exception restoring object: %a" Fmt.exn ex); - k (reject (Capnp_rpc.Exception.v "Internal error restoring object")); - Lwt.return_unit - ) - ) - -let restore (f:t) x = f x |> Lwt_result.map Cast.cap_of_raw + match r object_id with + | r -> k r + | exception ex -> + Log.err (fun f -> f "Uncaught exception restoring object: %a" Fmt.exn ex); + k (reject (Capnp_rpc.Exception.v "Internal error restoring object")) + +let restore (f:t) x = f x |> Result.map Cast.cap_of_raw let none : t = fun _ -> - Lwt.return @@ Error (Capnp_rpc.Exception.v "This vat has no restorer") + Error (Capnp_rpc.Exception.v "This vat has no restorer") let single id cap = let cap = Cast.cap_to_raw cap in @@ -69,20 +64,20 @@ let single id cap = let requested_id = Mirage_crypto.Hash.digest `SHA256 (Cstruct.of_string requested_id) in if Cstruct.equal id requested_id then ( Core_types.inc_ref cap; - Lwt.return (Ok cap) - ) else Lwt.return unknown_service_id + Ok cap + ) else unknown_service_id module Table = struct type digest = string type entry = - | Cached of resolution Lwt.t + | Cached of resolution Promise.or_exn | Manual of Core_types.cap (* We hold a ref on the cap *) type t = { hash : Mirage_crypto.Hash.hash; cache : (digest, entry) Hashtbl.t; - load : Id.t -> digest -> resolution Lwt.t; + load : Id.t -> digest -> resolution Promise.or_exn; make_sturdy : Id.t -> Uri.t; } @@ -91,7 +86,7 @@ module Table = struct let create make_sturdy = let hash = `SHA256 in let cache = Hashtbl.create 53 in - let load _ _ = Lwt.return unknown_service_id in + let load _ _ = Promise.create_resolved (Ok unknown_service_id) in { hash; cache; load; make_sturdy } let hash t id = @@ -102,43 +97,42 @@ module Table = struct match Hashtbl.find t.cache digest with | Manual cap -> Core_types.inc_ref cap; - Lwt.return @@ Ok cap + Ok cap | Cached res -> - begin res >>= function - | Error _ as e -> Lwt.return e + begin match Promise.await_exn res with + | Error _ as e -> e | Ok cap -> Core_types.inc_ref cap; - Lwt.pause () >|= fun () -> + Fiber.yield (); Ok cap end | exception Not_found -> let cap = t.load id digest in Hashtbl.add t.cache digest (Cached cap); - Lwt.try_bind - (fun () -> cap) - (fun result -> - begin match result with - | Error _ -> Hashtbl.remove t.cache digest - | Ok cap -> cap#when_released (fun () -> Hashtbl.remove t.cache digest) - end; - (* Ensure all [inc_ref]s are done before handing over to the user. *) - Lwt.pause () >|= fun () -> - result - ) - (fun ex -> - Hashtbl.remove t.cache digest; - Lwt.fail ex - ) - - let of_loader (type l) (module L : LOADER with type t = l) loader = + match Promise.await_exn cap with + | result -> + begin match result with + | Error _ -> Hashtbl.remove t.cache digest + | Ok cap -> cap#when_released (fun () -> Hashtbl.remove t.cache digest) + end; + (* Ensure all [inc_ref]s are done before handing over to the user. *) + Fiber.yield (); + result + | exception ex -> + Hashtbl.remove t.cache digest; + raise ex + + let of_loader (type l) ~sw (module L : LOADER with type t = l) loader = let hash = (L.hash loader :> Mirage_crypto.Hash.hash) in let cache = Hashtbl.create 53 in let rec load id digest = - let sr : Private.Capnp_core.sturdy_ref = object - method connect = resolve t id - method to_uri_with_secrets = L.make_sturdy loader id - end in - L.load loader (Cast.sturdy_of_raw sr) digest + Fiber.fork_promise ~sw (fun () -> + let sr : Private.Capnp_core.sturdy_ref = object + method connect = resolve t id + method to_uri_with_secrets = L.make_sturdy loader id + end in + L.load loader (Cast.sturdy_of_raw sr) digest + ) and t = { hash; cache; load; make_sturdy = L.make_sturdy loader } in t diff --git a/capnp-rpc-net/s.ml b/capnp-rpc-net/s.ml index 261ec0df6..bda4ddec5 100644 --- a/capnp-rpc-net/s.ml +++ b/capnp-rpc-net/s.ml @@ -29,15 +29,14 @@ module type NETWORK = sig val connect : t -> - switch:Lwt_switch.t -> + sw:Eio.Switch.t -> secret_key:Auth.Secret_key.t Lazy.t -> Address.t -> - (Endpoint.t, [> `Msg of string]) result Lwt.t - (** [connect t ~switch ~secret_key address] connects to [address], proves ownership of + (Endpoint.t, [> `Msg of string]) result + (** [connect t ~sw ~secret_key address] connects to [address], proves ownership of [secret_key] (if TLS is being used), and returns the resulting endpoint. Returns an error if no connection can be established or the target fails - to authenticate itself. - If [switch] is turned off, the connection should be terminated. *) + to authenticate itself. *) val parse_third_party_cap_id : Private.Schema.Reader.pointer_t -> Types.third_party_cap_id end @@ -49,9 +48,6 @@ module type VAT_NETWORK = sig type +'a capability (** An ['a capability] is a capability reference to a service of type ['a]. *) - type flow - (** A bi-directional byte-stream. *) - type restorer (** A function for restoring persistent capabilities from sturdy ref service IDs. *) @@ -69,17 +65,22 @@ module type VAT_NETWORK = sig type t (** A CapTP connection to a remote peer. *) - val connect : restore:restorer -> ?tags:Logs.Tag.set -> Endpoint.t -> t - (** [connect ~restore ~switch endpoint] is fresh CapTP protocol handler that sends and + val connect : sw:Eio.Switch.t -> restore:restorer -> ?tags:Logs.Tag.set -> Endpoint.t -> t + (** [connect ~sw ~restore ~switch endpoint] is fresh CapTP protocol handler that sends and receives messages using [endpoint]. [restore] is used to respond to "Bootstrap" messages. - If the connection fails then [endpoint] will be disconnected. *) + If the connection fails then [endpoint] will be disconnected. + You must call {!listen} to run the loop handling messages. + @param sw Used to run methods and to run the transmit thread. *) + + val listen : t -> unit + (** [listen t] reads and handles incoming messages until the connection is finished. *) val bootstrap : t -> service_id -> 'a capability (** [bootstrap t object_id] is the peer's bootstrap object [object_id], if any. Use [object_id = ""] for the main, public object. *) - val disconnect : t -> Capnp_rpc.Exception.t -> unit Lwt.t + val disconnect : t -> Capnp_rpc.Exception.t -> unit (** [disconnect reason] closes the connection, sending [reason] to the peer to explain why. Capabilities and questions at both ends will break, with [reason] as the problem. *) @@ -99,23 +100,21 @@ module type VAT_NETWORK = sig (** A local Vat. *) val create : - ?switch:Lwt_switch.t -> ?tags:Logs.Tag.set -> ?restore:restorer -> ?address:Network.Address.t -> + sw:Eio.Switch.t -> secret_key:Auth.Secret_key.t Lazy.t -> Network.t -> t - (** [create ~switch ~restore ~address ~secret_key network] is a new Vat that + (** [create ~sw ~restore ~address ~secret_key network] is a new Vat that uses [restore] to restore sturdy refs hosted at this vat to live capabilities for peers. The Vat will suggest that other parties connect to it using [address]. Turning off the switch will disconnect any active connections. *) - val add_connection : t -> switch:Lwt_switch.t -> mode:[`Accept|`Connect] -> Endpoint.t -> CapTP.t Lwt.t - (** [add_connection t ~switch ~mode endpoint] runs the CapTP protocol over [endpoint], + val add_connection : t -> mode:[`Accept|`Connect] -> Endpoint.t -> CapTP.t + (** [add_connection t ~mode endpoint] runs the CapTP protocol over [endpoint], which is a connection to another vat. - When the connection ends, [switch] will be turned off, and turning off [switch] will - end the connection. [mode] is used if two Vats connect to each other at the same time to decide which connection to drop. Use [`Connect] if [t] initiated the new connection. Note that [add_connection] may return an existing connection. *) diff --git a/capnp-rpc-net/tls_eio.ml b/capnp-rpc-net/tls_eio.ml new file mode 100644 index 000000000..308743e62 --- /dev/null +++ b/capnp-rpc-net/tls_eio.ml @@ -0,0 +1,228 @@ +module Flow = Eio.Flow + +type error = [ `Tls_alert of Tls.Packet.alert_type + | `Tls_failure of Tls.Engine.failure + | `Read of exn + | `Write of exn ] + +type write_error = [ `Closed | error ] +(** The type for write errors. *) + +let pp_error ppf = function + | `Tls_failure f -> Fmt.string ppf @@ Tls.Engine.string_of_failure f + | `Tls_alert a -> Fmt.string ppf @@ Tls.Packet.alert_type_to_string a + | `Read ex -> Fmt.exn ppf @@ ex + | `Write ex -> Fmt.exn ppf @@ ex + +let pp_write_error ppf = function + | #error as e -> pp_error ppf e + | `Closed -> Fmt.string ppf "Closed" + +type flow = { + role : [ `Server | `Client ]; + flow : Flow.two_way; + mutable state : [ `Active of Tls.Engine.state + | `Eof + | `Error of error ]; + mutable linger : Cstruct.t list; +} + +let tls_alert a = `Error (`Tls_alert a) +let tls_fail f = `Error (`Tls_failure f) + +let lift_read_result = function + | Ok (`Data _ | `Eof as x) -> x + | Error e -> `Error (`Read e) + +let lift_write_result = function + | Ok () -> `Ok () + | Error e -> `Error (`Write e) + +let check_write flow f_res = + let res = lift_write_result f_res in + ( match flow.state, res with + | `Active _, (`Eof | `Error _ as e) -> + flow.state <- e ; Flow.shutdown flow.flow `All + | _ -> () ); + match f_res with + | Ok () -> Ok () + | Error e -> Error (`Write e :> write_error) + +let copy src dst = + match Flow.copy src dst with + | () -> Ok () + | exception ex -> Error ex + +let read_react flow = + match flow.state with + | `Eof | `Error _ as e -> e + | `Active _ -> + let cbuf = Cstruct.create 4096 in + match Flow.read flow.flow cbuf with + | exception (End_of_file as ex) -> flow.state <- `Eof; raise ex + | exception ex -> flow.state <- `Error (`Read ex); raise ex + | got -> + match flow.state with + | `Eof | `Error _ as e -> e + | `Active tls -> + let cbuf = Cstruct.sub cbuf 0 got in + match Tls.Engine.handle_tls tls cbuf with + | Ok (res, `Response resp, `Data data) -> + flow.state <- ( match res with + | `Ok tls -> `Active tls + | `Eof -> `Eof + | `Alert alert -> tls_alert alert ); + let _ = + match resp with + | None -> Ok () + | Some buf -> copy (Flow.cstruct_source [buf]) flow.flow |> check_write flow + in + ( match res with + | `Ok _ -> () + | _ -> Flow.shutdown flow.flow `All); + `Ok data + | Error (fail, `Response resp) -> + let reason = tls_fail fail in + flow.state <- reason ; + begin try Flow.copy (Flow.cstruct_source [resp]) flow.flow with _ -> () end; + Flow.shutdown flow.flow `All; + reason + +let rec read_into t buf = + let got, bufs = Cstruct.fillv ~src:t.linger ~dst:buf in + t.linger <- bufs; + if got > 0 then got + else ( + read_react t |> function + | `Ok None -> read_into t buf + | `Ok (Some next) -> + t.linger <- t.linger @ [next]; + read_into t buf + | `Eof -> raise End_of_file + | `Error (`Read ex | `Write ex) -> raise ex + | `Error e -> raise (Failure (Fmt.to_to_string pp_error e)) + ) + +let writev flow bufs = + match flow.state with + | `Eof -> Error `Closed + | `Error e -> Error (e :> write_error) + | `Active tls -> + match Tls.Engine.send_application_data tls bufs with + | Some (tls, answer) -> + flow.state <- `Active tls ; + copy (Flow.cstruct_source [answer]) flow.flow |> check_write flow + | None -> + (* "Impossible" due to handshake draining. *) + assert false + +let write flow buf = writev flow [buf] + +let close flow = + match flow.state with + | `Active tls -> + flow.state <- `Eof ; + let (_, buf) = Tls.Engine.send_close_notify tls in + (* XXX: need a switch here *) + copy (Flow.cstruct_source [buf]) flow.flow |> fun _ -> Flow.shutdown flow.flow `All + | _ -> () + +(* + * XXX bad XXX + * This is a point that should particularly be protected from concurrent r/w. + * Doing this before a `t` is returned is safe; redoing it during rekeying is + * not, as the API client already sees the `t` and can mistakenly interleave + * writes while this is in progress. + * *) +let rec drain_handshake flow = + match flow.state with + | `Active tls when not (Tls.Engine.handshake_in_progress tls) -> () + | _ -> + (* read_react re-throws *) + read_react flow |> function + | `Ok mbuf -> + flow.linger <- Option.to_list mbuf @ flow.linger ; + drain_handshake flow + | `Error e -> raise (Failure (Fmt.to_to_string pp_write_error e)) + | `Eof -> raise End_of_file + +let wrap flow = + object (_ : ) + method probe _ = None + + method read_into buf = read_into flow buf + + method read_methods = [] (* TODO: this would be faster *) + + method copy src = + let buf = Cstruct.create 4096 in + let got = Flow.read src buf in (* XXX: wrap errors? *) + match write flow (Cstruct.sub buf 0 got) with + | Ok () -> () + | Error `Closed -> raise End_of_file + | Error (`Read ex | `Write ex) -> raise ex + | Error err -> raise (Failure (Fmt.to_to_string pp_write_error err)) + + method epoch = + match flow.state with + | `Eof | `Error _ -> Error () + | `Active tls -> + match Tls.Engine.epoch tls with + | `InitialEpoch -> assert false (* `drain_handshake` invariant. *) + | `Epoch e -> Ok e + + method reneg ?authenticator ?acceptable_cas ?cert ?(drop = true) () = + match flow.state with + | `Eof -> raise End_of_file + | `Error e -> raise (Failure (Fmt.to_to_string pp_write_error e)) + | `Active tls -> + match Tls.Engine.reneg ?authenticator ?acceptable_cas ?cert tls with + | None -> + (* XXX make this impossible to reach *) + invalid_arg "Renegotiation already in progress" + | Some (tls', buf) -> + if drop then flow.linger <- [] ; + flow.state <- `Active tls' ; + copy (Flow.cstruct_source [buf]) flow.flow |> fun _ -> + drain_handshake flow + + method key_update ?request () = + match flow.state with + | `Eof -> Error `Closed + | `Error e -> Error (e :> write_error) + | `Active tls -> + match Tls.Engine.key_update ?request tls with + | Error _ -> invalid_arg "Key update failed" + | Ok (tls', buf) -> + flow.state <- `Active tls' ; + copy (Flow.cstruct_source [buf]) flow.flow |> check_write flow + + method close = close flow + method shutdown _ = close flow (* XXX *) + end + +let client_of_flow conf ?host flow = + let conf' = match host with + | None -> conf + | Some host -> Tls.Config.peer conf host + in + let (tls, init) = Tls.Engine.client conf' in + let tls_flow = { + role = `Client ; + flow = flow ; + state = `Active tls ; + linger = [] ; + } in + copy (Flow.cstruct_source [init]) flow |> fun _ -> + drain_handshake tls_flow; + wrap tls_flow + +let server_of_flow conf flow = + let tls_flow = { + role = `Server ; + flow = flow ; + state = `Active (Tls.Engine.server conf) ; + linger = [] ; + } in + drain_handshake tls_flow; + wrap tls_flow diff --git a/capnp-rpc-net/tls_wrapper.ml b/capnp-rpc-net/tls_wrapper.ml index 25e615ebc..1d98aad31 100644 --- a/capnp-rpc-net/tls_wrapper.ml +++ b/capnp-rpc-net/tls_wrapper.ml @@ -1,57 +1,42 @@ module Log = Capnp_rpc.Debug.Log -open Lwt.Infix open Auth let error fmt = fmt |> Fmt.kstr @@ fun msg -> Error (`Msg msg) -module Make (Underlying : Mirage_flow.S) = struct - module Flow = struct - include Tls_mirage.Make(Underlying) - - let read flow = - read flow >|= function - | Error (`Write `Closed) -> Ok `Eof (* This can happen, despite being a write error on a read! *) - | x -> x - - let writev flow bufs = - writev flow bufs >|= function - | Error (`Write `Closed) -> Error `Closed - | x -> x - - let write flow buf = writev flow [buf] - end - - let plain_endpoint ~switch flow = - Endpoint.of_flow ~switch ~peer_id:Auth.Digest.insecure (module Underlying) flow - - let connect_as_server ~switch flow secret_key = - match secret_key with - | None -> Lwt.return @@ Ok (plain_endpoint ~switch flow) - | Some key -> - Log.info (fun f -> f "Doing TLS server-side handshake..."); - let tls_config = Secret_key.tls_server_config key in - Flow.server_of_flow tls_config flow >|= function - | Error e -> error "TLS connection failed: %a" Flow.pp_write_error e - | Ok flow -> - match Flow.epoch flow with - | Error () -> failwith "Unknown error getting TLS epoch data" - | Ok data -> - match data.Tls.Core.peer_certificate with - | None -> error "No client certificate found" - | Some client_cert -> - let peer_id = Digest.of_certificate client_cert in - Ok (Endpoint.of_flow ~switch ~peer_id (module Flow) flow) - - let connect_as_client ~switch flow secret_key auth = - match Digest.authenticator auth with - | None -> Lwt.return @@ Ok (plain_endpoint ~switch flow) - | Some authenticator -> - let tls_config = Secret_key.tls_client_config ~authenticator (Lazy.force secret_key) in - Log.info (fun f -> f "Doing TLS client-side handshake..."); - Flow.client_of_flow tls_config flow >|= function - | Error e -> error "TLS connection failed: %a" Flow.pp_write_error e - | Ok flow -> Ok (Endpoint.of_flow ~switch ~peer_id:auth (module Flow) flow) -end +let plain_endpoint flow = + Endpoint.of_flow ~peer_id:Auth.Digest.insecure flow + +let connect_as_server flow secret_key = + let flow = (flow :> Eio.Flow.two_way) in + match secret_key with + | None -> Ok (plain_endpoint flow) + | Some key -> + Log.info (fun f -> f "Doing TLS server-side handshake..."); + let tls_config = Secret_key.tls_server_config key in + match Tls_eio.server_of_flow tls_config flow with + | exception (Failure msg) -> error "TLS connection failed: %s" msg + | exception ex -> error "TLS connection failed: %a" Fmt.exn ex + | flow -> + match flow#epoch with + | Error () -> failwith "Unknown error getting TLS epoch data" + | Ok data -> + match data.Tls.Core.peer_certificate with + | None -> error "No client certificate found" + | Some client_cert -> + let peer_id = Digest.of_certificate client_cert in + Ok (Endpoint.of_flow ~peer_id flow) + +let connect_as_client flow secret_key auth = + let flow = (flow :> Eio.Flow.two_way) in + match Digest.authenticator auth with + | None -> Ok (plain_endpoint flow) + | Some authenticator -> + let tls_config = Secret_key.tls_client_config ~authenticator (Lazy.force secret_key) in + Log.info (fun f -> f "Doing TLS client-side handshake..."); + match Tls_eio.client_of_flow tls_config flow with + | exception (Failure msg) -> error "TLS connection failed: %s" msg + | exception ex -> error "TLS connection failed: %a" Fmt.exn ex + | flow -> Ok (Endpoint.of_flow ~peer_id:auth flow) diff --git a/capnp-rpc-net/tls_wrapper.mli b/capnp-rpc-net/tls_wrapper.mli index f99c7b562..a64d902b6 100644 --- a/capnp-rpc-net/tls_wrapper.mli +++ b/capnp-rpc-net/tls_wrapper.mli @@ -1,17 +1,12 @@ open Auth -module Make (Underlying : Mirage_flow.S) : sig - (** Make an [Endpoint] from an [Underlying.flow], using TLS if appropriate. *) - - val connect_as_server : - switch:Lwt_switch.t -> Underlying.flow -> Auth.Secret_key.t option -> - (Endpoint.t, [> `Msg of string]) result Lwt.t - - val connect_as_client : - switch:Lwt_switch.t -> Underlying.flow -> Auth.Secret_key.t Lazy.t -> Digest.t -> - (Endpoint.t, [> `Msg of string]) result Lwt.t - (** [connect_as_client ~switch underlying key digest] is an endpoint using flow [underlying]. - If [digest] requires TLS, it performs a TLS handshake. It uses [key] as its private key - and checks that the server is the one required by [auth]. *) -end - +val connect_as_server : + #Eio.Flow.two_way -> Auth.Secret_key.t option -> + (Endpoint.t, [> `Msg of string]) result + +val connect_as_client : + #Eio.Flow.two_way -> Auth.Secret_key.t Lazy.t -> Digest.t -> + (Endpoint.t, [> `Msg of string]) result +(** [connect_as_client underlying key digest] is an endpoint using flow [underlying]. + If [digest] requires TLS, it performs a TLS handshake. It uses [key] as its private key + and checks that the server is the one required by [auth]. *) diff --git a/capnp-rpc-net/two_party_network.ml b/capnp-rpc-net/two_party_network.ml index d800c9f40..a407d96d5 100644 --- a/capnp-rpc-net/two_party_network.ml +++ b/capnp-rpc-net/two_party_network.ml @@ -22,4 +22,4 @@ type t = unit let parse_third_party_cap_id _ = `Two_party_only -let connect () ~switch:_ ~secret_key:_ _ = assert false +let connect () ~sw:_ ~secret_key:_ _ = assert false diff --git a/capnp-rpc-net/vat.ml b/capnp-rpc-net/vat.ml index bf1d7c0ff..e9bf1abad 100644 --- a/capnp-rpc-net/vat.ml +++ b/capnp-rpc-net/vat.ml @@ -1,83 +1,97 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc_lwt module Log = Capnp_rpc.Debug.Log module ID_map = Auth.Digest.Map -module Make (Network : S.NETWORK) (Underlying : Mirage_flow.S) = struct +module Condition = struct + type t = (unit Promise.t * unit Promise.u) ref + + let create () : t = ref (Promise.create ()) + + let await t = Promise.await (fst !t) + + let notify t = + Promise.resolve (snd !t) (); + t := Promise.create () +end + +module Make (Network : S.NETWORK) = struct module CapTP = CapTP_capnp.Make (Network) let hash = `SHA256 (* Only support a single hash for now *) - type connection_attempt = (CapTP.t, Capnp_rpc.Exception.t) result Lwt.t + type connection_attempt = (CapTP.t, Capnp_rpc.Exception.t) result Eio.Promise.or_exn type t = { + sw : Eio.Switch.t; network : Network.t; - switch : Lwt_switch.t option; secret_key : Auth.Secret_key.t Lazy.t; address : Network.Address.t option; restore : Restorer.t; tags : Logs.Tag.set; - connection_removed : unit Lwt_condition.t; (* Fires when a connection is removed *) + connection_removed : Condition.t; (* Fires when a connection is removed *) mutable connecting : connection_attempt ID_map.t; (* Out-going connections being attempted. *) mutable connections : CapTP.t ID_map.t; (* Accepted connections *) mutable anon_connections : CapTP.t list; (* Connections not using TLS. *) } - let create ?switch ?(tags=Logs.Tag.empty) ?(restore=Restorer.none) ?address ~secret_key network = + let create ?(tags=Logs.Tag.empty) ?(restore=Restorer.none) ?address ~sw ~secret_key network = let t = { + sw; network; - switch; secret_key; address; restore; tags; - connection_removed = Lwt_condition.create (); + connection_removed = Condition.create (); connecting = ID_map.empty; connections = ID_map.empty; anon_connections = []; } in - Lwt_switch.add_hook switch (fun () -> + Switch.on_release sw (fun () -> let ex = Capnp_rpc.Exception.v ~ty:`Disconnected "Vat shut down" in - ID_map.bindings t.connections |> Lwt_list.iter_p (fun (_, c) -> CapTP.disconnect c ex) >>= fun () -> + ID_map.bindings t.connections |> Fiber.iter (fun (_, c) -> CapTP.disconnect c ex); t.connections <- ID_map.empty; - Lwt_list.iter_p (fun c -> CapTP.disconnect c ex) t.anon_connections >|= fun () -> + Fiber.iter (fun c -> CapTP.disconnect c ex) t.anon_connections; t.anon_connections <- []; - ID_map.iter (fun _ th -> Lwt.cancel th) t.connecting; - t.connecting <- ID_map.empty; + (* If sw is being released then the connection fibers must have finished. *) + assert (ID_map.is_empty t.connecting); ); t - let add_tls_connection t ~switch endpoint = - let conn = CapTP.connect ~tags:t.tags ~restore:t.restore endpoint in - let peer_id = Endpoint.peer_id endpoint in - t.connections <- ID_map.add peer_id conn t.connections; - Lwt_switch.add_hook (Some switch) (fun () -> - begin match ID_map.find peer_id t.connections with - | Some x when x == conn -> t.connections <- ID_map.remove peer_id t.connections - | Some _ (* Already replaced by a new one? *) - | None -> () - end; - CapTP.disconnect conn (Capnp_rpc.Exception.v ~ty:`Disconnected "Switch turned off") >|= fun () -> - Lwt_condition.broadcast t.connection_removed () + let spawn_connection t ~add ~remove endpoint = + let conn = CapTP.connect ~sw:t.sw ~tags:t.tags ~restore:t.restore endpoint in + Fiber.fork ~sw:t.sw (fun () -> + add conn; + Fun.protect (fun () -> CapTP.listen conn) + ~finally:(fun () -> + remove conn; + Condition.notify t.connection_removed + ) ); conn - let add_connection t ~switch ~(mode:[`Accept|`Connect]) endpoint = - let tags = t.tags in + let add_tls_connection t endpoint = + let peer_id = Endpoint.peer_id endpoint in + spawn_connection t endpoint + ~add:(fun conn -> t.connections <- ID_map.add peer_id conn t.connections) + ~remove:(fun conn -> + match ID_map.find peer_id t.connections with + | Some x when x == conn -> t.connections <- ID_map.remove peer_id t.connections + | Some _ (* Already replaced by a new one? *) + | None -> () + ) + + let add_connection t ~(mode:[`Accept|`Connect]) endpoint = let peer_id = Endpoint.peer_id endpoint in if peer_id = Auth.Digest.insecure then ( - let conn = CapTP.connect ~tags ~restore:t.restore endpoint in - t.anon_connections <- conn :: t.anon_connections; - Lwt_switch.add_hook (Some switch) (fun () -> - t.anon_connections <- List.filter ((!=) conn) t.anon_connections; - CapTP.disconnect conn (Capnp_rpc.Exception.v ~ty:`Disconnected "Switch turned off") >|= fun () -> - Lwt_condition.broadcast t.connection_removed () - ); - Lwt.return conn + spawn_connection t endpoint + ~add:(fun conn -> t.anon_connections <- conn :: t.anon_connections) + ~remove:(fun conn -> t.anon_connections <- List.filter ((!=) conn) t.anon_connections) ) else match ID_map.find peer_id t.connections with - | None -> Lwt.return @@ add_tls_connection t ~switch endpoint + | None -> add_tls_connection t endpoint | Some existing -> Log.info (fun f -> f ~tags:t.tags "Trying to add a connection, but we already have one for this vat"); (* This can happen if two vats call each other at exactly the same time. @@ -127,35 +141,34 @@ module Make (Network : S.NETWORK) (Underlying : Mirage_flow.S) = struct let my_id = Auth.Secret_key.digest ~hash (Lazy.force t.secret_key) in let keep_new = (my_id > peer_id) = (mode = `Connect) in if keep_new then ( - let conn = add_tls_connection t ~switch endpoint in + let conn = add_tls_connection t endpoint in let reason = Capnp_rpc.Exception.v "Closing duplicate connection" in - CapTP.disconnect existing reason >|= fun () -> + CapTP.disconnect existing reason; conn ) else ( - Lwt_switch.turn_off switch >|= fun () -> existing ) let public_address t = t.address let connect_anon t addr ~service = - let switch = Lwt_switch.create () in - Network.connect t.network ~switch ~secret_key:t.secret_key addr >>= function - | Error (`Msg m) -> Lwt.return @@ Error (Capnp_rpc.Exception.v m) + match Network.connect ~sw:t.sw t.network ~secret_key:t.secret_key addr with + | Error (`Msg m) -> Error (Capnp_rpc.Exception.v m) | Ok ep -> - add_connection t ~switch ep ~mode:`Connect >|= fun conn -> + let conn = add_connection t ep ~mode:`Connect in Ok (CapTP.bootstrap conn service) let initiate_connection t remote_id addr service = (* We need to start a new connection attempt. *) - let switch = Lwt_switch.create () in let conn = - Network.connect t.network ~switch ~secret_key:t.secret_key addr >>= function - | Error (`Msg m) -> Lwt.return @@ Error (Capnp_rpc.Exception.v m) - | Ok ep -> add_connection t ~switch ep ~mode:`Connect >|= fun conn -> Ok conn + Fiber.fork_promise ~sw:t.sw @@ fun () -> + match Network.connect ~sw:t.sw t.network ~secret_key:t.secret_key addr with + | Error (`Msg m) -> Error (Capnp_rpc.Exception.v m) + | Ok ep -> + Ok (add_connection t ep ~mode:`Connect) in t.connecting <- ID_map.add remote_id conn t.connecting; - conn >|= fun conn -> + let conn = Promise.await_exn conn in t.connecting <- ID_map.remove remote_id t.connecting; match conn with | Ok conn -> Ok (CapTP.bootstrap conn service) @@ -167,17 +180,17 @@ module Make (Network : S.NETWORK) (Underlying : Mirage_flow.S) = struct Restorer.restore t.restore service else match ID_map.find remote_id t.connections with | Some conn when CapTP.disconnecting conn -> - Lwt_condition.wait t.connection_removed >>= fun () -> + Condition.await t.connection_removed; connect_auth t remote_id addr ~service | Some conn -> (* Already connected; use that. *) - Lwt.return @@ Ok (CapTP.bootstrap conn service) + Ok (CapTP.bootstrap conn service) | None -> match ID_map.find remote_id t.connecting with | None -> initiate_connection t remote_id addr service | Some conn -> (* We're already trying to establish a connection, wait for that. *) - conn >|= function + match Promise.await_exn conn with | Ok conn -> Ok (CapTP.bootstrap conn service) | Error _ as e -> e @@ -186,7 +199,7 @@ module Make (Network : S.NETWORK) (Underlying : Mirage_flow.S) = struct method connect = let (addr, service) = sr in let remote_id = Network.Address.digest addr in - Lwt_result.map Cast.cap_to_raw ( + Result.map Cast.cap_to_raw ( if remote_id = Auth.Digest.insecure then connect_anon t addr ~service else connect_auth t remote_id addr ~service ) diff --git a/capnp-rpc-unix.opam b/capnp-rpc-unix.opam index 531d3aefc..dce811a19 100644 --- a/capnp-rpc-unix.opam +++ b/capnp-rpc-unix.opam @@ -13,7 +13,6 @@ depends: [ "ocaml" {>= "4.08.0"} "capnp-rpc-net" {= version} "cmdliner" {>= "1.1.0"} - "cstruct-lwt" "astring" "fmt" {>= "0.8.7"} "logs" @@ -21,11 +20,10 @@ depends: [ "base64" {>= "3.0.0"} "dune" {>= "3.0"} "alcotest" {>= "1.0.1" & with-test} - "alcotest-lwt" { >= "1.0.1" & with-test} "mirage-crypto-rng" {>= "0.7.0"} "mdx" {with-test} - "lwt" "asetmap" {with-test} + "eio_main" ] build: [ ["dune" "build" "-p" name "-j" jobs] diff --git a/capnp-rpc/capTP.ml b/capnp-rpc/capTP.ml index 87ce05854..94316c3a2 100644 --- a/capnp-rpc/capTP.ml +++ b/capnp-rpc/capTP.ml @@ -705,6 +705,7 @@ module Make (EP : Message_types.ENDPOINT) = struct tags : Logs.Tag.set; embargoes : (EmbargoId.t * Cap_proxy.resolver_cap) Embargoes.t; restore : restorer; + fork : (unit -> unit) -> unit; questions : Question.t Questions.t; answers : Answer.t Answers.t; @@ -740,11 +741,12 @@ module Make (EP : Message_types.ENDPOINT) = struct let default_restore k _object_id = k @@ Error (Exception.v "This vat has no restorer") - let create ?(restore=default_restore) ~tags ~queue_send = + let create ?(restore=default_restore) ~tags ~fork ~queue_send = { queue_send = (queue_send :> EP.Out.t -> unit); tags; restore = restore; + fork; questions = Questions.make (); answers = Answers.make (); imports = Imports.make (); @@ -1465,8 +1467,10 @@ module Make (EP : Message_types.ENDPOINT) = struct | `Local target -> Log.debug (fun f -> f ~tags:t.tags "Handling call: (%t).call %a" target#pp Core_types.Request_payload.pp msg); - target#call answer_resolver msg; (* Takes ownership of [caps]. *) - dec_ref target + t.fork (fun () -> + target#call answer_resolver msg; (* Takes ownership of [caps]. *) + dec_ref target + ) | #message_target_cap as target -> Log.debug (fun f -> f ~tags:t.tags "Forwarding call: (%a).call %a" pp_message_target_cap target Core_types.Request_payload.pp msg); @@ -1476,6 +1480,7 @@ module Make (EP : Message_types.ENDPOINT) = struct let promise, answer_resolver = Local_struct_promise.make () in let answer = Answer.create id ~answer:promise in Answers.set t.answers id answer; + t.fork @@ fun () -> object_id |> t.restore @@ fun service -> if Answer.needs_return answer && t.disconnected = None then ( let results = diff --git a/capnp-rpc/capTP.mli b/capnp-rpc/capTP.mli index 93eb5fa0f..51dda76f0 100644 --- a/capnp-rpc/capTP.mli +++ b/capnp-rpc/capTP.mli @@ -12,11 +12,13 @@ module Make (EP : Message_types.ENDPOINT) : sig capability. *) val create : ?restore:restorer -> tags:Logs.Tag.set -> + fork:((unit -> unit) -> unit) -> queue_send:([> EP.Out.t] -> unit) -> t - (** [create ~bootstrap ~tags ~queue_send] is a handler for a connection to a remote peer. + (** [create ~restore ~tags ~queue_send] is a handler for a connection to a remote peer. Messages will be sent to the peer by calling [queue_send] (which MUST deliver them in order). - If the remote peer asks for the bootstrap object, it will be given a reference to [bootstrap]. - Log messages will be tagged with [tags]. *) + If the remote peer asks for a bootstrap object, [restore] will be used to get it. + Log messages will be tagged with [tags]. + @param fork is used when dispatching a local method handler. *) val bootstrap : t -> string -> EP.Core_types.cap (** [bootstrap t object_id] returns a reference to the remote peer's bootstrap object, if any. diff --git a/examples/pipelining/dune b/examples/pipelining/dune index fc80455c8..e7a552a10 100644 --- a/examples/pipelining/dune +++ b/examples/pipelining/dune @@ -1,6 +1,6 @@ (executable (name main) - (libraries lwt.unix capnp-rpc-unix logs.fmt) + (libraries eio_main capnp-rpc-unix logs.fmt) (flags (:standard -w -53-55))) (rule diff --git a/examples/pipelining/echo.ml b/examples/pipelining/echo.ml index a293321b3..7a3bdef9b 100644 --- a/examples/pipelining/echo.ml +++ b/examples/pipelining/echo.ml @@ -1,6 +1,6 @@ module Api = Echo_api.MakeRPC(Capnp_rpc_lwt) -open Lwt.Infix +open Eio.Std open Capnp_rpc_lwt module Callback = struct @@ -26,23 +26,23 @@ module Callback = struct Capability.call_for_unit t method_id request end -let (>>!=) = Lwt_result.bind (* Return errors *) - -let notify callback ~msg = +let notify ~clock msg callback = let rec loop = function | 0 -> - Lwt.return @@ Ok (Service.Response.create_empty ()) + Service.return_empty () | i -> - Callback.log callback msg >>!= fun () -> - Lwt_unix.sleep 1.0 >>= fun () -> - loop (i - 1) + match Callback.log callback msg with + | Error (`Capnp e) -> Service.error e + | Ok () -> + Eio.Time.sleep clock 1.0; + loop (i - 1) in loop 3 let service_logger = - Callback.local (Printf.printf "[server] Received %S\n%!") + Callback.local (traceln "[server] Received %S") -let local = +let local ~clock = let module Echo = Api.Service.Echo in Echo.local @@ object inherit Echo.service @@ -63,8 +63,7 @@ let local = match callback with | None -> Service.fail "No callback parameter!" | Some callback -> - Service.return_lwt @@ fun () -> - Capability.with_ref callback (notify ~msg) + Capability.with_ref callback (notify ~clock msg) (* $MDX part-begin=server-get-logger *) method get_logger_impl _ release_params = @@ -82,7 +81,7 @@ let ping t msg = let open Echo.Ping in let request, params = Capability.Request.create Params.init_pointer in Params.msg_set params msg; - Capability.call_for_value_exn t method_id request >|= Results.reply_get + Capability.call_for_value_exn t method_id request |> Results.reply_get let heartbeat t msg callback = let open Echo.Heartbeat in diff --git a/examples/pipelining/main.ml b/examples/pipelining/main.ml index 8d0adbd94..ac139ec64 100644 --- a/examples/pipelining/main.ml +++ b/examples/pipelining/main.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc_lwt let () = @@ -6,12 +6,12 @@ let () = Logs.set_reporter (Logs_fmt.reporter ()) let callback_fn msg = - Fmt.pr "Callback got %S@." msg + traceln "Callback got %S" msg (* $MDX part-begin=run-client *) let run_client service = let logger = Echo.get_logger service in - Echo.Callback.log logger "Message from client" >|= function + match Echo.Callback.log logger "Message from client" with | Ok () -> () | Error (`Capnp err) -> Fmt.epr "Server's logger failed: %a" Capnp_rpc.Error.pp err @@ -20,18 +20,20 @@ let run_client service = let secret_key = `Ephemeral let listen_address = `TCP ("127.0.0.1", 7000) -let start_server () = +let start_server ~sw ~clock net = let config = Capnp_rpc_unix.Vat_config.create ~secret_key listen_address in let service_id = Capnp_rpc_unix.Vat_config.derived_id config "main" in - let restore = Capnp_rpc_net.Restorer.single service_id Echo.local in - Capnp_rpc_unix.serve config ~restore >|= fun vat -> + let restore = Capnp_rpc_net.Restorer.single service_id (Echo.local ~clock) in + let vat = Capnp_rpc_unix.serve ~sw ~net ~restore config in Capnp_rpc_unix.Vat.sturdy_uri vat service_id let () = - Lwt_main.run begin - start_server () >>= fun uri -> - Fmt.pr "[client] Connecting to echo service...@."; - let client_vat = Capnp_rpc_unix.client_only_vat () in - let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in - Sturdy_ref.with_cap_exn sr run_client - end + Eio_main.run @@ fun env -> + Switch.run @@ fun sw -> + let clock = env#clock in + let uri = start_server ~sw ~clock env#net in + traceln "[client] Connecting to echo service..."; + let client_vat = Capnp_rpc_unix.client_only_vat ~sw env#net in + let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in + Sturdy_ref.with_cap_exn sr run_client; + raise Exit diff --git a/examples/sturdy-refs-2/dune b/examples/sturdy-refs-2/dune index 137d3fd39..5a91aaf46 100644 --- a/examples/sturdy-refs-2/dune +++ b/examples/sturdy-refs-2/dune @@ -1,6 +1,6 @@ (executable (name main) - (libraries lwt.unix capnp-rpc-unix logs.fmt) + (libraries eio_main capnp-rpc-unix logs.fmt) (flags (:standard -w -53-55))) (rule diff --git a/examples/sturdy-refs-2/main.ml b/examples/sturdy-refs-2/main.ml index 2d57c832f..acd8f70b7 100644 --- a/examples/sturdy-refs-2/main.ml +++ b/examples/sturdy-refs-2/main.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc_lwt module Restorer = Capnp_rpc_net.Restorer @@ -14,7 +14,7 @@ let or_fail = function | Ok x -> x | Error (`Msg m) -> failwith m -let start_server () = +let start_server ~sw net = let config = Capnp_rpc_unix.Vat_config.create ~secret_key listen_address in let make_sturdy = Capnp_rpc_unix.Vat_config.sturdy_uri config in let services = Restorer.Table.create make_sturdy in @@ -22,20 +22,22 @@ let start_server () = let root_id = Capnp_rpc_unix.Vat_config.derived_id config "root" in let root = Logger.local "root" in Restorer.Table.add services root_id root; - Capnp_rpc_unix.serve config ~restore >|= fun _vat -> + let _vat = Capnp_rpc_unix.serve ~sw ~net ~restore config in Capnp_rpc_unix.Vat_config.sturdy_uri config root_id (* $MDX part-begin=main *) let () = - Lwt_main.run begin - start_server () >>= fun root_uri -> - let vat = Capnp_rpc_unix.client_only_vat () in - let root_sr = Capnp_rpc_unix.Vat.import vat root_uri |> or_fail in - Sturdy_ref.with_cap_exn root_sr @@ fun root -> - Logger.log root "Message from Admin" >>= fun () -> - let for_alice = Logger.sub root "alice" in - let for_bob = Logger.sub root "bob" in - Logger.log for_alice "Message from Alice" >>= fun () -> - Logger.log for_bob "Message from Bob" - end + Eio_main.run @@ fun env -> + Switch.run @@ fun sw -> + let net = env#net in + let root_uri = start_server ~sw net in + let vat = Capnp_rpc_unix.client_only_vat ~sw net in + let root_sr = Capnp_rpc_unix.Vat.import vat root_uri |> or_fail in + Sturdy_ref.with_cap_exn root_sr @@ fun root -> + Logger.log root "Message from Admin"; + let for_alice = Logger.sub root "alice" in + let for_bob = Logger.sub root "bob" in + Logger.log for_alice "Message from Alice"; + Logger.log for_bob "Message from Bob"; + raise Exit (* $MDX part-end *) diff --git a/examples/sturdy-refs-3/dune b/examples/sturdy-refs-3/dune index 137d3fd39..5a91aaf46 100644 --- a/examples/sturdy-refs-3/dune +++ b/examples/sturdy-refs-3/dune @@ -1,6 +1,6 @@ (executable (name main) - (libraries lwt.unix capnp-rpc-unix logs.fmt) + (libraries eio_main capnp-rpc-unix logs.fmt) (flags (:standard -w -53-55))) (rule diff --git a/examples/sturdy-refs-3/main.ml b/examples/sturdy-refs-3/main.ml index f53219889..b87da1fc9 100644 --- a/examples/sturdy-refs-3/main.ml +++ b/examples/sturdy-refs-3/main.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc_lwt module Restorer = Capnp_rpc_net.Restorer @@ -14,7 +14,7 @@ let or_fail = function | Ok x -> x | Error (`Msg m) -> failwith m -let start_server () = +let start_server ~sw net = let config = Capnp_rpc_unix.Vat_config.create ~secret_key listen_address in let make_sturdy = Capnp_rpc_unix.Vat_config.sturdy_uri config in let services = Restorer.Table.create make_sturdy in @@ -27,28 +27,30 @@ let start_server () = in (* $MDX part-end *) Restorer.Table.add services root_id root; - Capnp_rpc_unix.serve config ~restore >|= fun _vat -> + let _vat = Capnp_rpc_unix.serve ~sw ~net ~restore config in Capnp_rpc_unix.Vat_config.sturdy_uri config root_id -let run_client cap_file = - let vat = Capnp_rpc_unix.client_only_vat () in +let run_client ~sw ~net cap_file = + let vat = Capnp_rpc_unix.client_only_vat ~sw net in let sr = Capnp_rpc_unix.Cap_file.load vat cap_file |> or_fail in Sturdy_ref.with_cap_exn sr @@ fun for_alice -> Logger.log for_alice "Message from Alice" let () = - Lwt_main.run begin - start_server () >>= fun root_uri -> - let vat = Capnp_rpc_unix.client_only_vat () in - let root_sr = Capnp_rpc_unix.Vat.import vat root_uri |> or_fail in - Sturdy_ref.with_cap_exn root_sr @@ fun root -> - Logger.log root "Message from Admin" >>= fun () -> - (* $MDX part-begin=save *) - (* The admin creates a logger for Alice and saves it: *) - let for_alice = Logger.sub root "alice" in - Persistence.save_exn for_alice >>= fun uri -> - Capnp_rpc_unix.Cap_file.save_uri uri "alice.cap" |> or_fail; - (* Alice uses it: *) - run_client "alice.cap" - (* $MDX part-end *) - end + Eio_main.run @@ fun env -> + Switch.run @@ fun sw -> + let net = env#net in + let root_uri = start_server ~sw net in + let vat = Capnp_rpc_unix.client_only_vat ~sw net in + let root_sr = Capnp_rpc_unix.Vat.import vat root_uri |> or_fail in + Sturdy_ref.with_cap_exn root_sr @@ fun root -> + Logger.log root "Message from Admin"; + (* $MDX part-begin=save *) + (* The admin creates a logger for Alice and saves it: *) + let for_alice = Logger.sub root "alice" in + let uri = Persistence.save_exn for_alice in + Capnp_rpc_unix.Cap_file.save_uri uri "alice.cap" |> or_fail; + (* Alice uses it: *) + run_client ~sw ~net "alice.cap"; + raise Exit + (* $MDX part-end *) diff --git a/examples/sturdy-refs-4/db.ml b/examples/sturdy-refs-4/db.ml index 32bcada09..fdda6f953 100644 --- a/examples/sturdy-refs-4/db.ml +++ b/examples/sturdy-refs-4/db.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc_lwt open Capnp_rpc_net @@ -9,7 +9,7 @@ type loader = [`Logger_beacebd78653e9af] Sturdy_ref.t -> label:string -> Restore type t = { store : Store.Reader.SavedService.struct_t File_store.t; - loader : loader Lwt.t; + loader : loader Promise.t; make_sturdy : Restorer.Id.t -> Uri.t; } @@ -32,16 +32,16 @@ let save_new t ~label = let load t sr digest = match File_store.load t.store ~digest with - | None -> Lwt.return Restorer.unknown_service_id + | None -> Restorer.unknown_service_id | Some saved_service -> let logger = Store.Reader.SavedService.logger_get saved_service in let label = Store.Reader.SavedLogger.label_get logger in let sr = Capnp_rpc_lwt.Sturdy_ref.cast sr in - t.loader >|= fun loader -> + let loader = Promise.await t.loader in loader sr ~label let create ~make_sturdy dir = - let loader, set_loader = Lwt.wait () in + let loader, set_loader = Promise.create () in if not (Sys.file_exists dir) then Unix.mkdir dir 0o755; let store = File_store.create dir in {store; loader; make_sturdy}, set_loader diff --git a/examples/sturdy-refs-4/db.mli b/examples/sturdy-refs-4/db.mli index 6ededf061..75a9805d2 100644 --- a/examples/sturdy-refs-4/db.mli +++ b/examples/sturdy-refs-4/db.mli @@ -6,7 +6,7 @@ include Restorer.LOADER type loader = [`Logger_beacebd78653e9af] Sturdy_ref.t -> label:string -> Restorer.resolution (** A function to create a new in-memory logger with the given label and sturdy-ref. *) -val create : make_sturdy:(Restorer.Id.t -> Uri.t) -> string -> t * loader Lwt.u +val create : make_sturdy:(Restorer.Id.t -> Uri.t) -> string -> t * loader Eio.Promise.u (** [create ~make_sturdy dir] is a database that persists services in [dir] and a resolver to let you set the loader (we're not ready to set the loader when we create the database). *) diff --git a/examples/sturdy-refs-4/dune b/examples/sturdy-refs-4/dune index 97144379c..8fceed74b 100644 --- a/examples/sturdy-refs-4/dune +++ b/examples/sturdy-refs-4/dune @@ -1,6 +1,6 @@ (executable (name main) - (libraries lwt.unix capnp-rpc-unix logs.fmt cmdliner) + (libraries eio_main capnp-rpc-unix logs.fmt cmdliner) (flags (:standard -w -53-55))) (rule diff --git a/examples/sturdy-refs-4/logger.ml b/examples/sturdy-refs-4/logger.ml index 421ae484a..739018146 100644 --- a/examples/sturdy-refs-4/logger.ml +++ b/examples/sturdy-refs-4/logger.ml @@ -1,5 +1,3 @@ -open Lwt.Infix - module Api = Api.MakeRPC(Capnp_rpc_lwt) open Capnp_rpc_lwt @@ -22,14 +20,13 @@ let local ~persist_new sr label = let sub_label = Params.label_get params in release_param_caps (); let label = Printf.sprintf "%s/%s" label sub_label in - Service.return_lwt @@ fun () -> - persist_new ~label >|= function - | Error e -> Error (`Capnp (`Exception e)) + match persist_new ~label with + | Error e -> Service.error (`Exception e) | Ok logger -> let response, results = Service.Response.create Results.init_pointer in Results.logger_set results (Some logger); Capability.dec_ref logger; - Ok response + Service.return response (* $MDX part-end *) end diff --git a/examples/sturdy-refs-4/main.ml b/examples/sturdy-refs-4/main.ml index bc2bee3db..c504b533d 100644 --- a/examples/sturdy-refs-4/main.ml +++ b/examples/sturdy-refs-4/main.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc_lwt module Restorer = Capnp_rpc_net.Restorer @@ -13,56 +13,55 @@ let or_fail = function (* $MDX part-begin=server *) let serve config = - Lwt_main.run begin - (* Create the on-disk store *) - let make_sturdy = Capnp_rpc_unix.Vat_config.sturdy_uri config in - let db, set_loader = Db.create ~make_sturdy "./store" in - (* Create the restorer *) - let services = Restorer.Table.of_loader (module Db) db in - let restore = Restorer.of_table services in - (* Add the root service *) - let persist_new ~label = - let id = Db.save_new db ~label in - Capnp_rpc_net.Restorer.restore restore id - in - let root_id = Capnp_rpc_unix.Vat_config.derived_id config "root" in - let root = - let sr = Capnp_rpc_net.Restorer.Table.sturdy_ref services root_id in - Logger.local ~persist_new sr "root" - in - Restorer.Table.add services root_id root; - (* Tell the database how to restore saved loggers *) - Lwt.wakeup set_loader (fun sr ~label -> Restorer.grant @@ Logger.local ~persist_new sr label); - (* Run the server *) - Capnp_rpc_unix.serve config ~restore >>= fun _vat -> - let uri = Capnp_rpc_unix.Vat_config.sturdy_uri config root_id in - Capnp_rpc_unix.Cap_file.save_uri uri "admin.cap" |> or_fail; - print_endline "Wrote admin.cap"; - fst @@ Lwt.wait () (* Wait forever *) - end + Eio_main.run @@ fun env -> + Switch.run @@ fun sw -> + (* Create the on-disk store *) + let make_sturdy = Capnp_rpc_unix.Vat_config.sturdy_uri config in + let db, set_loader = Db.create ~make_sturdy "./store" in + (* Create the restorer *) + let services = Restorer.Table.of_loader ~sw (module Db) db in + let restore = Restorer.of_table services in + (* Add the root service *) + let persist_new ~label = + let id = Db.save_new db ~label in + Capnp_rpc_net.Restorer.restore restore id + in + let root_id = Capnp_rpc_unix.Vat_config.derived_id config "root" in + let root = + let sr = Capnp_rpc_net.Restorer.Table.sturdy_ref services root_id in + Logger.local ~persist_new sr "root" + in + Restorer.Table.add services root_id root; + (* Tell the database how to restore saved loggers *) + Promise.resolve set_loader (fun sr ~label -> Restorer.grant @@ Logger.local ~persist_new sr label); + (* Run the server *) + let _vat = Capnp_rpc_unix.serve ~sw ~net:env#net ~restore config in + let uri = Capnp_rpc_unix.Vat_config.sturdy_uri config root_id in + Capnp_rpc_unix.Cap_file.save_uri uri "admin.cap" |> or_fail; + print_endline "Wrote admin.cap"; + Fiber.await_cancel () (* $MDX part-end *) let log cap_file msg = - Lwt_main.run begin - let vat = Capnp_rpc_unix.client_only_vat () in - let sr = Capnp_rpc_unix.Cap_file.load vat cap_file |> or_fail in - Sturdy_ref.with_cap_exn sr @@ fun logger -> - Logger.log logger msg - end + Eio_main.run @@ fun env -> + Switch.run @@ fun sw -> + let vat = Capnp_rpc_unix.client_only_vat ~sw env#net in + let sr = Capnp_rpc_unix.Cap_file.load vat cap_file |> or_fail in + Sturdy_ref.with_cap_exn sr @@ fun logger -> + Logger.log logger msg let sub cap_file label = - Lwt_main.run begin - let sub_file = label ^ ".cap" in - if Sys.file_exists sub_file then Fmt.failwith "%S already exists!" sub_file; - let vat = Capnp_rpc_unix.client_only_vat () in - let sr = Capnp_rpc_unix.Cap_file.load vat cap_file |> or_fail in - Sturdy_ref.with_cap_exn sr @@ fun logger -> - let sub = Logger.sub logger label in - Persistence.save_exn sub >>= fun uri -> - Capnp_rpc_unix.Cap_file.save_uri uri sub_file |> or_fail; - Printf.printf "Wrote %S\n%!" sub_file; - Lwt.return_unit - end + Eio_main.run @@ fun env -> + Switch.run @@ fun sw -> + let sub_file = label ^ ".cap" in + if Sys.file_exists sub_file then Fmt.failwith "%S already exists!" sub_file; + let vat = Capnp_rpc_unix.client_only_vat ~sw env#net in + let sr = Capnp_rpc_unix.Cap_file.load vat cap_file |> or_fail in + Sturdy_ref.with_cap_exn sr @@ fun logger -> + let sub = Logger.sub logger label in + let uri = Persistence.save_exn sub in + Capnp_rpc_unix.Cap_file.save_uri uri sub_file |> or_fail; + Printf.printf "Wrote %S\n%!" sub_file open Cmdliner diff --git a/examples/sturdy-refs/dune b/examples/sturdy-refs/dune index 137d3fd39..5a91aaf46 100644 --- a/examples/sturdy-refs/dune +++ b/examples/sturdy-refs/dune @@ -1,6 +1,6 @@ (executable (name main) - (libraries lwt.unix capnp-rpc-unix logs.fmt) + (libraries eio_main capnp-rpc-unix logs.fmt) (flags (:standard -w -53-55))) (rule diff --git a/examples/sturdy-refs/main.ml b/examples/sturdy-refs/main.ml index 01492e265..47c30f4c6 100644 --- a/examples/sturdy-refs/main.ml +++ b/examples/sturdy-refs/main.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc_lwt module Restorer = Capnp_rpc_net.Restorer @@ -21,30 +21,32 @@ let make_service ~config ~services name = Restorer.Table.add services id service; name, id -let start_server () = +let start_server ~sw net = let config = Capnp_rpc_unix.Vat_config.create ~secret_key listen_address in let make_sturdy = Capnp_rpc_unix.Vat_config.sturdy_uri config in let services = Restorer.Table.create make_sturdy in let restore = Restorer.of_table services in let services = List.map (make_service ~config ~services) ["alice"; "bob"] in - Capnp_rpc_unix.serve config ~restore >|= fun vat -> + let vat = Capnp_rpc_unix.serve ~sw ~net ~restore config in services |> List.iter (fun (name, id) -> let cap_file = name ^ ".cap" in Capnp_rpc_unix.Cap_file.save_service vat id cap_file |> or_fail; Printf.printf "[server] saved %S\n%!" cap_file ) -let run_client cap_file msg = - let vat = Capnp_rpc_unix.client_only_vat () in +let run_client ~sw ~net cap_file msg = + let vat = Capnp_rpc_unix.client_only_vat ~sw net in let sr = Capnp_rpc_unix.Cap_file.load vat cap_file |> or_fail in Printf.printf "[client] loaded %S\n%!" cap_file; Sturdy_ref.with_cap_exn sr @@ fun cap -> Logger.log cap msg let () = - Lwt_main.run begin - start_server () >>= fun () -> - run_client "./alice.cap" "Message from Alice" >>= fun () -> - run_client "./bob.cap" "Message from Bob" - end + Eio_main.run @@ fun env -> + Switch.run @@ fun sw -> + let net = env#net in + start_server ~sw net; + run_client ~sw ~net "./alice.cap" "Message from Alice"; + run_client ~sw ~net "./bob.cap" "Message from Bob"; + raise Exit (* $MDX part-end *) diff --git a/examples/testlib/calc.ml b/examples/testlib/calc.ml index bac69b1f2..de1f513e9 100644 --- a/examples/testlib/calc.ml +++ b/examples/testlib/calc.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc_lwt module Api = Calculator.MakeRPC(Capnp_rpc_lwt) @@ -85,10 +85,10 @@ module Value = struct let read v = let open Api.Client.Calculator.Value.Read in let req = Capability.Request.create_no_args () in - Capability.call_for_value_exn v method_id req >|= Results.value_get + Capability.call_for_value_exn v method_id req |> Results.value_get let final_read v = - read v >|= fun result -> + let result = read v in Capability.dec_ref v; result @@ -114,26 +114,31 @@ let call_fn fn args = let open Api.Client.Calculator.Function.Call in let req, p = Capability.Request.create Params.init_pointer in ignore (Params.params_set_list p args); - Capability.call_for_value_exn fn method_id req >|= Results.value_get + Capability.call_for_value_exn fn method_id req |> Results.value_get -let pp_result_lwt f x = - match Lwt.state x with - | Lwt.Return v -> Fmt.float f v - | Lwt.Fail ex -> Fmt.exn f ex - | Lwt.Sleep -> Fmt.string f "(still calculating)" +let pp_result_promise f x = + match Promise.peek x with + | Some (Ok v) -> Fmt.float f v + | Some (Error ex) -> Fmt.exn f ex + | None -> Fmt.string f "(still calculating)" -(* Evaluate an expression, where some sub-expressions may require remote calls. *) -let rec eval ?(args=[||]) : _ -> Api.Reader.Calculator.Value.t Capability.t = +(* Evaluate an expression, where some sub-expressions may require remote calls. + Immediately returns a service for the result, while the calculation continues in [sw]. *) +let rec eval ~sw ?(args=[||]) : _ -> Api.Reader.Calculator.Value.t Capability.t = let open Expr in function | Float f -> Value.local f | Prev v -> Capability.inc_ref v; v | Param p -> Value.local args.(p) | Call (f, params) -> - let params = params |> Lwt_list.map_p (fun p -> - let value = eval ~args p in - Value.final_read value - ) in - let result = params >>= call_fn f in + let result = Fiber.fork_promise ~sw (fun () -> + params + |> Fiber.map (fun p -> + let value = eval ~sw ~args p in + Value.final_read value + ) + |> call_fn f + ) + in let open Api.Service.Calculator in Value.local @@ object inherit Value.service @@ -141,17 +146,15 @@ let rec eval ?(args=[||]) : _ -> Api.Reader.Calculator.Value.t Capability.t = val id = Capnp_rpc.Debug.OID.next () method! pp f = - Fmt.pf f "EvalResultValue(%a) = %a" Capnp_rpc.Debug.OID.pp id pp_result_lwt result + Fmt.pf f "EvalResultValue(%a) = %a" Capnp_rpc.Debug.OID.pp id pp_result_promise result method read_impl _ release_params = let open Value.Read in release_params (); - Service.return_lwt (fun () -> - result >|= fun result -> - let resp, c = Service.Response.create Results.init_pointer in - Results.value_set c result; - Ok resp - ) + let result = Promise.await_exn result in + let resp, c = Service.Response.create Results.init_pointer in + Results.value_set c result; + Service.return resp end module Fn = struct @@ -168,15 +171,14 @@ module Fn = struct let open Function.Call in let args = Params.params_get_array params in assert (Array.length args = n_args); - let value = eval ~args body in - release_params (); (* Functions return floats, not Value objects, so we have to wait here. *) - Service.return_lwt (fun () -> - Value.final_read value >|= fun value -> - let resp, r = Service.Response.create ~message_size:200 Results.init_pointer in - Results.value_set r value; - Ok resp - ) + Switch.run @@ fun sw -> + let value = eval ~sw ~args body in + release_params (); + let value = Value.final_read value in + let resp, r = Service.Response.create ~message_size:200 Results.init_pointer in + Results.value_set r value; + Service.return resp end let local_binop op : Api.Builder.Calculator.Function.t Capability.t = @@ -204,7 +206,7 @@ module Fn = struct end (* The main calculator service *) -let local = +let local ~sw = let module Calculator = Api.Service.Calculator in Calculator.local @@ object inherit Calculator.service @@ -224,7 +226,7 @@ let local = let open Calculator.Evaluate in let expr = Expr.parse (Params.expression_get params) in release_params (); - let value_obj = eval expr in + let value_obj = eval ~sw expr in Expr.release expr; let resp, results = Service.Response.create ~message_size:200 Results.init_pointer in Results.value_set results (Some value_obj); diff --git a/examples/testlib/calc.mli b/examples/testlib/calc.mli index 8d6eb835e..ac59997e8 100644 --- a/examples/testlib/calc.mli +++ b/examples/testlib/calc.mli @@ -7,10 +7,10 @@ type t = [`Calculator_97983392df35cc36] Capability.t module rec Value : sig type t = [`Value_c3e69d34d3ee48d2] Capability.t - val read : t -> float Lwt.t + val read : t -> float (** [read t] reads the value of the remote value object. *) - val final_read : t -> float Lwt.t + val final_read : t -> float (** [final_read t] reads the value and dec_ref's [t]. *) val local : float -> t @@ -20,7 +20,7 @@ end and Fn : sig type t = [`Function_ede83a3d96840394] Capability.t - val call : t -> float list -> float Lwt.t + val call : t -> float list -> float (** [call fn args] does [fn args]. *) val local : int -> Expr.t -> Fn.t @@ -58,5 +58,6 @@ val evaluate : t -> Expr.t -> Value.t val getOperator : t -> [`Add | `Subtract | `Multiply | `Divide] -> Fn.t (** [getOperator t op] is a remote operator function provided by [t]. *) -val local : t -(** A capability to a local calculator service *) +val local : sw:Eio.Switch.t -> t +(** A capability to a local calculator service. + It may immediately return a promise of a result, while continuing the calculation in [sw]. *) diff --git a/examples/testlib/echo.ml b/examples/testlib/echo.ml index f040f9f93..51c6f1d9a 100644 --- a/examples/testlib/echo.ml +++ b/examples/testlib/echo.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc_lwt type t = Api.Service.Echo.t Capability.t @@ -10,7 +10,7 @@ let local () = Echo.local @@ object inherit Echo.service - val mutable blocked = Lwt.wait () + val mutable blocked = Promise.create () val mutable count = 0 val id = Capnp_rpc.Debug.OID.next () @@ -25,16 +25,15 @@ let local () = Results.reply_set results (Fmt.str "got:%d:%s" count msg); count <- count + 1; if Params.slow_get params then ( - Service.return_lwt (fun () -> - fst blocked >|= fun () -> Ok resp - ) + Promise.await (fst blocked); + Service.return resp ) else Service.return resp method unblock_impl _ release_params = release_params (); - Lwt.wakeup (snd blocked) (); - blocked <- Lwt.wait (); + Promise.resolve (snd blocked) (); + blocked <- Promise.create (); Service.return_empty () end @@ -45,14 +44,14 @@ let ping t ?(slow=false) msg = let req, p = Capability.Request.create Params.init_pointer in Params.slow_set p slow; Params.msg_set p msg; - Capability.call_for_value_exn t method_id req >|= Results.reply_get + Capability.call_for_value_exn t method_id req |> Results.reply_get let ping_result t ?(slow=false) msg = let open Echo.Ping in let req, p = Capability.Request.create Params.init_pointer in Params.slow_set p slow; Params.msg_set p msg; - Capability.call_for_value t method_id req >|= function + match Capability.call_for_value t method_id req with | Ok x -> Ok (Results.reply_get x) | Error _ as e -> e diff --git a/examples/testlib/echo.mli b/examples/testlib/echo.mli index f42797de0..c9b75ab5d 100644 --- a/examples/testlib/echo.mli +++ b/examples/testlib/echo.mli @@ -5,13 +5,13 @@ type t = [`Echo_bb48258560861cec] Capability.t val local : unit -> t (** [local ()] is a capability to a new local echo service. *) -val ping : t -> ?slow:bool -> string -> string Lwt.t +val ping : t -> ?slow:bool -> string -> string (** [ping t msg] sends [msg] to [t] and returns its response. If [slow] is given, the service will wait until [unblock] is called before replying. *) -val ping_result : t -> ?slow:bool -> string -> (string, [> `Capnp of Capnp_rpc.Error.t]) Lwt_result.t +val ping_result : t -> ?slow:bool -> string -> (string, [> `Capnp of Capnp_rpc.Error.t]) result (** [ping t msg] sends [msg] to [t] and returns its response. If [slow] is given, the service will wait until [unblock] is called before replying. *) -val unblock : t -> unit Lwt.t +val unblock : t -> unit (** [unblock t] tells the service to return any blocked ping responses. *) diff --git a/examples/testlib/registry.ml b/examples/testlib/registry.ml index 513499b19..8b5a1084a 100644 --- a/examples/testlib/registry.ml +++ b/examples/testlib/registry.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc_lwt type t = Api.Service.Registry.t Capability.t @@ -17,12 +17,12 @@ let version_service = Service.return resp end -let local () = +let local ~sw () = let module Registry = Api.Service.Registry in Registry.local @@ object inherit Registry.service - val mutable blocked = Lwt.wait () + val mutable blocked = Promise.create () val mutable echo_service = Echo.local () method! release = Capability.dec_ref echo_service @@ -45,9 +45,8 @@ let local () = let open Registry.EchoService in let resp, results = Service.Response.create Results.init_pointer in Results.service_set results (Some echo_service); - Service.return_lwt (fun () -> - fst blocked >|= fun () -> Ok resp - ) + Promise.await (fst blocked); + Service.return resp method echo_service_promise_impl _params release_params = release_params (); @@ -56,8 +55,8 @@ let local () = let promise, resolver = Capability.promise () in Results.service_set results (Some promise); Capability.dec_ref promise; - Lwt.async (fun () -> - fst blocked >|= fun () -> + Fiber.fork ~sw (fun () -> + Promise.await (fst blocked); Capability.inc_ref echo_service; Capability.resolve_ok resolver echo_service ); @@ -65,8 +64,8 @@ let local () = method unblock_impl _ release_params = release_params (); - Lwt.wakeup (snd blocked) (); - blocked <- Lwt.wait (); + Promise.resolve (snd blocked) (); + blocked <- Promise.create (); Service.return_empty () method complex_impl _ release_params = @@ -131,5 +130,5 @@ module Version = struct let read t = let open Version.Read in let req = Capability.Request.create_no_args () in - Capability.call_for_value_exn t method_id req >|= Results.version_get + Capability.call_for_value_exn t method_id req |> Results.version_get end diff --git a/examples/testlib/registry.mli b/examples/testlib/registry.mli index f0e442ee8..71e05e23f 100644 --- a/examples/testlib/registry.mli +++ b/examples/testlib/registry.mli @@ -1,14 +1,15 @@ +open Eio.Std open Capnp_rpc_lwt module Version : sig type t = [`Version_ed7d11372e0a7243] Capability.t - val read : t -> string Lwt.t + val read : t -> string end type t = [`Registry_d9975f668b337b6d] Capability.t -val set_echo_service : t -> Echo.t -> unit Lwt.t +val set_echo_service : t -> Echo.t -> unit val echo_service : t -> Echo.t (** Waits until unblocked before returning. *) @@ -17,10 +18,10 @@ val echo_service_promise : t -> Echo.t (** Returns a promise immediately. Resolves promise when unblocked. (should appear to work the same as [echo_service] to users) *) -val unblock : t -> unit Lwt.t +val unblock : t -> unit val complex : t -> Echo.t * Version.t (** [complex t] returns two capabilities in a single, somewhat complex, message. *) -val local : unit -> t -(** [local ()] is a new local registry. *) +val local : sw:Switch.t -> unit -> t +(** [local ~sw ()] is a new local registry. *) diff --git a/examples/testlib/store.ml b/examples/testlib/store.ml index 5177d2747..731ec2951 100644 --- a/examples/testlib/store.ml +++ b/examples/testlib/store.ml @@ -1,4 +1,3 @@ -open Lwt.Infix open Capnp_rpc_lwt open Capnp_rpc_net @@ -57,7 +56,7 @@ module File = struct let get t = let open Api.Client.File.Get in let request = Capability.Request.create_no_args () in - Capability.call_for_value_exn t method_id request >|= Results.data_get + Capability.call_for_value_exn t method_id request |> Results.data_get let local (db:DB.t) sr digest = let module File = Api.Service.File in @@ -92,14 +91,14 @@ module File = struct let load t sr digest = if DB.mem t.db digest then ( let sr = Sturdy_ref.cast sr in - Lwt.return @@ Restorer.grant @@ local t.db sr digest + Restorer.grant @@ local t.db sr digest ) else ( - Lwt.return Restorer.unknown_service_id + Restorer.unknown_service_id ) end - let table ~make_sturdy db = - Restorer.Table.of_loader (module Loader) {Loader.db; make_sturdy} + let table ~sw ~make_sturdy db = + Restorer.Table.of_loader ~sw (module Loader) {Loader.db; make_sturdy} end type t = Api.Client.Store.t Capability.t @@ -121,12 +120,11 @@ let local ~restore db = let open Store.CreateFile in release_params (); let id = DB.add db in - Service.return_lwt @@ fun () -> - Restorer.restore restore id >|= function - | Error e -> Error (`Capnp (`Exception e)) + match Restorer.restore restore id with + | Error e -> Service.error (`Exception e) | Ok x -> let resp, results = Service.Response.create Results.init_pointer in Results.file_set results (Some x); Capability.dec_ref x; - Ok resp + Service.return resp end diff --git a/examples/testlib/store.mli b/examples/testlib/store.mli index b7ca9037c..b0a56bb81 100644 --- a/examples/testlib/store.mli +++ b/examples/testlib/store.mli @@ -16,13 +16,13 @@ end module File : sig type t = [`File_aec5916d9557ed0e] Capability.t - val set : t -> string -> unit Lwt.t + val set : t -> string -> unit (** [set t data] saves [data] as [t]'s contents. *) - val get : t -> string Lwt.t + val get : t -> string (** [get t] is the current contents of [t]. *) - val table : make_sturdy:(Restorer.Id.t -> Uri.t) -> DB.t -> Restorer.Table.t + val table : sw:Eio.Switch.t -> make_sturdy:(Restorer.Id.t -> Uri.t) -> DB.t -> Restorer.Table.t (** [table ~make_sturdy db] is a table of file services, backed by [db]. [make_sturdy] is used to generate sturdy URIs for files. *) end diff --git a/examples/v1/dune b/examples/v1/dune index 9d2eaf789..3eb517104 100644 --- a/examples/v1/dune +++ b/examples/v1/dune @@ -1,6 +1,6 @@ (executable (name main) - (libraries lwt.unix capnp-rpc-lwt logs.fmt) + (libraries eio_main capnp-rpc-lwt logs.fmt) (flags (:standard -w -53-55))) (rule diff --git a/examples/v1/echo.ml b/examples/v1/echo.ml index 8fa53dbed..1ed710272 100644 --- a/examples/v1/echo.ml +++ b/examples/v1/echo.ml @@ -1,7 +1,6 @@ (* $MDX part-begin=server *) module Api = Echo_api.MakeRPC(Capnp_rpc_lwt) -open Lwt.Infix open Capnp_rpc_lwt let local = @@ -26,5 +25,5 @@ let ping t msg = let open Echo.Ping in let request, params = Capability.Request.create Params.init_pointer in Params.msg_set params msg; - Capability.call_for_value_exn t method_id request >|= Results.reply_get + Capability.call_for_value_exn t method_id request |> Results.reply_get (* $MDX part-end *) diff --git a/examples/v1/main.ml b/examples/v1/main.ml index a8a396bec..a15068fcb 100644 --- a/examples/v1/main.ml +++ b/examples/v1/main.ml @@ -1,13 +1,11 @@ -open Lwt.Infix +open Eio.Std let () = Logs.set_level (Some Logs.Warning); Logs.set_reporter (Logs_fmt.reporter ()) let () = - Lwt_main.run begin - let service = Echo.local in - Echo.ping service "foo" >>= fun reply -> - Fmt.pr "Got reply %S@." reply; - Lwt.return_unit - end + Eio_main.run @@ fun _ -> + let service = Echo.local in + let reply = Echo.ping service "foo" in + traceln "Got reply %S" reply diff --git a/examples/v2/dune b/examples/v2/dune index 9d2eaf789..3eb517104 100644 --- a/examples/v2/dune +++ b/examples/v2/dune @@ -1,6 +1,6 @@ (executable (name main) - (libraries lwt.unix capnp-rpc-lwt logs.fmt) + (libraries eio_main capnp-rpc-lwt logs.fmt) (flags (:standard -w -53-55))) (rule diff --git a/examples/v2/echo.ml b/examples/v2/echo.ml index 933e76f35..af2406249 100644 --- a/examples/v2/echo.ml +++ b/examples/v2/echo.ml @@ -1,6 +1,5 @@ module Api = Echo_api.MakeRPC(Capnp_rpc_lwt) -open Lwt.Infix open Capnp_rpc_lwt module Callback = struct @@ -27,21 +26,21 @@ module Callback = struct end (* $MDX part-begin=notify *) -let (>>!=) = Lwt_result.bind (* Return errors *) - -let notify callback ~msg = +let notify ~clock msg callback = let rec loop = function | 0 -> - Lwt.return @@ Ok (Service.Response.create_empty ()) + Service.return_empty () | i -> - Callback.log callback msg >>!= fun () -> - Lwt_unix.sleep 1.0 >>= fun () -> - loop (i - 1) + match Callback.log callback msg with + | Error (`Capnp e) -> Service.error e + | Ok () -> + Eio.Time.sleep clock 1.0; + loop (i - 1) in loop 3 (* $MDX part-end *) -let local = +let local ~clock = let module Echo = Api.Service.Echo in Echo.local @@ object inherit Echo.service @@ -63,8 +62,7 @@ let local = match callback with | None -> Service.fail "No callback parameter!" | Some callback -> - Service.return_lwt @@ fun () -> - Capability.with_ref callback (notify ~msg) + Capability.with_ref callback (notify ~clock msg) (* $MDX part-end *) end @@ -74,7 +72,7 @@ let ping t msg = let open Echo.Ping in let request, params = Capability.Request.create Params.init_pointer in Params.msg_set params msg; - Capability.call_for_value_exn t method_id request >|= Results.reply_get + Capability.call_for_value_exn t method_id request |> Results.reply_get (* $MDX part-begin=client-heartbeat *) let heartbeat t msg callback = diff --git a/examples/v2/fake_clock.ml b/examples/v2/fake_clock.ml new file mode 100644 index 000000000..2736d35a2 --- /dev/null +++ b/examples/v2/fake_clock.ml @@ -0,0 +1,9 @@ +open Eio.Std + +(* We don't want delays while running the tests, so we replace the clock with this fake one. *) +let v = + object + inherit Eio.Time.clock + method sleep_until _ = Fiber.yield () + method now = 0.0 + end diff --git a/examples/v2/main.ml b/examples/v2/main.ml index ee9fb597f..dc7c3e0c1 100644 --- a/examples/v2/main.ml +++ b/examples/v2/main.ml @@ -1,3 +1,4 @@ +open Eio.Std open Capnp_rpc_lwt let () = @@ -5,14 +6,14 @@ let () = Logs.set_reporter (Logs_fmt.reporter ()) let callback_fn msg = - Fmt.pr "Callback got %S@." msg + traceln "Callback got %S" msg let run_client service = Capability.with_ref (Echo.Callback.local callback_fn) @@ fun callback -> Echo.heartbeat service "foo" callback let () = - Lwt_main.run begin - let service = Echo.local in - run_client service - end + Eio_main.run @@ fun env -> + let clock = if Sys.getenv_opt "CI" = None then env#clock else Fake_clock.v in + let service = Echo.local ~clock in + run_client service diff --git a/examples/v3/dune b/examples/v3/dune index fc80455c8..e7a552a10 100644 --- a/examples/v3/dune +++ b/examples/v3/dune @@ -1,6 +1,6 @@ (executable (name main) - (libraries lwt.unix capnp-rpc-unix logs.fmt) + (libraries eio_main capnp-rpc-unix logs.fmt) (flags (:standard -w -53-55))) (rule diff --git a/examples/v3/echo.ml b/examples/v3/echo.ml index 933e76f35..af2406249 100644 --- a/examples/v3/echo.ml +++ b/examples/v3/echo.ml @@ -1,6 +1,5 @@ module Api = Echo_api.MakeRPC(Capnp_rpc_lwt) -open Lwt.Infix open Capnp_rpc_lwt module Callback = struct @@ -27,21 +26,21 @@ module Callback = struct end (* $MDX part-begin=notify *) -let (>>!=) = Lwt_result.bind (* Return errors *) - -let notify callback ~msg = +let notify ~clock msg callback = let rec loop = function | 0 -> - Lwt.return @@ Ok (Service.Response.create_empty ()) + Service.return_empty () | i -> - Callback.log callback msg >>!= fun () -> - Lwt_unix.sleep 1.0 >>= fun () -> - loop (i - 1) + match Callback.log callback msg with + | Error (`Capnp e) -> Service.error e + | Ok () -> + Eio.Time.sleep clock 1.0; + loop (i - 1) in loop 3 (* $MDX part-end *) -let local = +let local ~clock = let module Echo = Api.Service.Echo in Echo.local @@ object inherit Echo.service @@ -63,8 +62,7 @@ let local = match callback with | None -> Service.fail "No callback parameter!" | Some callback -> - Service.return_lwt @@ fun () -> - Capability.with_ref callback (notify ~msg) + Capability.with_ref callback (notify ~clock msg) (* $MDX part-end *) end @@ -74,7 +72,7 @@ let ping t msg = let open Echo.Ping in let request, params = Capability.Request.create Params.init_pointer in Params.msg_set params msg; - Capability.call_for_value_exn t method_id request >|= Results.reply_get + Capability.call_for_value_exn t method_id request |> Results.reply_get (* $MDX part-begin=client-heartbeat *) let heartbeat t msg callback = diff --git a/examples/v3/fake_clock.ml b/examples/v3/fake_clock.ml new file mode 100644 index 000000000..2736d35a2 --- /dev/null +++ b/examples/v3/fake_clock.ml @@ -0,0 +1,9 @@ +open Eio.Std + +(* We don't want delays while running the tests, so we replace the clock with this fake one. *) +let v = + object + inherit Eio.Time.clock + method sleep_until _ = Fiber.yield () + method now = 0.0 + end diff --git a/examples/v3/main.ml b/examples/v3/main.ml index d039ae5b4..3c67514b9 100644 --- a/examples/v3/main.ml +++ b/examples/v3/main.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc_lwt let () = @@ -6,7 +6,7 @@ let () = Logs.set_reporter (Logs_fmt.reporter ()) let callback_fn msg = - Fmt.pr "Callback got %S@." msg + traceln "Callback got %S" msg let run_client service = Capability.with_ref (Echo.Callback.local callback_fn) @@ fun callback -> @@ -15,18 +15,20 @@ let run_client service = let secret_key = `Ephemeral let listen_address = `TCP ("127.0.0.1", 7000) -let start_server () = +let start_server ~sw ~clock net = let config = Capnp_rpc_unix.Vat_config.create ~secret_key listen_address in let service_id = Capnp_rpc_unix.Vat_config.derived_id config "main" in - let restore = Capnp_rpc_net.Restorer.single service_id Echo.local in - Capnp_rpc_unix.serve config ~restore >|= fun vat -> + let restore = Capnp_rpc_net.Restorer.single service_id (Echo.local ~clock) in + let vat = Capnp_rpc_unix.serve ~sw ~net ~restore config in Capnp_rpc_unix.Vat.sturdy_uri vat service_id let () = - Lwt_main.run begin - start_server () >>= fun uri -> - Fmt.pr "Connecting to echo service at: %a@." Uri.pp_hum uri; - let client_vat = Capnp_rpc_unix.client_only_vat () in - let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in - Sturdy_ref.with_cap_exn sr run_client - end + Eio_main.run @@ fun env -> + Switch.run @@ fun sw -> + let clock = if Sys.getenv_opt "CI" = None then env#clock else Fake_clock.v in + let uri = start_server ~sw ~clock env#net in + traceln "Connecting to echo service at: %a" Uri.pp_hum uri; + let client_vat = Capnp_rpc_unix.client_only_vat ~sw env#net in + let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in + Sturdy_ref.with_cap_exn sr run_client; + raise Exit diff --git a/examples/v4/client.ml b/examples/v4/client.ml index 95e556e4e..5aa68e339 100644 --- a/examples/v4/client.ml +++ b/examples/v4/client.ml @@ -1,3 +1,4 @@ +open Eio.Std open Capnp_rpc_lwt let () = @@ -5,18 +6,18 @@ let () = Logs.set_reporter (Logs_fmt.reporter ()) let callback_fn msg = - Fmt.pr "Callback got %S@." msg + traceln "Callback got %S" msg let run_client service = Capability.with_ref (Echo.Callback.local callback_fn) @@ fun callback -> Echo.heartbeat service "foo" callback let connect uri = - Lwt_main.run begin - let client_vat = Capnp_rpc_unix.client_only_vat () in - let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in - Capnp_rpc_unix.with_cap_exn sr run_client - end + Eio_main.run @@ fun env -> + Switch.run @@ fun sw -> + let client_vat = Capnp_rpc_unix.client_only_vat ~sw env#net in + let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in + Capnp_rpc_unix.with_cap_exn sr run_client open Cmdliner diff --git a/examples/v4/dune b/examples/v4/dune index 26674af9a..564e9b038 100644 --- a/examples/v4/dune +++ b/examples/v4/dune @@ -1,6 +1,6 @@ (executables (names client server) - (libraries lwt.unix capnp-rpc-lwt logs.fmt capnp-rpc-unix) + (libraries eio_main capnp-rpc-lwt logs.fmt capnp-rpc-unix) (flags (:standard -w -53-55))) (rule diff --git a/examples/v4/echo.ml b/examples/v4/echo.ml index bfedc5664..ce345321d 100644 --- a/examples/v4/echo.ml +++ b/examples/v4/echo.ml @@ -1,6 +1,5 @@ module Api = Echo_api.MakeRPC(Capnp_rpc_lwt) -open Lwt.Infix open Capnp_rpc_lwt module Callback = struct @@ -26,20 +25,20 @@ module Callback = struct Capability.call_for_unit t method_id request end -let (>>!=) = Lwt_result.bind (* Return errors *) - -let notify callback ~msg = +let notify ~clock msg callback = let rec loop = function | 0 -> - Lwt.return @@ Ok (Service.Response.create_empty ()) + Service.return_empty () | i -> - Callback.log callback msg >>!= fun () -> - Lwt_unix.sleep 1.0 >>= fun () -> - loop (i - 1) + match Callback.log callback msg with + | Error (`Capnp e) -> Service.error e + | Ok () -> + Eio.Time.sleep clock 1.0; + loop (i - 1) in loop 3 -let local = +let local ~clock = let module Echo = Api.Service.Echo in Echo.local @@ object inherit Echo.service @@ -60,8 +59,7 @@ let local = match callback with | None -> Service.fail "No callback parameter!" | Some callback -> - Service.return_lwt @@ fun () -> - Capability.with_ref callback (notify ~msg) + Capability.with_ref callback (notify ~clock msg) end module Echo = Api.Client.Echo @@ -70,7 +68,7 @@ let ping t msg = let open Echo.Ping in let request, params = Capability.Request.create Params.init_pointer in Params.msg_set params msg; - Capability.call_for_value_exn t method_id request >|= Results.reply_get + Capability.call_for_value_exn t method_id request |> Results.reply_get let heartbeat t msg callback = let open Echo.Heartbeat in diff --git a/examples/v4/fake_clock.ml b/examples/v4/fake_clock.ml new file mode 100644 index 000000000..2736d35a2 --- /dev/null +++ b/examples/v4/fake_clock.ml @@ -0,0 +1,9 @@ +open Eio.Std + +(* We don't want delays while running the tests, so we replace the clock with this fake one. *) +let v = + object + inherit Eio.Time.clock + method sleep_until _ = Fiber.yield () + method now = 0.0 + end diff --git a/examples/v4/server.ml b/examples/v4/server.ml index 6439c6e11..f9bc05bca 100644 --- a/examples/v4/server.ml +++ b/examples/v4/server.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std open Capnp_rpc_net let () = @@ -8,16 +8,17 @@ let () = let cap_file = "echo.cap" let serve config = - Lwt_main.run begin - let service_id = Capnp_rpc_unix.Vat_config.derived_id config "main" in - let restore = Restorer.single service_id Echo.local in - Capnp_rpc_unix.serve config ~restore >>= fun vat -> - match Capnp_rpc_unix.Cap_file.save_service vat service_id cap_file with - | Error `Msg m -> failwith m - | Ok () -> - Fmt.pr "Server running. Connect using %S.@." cap_file; - fst @@ Lwt.wait () (* Wait forever *) - end + Eio_main.run @@ fun env -> + let clock = if Sys.getenv_opt "CI" = None then env#clock else Fake_clock.v in + Switch.run @@ fun sw -> + let service_id = Capnp_rpc_unix.Vat_config.derived_id config "main" in + let restore = Restorer.single service_id (Echo.local ~clock) in + let vat = Capnp_rpc_unix.serve ~sw ~net:env#net ~restore config in + match Capnp_rpc_unix.Cap_file.save_service vat service_id cap_file with + | Error `Msg m -> failwith m + | Ok () -> + traceln "Server running. Connect using %S." cap_file; + Fiber.await_cancel () open Cmdliner diff --git a/fuzz/fuzz.ml b/fuzz/fuzz.ml index 0b4fe272b..acffbeaf5 100644 --- a/fuzz/fuzz.ml +++ b/fuzz/fuzz.ml @@ -214,9 +214,11 @@ module Endpoint = struct let check t = Conn.check t.conn + let fork f = f () + let create ~restore ~tags ~dump ~local_id ~remote_id xmit_queue recv_queue = let queue_send x = Queue.add x xmit_queue in - let conn = Conn.create ~restore ~tags ~queue_send in + let conn = Conn.create ~restore ~tags ~queue_send ~fork in { local_id; remote_id; diff --git a/mirage/capnp_rpc_mirage.ml b/mirage/capnp_rpc_mirage.ml deleted file mode 100644 index 52af4027f..000000000 --- a/mirage/capnp_rpc_mirage.ml +++ /dev/null @@ -1,54 +0,0 @@ -open Lwt.Infix - -module Log = Capnp_rpc.Debug.Log - -module Location = Network.Location - -module Make (R : Mirage_random.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) (P : Mirage_clock.PCLOCK) (Stack : Tcpip.Stack.V4V6) = struct - - module Dns = Dns_client_mirage.Make(R)(T)(M)(P)(Stack) - module Network = Network.Make(R)(T)(M)(P)(Stack) - module Vat_config = Vat_config.Make(Network) - module Vat_network = Capnp_rpc_net.Networking(Network)(Stack.TCP) - - type flow = Stack.TCP.flow - - module CapTP = Vat_network.CapTP - module Vat = Vat_network.Vat - - let network ~dns stack = {Network.dns; stack} - - let handle_connection ?tags ~secret_key vat flow = - let switch = Lwt_switch.create () in - Network.accept_connection ~switch ~secret_key flow >>= function - | Error (`Msg msg) -> - Log.warn (fun f -> f ?tags "Rejecting new connection: %s" msg); - Lwt.return_unit - | Ok ep -> - Vat.add_connection vat ~switch ~mode:`Accept ep >|= fun (_ : CapTP.t) -> - () - - let serve ?switch ?tags ?restore t config = - let {Vat_config.secret_key = _; serve_tls; listen_address; public_address} = config in - let vat = - let auth = Vat_config.auth config in - let secret_key = lazy (fst (Lazy.force config.secret_key)) in - Vat.create ?switch ?tags ?restore ~address:(public_address, auth) ~secret_key t - in - match listen_address with - | `TCP port -> - Stack.TCP.listen (Stack.tcp t.stack) ~port (fun flow -> - Log.info (fun f -> f ?tags "Accepting new connection"); - let secret_key = if serve_tls then Some (Vat_config.secret_key config) else None in - Lwt.async (fun () -> handle_connection ?tags ~secret_key vat flow); - Lwt.return_unit - ); - Log.info (fun f -> f ?tags "Waiting for %s connections on %a" - (if serve_tls then "(encrypted)" else "UNENCRYPTED") - Vat_config.Listen_address.pp listen_address); - Lwt.return vat - - let client_only_vat ?switch ?tags ?restore t = - let secret_key = lazy (Capnp_rpc_net.Auth.Secret_key.generate ()) in - Vat.create ?switch ?tags ?restore ~secret_key t -end diff --git a/mirage/capnp_rpc_mirage.mli b/mirage/capnp_rpc_mirage.mli deleted file mode 100644 index 7e766a3dc..000000000 --- a/mirage/capnp_rpc_mirage.mli +++ /dev/null @@ -1,72 +0,0 @@ -(** Helpers for using {!Capnp_rpc_lwt} with MirageOS. *) - -open Capnp_rpc_net - -module Location = Network.Location - -module Make (R : Mirage_random.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) (P : Mirage_clock.PCLOCK) (Stack : Tcpip.Stack.V4V6) : sig - include Capnp_rpc_net.VAT_NETWORK with - type flow = Stack.TCP.flow and - module Network = Network.Make(R)(T)(M)(P)(Stack) - - module Vat_config : sig - module Listen_address : sig - type t = [`TCP of int] - val pp : t Fmt.t - end - - type t - - val create : - public_address:Location.t -> - secret_key:[< `PEM of string | `Ephemeral] -> - ?serve_tls:bool -> - Listen_address.t -> t - (** [create ~public_address ~secret_key listen_address] is the configuration for a server vat that - listens on address [listen_address]. - [secret_key] may be one of: - - [`PEM data]: the given PEM-encoded data is used as the key. - - [`Ephemeral]: a new key is generated (if needed) and not saved anywhere. - If [serve_tls] is [false] then the vat accepts unencrypted incoming connections. - If [true] (the default), the vat performs a server TLS handshake, using - [secret_key] to prove its identity to clients. - The vat will suggest that others connect to it at [public_address]. *) - - val secret_key : t -> Auth.Secret_key.t - (** [secret_key t] returns the vat's secret yet, generating it if this is the first time - it has been used. *) - - val hashed_secret : t -> string - (** [hashed_secret t] is the SHA256 digest of the secret key file. - This is useful as an input to {!Restorer.Id.derived}. *) - - val derived_id : t -> string -> Restorer.Id.t - (** [derived_id t name] is a secret service ID derived from name and the - vat's secret key (using {!Restorer.Id.derived}). It won't change - (unless the vat's key changes). *) - - val sturdy_uri : t -> Restorer.Id.t -> Uri.t - (** [sturdy_uri t id] is a sturdy URI for [id] at the vat that would be - created by [t]. *) - end - - val network : dns:Network.Dns.t -> Stack.t -> Network.t - - val serve : - ?switch:Lwt_switch.t -> - ?tags:Logs.Tag.set -> - ?restore:Restorer.t -> - Network.t -> - Vat_config.t -> - Vat.t Lwt.t - (** [serve ~restore net vat_config] is a new vat that is listening for new connections - as specified by [vat_config]. After connecting to it, clients can get access - to services using [restore]. *) - - val client_only_vat : - ?switch:Lwt_switch.t -> - ?tags:Logs.Tag.set -> - ?restore:Restorer.t -> - Network.t -> Vat.t - (** [client_only_vat net] is a new vat that does not listen for incoming connections. *) -end diff --git a/mirage/dune b/mirage/dune deleted file mode 100644 index cd38a5e3d..000000000 --- a/mirage/dune +++ /dev/null @@ -1,4 +0,0 @@ -(library - (name capnp_rpc_mirage) - (public_name capnp-rpc-mirage) - (libraries capnp-rpc-lwt capnp-rpc-net capnp-rpc fmt logs dns-client.mirage tcpip)) diff --git a/mirage/network.ml b/mirage/network.ml deleted file mode 100644 index d43a9859c..000000000 --- a/mirage/network.ml +++ /dev/null @@ -1,93 +0,0 @@ -open Lwt.Infix -module Log = Capnp_rpc.Debug.Log - -let error fmt = - fmt |> Fmt.kstr @@ fun msg -> - Error (`Msg msg) - -module Location = struct - type t = [ - | `TCP of string * int - ] - - let tcp ~host ~port = `TCP (host, port) - - let pp f x = Capnp_rpc_net.Capnp_address.(Location.pp f (x :> Location.t)) - - let equal = ( = ) -end - -module Make (R : Mirage_random.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) (P : Mirage_clock.PCLOCK) (Stack : Tcpip.Stack.V4V6) = struct - - module Dns = Dns_client_mirage.Make(R)(T)(M)(P)(Stack) - module Tls_wrapper = Capnp_rpc_net.Tls_wrapper.Make(Stack.TCP) - - module Address = struct - module Full = Capnp_rpc_net.Capnp_address - - type t = Location.t * Capnp_rpc_net.Auth.Digest.t - - let digest t = Full.digest (t :> Full.t) - let to_uri (t, id) = Full.to_uri ((t :> Full.t), id) - let pp f t = Full.pp f (t :> Full.t) - - let parse_uri uri = - match Full.parse_uri uri with - | Error _ as e -> e - | Ok ((`TCP _, _), _ as x) -> Ok x - | Ok ((`Unix x, _), _) -> - error "Unix-domain addresses are not available with Mirage networking (%S)" x - - let equal (addr, auth) (addr_b, auth_b) = - Location.equal addr addr_b && - Capnp_rpc_net.Auth.Digest.equal auth auth_b - end - - module Types = struct - type provision_id - type recipient_id - type third_party_cap_id = [`Two_party_only] - type join_key_part - end - - let parse_third_party_cap_id _ = `Two_party_only - - type t = { - stack : Stack.t; - dns : Dns.t; - } - - let addr_of_host dns host = - match Ipaddr.of_string host with - | Ok ip -> Lwt.return @@ Ok ip - | Error (`Msg _) -> - match Domain_name.of_string host with - | Ok dn -> begin - match Domain_name.host dn with - | Ok h -> begin - Dns.gethostbyname dns h >>= function - | Ok addr -> Lwt.return_ok (Ipaddr.V4 addr) - | Error (`Msg error_msg) -> Lwt.return @@ error "Unknown host %S : %s" host error_msg - end - | Error (`Msg error_msg) -> Lwt.return @@ error "Invalid hostname %S : %s" host error_msg - end - | Error (`Msg error_msg) -> Lwt.return @@ error "Bad domain name %S : %s" host error_msg - - let ( >>*= ) x f = - x >>= function - | Error _ as e -> Lwt.return e - | Ok y -> f y - - let connect t ~switch ~secret_key (addr, auth) = - match addr with - | `TCP (host, port) -> - Logs.info (fun f -> f "Connecting to %s:%d..." host port); - addr_of_host t.dns host >>*= fun addr -> - let tcp = Stack.tcp t.stack in - Stack.TCP.create_connection tcp (addr, port) >>= function - | Error e -> Lwt.return @@ error "Failed to connect to %S:%d: %a" host port Stack.TCP.pp_error e - | Ok flow -> Tls_wrapper.connect_as_client ~switch flow secret_key auth - - let accept_connection ~switch ~secret_key flow = - Tls_wrapper.connect_as_server ~switch flow secret_key -end diff --git a/mirage/network.mli b/mirage/network.mli deleted file mode 100644 index fecfdaf33..000000000 --- a/mirage/network.mli +++ /dev/null @@ -1,37 +0,0 @@ -(** A capnp network build from a Mirage network stack. *) - -module Location : sig - type t = [ - | `TCP of string * int - ] - - val pp : t Fmt.t - - val equal : t -> t -> bool - - val tcp : host:string -> port:int -> t - (** [tcp ~host port] is [`TCP (host, port)]. *) -end - -module Make (R : Mirage_random.S) (T : Mirage_time.S) (M : Mirage_clock.MCLOCK) (P : Mirage_clock.PCLOCK) (Stack : Tcpip.Stack.V4V6) : sig - - module Dns : module type of Dns_client_mirage.Make(R)(T)(M)(P)(Stack) - - type t = { - stack : Stack.t; - dns : Dns.t; - } - - include Capnp_rpc_net.S.NETWORK with - type t := t and - type Address.t = Location.t * Capnp_rpc_net.Auth.Digest.t - - val accept_connection : - switch:Lwt_switch.t -> - secret_key:Capnp_rpc_net.Auth.Secret_key.t option -> - Stack.TCP.flow -> - (Capnp_rpc_net.Endpoint.t, [> `Msg of string]) result Lwt.t - (** [accept_connection ~switch ~secret_key flow] is a new endpoint for [flow]. - If [secret_key] is not [None], it is used to perform a TLS server-side handshake. - Otherwise, the connection is not encrypted. *) -end diff --git a/mirage/vat_config.ml b/mirage/vat_config.ml deleted file mode 100644 index d13a5b91d..000000000 --- a/mirage/vat_config.ml +++ /dev/null @@ -1,57 +0,0 @@ -module Auth = Capnp_rpc_net.Auth -module Log = Capnp_rpc.Debug.Log - -module Secret_hash : sig - type t - - val of_pem_data : string -> t - val to_string : t -> string -end = struct - type t = string - - let of_pem_data data = Mirage_crypto.Hash.SHA256.digest (Cstruct.of_string data) |> Cstruct.to_string - let to_string x = x -end - -module Make (N : Capnp_rpc_net.S.NETWORK with type Address.t = Network.Location.t * Auth.Digest.t) = struct - module Listen_address = struct - type t = [`TCP of int] - - let pp f (`TCP port) = - Fmt.pf f "tcp:%d" port - end - - type t = { - secret_key : (Auth.Secret_key.t * Secret_hash.t) Lazy.t; - serve_tls : bool; - listen_address : Listen_address.t; - public_address : Network.Location.t; - } - - let secret_key t = fst @@ Lazy.force t.secret_key - - let hashed_secret t = Secret_hash.to_string @@ snd @@ Lazy.force t.secret_key - - let create ~public_address ~secret_key ?(serve_tls=true) listen_address = - let secret_key = lazy ( - match secret_key with - | `PEM data -> (Auth.Secret_key.of_pem_data data, Secret_hash.of_pem_data data) - | `Ephemeral -> - let key = Auth.Secret_key.generate () in - let data = Auth.Secret_key.to_pem_data key in - (key, Secret_hash.of_pem_data data) - ) in - { secret_key; serve_tls; listen_address; public_address } - - let derived_id t name = - let secret = hashed_secret t in - Capnp_rpc_net.Restorer.Id.derived ~secret name - - let auth t = - if t.serve_tls then Capnp_rpc_net.Auth.Secret_key.digest (secret_key t) - else Capnp_rpc_net.Auth.Digest.insecure - - let sturdy_uri t service = - let address = (t.public_address, auth t) in - N.Address.to_uri (address, Capnp_rpc_net.Restorer.Id.to_string service) -end diff --git a/test-bin/calc.ml b/test-bin/calc.ml index 881ad436e..075493540 100644 --- a/test-bin/calc.ml +++ b/test-bin/calc.ml @@ -1,4 +1,4 @@ -open Lwt.Infix +open Eio.Std module Vat = Capnp_rpc_unix.Vat module Calc = Testlib.Calc @@ -30,29 +30,29 @@ let reporter = (* Run as server *) let serve vat_config = - Lwt_main.run begin - let service_id = Capnp_rpc_net.Restorer.Id.public "" in - let restore = Capnp_rpc_net.Restorer.single service_id Calc.local in - Capnp_rpc_unix.serve vat_config ~restore >>= fun vat -> - let sr = Vat.sturdy_uri vat service_id in - Fmt.pr "Waiting for incoming connections at:@.%a@." Uri.pp_hum sr; - fst @@ Lwt.wait () - end + Eio_main.run @@ fun env -> + Switch.run @@ fun sw -> + let service_id = Capnp_rpc_net.Restorer.Id.public "" in + let service = Calc.local ~sw in + let restore = Capnp_rpc_net.Restorer.single service_id service in + let vat = Capnp_rpc_unix.serve ~sw ~net:env#net ~restore vat_config in + let sr = Vat.sturdy_uri vat service_id in + traceln "Waiting for incoming connections at:@.%a" Uri.pp_hum sr; + Fiber.await_cancel () (* Run as client *) let connect addr = - Lwt_main.run begin - let vat = Capnp_rpc_unix.client_only_vat () in - let sr = Vat.import_exn vat addr in - Capnp_rpc_unix.with_cap_exn sr @@ fun calc -> - Logs.info (fun f -> f "Evaluating expression..."); - let remote_add = Calc.getOperator calc `Add in - let result = Calc.evaluate calc Calc.Expr.(Call (remote_add, [Float 40.0; Float 2.0])) in - Calc.Value.read result >>= fun v -> - Fmt.pr "Result: %f@." v; - Lwt.return_unit - end + Eio_main.run @@ fun env -> + Switch.run @@ fun sw -> + let vat = Capnp_rpc_unix.client_only_vat ~sw env#net in + let sr = Vat.import_exn vat addr in + Capnp_rpc_unix.with_cap_exn sr @@ fun calc -> + Logs.info (fun f -> f "Evaluating expression..."); + let remote_add = Calc.getOperator calc `Add in + let result = Calc.evaluate calc Calc.Expr.(Call (remote_add, [Float 40.0; Float 2.0])) in + let v = Calc.Value.read result in + traceln "Result: %f" v (* Command-line parsing *) diff --git a/test-bin/calc_direct.ml b/test-bin/calc_direct.ml index 1f75b66bc..4e27149b1 100644 --- a/test-bin/calc_direct.ml +++ b/test-bin/calc_direct.ml @@ -1,7 +1,7 @@ (* Run the calc service as a child process, connecting directly over a socketpair. Unlike a normal connection, there is no encryption or use of sturdy refs here. *) -open Lwt.Infix +open Eio.Std module Calc = Testlib.Calc @@ -32,45 +32,40 @@ end module Parent = struct let run socket = Logging.init "parent"; + Switch.run @@ fun sw -> (* Run Cap'n Proto RPC protocol on [socket]: *) - Lwt_switch.with_switch @@ fun switch -> - let p = Lwt_unix.of_unix_file_descr socket - |> Capnp_rpc_unix.Unix_flow.connect ~switch - |> Capnp_rpc_net.Endpoint.of_flow (module Capnp_rpc_unix.Unix_flow) + let p = Eio_unix.FD.as_socket ~sw ~close_unix:true socket + |> Capnp_rpc_net.Endpoint.of_flow ~peer_id:Capnp_rpc_net.Auth.Digest.insecure - ~switch in + in Logs.info (fun f -> f "Connecting to child process..."); - let conn = Capnp_rpc_unix.CapTP.connect ~restore:Capnp_rpc_net.Restorer.none p in + let conn = Capnp_rpc_unix.CapTP.connect ~sw ~restore:Capnp_rpc_net.Restorer.none p in + Fiber.fork_daemon ~sw (fun () -> Capnp_rpc_unix.CapTP.listen conn; `Stop_daemon); (* Get the child's service object: *) let calc = Capnp_rpc_unix.CapTP.bootstrap conn service_name in (* Use the service: *) Logs.app (fun f -> f "Sending request..."); let remote_mul = Calc.getOperator calc `Multiply in let result = Calc.evaluate calc Calc.Expr.(Call (remote_mul, [Float 21.0; Float 2.0])) in - Calc.Value.read result >>= fun v -> + let v = Calc.Value.read result in Logs.app (fun f -> f "Result: %f" v); - Logs.app (fun f -> f "Shutting down..."); - Lwt.return_unit + Logs.app (fun f -> f "Shutting down...") end module Child = struct - let service = Calc.local - let run socket = Logging.init "child"; - Lwt_main.run begin - Lwt_switch.with_switch @@ fun switch -> - let restore = Capnp_rpc_net.Restorer.single service_name service in - (* Run Cap'n Proto RPC protocol on [socket]: *) - let endpoint = Capnp_rpc_unix.Unix_flow.connect (Lwt_unix.of_unix_file_descr socket) - |> Capnp_rpc_net.Endpoint.of_flow (module Capnp_rpc_unix.Unix_flow) - ~peer_id:Capnp_rpc_net.Auth.Digest.insecure - ~switch - in - let _ : Capnp_rpc_unix.CapTP.t = Capnp_rpc_unix.CapTP.connect ~restore endpoint in - Logs.info (fun f -> f "Serving requests..."); - fst (Lwt.wait ()) (* Wait forever *) - end + Switch.run @@ fun sw -> + let socket = Eio_unix.FD.as_socket ~sw ~close_unix:false socket in + let service = Calc.local ~sw in + let restore = Capnp_rpc_net.Restorer.single service_name service in + (* Run Cap'n Proto RPC protocol on [socket]: *) + let endpoint = Capnp_rpc_net.Endpoint.of_flow socket + ~peer_id:Capnp_rpc_net.Auth.Digest.insecure + in + let conn = Capnp_rpc_unix.CapTP.connect ~sw ~restore endpoint in + Logs.info (fun f -> f "Serving requests..."); + Capnp_rpc_unix.CapTP.listen conn end let find_our_path prog = @@ -82,25 +77,32 @@ let find_our_path prog = else Fmt.failwith "Can't find path to own binary %S from %S" prog (Sys.getcwd ()) ) +let await_exit pid = + Eio_unix.run_in_systhread @@ fun () -> + let rec aux () = + match Unix.waitpid [] pid with + | exception Unix.Unix_error (Unix.EINTR, _, _) -> aux () + | _pid, status -> status + in + aux () + let () = - Lwt_main.run begin - match Sys.argv with - | [| prog |] -> - (* We are the parent. *) - let prog = find_our_path prog in - let p, c = Unix.(socketpair PF_UNIX SOCK_STREAM 0 ~cloexec:true) in - Unix.clear_close_on_exec c; - (* Run the child, passing the socket as its stdin. *) - let child = Lwt_process.open_process_none ~stdin:(`FD_move c) ("", [| prog; "--child" |]) in - Parent.run p >>= fun () -> - Logs.info (fun f -> f "Waiting for child to exit..."); - child#terminate; - child#status >>= fun _ -> - Logs.info (fun f -> f "Done"); - Lwt.return_unit - | [| _prog; "--child" |] -> - (* We are the child. Our socket is on stdin. *) - Child.run Unix.stdin - | _ -> - failwith "Run this command without arguments." - end + Eio_main.run @@ fun _env -> + match Sys.argv with + | [| prog |] -> + (* We are the parent. *) + let prog = find_our_path prog in + let p, c = Unix.(socketpair PF_UNIX SOCK_STREAM 0 ~cloexec:true) in + Unix.clear_close_on_exec c; + (* Run the child, passing the socket as its stdin. *) + let child = Unix.create_process prog [| prog; "--child" |] c Unix.stdout Unix.stderr in + Parent.run p; + Logs.info (fun f -> f "Waiting for child to exit..."); + Unix.kill child Sys.sigkill; + let _ : Unix.process_status = await_exit child in + Logs.info (fun f -> f "Done") + | [| _prog; "--child" |] -> + (* We are the child. Our socket is on stdin. *) + Child.run Unix.stdin + | _ -> + failwith "Run this command without arguments." diff --git a/test-bin/dune b/test-bin/dune index 9c46081cf..8220e654d 100644 --- a/test-bin/dune +++ b/test-bin/dune @@ -1,3 +1,3 @@ (executables (names calc calc_direct) - (libraries testlib cmdliner astring logs.fmt fmt.tty capnp-rpc-unix)) + (libraries testlib cmdliner astring logs.fmt fmt.tty capnp-rpc-unix eio_main)) diff --git a/test-bin/echo/dune b/test-bin/echo/dune index 19562b5df..b638b44ed 100644 --- a/test-bin/echo/dune +++ b/test-bin/echo/dune @@ -1,6 +1,6 @@ (executable (name echo_bench) - (libraries lwt.unix capnp-rpc capnp-rpc-lwt capnp-rpc-net capnp-rpc-unix logs.fmt) + (libraries eio_main capnp-rpc capnp-rpc-lwt capnp-rpc-net capnp-rpc-unix logs.fmt) (flags (:standard -w -53-55))) (rule diff --git a/test-bin/echo/echo.ml b/test-bin/echo/echo.ml index 4b207c644..763838b20 100755 --- a/test-bin/echo/echo.ml +++ b/test-bin/echo/echo.ml @@ -1,6 +1,5 @@ module Api = Echo_api.MakeRPC(Capnp_rpc_lwt) -open Lwt.Infix open Capnp_rpc_lwt (*-- Server ----------------------------------------*) @@ -29,4 +28,4 @@ let ping t msg = let message_size = 200 + String.length msg in (* (rough estimate) *) let request, params = Capability.Request.create ~message_size Params.init_pointer in Params.msg_set params msg; - Capability.call_for_value_exn t method_id request >|= Results.reply_get + Capability.call_for_value_exn t method_id request |> Results.reply_get diff --git a/test-bin/echo/echo_bench.ml b/test-bin/echo/echo_bench.ml index 5c6a13b65..938023251 100755 --- a/test-bin/echo/echo_bench.ml +++ b/test-bin/echo/echo_bench.ml @@ -1,5 +1,4 @@ - -open Lwt.Infix +open Eio.Std open Capnp_rpc_lwt @@ -8,36 +7,37 @@ let () = Logs.set_reporter (Logs_fmt.reporter ()) let run_client service = - let n = 100000 in + (* let n = 100000 in *) (* XXX: improve speed *) + let n = 1000 in let ops = List.init n (fun i -> let payload = Int.to_string i in let desired_result = "echo:" ^ payload in fun () -> - Echo.ping service payload >|= fun res -> + let res = Echo.ping service payload in assert (res = desired_result) ) in let st = Unix.gettimeofday () in - Lwt_stream.of_list ops |> Lwt_stream.iter_n ~max_concurrency:12 (fun v -> v ()) >>= fun () -> + ops |> Fiber.iter ~max_fibers:12 (fun v -> v ()); let ed = Unix.gettimeofday () in let rate = (Int.to_float n) /. (ed -. st) in - Logs.info (fun m -> m "rate = %f" rate ); - Lwt.return_unit + Logs.info (fun m -> m "rate = %f" rate ) let secret_key = `Ephemeral let listen_address = `TCP ("127.0.0.1", 7000) -let start_server () = +let start_server ~sw net = let config = Capnp_rpc_unix.Vat_config.create ~secret_key ~serve_tls:false listen_address in let service_id = Capnp_rpc_unix.Vat_config.derived_id config "main" in let restore = Capnp_rpc_net.Restorer.single service_id Echo.local in - Capnp_rpc_unix.serve config ~restore >|= fun vat -> + let vat = Capnp_rpc_unix.serve ~sw ~net ~restore config in Capnp_rpc_unix.Vat.sturdy_uri vat service_id let () = - Lwt_main.run begin - start_server () >>= fun uri -> - Fmt.pr "Connecting to echo service at: %a@." Uri.pp_hum uri; - let client_vat = Capnp_rpc_unix.client_only_vat () in - let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in - Sturdy_ref.with_cap_exn sr run_client - end + Eio_main.run @@ fun env -> + Switch.run @@ fun sw -> + let uri = start_server ~sw env#net in + Fmt.pr "Connecting to echo service at: %a@." Uri.pp_hum uri; + let client_vat = Capnp_rpc_unix.client_only_vat ~sw env#net in + let sr = Capnp_rpc_unix.Vat.import_exn client_vat uri in + Sturdy_ref.with_cap_exn sr run_client; + raise Exit diff --git a/test-lwt/dune b/test-lwt/dune index 30ea84cd2..f2b55d006 100644 --- a/test-lwt/dune +++ b/test-lwt/dune @@ -1,5 +1,5 @@ (test (package capnp-rpc-unix) (name test_lwt) - (libraries capnp-rpc-lwt capnp-rpc-unix alcotest-lwt testlib logs.fmt - testbed)) + (libraries capnp-rpc-lwt capnp-rpc-unix testlib logs.fmt + testbed eio_main)) diff --git a/test-lwt/test_lwt.ml b/test-lwt/test_lwt.ml index 4b30ba299..e2758bc70 100644 --- a/test-lwt/test_lwt.ml +++ b/test-lwt/test_lwt.ml @@ -1,6 +1,6 @@ +open Eio.Std open Astring open Testlib -open Lwt.Infix open Capnp_rpc_lwt open Capnp_rpc_net @@ -8,17 +8,23 @@ module Test_utils = Testbed.Test_utils module Vat = Capnp_rpc_unix.Vat module CapTP = Capnp_rpc_unix.CapTP -module Unix_flow = Capnp_rpc_unix.Unix_flow -module Tls_wrapper = Capnp_rpc_net.Tls_wrapper.Make(Unix_flow) +module Tls_wrapper = Capnp_rpc_net.Tls_wrapper module Exception = Capnp_rpc.Exception +exception Simulated_failure + +let ( let/ ) x f = f (x ()) +let ( and/ ) x y () = Fiber.pair x y + +let _debug () = + Logs.Src.set_level Capnp_rpc.Debug.src (Some Logs.Debug) + type cs = { client : Vat.t; server : Vat.t; client_key : Auth.Secret_key.t; server_key : Auth.Secret_key.t; serve_tls : bool; - server_switch : Lwt_switch.t; } let ensure_removed path = @@ -27,7 +33,7 @@ let ensure_removed path = let next_port = ref 8000 -let get_test_address ~switch name = +let get_test_address ~sw name = match Sys.os_type with | "Win32" -> (* No Unix-domain sockets on Windows *) @@ -36,7 +42,7 @@ let get_test_address ~switch name = `TCP ("127.0.0.1", port) | _ -> let socket_path = Filename.(concat (Filename.get_temp_dir_name ())) name in - Lwt_switch.add_hook (Some switch) (fun () -> Lwt.return @@ ensure_removed socket_path); + Switch.on_release sw (fun () -> ensure_removed socket_path); `Unix socket_path (* Have the client ask the server for its bootstrap object, and return the @@ -71,206 +77,199 @@ let () = Logs.(set_level (Some Logs.Info)) let server_pem = `PEM (Auth.Secret_key.to_pem_data server_key) -let make_vats_full ?(serve_tls=false) ~client_switch ~server_switch ~restore () = - let server_config = - let addr = get_test_address ~switch:server_switch "capnp-rpc-test-server" in - Capnp_rpc_unix.Vat_config.create ~secret_key:server_pem ~serve_tls addr +let make_vats_full ?(serve_tls=false) ?server_sw ~sw ~net ~restore () = + let server = + let sw = Option.value server_sw ~default:sw in + let server_config = + let addr = get_test_address ~sw "capnp-rpc-test-server" in + Capnp_rpc_unix.Vat_config.create ~secret_key:server_pem ~serve_tls addr + in + Capnp_rpc_unix.serve ~sw ~net ~tags:Test_utils.server_tags ~restore server_config in - Capnp_rpc_unix.serve ~switch:server_switch ~tags:Test_utils.server_tags ~restore server_config >>= fun server -> - Lwt.return { - client = Vat.create ~switch:client_switch ~tags:Test_utils.client_tags ~secret_key:(lazy client_key) (); + { + client = Vat.create ~sw ~tags:Test_utils.client_tags ~secret_key:(lazy client_key) net; server; client_key; server_key; serve_tls; - server_switch; } -let make_vats ?serve_tls ~switch ~service () = - let server_switch = Lwt_switch.create () in - Lwt_switch.add_hook (Some switch) (fun () -> Lwt_switch.turn_off server_switch); - let id = Restorer.Id.public "" in - let restore = Restorer.single id service in - Lwt_switch.add_hook (Some switch) (fun () -> Capability.dec_ref service; Lwt.return_unit); - make_vats_full ?serve_tls ~client_switch:switch ~server_switch ~restore () - -(* Generic Lwt running for Alcotest. *) -let run_lwt name ?(expected_warnings=0) fn = - Alcotest_lwt.test_case name `Quick @@ fun sw () -> +let with_vats ?serve_tls ?server_sw ~net ~service fn = + try + Switch.run @@ fun sw -> + let id = Restorer.Id.public "" in + let restore = Restorer.single id service in + Switch.on_release sw (fun () -> Capability.dec_ref service); + fn @@ make_vats_full ?serve_tls ?server_sw ~sw ~net ~restore (); + Logs.info (fun f -> f "Test finished; shutting down vats..."); + raise Exit (* Stop vats *) + with Exit -> () + +(* Generic running for Alcotest. *) +let run_eio ~net name ?(expected_warnings=0) fn = + Alcotest.test_case name `Quick @@ fun () -> let warnings_at_start = Logs.(err_count () + warn_count ()) in Logs.info (fun f -> f "Start test-case"); - let finished = ref false in - Lwt_switch.add_hook (Some sw) (fun () -> - if not !finished then !Lwt.async_exception_hook (Failure "Switch turned off early"); - Lwt.return_unit - ); - fn sw >>= fun () -> finished := true; - Lwt_switch.turn_off sw >|= fun () -> - Gc.full_major (); - Lwt.wakeup_paused (); - Gc.full_major (); - Lwt.wakeup_paused (); + fn ~net; Gc.full_major (); let warnings_at_end = Logs.(err_count () + warn_count ()) in Alcotest.(check int) "Check log for warnings" expected_warnings (warnings_at_end - warnings_at_start) -let test_simple switch ~serve_tls = - make_vats ~switch ~serve_tls ~service:(Echo.local ()) () >>= fun cs -> - get_bootstrap cs >>= fun service -> - Echo.ping service "ping" >>= fun reply -> +let test_simple ~net ~serve_tls = + with_vats ~net ~serve_tls ~service:(Echo.local ()) @@ fun cs -> + let service = get_bootstrap cs in + let reply = Echo.ping service "ping" in Alcotest.(check string) "Ping response" "got:0:ping" reply; - Capability.dec_ref service; - Lwt.return () + Capability.dec_ref service -let test_bad_crypto switch = - make_vats ~switch ~serve_tls:true ~service:(Echo.local ()) () >>= fun cs -> +let test_bad_crypto ~net = + with_vats ~net ~serve_tls:true ~service:(Echo.local ()) @@ fun cs -> let id = Restorer.Id.public "" in let uri = Vat.sturdy_uri cs.server id in let bad_digest = Auth.Secret_key.digest ~hash:`SHA256 bad_key in let uri = Auth.Digest.add_to_uri bad_digest uri in let sr = Capnp_rpc_unix.Vat.import_exn cs.client uri in let old_warnings = Logs.warn_count () in - Sturdy_ref.connect sr >>= function + match Sturdy_ref.connect sr with | Ok _ -> Alcotest.fail "Wrong TLS key should have been rejected" | Error e -> let msg = Fmt.to_to_string Capnp_rpc.Exception.pp e in assert (String.is_prefix ~affix:"Failed: TLS connection failed: authentication failure" msg); - (* Wait for server to log warning *) - let rec wait () = - if Logs.warn_count () = old_warnings then Lwt.pause () >>= wait - else Lwt.return_unit - in - wait () - -let test_parallel switch = - make_vats ~switch ~service:(Echo.local ()) () >>= fun cs -> - get_bootstrap cs >>= fun service -> - let reply1 = Echo.ping service ~slow:true "ping1" in - Echo.ping service "ping2" >|= Alcotest.(check string) "Ping2 response" "got:1:ping2" >>= fun () -> - assert (Lwt.state reply1 = Lwt.Sleep); - Echo.unblock service >>= fun () -> - reply1 >|= Alcotest.(check string) "Ping1 response" "got:0:ping1" >>= fun () -> - Capability.dec_ref service; - Lwt.return () + Logs.info (fun f -> f "Wait for server to log warning..."); + while Logs.warn_count () = old_warnings do + Fiber.yield () + done + +let test_parallel ~net = + with_vats ~net ~service:(Echo.local ()) @@ fun cs -> + Switch.run @@ fun sw -> + let service = get_bootstrap cs in + let reply1 = Fiber.fork_promise ~sw (fun () -> Echo.ping service ~slow:true "ping1") in + Echo.ping service "ping2" |> Alcotest.(check string) "Ping2 response" "got:1:ping2"; + assert (Promise.peek reply1 = None); + Echo.unblock service; + Promise.await_exn reply1 |> Alcotest.(check string) "Ping1 response" "got:0:ping1"; + Capability.dec_ref service -let test_registry switch = - let registry_impl = Registry.local () in - make_vats ~switch ~service:registry_impl () >>= fun cs -> - get_bootstrap cs >>= fun registry -> +let test_registry ~net = + Switch.run @@ fun sw -> + let registry_impl = Registry.local ~sw () in + with_vats ~net ~service:registry_impl @@ fun cs -> + let registry = get_bootstrap cs in Capability.with_ref (Registry.echo_service registry) @@ fun echo_service -> - Registry.unblock registry >>= fun () -> - Echo.ping echo_service "ping" >|= Alcotest.(check string) "Ping response" "got:0:ping" >>= fun () -> - Capability.dec_ref registry; - Lwt.return () + Registry.unblock registry; + Echo.ping echo_service "ping" |> Alcotest.(check string) "Ping response" "got:0:ping"; + Capability.dec_ref registry -let test_embargo switch = - let registry_impl = Registry.local () in +let test_embargo ~net = + Switch.run @@ fun sw -> + let registry_impl = Registry.local ~sw () in let local_echo = Echo.local () in - make_vats ~switch ~service:registry_impl () >>= fun cs -> - get_bootstrap cs >>= fun registry -> - Registry.set_echo_service registry local_echo >>= fun () -> + with_vats ~net ~service:registry_impl @@ fun cs -> + let registry = get_bootstrap cs in + Registry.set_echo_service registry local_echo; Capability.dec_ref local_echo; let echo_service = Registry.echo_service registry in - let reply1 = Echo.ping echo_service "ping" in - Registry.unblock registry >>= fun () -> - reply1 >|= Alcotest.(check string) "Ping response" "got:0:ping" >>= fun () -> + let reply1 = Fiber.fork_promise ~sw (fun () -> Echo.ping echo_service "ping") in + Registry.unblock registry; + Promise.await_exn reply1 |> Alcotest.(check string) "Ping response" "got:0:ping"; (* Flush, to ensure we resolve the echo_service's location. *) - Echo.ping echo_service "ping" >|= Alcotest.(check string) "Ping response" "got:1:ping" >>= fun () -> + Echo.ping echo_service "ping" |> Alcotest.(check string) "Ping response" "got:1:ping"; (* Test local connection. *) - Echo.ping echo_service "ping" >|= Alcotest.(check string) "Ping response" "got:2:ping" >>= fun () -> + Echo.ping echo_service "ping" |> Alcotest.(check string) "Ping response" "got:2:ping"; Capability.dec_ref echo_service; - Capability.dec_ref registry; - Lwt.return () + Capability.dec_ref registry -let test_resolve switch = - let registry_impl = Registry.local () in +let test_resolve ~net = + Switch.run @@ fun sw -> + let registry_impl = Registry.local ~sw () in let local_echo = Echo.local () in - make_vats ~switch ~service:registry_impl () >>= fun cs -> - get_bootstrap cs >>= fun registry -> - Registry.set_echo_service registry local_echo >>= fun () -> + with_vats ~net ~service:registry_impl @@ fun cs -> + let registry = get_bootstrap cs in + Registry.set_echo_service registry local_echo; Capability.dec_ref local_echo; let echo_service = Registry.echo_service_promise registry in - let reply1 = Echo.ping echo_service "ping" in - Registry.unblock registry >>= fun () -> - reply1 >|= Alcotest.(check string) "Ping response" "got:0:ping" >>= fun () -> + let reply1 = Fiber.fork_promise ~sw (fun () -> Echo.ping echo_service "ping") in + Registry.unblock registry; + Promise.await_exn reply1 |> Alcotest.(check string) "Ping response" "got:0:ping"; (* Flush, to ensure we resolve the echo_service's location. *) - Echo.ping echo_service "ping" >|= Alcotest.(check string) "Ping response" "got:1:ping" >>= fun () -> + Echo.ping echo_service "ping" |> Alcotest.(check string) "Ping response" "got:1:ping"; (* Test local connection. *) - Echo.ping echo_service "ping" >|= Alcotest.(check string) "Ping response" "got:2:ping" >>= fun () -> + Echo.ping echo_service "ping" |> Alcotest.(check string) "Ping response" "got:2:ping"; Capability.dec_ref echo_service; - Capability.dec_ref registry; - Lwt.return () - -let test_cancel switch = - make_vats ~switch ~service:(Echo.local ()) () >>= fun cs -> - get_bootstrap cs >>= fun service -> - let reply1 = Echo.ping service ~slow:true "ping1" in - assert (Lwt.state reply1 = Lwt.Sleep); - Lwt.cancel reply1; - Lwt.try_bind - (fun () -> reply1) - (fun _ -> Alcotest.fail "Should have been cancelled!") - (function - | Lwt.Canceled -> Lwt.return () - | ex -> Lwt.fail ex + Capability.dec_ref registry + +(* todo: we stop waiting and we send a finish message, but we don't currently + abort the service operation. *) +let test_cancel ~net = + with_vats ~net ~service:(Echo.local ()) @@ fun cs -> + let service = get_bootstrap cs in + Fiber.first + (fun () -> + ignore (Echo.ping service ~slow:true "ping1" : string); + assert false ) - >>= fun () -> - Echo.unblock service >|= fun () -> + (fun () -> + Echo.ping service "ping" |> Alcotest.(check string) "Ping response" "got:1:ping" + ); + Echo.unblock service; + Echo.ping service "ping" |> Alcotest.(check string) "Ping response" "got:2:ping"; Capability.dec_ref service let float = Alcotest.testable Fmt.float (=) -let test_calculator switch = +let test_calculator ~net = let open Calc in - Capability.inc_ref Calc.local; - make_vats ~switch ~service:Calc.local () >>= fun cs -> - get_bootstrap cs >>= fun c -> - Calc.evaluate c (Float 1.) |> Value.final_read >|= Alcotest.check float "Simple calc" 1. >>= fun () -> + Switch.run @@ fun sw -> + let service = Calc.local ~sw in + with_vats ~net ~service @@ fun cs -> + let c = get_bootstrap cs in + Calc.evaluate c (Float 1.) |> Value.final_read |> Alcotest.check float "Simple calc" 1.; let local_add = Calc.Fn.add in let expr = Expr.(Call (local_add, [Float 1.; Float 2.])) in - Calc.evaluate c expr |> Value.final_read >|= Alcotest.check float "Complex with local fn" 3. >>= fun () -> + Calc.evaluate c expr |> Value.final_read |> Alcotest.check float "Complex with local fn" 3.; let remote_add = Calc.getOperator c `Add in - Calc.Fn.call remote_add [5.; 3.] >|= Alcotest.check float "Check fn" 8. >>= fun () -> + Calc.Fn.call remote_add [5.; 3.] |> Alcotest.check float "Check fn" 8.; let expr = Expr.(Call (remote_add, [Float 1.; Float 2.])) in - Calc.evaluate c expr |> Value.final_read >|= Alcotest.check float "Complex with remote fn" 3. >>= fun () -> + Calc.evaluate c expr |> Value.final_read |> Alcotest.check float "Complex with remote fn" 3.; Capability.dec_ref remote_add; - Capability.dec_ref c; - Lwt.return () + Capability.dec_ref c -let test_calculator2 switch = +let test_calculator2 ~net = let open Calc in - Capability.inc_ref Calc.local; - make_vats ~switch ~service:Calc.local () >>= fun cs -> - get_bootstrap cs >>= fun c -> + Switch.run @@ fun sw -> + let service = Calc.local ~sw in + with_vats ~net ~service @@ fun cs -> + let c = get_bootstrap cs in let remote_add = Calc.getOperator c `Add in let remote_mul = Calc.getOperator c `Multiply in let expr = Expr.(Call (remote_mul, [Float 4.; Float 6.])) in let result = Calc.evaluate c expr in - let expr = Expr.(Call (remote_add, [Prev result; Float 3.])) in - let add3 = Calc.evaluate c expr |> Value.final_read in - let expr = Expr.(Call (remote_add, [Prev result; Float 5.])) in - let add5 = Calc.evaluate c expr |> Value.final_read in - add3 >>= fun add3 -> - add5 >>= fun add5 -> + let/ add3 () = + let expr = Expr.(Call (remote_add, [Prev result; Float 3.])) in + Calc.evaluate c expr |> Value.final_read + and/ add5 () = + let expr = Expr.(Call (remote_add, [Prev result; Float 5.])) in + Calc.evaluate c expr |> Value.final_read + in Alcotest.check float "First" 27.0 add3; Alcotest.check float "Second" 29.0 add5; Capability.dec_ref result; Capability.dec_ref remote_add; Capability.dec_ref remote_mul; - Capability.dec_ref c; - Lwt.return () + Capability.dec_ref c -let test_indexing switch = - let registry_impl = Registry.local () in - make_vats ~switch ~service:registry_impl () >>= fun cs -> - get_bootstrap cs >>= fun registry -> +let test_indexing ~net = + Switch.run @@ fun sw -> + let registry_impl = Registry.local ~sw () in + with_vats ~net ~service:registry_impl @@ fun cs -> + let registry = get_bootstrap cs in let echo_service, version = Registry.complex registry in - Echo.ping echo_service "ping" >|= Alcotest.(check string) "Ping response" "got:0:ping" >>= fun () -> - Registry.Version.read version >|= Alcotest.(check string) "Version response" "0.1" >>= fun () -> + Echo.ping echo_service "ping" |> Alcotest.(check string) "Ping response" "got:0:ping"; + Registry.Version.read version |> Alcotest.(check string) "Version response" "0.1"; Capability.dec_ref registry; Capability.dec_ref echo_service; - Capability.dec_ref version; - Lwt.return () + Capability.dec_ref version let cmd_result t = let pp f (x : ('a Cmdliner.Cmd.eval_ok, Cmdliner.Cmd.eval_error) result) = @@ -338,17 +337,16 @@ let test_sturdy_uri () = let sr = (`Unix "/sock", auth), "main" in check "Secure Unix" "capnp://sha-256:s16WV4JeGusAL_nTjvICiQOFqm3LqYrDj3K-HXdMi8s@/sock/bWFpbg" sr -let test_sturdy_self switch = +let test_sturdy_self ~net = let service = Echo.local () in Capability.inc_ref service; - make_vats ~switch ~serve_tls:true ~service () >>= fun cs -> + with_vats ~net ~serve_tls:true ~service @@ fun cs -> let id = Restorer.Id.public "" in let sr = Vat.sturdy_uri cs.server id |> Vat.import_exn cs.server in - Sturdy_ref.connect_exn sr >>= fun service2 -> + let service2 = Sturdy_ref.connect_exn sr in Alcotest.check cap "Restore from same vat" service service2; Capability.dec_ref service2; - Capability.dec_ref service; - Lwt.return () + Capability.dec_ref service let expect_non_exn = function | Ok x -> x @@ -357,7 +355,8 @@ let expect_non_exn = function let except = Alcotest.testable Capnp_rpc.Exception.pp (=) let except_ty = Alcotest.testable Capnp_rpc.Exception.pp_ty (=) -let test_table_restorer _switch = +let test_table_restorer ~net:_ = + Switch.run @@ fun sw -> let make_sturdy id = Uri.make ~path:(Restorer.Id.to_string id) () in let table = Restorer.Table.create make_sturdy in let echo_id = Restorer.Id.public "echo" in @@ -365,120 +364,136 @@ let test_table_restorer _switch = let broken_id = Restorer.Id.public "broken" in let unknown_id = Restorer.Id.public "unknown" in Restorer.Table.add table echo_id @@ Echo.local (); - Restorer.Table.add table registry_id @@ Registry.local (); + Restorer.Table.add table registry_id @@ Registry.local ~sw (); Restorer.Table.add table broken_id @@ Capability.broken (Capnp_rpc.Exception.v "broken"); let r = Restorer.of_table table in - Restorer.restore r echo_id >|= expect_non_exn >>= fun a1 -> - Echo.ping a1 "ping" >>= fun reply -> + let a1 = Restorer.restore r echo_id |> expect_non_exn in + let reply = Echo.ping a1 "ping" in Alcotest.(check string) "Ping response" "got:0:ping" reply; - Restorer.restore r echo_id >|= expect_non_exn >>= fun a2 -> + let a2 = Restorer.restore r echo_id |> expect_non_exn in Alcotest.check cap "Same cap" a1 a2; - Restorer.restore r registry_id >|= expect_non_exn >>= fun r1 -> + let r1 = Restorer.restore r registry_id |> expect_non_exn in assert (a1 <> r1); - Restorer.restore r broken_id >|= expect_non_exn >>= fun x -> + let x = Restorer.restore r broken_id |> expect_non_exn in let expected = Some (Capnp_rpc.Exception.v "broken") in Alcotest.(check (option except)) "Broken response" expected (Capability.problem x); - Restorer.restore r unknown_id >>= fun x -> + let x = Restorer.restore r unknown_id in let expected = Error (Capnp_rpc.Exception.v "Unknown persistent service ID") in Alcotest.(check (result reject except)) "Missing mapping" expected x; Capability.dec_ref a1; Capability.dec_ref a2; Capability.dec_ref r1; Restorer.Table.remove table echo_id; - Restorer.Table.clear table; - Lwt.return () + Restorer.Table.clear table module Loader = struct - type t = string -> Restorer.resolution Lwt.t + type t = string -> Restorer.resolution let hash _ = `SHA256 let make_sturdy _ id = Uri.make ~path:(Restorer.Id.to_string id) () let load t _sr digest = t digest end -let test_fn_restorer _switch = +module Cond = struct + type t = (unit Promise.t * unit Promise.u) ref + + let create () : t = ref (Promise.create ()) + + let await t = Promise.await (fst !t) + + let notify t = + Promise.resolve (snd !t) (); + t := Promise.create () +end + +let test_fn_restorer ~net:_ = + Switch.run @@ fun sw -> let cap = Alcotest.testable Capability.pp (=) in let a = Restorer.Id.public "a" in let b = Restorer.Id.public "b" in let c = Restorer.Id.public "c" in let current_c = ref (Restorer.reject (Exception.v "Broken C")) in - let delay = Lwt_condition.create () in + let delay = Cond.create () in let digest = Restorer.Id.digest (Loader.hash ()) in let load d = - if d = digest a then Lwt.return @@ Restorer.grant @@ Echo.local () - else if d = digest b then Lwt_condition.wait delay >|= fun () -> Restorer.grant @@ Echo.local () - else if d = digest c then Lwt_condition.wait delay >|= fun () -> !current_c - else Lwt.return @@ Restorer.unknown_service_id + if d = digest a then Restorer.grant @@ Echo.local () + else if d = digest b then (Cond.await delay; Restorer.grant @@ Echo.local ()) + else if d = digest c then (Cond.await delay; !current_c) + else Restorer.unknown_service_id in - let table = Restorer.Table.of_loader (module Loader) load in + let table = Restorer.Table.of_loader ~sw (module Loader) load in let restorer = Restorer.of_table table in let restore x = Restorer.restore restorer x in (* Check that restoring the same ID twice caches the capability. *) - restore a >|= expect_non_exn >>= fun a1 -> - restore a >|= expect_non_exn >>= fun a2 -> + let a1 = restore a |> expect_non_exn in + let a2 = restore a |> expect_non_exn in Alcotest.check cap "Restore cached" a1 a2; Capability.dec_ref a1; Capability.dec_ref a2; (* But if it's released, the next lookup loads a fresh one. *) - restore a >|= expect_non_exn >>= fun a3 -> + let a3 = restore a |> expect_non_exn in if a1 = a3 then Alcotest.fail "Returned released cap!"; Capability.dec_ref a3; (* Doing two lookups in parallel only does one load. *) - let b1 = restore b in - let b2 = restore b in - assert (Lwt.state b1 = Lwt.Sleep); - Lwt_condition.broadcast delay (); - b1 >|= expect_non_exn >>= fun b1 -> - b2 >|= expect_non_exn >>= fun b2 -> + let b1 = Fiber.fork_promise ~sw (fun () -> restore b) in + let b2 = Fiber.fork_promise ~sw (fun () -> restore b) in + assert (Promise.peek b1 = None); + Cond.notify delay; + let b1 = Promise.await_exn b1 |> expect_non_exn in + let b2 = Promise.await_exn b2 |> expect_non_exn in Alcotest.check cap "Restore delayed cached" b1 b2; Restorer.Table.clear table; (* (should have no effect) *) Capability.dec_ref b1; Capability.dec_ref b2; (* Failed lookups aren't cached. *) - let c1 = restore c in - Lwt_condition.broadcast delay (); - c1 >>= fun c1 -> + let c1 = Fiber.fork_promise ~sw (fun () -> restore c) in + Cond.notify delay; + let c1 = Promise.await_exn c1 in let reject = Alcotest.result cap except in Alcotest.check reject "C initially fails" (Error (Exception.v "Broken C")) c1; - let c2 = restore c in + let c2 = Fiber.fork_promise ~sw (fun () -> restore c) in let c_service = Echo.local () in current_c := Restorer.grant c_service; - Lwt_condition.broadcast delay (); - c2 >|= expect_non_exn >>= fun c2 -> + Cond.notify delay; + let c2 = Promise.await_exn c2 |> expect_non_exn in Alcotest.check cap "C now works" c_service c2; Capability.dec_ref c2; (* Two users; one frees the cap immediately *) let b1 = - restore b >|= expect_non_exn >|= fun b1 -> + Fiber.fork_promise ~sw @@ fun () -> + restore b |> expect_non_exn |> fun b1 -> Capability.dec_ref b1; b1 in - let b2 = restore b in - Lwt_condition.broadcast delay (); - b1 >>= fun b1 -> - b2 >|= expect_non_exn >>= fun b2 -> + let b2 = Fiber.fork_promise ~sw (fun () -> restore b) in + Cond.notify delay; + let b1 = Promise.await_exn b1 in + let b2 = Promise.await_exn b2 |> expect_non_exn in Alcotest.check cap "Cap not freed" b1 b2; - Capability.dec_ref b2; - Lwt.return_unit - -let test_broken switch = - make_vats ~switch ~service:(Echo.local ()) () >>= fun cs -> - get_bootstrap cs >>= fun service -> - Echo.ping service "ping" >|= Alcotest.(check string) "Ping response" "got:0:ping" >>= fun () -> - let problem, set_problem = Lwt.wait () in - Capability.when_broken (fun x -> Lwt.wakeup set_problem x) service; - Alcotest.check (Alcotest.option except) "Still OK" None @@ Capability.problem service; - assert (Lwt.state problem = Lwt.Sleep); - Logs.info (fun f -> f "Turning off server..."); - Lwt_switch.turn_off cs.server_switch >>= fun () -> - problem >>= fun problem -> - Alcotest.check except_ty "Broken callback ran" `Disconnected problem.ty; - assert (Capability.problem service <> None); - Lwt.catch - (fun () -> Echo.ping service "ping" >|= fun _ -> Alcotest.fail "Should have failed!") - (fun _ -> Lwt.return ()) - >|= fun () -> - Capability.dec_ref service + Capability.dec_ref b2 + +let test_broken ~net = + try + Switch.run (fun server_sw -> + with_vats ~server_sw ~net ~service:(Echo.local ()) @@ fun cs -> + let service = get_bootstrap cs in + Echo.ping service "ping" |> Alcotest.(check string) "Ping response" "got:0:ping"; + let problem, set_problem = Promise.create () in + Capability.when_broken (fun x -> Promise.resolve set_problem x) service; + Alcotest.check (Alcotest.option except) "Still OK" None @@ Capability.problem service; + assert (Promise.peek problem = None); + Logs.info (fun f -> f "Turning off server..."); + Switch.fail server_sw Simulated_failure; + let problem = Promise.await problem in + Alcotest.check except_ty "Broken callback ran" `Disconnected problem.ty; + assert (Capability.problem service <> None); + try + ignore (Echo.ping service "ping" : string); + Alcotest.fail "Should have failed!" + with Failure _ -> + Capability.dec_ref service + ) + with Simulated_failure -> () (* [when_broken] follows promises. *) let test_broken2 () = @@ -507,149 +522,163 @@ let test_broken4 () = Capability.dec_ref promise; Alcotest.check (Alcotest.option except) "Released, not called" None !problem -let test_parallel_connect switch = - make_vats ~switch ~serve_tls:true ~service:(Echo.local ()) () >>= fun cs -> - let service = get_bootstrap cs in - let service2 = get_bootstrap cs in - service >>= fun service -> - service2 >>= fun service2 -> - Capability.await_settled_exn service >>= fun () -> - Capability.await_settled_exn service2 >>= fun () -> +let test_parallel_connect ~net = + with_vats ~net ~serve_tls:true ~service:(Echo.local ()) @@ fun cs -> + let/ service () = get_bootstrap cs + and/ service2 () = get_bootstrap cs in + Capability.await_settled_exn service; + Capability.await_settled_exn service2; Alcotest.check cap "Shared connection" service service2; Capability.dec_ref service; - Capability.dec_ref service2; - Lwt.return_unit - -let test_parallel_fails switch = - make_vats ~switch ~serve_tls:true ~service:(Echo.local ()) () >>= fun cs -> - let service = get_bootstrap cs in - let service2 = get_bootstrap cs in - service >>= fun service -> - service2 >>= fun service2 -> - Lwt_switch.turn_off cs.server_switch >>= fun () -> - Capability.await_settled_exn service >>= fun () -> - Capability.await_settled_exn service2 >>= fun () -> - Alcotest.check cap "Shared failure" service service2; - Capability.dec_ref service; - Capability.dec_ref service2; - (* Restart server (ignore new client) *) - Lwt.pause () >>= fun () -> - make_vats ~switch ~serve_tls:true ~service:(Echo.local ()) () >>= fun _cs2 -> - get_bootstrap cs >>= fun service -> - Echo.ping service "ping" >|= Alcotest.(check string) "Ping response" "got:0:ping" >>= fun () -> - Capability.dec_ref service; - Lwt.return_unit - -let test_crossed_calls switch = + Capability.dec_ref service2 + +let test_parallel_fails ~net = + try + Switch.run (fun server_sw -> + with_vats ~net ~server_sw ~serve_tls:true ~service:(Echo.local ()) @@ fun cs -> + let/ service () = get_bootstrap cs + and/ service2 () = get_bootstrap cs in + Switch.fail server_sw Simulated_failure; + let p, r = Promise.create () in + Capability.when_broken (Promise.resolve r) service2; + ignore (Promise.await p : Exception.t); + Alcotest.check cap "Shared failure" service service2; + Capability.dec_ref service; + Capability.dec_ref service2; + (* Restart server (ignore new client) *) + Fiber.yield (); + with_vats ~net ~serve_tls:true ~service:(Echo.local ()) @@ fun cs -> + let service = get_bootstrap cs in + Echo.ping service "ping" |> Alcotest.(check string) "Ping response" "got:0:ping"; + Capability.dec_ref service + ) + with Simulated_failure -> () + +let test_crossed_calls ~net = (* Would be good to control the ordering here, to test the various cases. Currently, it's not certain which path is actually tested. *) - let id = Restorer.Id.public "" in - let make_vat ~secret_key ~tags addr = - let service = Echo.local () in - let restore = Restorer.(single id) service in - let config = - let secret_key = `PEM (Auth.Secret_key.to_pem_data secret_key) in - let name = Fmt.str "capnp-rpc-test-%s" addr in - Capnp_rpc_unix.Vat_config.create ~secret_key (get_test_address ~switch name) + try + Switch.run @@ fun sw -> + let id = Restorer.Id.public "" in + let make_vat ~secret_key ~tags addr = + let service = Echo.local () in + let restore = Restorer.(single id) service in + let config = + let secret_key = `PEM (Auth.Secret_key.to_pem_data secret_key) in + let name = Fmt.str "capnp-rpc-test-%s" addr in + Capnp_rpc_unix.Vat_config.create ~secret_key (get_test_address ~sw name) + in + let vat = Capnp_rpc_unix.serve ~net ~sw ~tags ~restore config in + Switch.on_release sw (fun () -> Capability.dec_ref service); + vat in - Capnp_rpc_unix.serve ~switch ~tags ~restore config >>= fun vat -> - Lwt_switch.add_hook (Some switch) (fun () -> Capability.dec_ref service; Lwt.return_unit); - Lwt.return vat - in - make_vat ~secret_key:client_key ~tags:Test_utils.client_tags "client" >>= fun client -> - make_vat ~secret_key:server_key ~tags:Test_utils.server_tags "server" >>= fun server -> - let sr_to_client = Capnp_rpc_unix.Vat.sturdy_uri client id |> Vat.import_exn server in - let sr_to_server = Capnp_rpc_unix.Vat.sturdy_uri server id |> Vat.import_exn client in - let to_client = Sturdy_ref.connect_exn sr_to_client in - let to_server = Sturdy_ref.connect_exn sr_to_server in - to_client >>= fun to_client -> - to_server >>= fun to_server -> - Logs.info (fun f -> f ~tags:Test_utils.client_tags "%a" Capnp_rpc_unix.Vat.dump client); - Logs.info (fun f -> f ~tags:Test_utils.server_tags "%a" Capnp_rpc_unix.Vat.dump server); - let s_got = Echo.ping_result to_client "ping" in - let c_got = Echo.ping_result to_server "ping" in - s_got >>= fun s_got -> - c_got >>= fun c_got -> - begin match c_got, s_got with - | Ok x, Ok y -> Lwt.return (x, y) - | Ok x, Error _ -> - (* Server got an error. Try client again. *) - Sturdy_ref.connect_exn sr_to_client >>= fun to_client -> - Capability.with_ref to_client @@ fun to_client -> - Echo.ping to_client "ping" >|= fun s_got -> (x, s_got) - | Error _, Ok y -> - (* Client got an error. Try server again. *) - Sturdy_ref.connect_exn sr_to_server >>= fun to_server -> - Capability.with_ref to_server @@ fun to_server -> - Echo.ping to_server "ping" >|= fun c_got -> (c_got, y) - | Error (`Capnp e1), Error (`Capnp e2) -> - Fmt.failwith "@[Both connections failed!@,%a@,%a@]" - Capnp_rpc.Error.pp e1 - Capnp_rpc.Error.pp e2 - end >>= fun (c_got, s_got) -> - Alcotest.(check string) "Client's ping response" "got:0:ping" c_got; - Alcotest.(check string) "Server's ping response" "got:0:ping" s_got; - Capability.dec_ref to_client; - Capability.dec_ref to_server; - Lwt.return_unit + let client = make_vat ~secret_key:client_key ~tags:Test_utils.client_tags "client" in + let server = make_vat ~secret_key:server_key ~tags:Test_utils.server_tags "server" in + let sr_to_client = Capnp_rpc_unix.Vat.sturdy_uri client id |> Vat.import_exn server in + let sr_to_server = Capnp_rpc_unix.Vat.sturdy_uri server id |> Vat.import_exn client in + let/ to_client () = Sturdy_ref.connect_exn sr_to_client + and/ to_server () = Sturdy_ref.connect_exn sr_to_server in + Logs.info (fun f -> f ~tags:Test_utils.client_tags "%a" Capnp_rpc_unix.Vat.dump client); + Logs.info (fun f -> f ~tags:Test_utils.server_tags "%a" Capnp_rpc_unix.Vat.dump server); + let/ s_got () = Echo.ping_result to_client "ping" + and/ c_got () = Echo.ping_result to_server "ping" in + let c_got, s_got = + match c_got, s_got with + | Ok x, Ok y -> (x, y) + | Ok x, Error _ -> + (* Server got an error. Try client again. *) + let to_client = Sturdy_ref.connect_exn sr_to_client in + Capability.with_ref to_client @@ fun to_client -> + Echo.ping to_client "ping" |> fun s_got -> (x, s_got) + | Error _, Ok y -> + (* Client got an error. Try server again. *) + let to_server = Sturdy_ref.connect_exn sr_to_server in + Capability.with_ref to_server @@ fun to_server -> + Echo.ping to_server "ping" |> fun c_got -> (c_got, y) + | Error (`Capnp e1), Error (`Capnp e2) -> + Fmt.failwith "@[Both connections failed!@,%a@,%a@]" + Capnp_rpc.Error.pp e1 + Capnp_rpc.Error.pp e2 + in + Alcotest.(check string) "Client's ping response" "got:0:ping" c_got; + Alcotest.(check string) "Server's ping response" "got:0:ping" s_got; + Capability.dec_ref to_client; + Capability.dec_ref to_server; + raise Simulated_failure + with Simulated_failure -> () (* Run test_crossed_calls several times to try to trigger the various behaviours. *) -let test_crossed_calls _switch = - let rec aux i = - if i = 0 then Lwt.return_unit - else ( - Lwt_switch.with_switch test_crossed_calls >>= fun () -> - aux (i - 1) - ) - in - aux 10 - -let test_store switch = - (* Persistent server configuration *) - let db = Store.DB.create () in - let config = - let addr = get_test_address ~switch "capnp-rpc-test-server" in - Capnp_rpc_unix.Vat_config.create ~secret_key:server_pem addr - in - let main_id = Restorer.Id.generate () in - let start_server ~switch () = - let make_sturdy = Capnp_rpc_unix.Vat_config.sturdy_uri config in - let table = Store.File.table ~make_sturdy db in - Lwt_switch.add_hook (Some switch) (fun () -> Restorer.Table.clear table; Lwt.return_unit); - let restore = Restorer.of_table table in - let service = Store.local ~restore db in - Restorer.Table.add table main_id service; - Capnp_rpc_unix.serve ~switch ~restore ~tags:Test_utils.server_tags config - in - (* Start server *) - let server_switch = Lwt_switch.create () in - start_server ~switch:server_switch () >>= fun server -> - let store_uri = Capnp_rpc_unix.Vat.sturdy_uri server main_id in - (* Set up client *) - let client = Capnp_rpc_unix.client_only_vat ~tags:Test_utils.client_tags ~switch () in - let sr = Capnp_rpc_unix.Vat.import_exn client store_uri in - Sturdy_ref.with_cap_exn sr @@ fun store -> - (* Try creating a file *) - let file = Store.create_file store in - Store.File.set file "Hello" >>= fun () -> - Persistence.save_exn file >>= fun file_sr -> - let file_sr = Vat.import_exn client file_sr in (* todo: get rid of this step *) - (* Shut down server *) - Lwt.async (fun () -> Lwt_switch.turn_off server_switch); - let broken, set_broken = Lwt.wait () in - Capability.when_broken (Lwt.wakeup set_broken) file; - broken >>= fun _ex -> - assert (Capability.problem file <> None); - (* Restart server *) - start_server ~switch () >>= fun _server -> - (* Reconnect client *) - Sturdy_ref.with_cap_exn file_sr @@ fun file -> - Store.File.get file >>= fun data -> - Alcotest.(check string) "Read file" "Hello" data; - Lwt.return_unit - -let test_file_store _switch = - Lwt_io.with_temp_dir ~prefix:"capnp-tests-" @@ fun tmpdir -> +let test_crossed_calls ~net = + for _ = 1 to 10 do + test_crossed_calls ~net + done + +let test_store ~net = + try + Switch.run @@ fun sw -> + (* Persistent server configuration *) + let db = Store.DB.create () in + let config = + let addr = get_test_address ~sw "capnp-rpc-test-server" in + Capnp_rpc_unix.Vat_config.create ~secret_key:server_pem addr + in + let main_id = Restorer.Id.generate () in + let start_server ~sw () = + let make_sturdy = Capnp_rpc_unix.Vat_config.sturdy_uri config in + let table = Store.File.table ~sw ~make_sturdy db in + Switch.on_release sw (fun () -> Restorer.Table.clear table); + let restore = Restorer.of_table table in + let service = Store.local ~restore db in + Restorer.Table.add table main_id service; + Capnp_rpc_unix.serve ~sw ~net ~restore ~tags:Test_utils.server_tags config + in + (* Start server *) + Switch.run @@ fun server_switch -> + let server = start_server ~sw:server_switch () in + let store_uri = Capnp_rpc_unix.Vat.sturdy_uri server main_id in + (* Set up client *) + let client = Capnp_rpc_unix.client_only_vat ~tags:Test_utils.client_tags ~sw net in + let sr = Capnp_rpc_unix.Vat.import_exn client store_uri in + Sturdy_ref.with_cap_exn sr @@ fun store -> + (* Try creating a file *) + let file = Store.create_file store in + Store.File.set file "Hello"; + let file_sr = Persistence.save_exn file in + let file_sr = Vat.import_exn client file_sr in (* todo: get rid of this step *) + (* Shut down server *) + Switch.fail server_switch Simulated_failure; + let broken, set_broken = Promise.create () in + Capability.when_broken (Promise.resolve set_broken) file; + ignore (Promise.await broken : Exception.t); + assert (Capability.problem file <> None); + (* Restart server *) + let _server = start_server ~sw () in + (* Reconnect client *) + Sturdy_ref.with_cap_exn file_sr @@ fun file -> + let data = Store.File.get file in + Alcotest.(check string) "Read file" "Hello" data + with Simulated_failure -> () + +let ( / ) = Eio.Path.( / ) + +let rmtree dir = + Eio.Path.read_dir dir + |> List.iter (fun leaf -> + let path = dir / leaf in + traceln "rm %a" Eio.Path.pp path; + Eio.Path.unlink path + ); + traceln "rmdir %a" Eio.Path.pp dir; + Eio.Path.rmdir dir; + traceln "Removed" + +let with_temp_dir path fn = + Eio.Path.mkdir path ~perm:0o700; + Fun.protect (fun () -> Eio.Path.with_open_dir path fn) + ~finally:(fun () -> rmtree path) + +let test_file_store ~dir ~net:_ = + with_temp_dir (dir / "capnp-tests") @@ fun tmpdir -> let module S = Capnp_rpc_unix.File_store in let s = S.create tmpdir in Alcotest.(check (option reject)) "Missing file" None @@ S.load s ~digest:"missing"; @@ -661,84 +690,95 @@ let test_file_store _switch = Builder.to_reader b in S.save s ~digest:"!/.." data; - Alcotest.(check (option string)) "Restored" (Some "Test") @@ Option.map Reader.text_get (S.load s ~digest:"!/.."); - Lwt.return_unit + Alcotest.(check (option string)) "Restored" (Some "Test") @@ Option.map Reader.text_get (S.load s ~digest:"!/..") let capnp_error = Alcotest.of_pp Capnp_rpc.Exception.pp -let test_await_settled _switch = +let test_await_settled ~net:_ = (* Ok *) + Switch.run @@ fun sw -> let p, r = Capability.promise () in - let check = Capability.await_settled p in + let check = Fiber.fork_promise ~sw (fun () -> Capability.await_settled p) in Capability.resolve_ok r @@ Echo.local (); - check >>= fun check -> + let check = Promise.await_exn check in Alcotest.(check (result unit capnp_error)) "Check await success" (Ok ()) check; Capability.dec_ref p; (* Error *) let p, r = Capability.promise () in - let check = Capability.await_settled p in + let check = Fiber.fork_promise ~sw (fun () -> Capability.await_settled p) in let err = Capnp_rpc.Exception.v "Test" in Capability.resolve_exn r err; - check >>= fun check -> - Alcotest.(check (result unit capnp_error)) "Check await failure" (Error err) check; - Lwt.return_unit + let check = Promise.await_exn check in + Alcotest.(check (result unit capnp_error)) "Check await failure" (Error err) check (* The client disconnects before the server has finished loading the bootstrap object. *) -let test_late_bootstrap switch = - let connected, set_connected = Lwt.wait () in - let service, set_service = Lwt.wait () in - let module Loader = struct - type t = unit - let hash () = `SHA256 - let make_sturdy () _id = assert false - let load () _sr _name = - Lwt.wakeup_later set_connected (); - service - end in - let table = Capnp_rpc_net.Restorer.Table.of_loader (module Loader) () in - let restore = Restorer.of_table table in - let client_switch = Lwt_switch.create () in - make_vats_full ~client_switch ~server_switch:switch ~restore () >>= fun cs -> - let service = get_bootstrap cs in - connected >>= fun () -> - Lwt_switch.turn_off client_switch >>= fun () -> - Lwt.wakeup set_service @@ Capnp_rpc_net.Restorer.grant @@ Echo.local (); - service >>= fun _ -> - Lwt.return () - -let run name fn = Alcotest_lwt.test_case_sync name `Quick fn - -let rpc_tests = [ - run_lwt "Simple" (test_simple ~serve_tls:false); - run_lwt "Crypto" (test_simple ~serve_tls:true); - run_lwt "Bad crypto" test_bad_crypto ~expected_warnings:1; - run_lwt "Parallel" test_parallel; - run_lwt "Embargo" test_embargo; - run_lwt "Resolve" test_resolve; - run_lwt "Registry" test_registry; - run_lwt "Calculator" test_calculator; - run_lwt "Calculator 2" test_calculator2; - run_lwt "Cancel" test_cancel; - run_lwt "Indexing" test_indexing; - run "Options" test_options; - run "Sturdy URI" test_sturdy_uri; - run_lwt "Sturdy self" test_sturdy_self; - run_lwt "Table restorer" test_table_restorer; - run_lwt "Fn restorer" test_fn_restorer; - run_lwt "Broken ref" test_broken; - run "Broken ref 2" test_broken2; - run "Broken ref 3" test_broken3; - run "Broken ref 4" test_broken4; - run_lwt "Parallel connect" test_parallel_connect; - run_lwt "Parallel fails" test_parallel_fails; - run_lwt "Crossed calls" test_crossed_calls; - run_lwt "Store" test_store; - run_lwt "File store" test_file_store; - run_lwt "Await settled" test_await_settled; - run_lwt "Late bootstrap" test_late_bootstrap; -] +let test_late_bootstrap ~net = + try + Switch.run @@ fun server_sw -> + Switch.run @@ fun client_switch -> + let connected, set_connected = Promise.create () in + let service, set_service = Promise.create () in + let module Loader = struct + type t = unit + let hash () = `SHA256 + let make_sturdy () _id = assert false + let load () _sr _name = + Promise.resolve set_connected (); + Promise.await service + end in + let table = Capnp_rpc_net.Restorer.Table.of_loader ~sw:server_sw (module Loader) () in + let restore = Restorer.of_table table in + let cs = make_vats_full ~sw:client_switch ~server_sw ~restore ~net () in + let service = get_bootstrap cs in + Promise.await connected; + Eio.Cancel.protect @@ fun () -> + Switch.fail client_switch Simulated_failure; + Promise.resolve set_service @@ Capnp_rpc_net.Restorer.grant @@ Echo.local (); + let service = Capability.await_settled service |> Result.get_error in + Logs.info (fun f -> f "client got: %a" Capnp_rpc.Exception.pp service); + assert (service.Capnp_rpc.Exception.ty = `Disconnected); + (* The restorer yields once before returning the cap, + so we wait too, to ensure it's done. *) + Fiber.yield () + with Simulated_failure -> () + +let run name fn = Alcotest.test_case name `Quick fn + +let rpc_tests ~net ~dir = + let run_eio = run_eio ~net in + [ + run_eio "Simple" (test_simple ~serve_tls:false); + run_eio "Crypto" (test_simple ~serve_tls:true); + run_eio "Bad crypto" test_bad_crypto ~expected_warnings:1; + run_eio "Parallel" test_parallel; + run_eio "Embargo" test_embargo; + run_eio "Resolve" test_resolve; + run_eio "Registry" test_registry; + run_eio "Calculator" test_calculator; + run_eio "Calculator 2" test_calculator2; + run_eio "Cancel" test_cancel; + run_eio "Indexing" test_indexing; + run "Options" test_options; + run "Sturdy URI" test_sturdy_uri; + run_eio "Sturdy self" test_sturdy_self; + run_eio "Table restorer" test_table_restorer; + run_eio "Fn restorer" test_fn_restorer; + run_eio "Broken ref" test_broken; + run "Broken ref 2" test_broken2; + run "Broken ref 3" test_broken3; + run "Broken ref 4" test_broken4; + run_eio "Parallel connect" test_parallel_connect; + run_eio "Parallel fails" test_parallel_fails; + run_eio "Crossed calls" test_crossed_calls; + run_eio "Store" test_store; + run_eio "File store" (test_file_store ~dir); + run_eio "Await settled" test_await_settled; + run_eio "Late bootstrap" test_late_bootstrap; + ] let () = - Alcotest_lwt.run ~and_exit:false "capnp-rpc" [ - "lwt", rpc_tests; - ] |> Lwt_main.run + Eio_main.run @@ fun env -> + (* Eio_unix.Ctf.with_tracing "/tmp/trace.ctf" @@ fun () -> *) + Alcotest.run ~and_exit:false "capnp-rpc" [ + "eio", rpc_tests ~net:env#net ~dir:env#cwd; + ] diff --git a/test-mirage/dune b/test-mirage/dune deleted file mode 100644 index 0aa5204c1..000000000 --- a/test-mirage/dune +++ /dev/null @@ -1,6 +0,0 @@ -(test - (name test_mirage) - (package capnp-rpc-mirage) - (libraries io-page-unix capnp-rpc-lwt capnp-rpc-mirage alcotest-lwt testlib - logs.fmt testbed tcpip.ipv4 tcpip.ipv6 tcpip.stack-direct mirage-vnetif ethernet - arp.mirage tcpip.tcp tcpip.icmpv4 mirage-crypto-rng.lwt)) diff --git a/test-mirage/test_mirage.ml b/test-mirage/test_mirage.ml deleted file mode 100644 index 3fb160f1e..000000000 --- a/test-mirage/test_mirage.ml +++ /dev/null @@ -1,142 +0,0 @@ -open Lwt.Infix -open Capnp_rpc_lwt -open Capnp_rpc_net -open Testlib - -module Time = struct - let sleep_ns ns = Lwt_unix.sleep (Duration.to_f ns) -end - -module Clock = struct - let period_ns () = None - let elapsed_ns () = 0L -end - -module PClock = struct - let now_d_ps () = (0, 0L) - let current_tz_offset_s () = None - let period_d_ps () = None -end - -module Random = struct - type g = unit - - let generate ?g n = ignore g; Cstruct.create n -end - -module Stack = struct - module B = Basic_backend.Make - module V = Vnetif.Make(B) - module E = Ethernet.Make(V) - module A = Arp.Make(E)(Time) - module I4 = Static_ipv4.Make(Random)(Clock)(E)(A) - module I6 = Ipv6.Make(V)(E)(Random)(Time)(Clock) - module I = Tcpip_stack_direct.IPV4V6(I4)(I6) - module U = Udp.Make(I)(Random) - module T = Tcp.Flow.Make(I)(Time)(Clock)(Random) - module Icmp = Icmpv4.Make(I4) - include Tcpip_stack_direct.MakeV4V6(Time)(Random)(V)(E)(A)(I)(Icmp)(U)(T) - - let create_network () = B.create ~use_async_readers:true ~yield:Lwt.pause () - - let create_interface backend cidr = - V.connect backend >>= fun v -> - E.connect v >>= fun e -> - A.connect e >>= fun a -> - I4.connect ~cidr e a >>= fun i4 -> - I6.connect ~no_init:true v e >>= fun i6 -> - I.connect ~ipv4_only:true ~ipv6_only:false i4 i6 >>= fun i -> - U.connect i >>= fun u -> - T.connect i >>= fun t -> - Icmp.connect i4 >>= fun icmp -> - connect v e a i icmp u t -end -module Mirage = Capnp_rpc_mirage.Make(Random)(Time)(Clock)(PClock)(Stack) -module Vat = Mirage.Vat - -type cs = { - client : Vat.t; - server : Vat.t; - client_key : Auth.Secret_key.t; - server_key : Auth.Secret_key.t; - serve_tls : bool; - server_switch : Lwt_switch.t; -} - -(* Have the client ask the server for its bootstrap object, and return the - resulting client-side proxy to it. *) -let get_bootstrap cs = - let id = Restorer.Id.public "" in - let sr = Vat.sturdy_uri cs.server id |> Vat.import_exn cs.client in - Sturdy_ref.connect_exn sr - -let create_iface network cidr = - Stack.create_interface network (Ipaddr.V4.Prefix.of_string_exn cidr) >|= fun stack -> - let dns = Mirage.Network.Dns.create stack in - Mirage.network ~dns stack - -let () = Mirage_crypto_rng_lwt.initialize () -let server_key = Auth.Secret_key.generate () -let client_key = Auth.Secret_key.generate () - -let server_pem = `PEM (Auth.Secret_key.to_pem_data server_key) - -let make_vats ?(serve_tls=false) ~switch ~service () = - let id = Restorer.Id.public "" in - let restore = Restorer.single id service in - let server_config = - Mirage.Vat_config.create ~secret_key:server_pem ~serve_tls ~public_address:(`TCP ("10.0.0.1", 7000)) (`TCP 7000) - in - let net = Stack.create_network () in - create_iface net "10.0.0.1/8" >>= fun server_net -> - create_iface net "10.0.0.2/8" >>= fun client_net -> - let server_switch = Lwt_switch.create () in - Mirage.serve ~switch:server_switch ~tags:Testbed.Test_utils.server_tags ~restore server_net server_config >>= fun server -> - Lwt_switch.add_hook (Some switch) (fun () -> Lwt_switch.turn_off server_switch); - Lwt_switch.add_hook (Some switch) (fun () -> Capability.dec_ref service; Lwt.return_unit); - Lwt.return { - client = Vat.create ~switch ~tags:Testbed.Test_utils.client_tags ~secret_key:(lazy client_key) client_net; - server; - client_key; - server_key; - serve_tls; - server_switch; - } - -(* Generic Lwt running for Alcotest. *) -let run_lwt name ?(expected_warnings=0) fn = - Alcotest_lwt.test_case name `Quick @@ fun sw () -> - let warnings_at_start = Logs.(err_count () + warn_count ()) in - Logs.info (fun f -> f "Start test-case"); - let finished = ref false in - Lwt_switch.add_hook (Some sw) (fun () -> - if not !finished then !Lwt.async_exception_hook (Failure "Switch turned off early"); - Lwt.return_unit - ); - fn sw >>= fun () -> finished := true; - Lwt_switch.turn_off sw >|= fun () -> - Gc.full_major (); - Lwt.wakeup_paused (); - Gc.full_major (); - Lwt.wakeup_paused (); - Gc.full_major (); - let warnings_at_end = Logs.(err_count () + warn_count ()) in - Alcotest.(check int) "Check log for warnings" expected_warnings (warnings_at_end - warnings_at_start) - -let test_simple switch ~serve_tls = - make_vats ~switch ~serve_tls ~service:(Echo.local ()) () >>= fun cs -> - get_bootstrap cs >>= fun service -> - Echo.ping service "ping" >>= fun reply -> - Alcotest.(check string) "Ping response" "got:0:ping" reply; - Capability.dec_ref service; - Lwt.return () - -let rpc_tests = [ - run_lwt "Simple" (test_simple ~serve_tls:false); - run_lwt "TLS" (test_simple ~serve_tls:true); -] - -let () = - Alcotest_lwt.run ~and_exit:false "capnp-rpc" [ - "mirage", rpc_tests; - ] |> Lwt_main.run diff --git a/test-mirage/test_mirage.mli b/test-mirage/test_mirage.mli deleted file mode 100644 index 08c8a5838..000000000 --- a/test-mirage/test_mirage.mli +++ /dev/null @@ -1 +0,0 @@ -(* (no public API) *) diff --git a/test/testbed/connection.ml b/test/testbed/connection.ml index d52e2f6cc..770131f43 100644 --- a/test/testbed/connection.ml +++ b/test/testbed/connection.ml @@ -81,13 +81,15 @@ module Endpoint (EP : Capnp_direct.ENDPOINT) = struct | _ -> k @@ Error (Capnp_rpc.Exception.v "Only a main interface is available") ) + let fork fn = fn () + let create ?bootstrap ~tags (xmit_queue:[EP.Out.t | `Unimplemented of EP.In.t] Queue.t) (recv_queue:[EP.In.t | `Unimplemented of EP.Out.t] Queue.t) = let queue_send x = Queue.add (x :> [EP.Out.t | `Unimplemented of EP.In.t]) xmit_queue in let bootstrap = (bootstrap :> EP.Core_types.cap option) in let restore = restore_single bootstrap in - let conn = Conn.create ?restore ~tags ~queue_send in + let conn = Conn.create ?restore ~tags ~queue_send ~fork in { conn; recv_queue; diff --git a/unix/capnp_rpc_unix.ml b/unix/capnp_rpc_unix.ml index c930f1e1a..d16ff4bcf 100644 --- a/unix/capnp_rpc_unix.ml +++ b/unix/capnp_rpc_unix.ml @@ -1,13 +1,10 @@ +open Eio.Std open Astring -open Lwt.Infix module Log = Capnp_rpc.Debug.Log -module Unix_flow = Unix_flow let () = Mirage_crypto_rng_lwt.initialize () -type flow = Unix_flow.flow - module CapTP = Vat_network.CapTP module Vat = Vat_network.Vat module Network = Network @@ -95,8 +92,8 @@ module Console = struct clear (); messages := msg :: !messages; show (); - Lwt.finalize f - (fun () -> + Fun.protect f + ~finally:(fun () -> clear (); let rec remove_first = function | [] -> assert false @@ -104,8 +101,7 @@ module Console = struct | x :: xs -> x :: remove_first xs in messages := remove_first !messages; - show (); - Lwt.return_unit + show () ) end @@ -122,7 +118,7 @@ let rec connect_with_progress ?(mode=`Auto) sr = let did_log = ref false in Log.info (fun f -> did_log := true; f "Connecting to %a..." pp sr); if !did_log then ( - Sturdy_ref.connect sr >|= function + match Sturdy_ref.connect sr with | Ok _ as x -> Log.info (fun f -> f "Connected to %a" pp sr); x | Error _ as e -> e ) else ( @@ -133,108 +129,80 @@ let rec connect_with_progress ?(mode=`Auto) sr = ) | `Batch -> Fmt.epr "Connecting to %a... %!" pp sr; - begin Sturdy_ref.connect sr >|= function + begin match Sturdy_ref.connect sr with | Ok _ as x -> Fmt.epr "OK@."; x | Error _ as x -> Fmt.epr "ERROR@."; x end | `Console -> - let x = Sturdy_ref.connect sr in - Lwt.choose [Lwt_unix.sleep 0.5; Lwt.map ignore x] >>= fun () -> - if Lwt.is_sleeping x then ( - Console.with_msg (Fmt.str "[ connecting to %a ]" pp sr) - (fun () -> x) - ) else x + Switch.run @@ fun sw -> + let x = Fiber.fork_promise ~sw (fun () -> Sturdy_ref.connect sr) in + Fiber.first + (fun () -> Promise.await_exn x) + (fun () -> + Eio_unix.sleep 0.5; + Console.with_msg (Fmt.str "[ connecting to %a ]" pp sr) + (fun () -> Promise.await_exn x) + ) | `Silent -> Sturdy_ref.connect sr let with_cap_exn ?progress sr f = - connect_with_progress ?mode:progress sr >>= function + match connect_with_progress ?mode:progress sr with | Error ex -> Fmt.failwith "%a" Capnp_rpc.Exception.pp ex | Ok x -> Capnp_rpc_lwt.Capability.with_ref x f let handle_connection ?tags ~secret_key vat client = - Lwt.catch (fun () -> - let switch = Lwt_switch.create () in - let raw_flow = Unix_flow.connect ~switch client in - Network.accept_connection ~switch ~secret_key raw_flow >>= function - | Error (`Msg msg) -> - Log.warn (fun f -> f ?tags "Rejecting new connection: %s" msg); - Lwt.return_unit - | Ok ep -> - Vat.add_connection vat ~switch ~mode:`Accept ep >|= fun (_ : CapTP.t) -> - () - ) - (fun ex -> - Log.err (fun f -> f "Uncaught exception handling connection: %a" Fmt.exn ex); - Lwt.return_unit - ) - -let addr_of_host host = - match Unix.gethostbyname host with - | exception Not_found -> - Capnp_rpc.Debug.failf "Unknown host %S" host - | addr -> - if Array.length addr.Unix.h_addr_list = 0 then - Capnp_rpc.Debug.failf "No addresses found for host name %S" host - else - addr.Unix.h_addr_list.(0) - -let serve ?switch ?tags ?restore config = + match Network.accept_connection ~secret_key client with + | Error (`Msg msg) -> + Log.warn (fun f -> f ?tags "Rejecting new connection: %s" msg) + | Ok ep -> + let _ : CapTP.t = Vat.add_connection vat ~mode:`Accept ep in + () + +let create_server ?tags ?restore ~sw ~net config = let {Vat_config.backlog; secret_key = _; serve_tls; listen_address; public_address} = config in let vat = let auth = Vat_config.auth config in let secret_key = lazy (fst (Lazy.force config.secret_key)) in - Vat.create ?switch ?tags ?restore ~address:(public_address, auth) ~secret_key () + Vat.create ?tags ?restore ~sw ~address:(public_address, auth) ~secret_key net in let socket = match listen_address with - | `Unix path -> - begin match Unix.lstat path with - | { Unix.st_kind = Unix.S_SOCK; _ } -> Unix.unlink path - | _ -> () - | exception Unix.Unix_error(Unix.ENOENT, _, _) -> () - end; - let socket = Unix.(socket PF_UNIX SOCK_STREAM 0) in - Unix.bind socket (Unix.ADDR_UNIX path); - socket + | `Unix _ as addr -> Eio.Net.listen ~sw ~backlog ~reuse_addr:true net addr | `TCP (host, port) -> - let socket = Unix.(socket PF_INET SOCK_STREAM 0) in - Unix.setsockopt socket Unix.SO_REUSEADDR true; - Unix.setsockopt socket Unix.SO_KEEPALIVE true; - Keepalive.try_set_idle socket 60; - Unix.bind socket (Unix.ADDR_INET (addr_of_host host, port)); + let addr = Network.addr_of_host host in + let socket = Eio.Net.listen ~sw ~backlog ~reuse_addr:true net (`Tcp (addr, port)) in + let unix_socket = Option.get (Eio_unix.FD.peek_opt socket) in + Unix.setsockopt unix_socket Unix.SO_KEEPALIVE true; + Keepalive.try_set_idle unix_socket 60; socket in - Unix.listen socket backlog; Log.info (fun f -> f ?tags "Waiting for %s connections on %a" (if serve_tls then "(encrypted)" else "UNENCRYPTED") Vat_config.Listen_address.pp listen_address); - let lwt_socket = Lwt_unix.of_unix_file_descr socket in - let rec loop () = - Lwt_switch.check switch; - Lwt_unix.accept lwt_socket >>= fun (client, _addr) -> - Log.info (fun f -> f ?tags "Accepting new connection"); - let secret_key = if serve_tls then Some (Vat_config.secret_key config) else None in - Lwt.async (fun () -> handle_connection ?tags ~secret_key vat client); - loop () - in - Lwt.async (fun () -> - Lwt.catch - (fun () -> - let th = loop () in - Lwt_switch.add_hook switch (fun () -> Lwt.cancel th; Lwt.return_unit); - th - ) - (function - | Lwt.Canceled -> Lwt.return_unit - | ex -> Lwt.fail ex - ) - >>= fun () -> - Lwt_unix.close lwt_socket + vat, socket + +let listen ?tags ~sw (config, vat, socket) = + while true do + let client, addr = Eio.Net.accept ~sw socket in + Log.info (fun f -> f ?tags "Accepting new connection from %a" Eio.Net.Sockaddr.pp addr); + let secret_key = if config.Vat_config.serve_tls then Some (Vat_config.secret_key config) else None in + Fiber.fork ~sw (fun () -> + (* We don't use [Net.accept_fork] here because this returns immediately after connecting. *) + handle_connection ?tags ~secret_key vat client + ) + done + +let serve ?tags ?restore ~sw ~net config = + let net = (net : #Eio.Net.t :> Eio.Net.t) in + let (vat, socket) = create_server ?tags ?restore ~sw ~net config in + Fiber.fork ~sw (fun () -> + listen ?tags ~sw (config, vat, socket) ); - Lwt.return vat + vat -let client_only_vat ?switch ?tags ?restore () = +let client_only_vat ?tags ?restore ~sw net = + let net = (net : #Eio.Net.t :> Eio.Net.t) in let secret_key = lazy (Capnp_rpc_net.Auth.Secret_key.generate ()) in - Vat.create ?switch ?tags ?restore ~secret_key () + Vat.create ?tags ?restore ~secret_key ~sw net let manpage_capnp_options = Vat_config.docs diff --git a/unix/capnp_rpc_unix.mli b/unix/capnp_rpc_unix.mli index 9d391a38c..e6f576494 100644 --- a/unix/capnp_rpc_unix.mli +++ b/unix/capnp_rpc_unix.mli @@ -3,10 +3,7 @@ open Capnp_rpc_lwt open Capnp_rpc_net -module Unix_flow = Unix_flow - include Capnp_rpc_net.VAT_NETWORK with - type flow = Unix_flow.flow and module Network = Network (** Configuration for a {!Vat}. *) @@ -66,7 +63,7 @@ module File_store : sig type 'a t (** A store of values of type ['a]. *) - val create : string -> 'a t + val create : _ Eio.Path.t -> 'a t (** [create dir] is a store for Cap'n Proto structs. Items are stored inside [dir]. *) @@ -102,7 +99,7 @@ val sturdy_uri : Uri.t Cmdliner.Arg.conv val connect_with_progress : ?mode:[`Auto | `Log | `Batch | `Console | `Silent] -> - 'a Sturdy_ref.t -> ('a Capability.t, Capnp_rpc.Exception.t) Lwt_result.t + 'a Sturdy_ref.t -> ('a Capability.t, Capnp_rpc.Exception.t) result (** [connect_with_progress sr] is like [Sturdy_ref.connect], but shows that a connection is in progress. Note: On failure, it does {e not} display the error, which should instead be handled by the caller. @param mode Controls how progress is displayed: @@ -116,26 +113,27 @@ val connect_with_progress : val with_cap_exn : ?progress:[`Auto | `Log | `Batch | `Console | `Silent] -> 'a Sturdy_ref.t -> - ('a Capability.t -> 'b Lwt.t) -> - 'b Lwt.t + ('a Capability.t -> 'b) -> + 'b (** Like [Sturdy_ref.with_cap_exn], but using [connect_with_progress] to show progress. *) val serve : - ?switch:Lwt_switch.t -> ?tags:Logs.Tag.set -> ?restore:Restorer.t -> + sw:Eio.Switch.t -> + net:#Eio.Net.t -> Vat_config.t -> - Vat.t Lwt.t -(** [serve ~restore vat_config] is a new vat that is listening for new connections + Vat.t +(** [serve ~restore ~sw ~net vat_config] is a new vat that is listening for new connections as specified by [vat_config]. After connecting to it, clients can get access to services using [restore]. *) val client_only_vat : - ?switch:Lwt_switch.t -> ?tags:Logs.Tag.set -> ?restore:Restorer.t -> - unit -> Vat.t -(** [client_only_vat ()] is a new vat that does not listen for incoming connections. *) + sw:Eio.Switch.t -> + #Eio.Net.t -> Vat.t +(** [client_only_vat net] is a new vat that does not listen for incoming connections. *) val manpage_capnp_options : string (** [manpage_capnp_options] is the title of the section of the man-page containing the Cap'n Proto options. diff --git a/unix/dune b/unix/dune index 95bba0fb4..bc7125ad5 100644 --- a/unix/dune +++ b/unix/dune @@ -1,5 +1,5 @@ (library (name capnp_rpc_unix) (public_name capnp-rpc-unix) - (libraries lwt.unix astring capnp-rpc-lwt capnp-rpc-net capnp-rpc fmt logs - mirage-crypto-rng.lwt cmdliner cstruct-lwt extunix)) + (libraries eio.unix astring capnp-rpc-lwt capnp-rpc-net capnp-rpc fmt logs + mirage-crypto-rng.lwt cmdliner cstruct extunix)) diff --git a/unix/file_store.ml b/unix/file_store.ml index da1059ea2..7683b114a 100644 --- a/unix/file_store.ml +++ b/unix/file_store.ml @@ -2,53 +2,43 @@ open Capnp_rpc_lwt module ReaderOps = Capnp.Runtime.ReaderInc.Make(Capnp_rpc_lwt) +let ( / ) = Eio.Path.( / ) + type 'a t = { - dir : string; + dir : Eio.Fs.dir Eio.Path.t; } -let create dir = { dir } +let create dir = { dir = (dir :> Eio.Fs.dir Eio.Path.t) } -let path_of_digest t digest = - match Base64.encode ~alphabet:Base64.uri_safe_alphabet ~pad:false digest with - | Ok filename -> Filename.concat t.dir filename - | Error (`Msg m) -> failwith m (* Encoding can't really fail *) +let leaf_of_digest digest = + Base64.encode_exn ~alphabet:Base64.uri_safe_alphabet ~pad:false digest let segments_of_reader = function | None -> [] | Some ss -> Message.to_storage ss.StructStorage.data.Slice.msg let save t ~digest data = - let path = path_of_digest t digest in - let tmp_path = path ^ ".new" in - let ch = open_out_bin tmp_path in - Fun.protect ~finally:(fun () -> close_out ch) (fun () -> + let leaf = leaf_of_digest digest in + let tmp_leaf = leaf ^ ".new" in + Eio.Path.with_open_out ~create:(`Exclusive 0o644) (t.dir / tmp_leaf) (fun flow -> let segments = segments_of_reader data in segments |> List.iter (fun {Message.segment; bytes_consumed} -> - output ch segment 0 bytes_consumed + let buf = Cstruct.of_bytes segment ~len:bytes_consumed in + Eio.Flow.copy (Eio.Flow.cstruct_source [buf]) flow ); ); - Unix.rename tmp_path path + Eio.Path.rename (t.dir / tmp_leaf) (t.dir / leaf) let remove t ~digest = - let path = path_of_digest t digest in - Unix.unlink path + Eio.Path.unlink (t.dir / leaf_of_digest digest) let load t ~digest = - let path = path_of_digest t digest in - if Sys.file_exists path then ( - let ch = open_in_bin path in - let segment = - Fun.protect ~finally:(fun () -> close_in ch) (fun () -> - let len = in_channel_length ch in - let segment = Bytes.create len in - really_input ch segment 0 len; - segment - ) - in - let msg = Message.of_storage [segment] in + let leaf = leaf_of_digest digest in + match Eio.Path.load (t.dir / leaf) with + | segment -> + let msg = Message.of_storage [Bytes.unsafe_of_string segment] in let reader = ReaderOps.get_root_struct (Message.readonly msg) in Some reader - ) else ( - Logs.info (fun f -> f "File %S not found" path); + | exception Eio.Fs.Not_found _ -> + Logs.info (fun f -> f "File %S not found" leaf); None - ) diff --git a/unix/network.ml b/unix/network.ml index 6e397f339..409fe2403 100644 --- a/unix/network.ml +++ b/unix/network.ml @@ -1,7 +1,5 @@ -open Lwt.Infix - module Log = Capnp_rpc.Debug.Log -module Tls_wrapper = Capnp_rpc_net.Tls_wrapper.Make(Unix_flow) +module Tls_wrapper = Capnp_rpc_net.Tls_wrapper module Location = struct open Astring @@ -50,7 +48,7 @@ module Types = struct type join_key_part end -type t = unit +type t = Eio.Net.t let error fmt = fmt |> Fmt.kstr @@ fun msg -> @@ -58,45 +56,41 @@ let error fmt = let parse_third_party_cap_id _ = `Two_party_only +let gethostbyname name = + Eio_unix.run_in_systhread (fun () -> Unix.gethostbyname name) + let addr_of_host host = - match Unix.gethostbyname host with + match gethostbyname host with | exception Not_found -> Capnp_rpc.Debug.failf "Unknown host %S" host | addr -> if Array.length addr.Unix.h_addr_list = 0 then Capnp_rpc.Debug.failf "No addresses found for host name %S" host else - addr.Unix.h_addr_list.(0) - -let connect_socket = function - | `Unix path -> - Log.info (fun f -> f "Connecting to %S..." path); - let socket = Lwt_unix.(socket PF_UNIX SOCK_STREAM 0) in - Lwt.catch - (fun () -> Lwt_unix.connect socket (Unix.ADDR_UNIX path) >|= fun () -> socket) - (fun ex -> Lwt_unix.close socket >>= fun () -> Lwt.fail ex) - | `TCP (host, port) -> - Log.info (fun f -> f "Connecting to %s:%d..." host port); - let socket = Lwt_unix.(socket PF_INET SOCK_STREAM 0) in - Lwt.catch - (fun () -> - Lwt_unix.setsockopt socket Unix.SO_KEEPALIVE true; - Keepalive.try_set_idle (Lwt_unix.unix_file_descr socket) 60; - Lwt_unix.connect socket (Unix.ADDR_INET (addr_of_host host, port)) >|= fun () -> - socket - ) - (fun ex -> Lwt_unix.close socket >>= fun () -> Lwt.fail ex) - -let connect () ~switch ~secret_key (addr, auth) = - Lwt.try_bind - (fun () -> connect_socket addr) - (fun socket -> - let flow = Unix_flow.connect ~switch socket in - Tls_wrapper.connect_as_client ~switch flow secret_key auth - ) - (fun ex -> - Lwt.return @@ error "@[Network connection for %a failed:@,%a@]" Location.pp addr Fmt.exn ex - ) - -let accept_connection ~switch ~secret_key flow = - Tls_wrapper.connect_as_server ~switch flow secret_key + Eio_unix.Ipaddr.of_unix addr.Unix.h_addr_list.(0) + +let connect net ~sw ~secret_key (addr, auth) = + let eio_addr = + match addr with + | `Unix _ as x -> x + | `TCP (host, port) -> + let host = addr_of_host host in + `Tcp (host, port) + in + Log.info (fun f -> f "Connecting to %a..." Eio.Net.Sockaddr.pp eio_addr); + match Eio.Net.connect ~sw net eio_addr with + | socket -> + begin match addr with + | `Unix _ -> () + | `TCP _ -> + (* TODO: check it's OK to set keep-alives after connecting *) + let socket = Option.get (Eio_unix.FD.peek_opt socket) in + Unix.setsockopt socket Unix.SO_KEEPALIVE true; + Keepalive.try_set_idle socket 60 + end; + Tls_wrapper.connect_as_client socket secret_key auth + | exception ex -> + error "@[Network connection for %a failed:@,%a@]" Location.pp addr Fmt.exn ex + +let accept_connection ~secret_key flow = + Tls_wrapper.connect_as_server flow secret_key diff --git a/unix/network.mli b/unix/network.mli index 7ba6d427c..8d5b095c6 100644 --- a/unix/network.mli +++ b/unix/network.mli @@ -25,14 +25,15 @@ module Location : sig end include Capnp_rpc_net.S.NETWORK with - type t = unit and + type t = Eio.Net.t and type Address.t = Location.t * Capnp_rpc_net.Auth.Digest.t val accept_connection : - switch:Lwt_switch.t -> secret_key:Capnp_rpc_net.Auth.Secret_key.t option -> - Unix_flow.flow -> - (Capnp_rpc_net.Endpoint.t, [> `Msg of string]) result Lwt.t + #Eio.Flow.two_way -> + (Capnp_rpc_net.Endpoint.t, [> `Msg of string]) result (** [accept_connection ~switch ~secret_key flow] is a new endpoint for [flow]. If [secret_key] is not [None], it is used to perform a TLS server-side handshake. Otherwise, the connection is not encrypted. *) + +val addr_of_host : string -> Eio.Net.Ipaddr.v4v6 diff --git a/unix/unix_flow.ml b/unix/unix_flow.ml deleted file mode 100644 index 8836b1e04..000000000 --- a/unix/unix_flow.ml +++ /dev/null @@ -1,95 +0,0 @@ -open Lwt.Infix - -(* Slightly rude to set signal handlers in a library, but SIGPIPE makes no sense - in a modern application. *) -let () = if not Sys.win32 then Sys.(set_signal sigpipe Signal_ignore) - -type flow = { - fd : Lwt_unix.file_descr; - mutable current_write : int Lwt.t option; - mutable current_read : int Lwt.t option; - mutable closed : bool; -} -type error = [`Exception of exn] -type write_error = [`Closed | `Exception of exn] - -let opt_cancel = function - | None -> () - | Some x -> Lwt.cancel x - -let close t = - if t.closed then Lwt.return_unit - else ( - t.closed <- true; - opt_cancel t.current_read; - opt_cancel t.current_write; - Lwt_unix.close t.fd - ) - -let pp_error f = function - | `Exception ex -> Fmt.exn f ex - | `Closed -> Fmt.string f "Closed" - -let pp_write_error = pp_error - -let write t buf = - let rec aux buf = - if t.closed then Lwt.return (Error `Closed) - else ( - assert (t.current_write = None); - let write_thread = Lwt_cstruct.write t.fd buf in - t.current_write <- Some write_thread; - write_thread >>= fun wrote -> - t.current_write <- None; - if wrote = Cstruct.length buf then Lwt.return (Ok ()) - else aux (Cstruct.shift buf wrote) - ) - in - Lwt.catch - (fun () -> aux buf) - (function - | Unix.Unix_error (Unix.ECONNRESET, _, _) - | Unix.Unix_error (Unix.ENOTCONN, _, _) (* macos *) - | Unix.Unix_error (Unix.EPIPE, _, _) -> Lwt.return @@ Error `Closed - | ex -> Lwt.return @@ Error (`Exception ex)) - -let rec writev t = function - | [] -> Lwt.return (Ok ()) - | x :: xs -> - write t x >>= function - | Ok () -> writev t xs - | Error _ as e -> Lwt.return e - -let read t = - let len = 4096 in - let buf = Cstruct.create_unsafe len in - Lwt.try_bind - (fun () -> - assert (t.current_read = None); - if t.closed then raise Lwt.Canceled; - let read_thread = Lwt_cstruct.read t.fd buf in - t.current_read <- Some read_thread; - read_thread - ) - (function - | 0 -> - Lwt.return @@ Ok `Eof - | got -> - t.current_read <- None; - Lwt.return @@ Ok (`Data (Cstruct.sub buf 0 got)) - ) - (function - | Lwt.Canceled - | Unix.Unix_error (Unix.EPIPE, _, _) - | Unix.Unix_error (Unix.ECONNRESET, _, _) -> Lwt_result.return `Eof - | ex -> Lwt.return @@ Error (`Exception ex) - ) - -let connect ?switch fd = - let t = { fd; closed = false; current_read = None; current_write = None } in - Lwt_switch.add_hook switch (fun () -> close t); - t - -let socketpair ?switch () = - let a, b = Lwt_unix.(socketpair PF_UNIX SOCK_STREAM 0) in - connect ?switch a, connect ?switch b diff --git a/unix/unix_flow.mli b/unix/unix_flow.mli deleted file mode 100644 index 78d85bfee..000000000 --- a/unix/unix_flow.mli +++ /dev/null @@ -1,7 +0,0 @@ -(** Wraps a Unix [file_descr] to provide the Mirage flow API. *) - -include Mirage_flow.S - -val connect : ?switch:Lwt_switch.t -> Lwt_unix.file_descr -> flow - -val socketpair : ?switch:Lwt_switch.t -> unit -> flow * flow diff --git a/unix/vat_network.ml b/unix/vat_network.ml index d5d810f8e..922b2eda7 100644 --- a/unix/vat_network.ml +++ b/unix/vat_network.ml @@ -1 +1 @@ -include Capnp_rpc_net.Networking (Network) (Unix_flow) +include Capnp_rpc_net.Networking (Network)