Source code for rascal.utils.filter

import logging
import numpy as np
from .io import BaseIO
from ..models.sparse_points import SparsePoints
from ..representations.spherical_invariants import SphericalInvariants

LOGGER = logging.getLogger(__name__)

try:
    from skcosmo._selection import _FPS, _CUR
except ImportError as ie:
    LOGGER.warn(
        "Warning: skcosmo module not found. CUR and FPS filters will be unavailable."
    )
    LOGGER.warn("Original error:\n" + str(ie))
    _FPS = _CUR = None

# Index conversion utilities


def _indices_manager_to_perstructure(managers, selected_ids_global):
    """Convert manager-global center indexing to per-structure format

    That is, change them to a list of lists of structure-local atom indices
    that is accepted as input to SparsePoints

    Parameters
    ----------
    managers : AtomsList or list(ase.Atoms)
        list of atomic structures
    selected_ids_global : list of int
        global indices to convert

    Returns
    -------
    selected_ids : list of list(int)
        list the atom indices (within their structure) that have been selected
    """
    selected_ids = []
    selected_ids_global = np.array(selected_ids_global)
    natoms_list = [len(manager) for manager in managers]
    split_idces = np.cumsum(natoms_list)
    structure_start_idx = 0
    # Handle any out-of-range indices
    if np.any(selected_ids_global >= split_idces[-1]):
        bad_indices = selected_ids_global[selected_ids_global >= split_idces[-1]]
        bad_indices_str = np.array2string(
            bad_indices, threshold=5, edgeitems=2, separator=", "
        )
        raise ValueError(f"Selected index(es): {bad_indices_str} out of range")
    # Do the actual conversion if OK
    for structure_end_idx in split_idces:
        this_structure_idces = selected_ids_global[
            (selected_ids_global >= structure_start_idx)
            & (selected_ids_global < structure_end_idx)
        ]
        selected_structure_idces = this_structure_idces - structure_start_idx
        selected_ids.append(list(selected_structure_idces))
        structure_start_idx = structure_end_idx
    assert sum(len(ids) for ids in selected_ids) == len(selected_ids_global)
    return selected_ids


def _indices_perspecies_manager_to_perstructure(managers, selected_ids_by_sp, sps):
    """Convert per-species, manager-global center indexing to per-structure

    That is, change them to a list of lists of structure-local atom indices
    that is accepted as input to SparsePoints

    See _get_index_mappings_sample_per_species() to make the intput

    Parameters
    ----------
    managers : AtomsList
        list of atomic structures
    selected_ids_by_sp : dict
        indices to convert; each dictionary entry (keyed by species)
        is a list of indices into the array of all atoms in the manager
        of only that species
    sps : list(int) or set(int)
        unique center atom species present in managers

    Returns
    -------
    selected_ids : list of lists
        list the atom indices (within their structure) that have been selected
        They are ordered, first by species (sorted, irrespective of the order
        passed in), then by selection order.

    Notes
    -----
    Asking for indices for a species not present in the AtomsList will
    result in an out-of-range error (since the slice of atoms of that species
    is of size zero).

    This function does not yet support lists of ASE Atoms (instead of a list
    of managers) as input, but it should be fairly easy to add support in
    the future if required.
    """
    if len(set(sps)) != len(sps):
        raise ValueError(f"List of species contains duplicated entries: {sps}")
    selected_ids = [[] for ii in range(len(managers))]
    structure_sp_start_idx = {sp: 0 for sp in sps}
    for sp in sps:
        selected_ids_by_sp[sp] = np.array(selected_ids_by_sp[sp])
    for structure, structure_selected_ids in zip(managers, selected_ids):
        perspecies_counter = {sp: 0 for sp in sps}
        structure_index_mapping = {sp: [] for sp in sps}
        for perstructure_idx, atom in enumerate(structure):
            atom_sp = atom.atom_type
            if atom_sp not in sps:
                raise ValueError(
                    f"Atom of type {atom_sp} found but was not listed in sps: {sps}"
                )
            structure_index_mapping[atom_sp].append(perstructure_idx)
            perspecies_counter[atom_sp] += 1
        for sp in sorted(list(sps)):
            selected_ids_sp = selected_ids_by_sp[sp]
            structure_sp_end_idx = structure_sp_start_idx[sp] + perspecies_counter[sp]
            this_structure_sp_idces = (
                selected_ids_sp[
                    (selected_ids_sp >= structure_sp_start_idx[sp])
                    & (selected_ids_sp < structure_sp_end_idx)
                ]
                - structure_sp_start_idx[sp]
            )
            structure_selected_ids.extend(
                [
                    structure_index_mapping[sp][sp_idx]
                    for sp_idx in this_structure_sp_idces
                ]
            )
            structure_sp_start_idx[sp] += perspecies_counter[sp]
    for sp in sps:
        selected_out_of_range = (
            np.array(selected_ids_by_sp[sp]) >= structure_sp_start_idx[sp]
        )
        if np.any(selected_out_of_range):
            bad_indices = np.array(selected_ids_by_sp[sp])[selected_out_of_range]
            bad_indices_str = np.array2string(
                bad_indices, threshold=5, edgeitems=2, separator=", "
            )
            error_str = (
                f"Selected index(es): {bad_indices_str} for species {sp} out of range"
            )
            if 0 in bad_indices:
                error_str += " (species does not appear to be present)"
            raise ValueError(error_str)
    # Check that we haven't missed anything
    assert sum(len(ids) for ids in selected_ids) == sum(
        len(selected_ids_by_sp[sp]) for sp in sps
    )
    return selected_ids


