ligo/vendors/ocplib-resto/lib_resto-cohttp/server.ml
Vincent Bernardoff 76bf25646c Resto: fix CORS
CORS also requires to send the "Access-Control-Allow-Origin" header on
direct requests or requests following preflight requests.
2018-05-12 23:58:20 +02:00

358 lines
14 KiB
OCaml

(**************************************************************************)
(* *)
(* Copyright (c) 2014 - 2018. *)
(* Dynamic Ledger Solutions, Inc. <contact@tezos.com> *)
(* *)
(* All rights reserved. No warranty, explicit or implicit, provided. *)
(* *)
(**************************************************************************)
open Lwt.Infix
module ConnectionMap = Map.Make(Cohttp.Connection)
module type LOGGING = sig
val debug: ('a, Format.formatter, unit, unit) format4 -> 'a
val log_info: ('a, Format.formatter, unit, unit) format4 -> 'a
val log_notice: ('a, Format.formatter, unit, unit) format4 -> 'a
val warn: ('a, Format.formatter, unit, unit) format4 -> 'a
val log_error: ('a, Format.formatter, unit, unit) format4 -> 'a
val lwt_debug: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
val lwt_log_info: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
val lwt_log_notice: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
val lwt_warn: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
val lwt_log_error: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
end
module Make (Encoding : Resto.ENCODING)(Log : LOGGING) = struct
open Log
open Cohttp
module Service = Resto.MakeService(Encoding)
module Directory = Resto_directory.Make(Encoding)
module Media_type = Media_type.Make(Encoding)
type server = {
root : unit Directory.directory ;
mutable streams : (unit -> unit) ConnectionMap.t ;
cors : Cors.t ;
media_types : Media_type.t list ;
default_media_type : string * Media_type.t ;
stopper : unit Lwt.u ;
mutable worker : unit Lwt.t ;
}
let create_stream server con to_string s =
let running = ref true in
let stream =
Lwt_stream.from
(fun () ->
if not !running then
Lwt.return None
else
s.Resto_directory.Answer.next () >|= function
| None -> None
| Some x -> Some (to_string x)) in
let shutdown () =
running := false ;
s.shutdown () ;
server.streams <- ConnectionMap.remove con server.streams in
server.streams <- ConnectionMap.add con shutdown server.streams ;
stream
let (>>=?) m f =
m >>= function
| Ok x -> f x
| Error err -> Lwt.return_error err
let callback server (_io, con) req body =
(* FIXME: check inbound adress *)
let uri = Request.uri req in
lwt_log_info "(%s) receive request to %s"
(Connection.to_string con) (Uri.path uri) >>= fun () ->
let path = Utils.split_path (Uri.path uri) in
let req_headers = Request.headers req in
begin
match Request.meth req with
| #Resto.meth as meth -> begin
Directory.lookup server.root ()
meth path >>=? fun (Directory.Service s) ->
begin
match Header.get req_headers "content-type" with
| None -> Lwt.return_ok (snd server.default_media_type)
| Some content_type ->
match Utils.split_path content_type with
| [x ; y] -> begin
match Media_type.find_media (x, y) server.media_types with
| None ->
Lwt.return_error (`Unsupported_media_type content_type)
| Some media_type ->
Lwt.return_ok media_type
end
| _ ->
Lwt.return_error (`Unsupported_media_type content_type)
end >>=? fun input_media_type ->
lwt_debug "(%s) input media type %s"
(Connection.to_string con)
(Media_type.name input_media_type) >>= fun () ->
begin
match Header.get req_headers "accept" with
| None -> Lwt.return_ok server.default_media_type
| Some accepted ->
match Media_type.resolve_accept_header
server.media_types (Some accepted) with
| None -> Lwt.return_error `Not_acceptable
| Some media_type -> Lwt.return_ok media_type
end >>=? fun (output_content_type, output_media_type) ->
begin
match Resto.Query.parse s.types.query
(List.map
(fun (k, l) -> (k, String.concat "," l))
(Uri.query uri)) with
| exception (Resto.Query.Invalid s) ->
Lwt.return_error (`Cannot_parse_query s)
| query -> Lwt.return_ok query
end >>=? fun query ->
lwt_debug "(%s) ouput media type %s"
(Connection.to_string con)
(Media_type.name output_media_type) >>= fun () ->
let output = output_media_type.construct s.types.output
and error = function
| None -> Cohttp_lwt.Body.empty, Transfer.Fixed 0L
| Some e ->
let s = output_media_type.construct s.types.error e in
Cohttp_lwt.Body.of_string s,
Transfer.Fixed (Int64.of_int (String.length s)) in
let headers = Header.init () in
let headers =
Header.add headers "content-type" output_content_type in
let headers = Cors.add_allow_origin
headers server.cors (Header.get req_headers "origin") in
begin
match s.types.input with
| Service.No_input ->
s.handler query () >>= Lwt.return_ok
| Service.Input input ->
Cohttp_lwt.Body.to_string body >>= fun body ->
match
input_media_type.destruct input body
with
| Error s ->
Lwt.return_error (`Cannot_parse_body s)
| Ok body ->
s.handler query body >>= Lwt.return_ok
end >>=? function
| `Ok o ->
let body = output o in
let encoding =
Transfer.Fixed (Int64.of_int (String.length body)) in
Lwt.return_ok
(Response.make ~status:`OK ~encoding ~headers (),
Cohttp_lwt.Body.of_string body)
| `OkStream o ->
let body = create_stream server con output o in
let encoding = Transfer.Chunked in
Lwt.return_ok
(Response.make ~status:`OK ~encoding ~headers (),
Cohttp_lwt.Body.of_stream body)
| `Created s ->
let headers = Header.init () in
let headers =
match s with
| None -> headers
| Some s -> Header.add headers "location" s in
Lwt.return_ok
(Response.make ~status:`Created ~headers (),
Cohttp_lwt.Body.empty)
| `No_content ->
Lwt.return_ok
(Response.make ~status:`No_content (),
Cohttp_lwt.Body.empty)
| `Unauthorized e ->
let body, encoding = error e in
let status = `Unauthorized in
Lwt.return_ok
(Response.make ~status ~encoding ~headers (), body)
| `Forbidden e ->
let body, encoding = error e in
let status = `Forbidden in
Lwt.return_ok
(Response.make ~status ~encoding ~headers (), body)
| `Not_found e ->
let body, encoding = error e in
let status = `Not_found in
Lwt.return_ok
(Response.make ~status ~encoding ~headers (), body)
| `Conflict e ->
let body, encoding = error e in
let status = `Conflict in
Lwt.return_ok
(Response.make ~status ~encoding ~headers (), body)
| `Error e ->
let body, encoding = error e in
let status = `Internal_server_error in
Lwt.return_ok
(Response.make ~status ~encoding ~headers (), body)
end
| `HEAD ->
(* TODO ??? *)
Lwt.return_error `Not_implemented
| `OPTIONS ->
let req_headers = Request.headers req in
let origin_header = Header.get req_headers "origin" in
begin
(* Default OPTIONS handler for CORS preflight *)
if origin_header = None then
Directory.allowed_methods server.root () path
else
match Header.get req_headers
"Access-Control-Request-Method" with
| None ->
Directory.allowed_methods server.root () path
| Some meth ->
match Code.method_of_string meth with
| #Resto.meth as meth ->
Directory.lookup server.root () meth path >>=? fun _handler ->
Lwt.return_ok [ meth ]
| _ ->
Lwt.return_error `Not_found
end >>=? fun cors_allowed_meths ->
lwt_log_info "(%s) RPC preflight"
(Connection.to_string con) >>= fun () ->
let headers = Header.init () in
let headers =
Header.add_multi headers
"Access-Control-Allow-Methods"
(List.map Resto.string_of_meth cors_allowed_meths) in
let headers = Cors.add_headers headers server.cors origin_header in
Lwt.return_ok
(Response.make ~flush:true ~status:`OK ~headers (),
Cohttp_lwt.Body.empty)
| _ ->
Lwt.return_error `Not_implemented
end >>= function
| Ok answer -> Lwt.return answer
| Error `Not_implemented ->
Lwt.return
(Response.make ~status:`Not_implemented (),
Cohttp_lwt.Body.empty)
| Error `Method_not_allowed methods ->
let headers = Header.init () in
let headers =
Header.add_multi headers "allow"
(List.map Resto.string_of_meth methods) in
Lwt.return
(Response.make ~status:`Method_not_allowed ~headers (),
Cohttp_lwt.Body.empty)
| Error `Cannot_parse_path (context, arg, value) ->
let headers = Header.init () in
let headers =
Header.add headers "content-type" "text/plain" in
Lwt.return
(Response.make ~status:`Bad_request ~headers (),
Format.kasprintf Cohttp_lwt.Body.of_string
"Failed to parsed an argument in path. After \"%s\", \
the value \"%s\" is not acceptable for type \"%s\""
(String.concat "/" context) value arg.name)
| Error `Cannot_parse_body s ->
let headers = Header.init () in
let headers =
Header.add headers "content-type" "text/plain" in
Lwt.return
(Response.make ~status:`Bad_request ~headers (),
Format.kasprintf Cohttp_lwt.Body.of_string
"Failed to parse the request body: %s" s)
| Error `Cannot_parse_query s ->
let headers = Header.init () in
let headers =
Header.add headers "content-type" "text/plain" in
Lwt.return
(Response.make ~status:`Bad_request ~headers (),
Format.kasprintf Cohttp_lwt.Body.of_string
"Failed to parse the query string: %s" s)
| Error `Not_acceptable ->
let accepted_encoding =
Media_type.acceptable_encoding server.media_types in
Lwt.return
(Response.make ~status:`Not_acceptable (),
Cohttp_lwt.Body.of_string accepted_encoding)
| Error `Unsupported_media_type _ ->
Lwt.return
(Response.make ~status:`Unsupported_media_type (),
Cohttp_lwt.Body.empty)
| Error `Not_found ->
Lwt.return
(Response.make ~status:`Not_found (),
Cohttp_lwt.Body.empty)
(* Promise a running RPC server. *)
let launch
?(host="::")
?(cors = Cors.default)
~media_types
mode root =
let default_media_type =
match Media_type.first_complete_media media_types with
| None -> invalid_arg "RestoCohttp.launch(empty media type list)"
| Some ((l, r), m) -> l^"/"^r, m in
let stop, stopper = Lwt.wait () in
let server = {
root ;
streams = ConnectionMap.empty ;
cors ;
media_types ;
default_media_type ;
stopper ;
worker = Lwt.return_unit ;
} in
Conduit_lwt_unix.init ~src:host () >>= fun ctx ->
let ctx = Cohttp_lwt_unix.Net.init ~ctx () in
server.worker <- begin
let conn_closed (_, con) =
log_info "connection closed %s" (Connection.to_string con) ;
try ConnectionMap.find con server.streams ()
with Not_found -> ()
and on_exn = function
| Unix.Unix_error (Unix.EADDRINUSE, "bind", _) ->
log_error "RPC server port already taken, \
the node will be shutdown" ;
exit 1
| Unix.Unix_error (ECONNRESET, _, _)
| Unix.Unix_error (EPIPE, _, _) -> ()
| exn -> !Lwt.async_exception_hook exn
and callback (io, con) req body =
Lwt.catch
begin fun () -> callback server (io, con) req body end
begin function
| Not_found ->
let status = `Not_found in
let body = Cohttp_lwt.Body.empty in
Lwt.return (Response.make ~status (), body)
| exn ->
let headers = Header.init () in
let headers =
Header.add headers "content-type" "text/ocaml.exception" in
let status = `Internal_server_error in
let body = Cohttp_lwt.Body.of_string (Printexc.to_string exn) in
Lwt.return (Response.make ~status ~headers (), body)
end
in
Cohttp_lwt_unix.Server.create ~stop ~ctx ~mode ~on_exn
(Cohttp_lwt_unix.Server.make ~callback ~conn_closed ())
end ;
Lwt.return server
let shutdown server =
Lwt.wakeup_later server.stopper () ;
server.worker >>= fun () ->
ConnectionMap.iter (fun _ f -> f ()) server.streams ;
Lwt.return_unit
end