From 113d699b01baeec6ec15da0d393d616d633eac4e Mon Sep 17 00:00:00 2001 From: Gaith Hallak Date: Fri, 27 May 2022 15:45:26 +0400 Subject: [PATCH] Fix variance inference --- src/aeso_ast_infer_types.erl | 68 ++------- test/aeso_compiler_tests.erl | 131 ++++++++---------- ...rphism_variance_switching_custom_types.aes | 63 +++++---- 3 files changed, 112 insertions(+), 150 deletions(-) diff --git a/src/aeso_ast_infer_types.erl b/src/aeso_ast_infer_types.erl index a28cf12..fffc56f 100644 --- a/src/aeso_ast_infer_types.erl +++ b/src/aeso_ast_infer_types.erl @@ -809,7 +809,6 @@ infer(Contracts, Options) -> create_options(Options), ets_new(defined_contracts, [bag]), ets_new(type_vars, [set]), - ets_new(type_vars_uvar, [set]), ets_new(warnings, [bag]), ets_new(type_vars_variance, [set]), %% Set the variance for builtin types @@ -1126,7 +1125,7 @@ infer_type_vars_variance(TypeParams, Cons) -> % args from all type constructors FlatArgs = lists:flatten([Args || {constr_t, _, _, Args} <- Cons]) ++ [Type || {field_t, _, _, Type} <- Cons], - Vs = lists:flatten([element(1, infer_type_vars_variance(Arg)) || Arg <- FlatArgs]), + Vs = lists:flatten([infer_type_vars_variance(Arg) || Arg <- FlatArgs]), lists:map(fun({tvar, _, TVar}) -> S = sets:from_list([Variance || {TV, Variance} <- Vs, TV == TVar]), IsCovariant = sets:is_element(covariant, S), @@ -1139,7 +1138,10 @@ infer_type_vars_variance(TypeParams, Cons) -> end end, TypeParams). --spec infer_type_vars_variance(utype()) -> {[{name(), variance()}], integer()}. +-spec infer_type_vars_variance(utype()) -> [{name(), variance()}]. +infer_type_vars_variance(Types) + when is_list(Types) -> + lists:flatten([infer_type_vars_variance(T) || T <- Types]); infer_type_vars_variance({app_t, _, Type, Args}) -> Variances = case ets_lookup(type_vars_variance, qname(Type)) of [{_, Vs}] -> Vs; @@ -1147,36 +1149,14 @@ infer_type_vars_variance({app_t, _, Type, Args}) -> end, TypeVarsVariance = [{TVar, Variance} || {{tvar, _, TVar}, Variance} <- lists:zip(Args, Variances)], - {TypeVarsVariance, 0}; -infer_type_vars_variance({fun_t, _, [], [{app_t, _, Type, Args}], Res}) -> - Variances = case ets_lookup(type_vars_variance, qname(Type)) of - [{_, Vs}] -> Vs; - _ -> lists:duplicate(length(Args), covariant) - end, - TypeVarsVariance = [{TVar, Variance} - || {{tvar, _, TVar}, Variance} <- lists:zip(Args, Variances)], - FlipVariance = fun({TVar, covariant}) -> {TVar, contravariant}; - ({TVar, contravariant}) -> {TVar, covariant} - end, - {TVVs, Depth} = infer_type_vars_variance(Res), - Cur = case (Depth + 1) rem 2 of - 0 -> TypeVarsVariance; - 1 -> lists:map(FlipVariance, TypeVarsVariance) - end, - {Cur ++ TVVs, Depth + 1}; -infer_type_vars_variance({fun_t, _, [], [{tvar, _, TVar}], Res}) -> - {TVVs, Depth} = infer_type_vars_variance(Res), - Cur = case (Depth + 1) rem 2 of - 0 -> {TVar, covariant}; - 1 -> {TVar, contravariant} - end, - {[Cur | TVVs], Depth + 1}; -infer_type_vars_variance({fun_t, _, [], _, Res}) -> - {X, Depth} = infer_type_vars_variance(Res), - {X, Depth + 1}; -infer_type_vars_variance({tvar, _, TVar}) -> - {[{TVar, covariant}], 0}; -infer_type_vars_variance(_) -> {[], 0}. + TypeVarsVariance; +infer_type_vars_variance({tvar, _, TVar}) -> [{TVar, covariant}]; +infer_type_vars_variance({fun_t, _, [], Args, Res}) -> + ArgsVariance = infer_type_vars_variance(Args), + ResVariance = infer_type_vars_variance(Res), + FlippedArgsVariance = lists:map(fun({TVar, Variance}) -> {TVar, opposite_variance(Variance)} end, ArgsVariance), + FlippedArgsVariance ++ ResVariance; +infer_type_vars_variance(_) -> []. opposite_variance(invariant) -> invariant; opposite_variance(covariant) -> contravariant; @@ -1773,16 +1753,6 @@ infer_expr(Env, {app, Ann, Fun, Args0} = App) -> ResultType = fresh_uvar(Ann), when_warning(warn_unused_functions, fun() -> register_function_call(Namespace ++ qname(CurrentFun), Name) end), - % the uvars of tvars are stored so that no variance switching happens - % in between them (e.g. in TC('a => 'a), 'a should be a single type) - case FunType of - {fun_t, _, _, _, {app_t, _, _, TArgs}} -> - lists:foreach(fun({uvar, _, URef}) -> - ets_insert(type_vars_uvar, {URef}); - (_) -> ok - end, TArgs); - _ -> ok - end, unify(Env, FunType, {fun_t, [], NamedArgsVar, ArgTypes, GeneralResultType}, When), when_warning(warn_negative_spend, fun() -> warn_potential_negative_spend(Ann, NewFun1, NewArgs) end), add_constraint( @@ -2143,7 +2113,7 @@ next_count() -> ets_tables() -> [options, type_vars, constraints, freshen_tvars, type_errors, defined_contracts, warnings, function_calls, all_functions, - type_vars_variance, type_vars_uvar]. + type_vars_variance]. clean_up_ets() -> [ catch ets_delete(Tab) || Tab <- ets_tables() ], @@ -2687,7 +2657,7 @@ unify(Env, A, B, When) -> unify0(Env, A, B, covariant, When). unify0(_, {id, _, "_"}, _, _Variance, _When) -> true; unify0(_, _, {id, _, "_"}, _Variance, _When) -> true; -unify0(Env, A, B, Variance0, When) -> +unify0(Env, A, B, Variance, When) -> Options = case When of %% Improve source location for map_in_map_key errors {check_expr, E, _, _} -> [{ann, aeso_syntax:get_ann(E)}]; @@ -2695,14 +2665,6 @@ unify0(Env, A, B, Variance0, When) -> end, A1 = dereference(unfold_types_in_type(Env, A, Options)), B1 = dereference(unfold_types_in_type(Env, B, Options)), - Variance = case A of - {uvar, _,URef} -> - case ets_lookup(type_vars_uvar, URef) of - [_] -> invariant; - _ -> Variance0 - end; - _ -> Variance0 - end, unify1(Env, A1, B1, Variance, When). unify1(_Env, {uvar, _, R}, {uvar, _, R}, _Variance, _When) -> diff --git a/test/aeso_compiler_tests.erl b/test/aeso_compiler_tests.erl index 04268b6..f1feeda 100644 --- a/test/aeso_compiler_tests.erl +++ b/test/aeso_compiler_tests.erl @@ -884,114 +884,105 @@ failing_contracts() -> "when checking the type of the expression `some_animal : Animal` against the expected type `Cat`">> ]) , ?TYPE_ERROR(polymorphism_variance_switching_custom_types, - [<>, - <>, - < Animal) => dt_inv(Animal)`\n" - "to arguments\n" - " `f_a_to_c : (Animal) => Cat`">>, - < Cat) => dt_inv(Cat)`\n" - "to arguments\n" - " `f_c_to_a : (Cat) => Animal`">>, < Cat) => dt_inv(Cat)`\nto arguments\n `f_c_to_a : (Cat) => Animal`">>, + <>, - <>, < Animal) => dt_inv(Animal)`\n" - "to arguments\n" - " `f_a_to_c : (Animal) => Cat`">>, + "when checking the type of the expression `DT_INV(f_a_to_a) : dt_inv(Animal)` against the expected type `dt_inv(Cat)`">>, < Cat) => dt_inv(Cat)`\n" - "to arguments\n" - " `f_c_to_a : (Cat) => Animal`">>, - <>, + < Cat) => dt_inv(Cat)`\nto arguments\n `f_c_to_a : (Cat) => Animal`">>, + <>, - <>, - <>, - <>, - <>, - <>, - <>, - <>, - < (unit) => Animal) => dt_co_twice(Animal)`\n" + " `DT_CO_TWICE : (((Cat) => unit) => Cat) => dt_co_twice(Cat)`\n" "to arguments\n" - " `f_a_to_u_to_c : (Animal) => (unit) => Cat`">>, - < (unit) => Cat) => dt_co_twice(Cat)`\n" - "to arguments\n" - " `f_c_to_u_to_a : (Cat) => (unit) => Animal`">>, - < unit) => Animal`">>, + <>, - < (unit) => Animal) => dt_co_twice(Animal)`\n" - "to arguments\n" - " `f_a_to_u_to_c : (Animal) => (unit) => Cat`">>, <>, + < (unit) => Cat) => dt_co_twice(Cat)`\n" + " `DT_CO_TWICE : (((Cat) => unit) => Cat) => dt_co_twice(Cat)`\n" "to arguments\n" - " `f_c_to_u_to_a : (Cat) => (unit) => Animal`">>, - < unit) => Animal`">>, + <>, - <>, + <>, - <>, - <>, - <>, + <>, - <>, + <>, + <>, + <>, + <>, + <>, - < (Animal) => unit) => dt_contra_twice(Animal)`\nto arguments\n `f_a_to_c_to_u : (Animal) => (Cat) => unit`">>, + <>, + <>, + <> + "when checking the application of\n" + " `DT_CONTRA_TWICE : ((Animal) => (Animal) => unit) => dt_contra_twice(Animal)`\n" + "to arguments\n" + " `f_a_to_c_to_u : (Animal) => (Cat) => unit`">> ]) , ?TYPE_ERROR(polymorphism_variance_switching_records, [< unit) datatype dt_contra_nest_b('a) = DT_CONTRA_NEST_B(unit => dt_contra('a)) datatype dt_co_nest_b('a) = DT_CO_NEST_B(unit => dt_co('a)) - datatype dt_co_twice('a) = DT_CO_TWICE('a => unit => 'a) - datatype dt_a_co_b_contra('a, 'b) = DT_A_CO_B_CONTRA('a => 'b => unit) + datatype dt_co_twice('a) = DT_CO_TWICE(('a => unit) => 'a) + datatype dt_contra_twice('a) = DT_CONTRA_TWICE('a => 'a => unit) + datatype dt_a_contra_b_contra('a, 'b) = DT_A_CONTRA_B_CONTRA('a => 'b => unit) function f_a_to_a_to_u(_ : Animal) : (Animal => unit) = f_a_to_u function f_a_to_c_to_u(_ : Animal) : (Cat => unit) = f_c_to_u @@ -43,10 +44,10 @@ main contract Main = stateful function f_c_to_a(_ : Cat) : Animal = f_a() stateful function f_c_to_c(_ : Cat) : Cat = f_c() - stateful function f_a_to_u_to_a(_ : Animal) : (unit => Animal) = f_u_to_a - stateful function f_a_to_u_to_c(_ : Animal) : (unit => Cat) = f_u_to_c - stateful function f_c_to_u_to_a(_ : Cat) : (unit => Animal) = f_u_to_a - stateful function f_c_to_u_to_c(_ : Cat) : (unit => Cat) = f_u_to_c + stateful function f_a_to_u_to_a(_ : (Animal => unit)) : Animal = f_a() + stateful function f_a_to_u_to_c(_ : (Animal => unit)) : Cat = f_c() + stateful function f_c_to_u_to_a(_ : (Cat => unit)) : Animal = f_a() + stateful function f_c_to_u_to_c(_ : (Cat => unit)) : Cat = f_c() stateful function f_u_to_dt_co_a(_ : unit) : dt_co(Animal) = DT_CO(f_u_to_a) stateful function f_u_to_dt_co_c(_ : unit) : dt_co(Cat) = DT_CO(f_u_to_c) @@ -63,7 +64,7 @@ main contract Main = let vb4 : dt_co(Cat) = DT_CO(f_u_to_c) // success let vc1 : dt_inv(Animal) = DT_INV(f_a_to_a) // success - let vc2 : dt_inv(Animal) = DT_INV(f_a_to_c) // fail + let vc2 : dt_inv(Animal) = DT_INV(f_a_to_c) // success let vc3 : dt_inv(Animal) = DT_INV(f_c_to_a) // fail let vc4 : dt_inv(Animal) = DT_INV(f_c_to_c) // fail let vc5 : dt_inv(Cat) = DT_INV(f_a_to_a) // fail @@ -71,8 +72,8 @@ main contract Main = let vc7 : dt_inv(Cat) = DT_INV(f_c_to_a) // fail let vc8 : dt_inv(Cat) = DT_INV(f_c_to_c) // success - let vd1 : dt_biv(Animal) = DT_BIV(f_u_to_u) // success - let vd2 : dt_biv(Cat) = DT_BIV(f_u_to_u) // success + let vd1 : dt_biv(Animal) = DT_BIV(f_u_to_u) : dt_biv(Cat) // success + let vd2 : dt_biv(Cat) = DT_BIV(f_u_to_u) : dt_biv(Animal) // success let ve1 : dt_inv_sep(Animal) = DT_INV_SEP_A(f_a_to_u) // success let ve2 : dt_inv_sep(Animal) = DT_INV_SEP_A(f_c_to_u) // fail @@ -104,29 +105,37 @@ main contract Main = let vi4 : dt_co_nest_b(Cat) = DT_CO_NEST_B(f_u_to_dt_co_c) // success let vj1 : dt_co_twice(Animal) = DT_CO_TWICE(f_a_to_u_to_a) // success - let vj2 : dt_co_twice(Animal) = DT_CO_TWICE(f_a_to_u_to_c) // fail + let vj2 : dt_co_twice(Animal) = DT_CO_TWICE(f_a_to_u_to_c) // success let vj3 : dt_co_twice(Animal) = DT_CO_TWICE(f_c_to_u_to_a) // fail let vj4 : dt_co_twice(Animal) = DT_CO_TWICE(f_c_to_u_to_c) // success let vj5 : dt_co_twice(Cat) = DT_CO_TWICE(f_a_to_u_to_a) // fail let vj6 : dt_co_twice(Cat) = DT_CO_TWICE(f_a_to_u_to_c) // fail let vj7 : dt_co_twice(Cat) = DT_CO_TWICE(f_c_to_u_to_a) // fail - let vj8 : dt_co_twice(Cat) = DT_CO_TWICE(f_c_to_u_to_c) // success - let vk01 : dt_a_co_b_contra(Animal, Animal) = DT_A_CO_B_CONTRA(f_a_to_a_to_u) // success - let vk02 : dt_a_co_b_contra(Animal, Animal) = DT_A_CO_B_CONTRA(f_a_to_c_to_u) // fail - let vk03 : dt_a_co_b_contra(Animal, Animal) = DT_A_CO_B_CONTRA(f_c_to_a_to_u) // success - let vk04 : dt_a_co_b_contra(Animal, Animal) = DT_A_CO_B_CONTRA(f_c_to_c_to_u) // fail - let vk05 : dt_a_co_b_contra(Animal, Cat) = DT_A_CO_B_CONTRA(f_a_to_a_to_u) // success - let vk06 : dt_a_co_b_contra(Animal, Cat) = DT_A_CO_B_CONTRA(f_a_to_c_to_u) // success - let vk07 : dt_a_co_b_contra(Animal, Cat) = DT_A_CO_B_CONTRA(f_c_to_a_to_u) // success - let vk08 : dt_a_co_b_contra(Animal, Cat) = DT_A_CO_B_CONTRA(f_c_to_c_to_u) // success - let vk09 : dt_a_co_b_contra(Cat, Animal) = DT_A_CO_B_CONTRA(f_a_to_a_to_u) // fail - let vk10 : dt_a_co_b_contra(Cat, Animal) = DT_A_CO_B_CONTRA(f_a_to_c_to_u) // fail - let vk11 : dt_a_co_b_contra(Cat, Animal) = DT_A_CO_B_CONTRA(f_c_to_a_to_u) // success - let vk12 : dt_a_co_b_contra(Cat, Animal) = DT_A_CO_B_CONTRA(f_c_to_c_to_u) // fail - let vk13 : dt_a_co_b_contra(Cat, Cat) = DT_A_CO_B_CONTRA(f_a_to_a_to_u) // fail - let vk14 : dt_a_co_b_contra(Cat, Cat) = DT_A_CO_B_CONTRA(f_a_to_c_to_u) // fail - let vk15 : dt_a_co_b_contra(Cat, Cat) = DT_A_CO_B_CONTRA(f_c_to_a_to_u) // success - let vk16 : dt_a_co_b_contra(Cat, Cat) = DT_A_CO_B_CONTRA(f_c_to_c_to_u) // success + let vk01 : dt_a_contra_b_contra(Animal, Animal) = DT_A_CONTRA_B_CONTRA(f_a_to_a_to_u) // success + let vk02 : dt_a_contra_b_contra(Animal, Animal) = DT_A_CONTRA_B_CONTRA(f_a_to_c_to_u) // fail + let vk03 : dt_a_contra_b_contra(Animal, Animal) = DT_A_CONTRA_B_CONTRA(f_c_to_a_to_u) // fail + let vk04 : dt_a_contra_b_contra(Animal, Animal) = DT_A_CONTRA_B_CONTRA(f_c_to_c_to_u) // fail + let vk05 : dt_a_contra_b_contra(Animal, Cat) = DT_A_CONTRA_B_CONTRA(f_a_to_a_to_u) // success + let vk06 : dt_a_contra_b_contra(Animal, Cat) = DT_A_CONTRA_B_CONTRA(f_a_to_c_to_u) // success + let vk07 : dt_a_contra_b_contra(Animal, Cat) = DT_A_CONTRA_B_CONTRA(f_c_to_a_to_u) // fail + let vk08 : dt_a_contra_b_contra(Animal, Cat) = DT_A_CONTRA_B_CONTRA(f_c_to_c_to_u) // fail + let vk09 : dt_a_contra_b_contra(Cat, Animal) = DT_A_CONTRA_B_CONTRA(f_a_to_a_to_u) // success + let vk10 : dt_a_contra_b_contra(Cat, Animal) = DT_A_CONTRA_B_CONTRA(f_a_to_c_to_u) // fail + let vk11 : dt_a_contra_b_contra(Cat, Animal) = DT_A_CONTRA_B_CONTRA(f_c_to_a_to_u) // success + let vk12 : dt_a_contra_b_contra(Cat, Animal) = DT_A_CONTRA_B_CONTRA(f_c_to_c_to_u) // fail + let vk13 : dt_a_contra_b_contra(Cat, Cat) = DT_A_CONTRA_B_CONTRA(f_a_to_a_to_u) // success + let vk14 : dt_a_contra_b_contra(Cat, Cat) = DT_A_CONTRA_B_CONTRA(f_a_to_c_to_u) // success + let vk15 : dt_a_contra_b_contra(Cat, Cat) = DT_A_CONTRA_B_CONTRA(f_c_to_a_to_u) // success + let vk16 : dt_a_contra_b_contra(Cat, Cat) = DT_A_CONTRA_B_CONTRA(f_c_to_c_to_u) // success + + let vl1 : dt_contra_twice(Animal) = DT_CONTRA_TWICE(f_a_to_a_to_u) // success + let vl2 : dt_contra_twice(Animal) = DT_CONTRA_TWICE(f_a_to_c_to_u) // fail + let vl3 : dt_contra_twice(Animal) = DT_CONTRA_TWICE(f_c_to_a_to_u) // fail + let vl4 : dt_contra_twice(Animal) = DT_CONTRA_TWICE(f_c_to_c_to_u) // fail + let vl5 : dt_contra_twice(Cat) = DT_CONTRA_TWICE(f_a_to_a_to_u) // success + let vl6 : dt_contra_twice(Cat) = DT_CONTRA_TWICE(f_a_to_c_to_u) // fail + let vl7 : dt_contra_twice(Cat) = DT_CONTRA_TWICE(f_c_to_a_to_u) // success + let vl8 : dt_contra_twice(Cat) = DT_CONTRA_TWICE(f_c_to_c_to_u) // success ()