Source code for torchpme.lib.mesh_interpolator

from typing import Optional

import torch


[docs] class MeshInterpolator(torch.nn.Module): """ Class for handling all steps related to interpolations in the context of a mesh based Ewald summation. In particular, this includes two core functionalities: 1. "forwards" interpolation, in which the "charges" or more general "particle weights" of atoms are assigned to grid points of a mesh. This is done in the :func:`points_to_mesh` function. 2. "backwards" interpolation, in which values defined on a mesh are interpolated to arbitrary positions typically lying between mesh points. This is done in the :func:`mesh_to_points` function. Since the computation of the interpolation weights for both of the above types of calculations is identical, this is performed in a separate function called :func:`compute_weights`. See also the :ref:`example-mesh-demo` for a demonstration of the functionalities of this class. :param cell: torch.tensor of shape ``(3, 3)``, where ``cell[i]`` is the i-th basis vector of the unit cell :param ns_mesh: toch.tensor of shape ``(3,)`` Number of mesh points to use along each of the three axes :param interpolation_nodes: int The number ``n`` of nodes used in the interpolation per coordinate axis. The total number of interpolation nodes in 3D will be ``n^3``. In general, for ``n`` nodes, the interpolation will be performed by piecewise polynomials of degree ``n - 1`` (e.g. ``n = 4`` for cubic interpolation). For Lagrange interpolation, only the values ``3, 4, 5, 6, 7`` are supported. For P3M interpolation, only the values ``1, 2, 3, 4, 5`` are supported. :param method: str The interpolation method to use. Either "Lagrange" or "P3M". """ def __init__( self, cell: torch.Tensor, ns_mesh: torch.Tensor, interpolation_nodes: int, method: str, ): super().__init__() if method not in ["Lagrange", "P3M"]: raise ValueError( f"method '{method}' is not supported. Choose from 'Lagrange' or 'P3M'" ) self.method: str = method self.interpolation_nodes: int = interpolation_nodes self.update(cell, ns_mesh) # TorchScript requires to initialize all attributes in __init__ self.interpolation_weights: torch.Tensor = torch.zeros( 1, device=self._device, dtype=self._dtype ) self.x_shifts: torch.Tensor = torch.zeros(1, device=self._device) self.y_shifts: torch.Tensor = torch.zeros(1, device=self._device) self.z_shifts: torch.Tensor = torch.zeros(1, device=self._device) self.x_indices: torch.Tensor = torch.zeros(1, device=self._device) self.y_indices: torch.Tensor = torch.zeros(1, device=self._device) self.z_indices: torch.Tensor = torch.zeros(1, device=self._device)
[docs] def update( self, cell: Optional[torch.Tensor] = None, ns_mesh: Optional[torch.Tensor] = None, ) -> None: """ Update buffers and derived attributes of the instance. Call this to reuse a ``MeshInterpolator`` object when the ``cell`` parameters or the mesh resolution changes. If neither ``cell`` nor ``ns_mesh`` are passed there is nothing to be done. :param cell: torch.tensor of shape ``(3, 3)``, where ``cell[i]`` is the i-th basis vector of the unit cell :param ns_mesh: toch.tensor of shape ``(3,)`` Number of mesh points to use along each of the three axes """ if cell is not None: if cell.shape != (3, 3): raise ValueError( f"cell of shape {list(cell.shape)} should be of shape (3, 3)" ) self.cell = cell self.inverse_cell = cell.clone() self._dtype = cell.dtype self._device = cell.device if self.cell.is_cuda: # use function that does not synchronize with the CPU self.inverse_cell = torch.linalg.inv_ex(cell)[0] else: self.inverse_cell = torch.linalg.inv(cell) if ns_mesh is not None: if ns_mesh.shape != (3,): raise ValueError( f"shape {list(ns_mesh.shape)} of `ns_mesh` has to be (3,)" ) self.ns_mesh = ns_mesh if self.cell.device != self.ns_mesh.device: raise ValueError( "`cell` and `ns_mesh` are on different devices, got " f"{self.cell.device} and {self.ns_mesh.device}" )
[docs] def get_mesh_xyz(self) -> torch.Tensor: """ Returns the Cartesian positions of the mesh points. :return: torch.tensor of shape ``(nx, ny, nz, 3)`` containing the positions of the grid points """ nx = self.ns_mesh[0] ny = self.ns_mesh[1] nz = self.ns_mesh[2] grid_scaled = torch.stack( torch.meshgrid( torch.arange(nx, dtype=self._dtype, device=self._device) / nx, torch.arange(ny, dtype=self._dtype, device=self._device) / ny, torch.arange(nz, dtype=self._dtype, device=self._device) / nz, indexing="ij", ), dim=-1, ) return torch.matmul(grid_scaled, self.cell)
def _compute_1d_weights(self, x: torch.Tensor) -> torch.Tensor: if self.method == "Lagrange": return self._compute_1d_weights_Lagrange(x) if self.method == "P3M": return self._compute_1d_weights_P3M(x) raise ValueError("Only `method` `Lagrange` and `P3M` are allowed") def _compute_1d_weights_P3M(self, x: torch.Tensor) -> torch.Tensor: """ Generate the smooth interpolation weights used to smear the particles onto a mesh. The details of the method are described in `J. Chem. Phys. 109, 7678–7693 (1998) <https://doi.org/10.1063/1.477414>`_ :param x: torch.tensor of shape ``(n,)`` Set of relative positions in the interval [-1/2, 1/2]. :return: torch.tensor of shape ``(interpolation_nodes, n)`` Interpolation weights """ # Compute weights based on the given interpolation_nodes if self.interpolation_nodes == 1: return torch.ones( (1, x.shape[0], x.shape[1]), dtype=self._dtype, device=self._device ) if self.interpolation_nodes == 2: return torch.stack([0.5 * (1 - 2 * x), 0.5 * (1 + 2 * x)]) x2 = x * x if self.interpolation_nodes == 3: return torch.stack( [ 1 / 8 * (1 - 4 * x + 4 * x2), 1 / 4 * (3 - 4 * x2), 1 / 8 * (1 + 4 * x + 4 * x2), ] ) x3 = x * x2 if self.interpolation_nodes == 4: return torch.stack( [ 1 / 48 * (1 - 6 * x + 12 * x2 - 8 * x3), 1 / 48 * (23 - 30 * x - 12 * x2 + 24 * x3), 1 / 48 * (23 + 30 * x - 12 * x2 - 24 * x3), 1 / 48 * (1 + 6 * x + 12 * x2 + 8 * x3), ] ) x4 = x * x3 if self.interpolation_nodes == 5: return torch.stack( [ 1 / 384 * (1 - 8 * x + 24 * x2 - 32 * x3 + 16 * x4), 1 / 96 * (19 - 44 * x + 24 * x2 + 16 * x3 - 16 * x4), 1 / 192 * (115 - 120 * x2 + 48 * x4), 1 / 96 * (19 + 44 * x + 24 * x2 - 16 * x3 - 16 * x4), 1 / 384 * (1 + 8 * x + 24 * x2 + 32 * x3 + 16 * x4), ] ) raise ValueError("Only `interpolation_nodes` from 1 to 5 are allowed") def _compute_1d_weights_Lagrange(self, x: torch.Tensor) -> torch.Tensor: """ Generate the smooth interpolation weights used to smear the particles onto a mesh. The details of the method are described in `J. Chem. Phys. 103, 3668-3679 (1995) <https://doi.org/10.1063/1.470043>`_ :param x: torch.tensor of shape ``(n,)`` Set of relative positions in the interval [-1/2, 1/2]. :return: torch.tensor of shape ``(interpolation_nodes, n)`` Interpolation weights """ # Compute weights based on the given interpolation_nodes x2 = x * x if self.interpolation_nodes == 3: return torch.stack( [ 1 / 2 * (-x + x2), 1 / 2 * (2 - 2 * x2), 1 / 2 * (x + x2), ] ) x3 = x * x2 if self.interpolation_nodes == 4: return torch.stack( [ 1 / 48 * (-3 + 2 * x + 12 * x2 - 8 * x3), 1 / 48 * (27 - 54 * x - 12 * x2 + 24 * x3), 1 / 48 * (27 + 54 * x - 12 * x2 - 24 * x3), 1 / 48 * (-3 - 2 * x + 12 * x2 + 8 * x3), ] ) x4 = x * x3 if self.interpolation_nodes == 5: return torch.stack( [ 1 / 24 * (2 * x - x2 - 2 * x3 + x4), 1 / 24 * (-16 * x + 16 * x2 + 4 * x3 - 4 * x4), 1 / 24 * (24 - 30 * x2 + 6 * x4), 1 / 24 * (16 * x + 16 * x2 - 4 * x3 - 4 * x4), 1 / 24 * (-2 * x - x2 + 2 * x3 + x4), ] ) x5 = x * x4 if self.interpolation_nodes == 6: return torch.stack( [ 1 / 3840 * (45 - 18 * x - 200 * x2 + 80 * x3 + 80 * x4 - 32 * x5), 1 / 3840 * (-375 + 250 * x + 1560 * x2 - 1040 * x3 - 240 * x4 + 160 * x5), 1 / 3840 * (2250 - 4500 * x - 1360 * x2 + 2720 * x3 + 160 * x4 - 320 * x5), 1 / 3840 * (2250 + 4500 * x - 1360 * x2 - 2720 * x3 + 160 * x4 + 320 * x5), 1 / 3840 * (-375 - 250 * x + 1560 * x2 + 1040 * x3 - 240 * x4 - 160 * x5), 1 / 3840 * (45 + 18 * x - 200 * x2 - 80 * x3 + 80 * x4 + 32 * x5), ] ) x6 = x * x5 if self.interpolation_nodes == 7: return torch.stack( [ 1 / 720 * (-12 * x + 4 * x2 + 15 * x3 - 5 * x4 - 3 * x5 + x6), 1 / 720 * (108 * x - 54 * x2 - 120 * x3 + 60 * x4 + 12 * x5 - 6 * x6), 1 / 720 * (-540 * x + 540 * x2 + 195 * x3 - 195 * x4 - 15 * x5 + 15 * x6), 1 / 720 * (720 - 980 * x2 + 280 * x4 - 20 * x6), 1 / 720 * (540 * x + 540 * x2 - 195 * x3 - 195 * x4 + 15 * x5 + 15 * x6), 1 / 720 * (-108 * x - 54 * x2 + 120 * x3 + 60 * x4 - 12 * x5 - 6 * x6), 1 / 720 * (12 * x + 4 * x2 - 15 * x3 - 5 * x4 + 3 * x5 + x6), ] ) raise ValueError("Only `interpolation_nodes` from 3 to 7 are allowed")
[docs] def compute_weights(self, positions: torch.Tensor): """ Compute the interpolation weights of each atom for a given cell (specified during initialization of this class). The weights are not returned, but are used when calling the forward (:func:`points_to_mesh`) and backward (:func:`mesh_to_points`) interpolation functions. :param positions: torch.tensor of shape ``(N, 3)`` containing the Cartesian coordinates of the ``N`` particles within the supercell. """ if positions.device != self._device: raise ValueError( f"`positions` device {positions.device} is not the same as instance " f"device {self._device}" ) n_positions = len(positions) if positions.shape != (n_positions, 3): raise ValueError( f"shape {list(positions.shape)} of `positions` has to be (N, 3)" ) # Compute positions relative to the mesh basis vectors positions_rel = self.ns_mesh * torch.matmul(positions, self.inverse_cell) # Calculate positions and distances based on interpolation nodes even = self.interpolation_nodes % 2 == 0 if even: # For Lagrange interpolation, when the order is odd, the relative position # of a charge is the midpoint of the two nearest gridpoints. For P3M, the # same is true for even orders. positions_rel_idx = torch.floor(positions_rel).long() offsets = positions_rel - (positions_rel_idx + 1 / 2) else: # For Lagrange interpolation, when the order is even, the relative position # of a charge is the nearest gridpoint. For P3M, the same is true for # odd orders. positions_rel_idx = torch.round(positions_rel).long() offsets = positions_rel - positions_rel_idx # Compute weights based on distances and number of nodes self.interpolation_weights = self._compute_1d_weights(offsets) # Calculate indices of mesh points on which the particle weights are # interpolated. For each particle, its weight is "smeared" onto # `interpolation_nodes**3` mesh points, which can be achived using meshgrid # below. indices_to_interpolate = torch.stack( [ (positions_rel_idx + i) % self.ns_mesh for i in range( 1 - (self.interpolation_nodes + 1) // 2, 1 + self.interpolation_nodes // 2, ) ], dim=0, ) # Generate shifts for x, y, z axes and flatten for indexing x_shifts, y_shifts, z_shifts = torch.meshgrid( torch.arange(self.interpolation_nodes, device=self._device), torch.arange(self.interpolation_nodes, device=self._device), torch.arange(self.interpolation_nodes, device=self._device), indexing="ij", ) self.x_shifts = x_shifts.flatten() self.y_shifts = y_shifts.flatten() self.z_shifts = z_shifts.flatten() # Generate a flattened representation of all the indices # of the mesh points on which we wish to interpolate the # density. self.x_indices = indices_to_interpolate[self.x_shifts, :, 0] self.y_indices = indices_to_interpolate[self.y_shifts, :, 1] self.z_indices = indices_to_interpolate[self.z_shifts, :, 2]
[docs] def points_to_mesh(self, particle_weights: torch.Tensor) -> torch.Tensor: """ Generate a discretized density from interpolation weights. It assumes that :func:`compute_weights` has been called before to compute all the necessary weights and indices. :param particle_weights: torch.tensor of shape ``(n_points, n_channels)`` ``particle_weights[i,a]`` is the weight (charge) that point (atom) i has to generate the "a-th" potential. In practice, this can be used to compute e.g. the Na and Cl contributions to the potential separately by using a one-hot encoding of the types. :return: torch.tensor of shape ``(n_channels, n_mesh, n_mesh, n_mesh)`` Discrete density """ if particle_weights.device != self._device: raise ValueError( f"`particle_weights` device {particle_weights.device} is not the same " f"as instance device {self._device}" ) if particle_weights.dim() != 2: raise ValueError( f"`particle_weights` of dimension {particle_weights.dim()} has to be " "of dimension 2" ) # Update mesh values by combining particle weights and interpolation weights n_channels = particle_weights.shape[1] nx = int(self.ns_mesh[0]) ny = int(self.ns_mesh[1]) nz = int(self.ns_mesh[2]) rho_mesh = torch.zeros( (n_channels, nx, ny, nz), dtype=self._dtype, device=self._device ) for a in range(n_channels): rho_mesh[a].index_put_( (self.x_indices, self.y_indices, self.z_indices), ( particle_weights[:, a] * self.interpolation_weights[self.x_shifts, :, 0] * self.interpolation_weights[self.y_shifts, :, 1] * self.interpolation_weights[self.z_shifts, :, 2] ), accumulate=True, ) return rho_mesh
[docs] def mesh_to_points(self, mesh_vals: torch.Tensor) -> torch.Tensor: """ Take a function defined on a mesh and interpolate its values on arbitrary positions. :param mesh_vals: torch.tensor of shape ``(n_channels, nx, ny, nz)`` The tensor contains the values of a function evaluated on a three-dimensional mesh. ``(nx, ny, nz)`` are the number of points along each of the three directions, while ``n_channels`` provides the number of such functions that are treated simulateously for the present system. :return: interpolated_values: torch.tensor of shape ``(n_points, n_channels)`` Values of the interpolated function. """ if mesh_vals.dim() != 4: raise ValueError( f"`mesh_vals` of dimension {mesh_vals.dim()} has to be of dimension 4" ) return ( ( mesh_vals[:, self.x_indices, self.y_indices, self.z_indices] * self.interpolation_weights[self.x_shifts, :, 0] * self.interpolation_weights[self.y_shifts, :, 1] * self.interpolation_weights[self.z_shifts, :, 2] ) .sum(dim=1) .T )