Source code for torchpme.calculators.calculator_dipole

from typing import Optional

import torch
from torch import profiler

from .._utils import _validate_parameters
from ..lib import generate_kvectors_for_ewald
from ..potentials import PotentialDipole


[docs] class CalculatorDipole(torch.nn.Module): """ Base calculator for interacting dipoles in the torch interface. :param potential: a :class:`PotentialDipole` class object containing the functions that are necessary to compute the various components of the potential, as well as the parameters that determine the behavior of the potential itself. :param full_neighbor_list: parameter indicating whether the neighbor information will come from a full (True) or half (False, default) neighbor list. :param prefactor: electrostatics prefactor; see :ref:`prefactors` for details and common values. :param lr_wavelength: the wavelength of the long-range part of the potential. """ def __init__( self, potential: PotentialDipole, full_neighbor_list: bool = False, prefactor: float = 1.0, lr_wavelength: Optional[float] = None, ): super().__init__() if not isinstance(potential, PotentialDipole): raise TypeError( f"Potential must be an instance of PotentialDipole, got {type(potential)}" ) self.potential = potential self.lr_wavelength = lr_wavelength assert ( self.lr_wavelength is not None and self.potential.smearing is not None or (self.lr_wavelength is None and self.potential.smearing is None) ), "Either both `lr_wavelength` and `smearing` must be set or both must be None" self.full_neighbor_list = full_neighbor_list self.prefactor = prefactor def _compute_rspace( self, dipoles: torch.Tensor, neighbor_indices: torch.Tensor, neighbor_vectors: torch.Tensor, ) -> torch.Tensor: # Compute the pair potential terms V(r_ij) for each pair of atoms (i,j) # contained in the neighbor list with profiler.record_function("compute bare potential"): if self.potential.smearing is None: potentials_bare = self.potential.from_dist(neighbor_vectors) else: potentials_bare = self.potential.sr_from_dist(neighbor_vectors) # Multiply the bare potential terms V(r_ij) with the corresponding dipoles # of ``atom j'' to obtain q_j*V(r_ij). Since each atom j can be a neighbor of # multiple atom i's, we need to access those from neighbor_indices atom_is = neighbor_indices[:, 0] atom_js = neighbor_indices[:, 1] with profiler.record_function("compute real potential"): contributions_is = torch.bmm( potentials_bare, dipoles[atom_js].unsqueeze(-1) ).squeeze(-1) # For each atom i, add up all contributions of the form q_j*V(r_ij) for j # ranging over all of its neighbors. with profiler.record_function("assign potential"): potential = torch.zeros_like(dipoles) potential.index_add_(0, atom_is, contributions_is) # If we are using a half neighbor list, we need to add the contributions # from the "inverse" pairs (j, i) to the atoms i if not self.full_neighbor_list: contributions_js = torch.bmm( potentials_bare, dipoles[atom_is].unsqueeze(-1) ).squeeze(-1) potential.index_add_(0, atom_js, contributions_js) # Compensate for double counting of pairs (i,j) and (j,i) return potential / 2 def _compute_kspace( self, dipoles: torch.Tensor, cell: torch.Tensor, positions: torch.Tensor, ) -> torch.Tensor: # Define k-space cutoff from required real-space resolution k_cutoff = 2 * torch.pi / self.lr_wavelength # Compute number of times each basis vector of the reciprocal space can be # scaled until the cutoff is reached basis_norms = torch.linalg.norm(cell, dim=1) ns_float = k_cutoff * basis_norms / 2 / torch.pi ns = torch.ceil(ns_float).long() # Generate k-vectors and evaluate kvectors = generate_kvectors_for_ewald(ns=ns, cell=cell) knorm_sq = torch.sum(kvectors**2, dim=1) # We remove the singularity at k=0 by explicitly setting its # value to be equal to zero. This mathematically corresponds # to the requirement that the net charge of the cell is zero. # G = 4 * torch.pi * torch.exp(-0.5 * smearing**2 * knorm_sq) / knorm_sq G = self.potential.lr_from_k_sq(knorm_sq) # Compute the energy using the explicit method that # follows directly from the Poisson summation formula. # For this, we precompute trigonometric factors for optimization, which leads # to N^2 rather than N^3 scaling. trig_args = kvectors @ (positions.T) # [k, i] c = torch.cos(trig_args) # [k, i] s = torch.sin(trig_args) # [k, i] sc = torch.stack([c, s], dim=0) # [2 "f", k, i] mu_k = dipoles @ kvectors.T # [i, k] sc_summed_G = torch.einsum("fki, ik, k->fk", sc, mu_k, G) energy = torch.einsum("fk, fki, kc->ic", sc_summed_G, sc, kvectors) energy /= torch.abs(cell.det()) energy -= dipoles * self.potential.self_contribution() energy += self.potential.background_correction( torch.abs(cell.det()) ) * dipoles.sum(dim=0) return energy / 2
[docs] def forward( self, dipoles: torch.Tensor, cell: torch.Tensor, positions: torch.Tensor, neighbor_indices: torch.Tensor, neighbor_vectors: torch.Tensor, ): r""" Compute the potential "energy". It is calculated as: .. math:: V_i = \frac{1}{2} \sum_{j} \boldsymbol{\mu_j} \, \mathbf{v}(\mathbf{r_{ij}}) where :math:`\mathbf{v}(\mathbf{r})` is the pair potential defined by the ``potential`` parameter, and :math:`\boldsymbol{\mu_j}` are atomic "dipoles". If the ``smearing`` of the ``potential`` is not set, the calculator evaluates only the real-space part of the potential. Otherwise, provided that the calculator implements a ``_compute_kspace`` method, it will also evaluate the long-range part using a Fourier-domain method. :param dipoles: torch.tensor of shape ``(len(positions), 3)`` containaing the atomic dipoles. :param cell: torch.tensor of shape ``(3, 3)``, where ``cell[i]`` is the i-th basis vector of the unit cell :param positions: torch.tensor of shape ``(N, 3)`` containing the Cartesian coordinates of the ``N`` particles within the supercell. :param neighbor_indices: torch.tensor with the ``i,j`` indices of neighbors for which the potential should be computed in real space. :param neighbor_vectors: torch.tensor with the pair vectors of the neighbors for which the potential should be computed in real space. """ # TODO: _validate_parameters to allow also dipoles. Temporarily pass the # distance tensor. _validate_parameters( charges=dipoles, cell=cell, positions=positions, neighbor_indices=neighbor_indices, neighbor_distances=neighbor_vectors.norm(dim=-1), smearing=self.potential.smearing, ) # Compute short-range (SR) part using a real space sum potential_sr = self._compute_rspace( dipoles=dipoles, neighbor_indices=neighbor_indices, neighbor_vectors=neighbor_vectors, ) if self.potential.smearing is None: return self.prefactor * potential_sr # Compute long-range (LR) part using a Fourier / reciprocal space sum potential_lr = self._compute_kspace( dipoles=dipoles, cell=cell, positions=positions, ) return self.prefactor * (potential_sr + potential_lr)