import torch
from torch.nn.utils.rnn import pad_sequence
[docs]
def get_ns_mesh(cell: torch.Tensor, mesh_spacing: float):
"""
Computes the mesh size given a target mesh spacing and cell
getting the closest powers of 2 to help with FFT.
:param cell: torch.tensor of shape ``(3, 3)``, where ``cell[i]`` is the i-th basis
vector of the unit cell
:param mesh_spacing: float
:param differentiable: boll
:return: torch.tensor of length 3 containing the mesh size
"""
basis_norms = torch.linalg.norm(cell, dim=1)
ns_approx = basis_norms / mesh_spacing
ns_actual_approx = 2 * ns_approx + 1 # actual number of mesh points
# ns = [nx, ny, nz], closest power of 2 (helps for FT efficiency)
return torch.tensor(2).pow(torch.ceil(torch.log2(ns_actual_approx)).long())
def _generate_kvectors(
cell: torch.Tensor, ns: torch.Tensor, for_ewald: bool
) -> torch.Tensor:
# Check that all provided parameters have the correct shapes and are consistent
# with each other
if cell.shape != (3, 3):
raise ValueError(f"cell of shape {list(cell.shape)} should be of shape (3, 3)")
if ns.shape != (3,):
raise ValueError(f"ns of shape {list(ns.shape)} should be of shape (3, )")
if ns.device != cell.device:
raise ValueError(
f"`ns` and `cell` are not on the same device, got {ns.device} and "
f"{cell.device}."
)
if cell.is_cuda:
# use function that does not synchronize with the CPU
inverse_cell = torch.linalg.inv_ex(cell)[0]
else:
inverse_cell = torch.linalg.inv(cell)
reciprocal_cell = 2 * torch.pi * inverse_cell.T
bx = reciprocal_cell[0]
by = reciprocal_cell[1]
bz = reciprocal_cell[2]
# Generate all reciprocal space vectors from real FFT!
# The frequencies from the fftfreq function are of the form [0, 1/n, 2/n, ...]
# These are then converted to [0, 1, 2, ...] by multiplying with n.
# get the frequencies, multiply with n, then w/ the reciprocal space vectors
kxs = (bx * ns[0]) * torch.fft.fftfreq(
ns[0], device=cell.device, dtype=cell.dtype
).unsqueeze(-1)
kys = (by * ns[1]) * torch.fft.fftfreq(
ns[1], device=cell.device, dtype=cell.dtype
).unsqueeze(-1)
if for_ewald:
kzs = (bz * ns[2]) * torch.fft.fftfreq(
ns[2], device=cell.device, dtype=cell.dtype
).unsqueeze(-1)
else:
kzs = (bz * ns[2]) * torch.fft.rfftfreq(
ns[2], device=cell.device, dtype=cell.dtype
).unsqueeze(-1)
# then take the cartesian product (all possible combinations, same as meshgrid)
# via broadcasting (to avoid instantiating intermediates), and sum up
return kxs[:, None, None] + kys[None, :, None] + kzs[None, None, :]
[docs]
def generate_kvectors_for_mesh(cell: torch.Tensor, ns: torch.Tensor) -> torch.Tensor:
"""
Compute all reciprocal space vectors for Fourier space sums.
This variant is used in combination with **mesh based calculators** using the fast
fourier transform (FFT) algorithm.
:param cell: torch.tensor of shape ``(3, 3)``, where ``cell[i]`` is the i-th basis
vector of the unit cell
:param ns: torch.tensor of shape ``(3,)`` and dtype int
``ns = [nx, ny, nz]`` contains the number of mesh points in the x-, y- and
z-direction, respectively. For faster performance during the Fast Fourier
Transform (FFT) it is recommended to use values of nx, ny and nz that are
powers of 2.
:return: torch.tensor of shape ``(nx, ny, nz, 3)`` containing all reciprocal
space vectors that will be used in the (FFT-based) mesh calculators.
Note that ``k_vectors[0,0,0] = [0,0,0]`` always is the zero vector.
.. seealso::
:func:`generate_kvectors_for_ewald` for a function to be used for Ewald
calculators.
"""
return _generate_kvectors(cell=cell, ns=ns, for_ewald=False)
[docs]
def generate_kvectors_for_ewald(
cell: torch.Tensor,
ns: torch.Tensor,
) -> torch.Tensor:
"""
Compute all reciprocal space vectors for Fourier space sums.
This variant is used with the **Ewald calculator**, in which the sum over the
reciprocal space vectors is performed explicitly rather than using the fast Fourier
transform (FFT) algorithm.
The main difference with :func:`generate_kvectors_for_mesh` is the shape of the
output tensor (see documentation on return) and the fact that the full set of
reciprocal space vectors is returned, rather than the FFT-optimized set that roughly
contains only half of the vectors.
:param cell: torch.tensor of shape ``(3, 3)``, where ``cell[i]`` is the i-th basis
vector of the unit cell
:param ns: torch.tensor of shape ``(3,)`` and dtype int
``ns = [nx, ny, nz]`` contains the number of mesh points in the x-, y- and
z-direction, respectively.
:return: torch.tensor of shape ``(n, 3)`` containing all reciprocal
space vectors that will be used in the Ewald calculator.
Note that ``k_vectors[0] = [0,0,0]`` always is the zero vector.
.. seealso::
:func:`generate_kvectors_for_mesh` for a function to be used with mesh based
calculators like PME.
"""
return _generate_kvectors(cell=cell, ns=ns, for_ewald=True).reshape(-1, 3)
def compute_batched_kvectors(
lr_wavelength: float,
cells: torch.Tensor,
) -> torch.Tensor:
r"""
Generate k-vectors for multiple systems in batches.
:param lr_wavelength: Spatial resolution used for the long-range (reciprocal space)
part of the Ewald sum. More concretely, all Fourier space vectors with a
wavelength >= this value will be kept. If not set to a global value, it will be
set to half the smearing parameter to ensure convergence of the
long-range part to a relative precision of 1e-5.
:param cell: torch.tensor of shape ``(B, 3, 3)``, where ``cell[i]`` is the i-th
basis vector of the unit cell for system i in the batch of size B.
"""
all_kvectors = []
k_cutoff = 2 * torch.pi / lr_wavelength
for cell in cells:
basis_norms = torch.linalg.norm(cell, dim=1)
ns_float = k_cutoff * basis_norms / 2 / torch.pi
ns = torch.ceil(ns_float).long()
kvectors = generate_kvectors_for_ewald(ns=ns, cell=cell)
all_kvectors.append(kvectors)
# We do not return masks here; instead, we rely on the fact that for the Coulomb
# potential, the k = 0 vector is ignored in the calculations and can therefore be
# safely padded with zeros.
return pad_sequence(all_kvectors, batch_first=True)