Source code for torchpme.potentials.coulomb

from typing import Optional

import torch

from .potential import Potential


[docs] class CoulombPotential(Potential): """ Smoothed electrostatic Coulomb potential :math:`1/r`. Here :math:`r` is the inter-particle distance It can be used to compute: 1. the full :math:`1/r` potential 2. its short-range (SR) and long-range (LR) parts, the split being determined by a length-scale parameter (called "Inverse" in the code) 3. the Fourier transform of the LR part :param smearing: float or torch.Tensor containing the parameter often called "sigma" in publications, which determines the length-scale at which the short-range and long-range parts of the naive :math:`1/r` potential are separated. The smearing parameter corresponds to the "width" of a Gaussian smearing of the particle density. :param exclusion_radius: A length scale that defines a *local environment* within which the potential should be smoothly zeroed out, as it will be described by a separate model. :param exclusion_degree: Controls the sharpness of the transition in the cutoff function applied within the ``exclusion_radius``. The cutoff is computed as a raised cosine with exponent ``exclusion_degree`` """ def __init__( self, smearing: Optional[float] = None, exclusion_radius: Optional[float] = None, exclusion_degree: int = 1, ): super().__init__(smearing, exclusion_radius, exclusion_degree)
[docs] def from_dist(self, dist: torch.Tensor) -> torch.Tensor: """ Full :math:`1/r` potential as a function of :math:`r`. :param dist: torch.tensor containing the distances at which the potential is to be evaluated. """ return 1.0 / dist
[docs] def lr_from_dist(self, dist: torch.Tensor) -> torch.Tensor: """ Long range of the range-separated :math:`1/r` potential. Used to subtract out the interior contributions after computing the LR part in reciprocal (Fourier) space. :param dist: torch.tensor containing the distances at which the potential is to be evaluated. """ if self.smearing is None: raise ValueError( "Cannot compute long-range contribution without specifying `smearing`." ) return torch.erf(dist / self.smearing / 2.0**0.5) / dist
[docs] def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor: r""" Fourier transform of the LR part potential in terms of :math:`\mathbf{k^2}`. :param k_sq: torch.tensor containing the squared lengths (2-norms) of the wave vectors k at which the Fourier-transformed potential is to be evaluated """ if self.smearing is None: raise ValueError( "Cannot compute long-range kernel without specifying `smearing`." ) # avoid NaNs in backward, see # https://github.com/jax-ml/jax/issues/1052 # https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf masked = torch.where(k_sq == 0, 1.0, k_sq) return torch.where( k_sq == 0, 0.0, 4 * torch.pi * torch.exp(-0.5 * self.smearing**2 * masked) / masked, )
[docs] def self_contribution(self) -> torch.Tensor: # self-correction for 1/r potential if self.smearing is None: raise ValueError( "Cannot compute self contribution without specifying `smearing`." ) return (2 / torch.pi) ** 0.5 / self.smearing
[docs] def background_correction(self) -> torch.Tensor: # "charge neutrality" correction for 1/r potential if self.smearing is None: raise ValueError( "Cannot compute background correction without specifying `smearing`." ) return torch.pi * self.smearing**2
[docs] @staticmethod def pbc_correction( periodic: Optional[torch.Tensor], positions: torch.Tensor, cell: torch.Tensor, charges: torch.Tensor, ) -> torch.Tensor: # "2D periodicity" correction for 1/r potential if periodic is None: periodic = torch.tensor([True, True, True], device=cell.device) n_periodic = torch.sum(periodic).item() if n_periodic == 3: periodicity = 3 nonperiodic_axis = None elif n_periodic == 2: periodicity = 2 nonperiodic_axis = torch.where(~periodic)[0] max_distance = torch.max(positions[:, nonperiodic_axis]) - torch.min( positions[:, nonperiodic_axis] ) cell_size = torch.linalg.norm(cell[nonperiodic_axis]) if max_distance > cell_size / 3: raise ValueError( f"Maximum distance along non-periodic axis ({max_distance}) " f"exceeds one third of cell size ({cell_size})." ) else: raise ValueError( "K-space summation is not implemented for 1D or non-periodic systems." ) if periodicity == 2: charge_tot = torch.sum(charges, dim=0) axis = nonperiodic_axis z_i = positions[:, axis].view(-1, 1) basis_len = torch.linalg.norm(cell[axis]) M_axis = torch.sum(charges * z_i, dim=0) M_axis_sq = torch.sum(charges * z_i**2, dim=0) V = torch.abs(torch.linalg.det(cell)) E_slab = (4.0 * torch.pi / V) * ( z_i * M_axis - 0.5 * (M_axis_sq + charge_tot * z_i**2) - charge_tot / 12.0 * basis_len**2 ) else: E_slab = torch.zeros_like(charges) return E_slab
self_contribution.__doc__ = Potential.self_contribution.__doc__ background_correction.__doc__ = Potential.background_correction.__doc__ pbc_correction.__doc__ = Potential.pbc_correction.__doc__