Add Bytes.concat and Bytes.split to type checker

This commit is contained in:
Ulf Norell 2019-09-09 12:23:13 +02:00
parent 0f612ead90
commit e2ab41eeb2
4 changed files with 159 additions and 16 deletions

View File

@ -372,7 +372,8 @@ global_env() ->
Option = fun(T) -> {app_t, Ann, {id, Ann, "option"}, [T]} end,
Map = fun(A, B) -> {app_t, Ann, {id, Ann, "map"}, [A, B]} end,
Pair = fun(A, B) -> {tuple_t, Ann, [A, B]} end,
Fun = fun(Ts, T) -> {type_sig, Ann, none, [], Ts, T} end,
FunC = fun(C, Ts, T) -> {type_sig, Ann, C, [], Ts, T} end,
Fun = fun(Ts, T) -> FunC(none, Ts, T) end,
Fun1 = fun(S, T) -> Fun([S], T) end,
%% Lambda = fun(Ts, T) -> {fun_t, Ann, [], Ts, T} end,
%% Lambda1 = fun(S, T) -> Lambda([S], T) end,
@ -510,7 +511,10 @@ global_env() ->
BytesScope = #scope
{ funs = MkDefs(
[{"to_int", Fun1(Bytes(any), Int)},
{"to_str", Fun1(Bytes(any), String)}]) },
{"to_str", Fun1(Bytes(any), String)},
{"concat", FunC(bytes_concat, [Bytes(any), Bytes(any)], Bytes(any))},
{"split", FunC(bytes_split, [Bytes(any)], Pair(Bytes(any), Bytes(any)))}
]) },
%% Conversion
IntScope = #scope{ funs = MkDefs([{"to_str", Fun1(Int, String)}]) },
@ -942,8 +946,8 @@ infer_letfun(Env0, {letfun, Attrib, Fun = {id, NameAttrib, Name}, Args, What, Bo
Env = Env0#env{ stateful = aeso_syntax:get_ann(stateful, Attrib, false),
current_function = Fun },
check_unique_arg_names(Fun, Args),
ArgTypes = [{ArgName, check_type(Env, arg_type(T))} || {arg, _, ArgName, T} <- Args],
ExpectedType = check_type(Env, arg_type(What)),
ArgTypes = [{ArgName, check_type(Env, arg_type(ArgAnn, T))} || {arg, ArgAnn, ArgName, T} <- Args],
ExpectedType = check_type(Env, arg_type(NameAttrib, What)),
NewBody={typed, _, _, ResultType} = check_expr(bind_vars(ArgTypes, Env), Body, ExpectedType),
NewArgs = [{arg, A1, {id, A2, ArgName}, T}
|| {{_, T}, {arg, A1, {id, A2, ArgName}, _}} <- lists:zip(ArgTypes, Args)],
@ -962,11 +966,14 @@ check_unique_arg_names(Fun, Args) ->
print_typesig({Name, TypeSig}) ->
?PRINT_TYPES("Inferred ~s : ~s\n", [Name, pp(TypeSig)]).
arg_type({id, Attrs, "_"}) ->
fresh_uvar(Attrs);
arg_type({app_t, Attrs, Name, Args}) ->
{app_t, Attrs, Name, [arg_type(T) || T <- Args]};
arg_type(T) ->
arg_type(ArgAnn, {id, Ann, "_"}) ->
case aeso_syntax:get_ann(origin, Ann, user) of
system -> fresh_uvar(ArgAnn);
user -> fresh_uvar(Ann)
end;
arg_type(ArgAnn, {app_t, Attrs, Name, Args}) ->
{app_t, Attrs, Name, [arg_type(ArgAnn, T) || T <- Args]};
arg_type(_, T) ->
T.
app_t(_Ann, Name, []) -> Name;
@ -1118,7 +1125,7 @@ infer_expr(Env, {list_comp, AttrsL, Yield, [{comprehension_if, AttrsIF, Cond}|Re
, {list_comp, AttrsL, TypedYield, [{comprehension_if, AttrsIF, NewCond}|TypedRest]}
, ResType};
infer_expr(Env, {list_comp, AsLC, Yield, [{letval, AsLV, Pattern, Type, E}|Rest]}) ->
NewE = {typed, _, _, PatType} = infer_expr(Env, {typed, AsLV, E, arg_type(Type)}),
NewE = {typed, _, _, PatType} = infer_expr(Env, {typed, AsLV, E, arg_type(AsLV, Type)}),
BlockType = fresh_uvar(AsLV),
{'case', _, NewPattern, NewRest} =
infer_case( Env
@ -1342,7 +1349,7 @@ infer_block(Env, Attrs, [Def={letfun, Ann, _, _, _, _}|Rest], BlockType) ->
NewE = bind_var({id, Ann, Name}, FunT, Env),
[LetFun|infer_block(NewE, Attrs, Rest, BlockType)];
infer_block(Env, _, [{letval, Attrs, Pattern, Type, E}|Rest], BlockType) ->
NewE = {typed, _, _, PatType} = infer_expr(Env, {typed, Attrs, E, arg_type(Type)}),
NewE = {typed, _, _, PatType} = infer_expr(Env, {typed, Attrs, E, arg_type(aeso_syntax:get_ann(Pattern), Type)}),
{'case', _, NewPattern, {typed, _, {block, _, NewRest}, _}} =
infer_case(Env, Attrs, Pattern, PatType, {block, Attrs, Rest}, BlockType),
[{letval, Attrs, NewPattern, Type, NewE}|NewRest];
@ -1542,7 +1549,8 @@ destroy_and_report_unsolved_named_argument_constraints(Env) ->
%% -- Bytes constraints --
-type byte_constraint() :: {is_bytes, utype()}.
-type byte_constraint() :: {is_bytes, utype()}
| {add_bytes, aeso_syntax:ann(), concat | split, utype(), utype(), utype()}.
create_bytes_constraints() ->
ets_new(bytes_constraints, [bag]).
@ -1554,21 +1562,52 @@ get_bytes_constraints() ->
add_bytes_constraint(Constraint) ->
ets_insert(bytes_constraints, Constraint).
solve_bytes_constraints(_Env) ->
solve_bytes_constraints(Env) ->
[ solve_bytes_constraint(Env, C) || C <- get_bytes_constraints() ],
ok.
solve_bytes_constraint(_Env, {is_bytes, _}) -> ok;
solve_bytes_constraint(Env, {add_bytes, Ann, _, A0, B0, C0}) ->
A = unfold_types_in_type(Env, dereference(A0)),
B = unfold_types_in_type(Env, dereference(B0)),
C = unfold_types_in_type(Env, dereference(C0)),
case {A, B, C} of
{{bytes_t, _, M}, {bytes_t, _, N}, _} -> unify(Env, {bytes_t, Ann, M + N}, C, {at, Ann});
{{bytes_t, _, M}, _, {bytes_t, _, R}} when R >= M -> unify(Env, {bytes_t, Ann, R - M}, B, {at, Ann});
{_, {bytes_t, _, N}, {bytes_t, _, R}} when R >= N -> unify(Env, {bytes_t, Ann, R - N}, A, {at, Ann});
_ -> ok
end.
destroy_bytes_constraints() ->
ets_delete(bytes_constraints).
destroy_and_report_unsolved_bytes_constraints(Env) ->
[ check_bytes_constraint(Env, C) || C <- get_bytes_constraints() ],
Constraints = get_bytes_constraints(),
InAddConstraint = [ T || {add_bytes, _, _, A, B, C} <- Constraints,
T <- [A, B, C],
element(1, T) /= bytes_t ],
%% Skip is_bytes constraints for types that occur in add_bytes constraints
%% (no need to generate error messages for both is_bytes and add_bytes).
Skip = fun({is_bytes, T}) -> lists:member(T, InAddConstraint);
(_) -> false end,
[ check_bytes_constraint(Env, C) || C <- Constraints, not Skip(C) ],
destroy_bytes_constraints().
check_bytes_constraint(Env, {is_bytes, Type}) ->
Type1 = unfold_types_in_type(Env, instantiate(Type)),
case Type1 of
{bytes_t, _, _} -> ok;
_ -> type_error({cannot_unify, Type1, {bytes_t, [], any}, {at, Type}})
_ ->
type_error({unknown_byte_length, Type})
end;
check_bytes_constraint(Env, {add_bytes, Ann, Fun, A0, B0, C0}) ->
A = unfold_types_in_type(Env, instantiate(A0)),
B = unfold_types_in_type(Env, instantiate(B0)),
C = unfold_types_in_type(Env, instantiate(C0)),
case {A, B, C} of
{{bytes_t, _, _M}, {bytes_t, _, _N}, {bytes_t, _, _R}} ->
ok; %% If all are solved we checked M + N == R in solve_bytes_constraint.
_ -> type_error({unsolved_bytes_constraint, Ann, Fun, A, B, C})
end.
%% -- Field constraints --
@ -1999,6 +2038,8 @@ occurs_check1(R, [H | T]) ->
occurs_check(R, H) orelse occurs_check(R, T);
occurs_check1(_, []) -> false.
fresh_uvar([{origin, system}]) ->
error(oh_no_you_dont);
fresh_uvar(Attrs) ->
{uvar, Attrs, make_ref()}.
@ -2040,7 +2081,11 @@ freshen_type_sig(Ann, TypeSig = {type_sig, _, Constr, _, _, _}) ->
apply_typesig_constraint(Ann, Constr, FunT),
FunT.
apply_typesig_constraint(_Ann, none, _FunT) -> ok.
apply_typesig_constraint(_Ann, none, _FunT) -> ok;
apply_typesig_constraint(Ann, bytes_concat, {fun_t, _, [], [A, B], C}) ->
add_bytes_constraint({add_bytes, Ann, concat, A, B, C});
apply_typesig_constraint(Ann, bytes_split, {fun_t, _, [], [C], {tuple_t, _, [A, B]}}) ->
add_bytes_constraint({add_bytes, Ann, split, A, B, C}).
%% Dereferences all uvars and replaces the uninstantiated ones with a
%% succession of tvars.
@ -2326,6 +2371,21 @@ mk_error({bad_top_level_decl, Decl}) ->
Msg = io_lib:format("The definition of '~s' must appear inside a ~s.\n",
[pp_expr("", Id), What]),
mk_t_err(pos(Decl), Msg);
mk_error({unknown_byte_length, Type}) ->
Msg = io_lib:format("Cannot resolve length of byte array.\n", []),
mk_t_err(pos(Type), Msg);
mk_error({unsolved_bytes_constraint, Ann, concat, A, B, C}) ->
Msg = io_lib:format("Failed to resolve byte array lengths in call to Bytes.concat with arguments of type\n"
"~s (at ~s)\n~s (at ~s)\nand result type\n~s (at ~s)\n",
[pp_type(" - ", A), pp_loc(A), pp_type(" - ", B),
pp_loc(B), pp_type(" - ", C), pp_loc(C)]),
mk_t_err(pos(Ann), Msg);
mk_error({unsolved_bytes_constraint, Ann, split, A, B, C}) ->
Msg = io_lib:format("Failed to resolve byte array lengths in call to Bytes.split with argument of type\n"
"~s (at ~s)\nand result types\n~s (at ~s)\n~s (at ~s)\n",
[ pp_type(" - ", C), pp_loc(C), pp_type(" - ", A), pp_loc(A),
pp_type(" - ", B), pp_loc(B)]),
mk_t_err(pos(Ann), Msg);
mk_error(Err) ->
Msg = io_lib:format("Unknown error: ~p\n", [Err]),
mk_t_err(pos(0, 0), Msg).

View File

@ -511,6 +511,50 @@ failing_contracts() ->
[<<?Pos(3, 5)
"Unbound variable Chain.event at line 3, column 5\n"
"Did you forget to define the event type?">>])
, ?TYPE_ERROR(bad_bytes_concat,
[<<?Pos(12, 40)
"Failed to resolve byte array lengths in call to Bytes.concat with arguments of type\n"
" - 'g (at line 12, column 20)\n"
" - 'h (at line 12, column 23)\n"
"and result type\n"
" - bytes(10) (at line 12, column 28)">>,
<<?Pos(13, 28)
"Failed to resolve byte array lengths in call to Bytes.concat with arguments of type\n"
" - 'd (at line 13, column 20)\n"
" - 'e (at line 13, column 23)\n"
"and result type\n"
" - 'f (at line 13, column 14)">>,
<<?Pos(15, 5)
"Cannot unify bytes(26)\n"
" and bytes(25)\n"
"at line 15, column 5">>,
<<?Pos(17, 5)
"Failed to resolve byte array lengths in call to Bytes.concat with arguments of type\n"
" - bytes(6) (at line 16, column 24)\n"
" - 'b (at line 16, column 34)\n"
"and result type\n"
" - 'c (at line 16, column 39)">>,
<<?Pos(19, 25)
"Cannot resolve length of byte array.">>])
, ?TYPE_ERROR(bad_bytes_split,
[<<?Pos(13, 5)
"Failed to resolve byte array lengths in call to Bytes.split with argument of type\n"
" - 'f (at line 12, column 20)\n"
"and result types\n"
" - 'e (at line 13, column 5)\n"
" - bytes(20) (at line 12, column 29)">>,
<<?Pos(16, 5)
"Failed to resolve byte array lengths in call to Bytes.split with argument of type\n"
" - bytes(15) (at line 15, column 24)\n"
"and result types\n"
" - 'c (at line 16, column 5)\n"
" - 'd (at line 16, column 5)">>,
<<?Pos(19, 5)
"Failed to resolve byte array lengths in call to Bytes.split with argument of type\n"
" - 'b (at line 18, column 20)\n"
"and result types\n"
" - bytes(20) (at line 18, column 25)\n"
" - 'a (at line 19, column 5)">>])
].
-define(Path(File), "code_errors/" ??File).

View File

@ -0,0 +1,19 @@
contract BytesConcat =
entrypoint test1(x : bytes(10), y : bytes(20)) =
Bytes.concat(x, y)
entrypoint test2(x : bytes(10), y) : bytes(15) =
Bytes.concat(x, y)
entrypoint test3(x, y : bytes(20)) : bytes(25) =
Bytes.concat(x, y)
entrypoint fail1(x, y) : bytes(10) = Bytes.concat(x, y)
entrypoint fail2(x, y) = Bytes.concat(x, y)
entrypoint fail3(x : bytes(6), y : bytes(20)) : bytes(25) =
Bytes.concat(x, y)
entrypoint fail4(x : bytes(6), y) : _ =
Bytes.concat(x, y)
entrypoint fail5(x) = Bytes.to_str(x)

View File

@ -0,0 +1,20 @@
contract BytesSplit =
entrypoint test1(x) : bytes(10) * bytes(20) =
Bytes.split(x)
entrypoint test2(x : bytes(15)) : bytes(10) * _ =
Bytes.split(x)
entrypoint test3(x : bytes(25)) : _ * bytes(20) =
Bytes.split(x)
entrypoint fail1(x) : _ * bytes(20) =
Bytes.split(x)
entrypoint fail2(x : bytes(15)) : _ =
Bytes.split(x)
entrypoint fail3(x) : bytes(20) * _ =
Bytes.split(x)