RPC: Add CORS headers

This commit is contained in:
damian 2016-12-06 13:58:21 +01:00 committed by Vincent Bernardoff
parent 4e9c54efa9
commit 517893f707
4 changed files with 92 additions and 17 deletions

View File

@ -18,10 +18,43 @@ type server = (* hidden *)
module ConnectionMap = Map.Make(Cohttp.Connection) module ConnectionMap = Map.Make(Cohttp.Connection)
exception Invalid_method exception Invalid_method
exception Options_preflight
exception Cannot_parse_body of string exception Cannot_parse_body of string
let check_origin_matches origin allowed_origin =
String.equal "*" allowed_origin ||
String.equal allowed_origin origin ||
begin
let allowed_w_slash = allowed_origin ^ "/" in
let len_a_w_s = String.length allowed_w_slash in
let len_o = String.length origin in
(len_o >= len_a_w_s) &&
String.equal allowed_w_slash @@ String.sub origin 0 len_a_w_s
end
let find_matching_origin allowed_origins origin =
let matching_origins = List.filter (check_origin_matches origin) allowed_origins in
let compare_by_length_neg a b = ~- (compare (String.length a) (String.length b)) in
let matching_origins_sorted = List.sort compare_by_length_neg matching_origins in
match matching_origins_sorted with
| [] -> None
| x :: _ -> Some x
let make_cors_headers ?(headers=Cohttp.Header.init ())
cors_allowed_headers cors_allowed_origins origin_header =
let cors_headers = Cohttp.Header.add_multi headers
"Access-Control-Allow-Headers" cors_allowed_headers in
match origin_header with
| None -> cors_headers
| Some origin ->
match find_matching_origin cors_allowed_origins origin with
| None -> cors_headers
| Some allowed_origin ->
Cohttp.Header.add_multi cors_headers
"Access-Control-Allow-Origin" [allowed_origin]
(* Promise a running RPC server. Takes the port. *) (* Promise a running RPC server. Takes the port. *)
let launch port ?pre_hook ?post_hook root = let launch port ?pre_hook ?post_hook root cors_allowed_origins cors_allowed_headers =
(* launch the worker *) (* launch the worker *)
let cancelation, canceler, _ = Lwt_utils.canceler () in let cancelation, canceler, _ = Lwt_utils.canceler () in
let open Cohttp_lwt_unix in let open Cohttp_lwt_unix in
@ -71,6 +104,8 @@ let launch port ?pre_hook ?post_hook root =
let callback (io, con) req body = let callback (io, con) req body =
(* FIXME: check inbound adress *) (* FIXME: check inbound adress *)
let path = Utils.split_path (Uri.path (Cohttp.Request.uri req)) in let path = Utils.split_path (Uri.path (Cohttp.Request.uri req)) in
let req_headers = Cohttp.Request.headers req in
let origin_header = Cohttp.Header.get req_headers "origin" in
lwt_log_info "(%s) receive request to %s" lwt_log_info "(%s) receive request to %s"
(Cohttp.Connection.to_string con) (Uri.path (Cohttp.Request.uri req)) >>= fun () -> (Cohttp.Connection.to_string con) (Uri.path (Cohttp.Request.uri req)) >>= fun () ->
Lwt.catch Lwt.catch
@ -89,6 +124,7 @@ let launch port ?pre_hook ?post_hook root =
| Ok body -> Lwt.return (Some body) | Ok body -> Lwt.return (Some body)
end end
| `GET -> Lwt.return None | `GET -> Lwt.return None
| `OPTIONS -> Lwt.fail Options_preflight
| _ -> Lwt.fail Invalid_method | _ -> Lwt.fail Invalid_method
end >>= fun body -> end >>= fun body ->
handler body >>= fun { Answer.code ; body } -> handler body >>= fun { Answer.code ; body } ->
@ -106,8 +142,9 @@ let launch port ?pre_hook ?post_hook root =
(if Cohttp.Code.is_error code (if Cohttp.Code.is_error code
then "failed" then "failed"
else "success") >>= fun () -> else "success") >>= fun () ->
Lwt.return (Response.make ~flush:true ~status:(`Code code) (), let headers = make_cors_headers cors_allowed_headers cors_allowed_origins origin_header in
body)) Lwt.return (Response.make
~flush:true ~status:(`Code code) ~headers (), body))
(function (function
| Not_found | Cannot_parse _ -> | Not_found | Cannot_parse _ ->
lwt_log_info "(%s) not found" lwt_log_info "(%s) not found"
@ -123,6 +160,7 @@ let launch port ?pre_hook ?post_hook root =
let headers = let headers =
Cohttp.Header.add_multi (Cohttp.Header.init ()) Cohttp.Header.add_multi (Cohttp.Header.init ())
"Allow" ["POST"] in "Allow" ["POST"] in
let headers = make_cors_headers ~headers cors_allowed_headers cors_allowed_origins origin_header in
Lwt.return (Response.make Lwt.return (Response.make
~flush:true ~status:`Method_not_allowed ~flush:true ~status:`Method_not_allowed
~headers (), ~headers (),
@ -132,6 +170,12 @@ let launch port ?pre_hook ?post_hook root =
(Cohttp.Connection.to_string con) >>= fun () -> (Cohttp.Connection.to_string con) >>= fun () ->
Lwt.return (Response.make ~flush:true ~status:`Bad_request (), Lwt.return (Response.make ~flush:true ~status:`Bad_request (),
Cohttp_lwt_body.of_string msg) Cohttp_lwt_body.of_string msg)
| Options_preflight ->
lwt_log_info "(%s) RPC preflight"
(Cohttp.Connection.to_string con) >>= fun () ->
let headers = make_cors_headers cors_allowed_headers cors_allowed_origins origin_header in
Lwt.return (Response.make ~flush:true ~status:(`Code 200) ~headers (),
Cohttp_lwt_body.empty)
| e -> Lwt.fail e) | e -> Lwt.fail e)
and conn_closed (_, con) = and conn_closed (_, con) =
log_info "connection close %s" (Cohttp.Connection.to_string con) ; log_info "connection close %s" (Cohttp.Connection.to_string con) ;

