import logging
import numpy as np
from .io import BaseIO
from ..models.sparse_points import SparsePoints
from ..representations.spherical_invariants import SphericalInvariants
LOGGER = logging.getLogger(__name__)
try:
    from skcosmo._selection import _FPS, _CUR
except ImportError as ie:
    LOGGER.warn(
        "Warning: skcosmo module not found. CUR and FPS filters will be unavailable."
    )
    LOGGER.warn("Original error:\n" + str(ie))
    _FPS = _CUR = None
# Index conversion utilities
def _indices_manager_to_perstructure(managers, selected_ids_global):
    """Convert manager-global center indexing to per-structure format
    That is, change them to a list of lists of structure-local atom indices
    that is accepted as input to SparsePoints
    Parameters
    ----------
    managers : AtomsList or list(ase.Atoms)
        list of atomic structures
    selected_ids_global : list of int
        global indices to convert
    Returns
    -------
    selected_ids : list of list(int)
        list the atom indices (within their structure) that have been selected
    """
    selected_ids = []
    selected_ids_global = np.array(selected_ids_global)
    natoms_list = [len(manager) for manager in managers]
    split_idces = np.cumsum(natoms_list)
    structure_start_idx = 0
    # Handle any out-of-range indices
    if np.any(selected_ids_global >= split_idces[-1]):
        bad_indices = selected_ids_global[selected_ids_global >= split_idces[-1]]
        bad_indices_str = np.array2string(
            bad_indices, threshold=5, edgeitems=2, separator=", "
        )
        raise ValueError(f"Selected index(es): {bad_indices_str} out of range")
    # Do the actual conversion if OK
    for structure_end_idx in split_idces:
        this_structure_idces = selected_ids_global[
            (selected_ids_global >= structure_start_idx)
            & (selected_ids_global < structure_end_idx)
        ]
        selected_structure_idces = this_structure_idces - structure_start_idx
        selected_ids.append(list(selected_structure_idces))
        structure_start_idx = structure_end_idx
    assert sum(len(ids) for ids in selected_ids) == len(selected_ids_global)
    return selected_ids
def _indices_perspecies_manager_to_perstructure(managers, selected_ids_by_sp, sps):
    """Convert per-species, manager-global center indexing to per-structure
    That is, change them to a list of lists of structure-local atom indices
    that is accepted as input to SparsePoints
    See _get_index_mappings_sample_per_species() to make the intput
    Parameters
    ----------
    managers : AtomsList
        list of atomic structures
    selected_ids_by_sp : dict
        indices to convert; each dictionary entry (keyed by species)
        is a list of indices into the array of all atoms in the manager
        of only that species
    sps : list(int) or set(int)
        unique center atom species present in managers
    Returns
    -------
    selected_ids : list of lists
        list the atom indices (within their structure) that have been selected
        They are ordered, first by species (sorted, irrespective of the order
        passed in), then by selection order.
    Notes
    -----
    Asking for indices for a species not present in the AtomsList will
    result in an out-of-range error (since the slice of atoms of that species
    is of size zero).
    This function does not yet support lists of ASE Atoms (instead of a list
    of managers) as input, but it should be fairly easy to add support in
    the future if required.
    """
    if len(set(sps)) != len(sps):
        raise ValueError(f"List of species contains duplicated entries: {sps}")
    selected_ids = [[] for ii in range(len(managers))]
    structure_sp_start_idx = {sp: 0 for sp in sps}
    for sp in sps:
        selected_ids_by_sp[sp] = np.array(selected_ids_by_sp[sp])
    for structure, structure_selected_ids in zip(managers, selected_ids):
        perspecies_counter = {sp: 0 for sp in sps}
        structure_index_mapping = {sp: [] for sp in sps}
        for perstructure_idx, atom in enumerate(structure):
            atom_sp = atom.atom_type
            if atom_sp not in sps:
                raise ValueError(
                    f"Atom of type {atom_sp} found but was not listed in sps: {sps}"
                )
            structure_index_mapping[atom_sp].append(perstructure_idx)
            perspecies_counter[atom_sp] += 1
        for sp in sorted(list(sps)):
            selected_ids_sp = selected_ids_by_sp[sp]
            structure_sp_end_idx = structure_sp_start_idx[sp] + perspecies_counter[sp]
            this_structure_sp_idces = (
                selected_ids_sp[
                    (selected_ids_sp >= structure_sp_start_idx[sp])
                    & (selected_ids_sp < structure_sp_end_idx)
                ]
                - structure_sp_start_idx[sp]
            )
            structure_selected_ids.extend(
                [
                    structure_index_mapping[sp][sp_idx]
                    for sp_idx in this_structure_sp_idces
                ]
            )
            structure_sp_start_idx[sp] += perspecies_counter[sp]
    for sp in sps:
        selected_out_of_range = (
            np.array(selected_ids_by_sp[sp]) >= structure_sp_start_idx[sp]
        )
        if np.any(selected_out_of_range):
            bad_indices = np.array(selected_ids_by_sp[sp])[selected_out_of_range]
            bad_indices_str = np.array2string(
                bad_indices, threshold=5, edgeitems=2, separator=", "
            )
            error_str = (
                f"Selected index(es): {bad_indices_str} for species {sp} out of range"
            )
            if 0 in bad_indices:
                error_str += " (species does not appear to be present)"
            raise ValueError(error_str)
    # Check that we haven't missed anything
    assert sum(len(ids) for ids in selected_ids) == sum(
        len(selected_ids_by_sp[sp]) for sp in sps
    )
    return selected_ids
