Source code for torchpme.potentials.potential_dipole
from typing import Optional
import torch
from .potential import Potential
[docs]
class PotentialDipole(torch.nn.Module):
    r"""
    Pair potential energy function between point dipoles.
    The intercation is described as
    .. math::
        V(\vec{r}) = \frac{(\vec{\mu}_i \cdot \vec{\mu}_j)}{r^3} -
            \frac{3  (\vec{\mu}_i \cdot \vec{r})  (\vec{\mu}_j \cdot \vec{r}) }{r^5}
    where :math:`r=|\vec{r}|`.
    :param smearing: float or torch.Tensor containing the parameter often called "sigma"
        in publications, which determines the length-scale at which the short-range and
        long-range parts of the naive :math:`1/r` potential are separated. The smearing
        parameter corresponds to the "width" of a Gaussian smearing of the particle
        density.
    :param exclusion_radius: A length scale that defines a *local environment* within
        which the potential should be smoothly zeroed out, as it will be described by a
        separate model.
    :param epsilon: Dielectric constant of the medium in which the dipoles are embedded.
    :param exclusion_degree: Controls the sharpness of the transition in the cutoff function
        applied within the ``exclusion_radius``. The cutoff is computed as a raised cosine
        with exponent ``exclusion_degree``
    """
    def __init__(
        self,
        smearing: Optional[float] = None,
        exclusion_radius: Optional[float] = None,
        exclusion_degree: int = 1,
        epsilon: float = 0.0,
    ):
        super().__init__()
        self.exclusion_degree = exclusion_degree
        if smearing is not None:
            self.register_buffer(
                "smearing", torch.tensor(smearing, dtype=torch.float64)
            )
        else:
            self.smearing = None
        if exclusion_radius is not None:
            self.register_buffer(
                "exclusion_radius",
                torch.tensor(exclusion_radius, dtype=torch.float64),
            )
        else:
            self.exclusion_radius = None
        self.register_buffer("epsilon", torch.tensor(epsilon, dtype=torch.float64))
[docs]
    @torch.jit.export
    def f_cutoff(self, vector: torch.Tensor) -> torch.Tensor:
        r"""
        Default cutoff function defining the *local* region that should be excluded from
        the computation of a long-range model. Defaults to a shifted cosine
        :math:`1 - ((1 - \cos \pi r/r_\mathrm{cut})/2) ^ n`. where :math:`n` is the
        ``exclusion_degree`` parameter.
        :param vector: torch.tensor containing the vectors at which the potential is to
            be evaluated.
        """
        r_mag = torch.norm(vector, dim=1, keepdim=True)
        if self.exclusion_radius is None:
            raise ValueError(
                "Cannot compute cutoff function when `exclusion_radius` is not set"
            )
        return torch.where(
            r_mag < self.exclusion_radius,
            1
            - ((1 - torch.cos(torch.pi * (r_mag / self.exclusion_radius))) * 0.5)
            ** self.exclusion_degree,
            0.0,
        ) 
[docs]
    def from_dist(self, vector: torch.Tensor) -> torch.Tensor:
        r"""
        Full dipolar potential as a function of :math:`\mathbf{r}`.
        :param vector: torch.tensor containing the vectors at which the potential is to
            be evaluated.
        """
        r_mag = torch.norm(vector, dim=1, keepdim=True)
        scalar_potential = 1.0 / (r_mag**3)
        r_outer = torch.bmm(vector.unsqueeze(2), vector.unsqueeze(1))
        return scalar_potential.unsqueeze(-1) * torch.eye(3).to(r_outer).unsqueeze(
            0
        ) - 3.0 * r_outer / (r_mag**5).unsqueeze(-1) 
[docs]
    @torch.jit.export
    def sr_from_dist(self, vector: torch.Tensor) -> torch.Tensor:
        """
        Short-range part of the pair potential in real space.
        :param dist: torch.tensor containing the distance vectors at which the potential
            is to be evaluated.
        """
        if self.smearing is None:
            raise ValueError(
                "Cannot compute range-separated potential when `smearing` is not specified."
            )
        if self.exclusion_radius is None:
            return self.from_dist(vector) - self.lr_from_dist(vector)
        return -self.lr_from_dist(vector) * self.f_cutoff(vector).unsqueeze(-1) 
[docs]
    @torch.jit.export
    def lr_from_dist(self, vector: torch.Tensor) -> torch.Tensor:
        r"""
        Long-range of the range-separated dipolar potential.
        Used to subtract out the interior contributions after computing the long-range
        part in reciprocal (Fourier) space.
        :param vector: torch.tensor containing the vectors at which the potential is to
            be evaluated.
        """
        if self.smearing is None:
            raise ValueError(
                "Cannot compute long-range contribution without specifying `smearing`."
            )
        alpha = 1 / (2 * self.smearing**2)
        r_mag = torch.norm(vector, dim=1, keepdim=True)
        r_outer = torch.bmm(vector.unsqueeze(2), vector.unsqueeze(1))
        B1 = torch.erfc(torch.sqrt(alpha) * r_mag) / r_mag**3
        B2 = 2 * torch.sqrt(alpha / torch.pi) * torch.exp(-alpha * r_mag**2) / r_mag**2
        B = 1.0 / (r_mag**3) - B1 - B2
        C1 = 3.0 * torch.erfc(torch.sqrt(alpha) * r_mag) / r_mag**5
        C2 = (
            2
            * torch.sqrt(alpha / torch.pi)
            * (2 * alpha + 3 / r_mag**2)
            * torch.exp(-alpha * r_mag**2)
            / r_mag**2
        )
        C = 3.0 / (r_mag**5) - C1 - C2
        return B.unsqueeze(-1) * torch.eye(3).to(r_outer).unsqueeze(
            0
        ) - r_outer * C.unsqueeze(-1) 
[docs]
    def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor:
        r"""
        Fourier transform of the long-range part of the potential.
        :param k_sq: torch.tensor containing the squared lengths (2-norms) of the wave
            vectors k at which the Fourier-transformed potential is to be evaluated
        """
        if self.smearing is None:
            raise ValueError(
                "Cannot compute long-range kernel without specifying `smearing`."
            )
        # avoid NaNs in backward, see
        # https://github.com/jax-ml/jax/issues/1052
        # https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdf
        masked = torch.where(k_sq == 0, 1.0, k_sq)
        return torch.where(
            k_sq == 0,
            0.0,
            4 * torch.pi * torch.exp(-0.5 * self.smearing**2 * masked) / masked,
        ) 
[docs]
    def self_contribution(self) -> torch.Tensor:
        if self.smearing is None:
            raise ValueError(
                "Cannot compute long-range contribution without specifying `smearing`."
            )
        alpha = 1 / (2 * self.smearing**2)
        return 4 * torch.pi / 3 * torch.sqrt((alpha / torch.pi) ** 3) 
[docs]
    def background_correction(self, volume) -> torch.Tensor:
        if self.epsilon == 0.0:
            return self.epsilon
        return 4 * torch.pi / (2 * self.epsilon + 1) / volume 
    self_contribution.__doc__ = Potential.self_contribution.__doc__
    background_correction.__doc__ = Potential.background_correction.__doc__