View File

@ -20,6 +20,16 @@ type server
callable. Calling /pipe will read a sequence of services to call in callable. Calling /pipe will read a sequence of services to call in
sequence from the request body, see {!pipe_encoding}. sequence from the request body, see {!pipe_encoding}.
The arguments cors_allowed_origins and cors_allowed_headers define
the cross-origin resource sharing using the headers
Access-Control-Allow-Origin and Access-Control-Allow-Headers. The
argument cors_allowed_headers sets the content of
Access-Control-Allow-Headers. Since you cannot have multiple
values for Access-Control-Allow-Origin, the server accepts a list
in cors_allowed_origins and matches it against the origin of the
incoming request; then returns the longest element of the passed
list as the content of Access-Control-Allow-Origin.
The optional [pre_hook] is called with the path part of the URL The optional [pre_hook] is called with the path part of the URL
before resolving each request, to delegate the answering to before resolving each request, to delegate the answering to
another resolution mechanism. Its result is ignored if the return another resolution mechanism. Its result is ignored if the return
@ -28,7 +38,10 @@ type server
val launch : int -> val launch : int ->
?pre_hook: (string -> string RPC.Answer.answer Lwt.t) -> ?pre_hook: (string -> string RPC.Answer.answer Lwt.t) ->
?post_hook: (string -> string RPC.Answer.answer Lwt.t) -> ?post_hook: (string -> string RPC.Answer.answer Lwt.t) ->
unit RPC.directory -> server Lwt.t unit RPC.directory ->
string list ->
string list ->
server Lwt.t
(** Kill an RPC server. *) (** Kill an RPC server. *)
val shutdown : server -> unit Lwt.t val shutdown : server -> unit Lwt.t

View File

