diff --git a/src/lib_shell/mempool_worker.ml b/src/lib_shell/mempool_worker.ml index f26bb792e..40a3bedb8 100644 --- a/src/lib_shell/mempool_worker.ml +++ b/src/lib_shell/mempool_worker.ml @@ -51,7 +51,7 @@ module type T = sig val shutdown : t -> unit Lwt.t (** parse a new operation and add it to the mempool context *) - val parse : t -> Operation.t -> operation tzresult + val parse : Operation.t -> operation tzresult (** validate a new operation and add it to the mempool context *) val validate : t -> operation -> result tzresult Lwt.t @@ -227,62 +227,58 @@ module Make(Static: STATIC)(Proto: Registered_protocol.T) (* parsed operations' cache. used for memoization *) module ParsedCache = struct - type t = { - table: operation tzresult Operation_hash.Table.t ; - ring: Operation_hash.t Ring.t ; - } + type t = operation tzresult Operation_hash.Table.t - let create () : t = { - table = Operation_hash.Table.create Static.max_size_parsed_cache ; - ring = Ring.create Static.max_size_parsed_cache ; - } + let encoding = + (Operation_hash.Table.encoding + (Error_monad.result_encoding operation_encoding)) + + let create () : t = + Operation_hash.Table.create 1000 let add t raw_op parsed_op = let hash = Operation.hash raw_op in - Option.iter - ~f:(Operation_hash.Table.remove t.table) - (Ring.add_and_return_erased t.ring hash); - Operation_hash.Table.replace t.table hash parsed_op + Operation_hash.Table.replace t hash parsed_op + + let mem t raw_op = + let hash = Operation.hash raw_op in + Operation_hash.Table.mem t hash let find_opt t raw_op = let hash = Operation.hash raw_op in - Operation_hash.Table.find_opt t.table hash + Operation_hash.Table.find_opt t hash let find_hash_opt t hash = - Operation_hash.Table.find_opt t.table hash + Operation_hash.Table.find_opt t hash let rem t hash = - (* NOTE: hashes are not removed from the ring. As a result, the cache size - * bound can be lowered. This is a non-issue because it's only a cache. *) - Operation_hash.Table.remove t.table hash + Operation_hash.Table.remove t hash end (* validated operations' cache. used for memoization *) module ValidatedCache = struct - type t = (result * Operation.t) Operation_hash.Table.t + type t = result Operation_hash.Table.t let encoding = - let open Data_encoding in - Operation_hash.Table.encoding ( - tup2 - result_encoding - Operation.encoding - ) + Operation_hash.Table.encoding result_encoding let create () = Operation_hash.Table.create 1000 let add t parsed_op result = Operation_hash.Table.replace t parsed_op.hash result + let mem t parsed_op = + Operation_hash.Table.mem t parsed_op.hash + let find_opt t parsed_op = Operation_hash.Table.find_opt t parsed_op.hash let iter f t = Operation_hash.Table.iter f t - let to_mempool t = + let to_mempool t parsed_cache = let empty = { Proto_services.Mempool.applied = [] ; refused = Operation_hash.Map.empty ; @@ -297,40 +293,44 @@ module Make(Static: STATIC)(Proto: Registered_protocol.T) op.Operation.proto in { Proto.shell = op.shell ; protocol_data } in Operation_hash.Table.fold - (fun hash (result,raw_op) acc -> - let proto_op = map_op raw_op in - match result with - | Applied _ -> { - acc with - Proto_services.Mempool.applied = - (hash, proto_op)::acc.Proto_services.Mempool.applied - } - | Branch_refused err -> { - acc with - Proto_services.Mempool.branch_refused = - Operation_hash.Map.add - hash - (proto_op,err) - acc.Proto_services.Mempool.branch_refused - } - | Branch_delayed err -> { - acc with - Proto_services.Mempool.branch_delayed = - Operation_hash.Map.add - hash - (proto_op,err) - acc.Proto_services.Mempool.branch_delayed - } - | Refused err -> { - acc with - Proto_services.Mempool.refused = - Operation_hash.Map.add - hash - (proto_op,err) - acc.Proto_services.Mempool.refused - } - | _ -> acc - ) t empty + (fun hash result acc -> + match ParsedCache.find_hash_opt parsed_cache hash with + (* XXX this invariant should be better enforced *) + | None | Some (Error _) -> assert false + | Some (Ok op) -> begin + match result with + | Applied _ -> { + acc with + Proto_services.Mempool.applied = + (hash, map_op op.raw)::acc.Proto_services.Mempool.applied + } + | Branch_refused err -> { + acc with + Proto_services.Mempool.branch_refused = + Operation_hash.Map.add + hash + (map_op op.raw,err) + acc.Proto_services.Mempool.branch_refused + } + | Branch_delayed err -> { + acc with + Proto_services.Mempool.branch_delayed = + Operation_hash.Map.add + hash + (map_op op.raw,err) + acc.Proto_services.Mempool.branch_delayed + } + | Refused err -> { + acc with + Proto_services.Mempool.refused = + Operation_hash.Map.add + hash + (map_op op.raw,err) + acc.Proto_services.Mempool.refused + } + | _ -> acc + end) + t empty let clear t = Operation_hash.Table.clear t @@ -365,7 +365,7 @@ module Make(Static: STATIC)(Proto: Registered_protocol.T) parameters : parameters ; } - type view = { cache : Cache.t } + type view = { cache : ValidatedCache.t } let view (state : state) _ : view = { cache = state.cache } @@ -374,7 +374,7 @@ module Make(Static: STATIC)(Proto: Registered_protocol.T) conv (fun { cache } -> cache) (fun cache -> { cache }) - Cache.encoding + ValidatedCache.encoding let pp ppf _view = Format.fprintf ppf "lots of operations" @@ -491,10 +491,10 @@ module Make(Static: STATIC)(Proto: Registered_protocol.T) (* memoization is done only at on_* level *) let on_validate w parsed_op = let state = Worker.state w in - match Cache.find_validated_opt state.cache parsed_op with + match ValidatedCache.find_opt state.cache parsed_op with | None | Some (Branch_delayed _) -> validate_helper w parsed_op >>= fun result -> - Cache.add_validated state.cache parsed_op result; + ValidatedCache.add state.cache parsed_op result; (* operations are notified only the first time *) notify_helper w result parsed_op.raw ; Lwt.return result @@ -511,6 +511,10 @@ module Make(Static: STATIC)(Proto: Registered_protocol.T) Chain.data chain_state >>= fun { current_mempool = _mempool ; live_blocks ; live_operations } -> + (* remove all operations that are already included *) + Operation_hash.Set.iter (fun hash -> + ParsedCache.rem parsed_cache hash + ) live_operations; Lwt.return { validation_state ; cache = ValidatedCache.create () ; @@ -523,11 +527,11 @@ module Make(Static: STATIC)(Proto: Registered_protocol.T) let on_close w = let state = Worker.state w in Lwt_watcher.shutdown_input state.operation_stream; - Cache.iter_validated (fun hash _ -> + ValidatedCache.iter (fun hash _ -> Distributed_db.Operation.clear_or_cancel state.parameters.chain_db hash) state.cache ; - Cache.clear state.cache; + ValidatedCache.clear state.cache; Lwt.return_unit let on_error w r st errs = @@ -568,12 +572,11 @@ module Make(Static: STATIC)(Proto: Registered_protocol.T) Worker.push_request_and_wait t (Request.Validate parsed_op) (* atomic parse + memoization *) - let parse t raw_op = - let state = Worker.state t in - begin match Cache.find_parsed_opt state.cache raw_op with + let parse raw_op = + begin match ParsedCache.find_opt parsed_cache raw_op with | None -> - let parsed_op = parse_helper t raw_op in - Cache.add_parsed state.cache raw_op parsed_op; + let parsed_op = parse_helper raw_op in + ParsedCache.add parsed_cache raw_op parsed_op; parsed_op | Some parsed_op -> parsed_op end @@ -588,7 +591,7 @@ module Make(Static: STATIC)(Proto: Registered_protocol.T) (Proto_services.S.Mempool.pending_operations RPC_path.open_root) (fun w () () -> let state = Worker.state w in - RPC_answer.return (Cache.to_mempool state.cache) + RPC_answer.return (ValidatedCache.to_mempool state.cache parsed_cache) ) let monitor_rpc_directory : t RPC_directory.t = diff --git a/src/lib_shell/mempool_worker.mli b/src/lib_shell/mempool_worker.mli index a485641b7..f5c33e8c9 100644 --- a/src/lib_shell/mempool_worker.mli +++ b/src/lib_shell/mempool_worker.mli @@ -51,8 +51,8 @@ module type T = sig val create : limits -> Distributed_db.chain_db -> t tzresult Lwt.t val shutdown : t -> unit Lwt.t - (** parse a new operation and add it to the mempool context *) - val parse : t -> Operation.t -> operation tzresult + (** parse a new operation *) + val parse : Operation.t -> operation tzresult (** validate a new operation and add it to the mempool context *) val validate : t -> operation -> result tzresult Lwt.t