Custom models with automatic differentiation

Authors:

Michele Ceriotti @ceriottm

This example showcases how the main building blocks of torchpme, MeshInterpolator and KSpacaFilter can be combined creatively to construct arbitrary models that incorporate long-range structural correlations.

None of the models presented here has probably much meaning, and the use in a ML setting (including the definition of an appropriate loss, and its optimization) is left as an exercise to the reader.

from time import time

import ase
import torch

import torchpme

device = "cpu"
dtype = torch.float64
rng = torch.Generator()
rng.manual_seed(32)
<torch._C.Generator object at 0x7f8fb059cb70>

Generate a trial structure – a distorted rocksalt structure with perturbed positions and charges

structure = ase.Atoms(
    positions=[
        [0, 0, 0],
        [3, 0, 0],
        [0, 3, 0],
        [3, 3, 0],
        [0, 0, 3],
        [3, 0, 3],
        [0, 3, 3],
        [3, 3, 3],
    ],
    cell=[6, 6, 6],
    symbols="NaClClNaClNaNaCl",
)

displacement = torch.normal(
    mean=0.0, std=2.5e-1, size=(len(structure), 3), generator=rng
)
structure.positions += displacement.numpy()

charges = torch.tensor(
    [[1.0], [-1.0], [-1.0], [1.0], [-1.0], [1.0], [1.0], [-1.0]],
    dtype=dtype,
    device=device,
)
charges += torch.normal(mean=0.0, std=1e-1, size=(len(charges), 1), generator=rng)
positions = torch.from_numpy(structure.positions).to(device=device, dtype=dtype)
cell = torch.from_numpy(structure.cell.array).to(device=device, dtype=dtype)

Autodifferentiation through the core torchpme classes

We begin by showing how it is possible to compute a function of the internal state for the core classes, and to differentiate with respect to the structural and input parameters.

Functions of the atom density

The construction of a “decorated atom density” through MeshInterpolator can be easily differentiated through. We only need to request a gradient evaluation, evaluate the grid, and compute a function of the grid points (again, this is a proof-of-principle example, probably not very useful in practice).

positions.requires_grad_(True)
charges.requires_grad_(True)
cell.requires_grad_(True)

ns = torch.tensor([5, 5, 5])
interpolator = torchpme.lib.MeshInterpolator(
    cell=cell, ns_mesh=ns, interpolation_nodes=3, method="Lagrange"
)
interpolator.compute_weights(positions)
mesh = interpolator.points_to_mesh(charges)

value = mesh.sum()

The gradients can be computed by just running backward on the end result. Because of the sum rules that apply to the interpolation scheme, the gradients with respect to positions and cell entries are zero, and the gradients relative to the charges are all 1.

# we keep the graph to compute another quantity
value.backward(retain_graph=True)

print(
    f"""
Position gradients:
{positions.grad.T}

Cell gradients:
{cell.grad}

Charges gradients:
{charges.grad.T}
"""
)
Position gradients:
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]], dtype=torch.float64)

Cell gradients:
tensor([[-0., -0., -0.],
        [-0., -0., -0.],
        [-0., -0., -0.]], dtype=torch.float64)

Charges gradients:
tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]],
       dtype=torch.float64)

If we apply a non-linear function before summing, these sum rules apply only approximately.

positions.grad.zero_()
charges.grad.zero_()
cell.grad.zero_()

value2 = torch.sin(mesh).sum()
value2.backward(retain_graph=True)

print(
    f"""
Position gradients:
{positions.grad.T}

Cell gradients:
{cell.grad}

Charges gradients:
{charges.grad.T}
"""
)
Position gradients:
tensor([[-0.1328, -0.1042,  0.0396,  0.0984, -0.0198, -0.0627, -0.0130,  0.1468],
        [-0.0235, -0.0561,  0.1602, -0.1015,  0.1001, -0.0472,  0.0776,  0.1489],
        [ 0.0749,  0.0315, -0.0837, -0.0269,  0.1536, -0.0734,  0.0683, -0.1032]],
       dtype=torch.float64)

