(**************************************************************************) (* *) (* Copyright (c) 2014 - 2017. *) (* Dynamic Ledger Solutions, Inc. *) (* *) (* All rights reserved. No warranty, explicit or implicit, provided. *) (* *) (**************************************************************************) open Lwt.Infix module Utils = struct let split_path path = let l = String.length path in let rec do_slashes acc i = if i >= l then List.rev acc else if String.get path i = '/' then do_slashes acc (i + 1) else do_component acc i i and do_component acc i j = if j >= l then if i = j then List.rev acc else List.rev (String.sub path i (j - i) :: acc) else if String.get path j = '/' then do_slashes (String.sub path i (j - i) :: acc) j else do_component acc i (j + 1) in do_slashes [] 0 end type cors = { allowed_headers : string list ; allowed_origins : string list ; } module Cors = struct let default = { allowed_headers = [] ; allowed_origins = [] } let check_origin_matches origin allowed_origin = String.equal "*" allowed_origin || String.equal allowed_origin origin || begin let allowed_w_slash = allowed_origin ^ "/" in let len_a_w_s = String.length allowed_w_slash in let len_o = String.length origin in (len_o >= len_a_w_s) && String.equal allowed_w_slash @@ String.sub origin 0 len_a_w_s end let find_matching_origin allowed_origins origin = let matching_origins = List.filter (check_origin_matches origin) allowed_origins in let compare_by_length_neg a b = ~- (compare (String.length a) (String.length b)) in let matching_origins_sorted = List.sort compare_by_length_neg matching_origins in match matching_origins_sorted with | [] -> None | x :: _ -> Some x let add_headers headers cors origin_header = let cors_headers = Cohttp.Header.add_multi headers "Access-Control-Allow-Headers" cors.allowed_headers in match origin_header with | None -> cors_headers | Some origin -> match find_matching_origin cors.allowed_origins origin with | None -> cors_headers | Some allowed_origin -> Cohttp.Header.add_multi cors_headers "Access-Control-Allow-Origin" [allowed_origin] end module ConnectionMap = Map.Make(Cohttp.Connection) module type LOGGING = sig val debug: ('a, Format.formatter, unit, unit) format4 -> 'a val log_info: ('a, Format.formatter, unit, unit) format4 -> 'a val log_notice: ('a, Format.formatter, unit, unit) format4 -> 'a val warn: ('a, Format.formatter, unit, unit) format4 -> 'a val log_error: ('a, Format.formatter, unit, unit) format4 -> 'a val lwt_debug: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a val lwt_log_info: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a val lwt_log_notice: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a val lwt_warn: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a val lwt_log_error: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a end module Make (Encoding : Resto.ENCODING)(Log : LOGGING) = struct open Log open Cohttp module Service = Resto.MakeService(Encoding) module Directory = RestoDirectory.MakeDirectory(Encoding) type media_type = { name: string ; construct: 'a. 'a Encoding.t -> 'a -> string ; destruct: 'a. 'a Encoding.t -> string -> ('a, string) result ; } module Media_type = struct (* Inspired from ocaml-webmachine *) let media_match (_, (range, _)) media = let type_, subtype = match Utils.split_path media.name with | [x ; y] -> x, y | _ -> Format.kasprintf invalid_arg "invalid media_type '%s'" media.name in let open Accept in match range with | AnyMedia -> true | AnyMediaSubtype type_' -> type_' = type_ | MediaType (type_', subtype') -> type_' = type_ && subtype' = subtype let match_header provided header = let ranges = Accept.(media_ranges header |> qsort) in let rec loop = function | [] -> None | r :: rs -> try Some(List.find (media_match r) provided) with Not_found -> loop rs in loop ranges end type server = { root : unit Directory.directory ; mutable streams : (unit -> unit) ConnectionMap.t ; cors : cors ; media_types : media_type list ; default_media_type : media_type ; stopper : unit Lwt.u ; mutable worker : unit Lwt.t ; } let create_stream server con to_string s = let running = ref true in let stream = Lwt_stream.from (fun () -> if not !running then Lwt.return None else s.RestoDirectory.Answer.next () >|= function | None -> None | Some x -> Some (to_string x)) in let shutdown () = running := false ; s.shutdown () ; server.streams <- ConnectionMap.remove con server.streams in server.streams <- ConnectionMap.add con shutdown server.streams ; stream let (>>=?) m f = m >>= function | Ok x -> f x | Error err -> Lwt.return_error err let callback server (_io, con) req body = (* FIXME: check inbound adress *) let uri = Request.uri req in lwt_log_info "(%s) receive request to %s" (Connection.to_string con) (Uri.path uri) >>= fun () -> let path = Utils.split_path (Uri.path uri) in let req_headers = Request.headers req in begin match Request.meth req with | #Resto.meth as meth -> begin Directory.lookup server.root () meth path >>=? fun (Directory.Service s) -> begin match Header.get req_headers "content-type" with | None -> Lwt.return_ok server.default_media_type | Some content_type -> match List.find (fun { name ; _ } -> name = content_type) server.media_types with | exception Not_found -> Lwt.return_error (`Unsupported_media_type content_type) | media_type -> Lwt.return_ok media_type end >>=? fun input_media_type -> begin match Header.get req_headers "accept" with | None -> Lwt.return_ok server.default_media_type | Some accepted -> match Media_type.match_header server.media_types (Some accepted) with | None -> Lwt.return_error `Not_acceptable | Some media_type -> Lwt.return_ok media_type end >>=? fun output_media_type -> begin match Resto.Query.parse s.types.query (List.map (fun (k, l) -> (k, String.concat "," l)) (Uri.query uri)) with | exception (Resto.Query.Invalid s) -> Lwt.return_error (`Cannot_parse_query s) | query -> Lwt.return_ok query end >>=? fun query -> let output = output_media_type.construct s.types.output and error = function | None -> Cohttp_lwt_body.empty, Transfer.Fixed 0L | Some e -> let s = output_media_type.construct s.types.error e in Cohttp_lwt_body.of_string s, Transfer.Fixed (Int64.of_int (String.length s)) in let headers = Header.init () in let headers = Header.add headers "content-type" output_media_type.name in begin match s.types.input with | Service.No_input -> s.handler query () >>= Lwt.return_ok | Service.Input input -> Cohttp_lwt_body.to_string body >>= fun body -> match input_media_type.destruct input body with | Error s -> Lwt.return_error (`Cannot_parse_body s) | Ok body -> s.handler query body >>= Lwt.return_ok end >>=? function | `Ok o -> let body = output o in let encoding = Transfer.Fixed (Int64.of_int (String.length body)) in Lwt.return_ok (Response.make ~status:`OK ~encoding ~headers (), Cohttp_lwt_body.of_string body) | `OkStream o -> let body = create_stream server con output o in let encoding = Transfer.Chunked in Lwt.return_ok (Response.make ~status:`OK ~encoding ~headers (), Cohttp_lwt_body.of_stream body) | `Created s -> let headers = Header.init () in let headers = match s with | None -> headers | Some s -> Header.add headers "location" s in Lwt.return_ok (Response.make ~status:`Created ~headers (), Cohttp_lwt_body.empty) | `No_content -> Lwt.return_ok (Response.make ~status:`No_content (), Cohttp_lwt_body.empty) | `Unauthorized e -> let body, encoding = error e in let status = `Unauthorized in Lwt.return_ok (Response.make ~status ~encoding ~headers (), body) | `Forbidden e -> let body, encoding = error e in let status = `Forbidden in Lwt.return_ok (Response.make ~status ~encoding ~headers (), body) | `Not_found e -> let body, encoding = error e in let status = `Not_found in Lwt.return_ok (Response.make ~status ~encoding ~headers (), body) | `Conflict e -> let body, encoding = error e in let status = `Conflict in Lwt.return_ok (Response.make ~status ~encoding ~headers (), body) | `Error e -> let body, encoding = error e in let status = `Internal_server_error in Lwt.return_ok (Response.make ~status ~encoding ~headers (), body) end | `HEAD -> (* TODO ??? *) Lwt.return_error `Not_implemented | `OPTIONS -> let req_headers = Request.headers req in let origin_header = Header.get req_headers "origin" in begin (* Default OPTIONS handler for CORS preflight *) if origin_header = None then Directory.allowed_methods server.root () path else match Header.get req_headers "Access-Control-Request-Method" with | None -> Directory.allowed_methods server.root () path | Some meth -> match Code.method_of_string meth with | #Resto.meth as meth -> Directory.lookup server.root () meth path >>=? fun _handler -> Lwt.return_ok [ meth ] | _ -> Lwt.return_error `Not_found end >>=? fun cors_allowed_meths -> lwt_log_info "(%s) RPC preflight" (Connection.to_string con) >>= fun () -> let headers = Header.init () in let headers = Header.add_multi headers "Access-Control-Allow-Methods" (List.map Resto.string_of_meth cors_allowed_meths) in let headers = Cors.add_headers headers server.cors origin_header in Lwt.return_ok (Response.make ~flush:true ~status:`OK ~headers (), Cohttp_lwt_body.empty) | _ -> Lwt.return_error `Not_implemented end >>= function | Ok answer -> Lwt.return answer | Error `Not_implemented -> Lwt.return (Response.make ~status:`Not_implemented (), Cohttp_lwt_body.empty) | Error `Method_not_allowed methods -> let headers = Header.init () in let headers = Header.add_multi headers "allow" (List.map Resto.string_of_meth methods) in Lwt.return (Response.make ~status:`Method_not_allowed ~headers (), Cohttp_lwt_body.empty) | Error `Cannot_parse_path (context, arg, value) -> let headers = Header.init () in let headers = Header.add headers "content-type" "text/plain" in Lwt.return (Response.make ~status:`Bad_request ~headers (), Format.kasprintf Cohttp_lwt_body.of_string "Failed to parsed an argument in path. After \"%s\", \ the value \"%s\" is not acceptable for type \"%s\"" (String.concat "/" context) value arg.name) | Error `Cannot_parse_body s -> let headers = Header.init () in let headers = Header.add headers "content-type" "text/plain" in Lwt.return (Response.make ~status:`Bad_request ~headers (), Format.kasprintf Cohttp_lwt_body.of_string "Failed to parse the request body: %s" s) | Error `Cannot_parse_query s -> let headers = Header.init () in let headers = Header.add headers "content-type" "text/plain" in Lwt.return (Response.make ~status:`Bad_request ~headers (), Format.kasprintf Cohttp_lwt_body.of_string "Failed to parse the query string: %s" s) | Error `Not_acceptable -> let accepted_encoding = String.concat ", " (List.map (fun f -> f.name) server.media_types) in Lwt.return (Response.make ~status:`Not_acceptable (), Cohttp_lwt_body.of_string accepted_encoding) | Error `Unsupported_media_type _ -> Lwt.return (Response.make ~status:`Unsupported_media_type (), Cohttp_lwt_body.empty) | Error `Not_found -> Lwt.return (Response.make ~status:`Not_found (), Cohttp_lwt_body.empty) (* Promise a running RPC server. *) let launch ?(host="::") ?(cors = Cors.default) ~media_types mode root = if media_types = [] then invalid_arg "RestoCohttp.launch(empty media type list)" ; let default_media_type = List.hd media_types in let stop, stopper = Lwt.wait () in let server = { root ; streams = ConnectionMap.empty ; cors ; media_types ; default_media_type ; stopper ; worker = Lwt.return_unit ; } in let open Cohttp_lwt_unix in Conduit_lwt_unix.init ~src:host () >>= fun ctx -> let ctx = Cohttp_lwt_unix_net.init ~ctx () in server.worker <- begin let conn_closed (_, con) = log_info "connection closed %s" (Connection.to_string con) ; try ConnectionMap.find con server.streams () with Not_found -> () and on_exn = function | Unix.Unix_error (Unix.EADDRINUSE, "bind", _) -> log_error "RPC server port already taken, \ the node will be shutdown" ; exit 1 | Unix.Unix_error (ECONNRESET, _, _) | Unix.Unix_error (EPIPE, _, _) -> () | exn -> !Lwt.async_exception_hook exn and callback (io, con) req body = Lwt.catch begin fun () -> callback server (io, con) req body end begin fun exn -> let headers = Header.init () in let headers = Header.add headers "content-type" "text/ocaml.exception" in let status = `Internal_server_error in let body = Cohttp_lwt_body.of_string (Printexc.to_string exn) in Lwt.return (Response.make ~status ~headers (), body) end in Server.create ~stop ~ctx ~mode ~on_exn (Server.make ~callback ~conn_closed ()) end ; Lwt.return server let shutdown server = Lwt.wakeup_later server.stopper () ; server.worker >>= fun () -> ConnectionMap.iter (fun _ f -> f ()) server.streams ; Lwt.return_unit end