%---------------------------------------------------------------------------%
% vim: ft=mercury ts=4 sw=4 et
%---------------------------------------------------------------------------%
% Copyright (C) 1994-2012 The University of Melbourne.
% Copyright (C) 2014-2021, 2024-2025 The Mercury team.
% This file may only be copied under the terms of the GNU General
% Public License - see the file COPYING in the Mercury distribution.
%---------------------------------------------------------------------------%
%
% File: mode_util.m.
% Main author: fjh.
%
% This module contains utility predicates for dealing with modes.
%
%---------------------------------------------------------------------------%

:- module hlds.mode_util.
:- interface.

:- import_module hlds.hlds_goal.
:- import_module hlds.hlds_module.
:- import_module parse_tree.
:- import_module parse_tree.prog_data.

:- import_module list.

%---------------------------------------------------------------------------%
%
% Breaking down modes into their initial and final insts.
%

:- func from_to_insts_to_init_inst(from_to_insts) = mer_inst.
:- func from_to_insts_to_final_inst(from_to_insts) = mer_inst.

%---------------------%

    % Return the initial and final instantiatedness for the given mode.
    % Fail if the mode is undefined.
    %
:- pred mode_get_insts_semidet(module_info::in, mer_mode::in,
    mer_inst::out, mer_inst::out) is semidet.

    % Return the initial and final instantiatedness for the given mode.
    % Throw an exception if the mode is undefined.
    %
:- pred mode_get_insts(module_info::in, mer_mode::in,
    mer_inst::out, mer_inst::out) is det.
:- pred mode_get_from_to_insts(module_info::in, mer_mode::in,
    from_to_insts::out) is det.

%---------------------%

    % Return the initial or final instantiatedness for the given mode.
    % Throw an exception if the mode is undefined.
    %
:- func mode_get_initial_inst(module_info, mer_mode) = mer_inst.
:- func mode_get_final_inst(module_info, mer_mode) = mer_inst.

    % Return the initial or final instantiatedness for each of
    % the given modes.
    % Throw an exception if any mode is undefined.
    %
:- pred mode_list_get_initial_insts(module_info::in,
    list(mer_mode)::in, list(mer_inst)::out) is det.
:- pred mode_list_get_final_insts(module_info::in,
    list(mer_mode)::in, list(mer_inst)::out) is det.
:- pred mode_list_get_initial_final_insts(module_info::in,
    list(mer_mode)::in, list(mer_inst)::out, list(mer_inst)::out) is det.

%---------------------------------------------------------------------------%
%
% Converting between from_to_insts and modes.
%

:- func from_to_insts_to_mode(from_to_insts) = mer_mode.
:- func mode_to_from_to_insts(module_info, mer_mode) = from_to_insts.

%---------------------------------------------------------------------------%
%
% Converting between unify_modes and from_to_insts/modes.
%

:- pred unify_mode_to_lhs_rhs_from_to_insts(unify_mode::in,
    from_to_insts::out, from_to_insts::out) is det.

    % Return the modes of the operands on the given side of the unifications.
    %
:- func unify_mode_to_lhs_mode(unify_mode) = mer_mode.
:- func unify_mode_to_rhs_mode(unify_mode) = mer_mode.
:- func unify_mode_to_lhs_from_to_insts(unify_mode) = from_to_insts.
:- func unify_mode_to_rhs_from_to_insts(unify_mode) = from_to_insts.

    % Given the modes of the two sides of a unification, return the unify_mode.
    %
:- pred modes_to_unify_mode(module_info::in,
    mer_mode::in, mer_mode::in, unify_mode::out) is det.
:- pred from_to_insts_to_unify_mode(from_to_insts::in, from_to_insts::in,
    unify_mode::out) is det.

    % Given two lists of modes (of equal length), with each pair
    % giving the modes of the two sides of a unification,
    % return a unify_mode for each corresponding pair of modes.
    %
:- pred modes_to_unify_modes(module_info::in,
    list(mer_mode)::in, list(mer_mode)::in, list(unify_mode)::out) is det.
:- pred from_to_insts_to_unify_modes(
    list(from_to_insts)::in, list(from_to_insts)::in, list(unify_mode)::out)
    is det.

%---------------------------------------------------------------------------%
%
% Figure out which arguments are live in callers.
%

    % Given the mode of a predicate, work out which arguments are live
    % (might be used again by the caller of that predicate) and which are dead.
    %