Cell gradients:
tensor([[-0.0547,  0.0234,  0.0970],
        [-0.1439, -0.1427,  0.0822],
        [-0.0115, -0.1393, -0.0238]], dtype=torch.float64)

Charges gradients:
tensor([[0.7647, 0.5572, 0.7826, 0.8653, 0.7911, 0.8245, 0.8831, 0.8013]],
       dtype=torch.float64)

Indirect functions of the weights

It is possible to have the atomic weights be a function of other quantities. For instance, pretend there is an external electric field along \(x\), and that the weights should be proportional to the electrostatic energy at each atom position (NB: defining an electric field in a periodic setting is not so simple, this is just a toy example).

positions.grad.zero_()
charges.grad.zero_()
cell.grad.zero_()

weights = charges * positions[:, :1]
mesh3 = interpolator.points_to_mesh(weights)

value3 = mesh3.sum()
value3.backward()

print(
    f"""
Position gradients:
{positions.grad.T}

Cell gradients:
{cell.grad}

Charges gradients:
{charges.grad.T}
"""
)
Position gradients:
tensor([[ 0.9216, -1.0203, -1.0203,  1.0234, -1.0189,  0.8147,  0.8905, -1.1390],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]],
       dtype=torch.float64)

Cell gradients:
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], dtype=torch.float64)

Charges gradients:
tensor([[-0.4354,  2.5653, -0.1268,  2.9276,  0.0661,  3.2907, -0.0991,  3.1386]],
       dtype=torch.float64)

Optimizable k-space filter

The operations in a KSpaceFilter can also be differentiated through.

A parametric k-space filter

We define a filter with multiple smearing parameters, that are applied separately to multiple mesh channels

class ParametricKernel(torch.nn.Module):
    def __init__(self, sigma: torch.Tensor, a0: torch.Tensor):
        super().__init__()
        self._sigma = sigma
        self._a0 = a0

    def kernel_from_k_sq(self, k_sq):
        filter = torch.stack([torch.exp(-k_sq * s**2 / 2) for s in self._sigma])
        filter[0, :] *= self._a0[0] / (1 + k_sq)
        filter[1, :] *= self._a0[1] / (1 + k_sq**3)
        return filter

We define a 2D weights (to get a 2D mesh), and define parameters as optimizable quantities

weights = torch.tensor(
    [
        [1.0, 1.0],
        [-1.0, 1.0],
        [-1.0, 1.0],
        [1.0, 1.0],
        [-1.0, 1.0],
        [1.0, 1.0],
        [1.0, 1.0],
        [-1.0, 1.0],
    ],
    dtype=dtype,
    device=device,
)

torch.autograd.set_detect_anomaly(True)
sigma = torch.tensor([1.0, 0.5], dtype=dtype, device=device)
a0 = torch.tensor([1.0, 2.0], dtype=dtype, device=device)

positions = positions.detach()
cell = cell.detach()
positions.requires_grad_(True)
cell.requires_grad_(True)

weights = weights.detach()
sigma = sigma.detach()
a0 = a0.detach()
weights.requires_grad_(True)
sigma.requires_grad_(True)
a0.requires_grad_(True)
tensor([1., 2.], dtype=torch.float64, requires_grad=True)

Compute the mesh, apply the filter, and also complete the PME-like operation by evaluating the transformed mesh at the atom positions

interpolator = torchpme.lib.MeshInterpolator(cell, ns, 3, method="Lagrange")
interpolator.compute_weights(positions)
mesh = interpolator.points_to_mesh(weights)

kernel = ParametricKernel(sigma, a0)
kernel_filter = torchpme.lib.KSpaceFilter(cell, ns, kernel=kernel)

filtered = kernel_filter.forward(mesh)

filtered_at_positions = interpolator.mesh_to_points(filtered)

Computes a (rather arbitrary) function of the outputs, backpropagates and then outputs the gradients. With this messy non-linear function everything has nonzero gradients

