from iann.data import AtomsData, replace_properties
import torch
from torch import nn
from typing import List, Optional, Tuple
from torch import Tensor
from e3nn import o3
import math
import abc
class Interaction(nn.Module):
"""Enhanced message function with O3NN features for 2-body interactions"""
def __init__(self, embeddings: dict, num_channels: int):
super().__init__()
self.embeddings = embeddings
self.node_dim = embeddings['node'].num_channels
self.edge_dim = embeddings['edge'].num_channels
self.sh_dim = embeddings['angular'].num_channels
# self.global_dim = embeddings['global'].num_channels
self.num_channels = num_channels
self.msg_mlp = nn.Sequential(
nn.Linear(self.node_dim, self.num_channels),
nn.SiLU(),
nn.Linear(self.num_channels, self.num_channels * 1),
)
self.edge_linear = nn.Linear(self.edge_dim, self.num_channels * 1)
self.angular_linear = nn.Linear(self.sh_dim, self.num_channels * 1)
self.node_mlp = nn.Sequential(
nn.Linear(self.num_channels * 2, self.num_channels * 1),
nn.SiLU(),
nn.Linear(self.num_channels * 1, self.num_channels * 3),
)
self.linear_1 = nn.Linear(self.num_channels * 1, self.num_channels * 1)
self.linear_2 = nn.Linear(self.num_channels * 1, self.num_channels * 1)
def forward(self, node_features, edge_features, angular_features, edge_indices, edge_vectors, force_vector):
node_in = self.msg_mlp(node_features[edge_indices[:, 1]])
edge_mlp = self.edge_linear(edge_features)
angular_mlp = self.angular_linear(angular_features)
edge_in = angular_mlp * edge_mlp
edge_pass = node_in * edge_in
edge_dist = torch.linalg.norm(edge_vectors, dim=1)
edge_norm = (edge_vectors / edge_dist.unsqueeze(-1)).unsqueeze(-1)
edge_pass = edge_pass.unsqueeze(1) * edge_norm
edge_msg = torch.zeros_like(force_vector, device=edge_indices.device)
edge_msg.index_add_(0, edge_indices[:, 0], edge_pass.contiguous())
edge_msg += force_vector
edge_msg_1 = self.linear_1(edge_msg)
edge_msg_2 = self.linear_2(edge_msg)
edge_msg_norm1 = torch.linalg.norm(edge_msg_1, dim=1)
edge_msg_norm2 = torch.linalg.norm(edge_msg_2, dim=1)
edge_msg_prod = edge_msg_norm1 * edge_msg_norm2
node_msg = torch.cat((edge_msg_norm1, node_features), dim=1)
node_msg = self.node_mlp(node_msg)
node_msg1, node_msg2, node_msg3 = torch.split(node_msg, self.num_channels, dim=1)
residual_node = node_msg1 * edge_msg_prod + node_msg2
residual_edge = node_msg3.unsqueeze(1) * edge_msg_2 + force_vector
node_features = node_features + residual_node
force_vector = force_vector + residual_edge
return node_features.contiguous(), force_vector.contiguous()
class RadialBasis(torch.nn.Module, metaclass=abc.ABCMeta):
@abc.abstractmethod
def forward(self):
pass
class BesselBasis(RadialBasis):
def __init__(self, cutoff: float, num_basis: int=8, trainable: bool=True):
r"""
Radial Bessel Basis, as proposed in DimeNet: https://arxiv.org/abs/2003.03123
Parameters
----------
cutoff : float
Cutoff radius
num_basis : int
Number of Bessel Basis functions
trainable : bool
Train the :math:`n \pi` part or not.
"""
super(BesselBasis, self).__init__()
self.trainable = trainable
self.num_basis = num_basis
self.cutoff = float(cutoff)
self.prefactor = 2.0 / self.cutoff
# output edge dist irreps
self.irreps_out = o3.Irreps([(num_basis, o3.Irrep(0, 1))])
bessel_weights = (
torch.linspace(start=1.0, end=num_basis, steps=num_basis) * math.pi
)
if self.trainable:
self.bessel_weights = nn.Parameter(bessel_weights)
else:
self.register_buffer("bessel_weights", bessel_weights)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Evaluate Bessel Basis for input x.
Parameters
----------
x : torch.Tensor
Input
"""
numerator = torch.sin(self.bessel_weights * x.unsqueeze(-1) / self.cutoff)
return self.prefactor * (numerator / x.unsqueeze(-1))
def _poly_cutoff(x: torch.Tensor, factor: float, p: float = 6.0) -> torch.Tensor:
x = x * factor
out = 1.0
out = out - (((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(x, p))
out = out + (p * (p + 2.0) * torch.pow(x, p + 1.0))
out = out - ((p * (p + 1.0) / 2) * torch.pow(x, p + 2.0))
return out * (x < 1.0)
class CutoffFunction(torch.nn.Module, metaclass=abc.ABCMeta):
@abc.abstractmethod
def forward(self):
pass
class PolynomialCutoff(CutoffFunction):
def __init__(self, cutoff: float, power: float = 6):
r"""Polynomial cutoff, as proposed in DimeNet: https://arxiv.org/abs/2003.03123
Parameters
----------
cutoff : float
Cutoff radius
power : int
Power used in envelope function
"""
super().__init__()
assert power >= 2.0
self.p = float(power)
self._factor = 1.0 / float(cutoff)
def forward(self, x):
"""
Evaluate cutoff function.
x: torch.Tensor, input distance
"""
return _poly_cutoff(x, self._factor, p=self.p)
class NodeEmbedding(nn.Module):
"""Node embedding module"""
def __init__(self, num_channels: int, species: List[int]):
super().__init__()
self.species = species
self.num_channels = num_channels
if self.species is None:
self.num_embedding = 119
else:
self.num_embedding = len(self.species)
self.atom_embedding = nn.Embedding(self.num_embedding, self.num_channels)
def forward(self, data: AtomsData) -> AtomsData:
node_attr = self.atom_embedding(data.atomic_numbers)
data = replace_properties(data, node_attr=node_attr)
return data
class EdgeEmbedding(nn.Module):
"""Edge embedding module"""
def __init__(self, basis: nn.Module, cutoff_fn: nn.Module):
super().__init__()
self.basis = basis
self.cutoff_fn = cutoff_fn
self.num_channels = basis.num_basis
def forward(self, data: AtomsData) -> AtomsData:
edge_dist = torch.linalg.norm(data.edge_vectors, dim=1)
edge_dist_embedding = (self.basis(edge_dist) * self.cutoff_fn(edge_dist)[:, None])
data = replace_properties(data, edge_dist_embedding=edge_dist_embedding)
return data
class AngularEmbedding(nn.Module):
"""Angular embedding module"""
def __init__(self, edge_sh_irreps: o3.Irreps, edge_sh_normalize: bool=True, edge_sh_normalization: str='component'):
super().__init__()
self.sh = o3.SphericalHarmonics(edge_sh_irreps, edge_sh_normalize, edge_sh_normalization)
self.num_channels = edge_sh_irreps.dim
def forward(self, data: AtomsData) -> AtomsData:
edge_diff_embedding = self.sh(data.edge_vectors)
data = replace_properties(data, edge_diff_embedding=edge_diff_embedding)
return data
class GlobalEmbedding(nn.Module):
"""Global embedding module"""
def __init__(self, batch_size: int, num_channels: int, ):
super().__init__()
self.num_channels = num_channels
self.batch_size = batch_size
self.global_embedding = nn.Linear(8, num_channels * 1)
def forward(self, data: AtomsData) -> AtomsData:
global_attr = data.global_attr
global_attr = global_attr.reshape(-1, 8)
global_attr = torch.repeat_interleave(global_attr, data.num_atoms, dim=0)
global_embedding = self.global_embedding(global_attr)
data = replace_properties(data, global_embedding=global_embedding)
return data
[docs]
class FastPot(nn.Module):
"""
FastPot: Fast Potential with high-order tensor features and equivariant message passing for fast and accurate potential energy surface prediction.
"""
[docs]
def __init__(
self,
num_layers=3,
num_channels=128,
norm_data=True,
data_mean=[0.0],
data_stddev=[1.0],
norm_per_atom=True,
**kwargs,
):
"""
Initialize the FastPot model.
Parameters
----------
num_layers : int
Number of message passing layers
num_channels : int
Number of feature channels
cutoff : float
Interaction cutoff distance
lmax : int
Maximum spherical harmonic degree
"""
super().__init__()
self.cutoff = kwargs.get('cutoff', 5.5)
self.num_layers = num_layers
self.num_channels = num_channels
self.lmax = kwargs.get('lmax', 7)
self.species = kwargs.get('species', None)
self.num_basis: int = kwargs.get('num_basis', num_channels)
self.power: int = kwargs.get('power', 6)
self.batch_size = kwargs.get('batch_size', 12)
self.forces_scale = kwargs.get('forces_scale', 1.0)
self.embeddings = nn.ModuleDict()
# node embedding
self.embeddings['node'] = NodeEmbedding(self.num_channels, self.species)
# edge embedding
self.basis = BesselBasis(cutoff=self.cutoff, num_basis=self.num_basis)
self.cutoff_fn = PolynomialCutoff(cutoff=self.cutoff, power=self.power)
self.embeddings['edge'] = EdgeEmbedding(self.basis, self.cutoff_fn)
# angular embedding
self.edge_sh_irreps = o3.Irreps.spherical_harmonics(self.lmax, p=-1)
self.embeddings['angular'] = AngularEmbedding(self.edge_sh_irreps)
# global embedding
# self.embeddings['global'] = GlobalEmbedding(self.batch_size, self.num_channels)
# Setup message-passing layers
self.interaction_layers = nn.ModuleList([
Interaction(self.embeddings, self.num_channels) for _ in range(self.num_layers)
])
# Setup readout function
self.readout_mlp = nn.Sequential(
nn.Linear(self.num_channels, self.num_channels),
nn.SiLU(),
nn.Linear(self.num_channels, 1),
)
# self.force_mlp = nn.Sequential(
# nn.Linear(self.num_channels, self.num_channels),
# nn.SiLU(),
# nn.Linear(self.num_channels, self.num_channels // 2),
# nn.SiLU(),
# nn.Linear(self.num_channels // 2, 1),
# )
# Normalisation constants
self.norm_data = torch.nn.Parameter(
torch.tensor(norm_data), requires_grad=False
)
self.norm_per_atom = torch.nn.Parameter(
torch.tensor(norm_per_atom), requires_grad=False
)
self.normalize_stddev = torch.nn.Parameter(
torch.tensor(data_stddev[0]), requires_grad=False
)
self.data_mean = torch.nn.Parameter(
torch.tensor(data_mean[0]), requires_grad=False
)
self.compute_forces = False
if 'compute_forces' in kwargs.keys():
if kwargs['compute_forces']:
self.compute_forces = True
[docs]
def forward(self, data: AtomsData):
"""
Parameters
----------
data : AtomsData
Input data containing atomic information
Returns
-------
AtomsData
Output data with predicted energies and optionally forces
"""
num_atoms = data.num_atoms
num_edges = data.num_edges
positions = data.positions
edge_indices = data.edge_indices
edge_vectors = data.edge_vectors
# Process embeddings following MACE pattern
for embedding in self.embeddings.values():
data = embedding(data)
# Extract embeddings from data
node_features = data.node_attr # One-hot atomic features
edge_features = data.edge_dist_embedding # Radial basis features
angular_features = data.edge_diff_embedding # Spherical harmonics features
# global_features = data.global_embedding # Global features
force_vector = torch.zeros((positions.shape[0], 3, self.num_channels), device=positions.device, dtype=torch.float32)
# Message passing iterations
for layer_idx in range(self.num_layers):
node_features, force_vector = self.interaction_layers[layer_idx](node_features, edge_features, angular_features, edge_indices, edge_vectors, force_vector)
# Readout
node_features = self.readout_mlp(node_features)
node_features = node_features.squeeze()
# Ensure node_features is contiguous
node_features = node_features.contiguous()
# Aggregate atomic energies
image_idx = torch.arange(num_atoms.shape[0], device=edge_indices.device)
image_idx = torch.repeat_interleave(image_idx, num_atoms)
energy = torch.zeros(num_atoms.shape[0], device=num_atoms.device, dtype=torch.float32)
energy.index_add_(0, image_idx, node_features)
# Ensure energy is contiguous
energy = energy.contiguous()
atomic_energy = node_features
data = replace_properties(data, atomic_energy=atomic_energy)
# Apply normalization
if self.norm_data:
normalizer = self.normalize_stddev
energy = normalizer * energy
mean_shift = self.data_mean
if self.norm_per_atom:
mean_shift = num_edges * mean_shift
energy = energy + mean_shift
data = replace_properties(data, energy=energy)
# if self.compute_forces:
# forces = self.force_mlp(force_vector)
# forces = forces.squeeze() * self.forces_scale
# data = replace_properties(data, forces=forces)
# Force computation
if self.compute_forces:
outputs_list = torch.jit.annotate(List[Tensor], [energy])
inputs_list = torch.jit.annotate(List[Tensor], [edge_vectors.contiguous()])
grad_outputs_list = torch.jit.annotate(Optional[List[Optional[Tensor]]], [torch.ones_like(energy)])
dE_ddiff = torch.autograd.grad(
outputs=outputs_list,
inputs=inputs_list,
grad_outputs=grad_outputs_list,
retain_graph=True,
create_graph=True,
)[0]
# Ensure gradients are contiguous
dE_ddiff = dE_ddiff.contiguous()
# Initialize forces
i_forces = torch.zeros(positions.shape[0], 3, device=positions.device, dtype=torch.float32)
j_forces = torch.zeros(positions.shape[0], 3, device=positions.device, dtype=torch.float32)
i_forces.index_add_(0, edge_indices[:, 0], dE_ddiff)
j_forces.index_add_(0, edge_indices[:, 1], -dE_ddiff)
forces = i_forces + j_forces
# Ensure forces are contiguous
forces = forces.contiguous()
data = replace_properties(data, forces=forces)
return data