From 5aed8b3ef5ce56e46d8903d43ae7ec157fec523a Mon Sep 17 00:00:00 2001 From: Ulf Norell Date: Mon, 13 May 2019 13:39:17 +0200 Subject: [PATCH 1/8] Check stateful annotations Functions must be annotated as `stateful` in order to - Update the contract state (using `put`) - Call `Chain.spend` or other primitive functions that cost tokens - Call an Oracle or AENS function that requires a signature - Make a remote call with a non-zero value - Construct a lambda calling a stateful function It does not need to be stateful to - Read the contract state - Call another contract with value=0, even when the remote function is stateful --- src/aeso_ast_infer_types.erl | 44 +++++++++++++++++++++---- test/aeso_abi_tests.erl | 2 +- test/aeso_compiler_tests.erl | 9 ++++++ test/contracts/basic_auth.aes | 2 +- test/contracts/bitcoin_auth.aes | 2 +- test/contracts/counter.aes | 2 +- test/contracts/dutch_auction.aes | 2 +- test/contracts/environment.aes | 4 +-- test/contracts/factorial.aes | 2 +- test/contracts/fundme.aes | 2 +- test/contracts/maps.aes | 24 +++++++------- test/contracts/oracles.aes | 26 +++++++-------- test/contracts/remote_call.aes | 2 +- test/contracts/simple_storage.aes | 2 +- test/contracts/spend_test.aes | 8 ++--- test/contracts/state_handling.aes | 18 +++++------ test/contracts/stateful.aes | 54 +++++++++++++++++++++++++++++++ test/contracts/variant_types.aes | 6 ++-- 18 files changed, 152 insertions(+), 59 deletions(-) create mode 100644 test/contracts/stateful.aes diff --git a/src/aeso_ast_infer_types.erl b/src/aeso_ast_infer_types.erl index 3f65c9b..be61478 100644 --- a/src/aeso_ast_infer_types.erl +++ b/src/aeso_ast_infer_types.erl @@ -104,6 +104,8 @@ , fields = #{} :: #{ name() => [field_info()] } %% fields are global , namespace = [] :: qname() , in_pattern = false :: boolean() + , stateful = false :: boolean() + , current_function = none :: none | aeso_syntax:id() }). -type env() :: #env{}. @@ -197,7 +199,7 @@ bind_state(Env) -> false -> {id, Ann, "event"} %% will cause type error if used(?) end, Env1 = bind_funs([{"state", State}, - {"put", {fun_t, Ann, [], [State], Unit}}], Env), + {"put", {type_sig, [stateful | Ann], [], [State], Unit}}], Env), %% We bind Chain.event in a local 'Chain' namespace. pop_scope( @@ -357,11 +359,12 @@ global_env() -> Pair = fun(A, B) -> {tuple_t, Ann, [A, B]} end, Fun = fun(Ts, T) -> {type_sig, Ann, [], Ts, T} end, Fun1 = fun(S, T) -> Fun([S], T) end, - TVar = fun(X) -> {tvar, Ann, "'" ++ X} end, + StateFun = fun(Ts, T) -> {type_sig, [stateful|Ann], [], Ts, T} end, + TVar = fun(X) -> {tvar, Ann, "'" ++ X} end, SignId = {id, Ann, "signature"}, SignDef = {bytes, Ann, <<0:64/unit:8>>}, Signature = {named_arg_t, Ann, SignId, SignId, {typed, Ann, SignDef, SignId}}, - SignFun = fun(Ts, T) -> {type_sig, Ann, [Signature], Ts, T} end, + SignFun = fun(Ts, T) -> {type_sig, [stateful|Ann], [Signature], Ts, T} end, TTL = {qid, Ann, ["Chain", "ttl"]}, Fee = Int, [A, Q, R, K, V] = lists:map(TVar, ["a", "q", "r", "k", "v"]), @@ -390,7 +393,7 @@ global_env() -> ChainScope = #scope { funs = MkDefs( %% Spend transaction. - [{"spend", Fun([Address, Int], Unit)}, + [{"spend", StateFun([Address, Int], Unit)}, %% Chain environment {"balance", Fun1(Address, Int)}, {"block_hash", Fun1(Int, Int)}, @@ -420,7 +423,7 @@ global_env() -> { funs = MkDefs( [{"register", SignFun([Address, Fee, TTL], Oracle(Q, R))}, {"query_fee", Fun([Oracle(Q, R)], Fee)}, - {"query", Fun([Oracle(Q, R), Q, Fee, TTL, TTL], Query(Q, R))}, + {"query", StateFun([Oracle(Q, R), Q, Fee, TTL, TTL], Query(Q, R))}, {"get_question", Fun([Oracle(Q, R), Query(Q, R)], Q)}, {"respond", SignFun([Oracle(Q, R), Query(Q, R), R], Unit)}, {"extend", SignFun([Oracle(Q, R), TTL], Unit)}, @@ -862,7 +865,9 @@ infer_letrec(Env, Defs) -> [print_typesig(S) || S <- TypeSigs], {TypeSigs, NewDefs}. -infer_letfun(Env, {letfun, Attrib, Fun = {id, NameAttrib, Name}, Args, What, Body}) -> +infer_letfun(Env0, {letfun, Attrib, Fun = {id, NameAttrib, Name}, Args, What, Body}) -> + Env = Env0#env{ stateful = aeso_syntax:get_ann(stateful, Attrib, false), + current_function = Fun }, check_unique_arg_names(Fun, Args), ArgTypes = [{ArgName, check_type(Env, arg_type(T))} || {arg, _, ArgName, T} <- Args], ExpectedType = check_type(Env, arg_type(What)), @@ -904,6 +909,7 @@ lookup_name(Env, As, Id, Options) -> {Id, fresh_uvar(As)}; {QId, {_, Ty}} -> Freshen = proplists:get_value(freshen, Options, false), + check_stateful(Env, Id, Ty), Ty1 = case Ty of {type_sig, _, _, _, _} -> freshen_type(typesig_to_fun_t(Ty)); _ when Freshen -> freshen_type(Ty); @@ -912,6 +918,23 @@ lookup_name(Env, As, Id, Options) -> {set_qname(QId, Id), Ty1} end. +check_stateful(#env{ stateful = false, current_function = Fun }, Id, Type = {type_sig, _, _, _, _}) -> + case aeso_syntax:get_ann(stateful, Type, false) of + false -> ok; + true -> + type_error({stateful_not_allowed, Id, Fun}) + end; +check_stateful(_Env, _Id, _Type) -> ok. + +%% Hack: don't allow passing the 'value' named arg if not stateful. This only +%% works since the user can't create functions with named arguments. +check_stateful_named_arg(#env{ stateful = false, current_function = Fun }, {id, _, "value"}, Default) -> + case Default of + {int, _, 0} -> ok; + _ -> type_error({value_arg_not_allowed, Default, Fun}) + end; +check_stateful_named_arg(_, _, _) -> ok. + check_expr(Env, Expr, Type) -> E = {typed, _, _, Type1} = infer_expr(Env, Expr), unify(Env, Type1, Type, {check_expr, Expr, Type1, Type}), @@ -1072,6 +1095,7 @@ infer_expr(Env, {lam, Attrs, Args, Body}) -> infer_named_arg(Env, NamedArgs, {named_arg, Ann, Id, E}) -> CheckedExpr = {typed, _, _, ArgType} = infer_expr(Env, E), + check_stateful_named_arg(Env, Id, E), add_named_argument_constraint( #named_argument_constraint{ args = NamedArgs, @@ -1821,7 +1845,7 @@ create_type_errors() -> ets_new(type_errors, [bag]). destroy_and_report_type_errors(Env) -> - Errors = ets_tab2list(type_errors), + Errors = lists:reverse(ets_tab2list(type_errors)), %% io:format("Type errors now: ~p\n", [Errors]), PPErrors = [ pp_error(unqualify(Env, Err)) || Err <- Errors ], ets_delete(type_errors), @@ -1940,6 +1964,12 @@ pp_error({namespace, _Pos, {con, Pos, Name}, _Def}) -> pp_error({repeated_arg, Fun, Arg}) -> io_lib:format("Repeated argument ~s to function ~s (at ~s).\n", [Arg, pp(Fun), pp_loc(Fun)]); +pp_error({stateful_not_allowed, Id, Fun}) -> + io_lib:format("Cannot reference stateful function ~s (at ~s)\nin the definition of non-stateful function ~s.\n", + [pp(Id), pp_loc(Id), pp(Fun)]); +pp_error({value_arg_not_allowed, Value, Fun}) -> + io_lib:format("Cannot pass non-zero value argument ~s (at ~s)\nin the definition of non-stateful function ~s.\n", + [pp_expr("", Value), pp_loc(Value), pp(Fun)]); pp_error(Err) -> io_lib:format("Unknown error: ~p\n", [Err]). diff --git a/test/aeso_abi_tests.erl b/test/aeso_abi_tests.erl index 5357c1d..2ac6480 100644 --- a/test/aeso_abi_tests.erl +++ b/test/aeso_abi_tests.erl @@ -160,7 +160,7 @@ oracle_test() -> permissive_literals_fail_test() -> Contract = "contract OracleTest =\n" - " function haxx(o : oracle(list(string), option(int))) =\n" + " stateful function haxx(o : oracle(list(string), option(int))) =\n" " Chain.spend(o, 1000000)\n", {error, <<"Type errors\nCannot unify", _/binary>>} = aeso_compiler:check_call(Contract, "haxx", ["#123"], []), diff --git a/test/aeso_compiler_tests.erl b/test/aeso_compiler_tests.erl index 485793d..62bee65 100644 --- a/test/aeso_compiler_tests.erl +++ b/test/aeso_compiler_tests.erl @@ -307,4 +307,13 @@ failing_contracts() -> " ak_2gx9MEFxKvY9vMG5YnqnXWv1hCsX7rgnfvBLJS4aQurustR1rt : address\n" "against the expected type\n" " bytes(32)">>]} + , {"stateful", + [<<"Cannot reference stateful function Chain.spend (at line 13, column 33)\nin the definition of non-stateful function fail1.">>, + <<"Cannot reference stateful function local_spend (at line 14, column 33)\nin the definition of non-stateful function fail2.">>, + <<"Cannot reference stateful function Chain.spend (at line 16, column 15)\nin the definition of non-stateful function fail3.">>, + <<"Cannot reference stateful function Chain.spend (at line 20, column 31)\nin the definition of non-stateful function fail4.">>, + <<"Cannot reference stateful function Chain.spend (at line 35, column 53)\nin the definition of non-stateful function fail5.">>, + <<"Cannot pass non-zero value argument 1000 (at line 48, column 55)\nin the definition of non-stateful function fail6.">>, + <<"Cannot pass non-zero value argument 1000 (at line 49, column 54)\nin the definition of non-stateful function fail7.">>, + <<"Cannot pass non-zero value argument 1000 (at line 52, column 17)\nin the definition of non-stateful function fail8.">>]} ]. diff --git a/test/contracts/basic_auth.aes b/test/contracts/basic_auth.aes index 3ec9e55..641c01e 100644 --- a/test/contracts/basic_auth.aes +++ b/test/contracts/basic_auth.aes @@ -4,7 +4,7 @@ contract BasicAuth = function init() = { nonce = 1, owner = Call.caller } - function authorize(n : int, s : signature) : bool = + stateful function authorize(n : int, s : signature) : bool = require(n >= state.nonce, "Nonce too low") require(n =< state.nonce, "Nonce too high") put(state{ nonce = n + 1 }) diff --git a/test/contracts/bitcoin_auth.aes b/test/contracts/bitcoin_auth.aes index 79762ce..ba60c31 100644 --- a/test/contracts/bitcoin_auth.aes +++ b/test/contracts/bitcoin_auth.aes @@ -3,7 +3,7 @@ contract BitcoinAuth = function init(owner' : bytes(64)) = { nonce = 1, owner = owner' } - function authorize(n : int, s : signature) : bool = + stateful function authorize(n : int, s : signature) : bool = require(n >= state.nonce, "Nonce too low") require(n =< state.nonce, "Nonce too high") put(state{ nonce = n + 1 }) diff --git a/test/contracts/counter.aes b/test/contracts/counter.aes index 4015cef..3d77194 100644 --- a/test/contracts/counter.aes +++ b/test/contracts/counter.aes @@ -5,5 +5,5 @@ contract Counter = function init(val) = { value = val } function get() = state.value - function tick() = put(state{ value = state.value + 1 }) + stateful function tick() = put(state{ value = state.value + 1 }) diff --git a/test/contracts/dutch_auction.aes b/test/contracts/dutch_auction.aes index 6106146..deca3ef 100644 --- a/test/contracts/dutch_auction.aes +++ b/test/contracts/dutch_auction.aes @@ -10,7 +10,7 @@ contract DutchAuction = sold : bool } // Add to work around current lack of predefined functions - private function spend(to, amount) = + private stateful function spend(to, amount) = let total = Contract.balance Chain.spend(to, amount) total - amount diff --git a/test/contracts/environment.aes b/test/contracts/environment.aes index 3b48a40..3f7f721 100644 --- a/test/contracts/environment.aes +++ b/test/contracts/environment.aes @@ -12,7 +12,7 @@ contract Environment = function init(remote) = {remote = remote} - function set_remote(remote) = put({remote = remote}) + stateful function set_remote(remote) = put({remote = remote}) // -- Information about the this contract --- @@ -38,7 +38,7 @@ contract Environment = // Value function call_value() : int = Call.value - function nested_value(value : int) : int = + stateful function nested_value(value : int) : int = state.remote.call_value(value = value / 2) // Gas price diff --git a/test/contracts/factorial.aes b/test/contracts/factorial.aes index 447196e..7a610b5 100644 --- a/test/contracts/factorial.aes +++ b/test/contracts/factorial.aes @@ -9,7 +9,7 @@ contract Factorial = function init(worker) = {worker = worker} - function set_worker(worker) = put(state{worker = worker}) + stateful function set_worker(worker) = put(state{worker = worker}) function fac(x : int) : int = if(x == 0) 1 diff --git a/test/contracts/fundme.aes b/test/contracts/fundme.aes index eed32c9..08f5d69 100644 --- a/test/contracts/fundme.aes +++ b/test/contracts/fundme.aes @@ -15,7 +15,7 @@ contract FundMe = private function require(b : bool, err : string) = if(!b) abort(err) - private function spend(args : spend_args) = + private stateful function spend(args : spend_args) = Chain.spend(args.recipient, args.amount) public function init(beneficiary, deadline, goal) : state = diff --git a/test/contracts/maps.aes b/test/contracts/maps.aes index 02d4b13..b7aa48d 100644 --- a/test/contracts/maps.aes +++ b/test/contracts/maps.aes @@ -17,8 +17,8 @@ contract Maps = { ["one"] = {x = 1, y = 2}, ["two"] = {x = 3, y = 4}, ["three"] = {x = 5, y = 6} } - function map_state_i() = put(state{ map_i = map_i() }) - function map_state_s() = put(state{ map_s = map_s() }) + stateful function map_state_i() = put(state{ map_i = map_i() }) + stateful function map_state_s() = put(state{ map_s = map_s() }) // m[k] function get_i(k, m : map(int, pt)) = m[k] @@ -35,20 +35,20 @@ contract Maps = // m{[k] = v} function set_i(k, p, m : map(int, pt)) = m{ [k] = p } function set_s(k, p, m : map(string, pt)) = m{ [k] = p } - function set_state_i(k, p) = put(state{ map_i = set_i(k, p, state.map_i) }) - function set_state_s(k, p) = put(state{ map_s = set_s(k, p, state.map_s) }) + stateful function set_state_i(k, p) = put(state{ map_i = set_i(k, p, state.map_i) }) + stateful function set_state_s(k, p) = put(state{ map_s = set_s(k, p, state.map_s) }) // m{f[k].x = v} function setx_i(k, x, m : map(int, pt)) = m{ [k].x = x } function setx_s(k, x, m : map(string, pt)) = m{ [k].x = x } - function setx_state_i(k, x) = put(state{ map_i[k].x = x }) - function setx_state_s(k, x) = put(state{ map_s[k].x = x }) + stateful function setx_state_i(k, x) = put(state{ map_i[k].x = x }) + stateful function setx_state_s(k, x) = put(state{ map_s[k].x = x }) // m{[k] @ x = v } function addx_i(k, d, m : map(int, pt)) = m{ [k].x @ x = x + d } function addx_s(k, d, m : map(string, pt)) = m{ [k].x @ x = x + d } - function addx_state_i(k, d) = put(state{ map_i[k].x @ x = x + d }) - function addx_state_s(k, d) = put(state{ map_s[k].x @ x = x + d }) + stateful function addx_state_i(k, d) = put(state{ map_i[k].x @ x = x + d }) + stateful function addx_state_s(k, d) = put(state{ map_s[k].x @ x = x + d }) // m{[k = def] @ x = v } function addx_def_i(k, v, d, m : map(int, pt)) = m{ [k = v].x @ x = x + d } @@ -77,8 +77,8 @@ contract Maps = // Map.delete function delete_i(k, m : map(int, pt)) = Map.delete(k, m) function delete_s(k, m : map(string, pt)) = Map.delete(k, m) - function delete_state_i(k) = put(state{ map_i = delete_i(k, state.map_i) }) - function delete_state_s(k) = put(state{ map_s = delete_s(k, state.map_s) }) + stateful function delete_state_i(k) = put(state{ map_i = delete_i(k, state.map_i) }) + stateful function delete_state_s(k) = put(state{ map_s = delete_s(k, state.map_s) }) // Map.size function size_i(m : map(int, pt)) = Map.size(m) @@ -95,6 +95,6 @@ contract Maps = // Map.from_list function fromlist_i(xs : list((int, pt))) = Map.from_list(xs) function fromlist_s(xs : list((string, pt))) = Map.from_list(xs) - function fromlist_state_i(xs) = put(state{ map_i = fromlist_i(xs) }) - function fromlist_state_s(xs) = put(state{ map_s = fromlist_s(xs) }) + stateful function fromlist_state_i(xs) = put(state{ map_i = fromlist_i(xs) }) + stateful function fromlist_state_s(xs) = put(state{ map_s = fromlist_s(xs) }) diff --git a/test/contracts/oracles.aes b/test/contracts/oracles.aes index 4f125fc..e6e44b4 100644 --- a/test/contracts/oracles.aes +++ b/test/contracts/oracles.aes @@ -9,22 +9,22 @@ contract Oracles = type oracle_id = oracle(query_t, answer_t) type query_id = oracle_query(query_t, answer_t) - function registerOracle(acct : address, + stateful function registerOracle(acct : address, qfee : fee, ttl : ttl) : oracle_id = Oracle.register(acct, qfee, ttl) - function registerIntIntOracle(acct : address, + stateful function registerIntIntOracle(acct : address, qfee : fee, ttl : ttl) : oracle(int, int) = Oracle.register(acct, qfee, ttl) - function registerStringStringOracle(acct : address, + stateful function registerStringStringOracle(acct : address, qfee : fee, ttl : ttl) : oracle(string, string) = Oracle.register(acct, qfee, ttl) - function signedRegisterOracle(acct : address, + stateful function signedRegisterOracle(acct : address, sign : signature, qfee : fee, ttl : ttl) : oracle_id = @@ -33,7 +33,7 @@ contract Oracles = function queryFee(o : oracle_id) : fee = Oracle.query_fee(o) - function createQuery(o : oracle_id, + stateful function createQuery(o : oracle_id, q : query_t, qfee : fee, qttl : ttl, @@ -42,7 +42,7 @@ contract Oracles = Oracle.query(o, q, qfee, qttl, rttl) // Do not use in production! - function unsafeCreateQuery(o : oracle_id, + stateful function unsafeCreateQuery(o : oracle_id, q : query_t, qfee : fee, qttl : ttl, @@ -50,7 +50,7 @@ contract Oracles = Oracle.query(o, q, qfee, qttl, rttl) // Do not use in production! - function unsafeCreateQueryThenErr(o : oracle_id, + stateful function unsafeCreateQueryThenErr(o : oracle_id, q : query_t, qfee : fee, qttl : ttl, @@ -59,21 +59,21 @@ contract Oracles = require(qfee >= 100000000000000000, "causing a late error") res - function extendOracle(o : oracle_id, + stateful function extendOracle(o : oracle_id, ttl : ttl) : () = Oracle.extend(o, ttl) - function signedExtendOracle(o : oracle_id, + stateful function signedExtendOracle(o : oracle_id, sign : signature, // Signed oracle address ttl : ttl) : () = Oracle.extend(o, signature = sign, ttl) - function respond(o : oracle_id, + stateful function respond(o : oracle_id, q : query_id, r : answer_t) : () = Oracle.respond(o, q, r) - function signedRespond(o : oracle_id, + stateful function signedRespond(o : oracle_id, q : query_id, sign : signature, r : answer_t) : () = @@ -96,13 +96,13 @@ contract Oracles = datatype complexQuestion = Why(int) | How(string) datatype complexAnswer = NoAnswer | Answer(complexQuestion, string, int) - function complexOracle(question) = + stateful function complexOracle(question) = let o = Oracle.register(Contract.address, 0, FixedTTL(1000)) : oracle(complexQuestion, complexAnswer) let q = Oracle.query(o, question, 0, RelativeTTL(100), RelativeTTL(100)) Oracle.respond(o, q, Answer(question, "magic", 1337)) Oracle.get_answer(o, q) - function signedComplexOracle(question, sig) = + stateful function signedComplexOracle(question, sig) = let o = Oracle.register(signature = sig, Contract.address, 0, FixedTTL(1000)) : oracle(complexQuestion, complexAnswer) let q = Oracle.query(o, question, 0, RelativeTTL(100), RelativeTTL(100)) Oracle.respond(o, q, Answer(question, "magic", 1337), signature = sig) diff --git a/test/contracts/remote_call.aes b/test/contracts/remote_call.aes index 6c7998f..333fd47 100644 --- a/test/contracts/remote_call.aes +++ b/test/contracts/remote_call.aes @@ -11,7 +11,7 @@ contract Remote3 = contract RemoteCall = - function call(r : Remote1, x : int) : int = + stateful function call(r : Remote1, x : int) : int = r.main(gas = 10000, value = 10, x) function staged_call(r1 : Remote1, r2 : Remote2, x : int) = diff --git a/test/contracts/simple_storage.aes b/test/contracts/simple_storage.aes index 2c45a53..ee19616 100644 --- a/test/contracts/simple_storage.aes +++ b/test/contracts/simple_storage.aes @@ -24,5 +24,5 @@ contract SimpleStorage = function get() : int = state.data - function set(value : int) = + stateful function set(value : int) = put(state{data = value}) diff --git a/test/contracts/spend_test.aes b/test/contracts/spend_test.aes index 21140e4..ee95fd5 100644 --- a/test/contracts/spend_test.aes +++ b/test/contracts/spend_test.aes @@ -4,19 +4,19 @@ contract SpendContract = contract SpendTest = - function spend(to, amount) = + stateful function spend(to, amount) = let total = Contract.balance Chain.spend(to, amount) total - amount - function withdraw(amount) : int = + stateful function withdraw(amount) : int = spend(Call.caller, amount) - function withdraw_from(account, amount) = + stateful function withdraw_from(account, amount) = account.withdraw(amount) withdraw(amount) - function spend_from(from, to, amount) = + stateful function spend_from(from, to, amount) = from.withdraw(amount) Chain.spend(to, amount) Chain.balance(to) diff --git a/test/contracts/state_handling.aes b/test/contracts/state_handling.aes index 7fdd196..932f8ca 100644 --- a/test/contracts/state_handling.aes +++ b/test/contracts/state_handling.aes @@ -27,13 +27,13 @@ contract StateHandling = function read_s() = state.s function read_m() = state.m - function update(new_state : state) = put(new_state) - function update_i(new_i) = put(state{ i = new_i }) - function update_s(new_s) = put(state{ s = new_s }) - function update_m(new_m) = put(state{ m = new_m }) + stateful function update(new_state : state) = put(new_state) + stateful function update_i(new_i) = put(state{ i = new_i }) + stateful function update_s(new_s) = put(state{ s = new_s }) + stateful function update_m(new_m) = put(state{ m = new_m }) function pass_it(r : Remote) = r.look_at(state) - function nop(r : Remote) = put(state{ i = state.i }) + stateful function nop(r : Remote) = put(state{ i = state.i }) function return_it_s(r : Remote, big : bool) = let x = r.return_s(big) String.length(x) @@ -50,10 +50,10 @@ contract StateHandling = function pass_update_s(r : Remote, s) = r.fun_update_s(state, s) function pass_update_m(r : Remote, m) = r.fun_update_m(state, m) - function remote_update_i (r : Remote, i) = put(r.fun_update_i(state, i)) - function remote_update_s (r : Remote, s) = put(r.fun_update_s(state, s)) - function remote_update_m (r : Remote, m) = put(r.fun_update_m(state, m)) - function remote_update_mk(r : Remote, k, v) = put(r.fun_update_mk(state, k, v)) + stateful function remote_update_i (r : Remote, i) = put(r.fun_update_i(state, i)) + stateful function remote_update_s (r : Remote, s) = put(r.fun_update_s(state, s)) + stateful function remote_update_m (r : Remote, m) = put(r.fun_update_m(state, m)) + stateful function remote_update_mk(r : Remote, k, v) = put(r.fun_update_mk(state, k, v)) // remote called function look_at(s : state) = () diff --git a/test/contracts/stateful.aes b/test/contracts/stateful.aes new file mode 100644 index 0000000..5121999 --- /dev/null +++ b/test/contracts/stateful.aes @@ -0,0 +1,54 @@ + +contract Remote = + stateful function remote_spend : (address, int) => () + function remote_pure : int => int + +contract Stateful = + + private function pure(x) = x + 1 + private stateful function local_spend(a) = + Chain.spend(a, 1000) + + // Non-stateful functions cannot mention stateful functions + function fail1(a : address) = Chain.spend(a, 1000) + function fail2(a : address) = local_spend(a) + function fail3(a : address) = + let foo = Chain.spend + foo(a, 1000) + + // Private functions must also be annotated + private function fail4(a) = Chain.spend(a, 1000) + + // If annotated, stateful functions are allowed + stateful function ok1(a : address) = Chain.spend(a, 1000) + + // And pure functions are always allowed + stateful function ok2(a : address) = pure(5) + stateful function ok3(a : address) = + let foo = pure + foo(5) + + // No error here (fail4 is annotated as not stateful) + function ok4(a : address) = fail4(a) + + // Lamdbas are checked at the construction site + private function fail5() : address => () = (a) => Chain.spend(a, 1000) + + // .. so you can pass a stateful lambda to a non-stateful higher-order + // function: + private function apply(f : 'a => 'b, x) = f(x) + stateful function ok5(a : address) = + apply((val) => Chain.spend(a, val), 1000) + + // It doesn't matter if remote calls are stateful or not + function ok6(r : Remote) = r.remote_spend(Contract.address, 1000) + function ok7(r : Remote) = r.remote_pure(5) + + // But you can't send any tokens if not stateful + function fail6(r : Remote) = r.remote_spend(value = 1000, Contract.address, 1000) + function fail7(r : Remote) = r.remote_pure(value = 1000, 5) + function fail8(r : Remote) = + let foo = r.remote_pure + foo(value = 1000, 5) + function ok8(r : Remote) = r.remote_spend(Contract.address, 1000, value = 0) + diff --git a/test/contracts/variant_types.aes b/test/contracts/variant_types.aes index bdd88a9..59079c2 100644 --- a/test/contracts/variant_types.aes +++ b/test/contracts/variant_types.aes @@ -11,11 +11,11 @@ contract VariantTypes = function require(b) = if(!b) abort("required") - function start(bal : int) = + stateful function start(bal : int) = switch(state) Stopped => put(Started({owner = Call.caller, balance = bal, color = Grey(0)})) - function stop() = + stateful function stop() = switch(state) Started(st) => require(Call.caller == st.owner) @@ -23,7 +23,7 @@ contract VariantTypes = st.balance function get_color() = switch(state) Started(st) => st.color - function set_color(c) = switch(state) Started(st) => put(Started(st{color = c})) + stateful function set_color(c) = switch(state) Started(st) => put(Started(st{color = c})) function get_state() = state From 6bd2b7c4836e41712cd90e29222416756d26786a Mon Sep 17 00:00:00 2001 From: Ulf Norell Date: Mon, 13 May 2019 17:50:34 +0200 Subject: [PATCH 2/8] Remember source location when computing used names --- src/aeso_syntax_utils.erl | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/src/aeso_syntax_utils.erl b/src/aeso_syntax_utils.erl index 8a7f374..95b61d2 100644 --- a/src/aeso_syntax_utils.erl +++ b/src/aeso_syntax_utils.erl @@ -98,29 +98,31 @@ fold(Alg = #alg{zero = Zero, plus = Plus, scoped = Scoped}, Fun, K, X) -> %% Name dependencies used_ids(E) -> - [ X || {term, [X]} <- used(E) ]. + [ X || {{term, [X]}, _} <- used(E) ]. used_types(T) -> - [ X || {type, [X]} <- used(T) ]. + [ X || {{type, [X]}, _} <- used(T) ]. -type entity() :: {term, [string()]} | {type, [string()]} | {namespace, [string()]}. --spec entity_alg() -> alg([entity()]). +-spec entity_alg() -> alg(#{entity() => aeso_syntax:ann()}). entity_alg() -> IsBound = fun({K, _}) -> lists:member(K, [bound_term, bound_type]) end, Unbind = fun(bound_term) -> term; (bound_type) -> type end, + Remove = fun(Keys, Map) -> lists:foldl(fun maps:remove/2, Map, Keys) end, Scoped = fun(Xs, Ys) -> - {Bound, Others} = lists:partition(IsBound, Ys), + Bound = [E || E <- maps:keys(Ys), IsBound(E)], + Others = Remove(Bound, Ys), Bound1 = [ {Unbind(Tag), X} || {Tag, X} <- Bound ], - lists:umerge(Xs -- Bound1, Others) + maps:merge(Remove(Bound1, Xs), Others) end, - #alg{ zero = [] - , plus = fun lists:umerge/2 + #alg{ zero = #{} + , plus = fun maps:merge/2 , scoped = Scoped }. --spec used(_) -> [entity()]. +-spec used(_) -> [{entity(), aeso_syntax:ann()}]. used(D) -> Kind = fun(expr) -> term; (bind_expr) -> bound_term; @@ -128,14 +130,14 @@ used(D) -> (bind_type) -> bound_type end, NS = fun(Xs) -> {namespace, lists:droplast(Xs)} end, - NotBound = fun({Tag, _}) -> not lists:member(Tag, [bound_term, bound_type]) end, + NotBound = fun({{Tag, _}, _}) -> not lists:member(Tag, [bound_term, bound_type]) end, Xs = - fold(entity_alg(), - fun(K, {id, _, X}) -> [{Kind(K), [X]}]; - (K, {qid, _, Xs}) -> [{Kind(K), Xs}, NS(Xs)]; - (K, {con, _, X}) -> [{Kind(K), [X]}]; - (K, {qcon, _, Xs}) -> [{Kind(K), Xs}, NS(Xs)]; - (_, _) -> [] - end, decl, D), + maps:to_list(fold(entity_alg(), + fun(K, {id, Ann, X}) -> #{{Kind(K), [X]} => Ann}; + (K, {qid, Ann, Xs}) -> #{{Kind(K), Xs} => Ann, NS(Xs) => Ann}; + (K, {con, Ann, X}) -> #{{Kind(K), [X]} => Ann}; + (K, {qcon, Ann, Xs}) -> #{{Kind(K), Xs} => Ann, NS(Xs) => Ann}; + (_, _) -> #{} + end, decl, D)), lists:filter(NotBound, Xs). From 74d4048d9fc726faab9f5719115b4a878c8b044a Mon Sep 17 00:00:00 2001 From: Ulf Norell Date: Mon, 13 May 2019 17:51:47 +0200 Subject: [PATCH 3/8] Check that init doesn't read or write the state --- src/aeso_ast_infer_types.erl | 49 ++++++++++++++++++++++++ test/aeso_compiler_tests.erl | 13 +++++++ test/contracts/bad_init_state_access.aes | 13 +++++++ 3 files changed, 75 insertions(+) create mode 100644 test/contracts/bad_init_state_access.aes diff --git a/src/aeso_ast_infer_types.erl b/src/aeso_ast_infer_types.erl index be61478..fde4db0 100644 --- a/src/aeso_ast_infer_types.erl +++ b/src/aeso_ast_infer_types.erl @@ -605,6 +605,8 @@ infer_contract(Env, What, Defs) -> SCCs = aeso_utils:scc(DepGraph), %% io:format("Dependency sorted functions:\n ~p\n", [SCCs]), {Env4, Defs1} = check_sccs(Env3, FunMap, SCCs, []), + %% Check that `init` doesn't read or write the state + check_state_dependencies(Env4, Defs1), destroy_and_report_type_errors(Env4), {Env4, TypeDefs ++ Decls ++ Defs1}. @@ -935,6 +937,47 @@ check_stateful_named_arg(#env{ stateful = false, current_function = Fun }, {id, end; check_stateful_named_arg(_, _, _) -> ok. +%% Check that `init` doesn't read or write the state +check_state_dependencies(Env, Defs) -> + Top = Env#env.namespace, + GetState = Top ++ ["state"], + SetState = Top ++ ["put"], + Init = Top ++ ["init"], + UsedNames = fun(X) -> [{Xs, Ann} || {{term, Xs}, Ann} <- aeso_syntax_utils:used(X)] end, + Funs = [ {Top ++ [Name], Fun} || Fun = {letfun, _, {id, _, Name}, _Args, _Type, _Body} <- Defs ], + Deps = maps:from_list([{Name, UsedNames(Def)} || {Name, Def} <- Funs]), + case maps:get(Init, Deps, false) of + false -> ok; %% No init, so nothing to check + _ -> + [ type_error({init_depends_on_state, state, Chain}) + || Chain <- get_call_chains(Deps, Init, GetState) ], + [ type_error({init_depends_on_state, put, Chain}) + || Chain <- get_call_chains(Deps, Init, SetState) ], + ok + end. + +%% Compute all paths (not sharing intermediate nodes) from Start to Stop in Graph. +get_call_chains(Graph, Start, Stop) -> + get_call_chains(Graph, #{}, queue:from_list([{Start, [], []}]), Stop, []). + +get_call_chains(_Graph, _Visit, [], _, Acc) -> lists:reverse(Acc); +get_call_chains(Graph, Visited, [{Stop, Path} | Queue], Stop, Acc) -> + get_call_chains(Graph, Visited, Queue, Stop, [lists:reverse(Path) | Acc]); +get_call_chains(Graph, Visited, Queue, Stop, Acc) -> + case queue:out(Queue) of + {empty, _} -> lists:reverse(Acc); + {{value, {Stop, Ann, Path}}, Queue1} -> + get_call_chains(Graph, Visited, Queue1, Stop, [lists:reverse([{Stop, Ann} | Path]) | Acc]); + {{value, {Node, Ann, Path}}, Queue1} -> + case maps:is_key(Node, Visited) of + true -> get_call_chains(Graph, Visited, Queue1, Stop, Acc); + false -> + Calls = maps:get(Node, Graph, []), + NewQ = queue:from_list([{New, Ann1, [{Node, Ann} | Path]} || {New, Ann1} <- Calls]), + get_call_chains(Graph, Visited#{Node => true}, queue:join(Queue1, NewQ), Stop, Acc) + end + end. + check_expr(Env, Expr, Type) -> E = {typed, _, _, Type1} = infer_expr(Env, Expr), unify(Env, Type1, Type, {check_expr, Expr, Type1, Type}), @@ -1970,6 +2013,12 @@ pp_error({stateful_not_allowed, Id, Fun}) -> pp_error({value_arg_not_allowed, Value, Fun}) -> io_lib:format("Cannot pass non-zero value argument ~s (at ~s)\nin the definition of non-stateful function ~s.\n", [pp_expr("", Value), pp_loc(Value), pp(Fun)]); +pp_error({init_depends_on_state, Which, [_Init | Chain]}) -> + WhichCalls = fun("put") -> ""; ("state") -> ""; (_) -> ", which calls" end, + io_lib:format("The init function should return the initial state as its result and cannot ~s the state,\nbut it calls\n~s", + [if Which == put -> "write"; true -> "read" end, + [ io_lib:format(" - ~s (at ~s)~s\n", [Fun, pp_loc(Ann), WhichCalls(Fun)]) + || {[_, Fun], Ann} <- Chain]]); pp_error(Err) -> io_lib:format("Unknown error: ~p\n", [Err]). diff --git a/test/aeso_compiler_tests.erl b/test/aeso_compiler_tests.erl index 62bee65..7d161ff 100644 --- a/test/aeso_compiler_tests.erl +++ b/test/aeso_compiler_tests.erl @@ -316,4 +316,17 @@ failing_contracts() -> <<"Cannot pass non-zero value argument 1000 (at line 48, column 55)\nin the definition of non-stateful function fail6.">>, <<"Cannot pass non-zero value argument 1000 (at line 49, column 54)\nin the definition of non-stateful function fail7.">>, <<"Cannot pass non-zero value argument 1000 (at line 52, column 17)\nin the definition of non-stateful function fail8.">>]} + , {"bad_init_state_access", + [<<"The init function should return the initial state as its result and cannot write the state,\n" + "but it calls\n" + " - set_state (at line 11, column 5), which calls\n" + " - roundabout (at line 8, column 36), which calls\n" + " - put (at line 7, column 37)">>, + <<"The init function should return the initial state as its result and cannot read the state,\n" + "but it calls\n" + " - new_state (at line 12, column 5), which calls\n" + " - state (at line 5, column 27)">>, + <<"The init function should return the initial state as its result and cannot read the state,\n" + "but it calls\n" + " - state (at line 13, column 13)">>]} ]. diff --git a/test/contracts/bad_init_state_access.aes b/test/contracts/bad_init_state_access.aes new file mode 100644 index 0000000..3a0aa52 --- /dev/null +++ b/test/contracts/bad_init_state_access.aes @@ -0,0 +1,13 @@ +contract BadInit = + + type state = int + + function new_state(n) = state + n + + stateful function roundabout(n) = put(n) + stateful function set_state(n) = roundabout(n) + + stateful function init() = + set_state(4) + new_state(0) + state + state From d8dd6b900ff2c143056f5538a8539b6f9530162b Mon Sep 17 00:00:00 2001 From: Ulf Norell Date: Mon, 13 May 2019 17:51:56 +0200 Subject: [PATCH 4/8] Remove unused test contract --- test/contracts/remote_oracles.aes | 99 ------------------------------- 1 file changed, 99 deletions(-) delete mode 100644 test/contracts/remote_oracles.aes diff --git a/test/contracts/remote_oracles.aes b/test/contracts/remote_oracles.aes deleted file mode 100644 index 1b921c0..0000000 --- a/test/contracts/remote_oracles.aes +++ /dev/null @@ -1,99 +0,0 @@ -contract Oracles = - - function registerOracle : - (address, - int, - Chain.ttl) => oracle(string, int) - - function createQuery : - (oracle(string, int), - string, - int, - Chain.ttl, - Chain.ttl) => oracle_query(string, int) - - function unsafeCreateQuery : - (oracle(string, int), - string, - int, - Chain.ttl, - Chain.ttl) => oracle_query(string, int) - - function respond : - (oracle(string, int), - oracle_query(string, int), - int) => () - -contract OraclesErr = - - function unsafeCreateQueryThenErr : - (oracle(string, int), - string, - int, - Chain.ttl, - Chain.ttl) => oracle_query(string, int) - -contract RemoteOracles = - - public function callRegisterOracle( - r : Oracles, - acct : address, - qfee : int, - ttl : Chain.ttl) : oracle(string, int) = - r.registerOracle(acct, qfee, ttl) - - public function callCreateQuery( - r : Oracles, - value : int, - o : oracle(string, int), - q : string, - qfee : int, - qttl : Chain.ttl, - rttl : Chain.ttl) : oracle_query(string, int) = - require(value =< Call.value, "insufficient value") - r.createQuery(value = value, o, q, qfee, qttl, rttl) - - // Do not use in production! - public function callUnsafeCreateQuery( - r : Oracles, - value : int, - o : oracle(string, int), - q : string, - qfee : int, - qttl : Chain.ttl, - rttl : Chain.ttl) : oracle_query(string, int) = - r.unsafeCreateQuery(value = value, o, q, qfee, qttl, rttl) - - // Do not use in production! - public function callUnsafeCreateQueryThenErr( - r : OraclesErr, - value : int, - o : oracle(string, int), - q : string, - qfee : int, - qttl : Chain.ttl, - rttl : Chain.ttl) : oracle_query(string, int) = - r.unsafeCreateQueryThenErr(value = value, o, q, qfee, qttl, rttl) - - // Do not use in production! - public function callUnsafeCreateQueryAndThenErr( - r : Oracles, - value : int, - o : oracle(string, int), - q : string, - qfee : int, - qttl : Chain.ttl, - rttl : Chain.ttl) : oracle_query(string, int) = - let x = r.unsafeCreateQuery(value = value, o, q, qfee, qttl, rttl) - switch(0) 1 => () - x // Never reached. - - public function callRespond( - r : Oracles, - o : oracle(string, int), - q : oracle_query(string, int), - qr : int) = - r.respond(o, q, qr) - - private function require(b : bool, err : string) = - if(!b) abort(err) From 389072fb12b2332806ac6007ccd30c78ae382833 Mon Sep 17 00:00:00 2001 From: Ulf Norell Date: Tue, 14 May 2019 09:32:52 +0200 Subject: [PATCH 5/8] Add stateful to __call --- src/aeso_compiler.erl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aeso_compiler.erl b/src/aeso_compiler.erl index 1927d98..2fb5763 100644 --- a/src/aeso_compiler.erl +++ b/src/aeso_compiler.erl @@ -188,7 +188,7 @@ insert_call_function(Code, FunName, Args, Options) -> [ Code, "\n\n", lists:duplicate(Ind, " "), - "function __call() = ", FunName, "(", string:join(Args, ","), ")\n" + "stateful function __call() = ", FunName, "(", string:join(Args, ","), ")\n" ]). -spec insert_init_function(string(), options()) -> string(). From d051fa6c89a855d34f37aa591d3cc27e851fa7c0 Mon Sep 17 00:00:00 2001 From: Ulf Norell Date: Tue, 14 May 2019 09:42:17 +0200 Subject: [PATCH 6/8] Remove bad code --- src/aeso_ast_infer_types.erl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/aeso_ast_infer_types.erl b/src/aeso_ast_infer_types.erl index fde4db0..1e5bfc8 100644 --- a/src/aeso_ast_infer_types.erl +++ b/src/aeso_ast_infer_types.erl @@ -960,9 +960,6 @@ check_state_dependencies(Env, Defs) -> get_call_chains(Graph, Start, Stop) -> get_call_chains(Graph, #{}, queue:from_list([{Start, [], []}]), Stop, []). -get_call_chains(_Graph, _Visit, [], _, Acc) -> lists:reverse(Acc); -get_call_chains(Graph, Visited, [{Stop, Path} | Queue], Stop, Acc) -> - get_call_chains(Graph, Visited, Queue, Stop, [lists:reverse(Path) | Acc]); get_call_chains(Graph, Visited, Queue, Stop, Acc) -> case queue:out(Queue) of {empty, _} -> lists:reverse(Acc); From 9e555a31210ba9b943e20fe3197b8c250a32ac7a Mon Sep 17 00:00:00 2001 From: Ulf Norell Date: Tue, 14 May 2019 09:42:27 +0200 Subject: [PATCH 7/8] Fix type definition --- src/aeso_syntax.erl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aeso_syntax.erl b/src/aeso_syntax.erl index b3cc3ce..ddb95e5 100644 --- a/src/aeso_syntax.erl +++ b/src/aeso_syntax.erl @@ -25,7 +25,7 @@ -type ann_origin() :: system | user. -type ann_format() :: '?:' | hex | infix | prefix | elif. --type ann() :: [{line, ann_line()} | {col, ann_col()} | {format, ann_format()} | {origin, ann_origin()}]. +-type ann() :: [{line, ann_line()} | {col, ann_col()} | {format, ann_format()} | {origin, ann_origin()} | stateful | private]. -type name() :: string(). -type id() :: {id, ann(), name()}. From cf5a8aeb5fd153e4c6d94f2fb22ca2a477960fad Mon Sep 17 00:00:00 2001 From: Ulf Norell Date: Tue, 14 May 2019 10:10:53 +0200 Subject: [PATCH 8/8] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index cb18703..cf6d2e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- `stateful` annotations are now properly enforced. Functions must be marked stateful + in order to update the state or spend tokens. ### Changed ### Removed