def _split_feature_matrix_by_species(managers, X, sps):
    """Does exactly what it says on the tin
    Parameters
    ----------
    managers : AtomsList
        list of atomic structures
    X : np.ndarray (2-D)
        feature matrix computed from managers; rows must correspond to atoms
    sps : list(int)
        list of unique center atom species present in managers
    Returns
    -------
    dict(np.ndarray)
        The feature matrix split into matrices each corresponding to one
        of the atomic species requested
    Warnings
    --------
    This function does not check that the list of species provided actually
    corresponds to those present in managers; it only performs the selection
    (which would be empty for nonexistent species).
    """
    X_per_species = {}
    global_species_list = []
    for structure in managers:
        global_species_list.extend([atom.atom_type for atom in structure])
    global_species_list = np.array(global_species_list)
    for sp in sps:
        X_per_species[sp] = X[global_species_list == sp]
    return X_per_species
[docs]class Filter(BaseIO):
    """
    A super class for filtering representations based upon a standard
    sample or feature selection class.
    This is mainly a wrapper around selectors (implemented e.g. in
    scikit-cosmo) that handles the semantic-index transformations
    required after selection.
    Parameters
    ----------
    representation : Calculator
        Representation calculator associated with the kernel
    Nselect: int
        number of points to select. If act_on='sample per species' then it should
        be a dictionary mapping atom type to the number of samples, e.g.
        Nselect = {1:200,6:100,8:50}.
    selector: selector to use for filtering. The selector should
            have a `fit` function, which when called will select from the input
            matrix the desired features / samples and a `get_support` function
            which takes parameters `indices` and `ordered`, and returns a list
            of selection indices, in the order that they were selected,
            when `indices=True` and `ordered=True`.
    act_on: string
        Select how to apply the selection. Can be either of 'sample',
        'sample per species','feature'.  Default 'sample per species'.
        Note that for 'feature' mode only the SphericalInvariants
        representation is supported.
    """
    def __init__(
        self,
        representation,
        Nselect,
        selector,
        act_on="sample per species",
    ):
        self._representation = representation
        self.Nselect = Nselect
        if self.act_on is None:
            self._check_set_mode(act_on)
        # effectively selected list of indices at the filter step
        # the indices have been reordered for effiency and compatibility with
        # the c++ routines
        self.selected_ids = None
        # for 'sample' selection
        self.selected_sample_ids = None
        # for 'sample per species' selection
        self.selected_sample_ids_by_sp = None
        # for feature selection
        self.selected_feature_ids_global = None
        self._selector = selector
    def _check_set_mode(
        self, act_on, modes=["sample", "sample per species", "feature"]
    ):
        """Check that the supplied act_on is one of the supported modes
        Set the mode if it is valid, aise a ValueError with a helpful
        message otherwise
        A list of valid modes can be supplied in case it differs from the
        superclass default.
        """
        if act_on in modes:
            self.act_on = act_on
        else:
            valid_modes = ['"{}"'.format(mode) for mode in modes]
            if len(valid_modes) > 1:
                valid_modes[-1] = "or " + valid_modes[-1]
            valid_modes_str = ", ".join(valid_modes)
            raise ValueError('"act_on" should be one of: ' + valid_modes_str)
    def select(self, managers):
        """Perform selection of samples/features.
        Parameters
        ----------
        managers : AtomsList
            list of structures containing features computed with representation
        Returns
        -------
        Filter (self)
            Returns self; use `filter()` to perform the actual filtering
            operation
        """
        X = managers.get_features(self._representation)
        if self.act_on == "sample per species":
            self.selected_sample_ids_by_sp = {}
            sps = list(self.Nselect.keys())
            LOGGER.info(
                f"The number of pseudo points selected by central atom species is: {self.Nselect}"
            )
            X_by_sp = _split_feature_matrix_by_species(managers, X, sps)
            for sp in sps:
                LOGGER.info(f"Selecting species: {sp}")
                if self._selector[sp] is not None:
                    self._selector[sp].fit(X_by_sp[sp])
                    in_sample_indices = np.array(
                        self._selector[sp].get_support(indices=True, ordered=True),
                        dtype=int,
                    )
                    self.selected_sample_ids_by_sp[sp] = in_sample_indices
                else:
                    self.selected_sample_ids_by_sp[sp] = []
            return self
        else:
            self._selector.fit(X)
            if self.act_on == "sample":
                self.selected_sample_ids = self._selector.get_support(
                    indices=True, ordered=True
                )
            else:
                self.selected_feature_ids_global = self._selector.get_support(
                    indices=True, ordered=True
                )
            return self
    def filter(self, managers, n_select=None):
        """Apply the fitted selection to a new set of managers
        Parameters
        ----------
        managers : AtomsList
            list of structures containing features computed with representation
        n_select : int
            number of selections to return, must be less than self.Nselect
        Returns
        -------
        SparsePoints or list(int) or dict
            Selected samples.  The format depends on self.act_on - if it is
            "sample" or "sample per species", then a SparsePoints instance
            is directly returned.  If it is "feature", then it is a dictionary
            containing the "coefficient_subselection" key that can be directly
            passed to the SphericalInvariants constructor.
        Warnings
        --------
        Note that the selected points are sorted in order of selection,
        _except_ if self.act_on=="sample", in which case the sparse points
        are afterwards sorted by species.
        Raises
        ------
        ValueError
            if requesting more selected samples or features than were used
            to initialize the representation
        """
        if n_select is None:
            n_select = self.Nselect
        else:
            if n_select > self.Nselect:
                raise ValueError(
                    f"It is only possible to filter {self.Nselect} {self.act_on}(s), "
                    f"you have requested {n_select}"
                )
        if self.act_on == "sample per species":
            sps = list(n_select.keys())
            selected_ids_by_sp = {
                key: val[: n_select[key]]
                for key, val in self.selected_sample_ids_by_sp.items()
            }
            self.selected_ids = _indices_perspecies_manager_to_perstructure(
                managers, selected_ids_by_sp, sps
            )
            sparse_points = SparsePoints(self._representation)
            sparse_points.extend(managers, self.selected_ids)
            return sparse_points
        elif self.act_on == "sample":
            selected_ids_global = self.selected_sample_ids[:n_select]
            self.selected_ids = _indices_manager_to_perstructure(
                managers, selected_ids_global
            )
            # The sparse points will be reordered since they're not per-species
            # but the resulting object is still usable
            sparse_points = SparsePoints(self._representation)
            sparse_points.extend(managers, self.selected_ids)
            return sparse_points
        elif self.act_on == "feature":
            if not isinstance(self._representation, SphericalInvariants):
                raise ValueError(
                    "Feature filtering currently only supported for SphericalInvariants"
                )
            feat_idx2coeff_idx = self._representation.get_feature_index_mapping(
                managers
            )
            self.selected_ids = {key: [] for key in feat_idx2coeff_idx[0].keys()}
            selected_ids_sorting = np.argsort(
                self.selected_feature_ids_global[:n_select]
            )
            selected_feature_ids = self.selected_feature_ids_global[
                selected_ids_sorting
            ]
            for idx in selected_feature_ids:
                coef_idx = feat_idx2coeff_idx[idx]
                for key in self.selected_ids.keys():
                    self.selected_ids[key].append(int(coef_idx[key]))
            self.selected_ids = dict(coefficient_subselection=self.selected_ids)
            # keep the global indices and ordering for ease of use
            self.selected_ids[
                "selected_feature_ids_global"
            ] = selected_feature_ids.tolist()
            self.selected_ids[
                "selected_feature_ids_global_selection_ordering"
            ] = selected_ids_sorting.tolist()
            return self.selected_ids
    def select_and_filter(self, managers):
        return self.select(managers).filter(managers)
    def _get_data(self):
        return dict(
            selected_ids=self.selected_ids,
            selected_sample_ids=self.selected_sample_ids,
            selected_sample_ids_by_sp=self.selected_sample_ids_by_sp,
            selected_feature_ids_global=self.selected_feature_ids_global,
        )
    def _set_data(self, data):
        self.selected_ids = data["selected_ids"]
        self.selected_sample_ids = data["selected_sample_ids"]
        self.selected_sample_ids_by_sp = data["selected_sample_ids_by_sp"]
        self.selected_feature_ids_global = data["selected_feature_ids_global"]
    def _get_init_params(self):
        return dict(
            representation=self._representation,
            Nselect=self.Nselect,
            act_on=self.act_on,
        ) 
