# -*- Mode: python; tab-width: 4; indent-tabs-mode:nil; coding:utf-8 -*-
#
# Copyright (c) 2023 Authors and contributors
# (see the AUTHORS.rst file for the full list of names)
#
# Released under the BSD 3-Clause "New" or "Revised" License
# SPDX-License-Identifier: BSD-3-Clause
"""Wrappers for the sample selectors of `scikit-matter`_.
.. _`scikit-matter`: https://scikit-matter.readthedocs.io/en/latest/selection.html
"""
from skmatter._selection import _CUR, _FPS
from ._selection import GreedySelector
[docs]
class FPS(GreedySelector):
"""
Transformer that performs Greedy Sample Selection using Farthest Point
Sampling.
If `n_to_select` is an `int`, all blocks will have this many samples
selected. In this case, `n_to_select` must be <= than the fewest number of
samples in any block.
If `n_to_select` is a dict, it must have keys that are tuples corresponding
to the key values of each block. In this case, the values of the
`n_to_select` dict can be int that specify different number of samples to
select for each block.
If `n_to_select` is -1, all samples for every block will be selected. This
is useful, for instance, for plotting Hausdorff distances, which can be
accessed through the selector.haussdorf_at_select property method after
calling the fit() method.
Refer to :py:class:`skmatter.sample_selection.FPS` for full documentation.
"""
def __init__(
self,
initialize=0,
n_to_select=None,
score_threshold=None,
score_threshold_type="absolute",
progress_bar=False,
full=False,
random_state=0,
):
super().__init__(
selector_class=_FPS,
selection_type="sample",
initialize=initialize,
n_to_select=n_to_select,
score_threshold=score_threshold,
score_threshold_type=score_threshold_type,
progress_bar=progress_bar,
full=full,
random_state=random_state,
)
[docs]
class CUR(GreedySelector):
"""Transformer that performs Greedy Sample Selection using CUR.
If `n_to_select` is an `int`, all blocks will have this many samples
selected. In this case, `n_to_select` must be <= than the fewest number of
samples in any block.
If `n_to_select` is a dict, it must have keys that are tuples corresponding
to the key values of each block. In this case, the values of the
`n_to_select` dict can be int that specify different number of samples to
select for each block.
If `n_to_select` is -1, all samples for every block will be selected.
Refer to :py:class:`skmatter.sample_selection.CUR` for full documentation.
"""
def __init__(
self,
recompute_every=1,
k=1,
tolerance=1e-12,
n_to_select=None,
score_threshold=None,
score_threshold_type="absolute",
progress_bar=False,
full=False,
random_state=0,
):
super().__init__(
selector_class=_CUR,
selection_type="sample",
recompute_every=recompute_every,
k=k,
tolerance=tolerance,
n_to_select=n_to_select,
score_threshold=score_threshold,
score_threshold_type=score_threshold_type,
progress_bar=progress_bar,
full=full,
random_state=random_state,
)