print(
    f"""
Value: {value}

Position gradients:
{positions.grad.T}

Cell gradients:
{cell.grad}

Weights gradients:
{weights.grad.T}

Param. a0:
{a0.grad}

Param. sigma:
{sigma.grad}
"""
)
Value: 0.09174260726413486

Position gradients:
tensor([[-7.9879e-03,  2.3547e-03, -3.0414e-04,  7.0431e-05,  5.3453e-03,
          1.7170e-02,  2.4995e-04, -6.4724e-03],
        [ 3.7694e-03,  7.8182e-03, -1.0497e-02,  1.5011e-02,  1.3578e-02,
          8.5528e-03, -1.2111e-02, -1.5039e-02],
        [-6.2817e-03,  1.6610e-03, -1.6096e-02, -4.1102e-03,  8.5094e-03,
          9.8638e-03, -3.0888e-03,  7.3603e-03]], dtype=torch.float64)

Cell gradients:
tensor([[ 0.0402, -0.0084, -0.0087],
        [ 0.0036,  0.0526,  0.0105],
        [-0.0086,  0.0036,  0.0310]], dtype=torch.float64)

Weights gradients:
tensor([[ 0.0178, -0.0275, -0.0240,  0.0188, -0.0255,  0.0137,  0.0132, -0.0277],
        [ 0.0126, -0.0279, -0.0229,  0.0140, -0.0295,  0.0061,  0.0028, -0.0317]],
       dtype=torch.float64)

Param. a0:
tensor([ 0.1681, -0.0382], dtype=torch.float64)

Param. sigma:
tensor([-0.5573,  0.0097], dtype=torch.float64)

A torch module based on torchpme

It is also possible to combine all this in a custom torch.nn.Module, which is the first step towards designing a model training pipeline based on a custom torchpme model.

We start by defining a Yukawa-like potential, and a (rather contrieved) model that combines a Fourier filter, with a multi-layer perceptron to post-process charges and “potential”.

# Define the kernel
class SmearedCoulomb(torchpme.lib.KSpaceKernel):
    def __init__(self, sigma2):
        super().__init__()
        self._sigma2 = sigma2

    def kernel_from_k_sq(self, k_sq):
        # we use a mask to set to zero the Gamma-point filter
        mask = torch.ones_like(k_sq, dtype=torch.bool, device=k_sq.device)
        mask[..., 0, 0, 0] = False
        potential = torch.zeros_like(k_sq)
        potential[mask] = torch.exp(-k_sq[mask] * self._sigma2 * 0.5) / k_sq[mask]
        return potential


# Define the module
class KSpaceModule(torch.nn.Module):
    """A demonstrative model combining torchpme and a multi-layer perceptron"""

    def __init__(
        self, mesh_spacing: float = 0.5, sigma2: float = 1.0, hidden_sizes=None
    ):
        super().__init__()
        self._mesh_spacing = mesh_spacing

        # degree of smearing as an optimizable parameter
        self._sigma2 = torch.nn.Parameter(
            torch.tensor(sigma2, dtype=dtype, device=device)
        )

        dummy_cell = torch.eye(3, dtype=dtype)
        self._mesh_interpolator = torchpme.lib.MeshInterpolator(
            cell=dummy_cell,
            ns_mesh=torch.tensor([1, 1, 1]),
            interpolation_nodes=3,
            method="Lagrange",
        )
        self._kernel_filter = torchpme.lib.KSpaceFilter(
            cell=dummy_cell,
            ns_mesh=torch.tensor([1, 1, 1]),
            kernel=SmearedCoulomb(self._sigma2),
        )

        if hidden_sizes is None:  # default architecture
            hidden_sizes = [10, 10]

        # a neural network to process "charge and potential"
        last_size = 2  # input is charge and potential
        self._layers = torch.nn.ModuleList()
        for hidden_size in hidden_sizes:
            self._layers.append(
                torch.nn.Linear(last_size, hidden_size, dtype=dtype, device=device)
            )
            self._layers.append(torch.nn.Tanh())
            last_size = hidden_size
        self._output_layer = torch.nn.Linear(
            last_size, 1, dtype=dtype, device=device
        )  # outputs one value

    def forward(self, positions, cell, charges):
        # use a helper function to get the mesh size given resolution
        ns_mesh = torchpme.lib.get_ns_mesh(cell, self._mesh_spacing)
        ns_mesh = torch.tensor([4, 4, 4])

        self._mesh_interpolator.update(cell=cell, ns_mesh=ns_mesh)
        self._mesh_interpolator.compute_weights(positions)
        mesh = self._mesh_interpolator.points_to_mesh(charges)

        self._kernel_filter.update(cell, ns_mesh)
        mesh = self._kernel_filter.forward(mesh)
        pot = self._mesh_interpolator.mesh_to_points(mesh)

        x = torch.hstack([charges, pot])
        for layer in self._layers:
            x = layer(x)
        # Output layer
        x = self._output_layer(x)
        return x.sum()

