from iann.data import AtomsData, replace_properties
import torch
from torch import nn
import abc, math
from ase.data import atomic_numbers
from typing import Dict, List, Optional, Union, Callable
import warnings
import logging
logging.getLogger("cuequivariance").setLevel(logging.WARNING)
warnings.filterwarnings("ignore", message="The TorchScript type system doesn't support instance-level annotations")
warnings.filterwarnings("ignore", message="cuequivariance_ops_torch is not available")
warnings.filterwarnings("ignore", message="Fused TP is not supported on CPU")
from e3nn import o3
from e3nn.o3 import Linear, TensorProduct, FullyConnectedTensorProduct
from e3nn.nn import FullyConnectedNet, Gate, NormActivation
import torch.compiler
if not hasattr(torch.compiler, "is_compiling"):
torch.compiler.is_compiling = lambda: False
# Try to import cuEquivariance, fallback to e3nn if not available
try:
import cuequivariance as cue
import cuequivariance_torch as cuet
_HAS_CUEQUIVARIANCE = True
except (ImportError, SyntaxError, Exception) as e:
_HAS_CUEQUIVARIANCE = False
def resolve_cuequivariance(use_cuequivariance: Optional[bool] = None) -> bool:
if use_cuequivariance is None:
if _HAS_CUEQUIVARIANCE:
logging.info("cuEquivariance detected - using optimized operations")
return True
else:
logging.info("cuEquivariance not available - falling back to e3nn")
return False
if use_cuequivariance and not _HAS_CUEQUIVARIANCE:
raise ImportError("cuEquivariance requested but not available")
if use_cuequivariance:
logging.info("cuEquivariance enabled - using optimized operations")
return use_cuequivariance
class Transform(torch.nn.Module, metaclass=abc.ABCMeta):
def __init__(self) -> None:
super().__init__()
@abc.abstractmethod
def forward(self):
raise NotImplementedError
class TypeMapper(Transform):
def __init__(
self,
species: Optional[List[str]]=None,
symbol_to_type: Optional[Dict[str, int]]=None,
) -> None:
super().__init__()
if species is not None:
if symbol_to_type is not None:
raise TypeError("Cannot give both `species` and `symbol_to_type`")
numbers = [atomic_numbers[s] for s in species]
# sort chemical species
species = [e[1] for e in sorted(zip(numbers, species))]
symbol_to_type = {k: idx for idx, k in enumerate(species)}
self.symbol_to_type = symbol_to_type
if self.symbol_to_type is not None:
for sym, type in self.symbol_to_type.items():
assert sym in atomic_numbers, f"Invalid chemical symbol {sym}"
assert 0 <= type, f"Invalid type number {type}"
# 119 elements
Z_to_index = torch.full(size=(119,), fill_value=-1, dtype=torch.long)
for sym, type in self.symbol_to_type.items():
Z_to_index[atomic_numbers[sym]] = type
index_to_Z = torch.zeros(size=(len(self.symbol_to_type),), dtype=torch.long)
for sym, type in self.symbol_to_type.items():
index_to_Z[type] = atomic_numbers[sym]
self.register_buffer("Z_to_index", Z_to_index)
self.register_buffer("index_to_Z", index_to_Z)
else:
raise ValueError("`species` or `symbol_to_type` should be given!")
def forward(self, data: AtomsData) -> torch.Tensor:
atomic_types = self.transform(data.atomic_numbers)
assert torch.all(atomic_types >= 0), "Provided data contains species not defined in TypeMapper!"
return atomic_types
def transform(self, numbers: torch.Tensor) -> torch.Tensor:
if numbers.max() > 119 or numbers.min() < 1:
raise ValueError("Provided atomic numbers are not in the periodic table!")
types = self.Z_to_index[numbers]
return types
def untransform(self, types: torch.Tensor) -> torch.Tensor:
return self.index_to_Z[types]
class OneHotAtomEncoding(torch.nn.Module):
"""Copmute a one-hot floating point encoding of atoms' discrete atom types.
Args:
set_features: If ``True`` (default), ``node_features`` will be set in addition to ``node_attrs``.
"""
def __init__(
self,
num_elements: Optional[int] = None,
species: Optional[List[str]] = None,
set_features: bool=True,
):
super().__init__()
self.num_elements = num_elements
self.set_features = set_features
self.species = species
if self.species is not None:
self.type_mapper = TypeMapper(self.species)
self.num_elements = len(self.species)
else:
self.num_elements = 119
self.type_mapper = None
# output node feature irreps
self.irreps_out = {
'node_attr': o3.Irreps([(self.num_elements, o3.Irrep(0, 1))])
}
if self.set_features:
self.irreps_out['node_feat'] = self.irreps_out['node_attr']
def forward(self, data: AtomsData) -> AtomsData:
if self.type_mapper is not None:
atomic_types = self.type_mapper(data)
else:
atomic_types = data.atomic_numbers - 1
onehot = torch.nn.functional.one_hot(
atomic_types, num_classes=self.num_elements
).to(device=data.positions.device, dtype=data.positions.dtype)
data = replace_properties(data, node_attr=onehot)
if self.set_features:
data = replace_properties(data, node_feat=onehot)
return data
class AtomwiseLinear(torch.nn.Module):
def __init__(
self,
irreps_in: o3.Irreps,
irreps_out: o3.Irreps,
use_cuequivariance: bool = False,
):
super().__init__()
self.irreps_in: o3.Irreps = irreps_in
self.irreps_out: o3.Irreps = irreps_out
self.use_cuequivariance = use_cuequivariance
if self.use_cuequivariance:
self.linear = cuet.Linear(
irreps_in=cue.Irreps(cue.O3, self.irreps_in),
irreps_out=cue.Irreps(cue.O3, self.irreps_out),
layout=cue.mul_ir
)
else:
self.linear = o3.Linear(
irreps_in=self.irreps_in, irreps_out=self.irreps_out
)
def forward(self, data: AtomsData) -> AtomsData:
if data.node_feat is None:
raise RuntimeError("node_feat must not be None")
node_feat = data.node_feat
assert isinstance(node_feat, torch.Tensor)
node_feat = self.linear(node_feat)
data = replace_properties(data, node_feat=node_feat)
return data
class RadialBasis(torch.nn.Module, metaclass=abc.ABCMeta):
@abc.abstractmethod
def forward(self):
pass
class BesselBasis(RadialBasis):
def __init__(self, cutoff: float, num_basis: int=8, trainable: bool=True):
r"""Radial Bessel Basis, as proposed in DimeNet: https://arxiv.org/abs/2003.03123
Parameters
----------
cutoff : float
Cutoff radius
num_basis : int
Number of Bessel Basis functions
trainable : bool
Train the :math:`n \pi` part or not.
"""
super(BesselBasis, self).__init__()
self.trainable = trainable
self.num_basis = num_basis
self.cutoff = float(cutoff)
self.prefactor = 2.0 / self.cutoff
# output edge dist irreps
self.irreps_out = o3.Irreps([(num_basis, o3.Irrep(0, 1))])
bessel_weights = (
torch.linspace(start=1.0, end=num_basis, steps=num_basis) * math.pi
)
if self.trainable:
self.bessel_weights = nn.Parameter(bessel_weights)
else:
self.register_buffer("bessel_weights", bessel_weights)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Evaluate Bessel Basis for input x.
Parameters
----------
x : torch.Tensor
Input
"""
numerator = torch.sin(self.bessel_weights * x.unsqueeze(-1) / self.cutoff)
return self.prefactor * (numerator / x.unsqueeze(-1))
class CutoffFunction(torch.nn.Module, metaclass=abc.ABCMeta):
@abc.abstractmethod
def forward(self):
pass
def _poly_cutoff(x: torch.Tensor, factor: float, p: float = 6.0) -> torch.Tensor:
x = x * factor
out = 1.0
out = out - (((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(x, p))
out = out + (p * (p + 2.0) * torch.pow(x, p + 1.0))
out = out - ((p * (p + 1.0) / 2) * torch.pow(x, p + 2.0))
return out * (x < 1.0)
class PolynomialCutoff(CutoffFunction):
def __init__(self, cutoff: float, power: float = 6):
r"""Polynomial cutoff, as proposed in DimeNet: https://arxiv.org/abs/2003.03123
Parameters
----------
cutoff : float
Cutoff radius
power : int
Power used in envelope function
"""
super().__init__()
assert power >= 2.0
self.p = float(power)
self._factor = 1.0 / float(cutoff)
def forward(self, x):
"""
Evaluate cutoff function.
x: torch.Tensor, input distance
"""
return _poly_cutoff(x, self._factor, p=self.p)
class RadialBasisEdgeEncoding(torch.nn.Module):
def __init__(
self,
basis: RadialBasis,
cutoff_fn: CutoffFunction,
):
super().__init__()
self.basis = basis
self.cutoff_fn = cutoff_fn
# output edge dist irreps
self.irreps_out = self.basis.irreps_out
def forward(self, data: AtomsData) -> AtomsData:
edge_diff = data.edge_vectors
edge_dist = torch.linalg.norm(edge_diff, dim=1)
edge_dist_embedding = (
self.basis(edge_dist) * self.cutoff_fn(edge_dist)[:, None]
)
data = replace_properties(data, edge_dist_embedding=edge_dist_embedding)
return data
class SphericalHarmonicEdgeAttrs(torch.nn.Module):
def __init__(
self,
edge_sh_irreps: o3.Irreps,
edge_sh_normalization: str = "component",
edge_sh_normalize: bool = True,
):
super().__init__()
self.edge_sh_irreps = edge_sh_irreps
self.sh = o3.SphericalHarmonics(
self.edge_sh_irreps, edge_sh_normalize, edge_sh_normalization
)
# output edge diff irreps
self.irreps_out = edge_sh_irreps
def forward(self, data: AtomsData) -> AtomsData:
edge_diff_embedding = self.sh(
data.edge_vectors
)
data = replace_properties(data, edge_diff_embedding=edge_diff_embedding)
return data
def scatter_add(
x: torch.Tensor, index: torch.Tensor, dim_size: int, dim: int = 0
) -> torch.Tensor:
shape = list(x.shape)
shape[dim] = dim_size
tmp = torch.zeros(shape, dtype=x.dtype, device=x.device)
y = tmp.index_add(dim, index, x)
return y
def tp_path_exists(irreps_in1, irreps_in2, ir_out):
irreps_in1 = o3.Irreps(irreps_in1).simplify()
irreps_in2 = o3.Irreps(irreps_in2).simplify()
ir_out = o3.Irrep(ir_out)
for _, ir1 in irreps_in1:
for _, ir2 in irreps_in2:
if ir_out in ir1 * ir2:
return True
return False
class ConvNetLayer(torch.nn.Module):
def __init__(
self,
irreps_in,
irreps_out,
invariant_layers: int=1,
invariant_neurons: int=8,
avg_num_neighbors: Optional[float]=None,
use_sc: bool=True,
nonlinearity_scalars: Dict[int, Callable] = {"e": "ssp"},
use_cuequivariance: bool = False,
) -> None:
"""
Convolution Block.
:param irreps_in: Input irreps, including
:param irreps_out: Output irreps, in our case typically a single scalar
:param radial_layers: Number of radial layers, default = 1
:param radial_neurons: Number of hidden neurons in radial function, default = 8
:param avg_num_neighbors: Number of neighbors to divide by, default None => no normalization.
:param number_of_basis: Number or Basis function, default = 8
:param irreps_in: Input Features, default = None
:param use_sc: bool, use self-connection or not
"""
super().__init__()
if avg_num_neighbors is not None:
self._initialized = True
avg_num_neigh = torch.tensor([avg_num_neighbors])
else:
self._initialized = False
avg_num_neigh = torch.ones((1,))
# self._initialized = True if avg_num_neighbors is not None else False
# avg_num_neighbors = torch.ones((1,)) if avg_num_neighbors is None else torch.tensor([avg_num_neighbors])
self.register_buffer("avg_num_neighbors", avg_num_neigh)
self.use_sc = use_sc
self.use_cuequivariance = use_cuequivariance
feature_irreps_in = irreps_in['node_feat']
feature_irreps_out = irreps_out
edge_diff_irreps = irreps_in['edge_diff_embedding']
edge_dist_irreps = irreps_in['edge_dist_embedding']
# - Build modules -
if self.use_cuequivariance:
self.linear_1 = cuet.Linear(
irreps_in=cue.Irreps(cue.O3, feature_irreps_in),
irreps_out=cue.Irreps(cue.O3, feature_irreps_in),
layout=cue.mul_ir,
internal_weights=True,
shared_weights=True,
)
else:
self.linear_1 = o3.Linear(
irreps_in=feature_irreps_in,
irreps_out=feature_irreps_in,
internal_weights=True,
shared_weights=True,
)
irreps_mid = []
instructions = []
for i, (mul, ir_in) in enumerate(feature_irreps_in):
for j, (_, ir_edge) in enumerate(edge_diff_irreps):
for ir_out in ir_in * ir_edge:
if ir_out in feature_irreps_out:
k = len(irreps_mid)
irreps_mid.append((mul, o3.Irrep(ir_out)))
instructions.append((i, j, k, "uvu", True))
# We sort the output irreps of the tensor product so that we can simplify them
# when they are provided to the second o3.Linear
irreps_mid = o3.Irreps(irreps_mid)
irreps_mid, p, _ = irreps_mid.sort()
# Permute the output indexes of the instructions to match the sorted irreps:
instructions = [
(i_in1, i_in2, p[i_out], mode, train)
for i_in1, i_in2, i_out, mode, train in instructions
]
if self.use_cuequivariance:
tp = cuet.ChannelWiseTensorProduct(
irreps_in1=cue.Irreps(cue.O3, feature_irreps_in),
irreps_in2=cue.Irreps(cue.O3, edge_diff_irreps),
layout=cue.mul_ir,
filter_irreps_out=cue.Irreps(cue.O3, irreps_mid.simplify()),
shared_weights=False,
internal_weights=False,
)
else:
tp = o3.TensorProduct(
feature_irreps_in,
edge_diff_irreps,
irreps_mid,
instructions,
shared_weights=False,
internal_weights=False,
)
# init_irreps already confirmed that the edge embeddding is all invariant scalars
self.fc = FullyConnectedNet(
[edge_dist_irreps.num_irreps]
+ invariant_layers * [invariant_neurons]
+ [tp.weight_numel],
{
"ssp": ShiftedSoftPlus,
"silu": torch.nn.functional.silu,
}[nonlinearity_scalars["e"]],
)
self.tp = tp
if self.use_cuequivariance:
self.linear_2 = cuet.Linear(
irreps_in=cue.Irreps(cue.O3, irreps_mid.simplify()),
irreps_out=cue.Irreps(cue.O3, feature_irreps_out),
layout=cue.mul_ir,
internal_weights=True,
shared_weights=True,
)
else:
self.linear_2 = o3.Linear(
irreps_in=irreps_mid.simplify(),
irreps_out=feature_irreps_out,
internal_weights=True,
shared_weights=True,
)
self.sc = None
if self.use_sc:
if self.use_cuequivariance:
self.sc = cuet.FullyConnectedTensorProduct(
irreps_in1=cue.Irreps(cue.O3, feature_irreps_in),
irreps_in2=cue.Irreps(cue.O3, irreps_in['node_attr']),
irreps_out=cue.Irreps(cue.O3, feature_irreps_out),
layout=cue.mul_ir,
)
else:
self.sc = o3.FullyConnectedTensorProduct(
feature_irreps_in,
irreps_in['node_attr'],
feature_irreps_out,
)
def forward(self, data: AtomsData) -> AtomsData:
"""
Evaluate interaction Block with ResNet (self-connection).
:param node_input:
:param node_attr:
:param edge_src:
:param edge_dst:
:param edge_attr:
:param edge_length_embedded:
:return:
"""
node_feat = data.node_feat
assert node_feat is not None
node_attr = data.node_attr
assert node_attr is not None
edge_dist_embedding = data.edge_dist_embedding
assert edge_dist_embedding is not None
edge_diff_embedding = data.edge_diff_embedding
assert edge_diff_embedding is not None
edge_indices = data.edge_indices
weight = self.fc(edge_dist_embedding)
x = node_feat
if self.sc is not None:
sc = self.sc(x, node_attr)
x = self.linear_1(x)
edge_features = self.tp(
x[edge_indices[:, 1]], edge_diff_embedding, weight
)
x = scatter_add(edge_features, edge_indices[:, 0], dim_size=len(x), dim=0)
# Necessary to get TorchScript to be able to type infer when its not None
# avg_num_neigh: Optional[float] = self.avg_num_neighbors
# if avg_num_neigh is not None:
x = x.div(self.avg_num_neighbors**0.5)
x = self.linear_2(x)
if self.sc is not None:
x = x + sc
node_feat = x
data = replace_properties(data, node_feat=node_feat)
return data
def datamodule(self, _datamodule):
if not self._initialized:
avg_num_neigh = _datamodule._get_avg_num_neighbors()
if avg_num_neigh is not None:
self.avg_num_neighbors = torch.tensor([avg_num_neigh])
def ShiftedSoftPlus(x):
return torch.nn.functional.softplus(x) - torch.log(torch.tensor(2.0))
def tp_path_exists(irreps_in1, irreps_in2, ir_out):
irreps_in1 = o3.Irreps(irreps_in1).simplify()
irreps_in2 = o3.Irreps(irreps_in2).simplify()
ir_out = o3.Irrep(ir_out)
for _, ir1 in irreps_in1:
for _, ir2 in irreps_in2:
if ir_out in ir1 * ir2:
return True
return False
acts = {
"abs": torch.abs,
"tanh": torch.tanh,
"ssp": ShiftedSoftPlus,
"silu": torch.nn.functional.silu,
}
class InteractionLayer(torch.nn.Module):
"""
Args:
"""
def __init__(
self,
irreps_in,
feature_irreps_hidden,
convolution=ConvNetLayer,
convolution_kwargs: dict = {},
resnet: bool = False,
nonlinearity_type: str = "gate",
nonlinearity_scalars: Dict[int, Callable] = {"e": "ssp", "o": "tanh"},
nonlinearity_gates: Dict[int, Callable] = {"e": "ssp", "o": "abs"},
use_cuequivariance: bool = False,
):
super().__init__()
# initialization
self.use_cuequivariance = use_cuequivariance
assert nonlinearity_type in ("gate", "norm")
# make the nonlin dicts from parity ints instead of convinience strs
nonlinearity_scalars_dict = {
1: nonlinearity_scalars["e"],
-1: nonlinearity_scalars["o"],
}
nonlinearity_gates_dict = {
1: nonlinearity_gates["e"],
-1: nonlinearity_gates["o"],
}
self.feature_irreps_hidden = o3.Irreps(feature_irreps_hidden)
self.resnet = resnet
self.irreps_out = irreps_in.copy()
self.irreps_in = irreps_in
edge_diff_irreps = self.irreps_in['edge_diff_embedding']
irreps_layer_out_prev = self.irreps_in['node_feat']
irreps_scalars = o3.Irreps(
[
(mul, o3.Irrep(ir))
for mul, ir in self.feature_irreps_hidden
if ir.l == 0
and tp_path_exists(irreps_layer_out_prev, edge_diff_irreps, ir)
]
)
irreps_gated = o3.Irreps(
[
(mul, o3.Irrep(ir))
for mul, ir in self.feature_irreps_hidden
if ir.l > 0
and tp_path_exists(irreps_layer_out_prev, edge_diff_irreps, ir)
]
)
irreps_layer_out = (irreps_scalars + irreps_gated).simplify()
if nonlinearity_type == "gate":
ir = (
"0e"
if tp_path_exists(irreps_layer_out_prev, edge_diff_irreps, "0e")
else "0o"
)
irreps_gates = o3.Irreps([(mul, o3.Irrep(ir)) for mul, _ in irreps_gated])
# TO DO, it's not that safe to directly use the
# dictionary
equivariant_nonlin = Gate(
irreps_scalars=irreps_scalars,
act_scalars=[
acts[nonlinearity_scalars_dict[ir.p]] for _, ir in irreps_scalars
],
irreps_gates=irreps_gates,
act_gates=[acts[nonlinearity_gates_dict[ir.p]] for _, ir in irreps_gates],
irreps_gated=irreps_gated,
)
conv_irreps_out = equivariant_nonlin.irreps_in.simplify()
else:
conv_irreps_out = irreps_layer_out.simplify()
equivariant_nonlin = NormActivation(
irreps_in=conv_irreps_out,
# norm is an even scalar, so use nonlinearity_scalars[1]
scalar_nonlinearity=acts[nonlinearity_scalars_dict[1]],
normalize=True,
epsilon=1e-8,
bias=False,
)
self.equivariant_nonlin = equivariant_nonlin
if irreps_layer_out == irreps_layer_out_prev and resnet:
self.resnet = True
else:
self.resnet = False
# override defaults for irreps:
convolution_kwargs.pop("irreps_in", None)
convolution_kwargs.pop("irreps_out", None)
self.conv = convolution(
irreps_in=self.irreps_in,
irreps_out=conv_irreps_out,
nonlinearity_scalars=nonlinearity_scalars,
use_cuequivariance=use_cuequivariance,
**convolution_kwargs,
)
# output node feature irreps
self.irreps_out['node_feat'] = self.equivariant_nonlin.irreps_out
def forward(self, data: AtomsData):
# save old features for resnet
node_feat = data.node_feat
assert node_feat is not None
old_node_feat = node_feat.clone()
# run convolution
data = self.conv(data)
# do nonlinearity
node_feat = data.node_feat
assert node_feat is not None
node_feat = self.equivariant_nonlin(node_feat)
data = replace_properties(data, node_feat=node_feat)
if self.resnet:
node_feat = node_feat + old_node_feat
data = replace_properties(data, node_feat=node_feat)
return data
[docs]
class NequIP(torch.nn.Module):
"""
A class to set up the NequIP model.
"""
[docs]
def __init__(
self,
num_layers: int,
num_channels: int = 64,
norm_data: bool = False,
norm_per_atom: bool = False,
data_stddev: float = 1.0,
data_mean: float = 0.0,
**kwargs,
) -> None:
"""
Initialize the NequIP model.
"""
super().__init__()
self.cutoff:float = kwargs.get('cutoff', 5.5)
self.hidden_irreps:Union[o3.Irreps, str, None] = kwargs.get('hidden_irreps', None)
self.edge_sh_irreps:Union[o3.Irreps, str, None] = kwargs.get('edge_sh_irreps', None)
self.node_irreps:Union[o3.Irreps, str, None] = kwargs.get('node_irreps', None)
self.MLP_irreps:Union[o3.Irreps, str, None] = kwargs.get('MLP_irreps', None)
self.lmax:int = kwargs.get('lmax', 2)
self.parity:bool = kwargs.get('parity', True)
self.num_basis:int = kwargs.get('num_basis', 8)
self.power:int = kwargs.get('power', 6)
self.resnet:bool = kwargs.get('resnet', False)
self.nonlinearity_type:str = kwargs.get('nonlinearity_type', "gate")
self.nonlinearity_scalars:Dict[int, Callable] = kwargs.get('nonlinearity_scalars', {"e": "ssp", "o": "tanh"})
self.nonlinearity_gates:Dict[int, Callable] = kwargs.get('nonlinearity_gates', {"e": "ssp", "o": "abs"})
self.convolution_kwargs:dict = kwargs.get('convolution_kwargs', {})
self.use_cuequivariance: bool = resolve_cuequivariance(kwargs.get('use_cuequivariance', None))
species: List[str] = kwargs.get('species', None)
if bool(species):
num_elements = len(species)
else:
num_elements = 119
## handling irreps
# chemical embedding irreps
if self.node_irreps is None:
self.node_irreps = o3.Irreps([(num_channels, o3.Irrep(0, 1))])
elif isinstance(self.node_irreps, str):
self.node_irreps = o3.Irreps(self.node_irreps)
else:
self.node_irreps = self.node_irreps
# edge sphere harmonic irreps
if self.edge_sh_irreps is None:
self.edge_sh_irreps = o3.Irreps.spherical_harmonics(self.lmax, p=-1 if self.parity else 1)
elif isinstance(self.edge_sh_irreps, str):
self.edge_sh_irreps = o3.Irreps(self.edge_sh_irreps)
else:
self.edge_sh_irreps = self.edge_sh_irreps
# hidden feature irreps
if self.hidden_irreps is None:
self.hidden_irreps = o3.Irreps(
[
(num_channels, o3.Irrep(l, p))
for p in ((1, -1) if self.parity else (1,))
for l in range(self.lmax + 1)
]
)
elif isinstance(self.hidden_irreps, str):
self.hidden_irreps = o3.Irreps(self.hidden_irreps)
else:
self.hidden_irreps = self.hidden_irreps
# MLP_irreps
if self.MLP_irreps is None:
self.MLP_irreps = o3.Irreps([(max(1, num_channels // 2), o3.Irrep(0, 1))])
elif isinstance(self.MLP_irreps, str):
self.MLP_irreps = o3.Irreps(self.MLP_irreps)
else:
self.MLP_irreps = self.MLP_irreps
self.embeddings = nn.ModuleDict()
self.embeddings['onehot_embedding'] = OneHotAtomEncoding(num_elements=num_elements, species=species)
self.embeddings['radial_basis'] = RadialBasisEdgeEncoding(
basis=BesselBasis(cutoff=self.cutoff, num_basis=self.num_basis),
cutoff_fn=PolynomialCutoff(cutoff=self.cutoff, power=self.power),
)
self.embeddings['sphere_harmonics'] = SphericalHarmonicEdgeAttrs(edge_sh_irreps=self.edge_sh_irreps)
self.irreps_in = {
'edge_diff_embedding': self.embeddings.sphere_harmonics.irreps_out,
'edge_dist_embedding': self.embeddings.radial_basis.irreps_out,
}
self.irreps_in.update(self.embeddings.onehot_embedding.irreps_out)
self.embeddings['chemical_embedding'] = AtomwiseLinear(
irreps_in=self.irreps_in['node_attr'], # from OneHotAtomEncoding
irreps_out=self.node_irreps,
use_cuequivariance=self.use_cuequivariance,
)
self.irreps_in['node_feat'] = self.embeddings.chemical_embedding.irreps_out
self.interactions = nn.ModuleList()
for _ in range(num_layers):
interaction = InteractionLayer(
irreps_in=self.irreps_in,
feature_irreps_hidden=self.hidden_irreps,
convolution_kwargs=self.convolution_kwargs,
resnet=self.resnet,
nonlinearity_type=self.nonlinearity_type,
nonlinearity_scalars=self.nonlinearity_scalars,
nonlinearity_gates=self.nonlinearity_gates,
use_cuequivariance=self.use_cuequivariance,
)
self.interactions.append(interaction)
self.irreps_in.update(interaction.irreps_out)
if self.use_cuequivariance:
self.readout_mlp = nn.Sequential(
cuet.Linear(
irreps_in=cue.Irreps(cue.O3, self.irreps_in['node_feat']),
irreps_out=cue.Irreps(cue.O3, self.MLP_irreps),
layout=cue.mul_ir
),
cuet.Linear(
irreps_in=cue.Irreps(cue.O3, self.MLP_irreps),
irreps_out=cue.Irreps(cue.O3, '1x0e'),
layout=cue.mul_ir
),
)
else:
self.readout_mlp = nn.Sequential(
o3.Linear(
irreps_in=self.irreps_in['node_feat'],
irreps_out=self.MLP_irreps,
),
o3.Linear(
irreps_in=self.MLP_irreps,
irreps_out=o3.Irreps('1x0e'),
),
)
# Normalisation constants
self.norm_data = torch.nn.Parameter(torch.tensor(norm_data), requires_grad=False)
self.norm_per_atom = torch.nn.Parameter(torch.tensor(norm_per_atom), requires_grad=False)
self.data_stddev = torch.nn.Parameter(torch.tensor(data_stddev), requires_grad=False)
self.data_mean = torch.nn.Parameter(torch.tensor(data_mean), requires_grad=False)
self.atomwise_reduce = AtomwiseReduce(output_key='energy')
self.compute_forces = False
if 'compute_forces' in kwargs.keys():
if kwargs['compute_forces']:
self.compute_forces = True
self.gradient_output = GradientOutput(model_outputs=['forces'])
[docs]
def forward(self, data: AtomsData):
"""
Parameters
----------
data : AtomsData
Input data for the model.
Returns
-------
AtomsData
Output data after applying the model.
"""
for m in self.embeddings.values():
data = m(data)
for m in self.interactions:
data = m(data)
node_feat = data.node_feat
assert node_feat is not None
atomic_energy = self.readout_mlp(node_feat).reshape(-1)
assert atomic_energy is not None
data = replace_properties(data, atomic_energy=atomic_energy)
data = self.atomwise_reduce(data)
# de-normalization
if self.norm_data:
normalizer = self.data_stddev
energy = normalizer * data.energy
mean_shift = self.data_mean
if self.norm_per_atom:
mean_shift = len(data.edge_indices) * mean_shift
energy = energy + mean_shift
data = replace_properties(data, energy=energy)
if self.compute_forces:
data = self.gradient_output(data)
return data
def get_optimization_info(self):
"""Get information about optimization status"""
return {
"cuequivariance_available": _HAS_CUEQUIVARIANCE,
"optimization_enabled": self.use_cuequivariance,
"performance_boost": "2-5x speedup" if self.use_cuequivariance else "No optimization"
}
class AtomwiseReduce(nn.Module):
def __init__(
self,
output_key: str = "energy",
per_atom_output: bool = False,
aggregation_mode: str = "sum", # should be sum or mean
) -> None:
super().__init__()
self.aggregation_mode = aggregation_mode
self.per_atom_output = per_atom_output
def forward(self, data: AtomsData):
y = torch.zeros_like(
data.num_atoms,
dtype=data.edge_vectors.dtype,
device=data.edge_vectors.device
)
atomic_energy = data.atomic_energy
assert atomic_energy is not None
image_indices = data.image_indices
if image_indices is None:
image_indices = torch.zeros_like(atomic_energy, dtype=torch.long)
assert image_indices is not None
y.index_add_(0, image_indices.reshape(-1), atomic_energy.reshape(-1))
if self.aggregation_mode == "mean":
y = y / data.num_atoms
data = replace_properties(data, energy=y)
if self.per_atom_output:
assert atomic_energy is not None
data = replace_properties(data, atomic_energy=atomic_energy)
return data
class GradientOutput(torch.nn.Module):
def __init__(
self,
grad_on_edge_diff: bool = True,
grad_on_positions: bool = False,
model_outputs: List[str] = ['forces'],
update_callback: Optional[Callable] = None, # Add a callback parameter
) -> None:
# TODO: define a set for allowed model outputs
super().__init__()
self.grad_on_edge_diff = grad_on_edge_diff
self.grad_on_positions = grad_on_positions
self.update_callback = update_callback
self.model_outputs = model_outputs
def update_model_outputs(self, outputs: Union[List[str], str]):
if isinstance(outputs, str):
self.model_outputs.append(outputs)
else:
self.model_outputs.extend(outputs)
# update parent model
if self.update_callback:
self.update_callback()
def forward(self, data: AtomsData, training: bool=True,):
if self.grad_on_edge_diff:
energy = data.energy
edge_vectors = data.edge_vectors
forces_dim = int(torch.sum(data.num_atoms))
edge_indices = data.edge_indices
assert energy is not None
if 'forces' in self.model_outputs:
grad_outputs : List[Optional[torch.Tensor]] = [torch.ones_like(energy)] # for model deploy
dE_ddiff = torch.autograd.grad(
[energy,],
[edge_vectors,],
grad_outputs=grad_outputs,
retain_graph=training,
create_graph=training,
)
dE_ddiff = torch.zeros_like(data.positions) if dE_ddiff is None else dE_ddiff[0] # for torch.jit.script
assert dE_ddiff is not None
# diff = R_j - R_i, so -dE/dR_j = -dE/ddiff, -dE/R_i = dE/ddiff
i_forces = torch.zeros((forces_dim, 3), device=edge_vectors.device, dtype=torch.float32)
j_forces = torch.zeros_like(i_forces)
i_forces.index_add_(0, edge_indices[:, 0], dE_ddiff)
j_forces.index_add_(0, edge_indices[:, 1], -dE_ddiff)
forces = i_forces + j_forces
data = replace_properties(data, forces=forces)
return data