From 3b0ca28c8e779567e860b8337b3508e801231b75 Mon Sep 17 00:00:00 2001 From: Hans Svensson Date: Wed, 23 Aug 2023 09:43:49 +0200 Subject: [PATCH] Improve constraint solving (#480) * Clean up constraint solving a bit * Make unify always return true or false * Remove unused unify_throws field from Env * Better structure for constraint solving * Fix formatting of if_branches error * More cleanup --- src/aeso_ast_infer_types.erl | 328 ++++++++++++++++------------------- 1 file changed, 151 insertions(+), 177 deletions(-) diff --git a/src/aeso_ast_infer_types.erl b/src/aeso_ast_infer_types.erl index d74358b..ae75468 100644 --- a/src/aeso_ast_infer_types.erl +++ b/src/aeso_ast_infer_types.erl @@ -154,7 +154,6 @@ , in_pattern = false :: boolean() , in_guard = false :: boolean() , stateful = false :: boolean() - , unify_throws = true :: boolean() , current_const = none :: none | aeso_syntax:id() , current_function = none :: none | aeso_syntax:id() , what = top :: top | namespace | contract | contract_interface @@ -1661,7 +1660,7 @@ infer_letrec(Env, Defs) -> Got = proplists:get_value(Name, Funs), Expect = typesig_to_fun_t(TypeSig), unify(Env, Got, Expect, {check_typesig, Name, Got, Expect}), - solve_constraints(Env), + solve_all_constraints(Env), ?PRINT_TYPES("Checked ~s : ~s\n", [Name, pp(dereference_deep(Got))]), Res @@ -2576,61 +2575,118 @@ get_constraints() -> destroy_constraints() -> ets_delete(constraints). --spec solve_constraints(env()) -> ok. -solve_constraints(Env) -> - %% First look for record fields that appear in only one type definition - IsAmbiguous = - fun(#field_constraint{ - record_t = RecordType, - field = Field={id, _Attrs, FieldName}, - field_t = FieldType, - kind = Kind, - context = When }) -> - Arity = fun_arity(dereference_deep(FieldType)), - FieldInfos = case Arity of - none -> lookup_record_field(Env, FieldName, Kind); - _ -> lookup_record_field_arity(Env, FieldName, Arity, Kind) - end, - case FieldInfos of - [] -> - type_error({undefined_field, Field}), - false; - [#field_info{field_t = FldType, record_t = RecType}] -> - create_freshen_tvars(), - FreshFldType = freshen(FldType), - FreshRecType = freshen(RecType), - destroy_freshen_tvars(), - unify(Env, FreshFldType, FieldType, {field_constraint, FreshFldType, FieldType, When}), - unify(Env, FreshRecType, RecordType, {record_constraint, FreshRecType, RecordType, When}), - false; - _ -> - %% ambiguity--need cleverer strategy - true - end; - (_) -> true - end, - AmbiguousConstraints = lists:filter(IsAmbiguous, get_constraints()), +%% Solve all constraints by iterating until no-progress - % The two passes on AmbiguousConstraints are needed - solve_ambiguous_constraints(Env, AmbiguousConstraints ++ AmbiguousConstraints). +-spec solve_all_constraints(env()) -> ok. +solve_all_constraints(Env) -> + Constraints = [C || C <- get_constraints(), not one_shot_field_constraint(Env, C) ], + solve_constraints_top(Env, Constraints). --spec solve_ambiguous_constraints(env(), [constraint()]) -> ok. -solve_ambiguous_constraints(Env, Constraints) -> - Unknown = solve_known_record_types(Env, Constraints), - if Unknown == [] -> ok; - length(Unknown) < length(Constraints) -> - %% progress! Keep trying. - solve_ambiguous_constraints(Env, Unknown); +solve_constraints_top(Env, Constraints) -> + UnsolvedCs = solve_constraints(Env, Constraints), + Progress = solve_unknown_record_constraints(Env, UnsolvedCs), + + if length(UnsolvedCs) < length(Constraints) orelse Progress == true -> + solve_constraints_top(Env, UnsolvedCs); true -> - case solve_unknown_record_types(Env, Unknown) of - true -> %% Progress! - solve_ambiguous_constraints(Env, Unknown); - _ -> ok %% No progress. Report errors later. - end + ok end. +-spec solve_constraints(env(), [constraint()]) -> [constraint()]. +solve_constraints(Env, Constraints) -> + [ C1 || C <- Constraints, C1 <- [dereference_deep(C)], not solve_constraint(Env, C1) ]. + +solve_unknown_record_constraints(Env, Constraints) -> + FieldCs = lists:filter(fun(#field_constraint{record_t = {uvar, _, _}}) -> true; (_) -> false end, Constraints), + FieldCsUVars = lists:usort([UVar || #field_constraint{record_t = UVar = {uvar, _, _}} <- FieldCs]), + + FieldConstraint = fun(#field_constraint{ field = F, kind = K, context = Ctx }) -> {K, Ctx, F} end, + FieldsForUVar = fun(UVar) -> + [ FieldConstraint(FC) || FC = #field_constraint{record_t = U} <- FieldCs, U == UVar ] + end, + + + Solutions = [ solve_for_uvar(Env, UVar, FieldsForUVar(UVar)) || UVar <- FieldCsUVars ], + case lists:member(true, Solutions) of + true -> true; + false -> Solutions + end. + +%% -- Simple constraints -- +%% Returns true if solved (unified or type error) +solve_constraint(_Env, #field_constraint{record_t = {uvar, _, _}}) -> + false; +solve_constraint(Env, #field_constraint{record_t = RecordType, + field = Field = {id, _As, FieldName}, + field_t = FieldType, + context = When}) -> + RecId = record_type_name(RecordType), + Attrs = aeso_syntax:get_ann(RecId), + case lookup_type(Env, RecId) of + {_, {_Ann, {Formals, {What, Fields}}}} when What =:= record_t; What =:= contract_t -> + FieldTypes = [{Name, Type} || {field_t, _, {id, _, Name}, Type} <- Fields], + case proplists:get_value(FieldName, FieldTypes) of + undefined -> + type_error({missing_field, Field, RecId}); + FldType -> + solve_field_constraint(Env, FieldType, FldType, RecordType, app_t(Attrs, RecId, Formals), When) + end; + _ -> + type_error({not_a_record_type, instantiate(RecordType), When}) + end, + true; +solve_constraint(Env, C = #dependent_type_constraint{}) -> + check_named_argument_constraint(Env, C); +solve_constraint(Env, C = #named_argument_constraint{}) -> + check_named_argument_constraint(Env, C); +solve_constraint(_Env, {is_bytes, _}) -> false; +solve_constraint(Env, {add_bytes, Ann, _, A0, B0, C0}) -> + A = unfold_types_in_type(Env, dereference(A0)), + B = unfold_types_in_type(Env, dereference(B0)), + C = unfold_types_in_type(Env, dereference(C0)), + case {A, B, C} of + {{bytes_t, _, M}, {bytes_t, _, N}, _} -> unify(Env, {bytes_t, Ann, M + N}, C, {at, Ann}); + {{bytes_t, _, M}, _, {bytes_t, _, R}} when R >= M -> unify(Env, {bytes_t, Ann, R - M}, B, {at, Ann}); + {_, {bytes_t, _, N}, {bytes_t, _, R}} when R >= N -> unify(Env, {bytes_t, Ann, R - N}, A, {at, Ann}); + _ -> false + end; +solve_constraint(_, _) -> false. + +one_shot_field_constraint(Env, #field_constraint{record_t = RecordType, + field = Field = {id, _As, FieldName}, + field_t = FieldType, + kind = Kind, + context = When}) -> + Arity = fun_arity(dereference_deep(FieldType)), + FieldInfos = case Arity of + none -> lookup_record_field(Env, FieldName, Kind); + _ -> lookup_record_field_arity(Env, FieldName, Arity, Kind) + end, + + case FieldInfos of + [] -> + type_error({undefined_field, Field}), + true; + [#field_info{field_t = FldType, record_t = RecType}] -> + solve_field_constraint(Env, FieldType, FldType, RecordType, RecType, When), + true; + _ -> + false + end; +one_shot_field_constraint(_Env, _Constraint) -> + false. + + +solve_field_constraint(Env, FieldType, FldType, RecordType, RecType, When) -> + create_freshen_tvars(), + FreshFldType = freshen(FldType), + FreshRecType = freshen(RecType), + destroy_freshen_tvars(), + unify(Env, FreshFldType, FieldType, {field_constraint, FreshFldType, FieldType, When}), + unify(Env, FreshRecType, RecordType, {record_constraint, FreshRecType, RecordType, When}). + solve_then_destroy_and_report_unsolved_constraints(Env) -> - solve_constraints(Env), + solve_all_constraints(Env), destroy_and_report_unsolved_constraints(Env). destroy_and_report_unsolved_constraints(Env) -> @@ -2661,21 +2717,10 @@ destroy_and_report_unsolved_constraints(Env) -> (_) -> false end, OtherCs5), - Unsolved = [ S || S <- [ solve_constraint(Env, dereference_deep(C)) || C <- NamedArgCs ], - S == unsolved ], - [ type_error({unsolved_named_argument_constraint, C}) || C <- Unsolved ], - - Unknown = solve_known_record_types(Env, FieldCs), - if Unknown == [] -> ok; - true -> - case solve_unknown_record_types(Env, Unknown) of - true -> ok; - Errors -> [ type_error(Err) || Err <- Errors ] - end - end, - + check_field_constraints(Env, FieldCs), check_record_create_constraints(Env, CreateCs), check_is_contract_constraints(Env, ContractCs), + check_named_args_constraints(Env, NamedArgCs), check_bytes_constraints(Env, BytesCs), check_aens_resolve_constraints(Env, AensResolveCs), check_oracle_type_constraints(Env, OracleTypeCs), @@ -2693,20 +2738,21 @@ get_oracle_type(_Fun, _Args, _Ret) -> false. %% -- Named argument constraints -- -%% If false, a type error has been emitted, so it's safe to drop the constraint. --spec check_named_argument_constraint(env(), named_argument_constraint()) -> true | false | unsolved. +%% True if solved (unified or type error), false otherwise +-spec check_named_argument_constraint(env(), named_argument_constraint()) -> true | false. check_named_argument_constraint(_Env, #named_argument_constraint{ args = {uvar, _, _} }) -> - unsolved; + false; check_named_argument_constraint(Env, C = #named_argument_constraint{ args = Args, name = Id = {id, _, Name}, type = Type }) -> case [ T || {named_arg_t, _, {id, _, Name1}, T, _} <- Args, Name1 == Name ] of [] -> - type_error({bad_named_argument, Args, Id}), - false; - [T] -> unify(Env, T, Type, {check_named_arg_constraint, C}), true - end; + type_error({bad_named_argument, Args, Id}); + [T] -> + unify(Env, T, Type, {check_named_arg_constraint, C}) + end, + true; check_named_argument_constraint(Env, #dependent_type_constraint{ named_args_t = NamedArgsT0, named_args = NamedArgs, @@ -2723,10 +2769,11 @@ check_named_argument_constraint(Env, ArgEnv = maps:from_list([ {Name, GetVal(Name, Default)} || {named_arg_t, _, {id, _, Name}, _, Default} <- NamedArgsT ]), GenType1 = specialize_dependent_type(ArgEnv, GenType), - unify(Env, GenType1, SpecType, {check_expr, App, GenType1, SpecType}), - true; - _ -> unify(Env, GenType, SpecType, {check_expr, App, GenType, SpecType}), true - end. + unify(Env, GenType1, SpecType, {check_expr, App, GenType1, SpecType}); + _ -> + unify(Env, GenType, SpecType, {check_expr, App, GenType, SpecType}) + end, + true. specialize_dependent_type(Env, Type) -> case dereference(Type) of @@ -2742,53 +2789,16 @@ specialize_dependent_type(Env, Type) -> _ -> Type %% Currently no deep dependent types end. -%% -- Bytes constraints -- +check_field_constraints(Env, Constraints) -> + UnsolvedFieldCs = solve_constraints(Env, Constraints), + case solve_unknown_record_constraints(Env, UnsolvedFieldCs) of + true -> ok; + Errors -> [ type_error(Err) || Err <- Errors ] + end. -solve_constraint(_Env, #field_constraint{record_t = {uvar, _, _}}) -> - not_solved; -solve_constraint(Env, C = #field_constraint{record_t = RecType, - field = FieldName, - field_t = FieldType, - context = When}) -> - RecId = record_type_name(RecType), - Attrs = aeso_syntax:get_ann(RecId), - case lookup_type(Env, RecId) of - {_, {_Ann, {Formals, {What, Fields}}}} when What =:= record_t; What =:= contract_t -> - FieldTypes = [{Name, Type} || {field_t, _, {id, _, Name}, Type} <- Fields], - {id, _, FieldString} = FieldName, - case proplists:get_value(FieldString, FieldTypes) of - undefined -> - type_error({missing_field, FieldName, RecId}), - not_solved; - FldType -> - create_freshen_tvars(), - FreshFldType = freshen(FldType), - FreshRecType = freshen(app_t(Attrs, RecId, Formals)), - destroy_freshen_tvars(), - unify(Env, FreshFldType, FieldType, {field_constraint, FreshFldType, FieldType, When}), - unify(Env, FreshRecType, RecType, {record_constraint, FreshRecType, RecType, When}), - C - end; - _ -> - type_error({not_a_record_type, instantiate(RecType), When}), - not_solved - end; -solve_constraint(Env, C = #dependent_type_constraint{}) -> - check_named_argument_constraint(Env, C); -solve_constraint(Env, C = #named_argument_constraint{}) -> - check_named_argument_constraint(Env, C); -solve_constraint(_Env, {is_bytes, _}) -> ok; -solve_constraint(Env, {add_bytes, Ann, _, A0, B0, C0}) -> - A = unfold_types_in_type(Env, dereference(A0)), - B = unfold_types_in_type(Env, dereference(B0)), - C = unfold_types_in_type(Env, dereference(C0)), - case {A, B, C} of - {{bytes_t, _, M}, {bytes_t, _, N}, _} -> unify(Env, {bytes_t, Ann, M + N}, C, {at, Ann}); - {{bytes_t, _, M}, _, {bytes_t, _, R}} when R >= M -> unify(Env, {bytes_t, Ann, R - M}, B, {at, Ann}); - {_, {bytes_t, _, N}, {bytes_t, _, R}} when R >= N -> unify(Env, {bytes_t, Ann, R - N}, A, {at, Ann}); - _ -> ok - end; -solve_constraint(_, _) -> ok. +check_named_args_constraints(Env, Constraints) -> + UnsolvedNamedArgCs = solve_constraints(Env, Constraints), + [ type_error({unsolved_named_argument_constraint, C}) || C <- UnsolvedNamedArgCs ]. check_bytes_constraints(Env, Constraints) -> InAddConstraint = [ T || {add_bytes, _, _, A, B, C} <- Constraints, @@ -2885,30 +2895,6 @@ check_is_contract_constraints(Env, [C | Cs]) -> end, check_is_contract_constraints(Env, Cs). --spec solve_unknown_record_types(env(), [field_constraint()]) -> true | [tuple()]. -solve_unknown_record_types(Env, Unknown) -> - UVars = lists:usort([UVar || #field_constraint{record_t = UVar = {uvar, _, _}} <- Unknown]), - Solutions = [solve_for_uvar(Env, UVar, [{Kind, When, Field} - || #field_constraint{record_t = U, field = Field, kind = Kind, context = When} <- Unknown, - U == UVar]) - || UVar <- UVars], - case lists:member(true, Solutions) of - true -> true; - false -> Solutions - end. - -%% This will solve all kinds of constraints but will only return the -%% unsolved field constraints --spec solve_known_record_types(env(), [constraint()]) -> [field_constraint()]. -solve_known_record_types(Env, Constraints) -> - DerefConstraints = lists:map(fun(C = #field_constraint{record_t = RecordType}) -> - C#field_constraint{record_t = dereference(RecordType)}; - (C) -> dereference_deep(C) - end, Constraints), - SolvedConstraints = lists:map(fun(C) -> solve_constraint(Env, dereference_deep(C)) end, DerefConstraints), - Unsolved = DerefConstraints--SolvedConstraints, - lists:filter(fun(#field_constraint{}) -> true; (_) -> false end, Unsolved). - record_type_name({app_t, _Attrs, RecId, _Args}) when ?is_type_id(RecId) -> RecId; record_type_name(RecId) when ?is_type_id(RecId) -> @@ -3087,16 +3073,12 @@ unify0(Env, A, B, Variance, When) -> unify1(_Env, {uvar, _, R}, {uvar, _, R}, _Variance, _When) -> true; unify1(_Env, {uvar, _, _}, {fun_t, _, _, var_args, _}, _Variance, When) -> - type_error({unify_varargs, When}); -unify1(Env, {uvar, A, R}, T, _Variance, When) -> + type_error({unify_varargs, When}), + false; +unify1(_Env, {uvar, A, R}, T, _Variance, When) -> case occurs_check(R, T) of true -> - if - Env#env.unify_throws -> - cannot_unify({uvar, A, R}, T, none, When); - true -> - ok - end, + cannot_unify({uvar, A, R}, T, none, When), false; false -> ets_insert(type_vars, {R, T}), @@ -3123,18 +3105,13 @@ unify1(Env, A = {con, _, NameA}, B = {con, _, NameB}, Variance, When) -> case is_subtype(Env, NameA, NameB, Variance) of true -> true; false -> - if - Env#env.unify_throws -> - IsSubtype = is_subtype(Env, NameA, NameB, contravariant) orelse - is_subtype(Env, NameA, NameB, covariant), - Cxt = case IsSubtype of - true -> Variance; - false -> none - end, - cannot_unify(A, B, Cxt, When); - true -> - ok - end, + IsSubtype = is_subtype(Env, NameA, NameB, contravariant) orelse + is_subtype(Env, NameA, NameB, covariant), + Cxt = case IsSubtype of + true -> Variance; + false -> none + end, + cannot_unify(A, B, Cxt, When), false end; unify1(_Env, {qid, _, Name}, {qid, _, Name}, _Variance, _When) -> @@ -3148,9 +3125,11 @@ unify1(Env, {if_t, _, {id, _, Id}, Then1, Else1}, {if_t, _, {id, _, Id}, Then2, unify0(Env, Else1, Else2, Variance, When); unify1(_Env, {fun_t, _, _, _, _}, {fun_t, _, _, var_args, _}, _Variance, When) -> - type_error({unify_varargs, When}); + type_error({unify_varargs, When}), + false; unify1(_Env, {fun_t, _, _, var_args, _}, {fun_t, _, _, _, _}, _Variance, When) -> - type_error({unify_varargs, When}); + type_error({unify_varargs, When}), + false; unify1(Env, {fun_t, _, Named1, Args1, Result1}, {fun_t, _, Named2, Args2, Result2}, Variance, When) when length(Args1) == length(Args2) -> unify0(Env, Named1, Named2, opposite_variance(Variance), When) andalso @@ -3172,7 +3151,7 @@ unify1(Env, {tuple_t, _, As}, {tuple_t, _, Bs}, Variance, When) when length(As) == length(Bs) -> unify0(Env, As, Bs, Variance, When); unify1(Env, {named_arg_t, _, Id1, Type1, _}, {named_arg_t, _, Id2, Type2, _}, Variance, When) -> - unify1(Env, Id1, Id2, Variance, {arg_name, Id1, Id2, When}), + unify1(Env, Id1, Id2, Variance, {arg_name, Id1, Id2, When}) andalso unify1(Env, Type1, Type2, Variance, When); %% The grammar is a bit inconsistent about whether types without %% arguments are represented as applications to an empty list of @@ -3181,13 +3160,8 @@ unify1(Env, {app_t, _, T, []}, B, Variance, When) -> unify0(Env, T, B, Variance, When); unify1(Env, A, {app_t, _, T, []}, Variance, When) -> unify0(Env, A, T, Variance, When); -unify1(Env, A, B, _Variance, When) -> - if - Env#env.unify_throws -> - cannot_unify(A, B, none, When); - true -> - ok - end, +unify1(_Env, A, B, _Variance, When) -> + cannot_unify(A, B, none, When), false. is_subtype(_Env, NameA, NameB, invariant) -> @@ -4120,8 +4094,8 @@ pp_when({if_branches, Then, ThenType0, Else, ElseType0}) -> Branches = [ {Then, ThenType} | [ {B, ElseType} || B <- if_branches(Else) ] ], {pos(element(1, hd(Branches))), io_lib:format("when comparing the types of the if-branches\n" - "~s", [ [ io_lib:format("~s (at ~s)\n", [pp_typed(" - ", B, BType), pp_loc(B)]) - || {B, BType} <- Branches ] ])}; + "~s", [string:join([ io_lib:format("~s (at ~s)", [pp_typed(" - ", B, BType), pp_loc(B)]) + || {B, BType} <- Branches ], "\n")])}; pp_when({case_pat, Pat, PatType0, ExprType0}) -> {PatType, ExprType} = instantiate({PatType0, ExprType0}), {pos(Pat),