Mesh Interpolator¶
- class torchpme.lib.MeshInterpolator(cell: Tensor, ns_mesh: Tensor, interpolation_nodes: int, method: str)[source]¶
Class for handling all steps related to interpolations in the context of a mesh based Ewald summation.
In particular, this includes two core functionalities: 1. “forwards” interpolation, in which the “charges” or more general “particle weights” of atoms are assigned to grid points of a mesh. This is done in the
points_to_mesh()
function. 2. “backwards” interpolation, in which values defined on a mesh are interpolated to arbitrary positions typically lying between mesh points. This is done in themesh_to_points()
function.Since the computation of the interpolation weights for both of the above types of calculations is identical, this is performed in a separate function called
compute_weights()
.See also the Examples of the MeshInterpolator class for a demonstration of the functionalities of this class.
- Parameters:
cell (Tensor) – torch.tensor of shape
(3, 3)
, wherecell[i]
is the i-th basis vector of the unit cellns_mesh (Tensor) – toch.tensor of shape
(3,)
Number of mesh points to use along each of the three axesinterpolation_nodes (int) – int The number
n
of nodes used in the interpolation per coordinate axis. The total number of interpolation nodes in 3D will ben^3
. In general, forn
nodes, the interpolation will be performed by piecewise polynomials of degreen - 1
(e.g.n = 4
for cubic interpolation). For Lagrange interpolation, only the values3, 4, 5, 6, 7
are supported. For P3M interpolation, only the values1, 2, 3, 4, 5
are supported.method (str) – str The interpolation method to use. Either “Lagrange” or “P3M”.
- update(cell: Tensor | None = None, ns_mesh: Tensor | None = None) None [source]¶
Update buffers and derived attributes of the instance.
Call this to reuse a
MeshInterpolator
object when thecell
parameters or the mesh resolution changes. If neithercell
norns_mesh
are passed there is nothing to be done.
- get_mesh_xyz() Tensor [source]¶
Returns the Cartesian positions of the mesh points.
- Returns:
torch.tensor of shape
(nx, ny, nz, 3)
containing the positions of the grid points- Return type:
- compute_weights(positions: Tensor)[source]¶
Compute the interpolation weights of each atom for a given cell (specified during initialization of this class). The weights are not returned, but are used when calling the forward (
points_to_mesh()
) and backward (mesh_to_points()
) interpolation functions.- Parameters:
positions (Tensor) – torch.tensor of shape
(N, 3)
containing the Cartesian coordinates of theN
particles within the supercell.
- points_to_mesh(particle_weights: Tensor) Tensor [source]¶
Generate a discretized density from interpolation weights. It assumes that
compute_weights()
has been called before to compute all the necessary weights and indices.- Parameters:
particle_weights (Tensor) – torch.tensor of shape
(n_points, n_channels)
particle_weights[i,a]
is the weight (charge) that point (atom) i has to generate the “a-th” potential. In practice, this can be used to compute e.g. the Na and Cl contributions to the potential separately by using a one-hot encoding of the types.- Returns:
torch.tensor of shape
(n_channels, n_mesh, n_mesh, n_mesh)
Discrete density- Return type:
- mesh_to_points(mesh_vals: Tensor) Tensor [source]¶
Take a function defined on a mesh and interpolate its values on arbitrary positions.
- Parameters:
mesh_vals (Tensor) – torch.tensor of shape
(n_channels, nx, ny, nz)
The tensor contains the values of a function evaluated on a three-dimensional mesh.(nx, ny, nz)
are the number of points along each of the three directions, whilen_channels
provides the number of such functions that are treated simulateously for the present system.- Returns:
interpolated_values: torch.tensor of shape
(n_points, n_channels)
Values of the interpolated function.- Return type: