Advanced neighbor list usage

Accurately calculating forces as derivatives from energy is crucial for predicting system dynamics as well as in training machine learning models. In systems where forces are derived from the gradients of the potential energy, it is essential that the distance calculations between particles are included in the computational graph. This ensures that the force computations respect the dependencies between particle positions and distances, allowing for accurate gradients during backpropagation.

../_images/backprop-path.svg

Visualization of the data flow to compute the energy from the cell, positions and charges through a neighborlist calculator and the potential calculator. All operations on the red line have to be tracked to obtain the correct computation of derivatives on the positions.

In this tutorial, we demonstrate two methods for maintaining differentiability when computing distances between particles. The first method manually recomputes distances within the computational graph using positions, cell information, and neighbor shifts, making it suitable for any neighbor list code.

The second method uses a backpropagable neighbor list from the vesin-torch library, which automatically ensures that the distance calculations remain differentiable.

Note

While both approaches yield the same result, a backpropagable neighbor list is generally preferred because it eliminates the need to manually recompute distances. This not only simplifies your code but also improves performance.

from typing import Optional

import ase
import chemiscope
import matplotlib.pyplot as plt
import numpy as np
import torch
import vesin
import vesin.torch

import torchpme
from torchpme.tuning import tune_pme

The test system

As a test system, we use a 2x2x2 supercell of an CsCl crystal in a cubic cell.

dtype = torch.float64
atoms_unitcell = ase.Atoms(
    symbols=["Cs", "Cl"],
    positions=np.array([(0, 0, 0), (0.5, 0.5, 0.5)]),
    cell=np.eye(3),
    pbc=torch.tensor([True, True, True]),
)
charges_unitcell = np.array([1.0, -1.0])

atoms = atoms_unitcell.repeat([2, 2, 2])
charges = np.tile(charges_unitcell, 2 * 2 * 2)

We now slightly displace the atoms from their initial positions randomly based on a Gaussian distribution with a width of 0.1 Å to create non-zero forces.

atoms.rattle(stdev=0.1)

chemiscope.show(
    frames=[atoms],
    mode="structure",
    settings=chemiscope.quick_settings(structure_settings={"unitCell": True}),
)

Loading icon


Tune paramaters

Based on our system we will first tune the PME parameters for an accurate computation. We first convert the positions, charges and the cell from NumPy arrays into torch tensors and compute the summed squared charges.

The tuning found the following best values for our system.

print("smearing:", smearing)
print("PME parameters:", pme_params)
print("cutoff:", cutoff)
smearing: 1.1069526756106463
PME parameters: {'interpolation_nodes': 4, 'mesh_spacing': 0.5714285714285714}
cutoff: 4.4

Generic Neighborlist

One usual workflow is to compute the distance vectors using default tools like the the default (NumPy) version of the vesin neighbor list.

nl = vesin.NeighborList(cutoff=cutoff, full_list=False)
neighbor_indices, S = nl.compute(
    points=atoms.positions, box=atoms.cell.array, periodic=True, quantities="PS"
)

We now define a function that (re-)computes the distances in a way that torch can track these operations.

def distances(
    positions: torch.Tensor,
    neighbor_indices: torch.Tensor,
    cell: Optional[torch.Tensor] = None,
    neighbor_shifts: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """Compute pairwise distances."""
    atom_is = neighbor_indices[:, 0]
    atom_js = neighbor_indices[:, 1]

    pos_is = positions[atom_is]
    pos_js = positions[atom_js]

    distance_vectors = pos_js - pos_is

    if cell is not None and neighbor_shifts is not None:
        shifts = neighbor_shifts.type(cell.dtype)
        distance_vectors += shifts @ cell
    elif cell is not None and neighbor_shifts is None:
        raise ValueError("Provided `cell` but no `neighbor_shifts`.")
    elif cell is None and neighbor_shifts is not None:
        raise ValueError("Provided `neighbor_shifts` but no `cell`.")

    return torch.linalg.norm(distance_vectors, dim=1)

To use this function we now the tracking of operations by setting the requires_grad property to True.

Now, we start to re-compute the distances

and initialize a PMECalculator instance using a CoulombPotential to compute the potential.

tensor([[-0.8874],
        [ 1.2554],
        [-0.9274],
        [ 1.4907],
        [-0.7261],
        [ 0.7836],
        [-0.8839],
        [ 0.8997],
        [-0.9603],
        [ 1.0151],
        [-1.0800],
        [ 1.3655],
        [-1.0228],
        [ 0.6690],
        [-1.1581],
        [ 0.8331]], dtype=torch.float64, grad_fn=<MulBackward0>)

The energy is given by the scalar product of the potential with the charges.

Finally, we can compute and print the forces in CGS units as erg/Å.

tensor([[-0.9484, -1.4864,  0.9129],
        [ 0.1350, -0.6500,  0.4235],
        [-0.0230, -1.4964, -0.1962],
        [-0.3949, -0.3494,  0.1579],
        [-0.0038,  0.9212,  1.0006],
        [ 0.2718,  0.2613, -0.4679],
        [-1.2057,  0.9323, -1.0406],
        [-0.3698,  0.5807,  0.5612],
        [ 0.5696, -1.1072,  0.0541],
        [ 0.6476,  0.4007, -0.4058],
        [ 0.5322,  0.1812, -0.3541],
        [ 0.5632,  0.4293,  0.3243],
        [ 0.4923,  0.5976, -0.6239],
        [ 0.0734,  0.1513,  0.0368],
        [-0.3666,  0.8548, -0.5614],
        [ 0.0273, -0.2203,  0.1787]], dtype=torch.float64)

Backpropagable Neighborlist

We now repeat the computation of the forces, but instead of using a generic neighbor list and our custom distances function, we directly use a neighbor list function that tracks the operations, as implemented by the vesin-torch library.

We first detach and clone the position tensor to create a new computational graph

and create new distances in a similar manner as above.

nl = vesin.torch.NeighborList(cutoff=cutoff, full_list=False)
neighbor_indices_new, d = nl.compute(
    points=positions_new, box=cell, periodic=True, quantities="Pd"
)

Following the same steps as above, we compute the forces.

tensor([[-0.9484, -1.4864,  0.9129],
        [ 0.1350, -0.6500,  0.4235],
        [-0.0230, -1.4964, -0.1962],
        [-0.3949, -0.3494,  0.1579],
        [-0.0038,  0.9212,  1.0006],
        [ 0.2718,  0.2613, -0.4679],
        [-1.2057,  0.9323, -1.0406],
        [-0.3698,  0.5807,  0.5612],
        [ 0.5696, -1.1072,  0.0541],
        [ 0.6476,  0.4007, -0.4058],
        [ 0.5322,  0.1812, -0.3541],
        [ 0.5632,  0.4293,  0.3243],
        [ 0.4923,  0.5976, -0.6239],
        [ 0.0734,  0.1513,  0.0368],
        [-0.3666,  0.8548, -0.5614],
        [ 0.0273, -0.2203,  0.1787]], dtype=torch.float64)

The forces are the same as those we printed above. For better comparison, we can also plot the scalar force for each method.

plt.plot(torch.linalg.norm(forces, dim=1), "o-", label="normal Neighborlist")
plt.plot(torch.linalg.norm(forces_new, dim=1), ".-", label="torch Neighborlist")
plt.legend()

plt.xlabel("atom index")
plt.ylabel(r"$|F|~/~\mathrm{erg\,Å^{-1}}$")

plt.show()
02 neighbor lists usage

Total running time of the script: (0 minutes 3.759 seconds)

Gallery generated by Sphinx-Gallery