From e2ab41eeb22b1cd6776fa3d1badc3e993fd83ae4 Mon Sep 17 00:00:00 2001 From: Ulf Norell Date: Mon, 9 Sep 2019 12:23:13 +0200 Subject: [PATCH] Add Bytes.concat and Bytes.split to type checker --- src/aeso_ast_infer_types.erl | 92 ++++++++++++++++++++++++----- test/aeso_compiler_tests.erl | 44 ++++++++++++++ test/contracts/bad_bytes_concat.aes | 19 ++++++ test/contracts/bad_bytes_split.aes | 20 +++++++ 4 files changed, 159 insertions(+), 16 deletions(-) create mode 100644 test/contracts/bad_bytes_concat.aes create mode 100644 test/contracts/bad_bytes_split.aes diff --git a/src/aeso_ast_infer_types.erl b/src/aeso_ast_infer_types.erl index 1763354..7b725de 100644 --- a/src/aeso_ast_infer_types.erl +++ b/src/aeso_ast_infer_types.erl @@ -372,7 +372,8 @@ global_env() -> Option = fun(T) -> {app_t, Ann, {id, Ann, "option"}, [T]} end, Map = fun(A, B) -> {app_t, Ann, {id, Ann, "map"}, [A, B]} end, Pair = fun(A, B) -> {tuple_t, Ann, [A, B]} end, - Fun = fun(Ts, T) -> {type_sig, Ann, none, [], Ts, T} end, + FunC = fun(C, Ts, T) -> {type_sig, Ann, C, [], Ts, T} end, + Fun = fun(Ts, T) -> FunC(none, Ts, T) end, Fun1 = fun(S, T) -> Fun([S], T) end, %% Lambda = fun(Ts, T) -> {fun_t, Ann, [], Ts, T} end, %% Lambda1 = fun(S, T) -> Lambda([S], T) end, @@ -510,7 +511,10 @@ global_env() -> BytesScope = #scope { funs = MkDefs( [{"to_int", Fun1(Bytes(any), Int)}, - {"to_str", Fun1(Bytes(any), String)}]) }, + {"to_str", Fun1(Bytes(any), String)}, + {"concat", FunC(bytes_concat, [Bytes(any), Bytes(any)], Bytes(any))}, + {"split", FunC(bytes_split, [Bytes(any)], Pair(Bytes(any), Bytes(any)))} + ]) }, %% Conversion IntScope = #scope{ funs = MkDefs([{"to_str", Fun1(Int, String)}]) }, @@ -942,8 +946,8 @@ infer_letfun(Env0, {letfun, Attrib, Fun = {id, NameAttrib, Name}, Args, What, Bo Env = Env0#env{ stateful = aeso_syntax:get_ann(stateful, Attrib, false), current_function = Fun }, check_unique_arg_names(Fun, Args), - ArgTypes = [{ArgName, check_type(Env, arg_type(T))} || {arg, _, ArgName, T} <- Args], - ExpectedType = check_type(Env, arg_type(What)), + ArgTypes = [{ArgName, check_type(Env, arg_type(ArgAnn, T))} || {arg, ArgAnn, ArgName, T} <- Args], + ExpectedType = check_type(Env, arg_type(NameAttrib, What)), NewBody={typed, _, _, ResultType} = check_expr(bind_vars(ArgTypes, Env), Body, ExpectedType), NewArgs = [{arg, A1, {id, A2, ArgName}, T} || {{_, T}, {arg, A1, {id, A2, ArgName}, _}} <- lists:zip(ArgTypes, Args)], @@ -962,11 +966,14 @@ check_unique_arg_names(Fun, Args) -> print_typesig({Name, TypeSig}) -> ?PRINT_TYPES("Inferred ~s : ~s\n", [Name, pp(TypeSig)]). -arg_type({id, Attrs, "_"}) -> - fresh_uvar(Attrs); -arg_type({app_t, Attrs, Name, Args}) -> - {app_t, Attrs, Name, [arg_type(T) || T <- Args]}; -arg_type(T) -> +arg_type(ArgAnn, {id, Ann, "_"}) -> + case aeso_syntax:get_ann(origin, Ann, user) of + system -> fresh_uvar(ArgAnn); + user -> fresh_uvar(Ann) + end; +arg_type(ArgAnn, {app_t, Attrs, Name, Args}) -> + {app_t, Attrs, Name, [arg_type(ArgAnn, T) || T <- Args]}; +arg_type(_, T) -> T. app_t(_Ann, Name, []) -> Name; @@ -1118,7 +1125,7 @@ infer_expr(Env, {list_comp, AttrsL, Yield, [{comprehension_if, AttrsIF, Cond}|Re , {list_comp, AttrsL, TypedYield, [{comprehension_if, AttrsIF, NewCond}|TypedRest]} , ResType}; infer_expr(Env, {list_comp, AsLC, Yield, [{letval, AsLV, Pattern, Type, E}|Rest]}) -> - NewE = {typed, _, _, PatType} = infer_expr(Env, {typed, AsLV, E, arg_type(Type)}), + NewE = {typed, _, _, PatType} = infer_expr(Env, {typed, AsLV, E, arg_type(AsLV, Type)}), BlockType = fresh_uvar(AsLV), {'case', _, NewPattern, NewRest} = infer_case( Env @@ -1342,7 +1349,7 @@ infer_block(Env, Attrs, [Def={letfun, Ann, _, _, _, _}|Rest], BlockType) -> NewE = bind_var({id, Ann, Name}, FunT, Env), [LetFun|infer_block(NewE, Attrs, Rest, BlockType)]; infer_block(Env, _, [{letval, Attrs, Pattern, Type, E}|Rest], BlockType) -> - NewE = {typed, _, _, PatType} = infer_expr(Env, {typed, Attrs, E, arg_type(Type)}), + NewE = {typed, _, _, PatType} = infer_expr(Env, {typed, Attrs, E, arg_type(aeso_syntax:get_ann(Pattern), Type)}), {'case', _, NewPattern, {typed, _, {block, _, NewRest}, _}} = infer_case(Env, Attrs, Pattern, PatType, {block, Attrs, Rest}, BlockType), [{letval, Attrs, NewPattern, Type, NewE}|NewRest]; @@ -1542,7 +1549,8 @@ destroy_and_report_unsolved_named_argument_constraints(Env) -> %% -- Bytes constraints -- --type byte_constraint() :: {is_bytes, utype()}. +-type byte_constraint() :: {is_bytes, utype()} + | {add_bytes, aeso_syntax:ann(), concat | split, utype(), utype(), utype()}. create_bytes_constraints() -> ets_new(bytes_constraints, [bag]). @@ -1554,21 +1562,52 @@ get_bytes_constraints() -> add_bytes_constraint(Constraint) -> ets_insert(bytes_constraints, Constraint). -solve_bytes_constraints(_Env) -> +solve_bytes_constraints(Env) -> + [ solve_bytes_constraint(Env, C) || C <- get_bytes_constraints() ], ok. +solve_bytes_constraint(_Env, {is_bytes, _}) -> ok; +solve_bytes_constraint(Env, {add_bytes, Ann, _, A0, B0, C0}) -> + A = unfold_types_in_type(Env, dereference(A0)), + B = unfold_types_in_type(Env, dereference(B0)), + C = unfold_types_in_type(Env, dereference(C0)), + case {A, B, C} of + {{bytes_t, _, M}, {bytes_t, _, N}, _} -> unify(Env, {bytes_t, Ann, M + N}, C, {at, Ann}); + {{bytes_t, _, M}, _, {bytes_t, _, R}} when R >= M -> unify(Env, {bytes_t, Ann, R - M}, B, {at, Ann}); + {_, {bytes_t, _, N}, {bytes_t, _, R}} when R >= N -> unify(Env, {bytes_t, Ann, R - N}, A, {at, Ann}); + _ -> ok + end. + destroy_bytes_constraints() -> ets_delete(bytes_constraints). destroy_and_report_unsolved_bytes_constraints(Env) -> - [ check_bytes_constraint(Env, C) || C <- get_bytes_constraints() ], + Constraints = get_bytes_constraints(), + InAddConstraint = [ T || {add_bytes, _, _, A, B, C} <- Constraints, + T <- [A, B, C], + element(1, T) /= bytes_t ], + %% Skip is_bytes constraints for types that occur in add_bytes constraints + %% (no need to generate error messages for both is_bytes and add_bytes). + Skip = fun({is_bytes, T}) -> lists:member(T, InAddConstraint); + (_) -> false end, + [ check_bytes_constraint(Env, C) || C <- Constraints, not Skip(C) ], destroy_bytes_constraints(). check_bytes_constraint(Env, {is_bytes, Type}) -> Type1 = unfold_types_in_type(Env, instantiate(Type)), case Type1 of {bytes_t, _, _} -> ok; - _ -> type_error({cannot_unify, Type1, {bytes_t, [], any}, {at, Type}}) + _ -> + type_error({unknown_byte_length, Type}) + end; +check_bytes_constraint(Env, {add_bytes, Ann, Fun, A0, B0, C0}) -> + A = unfold_types_in_type(Env, instantiate(A0)), + B = unfold_types_in_type(Env, instantiate(B0)), + C = unfold_types_in_type(Env, instantiate(C0)), + case {A, B, C} of + {{bytes_t, _, _M}, {bytes_t, _, _N}, {bytes_t, _, _R}} -> + ok; %% If all are solved we checked M + N == R in solve_bytes_constraint. + _ -> type_error({unsolved_bytes_constraint, Ann, Fun, A, B, C}) end. %% -- Field constraints -- @@ -1999,6 +2038,8 @@ occurs_check1(R, [H | T]) -> occurs_check(R, H) orelse occurs_check(R, T); occurs_check1(_, []) -> false. +fresh_uvar([{origin, system}]) -> + error(oh_no_you_dont); fresh_uvar(Attrs) -> {uvar, Attrs, make_ref()}. @@ -2040,7 +2081,11 @@ freshen_type_sig(Ann, TypeSig = {type_sig, _, Constr, _, _, _}) -> apply_typesig_constraint(Ann, Constr, FunT), FunT. -apply_typesig_constraint(_Ann, none, _FunT) -> ok. +apply_typesig_constraint(_Ann, none, _FunT) -> ok; +apply_typesig_constraint(Ann, bytes_concat, {fun_t, _, [], [A, B], C}) -> + add_bytes_constraint({add_bytes, Ann, concat, A, B, C}); +apply_typesig_constraint(Ann, bytes_split, {fun_t, _, [], [C], {tuple_t, _, [A, B]}}) -> + add_bytes_constraint({add_bytes, Ann, split, A, B, C}). %% Dereferences all uvars and replaces the uninstantiated ones with a %% succession of tvars. @@ -2326,6 +2371,21 @@ mk_error({bad_top_level_decl, Decl}) -> Msg = io_lib:format("The definition of '~s' must appear inside a ~s.\n", [pp_expr("", Id), What]), mk_t_err(pos(Decl), Msg); +mk_error({unknown_byte_length, Type}) -> + Msg = io_lib:format("Cannot resolve length of byte array.\n", []), + mk_t_err(pos(Type), Msg); +mk_error({unsolved_bytes_constraint, Ann, concat, A, B, C}) -> + Msg = io_lib:format("Failed to resolve byte array lengths in call to Bytes.concat with arguments of type\n" + "~s (at ~s)\n~s (at ~s)\nand result type\n~s (at ~s)\n", + [pp_type(" - ", A), pp_loc(A), pp_type(" - ", B), + pp_loc(B), pp_type(" - ", C), pp_loc(C)]), + mk_t_err(pos(Ann), Msg); +mk_error({unsolved_bytes_constraint, Ann, split, A, B, C}) -> + Msg = io_lib:format("Failed to resolve byte array lengths in call to Bytes.split with argument of type\n" + "~s (at ~s)\nand result types\n~s (at ~s)\n~s (at ~s)\n", + [ pp_type(" - ", C), pp_loc(C), pp_type(" - ", A), pp_loc(A), + pp_type(" - ", B), pp_loc(B)]), + mk_t_err(pos(Ann), Msg); mk_error(Err) -> Msg = io_lib:format("Unknown error: ~p\n", [Err]), mk_t_err(pos(0, 0), Msg). diff --git a/test/aeso_compiler_tests.erl b/test/aeso_compiler_tests.erl index 3ba5fc9..9f19e82 100644 --- a/test/aeso_compiler_tests.erl +++ b/test/aeso_compiler_tests.erl @@ -511,6 +511,50 @@ failing_contracts() -> [<>]) + , ?TYPE_ERROR(bad_bytes_concat, + [<>, + <>, + <>, + <>, + <>]) + , ?TYPE_ERROR(bad_bytes_split, + [<>, + <>, + <>]) ]. -define(Path(File), "code_errors/" ??File). diff --git a/test/contracts/bad_bytes_concat.aes b/test/contracts/bad_bytes_concat.aes new file mode 100644 index 0000000..fa8ca01 --- /dev/null +++ b/test/contracts/bad_bytes_concat.aes @@ -0,0 +1,19 @@ +contract BytesConcat = + + entrypoint test1(x : bytes(10), y : bytes(20)) = + Bytes.concat(x, y) + + entrypoint test2(x : bytes(10), y) : bytes(15) = + Bytes.concat(x, y) + + entrypoint test3(x, y : bytes(20)) : bytes(25) = + Bytes.concat(x, y) + + entrypoint fail1(x, y) : bytes(10) = Bytes.concat(x, y) + entrypoint fail2(x, y) = Bytes.concat(x, y) + entrypoint fail3(x : bytes(6), y : bytes(20)) : bytes(25) = + Bytes.concat(x, y) + entrypoint fail4(x : bytes(6), y) : _ = + Bytes.concat(x, y) + + entrypoint fail5(x) = Bytes.to_str(x) diff --git a/test/contracts/bad_bytes_split.aes b/test/contracts/bad_bytes_split.aes new file mode 100644 index 0000000..97fcb89 --- /dev/null +++ b/test/contracts/bad_bytes_split.aes @@ -0,0 +1,20 @@ +contract BytesSplit = + + entrypoint test1(x) : bytes(10) * bytes(20) = + Bytes.split(x) + + entrypoint test2(x : bytes(15)) : bytes(10) * _ = + Bytes.split(x) + + entrypoint test3(x : bytes(25)) : _ * bytes(20) = + Bytes.split(x) + + entrypoint fail1(x) : _ * bytes(20) = + Bytes.split(x) + + entrypoint fail2(x : bytes(15)) : _ = + Bytes.split(x) + + entrypoint fail3(x) : bytes(20) * _ = + Bytes.split(x) +