diff --git a/src/lib_client_base_unix/client_context_unix.ml b/src/lib_client_base_unix/client_context_unix.ml index 68d8d80de..bc4171d68 100644 --- a/src/lib_client_base_unix/client_context_unix.ml +++ b/src/lib_client_base_unix/client_context_unix.ml @@ -11,61 +11,61 @@ open Client_context class unix_wallet ~base_dir : wallet = object (self) - method private filename alias_name = - Filename.concat - base_dir - (Str.(global_replace (regexp_string " ") "_" alias_name) ^ "s") + method private filename alias_name = + Filename.concat + base_dir + (Str.(global_replace (regexp_string " ") "_" alias_name) ^ "s") - method with_lock : type a. ( unit -> a Lwt.t) -> a Lwt.t = - (fun f -> - let unlock fd = - let fd = Lwt_unix.unix_file_descr fd in - Unix.lockf fd Unix.F_ULOCK 0; - Unix.close fd - in - let lock () = - Lwt_unix.openfile (Filename.concat base_dir "wallet_lock") - Lwt_unix.[O_CREAT; O_WRONLY] 0o644 >>= fun fd -> - Lwt_unix.lockf fd Unix.F_LOCK 0 >>= fun () -> - Lwt.return (fd,(Lwt_unix.on_signal Sys.sigint - (fun _s -> - unlock fd; - exit 0 (* exit code? *) ))) - in - lock () >>= fun (fd,sh) -> - (* catch might be useless if f always uses the error monad *) - Lwt.catch f (function e -> Lwt.return (unlock fd; raise e)) >>= fun res -> - Lwt.return (unlock fd) >>= fun () -> - Lwt_unix.disable_signal_handler sh; - Lwt.return res) + method with_lock : type a. ( unit -> a Lwt.t) -> a Lwt.t = + (fun f -> + let unlock fd = + let fd = Lwt_unix.unix_file_descr fd in + Unix.lockf fd Unix.F_ULOCK 0; + Unix.close fd + in + let lock () = + Lwt_unix.openfile (Filename.concat base_dir "wallet_lock") + Lwt_unix.[O_CREAT; O_WRONLY] 0o644 >>= fun fd -> + Lwt_unix.lockf fd Unix.F_LOCK 0 >>= fun () -> + Lwt.return (fd,(Lwt_unix.on_signal Sys.sigint + (fun _s -> + unlock fd; + exit 0 (* exit code? *) ))) + in + lock () >>= fun (fd,sh) -> + (* catch might be useless if f always uses the error monad *) + Lwt.catch f (function e -> Lwt.return (unlock fd; raise e)) >>= fun res -> + Lwt.return (unlock fd) >>= fun () -> + Lwt_unix.disable_signal_handler sh; + Lwt.return res) - method load : type a. string -> default:a -> a Data_encoding.encoding -> a tzresult Lwt.t = - fun alias_name ~default encoding -> - let filename = self#filename alias_name in - if not (Sys.file_exists filename) then - return default - else - Lwt_utils_unix.Json.read_file filename - |> generic_trace - "could not read the %s alias file" alias_name >>=? fun json -> - match Data_encoding.Json.destruct encoding json with - | exception _ -> (* TODO print_error *) - failwith "did not understand the %s alias file" alias_name - | data -> - return data + method load : type a. string -> default:a -> a Data_encoding.encoding -> a tzresult Lwt.t = + fun alias_name ~default encoding -> + let filename = self#filename alias_name in + if not (Sys.file_exists filename) then + return default + else + Lwt_utils_unix.Json.read_file filename + |> generic_trace + "could not read the %s alias file" alias_name >>=? fun json -> + match Data_encoding.Json.destruct encoding json with + | exception _ -> (* TODO print_error *) + failwith "did not understand the %s alias file" alias_name + | data -> + return data - method write : - type a. string -> a -> a Data_encoding.encoding -> unit tzresult Lwt.t = - fun alias_name list encoding -> - Lwt.catch - (fun () -> - Lwt_utils_unix.create_dir base_dir >>= fun () -> - let filename = self#filename alias_name in - let json = Data_encoding.Json.construct encoding list in - Lwt_utils_unix.Json.write_file filename json) - (fun exn -> Lwt.return (error_exn exn)) - |> generic_trace "could not write the %s alias file." alias_name - end + method write : + type a. string -> a -> a Data_encoding.encoding -> unit tzresult Lwt.t = + fun alias_name list encoding -> + Lwt.catch + (fun () -> + Lwt_utils_unix.create_dir base_dir >>= fun () -> + let filename = self#filename alias_name in + let json = Data_encoding.Json.construct encoding list in + Lwt_utils_unix.Json.write_file filename json) + (fun exn -> Lwt.return (error_exn exn)) + |> generic_trace "could not write the %s alias file." alias_name +end class unix_prompter = object method prompt : type a. (a, string tzresult) lwt_format -> a = diff --git a/src/proto_alpha/lib_baking/client_baking_nonces.ml b/src/proto_alpha/lib_baking/client_baking_nonces.ml index 4daa8ad33..10bcd829b 100644 --- a/src/proto_alpha/lib_baking/client_baking_nonces.ml +++ b/src/proto_alpha/lib_baking/client_baking_nonces.ml @@ -37,26 +37,26 @@ let mem (wallet : #Client_context.wallet) block_hash = let find (wallet : #Client_context.wallet) block_hash = wallet#with_lock ( fun () -> - load wallet >>|? fun data -> - try Some (List.assoc block_hash data) - with Not_found -> None) + load wallet >>|? fun data -> + try Some (List.assoc block_hash data) + with Not_found -> None) let add (wallet : #Client_context.wallet) block_hash nonce = wallet#with_lock ( fun () -> - load wallet >>=? fun data -> - save wallet ((block_hash, nonce) :: - List.remove_assoc block_hash data)) + load wallet >>=? fun data -> + save wallet ((block_hash, nonce) :: + List.remove_assoc block_hash data)) let del (wallet : #Client_context.wallet) block_hash = wallet#with_lock ( fun () -> - load wallet >>=? fun data -> - save wallet (List.remove_assoc block_hash data)) + load wallet >>=? fun data -> + save wallet (List.remove_assoc block_hash data)) let dels (wallet : #Client_context.wallet) hashes = wallet#with_lock ( fun () -> - load wallet >>=? fun data -> - save wallet @@ - List.fold_left - (fun data hash -> List.remove_assoc hash data) - data hashes) + load wallet >>=? fun data -> + save wallet @@ + List.fold_left + (fun data hash -> List.remove_assoc hash data) + data hashes) diff --git a/src/proto_alpha/lib_baking/test/proto_alpha_helpers.ml b/src/proto_alpha/lib_baking/test/proto_alpha_helpers.ml index 36e596679..16eb5e8e6 100644 --- a/src/proto_alpha/lib_baking/test/proto_alpha_helpers.ml +++ b/src/proto_alpha/lib_baking/test/proto_alpha_helpers.ml @@ -37,6 +37,7 @@ let no_write_context ?(block = `Head 0) config : #Client_context.full = object a -> a Data_encoding.encoding -> unit Error_monad.tzresult Lwt.t = fun _ _ _ -> return () + method with_lock : type a. (unit -> a Lwt.t) -> a Lwt.t = fun f -> f () method block = block method confirmations = None method prompt : type a. (a, string tzresult) Client_context.lwt_format -> a =