ligo/vendors/ocplib-resto/lib_resto-cohttp/restoCohttp.ml

444 lines
16 KiB
OCaml
Raw Normal View History

(**************************************************************************)
(* *)
(* Copyright (c) 2014 - 2017. *)
(* Dynamic Ledger Solutions, Inc. <contact@tezos.com> *)
(* *)
(* All rights reserved. No warranty, explicit or implicit, provided. *)
(* *)
(**************************************************************************)
open Lwt.Infix
module Utils = struct
let split_path path =
let l = String.length path in
let rec do_slashes acc i =
if i >= l then
List.rev acc
else if String.get path i = '/' then
do_slashes acc (i + 1)
else
do_component acc i i
and do_component acc i j =
if j >= l then
if i = j then
List.rev acc
else
List.rev (String.sub path i (j - i) :: acc)
else if String.get path j = '/' then
do_slashes (String.sub path i (j - i) :: acc) j
else
do_component acc i (j + 1) in
do_slashes [] 0
end
type cors = {
allowed_headers : string list ;
allowed_origins : string list ;
}
module Cors = struct
let default = { allowed_headers = [] ; allowed_origins = [] }
let check_origin_matches origin allowed_origin =
String.equal "*" allowed_origin ||
String.equal allowed_origin origin ||
begin
let allowed_w_slash = allowed_origin ^ "/" in
let len_a_w_s = String.length allowed_w_slash in
let len_o = String.length origin in
(len_o >= len_a_w_s) &&
String.equal allowed_w_slash @@ String.sub origin 0 len_a_w_s
end
let find_matching_origin allowed_origins origin =
let matching_origins =
List.filter (check_origin_matches origin) allowed_origins in
let compare_by_length_neg a b =
~- (compare (String.length a) (String.length b)) in
let matching_origins_sorted =
List.sort compare_by_length_neg matching_origins in
match matching_origins_sorted with
| [] -> None
| x :: _ -> Some x
let add_headers headers cors origin_header =
let cors_headers =
Cohttp.Header.add_multi headers
"Access-Control-Allow-Headers" cors.allowed_headers in
match origin_header with
| None -> cors_headers
| Some origin ->
match find_matching_origin cors.allowed_origins origin with
| None -> cors_headers
| Some allowed_origin ->
Cohttp.Header.add_multi cors_headers
"Access-Control-Allow-Origin" [allowed_origin]
end
module ConnectionMap = Map.Make(Cohttp.Connection)
module type LOGGING = sig
val debug: ('a, Format.formatter, unit, unit) format4 -> 'a
val log_info: ('a, Format.formatter, unit, unit) format4 -> 'a
val log_notice: ('a, Format.formatter, unit, unit) format4 -> 'a
val warn: ('a, Format.formatter, unit, unit) format4 -> 'a
val log_error: ('a, Format.formatter, unit, unit) format4 -> 'a
val lwt_debug: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
val lwt_log_info: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
val lwt_log_notice: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
val lwt_warn: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
val lwt_log_error: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
end
module Make (Encoding : Resto.ENCODING)(Log : LOGGING) = struct
open Log
open Cohttp
module Service = Resto.MakeService(Encoding)
module Directory = RestoDirectory.MakeDirectory(Encoding)
type media_type = {
name: string ;
construct: 'a. 'a Encoding.t -> 'a -> string ;
destruct: 'a. 'a Encoding.t -> string -> ('a, string) result ;
}
module Media_type = struct
(* Inspired from ocaml-webmachine *)
let media_match (_, (range, _)) media =
let type_, subtype =
match Utils.split_path media.name with
| [x ; y] -> x, y
| _ ->
Format.kasprintf invalid_arg "invalid media_type '%s'" media.name in
let open Accept in
match range with
| AnyMedia -> true
| AnyMediaSubtype type_' -> type_' = type_
| MediaType (type_', subtype') -> type_' = type_ && subtype' = subtype
let match_header provided header =
let ranges = Accept.(media_ranges header |> qsort) in
let rec loop = function
| [] -> None
| r :: rs ->
try Some(List.find (media_match r) provided)
with Not_found -> loop rs
in
loop ranges
end
type server = {
root : unit Directory.directory ;
mutable streams : (unit -> unit) ConnectionMap.t ;
cors : cors ;
media_types : media_type list ;
default_media_type : media_type ;
stopper : unit Lwt.u ;
mutable worker : unit Lwt.t ;
}
let create_stream server con to_string s =
let running = ref true in
let stream =
Lwt_stream.from
(fun () ->
if not !running then
Lwt.return None
else
s.RestoDirectory.Answer.next () >|= function
| None -> None
| Some x -> Some (to_string x)) in
let shutdown () =
running := false ;
s.shutdown () ;
server.streams <- ConnectionMap.remove con server.streams in
server.streams <- ConnectionMap.add con shutdown server.streams ;
stream
let (>>=?) m f =
m >>= function
| Ok x -> f x
| Error err -> Lwt.return_error err
let callback server (_io, con) req body =
(* FIXME: check inbound adress *)
let uri = Request.uri req in
lwt_log_info "(%s) receive request to %s"
(Connection.to_string con) (Uri.path uri) >>= fun () ->
let path = Utils.split_path (Uri.path uri) in
let req_headers = Request.headers req in
begin
match Request.meth req with
| #Resto.meth as meth -> begin
Directory.lookup server.root ()
meth path >>=? fun (Directory.Service s) ->
begin
match Header.get req_headers "content-type" with
| None -> Lwt.return_ok server.default_media_type
| Some content_type ->
match List.find (fun { name ; _ } -> name = content_type)
server.media_types with
| exception Not_found ->
Lwt.return_error (`Unsupported_media_type content_type)
| media_type -> Lwt.return_ok media_type
end >>=? fun input_media_type ->
begin
match Header.get req_headers "accept" with
| None -> Lwt.return_ok server.default_media_type
| Some accepted ->
match Media_type.match_header
server.media_types (Some accepted) with
| None -> Lwt.return_error `Not_acceptable
| Some media_type -> Lwt.return_ok media_type
end >>=? fun output_media_type ->
begin
match Resto.Query.parse s.types.query
(List.map
(fun (k, l) -> (k, String.concat "," l))
(Uri.query uri)) with
| exception (Resto.Query.Invalid s) ->
Lwt.return_error (`Cannot_parse_query s)
| query -> Lwt.return_ok query
end >>=? fun query ->
let output = output_media_type.construct s.types.output
and error = function
| None -> Cohttp_lwt.Body.empty, Transfer.Fixed 0L
| Some e ->
let s = output_media_type.construct s.types.error e in
Cohttp_lwt.Body.of_string s,
Transfer.Fixed (Int64.of_int (String.length s)) in
let headers = Header.init () in
let headers =
Header.add headers "content-type" output_media_type.name in
begin
match s.types.input with
| Service.No_input ->
s.handler query () >>= Lwt.return_ok
| Service.Input input ->
Cohttp_lwt.Body.to_string body >>= fun body ->
match
input_media_type.destruct input body
with
| Error s ->
Lwt.return_error (`Cannot_parse_body s)
| Ok body ->
s.handler query body >>= Lwt.return_ok
end >>=? function
| `Ok o ->
let body = output o in
let encoding =
Transfer.Fixed (Int64.of_int (String.length body)) in
Lwt.return_ok
(Response.make ~status:`OK ~encoding ~headers (),
Cohttp_lwt.Body.of_string body)
| `OkStream o ->
let body = create_stream server con output o in
let encoding = Transfer.Chunked in
Lwt.return_ok
(Response.make ~status:`OK ~encoding ~headers (),
Cohttp_lwt.Body.of_stream body)
| `Created s ->
let headers = Header.init () in
let headers =
match s with
| None -> headers
| Some s -> Header.add headers "location" s in
Lwt.return_ok
(Response.make ~status:`Created ~headers (),
Cohttp_lwt.Body.empty)
| `No_content ->
Lwt.return_ok
(Response.make ~status:`No_content (),
Cohttp_lwt.Body.empty)
| `Unauthorized e ->
let body, encoding = error e in
let status = `Unauthorized in
Lwt.return_ok
(Response.make ~status ~encoding ~headers (), body)
| `Forbidden e ->
let body, encoding = error e in
let status = `Forbidden in
Lwt.return_ok
(Response.make ~status ~encoding ~headers (), body)
| `Not_found e ->
let body, encoding = error e in
let status = `Not_found in
Lwt.return_ok
(Response.make ~status ~encoding ~headers (), body)
| `Conflict e ->
let body, encoding = error e in
let status = `Conflict in
Lwt.return_ok
(Response.make ~status ~encoding ~headers (), body)
| `Error e ->
let body, encoding = error e in
let status = `Internal_server_error in
Lwt.return_ok
(Response.make ~status ~encoding ~headers (), body)
end
| `HEAD ->
(* TODO ??? *)
Lwt.return_error `Not_implemented
| `OPTIONS ->
let req_headers = Request.headers req in
let origin_header = Header.get req_headers "origin" in
begin
(* Default OPTIONS handler for CORS preflight *)
if origin_header = None then
Directory.allowed_methods server.root () path
else
match Header.get req_headers
"Access-Control-Request-Method" with
| None ->
Directory.allowed_methods server.root () path
| Some meth ->
match Code.method_of_string meth with
| #Resto.meth as meth ->
Directory.lookup server.root () meth path >>=? fun _handler ->
Lwt.return_ok [ meth ]
| _ ->
Lwt.return_error `Not_found
end >>=? fun cors_allowed_meths ->
lwt_log_info "(%s) RPC preflight"
(Connection.to_string con) >>= fun () ->
let headers = Header.init () in
let headers =
Header.add_multi headers
"Access-Control-Allow-Methods"
(List.map Resto.string_of_meth cors_allowed_meths) in
let headers = Cors.add_headers headers server.cors origin_header in
Lwt.return_ok
(Response.make ~flush:true ~status:`OK ~headers (),
Cohttp_lwt.Body.empty)
| _ ->
Lwt.return_error `Not_implemented
end >>= function
| Ok answer -> Lwt.return answer
| Error `Not_implemented ->
Lwt.return
(Response.make ~status:`Not_implemented (),
Cohttp_lwt.Body.empty)
| Error `Method_not_allowed methods ->
let headers = Header.init () in
let headers =
Header.add_multi headers "allow"
(List.map Resto.string_of_meth methods) in
Lwt.return
(Response.make ~status:`Method_not_allowed ~headers (),
Cohttp_lwt.Body.empty)
| Error `Cannot_parse_path (context, arg, value) ->
let headers = Header.init () in
let headers =
Header.add headers "content-type" "text/plain" in
Lwt.return
(Response.make ~status:`Bad_request ~headers (),
Format.kasprintf Cohttp_lwt.Body.of_string
"Failed to parsed an argument in path. After \"%s\", \
the value \"%s\" is not acceptable for type \"%s\""
(String.concat "/" context) value arg.name)
| Error `Cannot_parse_body s ->
let headers = Header.init () in
let headers =
Header.add headers "content-type" "text/plain" in
Lwt.return
(Response.make ~status:`Bad_request ~headers (),
Format.kasprintf Cohttp_lwt.Body.of_string
"Failed to parse the request body: %s" s)
| Error `Cannot_parse_query s ->
let headers = Header.init () in
let headers =
Header.add headers "content-type" "text/plain" in
Lwt.return
(Response.make ~status:`Bad_request ~headers (),
Format.kasprintf Cohttp_lwt.Body.of_string
"Failed to parse the query string: %s" s)
| Error `Not_acceptable ->
let accepted_encoding =
String.concat ", "
(List.map (fun f -> f.name)
server.media_types) in
Lwt.return
(Response.make ~status:`Not_acceptable (),
Cohttp_lwt.Body.of_string accepted_encoding)
| Error `Unsupported_media_type _ ->
Lwt.return
(Response.make ~status:`Unsupported_media_type (),
Cohttp_lwt.Body.empty)
| Error `Not_found ->
Lwt.return
(Response.make ~status:`Not_found (),
Cohttp_lwt.Body.empty)
(* Promise a running RPC server. *)
let launch
?(host="::")
?(cors = Cors.default)
~media_types
mode root =
if media_types = [] then
invalid_arg "RestoCohttp.launch(empty media type list)" ;
let default_media_type = List.hd media_types in
let stop, stopper = Lwt.wait () in
let server = {
root ;
streams = ConnectionMap.empty ;
cors ;
media_types ;
default_media_type ;
stopper ;
worker = Lwt.return_unit ;
} in
Conduit_lwt_unix.init ~src:host () >>= fun ctx ->
let ctx = Cohttp_lwt_unix.Net.init ~ctx () in
server.worker <- begin
let conn_closed (_, con) =
log_info "connection closed %s" (Connection.to_string con) ;
try ConnectionMap.find con server.streams ()
with Not_found -> ()
and on_exn = function
| Unix.Unix_error (Unix.EADDRINUSE, "bind", _) ->
log_error "RPC server port already taken, \
the node will be shutdown" ;
exit 1
| Unix.Unix_error (ECONNRESET, _, _)
| Unix.Unix_error (EPIPE, _, _) -> ()
| exn -> !Lwt.async_exception_hook exn
and callback (io, con) req body =
Lwt.catch
begin fun () -> callback server (io, con) req body end
begin fun exn ->
let headers = Header.init () in
let headers =
Header.add headers "content-type" "text/ocaml.exception" in
let status = `Internal_server_error in
let body = Cohttp_lwt.Body.of_string (Printexc.to_string exn) in
Lwt.return (Response.make ~status ~headers (), body)
end
in
Cohttp_lwt_unix.Server.create ~stop ~ctx ~mode ~on_exn
(Cohttp_lwt_unix.Server.make ~callback ~conn_closed ())
end ;
Lwt.return server
let shutdown server =
Lwt.wakeup_later server.stopper () ;
server.worker >>= fun () ->
ConnectionMap.iter (fun _ f -> f ()) server.streams ;
Lwt.return_unit
end