444 lines
16 KiB
OCaml
444 lines
16 KiB
OCaml
(**************************************************************************)
|
|
(* *)
|
|
(* Copyright (c) 2014 - 2017. *)
|
|
(* Dynamic Ledger Solutions, Inc. <contact@tezos.com> *)
|
|
(* *)
|
|
(* All rights reserved. No warranty, explicit or implicit, provided. *)
|
|
(* *)
|
|
(**************************************************************************)
|
|
|
|
open Lwt.Infix
|
|
|
|
module Utils = struct
|
|
|
|
let split_path path =
|
|
let l = String.length path in
|
|
let rec do_slashes acc i =
|
|
if i >= l then
|
|
List.rev acc
|
|
else if String.get path i = '/' then
|
|
do_slashes acc (i + 1)
|
|
else
|
|
do_component acc i i
|
|
and do_component acc i j =
|
|
if j >= l then
|
|
if i = j then
|
|
List.rev acc
|
|
else
|
|
List.rev (String.sub path i (j - i) :: acc)
|
|
else if String.get path j = '/' then
|
|
do_slashes (String.sub path i (j - i) :: acc) j
|
|
else
|
|
do_component acc i (j + 1) in
|
|
do_slashes [] 0
|
|
|
|
end
|
|
|
|
type cors = {
|
|
allowed_headers : string list ;
|
|
allowed_origins : string list ;
|
|
}
|
|
|
|
module Cors = struct
|
|
|
|
let default = { allowed_headers = [] ; allowed_origins = [] }
|
|
|
|
let check_origin_matches origin allowed_origin =
|
|
String.equal "*" allowed_origin ||
|
|
String.equal allowed_origin origin ||
|
|
begin
|
|
let allowed_w_slash = allowed_origin ^ "/" in
|
|
let len_a_w_s = String.length allowed_w_slash in
|
|
let len_o = String.length origin in
|
|
(len_o >= len_a_w_s) &&
|
|
String.equal allowed_w_slash @@ String.sub origin 0 len_a_w_s
|
|
end
|
|
|
|
let find_matching_origin allowed_origins origin =
|
|
let matching_origins =
|
|
List.filter (check_origin_matches origin) allowed_origins in
|
|
let compare_by_length_neg a b =
|
|
~- (compare (String.length a) (String.length b)) in
|
|
let matching_origins_sorted =
|
|
List.sort compare_by_length_neg matching_origins in
|
|
match matching_origins_sorted with
|
|
| [] -> None
|
|
| x :: _ -> Some x
|
|
|
|
let add_headers headers cors origin_header =
|
|
let cors_headers =
|
|
Cohttp.Header.add_multi headers
|
|
"Access-Control-Allow-Headers" cors.allowed_headers in
|
|
match origin_header with
|
|
| None -> cors_headers
|
|
| Some origin ->
|
|
match find_matching_origin cors.allowed_origins origin with
|
|
| None -> cors_headers
|
|
| Some allowed_origin ->
|
|
Cohttp.Header.add_multi cors_headers
|
|
"Access-Control-Allow-Origin" [allowed_origin]
|
|
|
|
end
|
|
|
|
module ConnectionMap = Map.Make(Cohttp.Connection)
|
|
|
|
module type LOGGING = sig
|
|
|
|
val debug: ('a, Format.formatter, unit, unit) format4 -> 'a
|
|
val log_info: ('a, Format.formatter, unit, unit) format4 -> 'a
|
|
val log_notice: ('a, Format.formatter, unit, unit) format4 -> 'a
|
|
val warn: ('a, Format.formatter, unit, unit) format4 -> 'a
|
|
val log_error: ('a, Format.formatter, unit, unit) format4 -> 'a
|
|
|
|
val lwt_debug: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
|
|
val lwt_log_info: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
|
|
val lwt_log_notice: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
|
|
val lwt_warn: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
|
|
val lwt_log_error: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
|
|
|
|
end
|
|
|
|
module Make (Encoding : Resto.ENCODING)(Log : LOGGING) = struct
|
|
|
|
open Log
|
|
open Cohttp
|
|
|
|
module Service = Resto.MakeService(Encoding)
|
|
module Directory = Resto_directory.Make(Encoding)
|
|
|
|
type media_type = {
|
|
name: string ;
|
|
construct: 'a. 'a Encoding.t -> 'a -> string ;
|
|
destruct: 'a. 'a Encoding.t -> string -> ('a, string) result ;
|
|
}
|
|
|
|
module Media_type = struct
|
|
|
|
(* Inspired from ocaml-webmachine *)
|
|
|
|
let media_match (_, (range, _)) media =
|
|
let type_, subtype =
|
|
match Utils.split_path media.name with
|
|
| [x ; y] -> x, y
|
|
| _ ->
|
|
Format.kasprintf invalid_arg "invalid media_type '%s'" media.name in
|
|
let open Accept in
|
|
match range with
|
|
| AnyMedia -> true
|
|
| AnyMediaSubtype type_' -> type_' = type_
|
|
| MediaType (type_', subtype') -> type_' = type_ && subtype' = subtype
|
|
|
|
let match_header provided header =
|
|
let ranges = Accept.(media_ranges header |> qsort) in
|
|
let rec loop = function
|
|
| [] -> None
|
|
| r :: rs ->
|
|
try Some(List.find (media_match r) provided)
|
|
with Not_found -> loop rs
|
|
in
|
|
loop ranges
|
|
|
|
end
|
|
|
|
type server = {
|
|
root : unit Directory.directory ;
|
|
mutable streams : (unit -> unit) ConnectionMap.t ;
|
|
cors : cors ;
|
|
media_types : media_type list ;
|
|
default_media_type : media_type ;
|
|
stopper : unit Lwt.u ;
|
|
mutable worker : unit Lwt.t ;
|
|
}
|
|
|
|
let create_stream server con to_string s =
|
|
let running = ref true in
|
|
let stream =
|
|
Lwt_stream.from
|
|
(fun () ->
|
|
if not !running then
|
|
Lwt.return None
|
|
else
|
|
s.Resto_directory.Answer.next () >|= function
|
|
| None -> None
|
|
| Some x -> Some (to_string x)) in
|
|
let shutdown () =
|
|
running := false ;
|
|
s.shutdown () ;
|
|
server.streams <- ConnectionMap.remove con server.streams in
|
|
server.streams <- ConnectionMap.add con shutdown server.streams ;
|
|
stream
|
|
|
|
let (>>=?) m f =
|
|
m >>= function
|
|
| Ok x -> f x
|
|
| Error err -> Lwt.return_error err
|
|
|
|
let callback server (_io, con) req body =
|
|
(* FIXME: check inbound adress *)
|
|
let uri = Request.uri req in
|
|
lwt_log_info "(%s) receive request to %s"
|
|
(Connection.to_string con) (Uri.path uri) >>= fun () ->
|
|
let path = Utils.split_path (Uri.path uri) in
|
|
let req_headers = Request.headers req in
|
|
begin
|
|
match Request.meth req with
|
|
| #Resto.meth as meth -> begin
|
|
Directory.lookup server.root ()
|
|
meth path >>=? fun (Directory.Service s) ->
|
|
begin
|
|
match Header.get req_headers "content-type" with
|
|
| None -> Lwt.return_ok server.default_media_type
|
|
| Some content_type ->
|
|
match List.find (fun { name ; _ } -> name = content_type)
|
|
server.media_types with
|
|
| exception Not_found ->
|
|
Lwt.return_error (`Unsupported_media_type content_type)
|
|
| media_type -> Lwt.return_ok media_type
|
|
end >>=? fun input_media_type ->
|
|
begin
|
|
match Header.get req_headers "accept" with
|
|
| None -> Lwt.return_ok server.default_media_type
|
|
| Some accepted ->
|
|
match Media_type.match_header
|
|
server.media_types (Some accepted) with
|
|
| None -> Lwt.return_error `Not_acceptable
|
|
| Some media_type -> Lwt.return_ok media_type
|
|
end >>=? fun output_media_type ->
|
|
begin
|
|
match Resto.Query.parse s.types.query
|
|
(List.map
|
|
(fun (k, l) -> (k, String.concat "," l))
|
|
(Uri.query uri)) with
|
|
| exception (Resto.Query.Invalid s) ->
|
|
Lwt.return_error (`Cannot_parse_query s)
|
|
| query -> Lwt.return_ok query
|
|
end >>=? fun query ->
|
|
let output = output_media_type.construct s.types.output
|
|
and error = function
|
|
| None -> Cohttp_lwt.Body.empty, Transfer.Fixed 0L
|
|
| Some e ->
|
|
let s = output_media_type.construct s.types.error e in
|
|
Cohttp_lwt.Body.of_string s,
|
|
Transfer.Fixed (Int64.of_int (String.length s)) in
|
|
let headers = Header.init () in
|
|
let headers =
|
|
Header.add headers "content-type" output_media_type.name in
|
|
begin
|
|
match s.types.input with
|
|
| Service.No_input ->
|
|
s.handler query () >>= Lwt.return_ok
|
|
| Service.Input input ->
|
|
Cohttp_lwt.Body.to_string body >>= fun body ->
|
|
match
|
|
input_media_type.destruct input body
|
|
with
|
|
| Error s ->
|
|
Lwt.return_error (`Cannot_parse_body s)
|
|
| Ok body ->
|
|
s.handler query body >>= Lwt.return_ok
|
|
end >>=? function
|
|
| `Ok o ->
|
|
let body = output o in
|
|
let encoding =
|
|
Transfer.Fixed (Int64.of_int (String.length body)) in
|
|
Lwt.return_ok
|
|
(Response.make ~status:`OK ~encoding ~headers (),
|
|
Cohttp_lwt.Body.of_string body)
|
|
| `OkStream o ->
|
|
let body = create_stream server con output o in
|
|
let encoding = Transfer.Chunked in
|
|
Lwt.return_ok
|
|
(Response.make ~status:`OK ~encoding ~headers (),
|
|
Cohttp_lwt.Body.of_stream body)
|
|
| `Created s ->
|
|
let headers = Header.init () in
|
|
let headers =
|
|
match s with
|
|
| None -> headers
|
|
| Some s -> Header.add headers "location" s in
|
|
Lwt.return_ok
|
|
(Response.make ~status:`Created ~headers (),
|
|
Cohttp_lwt.Body.empty)
|
|
| `No_content ->
|
|
Lwt.return_ok
|
|
(Response.make ~status:`No_content (),
|
|
Cohttp_lwt.Body.empty)
|
|
| `Unauthorized e ->
|
|
let body, encoding = error e in
|
|
let status = `Unauthorized in
|
|
Lwt.return_ok
|
|
(Response.make ~status ~encoding ~headers (), body)
|
|
| `Forbidden e ->
|
|
let body, encoding = error e in
|
|
let status = `Forbidden in
|
|
Lwt.return_ok
|
|
(Response.make ~status ~encoding ~headers (), body)
|
|
| `Not_found e ->
|
|
let body, encoding = error e in
|
|
let status = `Not_found in
|
|
Lwt.return_ok
|
|
(Response.make ~status ~encoding ~headers (), body)
|
|
| `Conflict e ->
|
|
let body, encoding = error e in
|
|
let status = `Conflict in
|
|
Lwt.return_ok
|
|
(Response.make ~status ~encoding ~headers (), body)
|
|
| `Error e ->
|
|
let body, encoding = error e in
|
|
let status = `Internal_server_error in
|
|
Lwt.return_ok
|
|
(Response.make ~status ~encoding ~headers (), body)
|
|
end
|
|
| `HEAD ->
|
|
(* TODO ??? *)
|
|
Lwt.return_error `Not_implemented
|
|
| `OPTIONS ->
|
|
let req_headers = Request.headers req in
|
|
let origin_header = Header.get req_headers "origin" in
|
|
begin
|
|
(* Default OPTIONS handler for CORS preflight *)
|
|
if origin_header = None then
|
|
Directory.allowed_methods server.root () path
|
|
else
|
|
match Header.get req_headers
|
|
"Access-Control-Request-Method" with
|
|
| None ->
|
|
Directory.allowed_methods server.root () path
|
|
| Some meth ->
|
|
match Code.method_of_string meth with
|
|
| #Resto.meth as meth ->
|
|
Directory.lookup server.root () meth path >>=? fun _handler ->
|
|
Lwt.return_ok [ meth ]
|
|
| _ ->
|
|
Lwt.return_error `Not_found
|
|
end >>=? fun cors_allowed_meths ->
|
|
lwt_log_info "(%s) RPC preflight"
|
|
(Connection.to_string con) >>= fun () ->
|
|
let headers = Header.init () in
|
|
let headers =
|
|
Header.add_multi headers
|
|
"Access-Control-Allow-Methods"
|
|
(List.map Resto.string_of_meth cors_allowed_meths) in
|
|
let headers = Cors.add_headers headers server.cors origin_header in
|
|
Lwt.return_ok
|
|
(Response.make ~flush:true ~status:`OK ~headers (),
|
|
Cohttp_lwt.Body.empty)
|
|
| _ ->
|
|
Lwt.return_error `Not_implemented
|
|
end >>= function
|
|
| Ok answer -> Lwt.return answer
|
|
| Error `Not_implemented ->
|
|
Lwt.return
|
|
(Response.make ~status:`Not_implemented (),
|
|
Cohttp_lwt.Body.empty)
|
|
| Error `Method_not_allowed methods ->
|
|
let headers = Header.init () in
|
|
let headers =
|
|
Header.add_multi headers "allow"
|
|
(List.map Resto.string_of_meth methods) in
|
|
Lwt.return
|
|
(Response.make ~status:`Method_not_allowed ~headers (),
|
|
Cohttp_lwt.Body.empty)
|
|
| Error `Cannot_parse_path (context, arg, value) ->
|
|
let headers = Header.init () in
|
|
let headers =
|
|
Header.add headers "content-type" "text/plain" in
|
|
Lwt.return
|
|
(Response.make ~status:`Bad_request ~headers (),
|
|
Format.kasprintf Cohttp_lwt.Body.of_string
|
|
"Failed to parsed an argument in path. After \"%s\", \
|
|
the value \"%s\" is not acceptable for type \"%s\""
|
|
(String.concat "/" context) value arg.name)
|
|
| Error `Cannot_parse_body s ->
|
|
let headers = Header.init () in
|
|
let headers =
|
|
Header.add headers "content-type" "text/plain" in
|
|
Lwt.return
|
|
(Response.make ~status:`Bad_request ~headers (),
|
|
Format.kasprintf Cohttp_lwt.Body.of_string
|
|
"Failed to parse the request body: %s" s)
|
|
| Error `Cannot_parse_query s ->
|
|
let headers = Header.init () in
|
|
let headers =
|
|
Header.add headers "content-type" "text/plain" in
|
|
Lwt.return
|
|
(Response.make ~status:`Bad_request ~headers (),
|
|
Format.kasprintf Cohttp_lwt.Body.of_string
|
|
"Failed to parse the query string: %s" s)
|
|
| Error `Not_acceptable ->
|
|
let accepted_encoding =
|
|
String.concat ", "
|
|
(List.map (fun f -> f.name)
|
|
server.media_types) in
|
|
Lwt.return
|
|
(Response.make ~status:`Not_acceptable (),
|
|
Cohttp_lwt.Body.of_string accepted_encoding)
|
|
| Error `Unsupported_media_type _ ->
|
|
Lwt.return
|
|
(Response.make ~status:`Unsupported_media_type (),
|
|
Cohttp_lwt.Body.empty)
|
|
| Error `Not_found ->
|
|
Lwt.return
|
|
(Response.make ~status:`Not_found (),
|
|
Cohttp_lwt.Body.empty)
|
|
|
|
(* Promise a running RPC server. *)
|
|
|
|
let launch
|
|
?(host="::")
|
|
?(cors = Cors.default)
|
|
~media_types
|
|
mode root =
|
|
if media_types = [] then
|
|
invalid_arg "RestoCohttp.launch(empty media type list)" ;
|
|
let default_media_type = List.hd media_types in
|
|
let stop, stopper = Lwt.wait () in
|
|
let server = {
|
|
root ;
|
|
streams = ConnectionMap.empty ;
|
|
cors ;
|
|
media_types ;
|
|
default_media_type ;
|
|
stopper ;
|
|
worker = Lwt.return_unit ;
|
|
} in
|
|
Conduit_lwt_unix.init ~src:host () >>= fun ctx ->
|
|
let ctx = Cohttp_lwt_unix.Net.init ~ctx () in
|
|
server.worker <- begin
|
|
let conn_closed (_, con) =
|
|
log_info "connection closed %s" (Connection.to_string con) ;
|
|
try ConnectionMap.find con server.streams ()
|
|
with Not_found -> ()
|
|
and on_exn = function
|
|
| Unix.Unix_error (Unix.EADDRINUSE, "bind", _) ->
|
|
log_error "RPC server port already taken, \
|
|
the node will be shutdown" ;
|
|
exit 1
|
|
| Unix.Unix_error (ECONNRESET, _, _)
|
|
| Unix.Unix_error (EPIPE, _, _) -> ()
|
|
| exn -> !Lwt.async_exception_hook exn
|
|
and callback (io, con) req body =
|
|
Lwt.catch
|
|
begin fun () -> callback server (io, con) req body end
|
|
begin fun exn ->
|
|
let headers = Header.init () in
|
|
let headers =
|
|
Header.add headers "content-type" "text/ocaml.exception" in
|
|
let status = `Internal_server_error in
|
|
let body = Cohttp_lwt.Body.of_string (Printexc.to_string exn) in
|
|
Lwt.return (Response.make ~status ~headers (), body)
|
|
end
|
|
in
|
|
Cohttp_lwt_unix.Server.create ~stop ~ctx ~mode ~on_exn
|
|
(Cohttp_lwt_unix.Server.make ~callback ~conn_closed ())
|
|
end ;
|
|
Lwt.return server
|
|
|
|
let shutdown server =
|
|
Lwt.wakeup_later server.stopper () ;
|
|
server.worker >>= fun () ->
|
|
ConnectionMap.iter (fun _ f -> f ()) server.streams ;
|
|
Lwt.return_unit
|
|
|
|
end
|