diff --git a/src/lib_shell/prevalidator.ml b/src/lib_shell/prevalidator.ml index 841efd887..8d70b1ce0 100644 --- a/src/lib_shell/prevalidator.ml +++ b/src/lib_shell/prevalidator.ml @@ -320,6 +320,16 @@ module Make(Proto: Registered_protocol.T)(Arg: ARG): T = struct Worker.push_request_now w Advertise ; Lwt.return_unit) + let is_endorsement ( op : Prevalidation.operation ) = + Proto.acceptable_passes { + shell = op.raw.shell ; + protocol_data = op.protocol_data } = [0] + + let is_endorsement_raw op = + match Prevalidation.parse op with + |Ok op -> is_endorsement op + |Error _ -> false + let handle_unprocessed w pv = begin match pv.validation_state with | Error err -> @@ -358,8 +368,9 @@ module Make(Proto: Registered_protocol.T)(Arg: ARG): T = struct | Ok op -> Prevalidation.apply_operation state op >>= function | Applied (new_acc_validation_state, _) -> - if pv.applied_count <= 2000 (* this test is a quick fix while we wait for the new mempool *) - || Proto.acceptable_passes { shell = op.raw.shell ; protocol_data = op.protocol_data } = [0] then begin + if pv.applied_count <= 2000 + (* this test is a quick fix while we wait for the new mempool *) + || is_endorsement op then begin notify_operation pv `Applied op.raw ; let new_mempool = Mempool.{ acc_mempool with known_valid = op.hash :: acc_mempool.known_valid } in pv.applied <- (op.hash, op.raw) :: pv.applied ; @@ -403,12 +414,19 @@ module Make(Proto: Registered_protocol.T)(Arg: ARG): T = struct List.rev_map fst pv.applied ; pending = Operation_hash.Map.fold - (fun k _ s -> Operation_hash.Set.add k s) + (fun k (op,_) s -> + if is_endorsement_raw op then + Operation_hash.Set.add k s + else s) pv.branch_delays @@ Operation_hash.Map.fold - (fun k _ s -> Operation_hash.Set.add k s) + (fun k (op,_) s -> + if is_endorsement_raw op then + Operation_hash.Set.add k s + else s) pv.branch_refusals @@ - Operation_hash.Set.empty } ; + Operation_hash.Set.empty + } ; State.Current_mempool.set (Distributed_db.chain_state pv.chain_db) ~head:(State.Block.hash pv.predecessor) pv.mempool >>= fun () -> Lwt.return_unit