diff --git a/src/node/net/RPC_server.ml b/src/node/net/RPC_server.ml index 8d532e411..62a829348 100644 --- a/src/node/net/RPC_server.ml +++ b/src/node/net/RPC_server.ml @@ -18,10 +18,43 @@ type server = (* hidden *) module ConnectionMap = Map.Make(Cohttp.Connection) exception Invalid_method +exception Options_preflight exception Cannot_parse_body of string +let check_origin_matches origin allowed_origin = + String.equal "*" allowed_origin || + String.equal allowed_origin origin || + begin + let allowed_w_slash = allowed_origin ^ "/" in + let len_a_w_s = String.length allowed_w_slash in + let len_o = String.length origin in + (len_o >= len_a_w_s) && + String.equal allowed_w_slash @@ String.sub origin 0 len_a_w_s + end + +let find_matching_origin allowed_origins origin = + let matching_origins = List.filter (check_origin_matches origin) allowed_origins in + let compare_by_length_neg a b = ~- (compare (String.length a) (String.length b)) in + let matching_origins_sorted = List.sort compare_by_length_neg matching_origins in + match matching_origins_sorted with + | [] -> None + | x :: _ -> Some x + +let make_cors_headers ?(headers=Cohttp.Header.init ()) + cors_allowed_headers cors_allowed_origins origin_header = + let cors_headers = Cohttp.Header.add_multi headers + "Access-Control-Allow-Headers" cors_allowed_headers in + match origin_header with + | None -> cors_headers + | Some origin -> + match find_matching_origin cors_allowed_origins origin with + | None -> cors_headers + | Some allowed_origin -> + Cohttp.Header.add_multi cors_headers + "Access-Control-Allow-Origin" [allowed_origin] + (* Promise a running RPC server. Takes the port. *) -let launch port ?pre_hook ?post_hook root = +let launch port ?pre_hook ?post_hook root cors_allowed_origins cors_allowed_headers = (* launch the worker *) let cancelation, canceler, _ = Lwt_utils.canceler () in let open Cohttp_lwt_unix in @@ -71,6 +104,8 @@ let launch port ?pre_hook ?post_hook root = let callback (io, con) req body = (* FIXME: check inbound adress *) let path = Utils.split_path (Uri.path (Cohttp.Request.uri req)) in + let req_headers = Cohttp.Request.headers req in + let origin_header = Cohttp.Header.get req_headers "origin" in lwt_log_info "(%s) receive request to %s" (Cohttp.Connection.to_string con) (Uri.path (Cohttp.Request.uri req)) >>= fun () -> Lwt.catch @@ -89,6 +124,7 @@ let launch port ?pre_hook ?post_hook root = | Ok body -> Lwt.return (Some body) end | `GET -> Lwt.return None + | `OPTIONS -> Lwt.fail Options_preflight | _ -> Lwt.fail Invalid_method end >>= fun body -> handler body >>= fun { Answer.code ; body } -> @@ -106,8 +142,9 @@ let launch port ?pre_hook ?post_hook root = (if Cohttp.Code.is_error code then "failed" else "success") >>= fun () -> - Lwt.return (Response.make ~flush:true ~status:(`Code code) (), - body)) + let headers = make_cors_headers cors_allowed_headers cors_allowed_origins origin_header in + Lwt.return (Response.make + ~flush:true ~status:(`Code code) ~headers (), body)) (function | Not_found | Cannot_parse _ -> lwt_log_info "(%s) not found" @@ -123,6 +160,7 @@ let launch port ?pre_hook ?post_hook root = let headers = Cohttp.Header.add_multi (Cohttp.Header.init ()) "Allow" ["POST"] in + let headers = make_cors_headers ~headers cors_allowed_headers cors_allowed_origins origin_header in Lwt.return (Response.make ~flush:true ~status:`Method_not_allowed ~headers (), @@ -132,6 +170,12 @@ let launch port ?pre_hook ?post_hook root = (Cohttp.Connection.to_string con) >>= fun () -> Lwt.return (Response.make ~flush:true ~status:`Bad_request (), Cohttp_lwt_body.of_string msg) + | Options_preflight -> + lwt_log_info "(%s) RPC preflight" + (Cohttp.Connection.to_string con) >>= fun () -> + let headers = make_cors_headers cors_allowed_headers cors_allowed_origins origin_header in + Lwt.return (Response.make ~flush:true ~status:(`Code 200) ~headers (), + Cohttp_lwt_body.empty) | e -> Lwt.fail e) and conn_closed (_, con) = log_info "connection close %s" (Cohttp.Connection.to_string con) ; diff --git a/src/node/net/RPC_server.mli b/src/node/net/RPC_server.mli index 9366b3606..de93fce85 100644 --- a/src/node/net/RPC_server.mli +++ b/src/node/net/RPC_server.mli @@ -20,6 +20,16 @@ type server callable. Calling /pipe will read a sequence of services to call in sequence from the request body, see {!pipe_encoding}. + The arguments cors_allowed_origins and cors_allowed_headers define + the cross-origin resource sharing using the headers + Access-Control-Allow-Origin and Access-Control-Allow-Headers. The + argument cors_allowed_headers sets the content of + Access-Control-Allow-Headers. Since you cannot have multiple + values for Access-Control-Allow-Origin, the server accepts a list + in cors_allowed_origins and matches it against the origin of the + incoming request; then returns the longest element of the passed + list as the content of Access-Control-Allow-Origin. + The optional [pre_hook] is called with the path part of the URL before resolving each request, to delegate the answering to another resolution mechanism. Its result is ignored if the return @@ -28,7 +38,10 @@ type server val launch : int -> ?pre_hook: (string -> string RPC.Answer.answer Lwt.t) -> ?post_hook: (string -> string RPC.Answer.answer Lwt.t) -> - unit RPC.directory -> server Lwt.t + unit RPC.directory -> + string list -> + string list -> + server Lwt.t (** Kill an RPC server. *) val shutdown : server -> unit Lwt.t diff --git a/src/node_main.ml b/src/node_main.ml index a1e1c8f58..229dd7a43 100644 --- a/src/node_main.ml +++ b/src/node_main.ml @@ -63,6 +63,8 @@ type cfg = { (* rpc *) rpc_addr : (Ipaddr.t * int) option ; + cors_origins : string list ; + cors_headers : string list ; (* log *) log_output : [`Stderr | `File of string | `Syslog | `Null] ; @@ -93,6 +95,8 @@ let default_cfg_of_base_dir base_dir = { (* rpc *) rpc_addr = None ; + cors_origins = [] ; + cors_headers = ["content-type"] ; (* log *) log_output = `Stderr ; @@ -153,8 +157,10 @@ module Cfg_file = struct (opt "peers-cache" string) let rpc = - obj1 + obj3 (opt "addr" string) + (dft "cors-origin" (list string) []) + (dft "cors-header" (list string) []) let log = obj1 @@ -163,9 +169,9 @@ module Cfg_file = struct let t = conv (fun { store ; context ; protocol ; - min_connections ; max_connections ; expected_connections; - net_addr ; net_port ; local_discovery ; peers; - closed ; peers_cache ; rpc_addr; log_output } -> + min_connections ; max_connections ; expected_connections ; + net_addr ; net_port ; local_discovery ; peers ; + closed ; peers_cache ; rpc_addr ; cors_origins ; cors_headers ; log_output } -> let net_addr = string_of_sockaddr (net_addr, net_port) in let rpc_addr = Utils.map_option string_of_sockaddr rpc_addr in let peers = ListLabels.map peers ~f:string_of_sockaddr in @@ -173,11 +179,14 @@ module Cfg_file = struct ((Some store, Some context, Some protocol), (Some min_connections, Some max_connections, Some expected_connections, Some net_addr, local_discovery, Some peers, closed, Some peers_cache), - rpc_addr, Some log_output)) + (rpc_addr, cors_origins, cors_headers), + Some log_output)) (fun ( (store, context, protocol), - (min_connections, max_connections, expected_connections, - net_addr, local_discovery, peers, closed, peers_cache), rpc_addr, log_output) -> + (min_connections, max_connections, expected_connections, net_addr, + local_discovery, peers, closed, peers_cache), + (rpc_addr, cors_origins, cors_headers), + log_output) -> let open Utils in let store = unopt default_cfg.store store in let context = unopt default_cfg.context context in @@ -196,7 +205,7 @@ module Cfg_file = struct store ; context ; protocol ; min_connections; max_connections; expected_connections; net_addr; net_port ; local_discovery; peers; closed; peers_cache; - rpc_addr; log_output + rpc_addr; cors_origins ; cors_headers ; log_output } ) (obj4 @@ -273,10 +282,16 @@ module Cmdline = struct let rpc_addr = let doc = "The TCP socket address at which this RPC server instance can be reached" in Arg.(value & opt (some sockaddr_converter) None & info ~docs:"RPC" ~doc ~docv:"ADDR:PORT" ["rpc-addr"]) + let cors_origins = + let doc = "CORS origin allowed by the RPC server via Access-Control-Allow-Origin; may be used multiple times" in + Arg.(value & opt_all string [] & info ~docs:"RPC" ~doc ~docv:"ORIGIN" ["cors-origin"]) + let cors_headers = + let doc = "Header reported by Access-Control-Allow-Headers reported during CORS preflighting; may be used multiple times" in + Arg.(value & opt_all string [] & info ~docs:"RPC" ~doc ~docv:"HEADER" ["cors-header"]) let parse base_dir config_file sandbox sandbox_param log_level min_connections max_connections expected_connections - net_saddr local_discovery peers closed rpc_addr reset_cfg update_cfg = + net_saddr local_discovery peers closed rpc_addr cors_origins cors_headers reset_cfg update_cfg = let base_dir = Utils.(unopt (unopt default_cfg.base_dir base_dir) sandbox) in let config_file = Utils.(unopt ((unopt base_dir sandbox) // "config")) config_file in let no_config () = @@ -317,6 +332,8 @@ module Cmdline = struct peers = (match peers with [] -> cfg.peers | _ -> peers) ; closed = closed || cfg.closed ; rpc_addr = Utils.first_some rpc_addr cfg.rpc_addr ; + cors_origins = (match cors_origins with [] -> cfg.cors_origins | _ -> cors_origins) ; + cors_headers = (match cors_headers with [] -> cfg.cors_headers | _ -> cors_headers) ; log_output = cfg.log_output ; } in @@ -328,7 +345,8 @@ module Cmdline = struct ret (const parse $ base_dir $ config_file $ sandbox $ sandbox_param $ v $ min_connections $ max_connections $ expected_connections - $ net_addr $ local_discovery $ peers $ closed $ rpc_addr + $ net_addr $ local_discovery $ peers $ closed + $ rpc_addr $ cors_origins $ cors_headers $ reset_config $ update_config ), let doc = "The Tezos daemon" in @@ -420,7 +438,7 @@ let init_node { sandbox ; sandbox_param ; ?patch_context net_params -let init_rpc { rpc_addr } node = +let init_rpc { rpc_addr ; cors_origins ; cors_headers } node = match rpc_addr with | None -> lwt_log_notice "Not listening to RPC calls." >>= fun () -> @@ -428,7 +446,7 @@ let init_rpc { rpc_addr } node = | Some (_addr, port) -> lwt_log_notice "Starting the RPC server listening on port %d." port >>= fun () -> let dir = Node_rpc.build_rpc_directory node in - RPC_server.launch port dir >>= fun server -> + RPC_server.launch port dir cors_origins cors_headers >>= fun server -> Lwt.return (Some server) let init_signal () = diff --git a/src/webclient_main.ml b/src/webclient_main.ml index 1f79b809e..52ca74bfd 100644 --- a/src/webclient_main.ml +++ b/src/webclient_main.ml @@ -163,7 +163,7 @@ let http_proxy port = | None -> Lwt.return (RPC.Answer.Empty)) >>= fun body -> Lwt.return { RPC.Answer.code = 404 ; body } in - RPC_server.launch ~pre_hook ~post_hook port root + RPC_server.launch ~pre_hook ~post_hook port root [] [] let web_port = Client_config.in_both_groups @@ new Config_file.int_cp [ "web" ; "port" ] 8080