Source code for torchpme.metatensor.calculator

import torch

try:
    from metatensor.torch import Labels, TensorBlock, TensorMap
    from metatensor.torch.atomistic import System
except ImportError:
    raise ImportError(
        "metatensor.torch is required for torchpme.metatensor but is not installed. "
        "Try installing it with:\npip install metatensor[torch]"
    ) from None

from .. import calculators as torch_calculators


[docs] class Calculator(torch.nn.Module): """ Base calculator for the metatensor interface. This is just a thin wrapper around the corresponding generic torch :class:`torchpme.calculators.Calculator`. If you want to wrap a ``metatensor`` interface around another calculator, you can just define the class and set the static member ``_base_calculator`` to the corresponding torch calculator. """ _base_calculator: type[torch_calculators.Calculator] = torch_calculators.Calculator def __init__(self, *args, **kwargs): super().__init__() self._calculator = self._base_calculator(*args, **kwargs) @staticmethod def _validate_compute_parameters(system: System, neighbors: TensorBlock) -> None: dtype = system.positions.dtype device = system.positions.device if neighbors.values.dtype != dtype: raise ValueError( f"dtype of `neighbors` ({neighbors.values.dtype}) must be the same " f"as `system` ({dtype})" ) if neighbors.values.device != device: raise ValueError( f"device of `neighbors` ({neighbors.values.device}) must be the same " f"as `system` ({device})" ) # Check metadata of neighbors samples_names = neighbors.samples.names if ( len(samples_names) != 5 or samples_names[0] != "first_atom" or samples_names[1] != "second_atom" or samples_names[2] != "cell_shift_a" or samples_names[3] != "cell_shift_b" or samples_names[4] != "cell_shift_c" ): raise ValueError( "Invalid samples for `neighbors`: the sample names must be " "'first_atom', 'second_atom', 'cell_shift_a', 'cell_shift_b', " "'cell_shift_c'" ) components_labels = Labels( ["xyz"], torch.arange(3, dtype=torch.int32, device=device).unsqueeze(1), ) components = neighbors.components if len(components) != 1 or components[0] != components_labels: raise ValueError( "Invalid components for `neighbors`: there should be a single " "'xyz'=[0, 1, 2] component" ) properties_labels = Labels( ["distance"], torch.zeros(1, 1, dtype=torch.int32, device=device) ) if neighbors.properties != properties_labels: raise ValueError( "Invalid properties for `neighbors`: there should be a single " "'distance'=0 property" ) if "charges" not in system.known_data(): raise ValueError("`system` does not contain `charges` data") charge_tensor = system.get_data("charges") if len(charge_tensor) != 1: raise ValueError( f"Charge tensor have exactlty one block but has {len(charge_tensor)} " "blocks" ) n_charge_components = len(charge_tensor.block().components) if n_charge_components > 0: raise ValueError( "TensorBlock containg the charges should not have components; " f"found {n_charge_components}" )
[docs] def forward(self, system: System, neighbors: TensorBlock) -> TensorMap: """ Compute the potential "energy". The ``system`` must contain a custom data field ``charges``. The potential will be calculated for each ``"charges_channel"``, which will also be the properties name of the returned :class:`metatensor.torch.TensorMap`. :param system: System to run the calculations. The system must have attached ``"charges"`` using the :meth:`add_data <metatensor.torch.atomistic.System.add_data>` method. :param neighbors: The neighbor list. If a neighbor list is attached to a :class:`metatensor.torch.atomistic.System` it can be extracted with the :meth:`get_neighborlist <metatensor.torch.atomistic.System.get_neighborlist>` method using a :class:`NeighborListOptions <metatensor.torch.atomistic.NeighborListOptions>`. Note to use the same ``full_list`` option for these options as provided for ``full_neighbor_list`` in the constructor. .. note:: Although ``neighbors`` can be attached to the ``system``, they are required to be passed explicitly here. While it's possible to design the class to automatically extract the neighbor list by accepting a :class:`NeighborListOptions <metatensor.torch.atomistic.NeighborListOptions>` directly in the constructor, we chose explicit passing for consistency with the torch interface. :return: :class:`metatensor.torch.TensorMap` containing the potential """ self._validate_compute_parameters(system, neighbors) device = system.positions.device charges = system.get_data("charges").block().values n_atoms = len(system) samples = torch.zeros((n_atoms, 2), device=device, dtype=torch.int32) samples[:, 0] = 0 samples[:, 1] = torch.arange(n_atoms, device=device, dtype=torch.int32) neighbor_indices = neighbors.samples.view(["first_atom", "second_atom"]).values if device.type == "cpu": # move to 64-bit integers, for some reason indexing 64-bit is a lot faster # than using 32-bit integers on CPU. CUDA seems fine with either types neighbor_indices = neighbor_indices.to( torch.int64, memory_format=torch.contiguous_format ) neighbor_distances = torch.linalg.norm(neighbors.values, dim=1).squeeze(1) potential = self._calculator.forward( charges=charges, cell=system.cell, positions=system.positions, neighbor_indices=neighbor_indices, neighbor_distances=neighbor_distances, ) properties_values = torch.arange( charges.shape[1], device=device, dtype=torch.int32 ) block = TensorBlock( values=potential, samples=Labels(["system", "atom"], samples), components=[], properties=Labels("charges_channel", properties_values.unsqueeze(1)), ) keys = Labels("_", torch.zeros(1, 1, dtype=torch.int32, device=device)) return TensorMap(keys=keys, blocks=[block])