[docs]classCombinedPotential(Potential):""" A potential that is a linear combination of multiple potentials. A class representing a combined potential that aggregates multiple individual potentials with weights for use in long-range (LR) and short-range (SR) interactions. The ``CombinedPotential`` class allows for flexible combination of potential functions with user-specified weights, which can be either fixed or trainable. :param potentials: List of potential objects, each implementing a compatible interface with methods `from_dist`, `lr_from_dist`, `lr_from_k_sq`, `self_contribution`, and `background_correction`. :param initial_weights: Initial weights for combining the potentials. If provided, the length must match the number of potentials. If `None`, weights are initialized to ones. :param learnable_weights: If `True`, weights are trainable parameters, allowing optimization during training. If `False`, weights are fixed. :param exclusion_radius: A length scale that defines a *local environment* within which the potential should be smoothly zeroed out, as it will be described by a separate model. """def__init__(self,potentials:list[Potential],initial_weights:Optional[torch.Tensor]=None,learnable_weights:Optional[bool]=True,smearing:Optional[float]=None,exclusion_radius:Optional[float]=None,):super().__init__(smearing=smearing,exclusion_radius=exclusion_radius,)smearings=[pot.smearingforpotinpotentials]ifnotall(smearings)andany(smearings):raiseValueError(r"Cannot combine direct (`smearing=None`) and range-separated (`smearing=float`) potentials.")ifall(smearings)andnotself.smearing:# this is very misleading, but it is the way the original code works,# otherwise mypy complainsraiseValueError(r"You should specify a `smearing` when combining range-separated (`smearing=float`) potentials.")ifnotany(smearings)andself.smearing:# this is very misleading, but it is the way the original code works,# otherwise mypy complairaiseValueError(r"Cannot specify `smearing` when combining direct (`smearing=None`) potentials.")ifinitial_weightsisnotNone:iflen(initial_weights)!=len(potentials):raiseValueError("The number of initial weights must match the number of potentials being combined")else:initial_weights=torch.ones(len(potentials))# for torchscriptself.potentials=torch.nn.ModuleList(potentials)iflearnable_weights:self.weights=torch.nn.Parameter(initial_weights)else:self.register_buffer("weights",initial_weights)
[docs]defself_contribution(self)->torch.Tensor:# self-correction for 1/r^p potentialpotentials=[pot.self_contribution()forpotinself.potentials]potentials=torch.stack(potentials,dim=-1)returntorch.inner(self.weights,potentials)
[docs]defbackground_correction(self)->torch.Tensor:# "charge neutrality" correction for 1/r^p potentialpotentials=[pot.background_correction()forpotinself.potentials]potentials=torch.stack(potentials,dim=-1)returntorch.inner(self.weights,potentials)