@ -63,6 +63,8 @@ type cfg = {
(* rpc *) (* rpc *)
rpc_addr : (Ipaddr.t * int) option ; rpc_addr : (Ipaddr.t * int) option ;
cors_origins : string list ;
cors_headers : string list ;
(* log *) (* log *)
log_output : [`Stderr | `File of string | `Syslog | `Null] ; log_output : [`Stderr | `File of string | `Syslog | `Null] ;
@ -93,6 +95,8 @@ let default_cfg_of_base_dir base_dir = {
(* rpc *) (* rpc *)
rpc_addr = None ; rpc_addr = None ;
cors_origins = [] ;
cors_headers = ["content-type"] ;
(* log *) (* log *)
log_output = `Stderr ; log_output = `Stderr ;
@ -153,8 +157,10 @@ module Cfg_file = struct
(opt "peers-cache" string) (opt "peers-cache" string)
let rpc = let rpc =
obj1 obj3
(opt "addr" string) (opt "addr" string)
(dft "cors-origin" (list string) [])
(dft "cors-header" (list string) [])
let log = let log =
obj1 obj1
@ -163,9 +169,9 @@ module Cfg_file = struct
let t = let t =
conv conv
(fun { store ; context ; protocol ; (fun { store ; context ; protocol ;
min_connections ; max_connections ; expected_connections; min_connections ; max_connections ; expected_connections ;
net_addr ; net_port ; local_discovery ; peers; net_addr ; net_port ; local_discovery ; peers ;
closed ; peers_cache ; rpc_addr; log_output } -> closed ; peers_cache ; rpc_addr ; cors_origins ; cors_headers ; log_output } ->
let net_addr = string_of_sockaddr (net_addr, net_port) in let net_addr = string_of_sockaddr (net_addr, net_port) in
let rpc_addr = Utils.map_option string_of_sockaddr rpc_addr in let rpc_addr = Utils.map_option string_of_sockaddr rpc_addr in
let peers = ListLabels.map peers ~f:string_of_sockaddr in let peers = ListLabels.map peers ~f:string_of_sockaddr in
@ -173,11 +179,14 @@ module Cfg_file = struct
((Some store, Some context, Some protocol), ((Some store, Some context, Some protocol),
(Some min_connections, Some max_connections, Some expected_connections, (Some min_connections, Some max_connections, Some expected_connections,
Some net_addr, local_discovery, Some peers, closed, Some peers_cache), Some net_addr, local_discovery, Some peers, closed, Some peers_cache),
rpc_addr, Some log_output)) (rpc_addr, cors_origins, cors_headers),
Some log_output))
(fun ( (fun (
(store, context, protocol), (store, context, protocol),
(min_connections, max_connections, expected_connections, (min_connections, max_connections, expected_connections, net_addr,
net_addr, local_discovery, peers, closed, peers_cache), rpc_addr, log_output) -> local_discovery, peers, closed, peers_cache),
(rpc_addr, cors_origins, cors_headers),
log_output) ->
let open Utils in let open Utils in
let store = unopt default_cfg.store store in let store = unopt default_cfg.store store in
let context = unopt default_cfg.context context in let context = unopt default_cfg.context context in
@ -196,7 +205,7 @@ module Cfg_file = struct
store ; context ; protocol ; store ; context ; protocol ;
min_connections; max_connections; expected_connections; min_connections; max_connections; expected_connections;
net_addr; net_port ; local_discovery; peers; closed; peers_cache; net_addr; net_port ; local_discovery; peers; closed; peers_cache;
rpc_addr; log_output rpc_addr; cors_origins ; cors_headers ; log_output
} }
) )
(obj4 (obj4
@ -273,10 +282,16 @@ module Cmdline = struct
let rpc_addr = let rpc_addr =
let doc = "The TCP socket address at which this RPC server instance can be reached" in let doc = "The TCP socket address at which this RPC server instance can be reached" in
Arg.(value & opt (some sockaddr_converter) None & info ~docs:"RPC" ~doc ~docv:"ADDR:PORT" ["rpc-addr"]) Arg.(value & opt (some sockaddr_converter) None & info ~docs:"RPC" ~doc ~docv:"ADDR:PORT" ["rpc-addr"])
let cors_origins =
let doc = "CORS origin allowed by the RPC server via Access-Control-Allow-Origin; may be used multiple times" in
Arg.(value & opt_all string [] & info ~docs:"RPC" ~doc ~docv:"ORIGIN" ["cors-origin"])
let cors_headers =
let doc = "Header reported by Access-Control-Allow-Headers reported during CORS preflighting; may be used multiple times" in
Arg.(value & opt_all string [] & info ~docs:"RPC" ~doc ~docv:"HEADER" ["cors-header"])
let parse base_dir config_file sandbox sandbox_param log_level let parse base_dir config_file sandbox sandbox_param log_level
min_connections max_connections expected_connections min_connections max_connections expected_connections
net_saddr local_discovery peers closed rpc_addr reset_cfg update_cfg = net_saddr local_discovery peers closed rpc_addr cors_origins cors_headers reset_cfg update_cfg =
let base_dir = Utils.(unopt (unopt default_cfg.base_dir base_dir) sandbox) in let base_dir = Utils.(unopt (unopt default_cfg.base_dir base_dir) sandbox) in
let config_file = Utils.(unopt ((unopt base_dir sandbox) // "config")) config_file in let config_file = Utils.(unopt ((unopt base_dir sandbox) // "config")) config_file in
let no_config () = let no_config () =
@ -317,6 +332,8 @@ module Cmdline = struct
peers = (match peers with [] -> cfg.peers | _ -> peers) ; peers = (match peers with [] -> cfg.peers | _ -> peers) ;
closed = closed || cfg.closed ; closed = closed || cfg.closed ;
rpc_addr = Utils.first_some rpc_addr cfg.rpc_addr ; rpc_addr = Utils.first_some rpc_addr cfg.rpc_addr ;
cors_origins = (match cors_origins with [] -> cfg.cors_origins | _ -> cors_origins) ;
cors_headers = (match cors_headers with [] -> cfg.cors_headers | _ -> cors_headers) ;
log_output = cfg.log_output ; log_output = cfg.log_output ;
} }
in in
@ -328,7 +345,8 @@ module Cmdline = struct
ret (const parse $ base_dir $ config_file ret (const parse $ base_dir $ config_file
$ sandbox $ sandbox_param $ v $ sandbox $ sandbox_param $ v
$ min_connections $ max_connections $ expected_connections $ min_connections $ max_connections $ expected_connections
$ net_addr $ local_discovery $ peers $ closed $ rpc_addr $ net_addr $ local_discovery $ peers $ closed
$ rpc_addr $ cors_origins $ cors_headers
$ reset_config $ update_config $ reset_config $ update_config
), ),
let doc = "The Tezos daemon" in let doc = "The Tezos daemon" in
@ -420,7 +438,7 @@ let init_node { sandbox ; sandbox_param ;
?patch_context ?patch_context
net_params net_params
let init_rpc { rpc_addr } node = let init_rpc { rpc_addr ; cors_origins ; cors_headers } node =
match rpc_addr with match rpc_addr with
| None -> | None ->
lwt_log_notice "Not listening to RPC calls." >>= fun () -> lwt_log_notice "Not listening to RPC calls." >>= fun () ->
@ -428,7 +446,7 @@ let init_rpc { rpc_addr } node =
| Some (_addr, port) -> | Some (_addr, port) ->
lwt_log_notice "Starting the RPC server listening on port %d." port >>= fun () -> lwt_log_notice "Starting the RPC server listening on port %d." port >>= fun () ->
let dir = Node_rpc.build_rpc_directory node in let dir = Node_rpc.build_rpc_directory node in
RPC_server.launch port dir >>= fun server -> RPC_server.launch port dir cors_origins cors_headers >>= fun server ->
Lwt.return (Some server) Lwt.return (Some server)
let init_signal () = let init_signal () =

View File

@ -163,7 +163,7 @@ let http_proxy port =
| None -> | None ->
Lwt.return (RPC.Answer.Empty)) >>= fun body -> Lwt.return (RPC.Answer.Empty)) >>= fun body ->
Lwt.return { RPC.Answer.code = 404 ; body } in Lwt.return { RPC.Answer.code = 404 ; body } in
RPC_server.launch ~pre_hook ~post_hook port root RPC_server.launch ~pre_hook ~post_hook port root [] []
let web_port = Client_config.in_both_groups @@ let web_port = Client_config.in_both_groups @@
new Config_file.int_cp [ "web" ; "port" ] 8080 new Config_file.int_cp [ "web" ; "port" ] 8080