Source code for torchpme.utils.tuning.pme

import math
from typing import Optional

import torch

from ...lib import get_ns_mesh
from . import (
    _estimate_smearing_cutoff,
    _optimize_parameters,
    _validate_parameters,
)


[docs] def tune_pme( sum_squared_charges: float, cell: torch.Tensor, positions: torch.Tensor, smearing: Optional[float] = None, mesh_spacing: Optional[float] = None, cutoff: Optional[float] = None, interpolation_nodes: int = 4, exponent: int = 1, accuracy: float = 1e-3, max_steps: int = 50000, learning_rate: float = 0.1, ): r""" Find the optimal parameters for :class:`torchpme.PMECalculator`. For the error formulas are given `elsewhere <https://doi.org/10.1063/1.470043>`_. Note the difference notation between the parameters in the reference and ours: .. math:: \alpha = \left(\sqrt{2}\,\mathrm{smearing} \right)^{-1} For the optimization we use the :class:`torch.optim.Adam` optimizer. By default this function optimize the ``smearing``, ``mesh_spacing`` and ``cutoff`` based on the error formula given `elsewhere`_. You can limit the optimization by giving one or more parameters to the function. For example in usual ML workflows the cutoff is fixed and one wants to optimize only the ``smearing`` and the ``mesh_spacing`` with respect to the minimal error and fixed cutoff. :param sum_squared_charges: accumulated squared charges, must be positive :param cell: single tensor of shape (3, 3), describing the bounding :param positions: single tensor of shape (``len(charges), 3``) containing the Cartesian positions of all point charges in the system. :param smearing: if its value is given, it will not be tuned, see :class:`torchpme.PMECalculator` for details :param mesh_spacing: if its value is given, it will not be tuned, see :class:`torchpme.PMECalculator` for details :param cutoff: if its value is given, it will not be tuned, see :class:`torchpme.PMECalculator` for details :param interpolation_nodes: The number ``n`` of nodes used in the interpolation per coordinate axis. The total number of interpolation nodes in 3D will be ``n^3``. In general, for ``n`` nodes, the interpolation will be performed by piecewise polynomials of degree ``n - 1`` (e.g. ``n = 4`` for cubic interpolation). Only the values ``3, 4, 5, 6, 7`` are supported. :param exponent: exponent :math:`p` in :math:`1/r^p` potentials :param accuracy: Recomended values for a balance between the accuracy and speed is :math:`10^{-3}`. For more accurate results, use :math:`10^{-6}`. :param max_steps: maximum number of gradient descent steps :param learning_rate: learning rate for gradient descent :return: Tuple containing a float of the optimal smearing for the :class: `CoulombPotential`, a dictionary with the parameters for :class:`PMECalculator` and a float of the optimal cutoff value for the neighborlist computation. Example ------- >>> import torch To allow reproducibility, we set the seed to a fixed value >>> _ = torch.manual_seed(0) >>> positions = torch.tensor( ... [[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]], dtype=torch.float64 ... ) >>> charges = torch.tensor([[1.0], [-1.0]], dtype=torch.float64) >>> cell = torch.eye(3, dtype=torch.float64) >>> smearing, parameter, cutoff = tune_pme( ... torch.sum(charges**2, dim=0), cell, positions, accuracy=1e-1 ... ) You can check the values of the parameters >>> print(smearing) 0.6768985898318037 >>> print(parameter) {'mesh_spacing': 0.6305733973385922, 'interpolation_nodes': 4} >>> print(cutoff) 2.243154348782357 You can give one parameter to the function to tune only other parameters, for example, fixing the cutoff to 0.1 >>> smearing, parameter, cutoff = tune_pme( ... torch.sum(charges**2, dim=0), cell, positions, cutoff=0.6, accuracy=1e-1 ... ) You can check the values of the parameters, now the cutoff is fixed >>> print(smearing) 0.22038829671671745 >>> print(parameter) {'mesh_spacing': 0.5006356677116188, 'interpolation_nodes': 4} >>> print(cutoff) 0.6 """ _validate_parameters(sum_squared_charges, cell, positions, exponent, accuracy) smearing_opt, cutoff_opt = _estimate_smearing_cutoff( cell=cell, smearing=smearing, cutoff=cutoff, accuracy=accuracy, ) # We choose only one mesh as initial guess if mesh_spacing is None: ns_mesh_opt = torch.tensor( [1, 1, 1], device=cell.device, dtype=cell.dtype, requires_grad=True, ) else: ns_mesh_opt = get_ns_mesh(cell, mesh_spacing) cell_dimensions = torch.linalg.norm(cell, dim=1) volume = torch.abs(cell.det()) prefac = 2 * sum_squared_charges / math.sqrt(len(positions)) interpolation_nodes = torch.tensor(interpolation_nodes, device=cell.device) def err_Fourier(smearing, ns_mesh): def H(ns_mesh): return torch.prod(1 / ns_mesh) ** (1 / 3) def RMS_phi(ns_mesh): return torch.linalg.norm( _compute_RMS_phi(cell, interpolation_nodes, ns_mesh, positions) ) def log_factorial(x): return torch.lgamma(x + 1) def factorial(x): return torch.exp(log_factorial(x)) return ( prefac * torch.pi**0.25 * (6 * (1 / 2**0.5 / smearing) / (2 * interpolation_nodes + 1)) ** 0.5 / volume ** (2 / 3) * (2**0.5 / smearing * H(ns_mesh)) ** interpolation_nodes / factorial(interpolation_nodes) * torch.exp( (interpolation_nodes) * (torch.log(interpolation_nodes / 2) - 1) / 2 ) * RMS_phi(ns_mesh) ) def err_real(smearing, cutoff): return ( prefac / torch.sqrt(cutoff * volume) * torch.exp(-(cutoff**2) / 2 / smearing**2) ) def loss(smearing, ns_mesh, cutoff): return torch.sqrt( err_Fourier(smearing, ns_mesh) ** 2 + err_real(smearing, cutoff) ** 2 ) params = [smearing_opt, ns_mesh_opt, cutoff_opt] _optimize_parameters( params=params, loss=loss, max_steps=max_steps, accuracy=accuracy, learning_rate=learning_rate, ) return ( float(smearing_opt), { "mesh_spacing": float(torch.min(cell_dimensions / ns_mesh_opt)), "interpolation_nodes": int(interpolation_nodes), }, float(cutoff_opt), )
def _compute_RMS_phi( cell: torch.Tensor, interpolation_nodes: torch.Tensor, ns_mesh: torch.Tensor, positions: torch.Tensor, ) -> torch.Tensor: inverse_cell = torch.linalg.inv(cell) # Compute positions relative to the mesh basis vectors positions_rel = ns_mesh * torch.matmul(positions, inverse_cell) # Calculate positions and distances based on interpolation nodes even = interpolation_nodes % 2 == 0 if even: # For Lagrange interpolation, when the number of interpolation # is even, the relative position of a charge is the midpoint of # the two nearest gridpoints. positions_rel_idx = _Floor.apply(positions_rel) else: # For Lagrange interpolation, when the number of interpolation # points is odd, the relative position of a charge is the nearest gridpoint. positions_rel_idx = _Round.apply(positions_rel) # Calculate indices of mesh points on which the particle weights are # interpolated. For each particle, its weight is "smeared" onto `order**3` mesh # points, which can be achived using meshgrid below. indices_to_interpolate = torch.stack( [ (positions_rel_idx + i) for i in range( 1 - (interpolation_nodes + 1) // 2, 1 + interpolation_nodes // 2, ) ], dim=0, ) positions_rel = positions_rel[torch.newaxis, :, :] positions_rel += 1e-10 * torch.randn( positions_rel.shape, dtype=cell.dtype, device=cell.device ) # Noises help the algorithm work for tiny systems (<100 atoms) return ( torch.mean( (torch.prod(indices_to_interpolate - positions_rel, dim=0)) ** 2, dim=0 ) ** 0.5 ) class _Floor(torch.autograd.Function): """floor function with non-zero gradient""" @staticmethod def forward(ctx, input): result = torch.floor(input) ctx.save_for_backward(result) return result @staticmethod def backward(ctx, grad_output): return grad_output class _Round(torch.autograd.Function): """round function with non-zero gradient""" @staticmethod def forward(ctx, input): result = torch.round(input) ctx.save_for_backward(result) return result @staticmethod def backward(ctx, grad_output): return grad_output