Creates an instance of the model and evaluates it.

my_module = KSpaceModule(sigma2=1.0, mesh_spacing=1.0, hidden_sizes=[10, 4, 10])

# (re-)initialize vectors

charges = charges.detach()
positions = positions.detach()
cell = cell.detach()
charges.requires_grad_(True)
positions.requires_grad_(True)
cell.requires_grad_(True)

value = my_module.forward(positions, cell, charges)
value.backward()

Gradients compute, and look reasonable!

print(
    f"""
Value: {value}

Position gradients:
{positions.grad.T}

Cell gradients:
{cell.grad}

Charges gradients:
{charges.grad.T}
"""
)
Value: 2.317331764158612

Position gradients:
tensor([[-5.8288e-05, -4.5998e-05, -1.6066e-04, -1.2151e-04, -1.5586e-04,
          7.9277e-05, -8.0956e-05,  2.0819e-04],
        [-1.0582e-04,  5.6672e-05,  6.0971e-05, -2.6420e-05, -2.4710e-04,
         -2.3369e-04,  1.2400e-04,  1.7713e-04],
        [ 2.1109e-04, -1.3603e-04,  2.2981e-04,  3.4570e-05, -2.1192e-05,
          1.2179e-04, -2.5561e-04, -1.4806e-04]], dtype=torch.float64)

Cell gradients:
tensor([[-1.6788e-03,  2.8773e-05,  6.9028e-05],
        [ 7.5213e-05, -1.7471e-03,  6.9558e-05],
        [-1.1720e-05,  1.0415e-04, -1.4248e-03]], dtype=torch.float64)

Charges gradients:
tensor([[0.0907, 0.0491, 0.0491, 0.0880, 0.0491, 0.0928, 0.0914, 0.0457]],
       dtype=torch.float64)

… also on the MLP parameters!

{'weight': Parameter containing:
tensor([[ 0.0427,  0.4937],
        [-0.2863,  0.2163],
        [-0.1467,  0.1987],
        [-0.4445, -0.4371],
        [ 0.0655,  0.6123],
        [-0.7011,  0.4613],
        [ 0.1286, -0.3070],
        [ 0.5212, -0.4941],
        [ 0.0126,  0.3867],
        [ 0.0575,  0.0018]], dtype=torch.float64, requires_grad=True), 'bias': Parameter containing:
tensor([ 0.2709, -0.2062,  0.0294,  0.0188,  0.5238,  0.6004, -0.5211, -0.2365,
         0.6442, -0.1731], dtype=torch.float64, requires_grad=True)}
{}
{'weight': Parameter containing:
tensor([[ 0.2117,  0.1724,  0.1401, -0.1087,  0.0799, -0.2246,  0.1223,  0.2900,
          0.0260, -0.0147],
        [ 0.0231,  0.2833, -0.2252, -0.2948,  0.1909,  0.1607,  0.2409,  0.2119,
          0.1756, -0.0786],
        [-0.2551,  0.2784, -0.1645,  0.1819,  0.0732,  0.0837,  0.0114,  0.0257,
          0.1108,  0.2979],
        [-0.2434,  0.0693,  0.1852,  0.1666,  0.2755,  0.2535, -0.0136, -0.2607,
         -0.0189,  0.2182]], dtype=torch.float64, requires_grad=True), 'bias': Parameter containing:
tensor([-0.0236, -0.0401, -0.1675,  0.2271], dtype=torch.float64,
       requires_grad=True)}
{}
{'weight': Parameter containing:
tensor([[ 0.4973,  0.3861,  0.3834,  0.4129],
        [ 0.2016,  0.2200, -0.3462,  0.3720],
        [-0.2122, -0.2532,  0.4511,  0.2761],
        [ 0.1117, -0.1797,  0.0610,  0.0573],
        [-0.2799,  0.2741, -0.2999,  0.1364],
        [ 0.2040,  0.4799,  0.3296, -0.2455],
        [-0.1693,  0.0848, -0.4465,  0.3113],
        [ 0.1281,  0.1472,  0.4921, -0.4667],
        [ 0.2560,  0.0359,  0.3709,  0.0521],
        [ 0.2449,  0.1820,  0.4035, -0.0759]], dtype=torch.float64,
       requires_grad=True), 'bias': Parameter containing:
tensor([-0.4317,  0.4716, -0.0125,  0.2594,  0.0538, -0.1705,  0.3507,  0.3776,
        -0.3088, -0.2706], dtype=torch.float64, requires_grad=True)}
{}

