diff --git a/src/lib_shell/chain.ml b/src/lib_shell/chain.ml index 9cc05acb7..f168d71af 100644 --- a/src/lib_shell/chain.ml +++ b/src/lib_shell/chain.ml @@ -39,6 +39,7 @@ type data = State.chain_data = { current_mempool: Mempool.t ; live_blocks: Block_hash.Set.t ; live_operations: Operation_hash.Set.t ; + test_chain: Chain_id.t option ; } let data chain_state = @@ -87,6 +88,7 @@ let locked_set_head chain_store data block = current_mempool = Mempool.empty ; live_blocks ; live_operations ; + test_chain = None ; } let set_head chain_state block = diff --git a/src/lib_shell/chain.mli b/src/lib_shell/chain.mli index 481168c8e..13bd60391 100644 --- a/src/lib_shell/chain.mli +++ b/src/lib_shell/chain.mli @@ -23,6 +23,7 @@ type data = { current_mempool: Mempool.t ; live_blocks: Block_hash.Set.t ; live_operations: Operation_hash.Set.t ; + test_chain: Chain_id.t option ; } (** Reading atomically all the chain data. *) diff --git a/src/lib_shell/node.ml b/src/lib_shell/node.ml index ef0805402..1a19ce319 100644 --- a/src/lib_shell/node.ml +++ b/src/lib_shell/node.ml @@ -127,14 +127,13 @@ let create { genesis ; store_root ; context_root ; chain_validator_limits = init_p2p p2p_params >>=? fun p2p -> State.read - ~store_root ~context_root ?patch_context () >>=? fun state -> + ~store_root ~context_root ?patch_context genesis >>=? fun (state, mainchain_state) -> let distributed_db = Distributed_db.create state p2p in Validator.create state distributed_db peer_validator_limits block_validator_limits prevalidator_limits chain_validator_limits >>= fun validator -> - may_create_chain state genesis >>= fun mainchain_state -> Validator.activate validator ?max_child_ttl mainchain_state >>= fun mainchain_validator -> let shutdown () = diff --git a/src/lib_shell/state.ml b/src/lib_shell/state.ml index 01d1ec508..818747d04 100644 --- a/src/lib_shell/state.ml +++ b/src/lib_shell/state.ml @@ -23,6 +23,7 @@ end type global_state = { global_data: global_data Shared.t ; protocol_store: Store.Protocol.store Shared.t ; + main_chain: Chain_id.t ; } and global_data = { @@ -60,6 +61,7 @@ and chain_data = { current_mempool: Mempool.t ; live_blocks: Block_hash.Set.t ; live_operations: Operation_hash.Set.t ; + test_chain: Chain_id.t option ; } and block = { @@ -242,6 +244,12 @@ module Chain = struct type t = chain_state type chain_state = t + let main { main_chain } = main_chain + let test chain_state = + read_chain_data chain_state begin fun _ chain_data -> + Lwt.return chain_data.test_chain + end + let allocate ~genesis ~faked_genesis_hash ~expiration ~allow_forked_chain ~current_head @@ -258,6 +266,7 @@ module Chain = struct current_mempool = Mempool.empty ; live_blocks = Block_hash.Set.singleton genesis.block ; live_operations = Operation_hash.Set.empty ; + test_chain = None ; } ; chain_data_store ; } @@ -367,12 +376,18 @@ module Chain = struct locked_read_all state data end - let get state id = + let get_exn state id = Shared.use state.global_data begin fun data -> - try return (Chain_id.Table.find data.chains id) - with Not_found -> fail (Unknown_chain id) + Lwt.return (Chain_id.Table.find data.chains id) end + let get state id = + Lwt.catch + (fun () -> get_exn state id >>= return) + (function + | Not_found -> fail (Unknown_chain id) + | exn -> Lwt.fail exn) + let all state = Shared.use state.global_data begin fun { chains } -> Lwt.return @@ @@ -706,6 +721,9 @@ let fork_testchain block protocol expiration = } in Chain.locked_create block.chain_state.global_state data chain_id ~expiration genesis commit >>= fun chain -> + update_chain_data block.chain_state begin fun _ chain_data -> + Lwt.return (Some { chain_data with test_chain = Some chain.chain_id }, ()) + end >>= fun () -> return chain end @@ -793,11 +811,16 @@ module Current_mempool = struct end +let may_create_chain state chain genesis = + Chain.get state chain >>= function + | Ok chain -> Lwt.return chain + | Error _ -> Chain.create state genesis + let read ?patch_context ~store_root ~context_root - () = + genesis = Store.init store_root >>=? fun global_store -> Context.init ?patch_context ~root:context_root >>= fun context_index -> let global_data = { @@ -805,12 +828,15 @@ let read global_store ; context_index ; } in + let main_chain = Chain_id.of_block_hash genesis.Chain.block in let state = { global_data = Shared.create global_data ; protocol_store = Shared.create @@ Store.Protocol.get global_store ; + main_chain ; } in Chain.read_all state >>=? fun () -> - return state + may_create_chain state main_chain genesis >>= fun main_chain_state -> + return (state, main_chain_state) let close { global_data } = Shared.use global_data begin fun { global_store } -> diff --git a/src/lib_shell/state.mli b/src/lib_shell/state.mli index 18fc33a07..a419d20a8 100644 --- a/src/lib_shell/state.mli +++ b/src/lib_shell/state.mli @@ -19,18 +19,6 @@ type t type global_state = t -(** Read the internal state of the node and initialize - the databases. *) -val read: - ?patch_context:(Context.t -> Context.t Lwt.t) -> - store_root:string -> - context_root:string -> - unit -> - global_state tzresult Lwt.t - -val close: - global_state -> unit Lwt.t - (** {2 Network} ************************************************************) (** Data specific to a given chain (e.g the main chain or the current @@ -58,6 +46,10 @@ module Chain : sig (** Look up for a chain by the hash of its genesis block. *) val get: global_state -> Chain_id.t -> chain_state tzresult Lwt.t + val get_exn: global_state -> Chain_id.t -> chain_state Lwt.t + + val main: global_state -> Chain_id.t + val test: chain_state -> Chain_id.t option Lwt.t (** Returns all the known chains. *) val all: global_state -> chain_state list Lwt.t @@ -174,6 +166,7 @@ type chain_data = { current_mempool: Mempool.t ; live_blocks: Block_hash.Set.t ; live_operations: Operation_hash.Set.t ; + test_chain: Chain_id.t option ; } val read_chain_data: @@ -224,3 +217,15 @@ module Current_mempool : sig not the provided one. *) end + +(** Read the internal state of the node and initialize + the databases. *) +val read: + ?patch_context:(Context.t -> Context.t Lwt.t) -> + store_root:string -> + context_root:string -> + Chain.genesis -> + (global_state * Chain.t) tzresult Lwt.t + +val close: + global_state -> unit Lwt.t diff --git a/src/lib_shell/test/test_locator.ml b/src/lib_shell/test/test_locator.ml index 3d4cd3ede..f9d0cb19c 100644 --- a/src/lib_shell/test/test_locator.ml +++ b/src/lib_shell/test/test_locator.ml @@ -54,10 +54,11 @@ let incr_fitness fitness = let init_chain base_dir : State.Chain.t Lwt.t = let store_root = base_dir // "store" in let context_root = base_dir // "context" in - State.read ~store_root ~context_root () >>= function + State.read + ~store_root ~context_root state_genesis_block >>= function | Error _ -> Pervasives.failwith "read err" - | Ok (state:State.global_state) -> - State.Chain.create state state_genesis_block + | Ok (_state, chain) -> + Lwt.return chain let block_header diff --git a/src/lib_shell/test/test_state.ml b/src/lib_shell/test/test_state.ml index 4acce23dd..0e36da8c0 100644 --- a/src/lib_shell/test/test_state.ml +++ b/src/lib_shell/test/test_state.ml @@ -127,7 +127,6 @@ type state = { vblock: (string, State.Block.t) Hashtbl.t ; state: State.t ; chain: State.Chain.t ; - init: unit -> State.t tzresult Lwt.t; } let vblock s = Hashtbl.find s.vblock @@ -142,15 +141,12 @@ let wrap_state_init f base_dir = begin let store_root = base_dir // "store" in let context_root = base_dir // "context" in - let init () = - State.read - ~store_root - ~context_root - () in - init () >>=? fun state -> - State.Chain.create state genesis >>= fun chain -> + State.read + ~store_root + ~context_root + genesis >>=? fun (state, chain) -> build_example_tree chain >>= fun vblock -> - f { state ; chain ; vblock ; init } >>=? fun () -> + f { state ; chain ; vblock } >>=? fun () -> return () end