RPCs: fix missing CORS headers.

This commit is contained in:
Benjamin Canou 2017-04-19 16:54:46 +02:00
parent c1b4a74bf7
commit 442f2d00a0

View File

@ -17,7 +17,7 @@ type server = (* hidden *)
module ConnectionMap = Map.Make(Cohttp.Connection) module ConnectionMap = Map.Make(Cohttp.Connection)
exception Invalid_method exception Invalid_method of { allowed : RPC.meth list }
exception Cannot_parse_body of string exception Cannot_parse_body of string
let check_origin_matches origin allowed_origin = let check_origin_matches origin allowed_origin =
@ -110,6 +110,17 @@ let launch ?pre_hook ?post_hook ?(host="::") mode root cors_allowed_origins cors
let path = Utils.split_path (Uri.path (Cohttp.Request.uri req)) in let path = Utils.split_path (Uri.path (Cohttp.Request.uri req)) in
let req_headers = Cohttp.Request.headers req in let req_headers = Cohttp.Request.headers req in
let origin_header = Cohttp.Header.get req_headers "origin" in let origin_header = Cohttp.Header.get req_headers "origin" in
let answer_with_cors_headers ?headers ?body status =
let headers = match headers with
| None -> Cohttp.Header.init ()
| Some headers -> headers in
let body = match body with
| None -> Cohttp_lwt_body.empty
| Some body -> body in
let headers =
make_cors_headers ~headers
cors_allowed_headers cors_allowed_origins origin_header in
Lwt.return (Response.make ~flush:true ~status ~headers (), body) in
lwt_log_info "(%s) receive request to %s" lwt_log_info "(%s) receive request to %s"
(Cohttp.Connection.to_string con) (Uri.path (Cohttp.Request.uri req)) >>= fun () -> (Cohttp.Connection.to_string con) (Uri.path (Cohttp.Request.uri req)) >>= fun () ->
Lwt.catch Lwt.catch
@ -118,11 +129,25 @@ let launch ?pre_hook ?post_hook ?(host="::") mode root cors_allowed_origins cors
| Some res -> | Some res ->
Lwt.return res Lwt.return res
| None -> | None ->
let existing_methods () =
let supported_meths =
[ `OPTIONS ; `POST ; `PUT ; `PATCH ; `DELETE ; `GET ; `HEAD ] in
Lwt_list.filter_map_p
(fun meth ->
Lwt.catch
(fun () ->
lookup root ~meth () path >>= fun _handler ->
Lwt.return_some meth)
(function Not_found | Cannot_parse _ -> Lwt.return_none
| exn -> Lwt.fail exn))
supported_meths >>= function
| [] -> Lwt.fail Not_found (* No handler at all -> 404 *)
| meths -> Lwt.return meths in
Lwt.catch Lwt.catch
(fun () -> (fun () ->
lookup root ~meth:req.meth () path >>= fun handler -> lookup root ~meth:req.meth () path >>= fun handler ->
Lwt.return_some handler) Lwt.return_some handler)
(function Not_found -> Lwt.return_none (function Not_found | Cannot_parse _ -> Lwt.return_none
| exn -> Lwt.fail exn) >>= function | exn -> Lwt.fail exn) >>= function
| None -> | None ->
begin begin
@ -135,35 +160,19 @@ let launch ?pre_hook ?post_hook ?(host="::") mode root cors_allowed_origins cors
lookup root ~meth () path >>= fun _handler -> lookup root ~meth () path >>= fun _handler ->
(* unless [lookup] failed with [Not_found] -> 404 *) (* unless [lookup] failed with [Not_found] -> 404 *)
Lwt.return [ meth ] Lwt.return [ meth ]
| None -> | None -> existing_methods ()
let supported_meths = else
[ `POST ; `PUT ; `PATCH ; `DELETE ; `GET ; `HEAD ] in existing_methods () >>= fun allowed ->
Lwt_list.filter_map_p Lwt.fail (Invalid_method { allowed })
(fun meth ->
Lwt.catch
(fun () ->
lookup root ~meth () path >>= fun _handler ->
Lwt.return_some meth)
(function Not_found -> Lwt.return_none
| exn -> Lwt.fail exn))
supported_meths >>= function
| [] -> Lwt.fail Not_found (* No handler -> 404 *)
| meths -> Lwt.return meths
else Lwt.fail Not_found
end >>= fun cors_allowed_meths -> end >>= fun cors_allowed_meths ->
lwt_log_info "(%s) RPC preflight" lwt_log_info "(%s) RPC preflight"
(Cohttp.Connection.to_string con) >>= fun () -> (Cohttp.Connection.to_string con) >>= fun () ->
let headers = let headers =
Cohttp.Header.add Cohttp.Header.add_multi
(Cohttp.Header.init ()) (Cohttp.Header.init ())
"Access-Control-Allow-Methods" "Access-Control-Allow-Methods"
(String.concat ", " (List.map Cohttp.Code.string_of_method cors_allowed_meths) in
(List.map Cohttp.Code.string_of_method cors_allowed_meths)) in answer_with_cors_headers ~headers `OK
let headers =
make_cors_headers ~headers
cors_allowed_headers cors_allowed_origins origin_header in
Lwt.return (Response.make ~flush:true ~status:(`Code 200) ~headers (),
Cohttp_lwt_body.empty)
| Some handler -> | Some handler ->
begin match req.meth with begin match req.meth with
| `POST | `POST
@ -178,7 +187,9 @@ let launch ?pre_hook ?post_hook ?(host="::") mode root cors_allowed_origins cors
| `GET | `GET
| `HEAD | `HEAD
| `OPTIONS -> Lwt.return None | `OPTIONS -> Lwt.return None
| _ -> Lwt.fail Invalid_method | _ ->
existing_methods () >>= fun allowed ->
Lwt.fail (Invalid_method { allowed })
end >>= fun body -> end >>= fun body ->
handler body >>= fun { Answer.code ; body } -> handler body >>= fun { Answer.code ; body } ->
let body = match body with let body = match body with
@ -197,37 +208,27 @@ let launch ?pre_hook ?post_hook ?(host="::") mode root cors_allowed_origins cors
else "success") >>= fun () -> else "success") >>= fun () ->
let headers = let headers =
Cohttp.Header.init_with "Content-Type" "application/json" in Cohttp.Header.init_with "Content-Type" "application/json" in
let headers = answer_with_cors_headers ~headers ~body (`Code code))
make_cors_headers ~headers
cors_allowed_headers cors_allowed_origins origin_header
in
Lwt.return (Response.make
~flush:true ~status:(`Code code) ~headers (), body))
(function (function
| Not_found | Cannot_parse _ -> | Not_found | Cannot_parse _ ->
lwt_log_info "(%s) not found" lwt_log_info "(%s) not found"
(Cohttp.Connection.to_string con) >>= fun () -> (Cohttp.Connection.to_string con) >>= fun () ->
(call_hook (io, con) req ~answer_404: true post_hook >>= function (call_hook (io, con) req ~answer_404: true post_hook >>= function
| Some res -> Lwt.return res | Some res -> Lwt.return res
| None -> | None -> answer_with_cors_headers `Not_found)
Lwt.return (Response.make ~flush:true ~status:`Not_found (), | Invalid_method { allowed } ->
Cohttp_lwt_body.empty))
| Invalid_method ->
lwt_log_info "(%s) bad method" lwt_log_info "(%s) bad method"
(Cohttp.Connection.to_string con) >>= fun () -> (Cohttp.Connection.to_string con) >>= fun () ->
let headers = let headers =
Cohttp.Header.add_multi (Cohttp.Header.init ()) Cohttp.Header.add_multi (Cohttp.Header.init ())
"Allow" ["POST"] in "Allow"
let headers = make_cors_headers ~headers cors_allowed_headers cors_allowed_origins origin_header in (List.map Cohttp.Code.string_of_method allowed) in
Lwt.return (Response.make answer_with_cors_headers ~headers `Method_not_allowed
~flush:true ~status:`Method_not_allowed
~headers (),
Cohttp_lwt_body.empty)
| Cannot_parse_body msg -> | Cannot_parse_body msg ->
lwt_log_info "(%s) can't parse RPC body" lwt_log_info "(%s) can't parse RPC body"
(Cohttp.Connection.to_string con) >>= fun () -> (Cohttp.Connection.to_string con) >>= fun () ->
Lwt.return (Response.make ~flush:true ~status:`Bad_request (), let body = Cohttp_lwt_body.of_string msg in
Cohttp_lwt_body.of_string msg) answer_with_cors_headers ~body `Bad_request
| e -> Lwt.fail e) | e -> Lwt.fail e)
and conn_closed (_, con) = and conn_closed (_, con) =
log_info "connection closed %s" (Cohttp.Connection.to_string con) ; log_info "connection closed %s" (Cohttp.Connection.to_string con) ;