diff --git a/src/node/net/RPC.ml b/src/node/net/RPC.ml index 6bde5a48c..1b56982f3 100644 --- a/src/node/net/RPC.ml +++ b/src/node/net/RPC.ml @@ -36,30 +36,53 @@ exception Invalid_method exception Cannot_parse_body of string (* Promise a running RPC server. Takes the port. *) -let launch port root = +let launch port ?pre_hook ?post_hook root = (* launch the worker *) let cancelation, canceler, _ = Lwt_utils.canceler () in let open Cohttp_lwt_unix in - let create_stream, shutdown_stream = - let streams = ref ConnectionMap.empty in - let create _io con (s: _ Answer.stream) = - let running = ref true in - let stream = - Lwt_stream.from - (fun () -> - if not !running then Lwt.return None else - s.next () >|= function - | None -> None - | Some x -> Some (Data_encoding.Json.to_string x)) in - let shutdown () = running := false ; s.shutdown () in - streams := ConnectionMap.add con shutdown !streams ; - stream - in - let shutdown con = - try ConnectionMap.find con !streams () - with Not_found -> () in - create, shutdown + let streams = ref ConnectionMap.empty in + let create_stream _io con to_string (s: _ Answer.stream) = + let running = ref true in + let stream = + Lwt_stream.from + (fun () -> + if not !running then Lwt.return None else + s.next () >|= function + | None -> None + | Some x -> Some (to_string x)) in + let shutdown () = running := false ; s.shutdown () in + streams := ConnectionMap.add con shutdown !streams ; + stream in + let shutdown_stream con = + try ConnectionMap.find con !streams () + with Not_found -> () in + let call_hook (io, con) req ?(answer_404 = false) hook = + match hook with + | None -> Lwt.return None + | Some hook -> + Lwt.catch + (fun () -> + hook (Uri.path (Cohttp.Request.uri req)) + >>= fun { Answer.code ; body } -> + if code = 404 && not answer_404 then + Lwt.return None + else + let body = match body with + | Answer.Empty -> + Cohttp_lwt_body.empty + | Single body -> + Cohttp_lwt_body.of_string body + | Stream s -> + let stream = + create_stream io con (fun s -> s) s in + Cohttp_lwt_body.of_stream stream in + Lwt.return_some + (Response.make ~flush:true ~status:(`Code code) (), + body)) + (function + | Not_found -> Lwt.return None + | exn -> Lwt.fail exn) in let callback (io, con) req body = (* FIXME: check inbound adress *) let path = Utils.split_path (Uri.path (Cohttp.Request.uri req)) in @@ -67,39 +90,48 @@ let launch port root = (Cohttp.Connection.to_string con) (Uri.path (Cohttp.Request.uri req)) >>= fun () -> Lwt.catch (fun () -> - lookup root () path >>= fun handler -> - begin - match req.meth with - | `POST -> begin - Cohttp_lwt_body.to_string body >>= fun body -> - match Data_encoding.Json.from_string body with - | Error msg -> Lwt.fail (Cannot_parse_body msg) - | Ok body -> Lwt.return (Some body) - end - | `GET -> Lwt.return None - | _ -> Lwt.fail Invalid_method - end >>= fun body -> - handler body >>= fun { Answer.code ; body } -> - let body = match body with - | Empty -> - Cohttp_lwt_body.empty - | Single json -> - Cohttp_lwt_body.of_string (Data_encoding.Json.to_string json) - | Stream s -> - let stream = create_stream io con s in - Cohttp_lwt_body.of_stream stream in - lwt_log_info "(%s) RPC %s" - (Cohttp.Connection.to_string con) - (if Cohttp.Code.is_error code - then "failed" - else "success") >>= fun () -> - Lwt.return (Response.make ~flush:true ~status:(`Code code) (), body)) + call_hook (io, con) req pre_hook >>= function + | Some res -> + Lwt.return res + | None -> + lookup root () path >>= fun handler -> + begin + match req.meth with + | `POST -> begin + Cohttp_lwt_body.to_string body >>= fun body -> + match Data_encoding.Json.from_string body with + | Error msg -> Lwt.fail (Cannot_parse_body msg) + | Ok body -> Lwt.return (Some body) + end + | `GET -> Lwt.return None + | _ -> Lwt.fail Invalid_method + end >>= fun body -> + handler body >>= fun { Answer.code ; body } -> + let body = match body with + | Empty -> + Cohttp_lwt_body.empty + | Single json -> + Cohttp_lwt_body.of_string (Data_encoding.Json.to_string json) + | Stream s -> + let stream = + create_stream io con Data_encoding.Json.to_string s in + Cohttp_lwt_body.of_stream stream in + lwt_log_info "(%s) RPC %s" + (Cohttp.Connection.to_string con) + (if Cohttp.Code.is_error code + then "failed" + else "success") >>= fun () -> + Lwt.return (Response.make ~flush:true ~status:(`Code code) (), + body)) (function | Not_found | Cannot_parse _ -> lwt_log_info "(%s) not found" (Cohttp.Connection.to_string con) >>= fun () -> - Lwt.return (Response.make ~flush:true ~status:`Not_found (), - Cohttp_lwt_body.empty) + (call_hook (io, con) req ~answer_404: true post_hook >>= function + | Some res -> Lwt.return res + | None -> + Lwt.return (Response.make ~flush:true ~status:`Not_found (), + Cohttp_lwt_body.empty)) | Invalid_method -> lwt_log_info "(%s) bad method" (Cohttp.Connection.to_string con) >>= fun () -> diff --git a/src/node/net/RPC.mli b/src/node/net/RPC.mli index 77a47182d..0a8c5432b 100644 --- a/src/node/net/RPC.mli +++ b/src/node/net/RPC.mli @@ -272,7 +272,6 @@ val register_custom_lookup3: ('a -> 'b -> 'c -> string list -> custom_lookup Lwt.t) -> 'prefix directory - (** Registring a description service. *) val register_describe_directory_service: 'prefix directory -> @@ -283,13 +282,22 @@ val register_describe_directory_service: type server (** Promise a running RPC serve ; takes the port. To call - an RPX at /p/a/t/h/ in the provided service, one must call the URI + an RPC at /p/a/t/h/ in the provided service, one must call the URI /call/p/a/t/h/. Calling /list/p/a/t/h/ will list the services prefixed by /p/a/t/h/, if any. Calling /schema/p/a/t/h/ will describe the input and output of the service, if it is callable. Calling /pipe will read a sequence of services to call in - sequence from the request body, see {!pipe_encoding}. *) -val launch : int -> unit directory -> server Lwt.t + sequence from the request body, see {!pipe_encoding}. + + The optional [pre_hook] is called with the path part of the URL + before resolving each request, to delegate the answering to + another resolution mechanism. Its result is ignored if the return + code is [404]. The optional [post_hook] is called if both the + [pre_hook] and the serviced answered with a [404] code. *) +val launch : int -> + ?pre_hook: (string -> string Answer.answer Lwt.t) -> + ?post_hook: (string -> string Answer.answer Lwt.t) -> + unit directory -> server Lwt.t (** Kill an RPC server. *) val shutdown : server -> unit Lwt.t