Source code for torchpme.lib.kvectors

import torch


[docs] def get_ns_mesh(cell: torch.Tensor, mesh_spacing: float): """ Computes the mesh size given a target mesh spacing and cell getting the closest powers of 2 to help with FFT. :param cell: torch.tensor of shape ``(3, 3)``, where ``cell[i]`` is the i-th basis vector of the unit cell :param mesh_spacing: float :param differentiable: boll :return: torch.tensor of length 3 containing the mesh size """ basis_norms = torch.linalg.norm(cell, dim=1) ns_approx = basis_norms / mesh_spacing ns_actual_approx = 2 * ns_approx + 1 # actual number of mesh points # ns = [nx, ny, nz], closest power of 2 (helps for FT efficiency) return torch.tensor(2).pow(torch.ceil(torch.log2(ns_actual_approx)).long())
def _generate_kvectors( cell: torch.Tensor, ns: torch.Tensor, for_ewald: bool ) -> torch.Tensor: # Check that all provided parameters have the correct shapes and are consistent # with each other if cell.shape != (3, 3): raise ValueError(f"cell of shape {list(cell.shape)} should be of shape (3, 3)") if ns.shape != (3,): raise ValueError(f"ns of shape {list(ns.shape)} should be of shape (3, )") if ns.device != cell.device: raise ValueError( f"`ns` and `cell` are not on the same device, got {ns.device} and " f"{cell.device}." ) if cell.is_cuda: # use function that does not synchronize with the CPU inverse_cell = torch.linalg.inv_ex(cell)[0] else: inverse_cell = torch.linalg.inv(cell) reciprocal_cell = 2 * torch.pi * inverse_cell.T bx = reciprocal_cell[0] by = reciprocal_cell[1] bz = reciprocal_cell[2] # Generate all reciprocal space vectors from real FFT! # The frequencies from the fftfreq function are of the form [0, 1/n, 2/n, ...] # These are then converted to [0, 1, 2, ...] by multiplying with n. # get the frequencies, multiply with n, then w/ the reciprocal space vectors kxs = (bx * ns[0]) * torch.fft.fftfreq( ns[0], device=cell.device, dtype=cell.dtype ).unsqueeze(-1) kys = (by * ns[1]) * torch.fft.fftfreq( ns[1], device=cell.device, dtype=cell.dtype ).unsqueeze(-1) if for_ewald: kzs = (bz * ns[2]) * torch.fft.fftfreq( ns[2], device=cell.device, dtype=cell.dtype ).unsqueeze(-1) else: kzs = (bz * ns[2]) * torch.fft.rfftfreq( ns[2], device=cell.device, dtype=cell.dtype ).unsqueeze(-1) # then take the cartesian product (all possible combinations, same as meshgrid) # via broadcasting (to avoid instantiating intermediates), and sum up return kxs[:, None, None] + kys[None, :, None] + kzs[None, None, :]
[docs] def generate_kvectors_for_mesh(cell: torch.Tensor, ns: torch.Tensor) -> torch.Tensor: """ Compute all reciprocal space vectors for Fourier space sums. This variant is used in combination with **mesh based calculators** using the fast fourier transform (FFT) algorithm. :param cell: torch.tensor of shape ``(3, 3)``, where ``cell[i]`` is the i-th basis vector of the unit cell :param ns: torch.tensor of shape ``(3,)`` and dtype int ``ns = [nx, ny, nz]`` contains the number of mesh points in the x-, y- and z-direction, respectively. For faster performance during the Fast Fourier Transform (FFT) it is recommended to use values of nx, ny and nz that are powers of 2. :return: torch.tensor of shape ``(nx, ny, nz, 3)`` containing all reciprocal space vectors that will be used in the (FFT-based) mesh calculators. Note that ``k_vectors[0,0,0] = [0,0,0]`` always is the zero vector. .. seealso:: :func:`generate_kvectors_for_ewald` for a function to be used for Ewald calculators. """ return _generate_kvectors(cell=cell, ns=ns, for_ewald=False)
[docs] def generate_kvectors_for_ewald( cell: torch.Tensor, ns: torch.Tensor, ) -> torch.Tensor: """ Compute all reciprocal space vectors for Fourier space sums. This variant is used with the **Ewald calculator**, in which the sum over the reciprocal space vectors is performed explicitly rather than using the fast Fourier transform (FFT) algorithm. The main difference with :func:`generate_kvectors_for_mesh` is the shape of the output tensor (see documentation on return) and the fact that the full set of reciprocal space vectors is returned, rather than the FFT-optimized set that roughly contains only half of the vectors. :param cell: torch.tensor of shape ``(3, 3)``, where ``cell[i]`` is the i-th basis vector of the unit cell :param ns: torch.tensor of shape ``(3,)`` and dtype int ``ns = [nx, ny, nz]`` contains the number of mesh points in the x-, y- and z-direction, respectively. :return: torch.tensor of shape ``(n, 3)`` containing all reciprocal space vectors that will be used in the Ewald calculator. Note that ``k_vectors[0] = [0,0,0]`` always is the zero vector. .. seealso:: :func:`generate_kvectors_for_mesh` for a function to be used with mesh based calculators like PME. """ return _generate_kvectors(cell=cell, ns=ns, for_ewald=True).reshape(-1, 3)