375 lines
15 KiB
OCaml
375 lines
15 KiB
OCaml
(*****************************************************************************)
|
|
(* *)
|
|
(* Open Source License *)
|
|
(* Copyright (c) 2018 Dynamic Ledger Solutions, Inc. <contact@tezos.com> *)
|
|
(* *)
|
|
(* Permission is hereby granted, free of charge, to any person obtaining a *)
|
|
(* copy of this software and associated documentation files (the "Software"),*)
|
|
(* to deal in the Software without restriction, including without limitation *)
|
|
(* the rights to use, copy, modify, merge, publish, distribute, sublicense, *)
|
|
(* and/or sell copies of the Software, and to permit persons to whom the *)
|
|
(* Software is furnished to do so, subject to the following conditions: *)
|
|
(* *)
|
|
(* The above copyright notice and this permission notice shall be included *)
|
|
(* in all copies or substantial portions of the Software. *)
|
|
(* *)
|
|
(* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR*)
|
|
(* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *)
|
|
(* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *)
|
|
(* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER*)
|
|
(* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING *)
|
|
(* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER *)
|
|
(* DEALINGS IN THE SOFTWARE. *)
|
|
(* *)
|
|
(*****************************************************************************)
|
|
|
|
open Lwt.Infix
|
|
|
|
module ConnectionMap = Map.Make(Cohttp.Connection)
|
|
|
|
module type LOGGING = sig
|
|
|
|
val debug: ('a, Format.formatter, unit, unit) format4 -> 'a
|
|
val log_info: ('a, Format.formatter, unit, unit) format4 -> 'a
|
|
val log_notice: ('a, Format.formatter, unit, unit) format4 -> 'a
|
|
val warn: ('a, Format.formatter, unit, unit) format4 -> 'a
|
|
val log_error: ('a, Format.formatter, unit, unit) format4 -> 'a
|
|
|
|
val lwt_debug: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
|
|
val lwt_log_info: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
|
|
val lwt_log_notice: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
|
|
val lwt_warn: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
|
|
val lwt_log_error: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
|
|
|
|
end
|
|
|
|
module Make (Encoding : Resto.ENCODING)(Log : LOGGING) = struct
|
|
|
|
open Log
|
|
open Cohttp
|
|
|
|
module Service = Resto.MakeService(Encoding)
|
|
module Directory = Resto_directory.Make(Encoding)
|
|
|
|
module Media_type = Media_type.Make(Encoding)
|
|
|
|
type server = {
|
|
root : unit Directory.directory ;
|
|
mutable streams : (unit -> unit) ConnectionMap.t ;
|
|
cors : Cors.t ;
|
|
media_types : Media_type.t list ;
|
|
default_media_type : string * Media_type.t ;
|
|
stopper : unit Lwt.u ;
|
|
mutable worker : unit Lwt.t ;
|
|
}
|
|
|
|
let create_stream server con to_string s =
|
|
let running = ref true in
|
|
let stream =
|
|
Lwt_stream.from
|
|
(fun () ->
|
|
if not !running then
|
|
Lwt.return None
|
|
else
|
|
s.Resto_directory.Answer.next () >|= function
|
|
| None -> None
|
|
| Some x -> Some (to_string x)) in
|
|
let shutdown () =
|
|
running := false ;
|
|
s.shutdown () ;
|
|
server.streams <- ConnectionMap.remove con server.streams in
|
|
server.streams <- ConnectionMap.add con shutdown server.streams ;
|
|
stream
|
|
|
|
let (>>=?) m f =
|
|
m >>= function
|
|
| Ok x -> f x
|
|
| Error err -> Lwt.return_error err
|
|
|
|
let callback server (_io, con) req body =
|
|
(* FIXME: check inbound adress *)
|
|
let uri = Request.uri req in
|
|
let path = Uri.pct_decode (Uri.path uri) in
|
|
lwt_log_info "(%s) receive request to %s"
|
|
(Connection.to_string con) path >>= fun () ->
|
|
let path = Utils.split_path path in
|
|
let req_headers = Request.headers req in
|
|
begin
|
|
match Request.meth req with
|
|
| #Resto.meth as meth -> begin
|
|
Directory.lookup server.root ()
|
|
meth path >>=? fun (Directory.Service s) ->
|
|
begin
|
|
match Header.get req_headers "content-type" with
|
|
| None -> Lwt.return_ok (snd server.default_media_type)
|
|
| Some content_type ->
|
|
match Utils.split_path content_type with
|
|
| [x ; y] -> begin
|
|
match Media_type.find_media (x, y) server.media_types with
|
|
| None ->
|
|
Lwt.return_error (`Unsupported_media_type content_type)
|
|
| Some media_type ->
|
|
Lwt.return_ok media_type
|
|
end
|
|
| _ ->
|
|
Lwt.return_error (`Unsupported_media_type content_type)
|
|
end >>=? fun input_media_type ->
|
|
lwt_debug "(%s) input media type %s"
|
|
(Connection.to_string con)
|
|
(Media_type.name input_media_type) >>= fun () ->
|
|
begin
|
|
match Header.get req_headers "accept" with
|
|
| None -> Lwt.return_ok server.default_media_type
|
|
| Some accepted ->
|
|
match Media_type.resolve_accept_header
|
|
server.media_types (Some accepted) with
|
|
| None -> Lwt.return_error `Not_acceptable
|
|
| Some media_type -> Lwt.return_ok media_type
|
|
end >>=? fun (output_content_type, output_media_type) ->
|
|
begin
|
|
match Resto.Query.parse s.types.query
|
|
(List.map
|
|
(fun (k, l) -> (k, String.concat "," l))
|
|
(Uri.query uri)) with
|
|
| exception (Resto.Query.Invalid s) ->
|
|
Lwt.return_error (`Cannot_parse_query s)
|
|
| query -> Lwt.return_ok query
|
|
end >>=? fun query ->
|
|
lwt_debug "(%s) ouput media type %s"
|
|
(Connection.to_string con)
|
|
(Media_type.name output_media_type) >>= fun () ->
|
|
let output = output_media_type.construct s.types.output
|
|
and error = function
|
|
| None -> Cohttp_lwt.Body.empty, Transfer.Fixed 0L
|
|
| Some e ->
|
|
let s = output_media_type.construct s.types.error e in
|
|
Cohttp_lwt.Body.of_string s,
|
|
Transfer.Fixed (Int64.of_int (String.length s)) in
|
|
let headers = Header.init () in
|
|
let headers =
|
|
Header.add headers "content-type" output_content_type in
|
|
let headers = Cors.add_allow_origin
|
|
headers server.cors (Header.get req_headers "origin") in
|
|
begin
|
|
match s.types.input with
|
|
| Service.No_input ->
|
|
s.handler query () >>= Lwt.return_ok
|
|
| Service.Input input ->
|
|
Cohttp_lwt.Body.to_string body >>= fun body ->
|
|
match
|
|
input_media_type.destruct input body
|
|
with
|
|
| Error s ->
|
|
Lwt.return_error (`Cannot_parse_body s)
|
|
| Ok body ->
|
|
s.handler query body >>= Lwt.return_ok
|
|
end >>=? function
|
|
| `Ok o ->
|
|
let body = output o in
|
|
let encoding =
|
|
Transfer.Fixed (Int64.of_int (String.length body)) in
|
|
Lwt.return_ok
|
|
(Response.make ~status:`OK ~encoding ~headers (),
|
|
Cohttp_lwt.Body.of_string body)
|
|
| `OkStream o ->
|
|
let body = create_stream server con output o in
|
|
let encoding = Transfer.Chunked in
|
|
Lwt.return_ok
|
|
(Response.make ~status:`OK ~encoding ~headers (),
|
|
Cohttp_lwt.Body.of_stream body)
|
|
| `Created s ->
|
|
let headers = Header.init () in
|
|
let headers =
|
|
match s with
|
|
| None -> headers
|
|
| Some s -> Header.add headers "location" s in
|
|
Lwt.return_ok
|
|
(Response.make ~status:`Created ~headers (),
|
|
Cohttp_lwt.Body.empty)
|
|
| `No_content ->
|
|
Lwt.return_ok
|
|
(Response.make ~status:`No_content (),
|
|
Cohttp_lwt.Body.empty)
|
|
| `Unauthorized e ->
|
|
let body, encoding = error e in
|
|
let status = `Unauthorized in
|
|
Lwt.return_ok
|
|
(Response.make ~status ~encoding ~headers (), body)
|
|
| `Forbidden e ->
|
|
let body, encoding = error e in
|
|
let status = `Forbidden in
|
|
Lwt.return_ok
|
|
(Response.make ~status ~encoding ~headers (), body)
|
|
| `Not_found e ->
|
|
let body, encoding = error e in
|
|
let status = `Not_found in
|
|
Lwt.return_ok
|
|
(Response.make ~status ~encoding ~headers (), body)
|
|
| `Conflict e ->
|
|
let body, encoding = error e in
|
|
let status = `Conflict in
|
|
Lwt.return_ok
|
|
(Response.make ~status ~encoding ~headers (), body)
|
|
| `Error e ->
|
|
let body, encoding = error e in
|
|
let status = `Internal_server_error in
|
|
Lwt.return_ok
|
|
(Response.make ~status ~encoding ~headers (), body)
|
|
end
|
|
| `HEAD ->
|
|
(* TODO ??? *)
|
|
Lwt.return_error `Not_implemented
|
|
| `OPTIONS ->
|
|
let req_headers = Request.headers req in
|
|
let origin_header = Header.get req_headers "origin" in
|
|
begin
|
|
(* Default OPTIONS handler for CORS preflight *)
|
|
if origin_header = None then
|
|
Directory.allowed_methods server.root () path
|
|
else
|
|
match Header.get req_headers
|
|
"Access-Control-Request-Method" with
|
|
| None ->
|
|
Directory.allowed_methods server.root () path
|
|
| Some meth ->
|
|
match Code.method_of_string meth with
|
|
| #Resto.meth as meth ->
|
|
Directory.lookup server.root () meth path >>=? fun _handler ->
|
|
Lwt.return_ok [ meth ]
|
|
| _ ->
|
|
Lwt.return_error `Not_found
|
|
end >>=? fun cors_allowed_meths ->
|
|
lwt_log_info "(%s) RPC preflight"
|
|
(Connection.to_string con) >>= fun () ->
|
|
let headers = Header.init () in
|
|
let headers =
|
|
Header.add_multi headers
|
|
"Access-Control-Allow-Methods"
|
|
(List.map Resto.string_of_meth cors_allowed_meths) in
|
|
let headers = Cors.add_headers headers server.cors origin_header in
|
|
Lwt.return_ok
|
|
(Response.make ~flush:true ~status:`OK ~headers (),
|
|
Cohttp_lwt.Body.empty)
|
|
| _ ->
|
|
Lwt.return_error `Not_implemented
|
|
end >>= function
|
|
| Ok answer -> Lwt.return answer
|
|
| Error `Not_implemented ->
|
|
Lwt.return
|
|
(Response.make ~status:`Not_implemented (),
|
|
Cohttp_lwt.Body.empty)
|
|
| Error `Method_not_allowed methods ->
|
|
let headers = Header.init () in
|
|
let headers =
|
|
Header.add_multi headers "allow"
|
|
(List.map Resto.string_of_meth methods) in
|
|
Lwt.return
|
|
(Response.make ~status:`Method_not_allowed ~headers (),
|
|
Cohttp_lwt.Body.empty)
|
|
| Error `Cannot_parse_path (context, arg, value) ->
|
|
let headers = Header.init () in
|
|
let headers =
|
|
Header.add headers "content-type" "text/plain" in
|
|
Lwt.return
|
|
(Response.make ~status:`Bad_request ~headers (),
|
|
Format.kasprintf Cohttp_lwt.Body.of_string
|
|
"Failed to parsed an argument in path. After \"%s\", \
|
|
the value \"%s\" is not acceptable for type \"%s\""
|
|
(String.concat "/" context) value arg.name)
|
|
| Error `Cannot_parse_body s ->
|
|
let headers = Header.init () in
|
|
let headers =
|
|
Header.add headers "content-type" "text/plain" in
|
|
Lwt.return
|
|
(Response.make ~status:`Bad_request ~headers (),
|
|
Format.kasprintf Cohttp_lwt.Body.of_string
|
|
"Failed to parse the request body: %s" s)
|
|
| Error `Cannot_parse_query s ->
|
|
let headers = Header.init () in
|
|
let headers =
|
|
Header.add headers "content-type" "text/plain" in
|
|
Lwt.return
|
|
(Response.make ~status:`Bad_request ~headers (),
|
|
Format.kasprintf Cohttp_lwt.Body.of_string
|
|
"Failed to parse the query string: %s" s)
|
|
| Error `Not_acceptable ->
|
|
let accepted_encoding =
|
|
Media_type.acceptable_encoding server.media_types in
|
|
Lwt.return
|
|
(Response.make ~status:`Not_acceptable (),
|
|
Cohttp_lwt.Body.of_string accepted_encoding)
|
|
| Error `Unsupported_media_type _ ->
|
|
Lwt.return
|
|
(Response.make ~status:`Unsupported_media_type (),
|
|
Cohttp_lwt.Body.empty)
|
|
| Error `Not_found ->
|
|
Lwt.return
|
|
(Response.make ~status:`Not_found (),
|
|
Cohttp_lwt.Body.empty)
|
|
|
|
(* Promise a running RPC server. *)
|
|
|
|
let launch
|
|
?(host="::")
|
|
?(cors = Cors.default)
|
|
~media_types
|
|
mode root =
|
|
let default_media_type =
|
|
match Media_type.first_complete_media media_types with
|
|
| None -> invalid_arg "RestoCohttp.launch(empty media type list)"
|
|
| Some ((l, r), m) -> l^"/"^r, m in
|
|
let stop, stopper = Lwt.wait () in
|
|
let server = {
|
|
root ;
|
|
streams = ConnectionMap.empty ;
|
|
cors ;
|
|
media_types ;
|
|
default_media_type ;
|
|
stopper ;
|
|
worker = Lwt.return_unit ;
|
|
} in
|
|
Conduit_lwt_unix.init ~src:host () >>= fun ctx ->
|
|
let ctx = Cohttp_lwt_unix.Net.init ~ctx () in
|
|
server.worker <- begin
|
|
let conn_closed (_, con) =
|
|
log_info "connection closed %s" (Connection.to_string con) ;
|
|
try ConnectionMap.find con server.streams ()
|
|
with Not_found -> ()
|
|
and on_exn = function
|
|
| Unix.Unix_error (Unix.EADDRINUSE, "bind", _) ->
|
|
log_error "RPC server port already taken, \
|
|
the node will be shutdown" ;
|
|
exit 1
|
|
| Unix.Unix_error (ECONNRESET, _, _)
|
|
| Unix.Unix_error (EPIPE, _, _) -> ()
|
|
| exn -> !Lwt.async_exception_hook exn
|
|
and callback (io, con) req body =
|
|
Lwt.catch
|
|
begin fun () -> callback server (io, con) req body end
|
|
begin function
|
|
| Not_found ->
|
|
let status = `Not_found in
|
|
let body = Cohttp_lwt.Body.empty in
|
|
Lwt.return (Response.make ~status (), body)
|
|
| exn ->
|
|
let headers = Header.init () in
|
|
let headers =
|
|
Header.add headers "content-type" "text/ocaml.exception" in
|
|
let status = `Internal_server_error in
|
|
let body = Cohttp_lwt.Body.of_string (Printexc.to_string exn) in
|
|
Lwt.return (Response.make ~status ~headers (), body)
|
|
end
|
|
in
|
|
Cohttp_lwt_unix.Server.create ~stop ~ctx ~mode ~on_exn
|
|
(Cohttp_lwt_unix.Server.make ~callback ~conn_closed ())
|
|
end ;
|
|
Lwt.return server
|
|
|
|
let shutdown server =
|
|
Lwt.wakeup_later server.stopper () ;
|
|
server.worker >>= fun () ->
|
|
ConnectionMap.iter (fun _ f -> f ()) server.streams ;
|
|
Lwt.return_unit
|
|
|
|
end
|