:- pred get_arg_lives(module_info::in, list(mer_mode)::in, list(is_live)::out)
    is det.

    % Given a list of variables, and a list of livenesses,
    % select the live variables.
    %
:- pred get_live_vars(list(prog_var)::in, list(is_live)::in,
    list(prog_var)::out) is det.

%---------------------------------------------------------------------------%
%
% Constructing bound_functors.
%

    % Convert a list of constructors to a list of bound_functors where the
    % arguments are `ground'.
    %
    % NOTE: the list(bound_functor) is not sorted and may contain duplicates.
    %
:- pred constructors_to_bound_functors(module_info::in, uniqueness::in,
    type_ctor::in, list(constructor)::in, list(bound_functor)::out) is det.

    % Convert a list of constructors to a list of bound_functors where the
    % arguments are `any'.
    %
    % NOTE: the list(bound_functor) is not sorted and may contain duplicates.
    %
:- pred constructors_to_bound_any_insts(module_info::in, uniqueness::in,
    type_ctor::in, list(constructor)::in, list(bound_functor)::out) is det.

%---------------------------------------------------------------------------%
%
% Miscellaneous operations.
%

    % Return a map of all the inst variables in the given modes,
    % and the sub-insts to which they are constrained.
    %
:- pred get_constrained_inst_vars(module_info::in, list(mer_mode)::in,
    head_inst_vars::out) is det.

:- pred mode_is_free_of_uniqueness(module_info::in, mer_mode::in) is semidet.

%---------------------------------------------------------------------------%
%---------------------------------------------------------------------------%

:- implementation.

:- import_module hlds.hlds_inst_mode.
:- import_module hlds.inst_lookup.
:- import_module hlds.inst_test.
:- import_module mdbcomp.
:- import_module mdbcomp.builtin_modules.
:- import_module mdbcomp.sym_name.
:- import_module parse_tree.prog_mode.

:- import_module map.
:- import_module require.
:- import_module set.
:- import_module set_tree234.
:- import_module term.
:- import_module varset.

%---------------------------------------------------------------------------%

from_to_insts_to_init_inst(FromToInsts) = Init :-
    FromToInsts = from_to_insts(Init, _Final).

from_to_insts_to_final_inst(FromToInsts) = Final :-
    FromToInsts = from_to_insts(_Init, Final).

%---------------------%

mode_get_insts_semidet(ModuleInfo, Mode0, InitialInst, FinalInst) :-
    (
        Mode0 = from_to_mode(InitialInst, FinalInst)
    ;
        Mode0 = user_defined_mode(Name, Args),
        list.length(Args, Arity),
        module_info_get_mode_table(ModuleInfo, Modes),
        mode_table_get_mode_defns(Modes, ModeDefns),
        % Try looking up Name as-is. If that fails and Name is unqualified,
        % try looking it up with the builtin qualifier.
        % XXX This is a makeshift fix for a problem that requires more
        % investigation (without this fix the compiler occasionally
        % throws an exception in mode_get_insts/4).
        ( if map.search(ModeDefns, mode_ctor(Name, Arity), HLDS_Mode0) then
            HLDS_Mode = HLDS_Mode0
        else
            Name = unqualified(String),
            BuiltinName = qualified(mercury_public_builtin_module, String),
            map.search(ModeDefns, mode_ctor(BuiltinName, Arity), HLDS_Mode)
        ),
        HLDS_Mode = hlds_mode_defn(_VarSet, Params, ModeDefn,
            _Context, _Status),
        ModeDefn = hlds_mode_body(Mode1),
        mode_substitute_arg_list(Mode1, Params, Args, Mode),
        mode_get_insts_semidet(ModuleInfo, Mode, InitialInst, FinalInst)
    ).

mode_get_insts(ModuleInfo, Mode, InitInst, FinalInst) :-
    ( if
        mode_get_insts_semidet(ModuleInfo, Mode, InitInstPrime, FinalInstPrime)
    then
        InitInst = InitInstPrime,
        FinalInst = FinalInstPrime
    else
        unexpected($pred, "mode_get_insts_semidet failed")
    ).

mode_get_from_to_insts(ModuleInfo, Mode, FromToInsts) :-
    mode_get_insts(ModuleInfo, Mode, InitInst, FinalInst),
    FromToInsts = from_to_insts(InitInst, FinalInst).

%---------------------%

mode_get_initial_inst(ModuleInfo, Mode) = Inst :-
    mode_get_insts(ModuleInfo, Mode, Inst, _).

mode_get_final_inst(ModuleInfo, Mode) = Inst :-
    mode_get_insts(ModuleInfo, Mode, _, Inst).

