.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "examples/12-padding-example.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_examples_12-padding-example.py: Batched Ewald Computation with Padding ====================================== This example demonstrates how to compute Ewald potentials for a batch of systems with different numbers of atoms using padding. The idea is to pad atomic positions, charges, and neighbor lists to the same length and use masks to ignore padded entries during computation. Note that batching systems of varying sizes in this way can increase the computational cost during model training, since padded atoms are included in the batched operations even though they don't contribute physically. .. GENERATED FROM PYTHON SOURCE LINES 14-25 .. code-block:: Python import time import torch import vesin from torch.nn.utils.rnn import pad_sequence import torchpme dtype = torch.float64 cutoff = 4.4 .. GENERATED FROM PYTHON SOURCE LINES 26-27 Example: two systems with 5 different systems .. GENERATED FROM PYTHON SOURCE LINES 27-91 .. code-block:: Python systems = [ { "symbols": ("Cs", "Cl"), "positions": torch.tensor([(0, 0, 0), (0.5, 0.5, 0.5)], dtype=dtype), "charges": torch.tensor([[1.0], [-1.0]], dtype=dtype), "cell": torch.eye(3, dtype=dtype) * 3.0, "pbc": torch.tensor([True, True, True]), }, { "symbols": ("Na", "Cl", "Cl"), "positions": torch.tensor( [(0, 0, 0), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25)], dtype=dtype ), "charges": torch.tensor([[1.0], [-1.0], [-1.0]], dtype=dtype), "cell": torch.eye(3, dtype=dtype) * 4.0, "pbc": torch.tensor([True, True, True]), }, { "symbols": ("K", "Br", "Br", "K"), "positions": torch.tensor( [(0, 0, 0), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25), (0.75, 0.75, 0.75)], dtype=dtype, ), "charges": torch.tensor([[1.0], [-1.0], [-1.0], [1.0]], dtype=dtype), "cell": torch.eye(3, dtype=dtype) * 5.0, "pbc": torch.tensor([True, True, True]), }, { "symbols": ("Mg", "O", "O", "Mg", "O"), "positions": torch.tensor( [ (0, 0, 0), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25), (0.75, 0.75, 0.75), (0.1, 0.1, 0.1), ], dtype=dtype, ), "charges": torch.tensor([[2.0], [-2.0], [-2.0], [2.0], [-2.0]], dtype=dtype), "cell": torch.eye(3, dtype=dtype) * 6.0, "pbc": torch.tensor([True, True, True]), }, { "symbols": ("Al", "O", "O", "Al", "O", "O"), "positions": torch.tensor( [ (0, 0, 0), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25), (0.75, 0.75, 0.75), (0.1, 0.1, 0.1), (0.9, 0.9, 0.9), ], dtype=dtype, ), "charges": torch.tensor( [[3.0], [-2.0], [-2.0], [3.0], [-2.0], [-2.0]], dtype=dtype ), "cell": torch.eye(3, dtype=dtype) * 7.0, "pbc": torch.tensor([True, True, True]), }, ] .. GENERATED FROM PYTHON SOURCE LINES 92-93 Compute neighbor lists for each system .. GENERATED FROM PYTHON SOURCE LINES 93-120 .. code-block:: Python i_list, j_list, d_list, pos_list, charges_list, cell_list, periodic_list = ( [], [], [], [], [], [], [], ) nl = vesin.NeighborList(cutoff=cutoff, full_list=False) for sys in systems: neighbor_indices, neighbor_distances = nl.compute( points=sys["positions"], box=sys["cell"], periodic=sys["pbc"][0], quantities="Pd", ) i_list.append(torch.tensor(neighbor_indices[:, 0], dtype=torch.int64)) j_list.append(torch.tensor(neighbor_indices[:, 1], dtype=torch.int64)) d_list.append(torch.tensor(neighbor_distances, dtype=dtype)) pos_list.append(sys["positions"]) charges_list.append(sys["charges"]) cell_list.append(sys["cell"]) periodic_list.append(sys["pbc"]) .. GENERATED FROM PYTHON SOURCE LINES 121-122 Pad positions, charges, and neighbor lists .. GENERATED FROM PYTHON SOURCE LINES 122-140 .. code-block:: Python max_atoms = max(pos.shape[0] for pos in pos_list) pos_batch = pad_sequence(pos_list, batch_first=True) charges_batch = pad_sequence(charges_list, batch_first=True) cell_batch = torch.stack(cell_list) periodic_batch = torch.stack(periodic_list) i_batch = pad_sequence(i_list, batch_first=True, padding_value=0) j_batch = pad_sequence(j_list, batch_first=True, padding_value=0) d_batch = pad_sequence(d_list, batch_first=True, padding_value=0.0) # Masks for ignoring padded atoms and neighbor entries node_mask = ( torch.arange(max_atoms)[None, :] < torch.tensor([p.shape[0] for p in pos_list])[:, None] ) pair_mask = ( torch.arange(i_batch.shape[1])[None, :] < torch.tensor([len(i) for i in i_list])[:, None] ) .. GENERATED FROM PYTHON SOURCE LINES 141-142 Initialize Ewald calculator .. GENERATED FROM PYTHON SOURCE LINES 142-148 .. code-block:: Python calculator = torchpme.EwaldCalculator( torchpme.CoulombPotential(smearing=0.5), lr_wavelength=4.0, ) calculator.to(dtype=dtype) .. rst-class:: sphx-glr-script-out .. code-block:: none EwaldCalculator( (potential): CoulombPotential() ) .. GENERATED FROM PYTHON SOURCE LINES 149-150 Compute potentials in a batched manner using vmap .. GENERATED FROM PYTHON SOURCE LINES 150-166 .. code-block:: Python kvectors = torchpme.lib.compute_batched_kvectors( lr_wavelength=calculator.lr_wavelength, cells=cell_batch ) potentials_batch = torch.vmap(calculator.forward)( charges_batch, cell_batch, pos_batch, torch.stack((i_batch, j_batch), dim=-1), d_batch, periodic_batch, node_mask, pair_mask, kvectors, ) .. GENERATED FROM PYTHON SOURCE LINES 167-169 .. code-block:: Python print("Batched potentials shape:", potentials_batch.shape) print(potentials_batch) .. rst-class:: sphx-glr-script-out .. code-block:: none Batched potentials shape: torch.Size([5, 6, 1]) tensor([[[-0.8460], [ 0.8460], [ 0.0000], [ 0.0000], [ 0.0000], [ 0.0000]], [[-1.2799], [ 0.4120], [ 0.8102], [ 0.0000], [ 0.0000], [ 0.0000]], [[-1.3172], [ 0.8011], [ 0.8011], [-1.3172], [ 0.0000], [ 0.0000]], [[-7.0317], [ 1.2284], [-0.8886], [-2.7510], [ 3.0087], [ 0.0000]], [[-7.7262], [ 1.5084], [-0.3403], [-5.9061], [ 5.2142], [ 4.6411]]], dtype=torch.float64) .. GENERATED FROM PYTHON SOURCE LINES 170-171 Compare performance of batched vs. looped computation .. GENERATED FROM PYTHON SOURCE LINES 171-204 .. code-block:: Python n_iter = 100 t0 = time.perf_counter() for _ in range(n_iter): _ = torch.vmap(calculator.forward)( charges_batch, cell_batch, pos_batch, torch.stack((i_batch, j_batch), dim=-1), d_batch, periodic_batch, node_mask, pair_mask, kvectors, ) t_batch = (time.perf_counter() - t0) / n_iter t0 = time.perf_counter() for _ in range(n_iter): for k in range(len(pos_list)): _ = calculator.forward( charges_list[k], cell_list[k], pos_list[k], torch.stack((i_list[k], j_list[k]), dim=-1), d_list[k], periodic_list[k], ) t_loop = (time.perf_counter() - t0) / n_iter print(f"Average time per batched call: {t_batch:.6f} s") print(f"Average time per loop call: {t_loop:.6f} s") print("Batched is faster" if t_batch < t_loop else "Loop is faster") .. rst-class:: sphx-glr-script-out .. code-block:: none Average time per batched call: 0.001507 s Average time per loop call: 0.004710 s Batched is faster .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.633 seconds) .. _sphx_glr_download_examples_12-padding-example.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 12-padding-example.ipynb <12-padding-example.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 12-padding-example.py <12-padding-example.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 12-padding-example.zip <12-padding-example.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_