from iann.data import AtomsData, replace_properties
import torch
from torch import nn
from typing import List, Optional
from torch import Tensor
def sinc_expansion(edge_dist: torch.Tensor, edge_channels: int, cutoff: float):
"""
calculate sinc radial basis function:
sin(n *pi*d/d_cut)/d
"""
# n tensor
n = torch.arange(edge_channels, device=edge_dist.device, dtype=edge_dist.dtype) + 1
# Compute expansion
expanded = edge_dist.unsqueeze(-1) * n * torch.pi / cutoff
result = torch.sin(expanded) / edge_dist.unsqueeze(-1)
return result
def cosine_cutoff(edge_dist: torch.Tensor, cutoff: float):
"""
Calculate cutoff value based on distance.
This uses the cosine Behler-Parinello cutoff function:
f(d) = 0.5*(cos(pi*d/d_cut)+1) for d < d_cut and 0 otherwise
"""
return torch.where(
edge_dist < cutoff,
0.5 * (torch.cos(torch.pi * edge_dist / cutoff) + 1),
torch.tensor(0.0, device=edge_dist.device, dtype=edge_dist.dtype),
)
class PainnMessage(nn.Module):
"""Message function"""
def __init__(self, num_channels: int, edge_channels: int, cutoff: float):
super().__init__()
self.edge_channels = edge_channels
self.num_channels = num_channels
self.cutoff = cutoff
self.scalar_message_mlp = nn.Sequential(
nn.Linear(num_channels, num_channels),
nn.SiLU(),
nn.Linear(num_channels, num_channels * 3),
)
self.filter_layer = nn.Linear(edge_channels, num_channels * 3)
def forward(self, node_scalar, node_vector, edge_indices, edge_vectors, edge_dist):
# remember to use v_j, s_j but not v_i, s_i
filter_weight = self.filter_layer(sinc_expansion(edge_dist, self.edge_channels, self.cutoff))
filter_weight = filter_weight * cosine_cutoff(edge_dist, self.cutoff).unsqueeze(-1)
scalar_out = self.scalar_message_mlp(node_scalar)
filter_out = filter_weight * scalar_out[edge_indices[:, 1]]
gate_state_vector, gate_edge_vector, message_scalar = torch.split(
filter_out,
self.num_channels,
dim = 1,
)
# num_pairs * 3 * num_channels, num_pairs * num_channels
message_vector = node_vector[edge_indices[:, 1]] * gate_state_vector.unsqueeze(1)
edge_vector = gate_edge_vector.unsqueeze(1) * (edge_vectors / edge_dist.unsqueeze(-1)).unsqueeze(-1)
message_vector = message_vector + edge_vector
# sum message - keep contiguous() here for index_add_ operations
residual_scalar = torch.zeros_like(node_scalar).contiguous()
residual_vector = torch.zeros_like(node_vector).contiguous()
residual_scalar.index_add_(0, edge_indices[:, 0], message_scalar)
residual_vector.index_add_(0, edge_indices[:, 0], message_vector)
# new node state
new_node_scalar = node_scalar + residual_scalar
new_node_vector = node_vector + residual_vector
return new_node_scalar, new_node_vector
class PainnUpdate(nn.Module):
"""Update function"""
def __init__(self, num_channels: int):
super().__init__()
self.update_U = nn.Linear(num_channels, num_channels)
self.update_V = nn.Linear(num_channels, num_channels)
self.update_mlp = nn.Sequential(
nn.Linear(num_channels * 2, num_channels),
nn.SiLU(),
nn.Linear(num_channels, num_channels * 3),
)
def forward(self, node_scalar, node_vector):
# Linear transformations
Uv = self.update_U(node_vector)
Vv = self.update_V(node_vector)
# Compute norm
Vv_norm = torch.linalg.norm(Vv, dim=1)
mlp_input = torch.cat((Vv_norm, node_scalar), dim=1)
mlp_output = self.update_mlp(mlp_input)
# Split
a_vv, a_sv, a_ss = torch.split(
mlp_output,
node_vector.shape[-1],
dim = 1,
)
# Compute updates
delta_v = a_vv.unsqueeze(1) * Uv
inner_prod = torch.sum(Uv * Vv, dim=1)
delta_s = a_sv * inner_prod + a_ss
# Return updated states
return node_scalar + delta_s, node_vector + delta_v
[docs]
class PaiNN(nn.Module):
"""
A class to set up the PaiNN model.
"""
[docs]
def __init__(
self,
num_layers=3,
num_channels=128,
norm_data=True,
data_mean=[0.0],
data_stddev=[1.0],
norm_per_atom=True,
**kwargs,
):
"""
Initialize the PaiNN model.
"""
super().__init__()
num_embedding = kwargs.get('num_embedding', 119) # number of all elements
self.cutoff = kwargs.get('cutoff', 5.5)
self.num_layers = num_layers
self.num_channels = num_channels
self.edge_channels = kwargs.get('edge_channels', 20)
# Setup atom embeddings
self.atom_embedding = nn.Embedding(num_embedding, num_channels)
# Setup message-passing layers
self.message_layers = nn.ModuleList(
[
PainnMessage(self.num_channels, self.edge_channels, self.cutoff)
for _ in range(self.num_layers)
]
)
self.update_layers = nn.ModuleList(
[
PainnUpdate(self.num_channels)
for _ in range(self.num_layers)
]
)
# Setup readout function
self.readout_mlp = nn.Sequential(
nn.Linear(self.num_channels, self.num_channels),
nn.SiLU(),
nn.Linear(self.num_channels, 1),
)
# Normalisation constants
self.norm_data = torch.nn.Parameter(
torch.tensor(norm_data), requires_grad=False
)
self.norm_per_atom = torch.nn.Parameter(
torch.tensor(norm_per_atom), requires_grad=False
)
self.normalize_stddev = torch.nn.Parameter(
torch.tensor(data_stddev[0]), requires_grad=False
)
self.data_mean = torch.nn.Parameter(
torch.tensor(data_mean[0]), requires_grad=False
)
self.compute_forces = kwargs.get('compute_forces', False)
self.compute_stress = kwargs.get('compute_stress', False)
# Initialize parameters with proper memory layout
self.reset_parameters()
def reset_parameters(self):
"""Reset parameters to ensure proper memory layout."""
with torch.no_grad():
for param in self.parameters():
if param.requires_grad and param.dim() >= 2:
# Create a new tensor with the correct memory layout
new_data = torch.empty_like(param.data)
# Copy data with proper memory layout
new_data.copy_(param.data)
# Set strides to match DDP's expected layout
if param.dim() == 2:
# Create a new tensor with column-major layout
new_data = torch.empty(param.shape[1], param.shape[0], device=param.device, dtype=param.dtype)
new_data.copy_(param.data.t())
new_data = new_data.t()
param.data = new_data
def _make_contiguous(self, tensor: Optional[torch.Tensor]) -> torch.Tensor:
"""Ensure tensor has proper memory layout for DDP."""
if tensor is None:
raise ValueError("tensor is None in _make_contiguous")
if tensor.dim() == 2 and tensor is not None:
# Create a new tensor with column-major layout
new_tensor = torch.empty(tensor.shape[1], tensor.shape[0], device=tensor.device, dtype=tensor.dtype)
new_tensor.copy_(tensor.t())
return new_tensor.t()
return tensor.contiguous()
[docs]
def forward(self, data: AtomsData):
"""
Parameters
----------
data : AtomsData
Input data for the model.
Returns
-------
AtomsData
Output data after applying the model.
"""
num_atoms = data.num_atoms
num_edges = data.num_edges
positions = data.positions
edge_indices = data.edge_indices
atomic_numbers = data.atomic_numbers
edge_vectors = data.edge_vectors
edge_dist = torch.linalg.norm(edge_vectors, dim=1)
# Initialize node states
node_scalar = self.atom_embedding(atomic_numbers)
node_scalar = self._make_contiguous(node_scalar)
node_vector = torch.zeros((positions.shape[0], 3, self.num_channels),
device=positions.device,
dtype=torch.float32)
# Message passing iterations
for message_layer, update_layer in zip(self.message_layers, self.update_layers):
node_scalar, node_vector = message_layer(node_scalar, node_vector, edge_indices, edge_vectors, edge_dist)
node_scalar = self._make_contiguous(node_scalar)
node_vector = self._make_contiguous(node_vector)
node_scalar, node_vector = update_layer(node_scalar, node_vector)
node_scalar = self.readout_mlp(node_scalar)
node_scalar = self._make_contiguous(node_scalar)
node_scalar = node_scalar.squeeze(-1)
image_idx = torch.arange(num_atoms.shape[0],
device=edge_indices.device)
image_idx = torch.repeat_interleave(image_idx, num_atoms)
# Initialize energy with proper strides
energy = torch.zeros(num_atoms.shape[0], device=num_atoms.device, dtype=torch.float32)
energy.index_add_(0, image_idx, node_scalar)
atomic_energy = node_scalar
data = replace_properties(data, atomic_energy=atomic_energy)
# Apply (de-)norm_data
if self.norm_data:
normalizer = self.normalize_stddev
energy = self._make_contiguous(normalizer * energy)
mean_shift = self.data_mean
if self.norm_per_atom:
mean_shift = self._make_contiguous(num_edges * mean_shift)
energy = self._make_contiguous(energy + mean_shift)
data = replace_properties(data, energy=energy)
if self.compute_forces:
# TorchScript requires explicit list types for grad arguments
outputs_list = torch.jit.annotate(List[Tensor], [energy])
inputs_list = torch.jit.annotate(List[Tensor], [edge_vectors])
grad_outputs_list = torch.jit.annotate(Optional[List[Optional[Tensor]]], [torch.ones_like(energy)])
dE_ddiff = torch.autograd.grad(
outputs=outputs_list,
inputs=inputs_list,
grad_outputs=grad_outputs_list,
retain_graph=True,
create_graph=True,
)[0]
dE_ddiff = self._make_contiguous(dE_ddiff)
# Initialize forces with proper strides
i_forces = torch.zeros(positions.shape[0], 3, device=positions.device, dtype=torch.float32)
j_forces = torch.zeros(positions.shape[0], 3, device=positions.device, dtype=torch.float32)
i_forces.index_add_(0, edge_indices[:, 0], dE_ddiff)
j_forces.index_add_(0, edge_indices[:, 1], -dE_ddiff)
forces = self._make_contiguous(i_forces + j_forces)
data = replace_properties(data, forces=forces)
# if self.compute_stress:
# cell = data.cell
# volume = torch.abs(torch.det(cell))
# stress_virial = -torch.einsum('ij,ik->jk', forces, positions) / volume
# stress_virial_sym = 0.5 * (stress_virial + stress_virial.t())
# stress_virial_sym = self._make_contiguous(stress_virial_sym)
# data = replace_properties(data, stress=stress_virial_sym)
if self.compute_stress:
# Stress = -(1/V) * Σ_edges (edge_vector ⊗ dE_ddiff)
if not self.compute_forces:
raise ValueError("compute_forces must be True to compute stress")
cell = data.cell
if cell.dim() == 3 and cell.shape[0] == num_atoms.shape[0]: # Batched cells: (N, 3, 3)
cells_per_image = cell
elif cell.dim() == 2 and cell.shape == (3, 3): # Single cell: (3, 3) - expand to batch size
cells_per_image = cell.unsqueeze(0).expand(num_atoms.shape[0], -1, -1)
else:
raise ValueError("cell must be (N, 3, 3) or (3, 3)")
volumes = torch.abs(torch.det(cells_per_image)) # (N,)
stress_contrib = -torch.einsum('ij,ik->ijk', edge_vectors, dE_ddiff) # (num_edges, 3, 3)
stress_per_image = torch.zeros(num_atoms.shape[0], 3, 3, device=energy.device, dtype=torch.float32)
stress_contrib_flat = stress_contrib.view(stress_contrib.shape[0], -1) # (num_edges, 9)
stress_per_image_flat = stress_per_image.view(stress_per_image.shape[0], -1) # (N, 9)
stress_per_image_flat.index_add_(0, image_idx[edge_indices[:, 0]], stress_contrib_flat)
stress_per_image = stress_per_image_flat.view(stress_per_image.shape[0], 3, 3)
valid_volume_mask = volumes > 1e-10
volumes_expanded = volumes.unsqueeze(-1).unsqueeze(-1).expand(-1, 3, 3)
stress_per_image = torch.where(
valid_volume_mask.unsqueeze(-1).unsqueeze(-1).expand(-1, 3, 3),
stress_per_image / volumes_expanded,
torch.zeros_like(stress_per_image)
)
# Symmetrize stress tensor for all images at once
stress_per_image = (stress_per_image + stress_per_image.transpose(-1, -2)) / 2.0
if num_atoms.shape[0] == 1:
stress = stress_per_image[0]
else:
stress = stress_per_image
stress = self._make_contiguous(stress)
data = replace_properties(data, stress=stress)
return data