mode_list_get_initial_insts(_ModuleInfo, [], []).
mode_list_get_initial_insts(ModuleInfo, [Mode | Modes],
        [InitInst | InitInsts]) :-
    mode_get_insts(ModuleInfo, Mode, InitInst, _),
    mode_list_get_initial_insts(ModuleInfo, Modes, InitInsts).

mode_list_get_final_insts(_ModuleInfo, [], []).
mode_list_get_final_insts(ModuleInfo, [Mode | Modes],
        [FinalInst | FinalInsts]) :-
    mode_get_insts(ModuleInfo, Mode, _, FinalInst),
    mode_list_get_final_insts(ModuleInfo, Modes, FinalInsts).

mode_list_get_initial_final_insts(_ModuleInfo, [], [], []).
mode_list_get_initial_final_insts(ModuleInfo, [Mode | Modes],
        [InitInst | InitInsts], [FinalInst | FinalInsts]) :-
    mode_get_insts(ModuleInfo, Mode, InitInst, FinalInst),
    mode_list_get_initial_final_insts(ModuleInfo, Modes,
        InitInsts, FinalInsts).

%---------------------------------------------------------------------------%

from_to_insts_to_mode(FromToInsts) = Mode :-
    FromToInsts = from_to_insts(Init, Final),
    Mode = from_to_mode(Init, Final).

mode_to_from_to_insts(ModuleInfo, Mode) = FromToInsts :-
    mode_get_insts(ModuleInfo, Mode, Init, Final),
    FromToInsts = from_to_insts(Init, Final).

%---------------------------------------------------------------------------%

unify_mode_to_lhs_rhs_from_to_insts(UnifyMode, LHSInsts, RHSInsts) :-
    LHSInsts = from_to_insts(LHSInitInst, LHSFinalInst),
    RHSInsts = from_to_insts(RHSInitInst, RHSFinalInst),
    UnifyMode = unify_modes_li_lf_ri_rf(LHSInitInst, LHSFinalInst,
        RHSInitInst, RHSFinalInst).

unify_mode_to_lhs_mode(UnifyMode) = LHSMode :-
    UnifyMode = unify_modes_li_lf_ri_rf(LHSInitInst, LHSFinalInst, _, _),
    LHSMode = from_to_mode(LHSInitInst, LHSFinalInst).

unify_mode_to_rhs_mode(UnifyMode) = RHSMode :-
    UnifyMode = unify_modes_li_lf_ri_rf(_, _, RHSInitInst, RHSFinalInst),
    RHSMode = from_to_mode(RHSInitInst, RHSFinalInst).

unify_mode_to_lhs_from_to_insts(UnifyMode) = LHSFromToInsts :-
    UnifyMode = unify_modes_li_lf_ri_rf(LHSInitInst, LHSFinalInst, _, _),
    LHSFromToInsts = from_to_insts(LHSInitInst, LHSFinalInst).

unify_mode_to_rhs_from_to_insts(UnifyMode) = RHSFromToInsts :-
    UnifyMode = unify_modes_li_lf_ri_rf(_, _, RHSInitInst, RHSFinalInst),
    RHSFromToInsts = from_to_insts(RHSInitInst, RHSFinalInst).

%---------------------%

modes_to_unify_mode(ModuleInfo, ModeX, ModeY, UnifyMode) :-
    mode_get_insts(ModuleInfo, ModeX, InitialX, FinalX),
    mode_get_insts(ModuleInfo, ModeY, InitialY, FinalY),
    UnifyMode = unify_modes_li_lf_ri_rf(InitialX, FinalX, InitialY, FinalY).

from_to_insts_to_unify_mode(FromToInstsX, FromToInstsY, UnifyMode) :-
    FromToInstsX = from_to_insts(InitInstX, FinalInstX),
    FromToInstsY = from_to_insts(InitInstY, FinalInstY),
    UnifyMode = unify_modes_li_lf_ri_rf(InitInstX, FinalInstX,
        InitInstY, FinalInstY).

modes_to_unify_modes(_ModuleInfo, [], [], []).
modes_to_unify_modes(_ModuleInfo, [], [_ | _], _) :-
    unexpected($pred, "length mismatch").
modes_to_unify_modes(_ModuleInfo, [_ | _], [], _) :-
    unexpected($pred, "length mismatch").