def _split_feature_matrix_by_species(managers, X, sps):
    """Does exactly what it says on the tin

    Parameters
    ----------
    managers : AtomsList
        list of atomic structures
    X : np.ndarray (2-D)
        feature matrix computed from managers; rows must correspond to atoms
    sps : list(int)
        list of unique center atom species present in managers

    Returns
    -------
    dict(np.ndarray)
        The feature matrix split into matrices each corresponding to one
        of the atomic species requested

    Warnings
    --------
    This function does not check that the list of species provided actually
    corresponds to those present in managers; it only performs the selection
    (which would be empty for nonexistent species).
    """
    X_per_species = {}
    global_species_list = []
    for structure in managers:
        global_species_list.extend([atom.atom_type for atom in structure])
    global_species_list = np.array(global_species_list)
    for sp in sps:
        X_per_species[sp] = X[global_species_list == sp]
    return X_per_species


[docs]class Filter(BaseIO): """ A super class for filtering representations based upon a standard sample or feature selection class. This is mainly a wrapper around selectors (implemented e.g. in scikit-cosmo) that handles the semantic-index transformations required after selection. Parameters ---------- representation : Calculator Representation calculator associated with the kernel Nselect: int number of points to select. If act_on='sample per species' then it should be a dictionary mapping atom type to the number of samples, e.g. Nselect = {1:200,6:100,8:50}. selector: selector to use for filtering. The selector should have a `fit` function, which when called will select from the input matrix the desired features / samples and a `get_support` function which takes parameters `indices` and `ordered`, and returns a list of selection indices, in the order that they were selected, when `indices=True` and `ordered=True`. act_on: string Select how to apply the selection. Can be either of 'sample', 'sample per species','feature'. Default 'sample per species'. Note that for 'feature' mode only the SphericalInvariants representation is supported. """ def __init__( self, representation, Nselect, selector, act_on="sample per species", ): self._representation = representation self.Nselect = Nselect if self.act_on is None: self._check_set_mode(act_on) # effectively selected list of indices at the filter step # the indices have been reordered for effiency and compatibility with # the c++ routines self.selected_ids = None # for 'sample' selection self.selected_sample_ids = None # for 'sample per species' selection self.selected_sample_ids_by_sp = None # for feature selection self.selected_feature_ids_global = None self._selector = selector def _check_set_mode( self, act_on, modes=["sample", "sample per species", "feature"] ): """Check that the supplied act_on is one of the supported modes Set the mode if it is valid, aise a ValueError with a helpful message otherwise A list of valid modes can be supplied in case it differs from the superclass default. """ if act_on in modes: self.act_on = act_on else: valid_modes = ['"{}"'.format(mode) for mode in modes] if len(valid_modes) > 1: valid_modes[-1] = "or " + valid_modes[-1] valid_modes_str = ", ".join(valid_modes) raise ValueError('"act_on" should be one of: ' + valid_modes_str) def select(self, managers): """Perform selection of samples/features. Parameters ---------- managers : AtomsList list of structures containing features computed with representation Returns ------- Filter (self) Returns self; use `filter()` to perform the actual filtering operation """ X = managers.get_features(self._representation) if self.act_on == "sample per species": self.selected_sample_ids_by_sp = {} sps = list(self.Nselect.keys()) LOGGER.info( f"The number of pseudo points selected by central atom species is: {self.Nselect}" ) X_by_sp = _split_feature_matrix_by_species(managers, X, sps) for sp in sps: LOGGER.info(f"Selecting species: {sp}") if self._selector[sp] is not None: self._selector[sp].fit(X_by_sp[sp]) in_sample_indices = np.array( self._selector[sp].get_support(indices=True, ordered=True), dtype=int, ) self.selected_sample_ids_by_sp[sp] = in_sample_indices else: self.selected_sample_ids_by_sp[sp] = [] return self else: self._selector.fit(X) if self.act_on == "sample": self.selected_sample_ids = self._selector.get_support( indices=True, ordered=True ) else: self.selected_feature_ids_global = self._selector.get_support( indices=True, ordered=True ) return self def filter(self, managers, n_select=None): """Apply the fitted selection to a new set of managers Parameters ---------- managers : AtomsList list of structures containing features computed with representation n_select : int number of selections to return, must be less than self.Nselect Returns ------- SparsePoints or list(int) or dict Selected samples. The format depends on self.act_on - if it is "sample" or "sample per species", then a SparsePoints instance is directly returned. If it is "feature", then it is a dictionary containing the "coefficient_subselection" key that can be directly passed to the SphericalInvariants constructor. Warnings -------- Note that the selected points are sorted in order of selection, _except_ if self.act_on=="sample", in which case the sparse points are afterwards sorted by species. Raises ------ ValueError if requesting more selected samples or features than were used to initialize the representation """ if n_select is None: n_select = self.Nselect else: if n_select > self.Nselect: raise ValueError( f"It is only possible to filter {self.Nselect} {self.act_on}(s), " f"you have requested {n_select}" ) if self.act_on == "sample per species": sps = list(n_select.keys()) selected_ids_by_sp = { key: val[: n_select[key]] for key, val in self.selected_sample_ids_by_sp.items() } self.selected_ids = _indices_perspecies_manager_to_perstructure( managers, selected_ids_by_sp, sps ) sparse_points = SparsePoints(self._representation) sparse_points.extend(managers, self.selected_ids) return sparse_points elif self.act_on == "sample": selected_ids_global = self.selected_sample_ids[:n_select] self.selected_ids = _indices_manager_to_perstructure( managers, selected_ids_global ) # The sparse points will be reordered since they're not per-species # but the resulting object is still usable sparse_points = SparsePoints(self._representation) sparse_points.extend(managers, self.selected_ids) return sparse_points elif self.act_on == "feature": if not isinstance(self._representation, SphericalInvariants): raise ValueError( "Feature filtering currently only supported for SphericalInvariants" ) feat_idx2coeff_idx = self._representation.get_feature_index_mapping( managers ) self.selected_ids = {key: [] for key in feat_idx2coeff_idx[0].keys()} selected_ids_sorting = np.argsort( self.selected_feature_ids_global[:n_select] ) selected_feature_ids = self.selected_feature_ids_global[ selected_ids_sorting ] for idx in selected_feature_ids: coef_idx = feat_idx2coeff_idx[idx] for key in self.selected_ids.keys(): self.selected_ids[key].append(int(coef_idx[key])) self.selected_ids = dict(coefficient_subselection=self.selected_ids) # keep the global indices and ordering for ease of use self.selected_ids[ "selected_feature_ids_global" ] = selected_feature_ids.tolist() self.selected_ids[ "selected_feature_ids_global_selection_ordering" ] = selected_ids_sorting.tolist() return self.selected_ids def select_and_filter(self, managers): return self.select(managers).filter(managers) def _get_data(self): return dict( selected_ids=self.selected_ids, selected_sample_ids=self.selected_sample_ids, selected_sample_ids_by_sp=self.selected_sample_ids_by_sp, selected_feature_ids_global=self.selected_feature_ids_global, ) def _set_data(self, data): self.selected_ids = data["selected_ids"] self.selected_sample_ids = data["selected_sample_ids"] self.selected_sample_ids_by_sp = data["selected_sample_ids_by_sp"] self.selected_feature_ids_global = data["selected_feature_ids_global"] def _get_init_params(self): return dict( representation=self._representation, Nselect=self.Nselect, act_on=self.act_on, )
[docs]class CURFilter(Filter): def __init__( self, representation, Nselect, act_on="sample per species", selector_args={}, **kwargs, ): modes = ["sample", "sample per species", "feature"] self._check_set_mode(act_on, modes) if act_on == "sample": selector = _CUR( selection_type="sample", n_to_select=Nselect, **selector_args ) elif act_on == "sample per species": selector = { n: _CUR( selection_type="sample", n_to_select=Nselect[n], **selector_args ) if Nselect[n] > 0 else None for n in Nselect } else: assert act_on == "feature" selector = _CUR( selection_type="feature", n_to_select=Nselect, **selector_args ) super().__init__( representation=representation, Nselect=Nselect, act_on=act_on, selector=selector, **kwargs, )
[docs]class FPSFilter(Filter): def __init__( self, representation, Nselect, act_on="sample per species", selector_args={}, **kwargs, ): modes = ["sample", "sample per species", "feature"] self._check_set_mode(act_on, modes) if act_on == "sample": selector = _FPS( selection_type="sample", n_to_select=Nselect, **selector_args ) elif act_on == "sample per species": selector = { n: _FPS( selection_type="sample", n_to_select=Nselect[n], **selector_args ) if Nselect[n] > 0 else None for n in Nselect } else: assert act_on == "feature" selector = _FPS( selection_type="feature", n_to_select=Nselect, **selector_args ) super().__init__( representation=representation, Nselect=Nselect, act_on=act_on, selector=selector, **kwargs, )
[docs] def get_fps_distances(self): """Return the Hausdorff distances over the course of selection This may be a useful (rough) indicator for choosing how many points to select, as a small distance generally indicates that the selected point is close to the existing set of selected points and therefore probably does not add much additional information. Returns either an array of Hausdorff distances, or a species-indexed dict of arrays (for the "sample per species" mode). """ if self.act_on == "sample per species": return { sp: self._selector[sp].get_select_distance() for sp in self._selector } else: return self._selector.get_select_distance()