diff --git a/src/lib_client_base/client_context.ml b/src/lib_client_base/client_context.ml index bd8566f47..ee2203375 100644 --- a/src/lib_client_base/client_context.ml +++ b/src/lib_client_base/client_context.ml @@ -49,6 +49,7 @@ class simple_printer log = end class type wallet = object + method with_lock : ( unit -> 'a Lwt.t) -> 'a Lwt.t method load : string -> default:'a -> 'a Data_encoding.encoding -> 'a tzresult Lwt.t method write : string -> 'a -> 'a Data_encoding.encoding -> unit tzresult Lwt.t end @@ -95,6 +96,7 @@ class proxy_context (obj : full) = object 'p -> 'q -> 'i -> (unit -> unit) tzresult Lwt.t = obj#call_streamed_service method error : type a b. (a, b) lwt_format -> a = obj#error method generic_json_call = obj#generic_json_call + method with_lock : type a. ( unit -> a Lwt.t) -> a Lwt.t = obj#with_lock method load : type a. string -> default:a -> a Data_encoding.encoding -> a tzresult Lwt.t = obj#load method log : type a. string -> (a, unit) lwt_format -> a = obj#log method message : type a. (a, unit) lwt_format -> a = obj#message diff --git a/src/lib_client_base/client_context.mli b/src/lib_client_base/client_context.mli index 0e62cdf42..2341db9fa 100644 --- a/src/lib_client_base/client_context.mli +++ b/src/lib_client_base/client_context.mli @@ -29,6 +29,7 @@ class type io = object end class type wallet = object + method with_lock : ( unit -> 'a Lwt.t) -> 'a Lwt.t method load : string -> default:'a -> 'a Data_encoding.encoding -> 'a tzresult Lwt.t method write : string -> 'a -> 'a Data_encoding.encoding -> unit tzresult Lwt.t end diff --git a/src/lib_client_base_unix/client_context_unix.ml b/src/lib_client_base_unix/client_context_unix.ml index 3485b371f..68d8d80de 100644 --- a/src/lib_client_base_unix/client_context_unix.ml +++ b/src/lib_client_base_unix/client_context_unix.ml @@ -10,38 +10,62 @@ open Client_context class unix_wallet ~base_dir : wallet = object (self) - method private filename alias_name = - Filename.concat - base_dir - (Str.(global_replace (regexp_string " ") "_" alias_name) ^ "s") - method load : type a. string -> default:a -> a Data_encoding.encoding -> a tzresult Lwt.t = - fun alias_name ~default encoding -> - let filename = self#filename alias_name in - if not (Sys.file_exists filename) then - return default - else - Lwt_utils_unix.Json.read_file filename - |> generic_trace - "couldn't to read the %s file" alias_name >>=? fun json -> - match Data_encoding.Json.destruct encoding json with - | exception _ -> (* TODO print_error *) - failwith "didn't understand the %s file" alias_name - | data -> - return data + method private filename alias_name = + Filename.concat + base_dir + (Str.(global_replace (regexp_string " ") "_" alias_name) ^ "s") - method write : - type a. string -> a -> a Data_encoding.encoding -> unit tzresult Lwt.t = - fun alias_name list encoding -> - Lwt.catch - (fun () -> - Lwt_utils_unix.create_dir base_dir >>= fun () -> - let filename = self#filename alias_name in - let json = Data_encoding.Json.construct encoding list in - Lwt_utils_unix.Json.write_file filename json) - (fun exn -> Lwt.return (error_exn exn)) - |> generic_trace "could not write the %s alias file." alias_name -end + method with_lock : type a. ( unit -> a Lwt.t) -> a Lwt.t = + (fun f -> + let unlock fd = + let fd = Lwt_unix.unix_file_descr fd in + Unix.lockf fd Unix.F_ULOCK 0; + Unix.close fd + in + let lock () = + Lwt_unix.openfile (Filename.concat base_dir "wallet_lock") + Lwt_unix.[O_CREAT; O_WRONLY] 0o644 >>= fun fd -> + Lwt_unix.lockf fd Unix.F_LOCK 0 >>= fun () -> + Lwt.return (fd,(Lwt_unix.on_signal Sys.sigint + (fun _s -> + unlock fd; + exit 0 (* exit code? *) ))) + in + lock () >>= fun (fd,sh) -> + (* catch might be useless if f always uses the error monad *) + Lwt.catch f (function e -> Lwt.return (unlock fd; raise e)) >>= fun res -> + Lwt.return (unlock fd) >>= fun () -> + Lwt_unix.disable_signal_handler sh; + Lwt.return res) + + method load : type a. string -> default:a -> a Data_encoding.encoding -> a tzresult Lwt.t = + fun alias_name ~default encoding -> + let filename = self#filename alias_name in + if not (Sys.file_exists filename) then + return default + else + Lwt_utils_unix.Json.read_file filename + |> generic_trace + "could not read the %s alias file" alias_name >>=? fun json -> + match Data_encoding.Json.destruct encoding json with + | exception _ -> (* TODO print_error *) + failwith "did not understand the %s alias file" alias_name + | data -> + return data + + method write : + type a. string -> a -> a Data_encoding.encoding -> unit tzresult Lwt.t = + fun alias_name list encoding -> + Lwt.catch + (fun () -> + Lwt_utils_unix.create_dir base_dir >>= fun () -> + let filename = self#filename alias_name in + let json = Data_encoding.Json.construct encoding list in + Lwt_utils_unix.Json.write_file filename json) + (fun exn -> Lwt.return (error_exn exn)) + |> generic_trace "could not write the %s alias file." alias_name + end class unix_prompter = object method prompt : type a. (a, string tzresult) lwt_format -> a =