Note
Go to the end to download the full example code.
Custom models with automatic differentiation¶
- Authors:
Michele Ceriotti @ceriottm
This example showcases how the main building blocks of torchpme
,
MeshInterpolator
and KSpacaFilter
can be combined creatively to
construct arbitrary models that incorporate long-range structural correlations.
None of the models presented here has probably much meaning, and the use in a ML setting (including the definition of an appropriate loss, and its optimization) is left as an exercise to the reader.
from time import time
import ase
import torch
import torchpme
device = "cpu"
dtype = torch.float64
rng = torch.Generator()
rng.manual_seed(32)
<torch._C.Generator object at 0x7f8fb059cb70>
Generate a trial structure – a distorted rocksalt structure with perturbed positions and charges
structure = ase.Atoms(
positions=[
[0, 0, 0],
[3, 0, 0],
[0, 3, 0],
[3, 3, 0],
[0, 0, 3],
[3, 0, 3],
[0, 3, 3],
[3, 3, 3],
],
cell=[6, 6, 6],
symbols="NaClClNaClNaNaCl",
)
displacement = torch.normal(
mean=0.0, std=2.5e-1, size=(len(structure), 3), generator=rng
)
structure.positions += displacement.numpy()
charges = torch.tensor(
[[1.0], [-1.0], [-1.0], [1.0], [-1.0], [1.0], [1.0], [-1.0]],
dtype=dtype,
device=device,
)
charges += torch.normal(mean=0.0, std=1e-1, size=(len(charges), 1), generator=rng)
positions = torch.from_numpy(structure.positions).to(device=device, dtype=dtype)
cell = torch.from_numpy(structure.cell.array).to(device=device, dtype=dtype)
Autodifferentiation through the core torchpme
classes¶
We begin by showing how it is possible to compute a function of the internal state for the core classes, and to differentiate with respect to the structural and input parameters.
Functions of the atom density¶
The construction of a “decorated atom density” through
MeshInterpolator
can be easily differentiated through.
We only need to request a gradient evaluation, evaluate the grid, and compute
a function of the grid points (again, this is a proof-of-principle example,
probably not very useful in practice).
positions.requires_grad_(True)
charges.requires_grad_(True)
cell.requires_grad_(True)
ns = torch.tensor([5, 5, 5])
interpolator = torchpme.lib.MeshInterpolator(
cell=cell, ns_mesh=ns, interpolation_nodes=3, method="Lagrange"
)
interpolator.compute_weights(positions)
mesh = interpolator.points_to_mesh(charges)
value = mesh.sum()
The gradients can be computed by just running backward on the end result. Because of the sum rules that apply to the interpolation scheme, the gradients with respect to positions and cell entries are zero, and the gradients relative to the charges are all 1.
# we keep the graph to compute another quantity
value.backward(retain_graph=True)
print(
f"""
Position gradients:
{positions.grad.T}
Cell gradients:
{cell.grad}
Charges gradients:
{charges.grad.T}
"""
)
Position gradients:
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=torch.float64)
Cell gradients:
tensor([[-0., -0., -0.],
[-0., -0., -0.],
[-0., -0., -0.]], dtype=torch.float64)
Charges gradients:
tensor([[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000]],
dtype=torch.float64)
If we apply a non-linear function before summing, these sum rules apply only approximately.
Position gradients:
tensor([[-0.1328, -0.1042, 0.0396, 0.0984, -0.0198, -0.0627, -0.0130, 0.1468],
[-0.0235, -0.0561, 0.1602, -0.1015, 0.1001, -0.0472, 0.0776, 0.1489],
[ 0.0749, 0.0315, -0.0837, -0.0269, 0.1536, -0.0734, 0.0683, -0.1032]],
dtype=torch.float64)
Cell gradients:
tensor([[-0.0547, 0.0234, 0.0970],
[-0.1439, -0.1427, 0.0822],
[-0.0115, -0.1393, -0.0238]], dtype=torch.float64)
Charges gradients:
tensor([[0.7647, 0.5572, 0.7826, 0.8653, 0.7911, 0.8245, 0.8831, 0.8013]],
dtype=torch.float64)
Indirect functions of the weights¶
It is possible to have the atomic weights be a function of other quantities. For instance, pretend there is an external electric field along \(x\), and that the weights should be proportional to the electrostatic energy at each atom position (NB: defining an electric field in a periodic setting is not so simple, this is just a toy example).
positions.grad.zero_()
charges.grad.zero_()
cell.grad.zero_()
weights = charges * positions[:, :1]
mesh3 = interpolator.points_to_mesh(weights)
value3 = mesh3.sum()
value3.backward()
print(
f"""
Position gradients:
{positions.grad.T}
Cell gradients:
{cell.grad}
Charges gradients:
{charges.grad.T}
"""
)
Position gradients:
tensor([[ 0.9216, -1.0203, -1.0203, 1.0234, -1.0189, 0.8147, 0.8905, -1.1390],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
dtype=torch.float64)
Cell gradients:
tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=torch.float64)
Charges gradients:
tensor([[-0.4354, 2.5653, -0.1268, 2.9276, 0.0661, 3.2907, -0.0991, 3.1386]],
dtype=torch.float64)
Optimizable k-space filter¶
The operations in a
KSpaceFilter
can also be differentiated through.
A parametric k-space filter¶
We define a filter with multiple smearing parameters, that are applied separately to multiple mesh channels
class ParametricKernel(torch.nn.Module):
def __init__(self, sigma: torch.Tensor, a0: torch.Tensor):
super().__init__()
self._sigma = sigma
self._a0 = a0
def kernel_from_k_sq(self, k_sq):
filter = torch.stack([torch.exp(-k_sq * s**2 / 2) for s in self._sigma])
filter[0, :] *= self._a0[0] / (1 + k_sq)
filter[1, :] *= self._a0[1] / (1 + k_sq**3)
return filter
We define a 2D weights (to get a 2D mesh), and define parameters as optimizable quantities
weights = torch.tensor(
[
[1.0, 1.0],
[-1.0, 1.0],
[-1.0, 1.0],
[1.0, 1.0],
[-1.0, 1.0],
[1.0, 1.0],
[1.0, 1.0],
[-1.0, 1.0],
],
dtype=dtype,
device=device,
)
torch.autograd.set_detect_anomaly(True)
sigma = torch.tensor([1.0, 0.5], dtype=dtype, device=device)
a0 = torch.tensor([1.0, 2.0], dtype=dtype, device=device)
positions = positions.detach()
cell = cell.detach()
positions.requires_grad_(True)
cell.requires_grad_(True)
weights = weights.detach()
sigma = sigma.detach()
a0 = a0.detach()
weights.requires_grad_(True)
sigma.requires_grad_(True)
a0.requires_grad_(True)
tensor([1., 2.], dtype=torch.float64, requires_grad=True)
Compute the mesh, apply the filter, and also complete the PME-like operation by evaluating the transformed mesh at the atom positions
interpolator = torchpme.lib.MeshInterpolator(cell, ns, 3, method="Lagrange")
interpolator.compute_weights(positions)
mesh = interpolator.points_to_mesh(weights)
kernel = ParametricKernel(sigma, a0)
kernel_filter = torchpme.lib.KSpaceFilter(cell, ns, kernel=kernel)
filtered = kernel_filter.forward(mesh)
filtered_at_positions = interpolator.mesh_to_points(filtered)
Computes a (rather arbitrary) function of the outputs, backpropagates and then outputs the gradients. With this messy non-linear function everything has nonzero gradients
value = (charges * filtered_at_positions).sum()
value.backward()
print(
f"""
Value: {value}
Position gradients:
{positions.grad.T}
Cell gradients:
{cell.grad}
Weights gradients:
{weights.grad.T}
Param. a0:
{a0.grad}
Param. sigma:
{sigma.grad}
"""
)
Value: 0.09174260726413486
Position gradients:
tensor([[-7.9879e-03, 2.3547e-03, -3.0414e-04, 7.0431e-05, 5.3453e-03,
1.7170e-02, 2.4995e-04, -6.4724e-03],
[ 3.7694e-03, 7.8182e-03, -1.0497e-02, 1.5011e-02, 1.3578e-02,
8.5528e-03, -1.2111e-02, -1.5039e-02],
[-6.2817e-03, 1.6610e-03, -1.6096e-02, -4.1102e-03, 8.5094e-03,
9.8638e-03, -3.0888e-03, 7.3603e-03]], dtype=torch.float64)
Cell gradients:
tensor([[ 0.0402, -0.0084, -0.0087],
[ 0.0036, 0.0526, 0.0105],
[-0.0086, 0.0036, 0.0310]], dtype=torch.float64)
Weights gradients:
tensor([[ 0.0178, -0.0275, -0.0240, 0.0188, -0.0255, 0.0137, 0.0132, -0.0277],
[ 0.0126, -0.0279, -0.0229, 0.0140, -0.0295, 0.0061, 0.0028, -0.0317]],
dtype=torch.float64)
Param. a0:
tensor([ 0.1681, -0.0382], dtype=torch.float64)
Param. sigma:
tensor([-0.5573, 0.0097], dtype=torch.float64)
A torch
module based on torchpme
¶
It is also possible to combine all this in a
custom torch.nn.Module
, which is the
first step towards designing a model training pipeline
based on a custom torchpme
model.
We start by defining a Yukawa-like potential, and a (rather contrieved) model that combines a Fourier filter, with a multi-layer perceptron to post-process charges and “potential”.
# Define the kernel
class SmearedCoulomb(torchpme.lib.KSpaceKernel):
def __init__(self, sigma2):
super().__init__()
self._sigma2 = sigma2
def kernel_from_k_sq(self, k_sq):
# we use a mask to set to zero the Gamma-point filter
mask = torch.ones_like(k_sq, dtype=torch.bool, device=k_sq.device)
mask[..., 0, 0, 0] = False
potential = torch.zeros_like(k_sq)
potential[mask] = torch.exp(-k_sq[mask] * self._sigma2 * 0.5) / k_sq[mask]
return potential
# Define the module
class KSpaceModule(torch.nn.Module):
"""A demonstrative model combining torchpme and a multi-layer perceptron"""
def __init__(
self, mesh_spacing: float = 0.5, sigma2: float = 1.0, hidden_sizes=None
):
super().__init__()
self._mesh_spacing = mesh_spacing
# degree of smearing as an optimizable parameter
self._sigma2 = torch.nn.Parameter(
torch.tensor(sigma2, dtype=dtype, device=device)
)
dummy_cell = torch.eye(3, dtype=dtype)
self._mesh_interpolator = torchpme.lib.MeshInterpolator(
cell=dummy_cell,
ns_mesh=torch.tensor([1, 1, 1]),
interpolation_nodes=3,
method="Lagrange",
)
self._kernel_filter = torchpme.lib.KSpaceFilter(
cell=dummy_cell,
ns_mesh=torch.tensor([1, 1, 1]),
kernel=SmearedCoulomb(self._sigma2),
)
if hidden_sizes is None: # default architecture
hidden_sizes = [10, 10]
# a neural network to process "charge and potential"
last_size = 2 # input is charge and potential
self._layers = torch.nn.ModuleList()
for hidden_size in hidden_sizes:
self._layers.append(
torch.nn.Linear(last_size, hidden_size, dtype=dtype, device=device)
)
self._layers.append(torch.nn.Tanh())
last_size = hidden_size
self._output_layer = torch.nn.Linear(
last_size, 1, dtype=dtype, device=device
) # outputs one value
def forward(self, positions, cell, charges):
# use a helper function to get the mesh size given resolution
ns_mesh = torchpme.lib.get_ns_mesh(cell, self._mesh_spacing)
ns_mesh = torch.tensor([4, 4, 4])
self._mesh_interpolator.update(cell=cell, ns_mesh=ns_mesh)
self._mesh_interpolator.compute_weights(positions)
mesh = self._mesh_interpolator.points_to_mesh(charges)
self._kernel_filter.update(cell, ns_mesh)
mesh = self._kernel_filter.forward(mesh)
pot = self._mesh_interpolator.mesh_to_points(mesh)
x = torch.hstack([charges, pot])
for layer in self._layers:
x = layer(x)
# Output layer
x = self._output_layer(x)
return x.sum()
Creates an instance of the model and evaluates it.
my_module = KSpaceModule(sigma2=1.0, mesh_spacing=1.0, hidden_sizes=[10, 4, 10])
# (re-)initialize vectors
charges = charges.detach()
positions = positions.detach()
cell = cell.detach()
charges.requires_grad_(True)
positions.requires_grad_(True)
cell.requires_grad_(True)
value = my_module.forward(positions, cell, charges)
value.backward()
Gradients compute, and look reasonable!
Value: 2.317331764158612
Position gradients:
tensor([[-5.8288e-05, -4.5998e-05, -1.6066e-04, -1.2151e-04, -1.5586e-04,
7.9277e-05, -8.0956e-05, 2.0819e-04],
[-1.0582e-04, 5.6672e-05, 6.0971e-05, -2.6420e-05, -2.4710e-04,
-2.3369e-04, 1.2400e-04, 1.7713e-04],
[ 2.1109e-04, -1.3603e-04, 2.2981e-04, 3.4570e-05, -2.1192e-05,
1.2179e-04, -2.5561e-04, -1.4806e-04]], dtype=torch.float64)
Cell gradients:
tensor([[-1.6788e-03, 2.8773e-05, 6.9028e-05],
[ 7.5213e-05, -1.7471e-03, 6.9558e-05],
[-1.1720e-05, 1.0415e-04, -1.4248e-03]], dtype=torch.float64)
Charges gradients:
tensor([[0.0907, 0.0491, 0.0491, 0.0880, 0.0491, 0.0928, 0.0914, 0.0457]],
dtype=torch.float64)
… also on the MLP parameters!
for layer in my_module._layers:
print(layer._parameters)
{'weight': Parameter containing:
tensor([[ 0.0427, 0.4937],
[-0.2863, 0.2163],
[-0.1467, 0.1987],
[-0.4445, -0.4371],
[ 0.0655, 0.6123],
[-0.7011, 0.4613],
[ 0.1286, -0.3070],
[ 0.5212, -0.4941],
[ 0.0126, 0.3867],
[ 0.0575, 0.0018]], dtype=torch.float64, requires_grad=True), 'bias': Parameter containing:
tensor([ 0.2709, -0.2062, 0.0294, 0.0188, 0.5238, 0.6004, -0.5211, -0.2365,
0.6442, -0.1731], dtype=torch.float64, requires_grad=True)}
{}
{'weight': Parameter containing:
tensor([[ 0.2117, 0.1724, 0.1401, -0.1087, 0.0799, -0.2246, 0.1223, 0.2900,
0.0260, -0.0147],
[ 0.0231, 0.2833, -0.2252, -0.2948, 0.1909, 0.1607, 0.2409, 0.2119,
0.1756, -0.0786],
[-0.2551, 0.2784, -0.1645, 0.1819, 0.0732, 0.0837, 0.0114, 0.0257,
0.1108, 0.2979],
[-0.2434, 0.0693, 0.1852, 0.1666, 0.2755, 0.2535, -0.0136, -0.2607,
-0.0189, 0.2182]], dtype=torch.float64, requires_grad=True), 'bias': Parameter containing:
tensor([-0.0236, -0.0401, -0.1675, 0.2271], dtype=torch.float64,
requires_grad=True)}
{}
{'weight': Parameter containing:
tensor([[ 0.4973, 0.3861, 0.3834, 0.4129],
[ 0.2016, 0.2200, -0.3462, 0.3720],
[-0.2122, -0.2532, 0.4511, 0.2761],
[ 0.1117, -0.1797, 0.0610, 0.0573],
[-0.2799, 0.2741, -0.2999, 0.1364],
[ 0.2040, 0.4799, 0.3296, -0.2455],
[-0.1693, 0.0848, -0.4465, 0.3113],
[ 0.1281, 0.1472, 0.4921, -0.4667],
[ 0.2560, 0.0359, 0.3709, 0.0521],
[ 0.2449, 0.1820, 0.4035, -0.0759]], dtype=torch.float64,
requires_grad=True), 'bias': Parameter containing:
tensor([-0.4317, 0.4716, -0.0125, 0.2594, 0.0538, -0.1705, 0.3507, 0.3776,
-0.3088, -0.2706], dtype=torch.float64, requires_grad=True)}
{}
It’s always good to run some gradcheck…
my_module.zero_grad()
check = torch.autograd.gradcheck(
my_module,
(
torch.randn((16, 3), device=device, dtype=dtype, requires_grad=True),
torch.randn((3, 3), device=device, dtype=dtype, requires_grad=True),
torch.randn((16, 1), device=device, dtype=dtype, requires_grad=True),
),
)
if check:
print("gradcheck passed for custom torch-pme module")
else:
raise ValueError("gradcheck failed for custom torch-pme module")
gradcheck passed for custom torch-pme module
Jitting a custom module¶
The custom module can also be jitted!
old_cell_grad = cell.grad.clone()
jit_module = torch.jit.script(my_module)
jit_charges = charges.detach()
jit_positions = positions.detach()
jit_cell = cell.detach()
jit_cell.requires_grad_(True)
jit_charges.requires_grad_(True)
jit_positions.requires_grad_(True)
jit_value = jit_module.forward(jit_positions, jit_cell, jit_charges)
jit_value.backward()
Values match within machine precision
print(
f"""
Delta-Value: {value - jit_value}
Delta-Position gradients:
{positions.grad.T - jit_positions.grad.T}
Delta-Cell gradients:
{cell.grad - jit_cell.grad}
Delta-Charges gradients:
{charges.grad.T - jit_charges.grad.T}
"""
)
Delta-Value: 0.0
Delta-Position gradients:
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=torch.float64)
Delta-Cell gradients:
tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=torch.float64)
Delta-Charges gradients:
tensor([[0., 0., 0., 0., 0., 0., 0., 0.]], dtype=torch.float64)
We can also evaluate the difference in execution time between the Pytorch and scripted versions of the module (depending on the system, the relative efficiency of the two evaluations could go either way, as this is a too small system to make a difference!)
duration = 0.0
for _i in range(20):
my_module.zero_grad()
positions = positions.detach()
cell = cell.detach()
charges = charges.detach()
duration -= time()
value = my_module.forward(positions, cell, charges)
value.backward()
if device == "cuda":
torch.cuda.synchronize()
duration += time()
time_python = (duration) * 1e3 / 20
duration = 0.0
for _i in range(20):
jit_module.zero_grad()
positions = positions.detach()
cell = cell.detach()
charges = charges.detach()
duration -= time()
value = jit_module.forward(positions, cell, charges)
value.backward()
if device == "cuda":
torch.cuda.synchronize()
duration += time()
time_jit = (duration) * 1e3 / 20
print(f"Evaluation time:\nPytorch: {time_python}ms\nJitted: {time_jit}ms")
Evaluation time:
Pytorch: 7.775735855102539ms
Jitted: 9.124588966369629ms
Total running time of the script: (0 minutes 6.142 seconds)