Source code for torchpme.potentials.combined

from typing import Optional

import torch

from .potential import Potential


[docs] class CombinedPotential(Potential): """ A potential that is a linear combination of multiple potentials. A class representing a combined potential that aggregates multiple individual potentials with weights for use in long-range (LR) and short-range (SR) interactions. The ``CombinedPotential`` class allows for flexible combination of potential functions with user-specified weights, which can be either fixed or trainable. :param potentials: List of potential objects, each implementing a compatible interface with methods `from_dist`, `lr_from_dist`, `lr_from_k_sq`, `self_contribution`, and `background_correction`. :param initial_weights: Initial weights for combining the potentials. If provided, the length must match the number of potentials. If `None`, weights are initialized to ones. :param learnable_weights: If `True`, weights are trainable parameters, allowing optimization during training. If `False`, weights are fixed. :param exclusion_radius: A length scale that defines a *local environment* within which the potential should be smoothly zeroed out, as it will be described by a separate model. """ def __init__( self, potentials: list[Potential], initial_weights: Optional[torch.Tensor] = None, learnable_weights: Optional[bool] = True, smearing: Optional[float] = None, exclusion_radius: Optional[float] = None, ): super().__init__( smearing=smearing, exclusion_radius=exclusion_radius, ) smearings = [pot.smearing for pot in potentials] if not all(smearings) and any(smearings): raise ValueError( r"Cannot combine direct (`smearing=None`) and range-separated (`smearing=float`) potentials." ) if all(smearings) and not self.smearing: # this is very misleading, but it is the way the original code works, # otherwise mypy complains raise ValueError( r"You should specify a `smearing` when combining range-separated (`smearing=float`) potentials." ) if not any(smearings) and self.smearing: # this is very misleading, but it is the way the original code works, # otherwise mypy complai raise ValueError( r"Cannot specify `smearing` when combining direct (`smearing=None`) potentials." ) if initial_weights is not None: if len(initial_weights) != len(potentials): raise ValueError( "The number of initial weights must match the number of potentials being combined" ) else: initial_weights = torch.ones(len(potentials)) # for torchscript self.potentials = torch.nn.ModuleList(potentials) if learnable_weights: self.weights = torch.nn.Parameter(initial_weights) else: self.register_buffer("weights", initial_weights)
[docs] def from_dist(self, dist: torch.Tensor) -> torch.Tensor: potentials = [pot.from_dist(dist) for pot in self.potentials] potentials = torch.stack(potentials, dim=-1) return torch.inner(self.weights, potentials)
[docs] def sr_from_dist(self, dist: torch.Tensor) -> torch.Tensor: potentials = [pot.sr_from_dist(dist) for pot in self.potentials] potentials = torch.stack(potentials, dim=-1) return torch.inner(self.weights, potentials)
[docs] def lr_from_dist(self, dist: torch.Tensor) -> torch.Tensor: potentials = [pot.lr_from_dist(dist) for pot in self.potentials] potentials = torch.stack(potentials, dim=-1) return torch.inner(self.weights, potentials)
[docs] def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor: potentials = [pot.lr_from_k_sq(k_sq) for pot in self.potentials] potentials = torch.stack(potentials, dim=-1) return torch.inner(self.weights, potentials)
[docs] def self_contribution(self) -> torch.Tensor: # self-correction for 1/r^p potential potentials = [pot.self_contribution() for pot in self.potentials] potentials = torch.stack(potentials, dim=-1) return torch.inner(self.weights, potentials)
[docs] def background_correction(self) -> torch.Tensor: # "charge neutrality" correction for 1/r^p potential potentials = [pot.background_correction() for pot in self.potentials] potentials = torch.stack(potentials, dim=-1) return torch.inner(self.weights, potentials)
from_dist.__doc__ = Potential.from_dist.__doc__ sr_from_dist.__doc__ = Potential.sr_from_dist.__doc__ lr_from_dist.__doc__ = Potential.lr_from_dist.__doc__ lr_from_k_sq.__doc__ = Potential.lr_from_k_sq.__doc__ self_contribution.__doc__ = Potential.self_contribution.__doc__ background_correction.__doc__ = Potential.background_correction.__doc__