It’s always good to run some gradcheck

my_module.zero_grad()
check = torch.autograd.gradcheck(
    my_module,
    (
        torch.randn((16, 3), device=device, dtype=dtype, requires_grad=True),
        torch.randn((3, 3), device=device, dtype=dtype, requires_grad=True),
        torch.randn((16, 1), device=device, dtype=dtype, requires_grad=True),
    ),
)
if check:
    print("gradcheck passed for custom torch-pme module")
else:
    raise ValueError("gradcheck failed for custom torch-pme module")
gradcheck passed for custom torch-pme module

Jitting a custom module

The custom module can also be jitted!

old_cell_grad = cell.grad.clone()
jit_module = torch.jit.script(my_module)

jit_charges = charges.detach()
jit_positions = positions.detach()
jit_cell = cell.detach()
jit_cell.requires_grad_(True)
jit_charges.requires_grad_(True)
jit_positions.requires_grad_(True)

jit_value = jit_module.forward(jit_positions, jit_cell, jit_charges)
jit_value.backward()

Values match within machine precision

print(
    f"""
Delta-Value: {value - jit_value}

Delta-Position gradients:
{positions.grad.T - jit_positions.grad.T}

Delta-Cell gradients:
{cell.grad - jit_cell.grad}

Delta-Charges gradients:
{charges.grad.T - jit_charges.grad.T}
"""
)
Delta-Value: 0.0

Delta-Position gradients:
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]], dtype=torch.float64)

Delta-Cell gradients:
tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]], dtype=torch.float64)

Delta-Charges gradients:
tensor([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=torch.float64)

We can also evaluate the difference in execution time between the Pytorch and scripted versions of the module (depending on the system, the relative efficiency of the two evaluations could go either way, as this is a too small system to make a difference!)

duration = 0.0
for _i in range(20):
    my_module.zero_grad()
    positions = positions.detach()
    cell = cell.detach()
    charges = charges.detach()
    duration -= time()
    value = my_module.forward(positions, cell, charges)
    value.backward()
    if device == "cuda":
        torch.cuda.synchronize()
    duration += time()
time_python = (duration) * 1e3 / 20

duration = 0.0
for _i in range(20):
    jit_module.zero_grad()
    positions = positions.detach()
    cell = cell.detach()
    charges = charges.detach()
    duration -= time()
    value = jit_module.forward(positions, cell, charges)
    value.backward()
    if device == "cuda":
        torch.cuda.synchronize()
    duration += time()
time_jit = (duration) * 1e3 / 20
print(f"Evaluation time:\nPytorch: {time_python}ms\nJitted:  {time_jit}ms")
Evaluation time:
Pytorch: 7.775735855102539ms
Jitted:  9.124588966369629ms

Total running time of the script: (0 minutes 6.142 seconds)

Gallery generated by Sphinx-Gallery