Improve constraint solving (#480)

* Clean up constraint solving a bit

* Make unify always return true or false

* Remove unused unify_throws field from Env

* Better structure for constraint solving

* Fix formatting of if_branches error

* More cleanup
This commit is contained in:
Hans Svensson 2023-08-23 09:43:49 +02:00 committed by GitHub
parent 86d7b36ba7
commit 3b0ca28c8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -154,7 +154,6 @@
, in_pattern = false :: boolean() , in_pattern = false :: boolean()
, in_guard = false :: boolean() , in_guard = false :: boolean()
, stateful = false :: boolean() , stateful = false :: boolean()
, unify_throws = true :: boolean()
, current_const = none :: none | aeso_syntax:id() , current_const = none :: none | aeso_syntax:id()
, current_function = none :: none | aeso_syntax:id() , current_function = none :: none | aeso_syntax:id()
, what = top :: top | namespace | contract | contract_interface , what = top :: top | namespace | contract | contract_interface
@ -1661,7 +1660,7 @@ infer_letrec(Env, Defs) ->
Got = proplists:get_value(Name, Funs), Got = proplists:get_value(Name, Funs),
Expect = typesig_to_fun_t(TypeSig), Expect = typesig_to_fun_t(TypeSig),
unify(Env, Got, Expect, {check_typesig, Name, Got, Expect}), unify(Env, Got, Expect, {check_typesig, Name, Got, Expect}),
solve_constraints(Env), solve_all_constraints(Env),
?PRINT_TYPES("Checked ~s : ~s\n", ?PRINT_TYPES("Checked ~s : ~s\n",
[Name, pp(dereference_deep(Got))]), [Name, pp(dereference_deep(Got))]),
Res Res
@ -2576,61 +2575,118 @@ get_constraints() ->
destroy_constraints() -> destroy_constraints() ->
ets_delete(constraints). ets_delete(constraints).
-spec solve_constraints(env()) -> ok. %% Solve all constraints by iterating until no-progress
solve_constraints(Env) ->
%% First look for record fields that appear in only one type definition -spec solve_all_constraints(env()) -> ok.
IsAmbiguous = solve_all_constraints(Env) ->
fun(#field_constraint{ Constraints = [C || C <- get_constraints(), not one_shot_field_constraint(Env, C) ],
record_t = RecordType, solve_constraints_top(Env, Constraints).
field = Field={id, _Attrs, FieldName},
solve_constraints_top(Env, Constraints) ->
UnsolvedCs = solve_constraints(Env, Constraints),
Progress = solve_unknown_record_constraints(Env, UnsolvedCs),
if length(UnsolvedCs) < length(Constraints) orelse Progress == true ->
solve_constraints_top(Env, UnsolvedCs);
true ->
ok
end.
-spec solve_constraints(env(), [constraint()]) -> [constraint()].
solve_constraints(Env, Constraints) ->
[ C1 || C <- Constraints, C1 <- [dereference_deep(C)], not solve_constraint(Env, C1) ].
solve_unknown_record_constraints(Env, Constraints) ->
FieldCs = lists:filter(fun(#field_constraint{record_t = {uvar, _, _}}) -> true; (_) -> false end, Constraints),
FieldCsUVars = lists:usort([UVar || #field_constraint{record_t = UVar = {uvar, _, _}} <- FieldCs]),
FieldConstraint = fun(#field_constraint{ field = F, kind = K, context = Ctx }) -> {K, Ctx, F} end,
FieldsForUVar = fun(UVar) ->
[ FieldConstraint(FC) || FC = #field_constraint{record_t = U} <- FieldCs, U == UVar ]
end,
Solutions = [ solve_for_uvar(Env, UVar, FieldsForUVar(UVar)) || UVar <- FieldCsUVars ],
case lists:member(true, Solutions) of
true -> true;
false -> Solutions
end.
%% -- Simple constraints --
%% Returns true if solved (unified or type error)
solve_constraint(_Env, #field_constraint{record_t = {uvar, _, _}}) ->
false;
solve_constraint(Env, #field_constraint{record_t = RecordType,
field = Field = {id, _As, FieldName},
field_t = FieldType,
context = When}) ->
RecId = record_type_name(RecordType),
Attrs = aeso_syntax:get_ann(RecId),
case lookup_type(Env, RecId) of
{_, {_Ann, {Formals, {What, Fields}}}} when What =:= record_t; What =:= contract_t ->
FieldTypes = [{Name, Type} || {field_t, _, {id, _, Name}, Type} <- Fields],
case proplists:get_value(FieldName, FieldTypes) of
undefined ->
type_error({missing_field, Field, RecId});
FldType ->
solve_field_constraint(Env, FieldType, FldType, RecordType, app_t(Attrs, RecId, Formals), When)
end;
_ ->
type_error({not_a_record_type, instantiate(RecordType), When})
end,
true;
solve_constraint(Env, C = #dependent_type_constraint{}) ->
check_named_argument_constraint(Env, C);
solve_constraint(Env, C = #named_argument_constraint{}) ->
check_named_argument_constraint(Env, C);
solve_constraint(_Env, {is_bytes, _}) -> false;
solve_constraint(Env, {add_bytes, Ann, _, A0, B0, C0}) ->
A = unfold_types_in_type(Env, dereference(A0)),
B = unfold_types_in_type(Env, dereference(B0)),
C = unfold_types_in_type(Env, dereference(C0)),
case {A, B, C} of
{{bytes_t, _, M}, {bytes_t, _, N}, _} -> unify(Env, {bytes_t, Ann, M + N}, C, {at, Ann});
{{bytes_t, _, M}, _, {bytes_t, _, R}} when R >= M -> unify(Env, {bytes_t, Ann, R - M}, B, {at, Ann});
{_, {bytes_t, _, N}, {bytes_t, _, R}} when R >= N -> unify(Env, {bytes_t, Ann, R - N}, A, {at, Ann});
_ -> false
end;
solve_constraint(_, _) -> false.
one_shot_field_constraint(Env, #field_constraint{record_t = RecordType,
field = Field = {id, _As, FieldName},
field_t = FieldType, field_t = FieldType,
kind = Kind, kind = Kind,
context = When }) -> context = When}) ->
Arity = fun_arity(dereference_deep(FieldType)), Arity = fun_arity(dereference_deep(FieldType)),
FieldInfos = case Arity of FieldInfos = case Arity of
none -> lookup_record_field(Env, FieldName, Kind); none -> lookup_record_field(Env, FieldName, Kind);
_ -> lookup_record_field_arity(Env, FieldName, Arity, Kind) _ -> lookup_record_field_arity(Env, FieldName, Arity, Kind)
end, end,
case FieldInfos of case FieldInfos of
[] -> [] ->
type_error({undefined_field, Field}), type_error({undefined_field, Field}),
false; true;
[#field_info{field_t = FldType, record_t = RecType}] -> [#field_info{field_t = FldType, record_t = RecType}] ->
solve_field_constraint(Env, FieldType, FldType, RecordType, RecType, When),
true;
_ ->
false
end;
one_shot_field_constraint(_Env, _Constraint) ->
false.
solve_field_constraint(Env, FieldType, FldType, RecordType, RecType, When) ->
create_freshen_tvars(), create_freshen_tvars(),
FreshFldType = freshen(FldType), FreshFldType = freshen(FldType),
FreshRecType = freshen(RecType), FreshRecType = freshen(RecType),
destroy_freshen_tvars(), destroy_freshen_tvars(),
unify(Env, FreshFldType, FieldType, {field_constraint, FreshFldType, FieldType, When}), unify(Env, FreshFldType, FieldType, {field_constraint, FreshFldType, FieldType, When}),
unify(Env, FreshRecType, RecordType, {record_constraint, FreshRecType, RecordType, When}), unify(Env, FreshRecType, RecordType, {record_constraint, FreshRecType, RecordType, When}).
false;
_ ->
%% ambiguity--need cleverer strategy
true
end;
(_) -> true
end,
AmbiguousConstraints = lists:filter(IsAmbiguous, get_constraints()),
% The two passes on AmbiguousConstraints are needed
solve_ambiguous_constraints(Env, AmbiguousConstraints ++ AmbiguousConstraints).
-spec solve_ambiguous_constraints(env(), [constraint()]) -> ok.
solve_ambiguous_constraints(Env, Constraints) ->
Unknown = solve_known_record_types(Env, Constraints),
if Unknown == [] -> ok;
length(Unknown) < length(Constraints) ->
%% progress! Keep trying.
solve_ambiguous_constraints(Env, Unknown);
true ->
case solve_unknown_record_types(Env, Unknown) of
true -> %% Progress!
solve_ambiguous_constraints(Env, Unknown);
_ -> ok %% No progress. Report errors later.
end
end.
solve_then_destroy_and_report_unsolved_constraints(Env) -> solve_then_destroy_and_report_unsolved_constraints(Env) ->
solve_constraints(Env), solve_all_constraints(Env),
destroy_and_report_unsolved_constraints(Env). destroy_and_report_unsolved_constraints(Env).
destroy_and_report_unsolved_constraints(Env) -> destroy_and_report_unsolved_constraints(Env) ->
@ -2661,21 +2717,10 @@ destroy_and_report_unsolved_constraints(Env) ->
(_) -> false (_) -> false
end, OtherCs5), end, OtherCs5),
Unsolved = [ S || S <- [ solve_constraint(Env, dereference_deep(C)) || C <- NamedArgCs ], check_field_constraints(Env, FieldCs),
S == unsolved ],
[ type_error({unsolved_named_argument_constraint, C}) || C <- Unsolved ],
Unknown = solve_known_record_types(Env, FieldCs),
if Unknown == [] -> ok;
true ->
case solve_unknown_record_types(Env, Unknown) of
true -> ok;
Errors -> [ type_error(Err) || Err <- Errors ]
end
end,
check_record_create_constraints(Env, CreateCs), check_record_create_constraints(Env, CreateCs),
check_is_contract_constraints(Env, ContractCs), check_is_contract_constraints(Env, ContractCs),
check_named_args_constraints(Env, NamedArgCs),
check_bytes_constraints(Env, BytesCs), check_bytes_constraints(Env, BytesCs),
check_aens_resolve_constraints(Env, AensResolveCs), check_aens_resolve_constraints(Env, AensResolveCs),
check_oracle_type_constraints(Env, OracleTypeCs), check_oracle_type_constraints(Env, OracleTypeCs),
@ -2693,20 +2738,21 @@ get_oracle_type(_Fun, _Args, _Ret) -> false.
%% -- Named argument constraints -- %% -- Named argument constraints --
%% If false, a type error has been emitted, so it's safe to drop the constraint. %% True if solved (unified or type error), false otherwise
-spec check_named_argument_constraint(env(), named_argument_constraint()) -> true | false | unsolved. -spec check_named_argument_constraint(env(), named_argument_constraint()) -> true | false.
check_named_argument_constraint(_Env, #named_argument_constraint{ args = {uvar, _, _} }) -> check_named_argument_constraint(_Env, #named_argument_constraint{ args = {uvar, _, _} }) ->
unsolved; false;
check_named_argument_constraint(Env, check_named_argument_constraint(Env,
C = #named_argument_constraint{ args = Args, C = #named_argument_constraint{ args = Args,
name = Id = {id, _, Name}, name = Id = {id, _, Name},
type = Type }) -> type = Type }) ->
case [ T || {named_arg_t, _, {id, _, Name1}, T, _} <- Args, Name1 == Name ] of case [ T || {named_arg_t, _, {id, _, Name1}, T, _} <- Args, Name1 == Name ] of
[] -> [] ->
type_error({bad_named_argument, Args, Id}), type_error({bad_named_argument, Args, Id});
false; [T] ->
[T] -> unify(Env, T, Type, {check_named_arg_constraint, C}), true unify(Env, T, Type, {check_named_arg_constraint, C})
end; end,
true;
check_named_argument_constraint(Env, check_named_argument_constraint(Env,
#dependent_type_constraint{ named_args_t = NamedArgsT0, #dependent_type_constraint{ named_args_t = NamedArgsT0,
named_args = NamedArgs, named_args = NamedArgs,
@ -2723,10 +2769,11 @@ check_named_argument_constraint(Env,
ArgEnv = maps:from_list([ {Name, GetVal(Name, Default)} ArgEnv = maps:from_list([ {Name, GetVal(Name, Default)}
|| {named_arg_t, _, {id, _, Name}, _, Default} <- NamedArgsT ]), || {named_arg_t, _, {id, _, Name}, _, Default} <- NamedArgsT ]),
GenType1 = specialize_dependent_type(ArgEnv, GenType), GenType1 = specialize_dependent_type(ArgEnv, GenType),
unify(Env, GenType1, SpecType, {check_expr, App, GenType1, SpecType}), unify(Env, GenType1, SpecType, {check_expr, App, GenType1, SpecType});
true; _ ->
_ -> unify(Env, GenType, SpecType, {check_expr, App, GenType, SpecType}), true unify(Env, GenType, SpecType, {check_expr, App, GenType, SpecType})
end. end,
true.
specialize_dependent_type(Env, Type) -> specialize_dependent_type(Env, Type) ->
case dereference(Type) of case dereference(Type) of
@ -2742,53 +2789,16 @@ specialize_dependent_type(Env, Type) ->
_ -> Type %% Currently no deep dependent types _ -> Type %% Currently no deep dependent types
end. end.
%% -- Bytes constraints -- check_field_constraints(Env, Constraints) ->
UnsolvedFieldCs = solve_constraints(Env, Constraints),
case solve_unknown_record_constraints(Env, UnsolvedFieldCs) of
true -> ok;
Errors -> [ type_error(Err) || Err <- Errors ]
end.
solve_constraint(_Env, #field_constraint{record_t = {uvar, _, _}}) -> check_named_args_constraints(Env, Constraints) ->
not_solved; UnsolvedNamedArgCs = solve_constraints(Env, Constraints),
solve_constraint(Env, C = #field_constraint{record_t = RecType, [ type_error({unsolved_named_argument_constraint, C}) || C <- UnsolvedNamedArgCs ].
field = FieldName,
field_t = FieldType,
context = When}) ->
RecId = record_type_name(RecType),
Attrs = aeso_syntax:get_ann(RecId),
case lookup_type(Env, RecId) of
{_, {_Ann, {Formals, {What, Fields}}}} when What =:= record_t; What =:= contract_t ->
FieldTypes = [{Name, Type} || {field_t, _, {id, _, Name}, Type} <- Fields],
{id, _, FieldString} = FieldName,
case proplists:get_value(FieldString, FieldTypes) of
undefined ->
type_error({missing_field, FieldName, RecId}),
not_solved;
FldType ->
create_freshen_tvars(),
FreshFldType = freshen(FldType),
FreshRecType = freshen(app_t(Attrs, RecId, Formals)),
destroy_freshen_tvars(),
unify(Env, FreshFldType, FieldType, {field_constraint, FreshFldType, FieldType, When}),
unify(Env, FreshRecType, RecType, {record_constraint, FreshRecType, RecType, When}),
C
end;
_ ->
type_error({not_a_record_type, instantiate(RecType), When}),
not_solved
end;
solve_constraint(Env, C = #dependent_type_constraint{}) ->
check_named_argument_constraint(Env, C);
solve_constraint(Env, C = #named_argument_constraint{}) ->
check_named_argument_constraint(Env, C);
solve_constraint(_Env, {is_bytes, _}) -> ok;
solve_constraint(Env, {add_bytes, Ann, _, A0, B0, C0}) ->
A = unfold_types_in_type(Env, dereference(A0)),
B = unfold_types_in_type(Env, dereference(B0)),
C = unfold_types_in_type(Env, dereference(C0)),
case {A, B, C} of
{{bytes_t, _, M}, {bytes_t, _, N}, _} -> unify(Env, {bytes_t, Ann, M + N}, C, {at, Ann});
{{bytes_t, _, M}, _, {bytes_t, _, R}} when R >= M -> unify(Env, {bytes_t, Ann, R - M}, B, {at, Ann});
{_, {bytes_t, _, N}, {bytes_t, _, R}} when R >= N -> unify(Env, {bytes_t, Ann, R - N}, A, {at, Ann});
_ -> ok
end;
solve_constraint(_, _) -> ok.
check_bytes_constraints(Env, Constraints) -> check_bytes_constraints(Env, Constraints) ->
InAddConstraint = [ T || {add_bytes, _, _, A, B, C} <- Constraints, InAddConstraint = [ T || {add_bytes, _, _, A, B, C} <- Constraints,
@ -2885,30 +2895,6 @@ check_is_contract_constraints(Env, [C | Cs]) ->
end, end,
check_is_contract_constraints(Env, Cs). check_is_contract_constraints(Env, Cs).
-spec solve_unknown_record_types(env(), [field_constraint()]) -> true | [tuple()].
solve_unknown_record_types(Env, Unknown) ->
UVars = lists:usort([UVar || #field_constraint{record_t = UVar = {uvar, _, _}} <- Unknown]),
Solutions = [solve_for_uvar(Env, UVar, [{Kind, When, Field}
|| #field_constraint{record_t = U, field = Field, kind = Kind, context = When} <- Unknown,
U == UVar])
|| UVar <- UVars],
case lists:member(true, Solutions) of
true -> true;
false -> Solutions
end.
%% This will solve all kinds of constraints but will only return the
%% unsolved field constraints
-spec solve_known_record_types(env(), [constraint()]) -> [field_constraint()].
solve_known_record_types(Env, Constraints) ->
DerefConstraints = lists:map(fun(C = #field_constraint{record_t = RecordType}) ->
C#field_constraint{record_t = dereference(RecordType)};
(C) -> dereference_deep(C)
end, Constraints),
SolvedConstraints = lists:map(fun(C) -> solve_constraint(Env, dereference_deep(C)) end, DerefConstraints),
Unsolved = DerefConstraints--SolvedConstraints,
lists:filter(fun(#field_constraint{}) -> true; (_) -> false end, Unsolved).
record_type_name({app_t, _Attrs, RecId, _Args}) when ?is_type_id(RecId) -> record_type_name({app_t, _Attrs, RecId, _Args}) when ?is_type_id(RecId) ->
RecId; RecId;
record_type_name(RecId) when ?is_type_id(RecId) -> record_type_name(RecId) when ?is_type_id(RecId) ->
@ -3087,16 +3073,12 @@ unify0(Env, A, B, Variance, When) ->
unify1(_Env, {uvar, _, R}, {uvar, _, R}, _Variance, _When) -> unify1(_Env, {uvar, _, R}, {uvar, _, R}, _Variance, _When) ->
true; true;
unify1(_Env, {uvar, _, _}, {fun_t, _, _, var_args, _}, _Variance, When) -> unify1(_Env, {uvar, _, _}, {fun_t, _, _, var_args, _}, _Variance, When) ->
type_error({unify_varargs, When}); type_error({unify_varargs, When}),
unify1(Env, {uvar, A, R}, T, _Variance, When) -> false;
unify1(_Env, {uvar, A, R}, T, _Variance, When) ->
case occurs_check(R, T) of case occurs_check(R, T) of
true -> true ->
if cannot_unify({uvar, A, R}, T, none, When),
Env#env.unify_throws ->
cannot_unify({uvar, A, R}, T, none, When);
true ->
ok
end,
false; false;
false -> false ->
ets_insert(type_vars, {R, T}), ets_insert(type_vars, {R, T}),
@ -3123,18 +3105,13 @@ unify1(Env, A = {con, _, NameA}, B = {con, _, NameB}, Variance, When) ->
case is_subtype(Env, NameA, NameB, Variance) of case is_subtype(Env, NameA, NameB, Variance) of
true -> true; true -> true;
false -> false ->
if
Env#env.unify_throws ->
IsSubtype = is_subtype(Env, NameA, NameB, contravariant) orelse IsSubtype = is_subtype(Env, NameA, NameB, contravariant) orelse
is_subtype(Env, NameA, NameB, covariant), is_subtype(Env, NameA, NameB, covariant),
Cxt = case IsSubtype of Cxt = case IsSubtype of
true -> Variance; true -> Variance;
false -> none false -> none
end, end,
cannot_unify(A, B, Cxt, When); cannot_unify(A, B, Cxt, When),
true ->
ok
end,
false false
end; end;
unify1(_Env, {qid, _, Name}, {qid, _, Name}, _Variance, _When) -> unify1(_Env, {qid, _, Name}, {qid, _, Name}, _Variance, _When) ->
@ -3148,9 +3125,11 @@ unify1(Env, {if_t, _, {id, _, Id}, Then1, Else1}, {if_t, _, {id, _, Id}, Then2,
unify0(Env, Else1, Else2, Variance, When); unify0(Env, Else1, Else2, Variance, When);
unify1(_Env, {fun_t, _, _, _, _}, {fun_t, _, _, var_args, _}, _Variance, When) -> unify1(_Env, {fun_t, _, _, _, _}, {fun_t, _, _, var_args, _}, _Variance, When) ->
type_error({unify_varargs, When}); type_error({unify_varargs, When}),
false;
unify1(_Env, {fun_t, _, _, var_args, _}, {fun_t, _, _, _, _}, _Variance, When) -> unify1(_Env, {fun_t, _, _, var_args, _}, {fun_t, _, _, _, _}, _Variance, When) ->
type_error({unify_varargs, When}); type_error({unify_varargs, When}),
false;
unify1(Env, {fun_t, _, Named1, Args1, Result1}, {fun_t, _, Named2, Args2, Result2}, Variance, When) unify1(Env, {fun_t, _, Named1, Args1, Result1}, {fun_t, _, Named2, Args2, Result2}, Variance, When)
when length(Args1) == length(Args2) -> when length(Args1) == length(Args2) ->
unify0(Env, Named1, Named2, opposite_variance(Variance), When) andalso unify0(Env, Named1, Named2, opposite_variance(Variance), When) andalso
@ -3172,7 +3151,7 @@ unify1(Env, {tuple_t, _, As}, {tuple_t, _, Bs}, Variance, When)
when length(As) == length(Bs) -> when length(As) == length(Bs) ->
unify0(Env, As, Bs, Variance, When); unify0(Env, As, Bs, Variance, When);
unify1(Env, {named_arg_t, _, Id1, Type1, _}, {named_arg_t, _, Id2, Type2, _}, Variance, When) -> unify1(Env, {named_arg_t, _, Id1, Type1, _}, {named_arg_t, _, Id2, Type2, _}, Variance, When) ->
unify1(Env, Id1, Id2, Variance, {arg_name, Id1, Id2, When}), unify1(Env, Id1, Id2, Variance, {arg_name, Id1, Id2, When}) andalso
unify1(Env, Type1, Type2, Variance, When); unify1(Env, Type1, Type2, Variance, When);
%% The grammar is a bit inconsistent about whether types without %% The grammar is a bit inconsistent about whether types without
%% arguments are represented as applications to an empty list of %% arguments are represented as applications to an empty list of
@ -3181,13 +3160,8 @@ unify1(Env, {app_t, _, T, []}, B, Variance, When) ->
unify0(Env, T, B, Variance, When); unify0(Env, T, B, Variance, When);
unify1(Env, A, {app_t, _, T, []}, Variance, When) -> unify1(Env, A, {app_t, _, T, []}, Variance, When) ->
unify0(Env, A, T, Variance, When); unify0(Env, A, T, Variance, When);
unify1(Env, A, B, _Variance, When) -> unify1(_Env, A, B, _Variance, When) ->
if cannot_unify(A, B, none, When),
Env#env.unify_throws ->
cannot_unify(A, B, none, When);
true ->
ok
end,
false. false.
is_subtype(_Env, NameA, NameB, invariant) -> is_subtype(_Env, NameA, NameB, invariant) ->
@ -4120,8 +4094,8 @@ pp_when({if_branches, Then, ThenType0, Else, ElseType0}) ->
Branches = [ {Then, ThenType} | [ {B, ElseType} || B <- if_branches(Else) ] ], Branches = [ {Then, ThenType} | [ {B, ElseType} || B <- if_branches(Else) ] ],
{pos(element(1, hd(Branches))), {pos(element(1, hd(Branches))),
io_lib:format("when comparing the types of the if-branches\n" io_lib:format("when comparing the types of the if-branches\n"
"~s", [ [ io_lib:format("~s (at ~s)\n", [pp_typed(" - ", B, BType), pp_loc(B)]) "~s", [string:join([ io_lib:format("~s (at ~s)", [pp_typed(" - ", B, BType), pp_loc(B)])
|| {B, BType} <- Branches ] ])}; || {B, BType} <- Branches ], "\n")])};
pp_when({case_pat, Pat, PatType0, ExprType0}) -> pp_when({case_pat, Pat, PatType0, ExprType0}) ->
{PatType, ExprType} = instantiate({PatType0, ExprType0}), {PatType, ExprType} = instantiate({PatType0, ExprType0}),
{pos(Pat), {pos(Pat),