[docs]class CURFilter(Filter):
    def __init__(
        self,
        representation,
        Nselect,
        act_on="sample per species",
        selector_args={},
        **kwargs,
    ):
        modes = ["sample", "sample per species", "feature"]
        self._check_set_mode(act_on, modes)
        if act_on == "sample":
            selector = _CUR(
                selection_type="sample", n_to_select=Nselect, **selector_args
            )
        elif act_on == "sample per species":
            selector = {
                n: _CUR(
                    selection_type="sample", n_to_select=Nselect[n], **selector_args
                )
                if Nselect[n] > 0
                else None
                for n in Nselect
            }
        else:
            assert act_on == "feature"
            selector = _CUR(
                selection_type="feature", n_to_select=Nselect, **selector_args
            )
        super().__init__(
            representation=representation,
            Nselect=Nselect,
            act_on=act_on,
            selector=selector,
            **kwargs,
        ) 
[docs]class FPSFilter(Filter):
    def __init__(
        self,
        representation,
        Nselect,
        act_on="sample per species",
        selector_args={},
        **kwargs,
    ):
        modes = ["sample", "sample per species", "feature"]
        self._check_set_mode(act_on, modes)
        if act_on == "sample":
            selector = _FPS(
                selection_type="sample", n_to_select=Nselect, **selector_args
            )
        elif act_on == "sample per species":
            selector = {
                n: _FPS(
                    selection_type="sample", n_to_select=Nselect[n], **selector_args
                )
                if Nselect[n] > 0
                else None
                for n in Nselect
            }
        else:
            assert act_on == "feature"
            selector = _FPS(
                selection_type="feature", n_to_select=Nselect, **selector_args
            )
        super().__init__(
            representation=representation,
            Nselect=Nselect,
            act_on=act_on,
            selector=selector,
            **kwargs,
        )
[docs]    def get_fps_distances(self):
        """Return the Hausdorff distances over the course of selection
        This may be a useful (rough) indicator for choosing how many points to
        select, as a small distance generally indicates that the selected point
        is close to the existing set of selected points and therefore probably
        does not add much additional information.
        Returns either an array of Hausdorff distances, or a species-indexed
        dict of arrays (for the "sample per species" mode).
        """
        if self.act_on == "sample per species":
            return {
                sp: self._selector[sp].get_select_distance() for sp in self._selector
            }
        else:
            return self._selector.get_select_distance()