Source code for torchpme.potentials.potential_dipole

from typing import Optional

import torch

from .potential import Potential


[docs] class PotentialDipole(torch.nn.Module): r""" Pair potential energy function between point dipoles. The intercation is described as .. math:: V(\vec{r}) = \frac{(\vec{\mu}_i \cdot \vec{\mu}_j)}{r^3} - \frac{3 (\vec{\mu}_i \cdot \vec{r}) (\vec{\mu}_j \cdot \vec{r}) }{r^5} where :math:`r=|\vec{r}|`. :param smearing: float or torch.Tensor containing the parameter often called "sigma" in publications, which determines the length-scale at which the short-range and long-range parts of the naive :math:`1/r` potential are separated. The smearing parameter corresponds to the "width" of a Gaussian smearing of the particle density. :param exclusion_radius: A length scale that defines a *local environment* within which the potential should be smoothly zeroed out, as it will be described by a separate model. :param epsilon: Dielectric constant of the medium in which the dipoles are embedded. """ def __init__( self, smearing: Optional[float] = None, exclusion_radius: Optional[float] = None, epsilon: float = 0.0, ): super().__init__() if smearing is not None: self.register_buffer( "smearing", torch.tensor(smearing, dtype=torch.float64) ) else: self.smearing = None if exclusion_radius is not None: self.register_buffer( "exclusion_radius", torch.tensor(exclusion_radius, dtype=torch.float64), ) else: self.exclusion_radius = None self.register_buffer("epsilon", torch.tensor(epsilon, dtype=torch.float64))
[docs] @torch.jit.export def f_cutoff(self, vector: torch.Tensor) -> torch.Tensor: r""" Default cutoff function defining the *local* region that should be excluded from the computation of a long-range model. Defaults to a shifted cosine :math:`(1+\cos \pi r/r_\mathrm{cut})/2`. :param vector: torch.tensor containing the vectors at which the potential is to be evaluated. """ r_mag = torch.norm(vector, dim=1, keepdim=True) if self.exclusion_radius is None: raise ValueError( "Cannot compute cutoff function when `exclusion_radius` is not set" ) return torch.where( r_mag < self.exclusion_radius, (1 + torch.cos(r_mag * (torch.pi / self.exclusion_radius))) * 0.5, 0.0, )
[docs] def from_dist(self, vector: torch.Tensor) -> torch.Tensor: r""" Full dipolar potential as a function of :math:`\mathbf{r}`. :param vector: torch.tensor containing the vectors at which the potential is to be evaluated. """ r_mag = torch.norm(vector, dim=1, keepdim=True) scalar_potential = 1.0 / (r_mag**3) r_outer = torch.bmm(vector.unsqueeze(2), vector.unsqueeze(1)) return scalar_potential.unsqueeze(-1) * torch.eye(3).to(r_outer).unsqueeze( 0 ) - 3.0 * r_outer / (r_mag**5).unsqueeze(-1)
[docs] @torch.jit.export def sr_from_dist(self, vector: torch.Tensor) -> torch.Tensor: """ Short-range part of the pair potential in real space. :param dist: torch.tensor containing the distance vectors at which the potential is to be evaluated. """ if self.smearing is None: raise ValueError( "Cannot compute range-separated potential when `smearing` is not specified." ) if self.exclusion_radius is None: return self.from_dist(vector) - self.lr_from_dist(vector) return -self.lr_from_dist(vector) * self.f_cutoff(vector).unsqueeze(-1)
[docs] @torch.jit.export def lr_from_dist(self, vector: torch.Tensor) -> torch.Tensor: r""" Long-range of the range-separated dipolar potential. Used to subtract out the interior contributions after computing the long-range part in reciprocal (Fourier) space. :param vector: torch.tensor containing the vectors at which the potential is to be evaluated. """ if self.smearing is None: raise ValueError( "Cannot compute long-range contribution without specifying `smearing`." ) alpha = 1 / (2 * self.smearing**2) r_mag = torch.norm(vector, dim=1, keepdim=True) r_outer = torch.bmm(vector.unsqueeze(2), vector.unsqueeze(1)) B1 = torch.erfc(torch.sqrt(alpha) * r_mag) / r_mag**3 B2 = 2 * torch.sqrt(alpha / torch.pi) * torch.exp(-alpha * r_mag**2) / r_mag**2 B = 1.0 / (r_mag**3) - B1 - B2 C1 = 3.0 * torch.erfc(torch.sqrt(alpha) * r_mag) / r_mag**5 C2 = ( 2 * torch.sqrt(alpha / torch.pi) * (2 * alpha + 3 / r_mag**2) * torch.exp(-alpha * r_mag**2) / r_mag**2 ) C = 3.0 / (r_mag**5) - C1 - C2 return B.unsqueeze(-1) * torch.eye(3).to(r_outer).unsqueeze( 0 ) - r_outer * C.unsqueeze(-1)
[docs] def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor: r""" Fourier transform of the long-range part of the potential. :param k_sq: torch.tensor containing the squared lengths (2-norms) of the wave vectors k at which the Fourier-transformed potential is to be evaluated """ if self.smearing is None: raise ValueError( "Cannot compute long-range kernel without specifying `smearing`." ) # avoid NaNs in backward, see # https://github.com/jax-ml/jax/issues/1052 # https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf masked = torch.where(k_sq == 0, 1.0, k_sq) return torch.where( k_sq == 0, 0.0, 4 * torch.pi * torch.exp(-0.5 * self.smearing**2 * masked) / masked, )
[docs] def self_contribution(self) -> torch.Tensor: if self.smearing is None: raise ValueError( "Cannot compute long-range contribution without specifying `smearing`." ) alpha = 1 / (2 * self.smearing**2) return 4 * torch.pi / 3 * torch.sqrt((alpha / torch.pi) ** 3)
[docs] def background_correction(self, volume) -> torch.Tensor: if self.epsilon == 0.0: return self.epsilon return 4 * torch.pi / (2 * self.epsilon + 1) / volume
self_contribution.__doc__ = Potential.self_contribution.__doc__ background_correction.__doc__ = Potential.background_correction.__doc__