modes_to_unify_modes(ModuleInfo,
        [ModeX | ModeXs], [ModeY | ModeYs],
        [UnifyMode | UnifyModes]) :-
    modes_to_unify_mode(ModuleInfo, ModeX, ModeY, UnifyMode),
    modes_to_unify_modes(ModuleInfo, ModeXs, ModeYs, UnifyModes).

from_to_insts_to_unify_modes([], [], []).
from_to_insts_to_unify_modes([], [_ | _], _) :-
    unexpected($pred, "length mismatch").
from_to_insts_to_unify_modes([_ | _], [], _) :-
    unexpected($pred, "length mismatch").
from_to_insts_to_unify_modes(
        [FromToInstsX | FromToInstsXs], [FromToInstsY | FromToInstsYs],
        [UnifyMode | UnifyModes]) :-
    from_to_insts_to_unify_mode(FromToInstsX, FromToInstsY, UnifyMode),
    from_to_insts_to_unify_modes(FromToInstsXs, FromToInstsYs, UnifyModes).

%---------------------------------------------------------------------------%

get_arg_lives(_, [], []).
get_arg_lives(ModuleInfo, [Mode | Modes], [IsLive | IsLives]) :-
    % Arguments with final inst `clobbered' are dead, any others
    % are assumed to be live.
    mode_get_insts(ModuleInfo, Mode, _InitialInst, FinalInst),
    ( if inst_is_clobbered(ModuleInfo, FinalInst) then
        IsLive = is_dead
    else
        IsLive = is_live
    ),
    get_arg_lives(ModuleInfo, Modes, IsLives).

get_live_vars([], [], []).
get_live_vars([_ | _], [], _) :-
    unexpected($pred, "length mismatch").
get_live_vars([], [_ | _], _) :-
    unexpected($pred, "length mismatch").
get_live_vars([Var | Vars], [IsLive | IsLives], LiveVars) :-
    (
        IsLive = is_live,
        LiveVars = [Var | LiveVars0]
    ;
        IsLive = is_dead,
        LiveVars = LiveVars0
    ),
    get_live_vars(Vars, IsLives, LiveVars0).

%---------------------------------------------------------------------------%

constructors_to_bound_functors(ModuleInfo, Uniq, TypeCtor, Constructors,
        BoundFunctors) :-
    constructors_to_bound_functors_loop_over_ctors(ModuleInfo, Uniq, TypeCtor,
        Constructors, ground(Uniq, none_or_default_func), BoundFunctors).

constructors_to_bound_any_insts(ModuleInfo, Uniq, TypeCtor, Constructors,
        BoundFunctors) :-
    constructors_to_bound_functors_loop_over_ctors(ModuleInfo, Uniq, TypeCtor,
        Constructors, any(Uniq, none_or_default_func), BoundFunctors).

:- pred constructors_to_bound_functors_loop_over_ctors(module_info::in,
    uniqueness::in, type_ctor::in, list(constructor)::in, mer_inst::in,
    list(bound_functor)::out) is det.

constructors_to_bound_functors_loop_over_ctors(_, _, _, [], _, []).
constructors_to_bound_functors_loop_over_ctors(ModuleInfo, Uniq, TypeCtor,
        [Ctor | Ctors], ArgInst, [BoundFunctor | BoundFunctors]) :-
    Ctor = ctor(_Ordinal, _MaybeExistConstraints, Name, Args, _Arity, _Ctxt),
    ctor_arg_list_to_inst_list(Args, ArgInst, Insts),
    list.length(Insts, Arity),
    DuCtor = du_ctor(Name, Arity, TypeCtor),
    BoundFunctor = bound_functor(du_data_ctor(DuCtor), Insts),
    constructors_to_bound_functors_loop_over_ctors(ModuleInfo, Uniq, TypeCtor,
        Ctors, ArgInst, BoundFunctors).

:- pred ctor_arg_list_to_inst_list(list(constructor_arg)::in, mer_inst::in,
    list(mer_inst)::out) is det.

ctor_arg_list_to_inst_list([], _, []).
ctor_arg_list_to_inst_list([_ | Args], Inst, [Inst | Insts]) :-
    ctor_arg_list_to_inst_list(Args, Inst, Insts).

%---------------------------------------------------------------------------%

:- type inst_expansions == set_tree234(inst_name).

get_constrained_inst_vars(ModuleInfo, Modes, Map) :-
    list.foldl2(get_constrained_insts_in_mode(ModuleInfo), Modes,
        map.init, Map, set_tree234.init, _Expansions).

:- pred get_constrained_insts_in_mode(module_info::in, mer_mode::in,
    head_inst_vars::in, head_inst_vars::out,
    inst_expansions::in, inst_expansions::out) is det.

get_constrained_insts_in_mode(ModuleInfo, Mode, !Map, !Expansions) :-
    mode_get_insts(ModuleInfo, Mode, InitialInst, FinalInst),
    get_constrained_insts_in_inst(ModuleInfo, InitialInst, !Map, !Expansions),
    get_constrained_insts_in_inst(ModuleInfo, FinalInst, !Map, !Expansions).

:- pred get_constrained_insts_in_inst(module_info::in, mer_inst::in,
    head_inst_vars::in, head_inst_vars::out,
    inst_expansions::in, inst_expansions::out) is det.

get_constrained_insts_in_inst(ModuleInfo, Inst, !Map, !Expansions) :-
    (
        ( Inst = free
        ; Inst = not_reached
        )
    ;
        Inst = bound(_, InstResults, BoundFunctors),
        (
            InstResults = inst_test_results_fgtc
        ;
            InstResults = inst_test_results(_, _, _, InstVarsResult, _, _),
            ( if
                InstVarsResult =
                    inst_result_contains_inst_vars_known(InstVars),
                set.is_empty(InstVars)
            then
                true
            else
                list.foldl2(get_constrained_insts_in_bound_functor(ModuleInfo),
                    BoundFunctors, !Map, !Expansions)
            )
        ;
            InstResults = inst_test_no_results,
            list.foldl2(get_constrained_insts_in_bound_functor(ModuleInfo),
                BoundFunctors, !Map, !Expansions)
        )
    ;
        ( Inst = any(_, HOInstInfo)
        ; Inst = ground(_, HOInstInfo)
        ),
        (
            HOInstInfo = none_or_default_func
        ;
            HOInstInfo = higher_order(PredInstInfo),
            get_constrained_insts_in_ho_inst(ModuleInfo, PredInstInfo,
                !Map, !Expansions)
        )
    ;
        Inst = constrained_inst_vars(InstVars, _),
        inst_expand_and_remove_constrained_inst_vars(ModuleInfo,
            Inst, SubInst),
        set.fold(add_constrained_inst(SubInst), InstVars, !Map)
    ;
        Inst = defined_inst(InstName),
        ( if insert_new(InstName, !Expansions) then
            inst_lookup(ModuleInfo, InstName, ExpandedInst),
            get_constrained_insts_in_inst(ModuleInfo, ExpandedInst,
                !Map, !Expansions)
        else
            true
        )
    ;
        Inst = inst_var(_),
        unexpected($pred, "inst_var")
    ).

:- pred get_constrained_insts_in_bound_functor(module_info::in,
    bound_functor::in, head_inst_vars::in, head_inst_vars::out,
    inst_expansions::in, inst_expansions::out) is det.

get_constrained_insts_in_bound_functor(ModuleInfo, BoundFunctor,
        !Map, !Expansions) :-
    BoundFunctor = bound_functor(_ConsId, Insts),
    list.foldl2(get_constrained_insts_in_inst(ModuleInfo), Insts,
        !Map, !Expansions).

:- pred get_constrained_insts_in_ho_inst(module_info::in, pred_inst_info::in,
    head_inst_vars::in, head_inst_vars::out,
    inst_expansions::in, inst_expansions::out) is det.

get_constrained_insts_in_ho_inst(ModuleInfo, PredInstInfo,
        !Map, !Expansions) :-
    PredInstInfo = pred_inst_info(_, Modes, _, _),
    list.foldl2(get_constrained_insts_in_mode(ModuleInfo), Modes,
        !Map, !Expansions).

:- pred add_constrained_inst(mer_inst::in, inst_var::in,
    head_inst_vars::in, head_inst_vars::out) is det.

add_constrained_inst(SubInst, InstVar, !Map) :-
    ( if map.search(!.Map, InstVar, SubInst0) then
        ( if SubInst0 = SubInst then
            true
        else
            unexpected($pred, "SubInst differs")
        )
    else
        map.det_insert(InstVar, SubInst, !Map)
    ).

%---------------------------------------------------------------------------%

mode_is_free_of_uniqueness(ModuleInfo, Mode) :-
    mode_get_insts(ModuleInfo, Mode, InitInst, FinalInst),
    inst_is_not_partly_unique(ModuleInfo, InitInst),
    inst_is_not_partly_unique(ModuleInfo, FinalInst).

%---------------------------------------------------------------------------%
:- end_module hlds.mode_util.
%---------------------------------------------------------------------------%
