Splitting Equistore TensorMaps#
Module for splitting lists of TensorMap
objects into multiple
TensorMap
objects along a given axis.
- equisolve.utils.split_data.split_data(tensors: List[TensorMap] | TensorMap, axis: str, names: List[str] | str, n_groups: int, group_sizes: List[int] | List[float] | None = None, seed: int | None = None) Tuple[List[List[TensorMap]], List[Labels]] [source]#
Splits a list of
TensorMap
objects into multipleTensorMap
objects along a given axis.For either the “samples” or “properties” axis, the unique indices for the specified metadata name are found. If seed is set, the indices are shuffled. Then, they are divided into n_groups, where the sizes of the groups are specified by the group_sizes argument.
These grouped indices are then used to split the list of input tensors. The split tensors, along with the grouped labels, are returned. The tensors are returned as a list of list of
TensorMap
objects.Each list in the returned
list
oflist
corresponds to the split :py:class`TensorMap` at the same position in the input tensors list. Each nested list containsTensorMap
objects that share no common indices for the specified axis and names. However, the metadata on all other axes (including the keys) will be equivalent.The passed list of
TensorMap
objects in tensors must have the same set of unique indices for the specified axis and names. For instance, if passing an input and output tensor for splitting (i.e. as used in supervised machine learning), the output tensor must have structure indices 0 -> 10 if the input tensor does.- Parameters:
tensors – input list of
TensorMap
objects, each of which will be split into n_groups newTensorMap
objects.axis – a
str
equal to either “samples” or “properties”. This is the axis along which the inputTensorMap
objects will be split.names – a
list
ofstr
indicating the samples/properties names by which the tensors will be split.n_groups – an
int
indicating how many newTensorMap
objects each of the tensors passed in tensors will be split into. If group_sizes is none (default), n_groups is used to split the data inton
evenly sized groups according to the unique metadata for the specified axis and names, to the nearest integer.group_sizes – an ordered
list
offloat
the group sizes to split each inputTensorMap
into. Alist
ofint
will be interpreted as an indication of the absolute group sizes, whereas a list of float as indicating the relative sizes. For the former case, the sum of this list must be <= the total number of unique indices present in the input tensors for the chosen axis and names. In the latter, the sum of this list must be <= 1.seed – an
int
that seeds the numpy random number generator. Used to control shuffling of the unique indices, which dictate the data that ends up in each of the split output tensors. If None (default), no shuffling of the indices occurs. If aint
, shuffling is executed but with a random seed set to this value.
- Return split_tensors:
list
oflist
ofTensorMap
. Thei
th element in the list contains n_groupsTensorMap
objects corresponding to the split ithTensorMap
of the input list tensors.- Return grouped_labels:
list of
Labels
corresponding to the unique indices according to the specified axis and names that are present in each of the returned groups ofTensorMap
. The length of this list is n_groups.
Examples#
Split a TensorMap tensor into 2 new TensorMaps along the “samples” axis for the “structure” metadata. Without specifying group_sizes, the data will be split equally by structure index. If the number of unique strutcure indices present in the input data is not exactly divisible by n_groups, the group sizes will be made to the nearest int. Without specifying seed, no shuffling of the structure indices will occur and they will be grouped in lexigraphical order. For instance, if the input tensor has structure indices 0 -> 9 (inclusive), the first new tensor will contain only structure indices 0 -> 4 (inc.) and the second will contain only 5 -> 9 (inc).
from equisolve.utils import split_data [[new_tensor_1, new_tensor_2]], grouped_labels = split_data( tensors=tensor, axis="samples", names=["structure"], n_groups=2, )
Split 2 tensors corresponding to input and output data into train and test data, with a relative 80:20 ratio. If both input and output tensors contain structure indices 0 -> 9 (inclusive), the in_train and out_train tensors will contain structure indices 0 -> 7 (inc.) and the in_test and out_test tensors will contain structure indices 8 -> 9 (inc.). As we want to specify relative group sizes, we will pass group_sizes as a list of float. Specifying the seed will shuffle the structure indices before the groups are made.
from equisolve.utils import split_data [[in_train, in_test], [out_train, out_test]], grouped_labels = split_data( tensors=[input, output], axis="samples", names=["structure"], n_groups=2, # for train-test split group_sizes=[0.8, 0.2], # relative, a 80% 20% train-test split seed=100, )
Split 2 tensors corresponding to input and output data into train, test, and validation data. If input and output tensors have the same 10 structure indices, we can split such that the train, test, and val tensors have 7, 2, and 1 structures in each, respectively. We want to specify absolute group sizes, so will pass a list of int. Specifying the seed will shuffle the structure indices before they are grouped.
import metatensor from equisolve.utils import split_data # Find the unique structure indices in the input tensor unique_structure_indices = metatensor.unique_metadata( tensor=input, axis="samples", names=["structure"], ) # They run from 0 -> 10 (inclusive) unique_structure_indices >>> Labels( [(0,), (1,), (2,), (3,), (4,), (5,), (6,), (8,), (9,)], dtype=[('structure', '<i4')] ) # Verify that the output has the same unique structure indices assert unique_structure_indices == metatensor.unique_metadata( tensor=output, axis="samples", names=["structure"], ) >>> True # Split the data by structure index, with an abolute split of 7, 2, 1 # for the train, test, and validation tensors, respectively ( [ [in_train, in_test, in_val], [out_train, out_test, out_val] ] ), grouped_labels = split_data( tensors=[input, output], axis="samples", names=["structure"], n_groups=3, # for train-test-validation group_sizes=[7, 2, 1], # absolute; 7, 2, 1 for train, test, val seed=100, ) # Inspect the grouped structure indices grouped_labels >>> [ Labels( [(3,), (7,), (1,), (8,), (0,), (9,), (2,)], dtype=[('structure', '<i4')] ), Labels([(4,), (6,)], dtype=[('structure', '<i4')]), Labels([(5,)], dtype=[('structure', '<i4')]), ]