[docs]classPotentialDipole(torch.nn.Module):r""" Pair potential energy function between point dipoles. The intercation is described as .. math:: V(\vec{r}) = \frac{(\vec{\mu}_i \cdot \vec{\mu}_j)}{r^3} - \frac{3 (\vec{\mu}_i \cdot \vec{r}) (\vec{\mu}_j \cdot \vec{r}) }{r^5} where :math:`r=|\vec{r}|`. :param smearing: float or torch.Tensor containing the parameter often called "sigma" in publications, which determines the length-scale at which the short-range and long-range parts of the naive :math:`1/r` potential are separated. The smearing parameter corresponds to the "width" of a Gaussian smearing of the particle density. :param exclusion_radius: A length scale that defines a *local environment* within which the potential should be smoothly zeroed out, as it will be described by a separate model. :param epsilon: Dielectric constant of the medium in which the dipoles are embedded. """def__init__(self,smearing:Optional[float]=None,exclusion_radius:Optional[float]=None,epsilon:float=0.0,):super().__init__()ifsmearingisnotNone:self.register_buffer("smearing",torch.tensor(smearing,dtype=torch.float64))else:self.smearing=Noneifexclusion_radiusisnotNone:self.register_buffer("exclusion_radius",torch.tensor(exclusion_radius,dtype=torch.float64),)else:self.exclusion_radius=Noneself.register_buffer("epsilon",torch.tensor(epsilon,dtype=torch.float64))
[docs]@torch.jit.exportdeff_cutoff(self,vector:torch.Tensor)->torch.Tensor:r""" Default cutoff function defining the *local* region that should be excluded from the computation of a long-range model. Defaults to a shifted cosine :math:`(1+\cos \pi r/r_\mathrm{cut})/2`. :param vector: torch.tensor containing the vectors at which the potential is to be evaluated. """r_mag=torch.norm(vector,dim=1,keepdim=True)ifself.exclusion_radiusisNone:raiseValueError("Cannot compute cutoff function when `exclusion_radius` is not set")returntorch.where(r_mag<self.exclusion_radius,(1+torch.cos(r_mag*(torch.pi/self.exclusion_radius)))*0.5,0.0,)
[docs]deffrom_dist(self,vector:torch.Tensor)->torch.Tensor:r""" Full dipolar potential as a function of :math:`\mathbf{r}`. :param vector: torch.tensor containing the vectors at which the potential is to be evaluated. """r_mag=torch.norm(vector,dim=1,keepdim=True)scalar_potential=1.0/(r_mag**3)r_outer=torch.bmm(vector.unsqueeze(2),vector.unsqueeze(1))returnscalar_potential.unsqueeze(-1)*torch.eye(3).to(r_outer).unsqueeze(0)-3.0*r_outer/(r_mag**5).unsqueeze(-1)
[docs]@torch.jit.exportdefsr_from_dist(self,vector:torch.Tensor)->torch.Tensor:""" Short-range part of the pair potential in real space. :param dist: torch.tensor containing the distance vectors at which the potential is to be evaluated. """ifself.smearingisNone:raiseValueError("Cannot compute range-separated potential when `smearing` is not specified.")ifself.exclusion_radiusisNone:returnself.from_dist(vector)-self.lr_from_dist(vector)return-self.lr_from_dist(vector)*self.f_cutoff(vector).unsqueeze(-1)
[docs]@torch.jit.exportdeflr_from_dist(self,vector:torch.Tensor)->torch.Tensor:r""" Long-range of the range-separated dipolar potential. Used to subtract out the interior contributions after computing the long-range part in reciprocal (Fourier) space. :param vector: torch.tensor containing the vectors at which the potential is to be evaluated. """ifself.smearingisNone:raiseValueError("Cannot compute long-range contribution without specifying `smearing`.")alpha=1/(2*self.smearing**2)r_mag=torch.norm(vector,dim=1,keepdim=True)r_outer=torch.bmm(vector.unsqueeze(2),vector.unsqueeze(1))B1=torch.erfc(torch.sqrt(alpha)*r_mag)/r_mag**3B2=2*torch.sqrt(alpha/torch.pi)*torch.exp(-alpha*r_mag**2)/r_mag**2B=1.0/(r_mag**3)-B1-B2C1=3.0*torch.erfc(torch.sqrt(alpha)*r_mag)/r_mag**5C2=(2*torch.sqrt(alpha/torch.pi)*(2*alpha+3/r_mag**2)*torch.exp(-alpha*r_mag**2)/r_mag**2)C=3.0/(r_mag**5)-C1-C2returnB.unsqueeze(-1)*torch.eye(3).to(r_outer).unsqueeze(0)-r_outer*C.unsqueeze(-1)
[docs]deflr_from_k_sq(self,k_sq:torch.Tensor)->torch.Tensor:r""" Fourier transform of the long-range part of the potential. :param k_sq: torch.tensor containing the squared lengths (2-norms) of the wave vectors k at which the Fourier-transformed potential is to be evaluated """ifself.smearingisNone:raiseValueError("Cannot compute long-range kernel without specifying `smearing`.")# avoid NaNs in backward, see# https://github.com/jax-ml/jax/issues/1052# https://github.com/tensorflow/probability/blob/main/discussion/where-nan.pdfmasked=torch.where(k_sq==0,1.0,k_sq)returntorch.where(k_sq==0,0.0,4*torch.pi*torch.exp(-0.5*self.smearing**2*masked)/masked,)
[docs]defself_contribution(self)->torch.Tensor:ifself.smearingisNone:raiseValueError("Cannot compute long-range contribution without specifying `smearing`.")alpha=1/(2*self.smearing**2)return4*torch.pi/3*torch.sqrt((alpha/torch.pi)**3)