diff --git a/src/lib_shell/mempool_worker.ml b/src/lib_shell/mempool_worker.ml index 8b81a1bb1..c242a823b 100644 --- a/src/lib_shell/mempool_worker.ml +++ b/src/lib_shell/mempool_worker.ml @@ -47,7 +47,7 @@ module type T = sig | Not_in_branch (** Creates/tear-down a new mempool validator context. *) - val create : limits -> Distributed_db.chain_db -> t Lwt.t + val create : limits -> Distributed_db.chain_db -> t tzresult Lwt.t val shutdown : t -> unit Lwt.t (** parse a new operation and add it to the mempool context *) @@ -56,7 +56,7 @@ module type T = sig (** validate a new operation and add it to the mempool context *) val validate : t -> operation -> result tzresult Lwt.t - val chain_db : t -> Distributed_db.chain_db tzresult + val chain_db : t -> Distributed_db.chain_db val rpc_directory : t RPC_directory.t @@ -352,11 +352,12 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct type parameters = { limits : limits ; - chain_db : Distributed_db.chain_db + chain_db : Distributed_db.chain_db ; + validation_state : Proto.validation_state ; } (* internal worker state *) - type worker_state = + type state = { (* state of the validator. this is updated at each apply_operation *) mutable validation_state : Proto.validation_state ; @@ -375,22 +376,17 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct parameters : parameters ; } - type state = worker_state tzresult - type worker_view = { cache : Cache.t } - type view = worker_view tzresult + type view = { cache : Cache.t } - let view (state : state) _ : view = - state >|? fun state -> { cache = state.cache } + let view (state : state) _ : view = { cache = state.cache } let encoding = let open Data_encoding in - Error_monad.result_encoding ( - conv - (fun { cache } -> cache) - (fun cache -> { cache }) - Cache.encoding - ) + conv + (fun { cache } -> cache) + (fun cache -> { cache }) + Cache.encoding let pp ppf _view = Format.fprintf ppf "lots of operations" @@ -484,7 +480,7 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct (* this function update the internal state of the worker *) let validate_helper w parsed_op = - Lwt.return (Worker.state w) >>=? fun state -> + let state = Worker.state w in apply_operation state parsed_op >>= fun (validation_state,result) -> begin match validation_state with @@ -494,19 +490,17 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct return result let notify_helper w result { Operation.shell ; proto } = - match Worker.state w with + let state = Worker.state w in (* this function is called by on_validate where we take care of the error *) - | Error _err -> () - | Ok state -> - let protocol_data = - Data_encoding.Binary.of_bytes_exn - Proto.operation_data_encoding - proto in - Lwt_watcher.notify state.operation_stream (result, shell, protocol_data) + let protocol_data = + Data_encoding.Binary.of_bytes_exn + Proto.operation_data_encoding + proto in + Lwt_watcher.notify state.operation_stream (result, shell, protocol_data) (* memoization is done only at on_* level *) let on_validate w parsed_op = - Lwt.return (Worker.state w) >>=? fun state -> + let state = Worker.state w in match Cache.find_validated_opt state.cache parsed_op with | None | Some (Branch_delayed _) -> validate_helper w parsed_op >>=? fun result -> @@ -517,7 +511,7 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct | Some result -> return result let on_parse w raw_op = - Lwt.return (Worker.state w) >>=? fun state -> + let state = Worker.state w in match Cache.find_parsed_opt state.cache raw_op with | None -> parse_helper w raw_op >>= fun parsed_op -> @@ -532,14 +526,10 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct | Request.Parse raw_op -> on_parse w raw_op | Request.Validate parsed_op -> on_validate w parsed_op - let on_launch (_ : t) (_ : Name.t) ( { chain_db } as parameters ) = + let on_launch (_ : t) (_ : Name.t) ( { chain_db ; validation_state } as parameters ) = let chain_state = Distributed_db.chain_state chain_db in - Chain.data chain_state >>= fun - { current_head = predecessor ; current_mempool = _mempool ; - live_blocks ; live_operations } -> - let timestamp = Time.now () in - create ~predecessor ~timestamp () >>=? fun validation_state -> - return { + Chain.data chain_state >>= fun { current_mempool = _mempool ; live_blocks ; live_operations } -> + Lwt.return { validation_state ; cache = Cache.create () ; live_blocks ; @@ -549,16 +539,14 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct } let on_close w = - match Worker.state w with - | Error _err -> Lwt.return_unit - | Ok state -> - Lwt_watcher.shutdown_input state.operation_stream; - Cache.iter_validated (fun hash _ -> - Distributed_db.Operation.clear_or_cancel - state.parameters.chain_db hash) - state.cache ; - Cache.clear state.cache; - Lwt.return_unit + let state = Worker.state w in + Lwt_watcher.shutdown_input state.operation_stream; + Cache.iter_validated (fun hash _ -> + Distributed_db.Operation.clear_or_cancel + state.parameters.chain_db hash) + state.cache ; + Cache.clear state.cache; + Lwt.return_unit let on_error w r st errs = Worker.record_event w (Event.Request (r, st, Some errs)) ; @@ -575,7 +563,6 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct let chain_id = State.Chain.id chain_state in let module Handlers = struct type self = t - let on_launch = on_launch let on_close = on_close let on_error = on_error @@ -583,12 +570,15 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct let on_no_request _ = return_unit let on_request = on_request end in - Worker.launch - table - limits.worker_limits - (chain_id, Proto.hash) - { limits ; chain_db } - (module Handlers) + Chain.data chain_state >>= fun { current_head = predecessor } -> + let timestamp = Time.now () in + create ~predecessor ~timestamp () >>=? fun validation_state -> + (Worker.launch + table + limits.worker_limits + (chain_id, Proto.hash) + { limits ; chain_db ; validation_state } + (module Handlers) >>= return) (* Exporting functions *) @@ -600,7 +590,6 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct let chain_db t = let state = Worker.state t in - state >|? fun state -> state.parameters.chain_db let pending_rpc_directory : t RPC_directory.t = @@ -608,9 +597,8 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct RPC_directory.empty (Proto_services.S.Mempool.pending_operations RPC_path.open_root) (fun w () () -> - match Worker.state w with - | Error err -> RPC_answer.fail err - | Ok state -> RPC_answer.return (Cache.to_mempool state.cache) + let state = Worker.state w in + RPC_answer.return (Cache.to_mempool state.cache) ) let monitor_rpc_directory : t RPC_directory.t = @@ -618,24 +606,22 @@ module Make(Proto: Registered_protocol.T) : T with module Proto = Proto = struct RPC_directory.empty (Proto_services.S.Mempool.monitor_operations RPC_path.open_root) (fun w params () -> - match Worker.state w with - | Error err -> RPC_answer.fail err - | Ok state -> - let filter_result = function - | Applied _ -> params#applied - | Refused _ -> params#branch_refused - | Branch_refused _ -> params#refused - | Branch_delayed _ -> params#branch_delayed - | _ -> false in + let state = Worker.state w in + let filter_result = function + | Applied _ -> params#applied + | Refused _ -> params#branch_refused + | Branch_refused _ -> params#refused + | Branch_delayed _ -> params#branch_delayed + | _ -> false in - let op_stream, stopper = Lwt_watcher.create_stream state.operation_stream in - let shutdown () = Lwt_watcher.shutdown stopper in - let next () = - Lwt_stream.get op_stream >>= function - | Some (kind, shell, protocol_data) when filter_result kind -> - Lwt.return_some [ { Proto.shell ; protocol_data } ] - | _ -> Lwt.return_none in - RPC_answer.return_stream { next ; shutdown } + let op_stream, stopper = Lwt_watcher.create_stream state.operation_stream in + let shutdown () = Lwt_watcher.shutdown stopper in + let next () = + Lwt_stream.get op_stream >>= function + | Some (kind, shell, protocol_data) when filter_result kind -> + Lwt.return_some [ { Proto.shell ; protocol_data } ] + | _ -> Lwt.return_none in + RPC_answer.return_stream { next ; shutdown } ) (* /mempool//pending diff --git a/src/lib_shell/mempool_worker.mli b/src/lib_shell/mempool_worker.mli index ce528d7fe..d113f4497 100644 --- a/src/lib_shell/mempool_worker.mli +++ b/src/lib_shell/mempool_worker.mli @@ -48,7 +48,7 @@ module type T = sig | Not_in_branch (** Creates/tear-down a new mempool validator context. *) - val create : limits -> Distributed_db.chain_db -> t Lwt.t + val create : limits -> Distributed_db.chain_db -> t tzresult Lwt.t val shutdown : t -> unit Lwt.t (** parse a new operation and add it to the mempool context *) @@ -57,7 +57,7 @@ module type T = sig (** validate a new operation and add it to the mempool context *) val validate : t -> operation -> result tzresult Lwt.t - val chain_db : t -> Distributed_db.chain_db tzresult + val chain_db : t -> Distributed_db.chain_db val rpc_directory : t RPC_directory.t