(**************************************************************************)
(*                                                                        *)
(*    Copyright (c) 2014 - 2018.                                          *)
(*    Dynamic Ledger Solutions, Inc. <contact@tezos.com>                  *)
(*                                                                        *)
(*    All rights reserved. No warranty, explicit or implicit, provided.   *)
(*                                                                        *)
(**************************************************************************)

open Lwt.Infix

module ConnectionMap = Map.Make(Cohttp.Connection)

module type LOGGING = sig

  val debug: ('a, Format.formatter, unit, unit) format4 -> 'a
  val log_info: ('a, Format.formatter, unit, unit) format4 -> 'a
  val log_notice: ('a, Format.formatter, unit, unit) format4 -> 'a
  val warn: ('a, Format.formatter, unit, unit) format4 -> 'a
  val log_error: ('a, Format.formatter, unit, unit) format4 -> 'a

  val lwt_debug: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
  val lwt_log_info: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
  val lwt_log_notice: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
  val lwt_warn: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a
  val lwt_log_error: ('a, Format.formatter, unit, unit Lwt.t) format4 -> 'a

end

module Make (Encoding : Resto.ENCODING)(Log : LOGGING) = struct

  open Log
  open Cohttp

  module Service = Resto.MakeService(Encoding)
  module Directory = Resto_directory.Make(Encoding)

  module Media_type = Media_type.Make(Encoding)

  type server = {
    root : unit Directory.directory ;
    mutable streams : (unit -> unit) ConnectionMap.t ;
    cors : Cors.t ;
    media_types : Media_type.t list ;
    default_media_type : string * Media_type.t ;
    stopper : unit Lwt.u ;
    mutable worker : unit Lwt.t ;
  }

  let create_stream server con to_string s =
    let running = ref true in
    let stream =
      Lwt_stream.from
        (fun () ->
           if not !running then
             Lwt.return None
           else
             s.Resto_directory.Answer.next () >|= function
             | None -> None
             | Some x -> Some (to_string x)) in
    let shutdown () =
      running := false ;
      s.shutdown () ;
      server.streams <- ConnectionMap.remove con server.streams in
    server.streams <- ConnectionMap.add con shutdown server.streams ;
    stream

  let (>>=?) m f =
    m >>= function
    | Ok x -> f x
    | Error err -> Lwt.return_error err

  let callback server (_io, con) req body =
    (* FIXME: check inbound adress *)
    let uri = Request.uri req in
    lwt_log_info "(%s) receive request to %s"
      (Connection.to_string con) (Uri.path uri) >>= fun () ->
    let path = Utils.split_path (Uri.path uri) in
    let req_headers = Request.headers req in
    begin
      match Request.meth req with
      | #Resto.meth as meth -> begin
          Directory.lookup server.root ()
            meth path >>=? fun (Directory.Service s) ->
          begin
            match Header.get req_headers "content-type" with
            | None -> Lwt.return_ok (snd server.default_media_type)
            | Some content_type ->
                match Utils.split_path content_type with
                | [x ; y] -> begin
                    match Media_type.find_media (x, y) server.media_types with
                    | None ->
                        Lwt.return_error (`Unsupported_media_type content_type)
                    | Some media_type ->
                        Lwt.return_ok media_type
                  end
                | _ ->
                    Lwt.return_error (`Unsupported_media_type content_type)
          end >>=? fun input_media_type ->
          lwt_debug "(%s) input media type %s"
            (Connection.to_string con)
            (Media_type.name input_media_type) >>= fun () ->
          begin
            match Header.get req_headers "accept" with
            | None -> Lwt.return_ok server.default_media_type
            | Some accepted ->
                match Media_type.resolve_accept_header
                        server.media_types (Some accepted) with
                | None -> Lwt.return_error `Not_acceptable
                | Some media_type -> Lwt.return_ok media_type
          end >>=? fun (output_content_type, output_media_type) ->
          begin
            match Resto.Query.parse s.types.query
                    (List.map
                       (fun (k, l) -> (k, String.concat "," l))
                       (Uri.query uri)) with
            | exception (Resto.Query.Invalid s) ->
                Lwt.return_error (`Cannot_parse_query s)
            | query -> Lwt.return_ok query
          end >>=? fun query ->
          lwt_debug "(%s) ouput media type %s"
            (Connection.to_string con)
            (Media_type.name output_media_type) >>= fun () ->
          let output = output_media_type.construct s.types.output
          and error = function
            | None -> Cohttp_lwt.Body.empty, Transfer.Fixed 0L
            | Some e ->
                let s = output_media_type.construct s.types.error e in
                Cohttp_lwt.Body.of_string s,
                Transfer.Fixed (Int64.of_int (String.length s)) in
          let headers = Header.init () in
          let headers =
            Header.add headers "content-type" output_content_type in
          let headers = Cors.add_allow_origin
              headers server.cors (Header.get req_headers "origin") in
          begin
            match s.types.input with
            | Service.No_input ->
                s.handler query () >>= Lwt.return_ok
            | Service.Input input ->
                Cohttp_lwt.Body.to_string body >>= fun body ->
                match
                  input_media_type.destruct input body
                with
                | Error s ->
                    Lwt.return_error (`Cannot_parse_body s)
                | Ok body ->
                    s.handler query body >>= Lwt.return_ok
          end >>=? function
          | `Ok o ->
              let body = output o in
              let encoding =
                Transfer.Fixed (Int64.of_int (String.length body)) in
              Lwt.return_ok
                (Response.make ~status:`OK ~encoding ~headers (),
                 Cohttp_lwt.Body.of_string body)
          | `OkStream o ->
              let body = create_stream server con output o in
              let encoding = Transfer.Chunked in
              Lwt.return_ok
                (Response.make ~status:`OK ~encoding ~headers (),
                 Cohttp_lwt.Body.of_stream body)
          | `Created s ->
              let headers = Header.init () in
              let headers =
                match s with
                | None -> headers
                | Some s -> Header.add headers "location" s in
              Lwt.return_ok
                (Response.make ~status:`Created  ~headers (),
                 Cohttp_lwt.Body.empty)
          | `No_content ->
              Lwt.return_ok
                (Response.make ~status:`No_content (),
                 Cohttp_lwt.Body.empty)
          | `Unauthorized e ->
              let body, encoding = error e in
              let status = `Unauthorized in
              Lwt.return_ok
                (Response.make ~status ~encoding ~headers (), body)
          | `Forbidden e ->
              let body, encoding = error e in
              let status = `Forbidden in
              Lwt.return_ok
                (Response.make ~status ~encoding ~headers (), body)
          | `Not_found e ->
              let body, encoding = error e in
              let status = `Not_found in
              Lwt.return_ok
                (Response.make ~status ~encoding ~headers (), body)
          | `Conflict e ->
              let body, encoding = error e in
              let status = `Conflict in
              Lwt.return_ok
                (Response.make ~status ~encoding ~headers (), body)
          | `Error e ->
              let body, encoding = error e in
              let status = `Internal_server_error in
              Lwt.return_ok
                (Response.make ~status ~encoding ~headers (), body)
        end
      | `HEAD ->
          (* TODO ??? *)
          Lwt.return_error `Not_implemented
      | `OPTIONS ->
          let req_headers = Request.headers req in
          let origin_header = Header.get req_headers "origin" in
          begin
            (* Default OPTIONS handler for CORS preflight *)
            if origin_header = None then
              Directory.allowed_methods server.root () path
            else
              match Header.get req_headers
                      "Access-Control-Request-Method" with
              | None ->
                  Directory.allowed_methods server.root () path
              | Some meth ->
                  match Code.method_of_string meth with
                  | #Resto.meth as meth ->
                      Directory.lookup server.root () meth path >>=? fun _handler ->
                      Lwt.return_ok [ meth ]
                  | _ ->
                      Lwt.return_error `Not_found
          end >>=? fun cors_allowed_meths ->
          lwt_log_info "(%s) RPC preflight"
            (Connection.to_string con) >>= fun () ->
          let headers = Header.init () in
          let headers =
            Header.add_multi headers
              "Access-Control-Allow-Methods"
              (List.map Resto.string_of_meth cors_allowed_meths) in
          let headers = Cors.add_headers headers server.cors origin_header in
          Lwt.return_ok
            (Response.make ~flush:true ~status:`OK ~headers (),
             Cohttp_lwt.Body.empty)
      | _ ->
          Lwt.return_error `Not_implemented
    end >>= function
    | Ok answer -> Lwt.return answer
    | Error `Not_implemented ->
        Lwt.return
          (Response.make ~status:`Not_implemented (),
           Cohttp_lwt.Body.empty)
    | Error `Method_not_allowed methods ->
        let headers = Header.init () in
        let headers =
          Header.add_multi headers "allow"
            (List.map Resto.string_of_meth methods) in
        Lwt.return
          (Response.make ~status:`Method_not_allowed ~headers (),
           Cohttp_lwt.Body.empty)
    | Error `Cannot_parse_path (context, arg, value) ->
        let headers = Header.init () in
        let headers =
          Header.add headers "content-type" "text/plain" in
        Lwt.return
          (Response.make ~status:`Bad_request ~headers (),
           Format.kasprintf Cohttp_lwt.Body.of_string
             "Failed to parsed an argument in path. After \"%s\", \
              the value \"%s\" is not acceptable for type \"%s\""
             (String.concat "/" context) value arg.name)
    | Error `Cannot_parse_body s ->
        let headers = Header.init () in
        let headers =
          Header.add headers "content-type" "text/plain" in
        Lwt.return
          (Response.make ~status:`Bad_request ~headers (),
           Format.kasprintf Cohttp_lwt.Body.of_string
             "Failed to parse the request body: %s" s)
    | Error `Cannot_parse_query s ->
        let headers = Header.init () in
        let headers =
          Header.add headers "content-type" "text/plain" in
        Lwt.return
          (Response.make ~status:`Bad_request ~headers (),
           Format.kasprintf Cohttp_lwt.Body.of_string
             "Failed to parse the query string: %s" s)
    | Error `Not_acceptable ->
        let accepted_encoding =
          Media_type.acceptable_encoding server.media_types in
        Lwt.return
          (Response.make ~status:`Not_acceptable (),
           Cohttp_lwt.Body.of_string accepted_encoding)
    | Error `Unsupported_media_type _ ->
        Lwt.return
          (Response.make ~status:`Unsupported_media_type (),
           Cohttp_lwt.Body.empty)
    | Error `Not_found ->
        Lwt.return
          (Response.make ~status:`Not_found (),
           Cohttp_lwt.Body.empty)

  (* Promise a running RPC server. *)

  let launch
      ?(host="::")
      ?(cors = Cors.default)
      ~media_types
      mode root =
    let default_media_type =
      match Media_type.first_complete_media media_types with
      | None -> invalid_arg "RestoCohttp.launch(empty media type list)"
      | Some ((l, r), m) -> l^"/"^r, m in
    let stop, stopper = Lwt.wait () in
    let server = {
      root ;
      streams = ConnectionMap.empty ;
      cors ;
      media_types ;
      default_media_type ;
      stopper ;
      worker = Lwt.return_unit ;
    } in
    Conduit_lwt_unix.init ~src:host () >>= fun ctx ->
    let ctx = Cohttp_lwt_unix.Net.init ~ctx () in
    server.worker <- begin
      let conn_closed (_, con) =
        log_info "connection closed %s" (Connection.to_string con) ;
        try ConnectionMap.find con server.streams ()
        with Not_found -> ()
      and on_exn = function
        | Unix.Unix_error (Unix.EADDRINUSE, "bind", _) ->
            log_error "RPC server port already taken, \
                       the node will be shutdown" ;
            exit 1
        | Unix.Unix_error (ECONNRESET, _, _)
        | Unix.Unix_error (EPIPE, _, _)  -> ()
        | exn -> !Lwt.async_exception_hook exn
      and callback (io, con) req body =
        Lwt.catch
          begin fun () -> callback server (io, con) req body end
          begin function
            | Not_found ->
                let status = `Not_found in
                let body = Cohttp_lwt.Body.empty in
                Lwt.return (Response.make ~status (), body)
            | exn ->
                let headers = Header.init () in
                let headers =
                  Header.add headers "content-type" "text/ocaml.exception" in
                let status = `Internal_server_error in
                let body = Cohttp_lwt.Body.of_string (Printexc.to_string exn) in
                Lwt.return (Response.make ~status ~headers (), body)
          end
      in
      Cohttp_lwt_unix.Server.create ~stop ~ctx ~mode ~on_exn
        (Cohttp_lwt_unix.Server.make ~callback ~conn_closed ())
    end ;
    Lwt.return server

  let shutdown server =
    Lwt.wakeup_later server.stopper () ;
    server.worker >>= fun () ->
    ConnectionMap.iter (fun _ f -> f ()) server.streams ;
    Lwt.return_unit

end