from typing import Optional
import torch
from ..lib import (
CubicSpline,
CubicSplineReciprocal,
compute_second_derivatives,
compute_spline_ft,
)
from .potential import Potential
[docs]
class SplinePotential(Potential):
r"""
Potential built from a spline interpolation.
The potential is assumed to have only a long-range part, but one can also
add a short-range part if needed, by inheriting and redefining
``sr_from_dist``.
The real-space potential is computed based on a cubic spline built at
initialization time. The Fourier-domain kernel is computed numerically
as a spline, too. Assumes the infinite-separation value of the
potential to be zero.
:param r_grid: radial grid for the real-space evaluation
:param y_grid: potential values for the real-space evaluation
:param k_grid: radial grid for the k-space evaluation;
computed automatically from ``r_grid`` if absent.
:param yhat_grid: potential values for the k-space evaluation;
computed automatically from ``y_grid`` if absent.
:param reciprocal: flag that determines if the splining should
be performed on a :math:`1/r` axis; suitable to describe
long-range potentials. ``r_grid`` should contain only
stricty positive values.
:param y_at_zero: value to be used for :math:`r\rightarrow 0`
when using a reciprocal spline
:param yhat_at_zero: value to be used for :math:`k\rightarrow 0`
in the k-space kernel
:param smearing: The length scale associated with the switching between
:math:`V_{\mathrm{SR}}(r)` and :math:`V_{\mathrm{LR}}(r)`
:param exclusion_radius: A length scale that defines a *local environment* within
which the potential should be smoothly zeroed out, as it will be described by a
separate model.
"""
def __init__(
self,
r_grid: torch.Tensor,
y_grid: torch.Tensor,
k_grid: Optional[torch.Tensor] = None,
yhat_grid: Optional[torch.Tensor] = None,
reciprocal: Optional[bool] = False,
y_at_zero: Optional[float] = None,
yhat_at_zero: Optional[float] = None,
smearing: Optional[float] = None,
exclusion_radius: Optional[float] = None,
):
super().__init__(
smearing=smearing,
exclusion_radius=exclusion_radius,
)
if len(y_grid) != len(r_grid):
raise ValueError("Length of radial grid and value array mismatch.")
self.register_buffer("r_grid", r_grid)
self.register_buffer("y_grid", y_grid)
if reciprocal:
if torch.min(r_grid) <= 0.0:
raise ValueError(
"Positive-valued radial grid is needed for reciprocal axis spline."
)
self._spline = CubicSplineReciprocal(r_grid, y_grid, y_at_zero=y_at_zero)
else:
self._spline = CubicSpline(r_grid, y_grid)
if k_grid is None:
# defaults to 2pi/r_grid_points if reciprocal, to r_grid if not
if reciprocal:
k_grid = torch.pi * 2 * torch.reciprocal(r_grid).flip(dims=[0])
else:
k_grid = r_grid.clone().detach()
self.register_buffer("k_grid", k_grid)
if yhat_grid is None:
# computes automatically!
yhat_grid = compute_spline_ft(
k_grid,
r_grid,
y_grid,
compute_second_derivatives(r_grid, y_grid),
)
self.register_buffer("yhat_grid", yhat_grid)
# the function is defined for k**2, so we define the grid accordingly
if reciprocal:
self._krn_spline = CubicSplineReciprocal(
k_grid**2, yhat_grid, y_at_zero=yhat_at_zero
)
else:
self._krn_spline = CubicSpline(k_grid**2, yhat_grid)
if y_at_zero is None:
self._y_at_zero = self._spline(
torch.zeros(1, dtype=self.r_grid.dtype, device=self.r_grid.device)
)
else:
self._y_at_zero = torch.tensor(
y_at_zero, dtype=self.r_grid.dtype, device=self.r_grid.device
)
if yhat_at_zero is None:
self._yhat_at_zero = self._krn_spline(
torch.zeros(1, dtype=self.k_grid.dtype, device=self.k_grid.device)
)
else:
self._yhat_at_zero = torch.tensor(
yhat_at_zero, dtype=self.k_grid.dtype, device=self.k_grid.device
)
[docs]
def from_dist(self, dist: torch.Tensor) -> torch.Tensor:
# if the full spline is not given, falls back on the lr part
return self.lr_from_dist(dist) + self.sr_from_dist(dist)
[docs]
def sr_from_dist(self, dist: torch.Tensor) -> torch.Tensor:
"""
Short-range part of the range-separated potential.
:param dist: torch.tensor containing the distances at which the potential is to
be evaluated.
"""
return 0.0 * dist
[docs]
def lr_from_dist(self, dist: torch.Tensor) -> torch.Tensor:
return self._spline(dist)
[docs]
def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor:
return self._krn_spline(k_sq)
[docs]
def self_contribution(self) -> torch.Tensor:
return self._y_at_zero
[docs]
def background_correction(self) -> torch.Tensor:
return torch.zeros(1)
from_dist.__doc__ = Potential.from_dist.__doc__
lr_from_dist.__doc__ = Potential.lr_from_dist.__doc__
lr_from_k_sq.__doc__ = Potential.lr_from_k_sq.__doc__
self_contribution.__doc__ = Potential.self_contribution.__doc__
background_correction.__doc__ = Potential.background_correction.__doc__