Source code for iann.models.mace

from iann.data import AtomsData, replace_properties
import torch
from torch import nn
import abc,warnings
from ase.data import atomic_numbers
import math
from typing import Optional, Callable, Tuple
from typing import List, Dict, Union
import collections
import warnings
import logging
logging.getLogger("cuequivariance").setLevel(logging.WARNING)
warnings.filterwarnings("ignore", message="The TorchScript type system doesn't support instance-level annotations")
# warnings.filterwarnings("ignore", message="cuequivariance_ops_torch is not available")
# warnings.filterwarnings("ignore", message="Fused TP is not supported on CPU")

from e3nn import o3
from e3nn.nn import FullyConnectedNet
from e3nn.nn import Activation
from e3nn.util.codegen import CodeGenMixin
import opt_einsum_fx

import torch.compiler
if not hasattr(torch.compiler, "is_compiling"):
    torch.compiler.is_compiling = lambda: False

# Try to import cuEquivariance, fallback to e3nn if not available
try: 
    import cuequivariance as cue
    import cuequivariance_torch as cuet
    _HAS_CUEQUIVARIANCE = True
except (ImportError, SyntaxError, Exception) as e:
    _HAS_CUEQUIVARIANCE = False

def resolve_cuequivariance(use_cuequivariance: Optional[bool] = None) -> bool:
    if use_cuequivariance is None:
        if _HAS_CUEQUIVARIANCE:
            logging.info("cuEquivariance detected - using optimized operations")
            return True
        else:
            logging.info("cuEquivariance not available - falling back to e3nn")
            return False
    if use_cuequivariance and not _HAS_CUEQUIVARIANCE:
        raise ImportError("cuEquivariance requested but not available")
    
    if use_cuequivariance:
        logging.info("cuEquivariance enabled - using optimized operations")
    
    return use_cuequivariance

activation_fn = {
    "silu": torch.nn.SiLU(),
    "tanh": torch.tanh,
    "abs": torch.abs,
    "None": None,
}

class Transform(torch.nn.Module, metaclass=abc.ABCMeta):
    def __init__(self) -> None:
        super().__init__()
    
    @abc.abstractmethod
    def forward(self):
        raise NotImplementedError

class TypeMapper(Transform):
    def __init__(
        self,
        species: Optional[List[str]]=None,
        symbol_to_type: Optional[Dict[str, int]]=None,
    ) -> None:
        super().__init__()
        if species is not None:
            if symbol_to_type is not None:
                raise TypeError("Cannot give both `species` and `symbol_to_type`")
            numbers = [atomic_numbers[s] for s in species]
            species = [e[1] for e in sorted(zip(numbers, species))]
            symbol_to_type = {k: idx for idx, k in enumerate(species)}
        self.symbol_to_type = symbol_to_type
        
        if self.symbol_to_type is not None:
            for sym, type in self.symbol_to_type.items():
                assert sym in atomic_numbers, f"Invalid chemical symbol {sym}"
                assert 0 <= type, f"Invalid type number {type}"
            # 119 elements
            Z_to_index = torch.full(size=(119,), fill_value=-1, dtype=torch.long)
            for sym, type in self.symbol_to_type.items():
                Z_to_index[atomic_numbers[sym]] = type
            index_to_Z = torch.zeros(size=(len(self.symbol_to_type),), dtype=torch.long)
            for sym, type in self.symbol_to_type.items():
                index_to_Z[type] = atomic_numbers[sym]
        
            self.register_buffer("Z_to_index", Z_to_index)
            self.register_buffer("index_to_Z", index_to_Z)
        else:
            raise ValueError("`species` or `symbol_to_type` should be given!")
        
    def forward(self, data):      
        data.atomic_types = self.transform(data.atomic_numbers)
        assert torch.all(data.atomic_types >= 0), "Provided data contains species not defined in TypeMapper!"
        return data
        
    def transform(self, numbers: torch.Tensor) -> torch.Tensor:
        if numbers.max() > 119 or numbers.min() < 1:
            raise ValueError("Provided atomic numbers are not in the periodic table!")
        types = self.Z_to_index[numbers]
        return types
    
    def untransform(self, types: torch.Tensor) -> torch.Tensor:
        return self.index_to_Z[types]

class OneHotAtomEncoding(torch.nn.Module):
    """Copmute a one-hot floating point encoding of atoms' discrete atom types.

    Args:
        set_features: If ``True`` (default), ``node_features`` will be set in addition to ``node_attrs``.
    """

    def __init__(
        self,
        num_elements: Optional[int] = None,
        species: Optional[List[str]] = None,
        set_features: bool=True,
    ):
        super().__init__()
        self.num_elements = num_elements
        self.set_features = set_features
        self.species = species
        
        if self.species is not None:
            self.type_mapper = TypeMapper(self.species)
            self.num_elements = len(self.species)
        else:
            self.num_elements = 119
            self.type_mapper = None
        # output node feature irreps
        self.irreps_out = {
            'node_attr': o3.Irreps([(self.num_elements, (0, 1))])
        }
        if self.set_features:
            self.irreps_out['node_feat'] = self.irreps_out['node_attr']
            
    def forward(self, data: AtomsData):
        if self.type_mapper is not None:
            atomic_types = self.type_mapper(data)
        else:
            atomic_types = data.atomic_numbers - 1
    
        onehot = torch.nn.functional.one_hot(
            atomic_types, num_classes=self.num_elements
        ).to(device=data.positions.device, dtype=data.positions.dtype)
        
        data = replace_properties(data, node_attr=onehot)

        if self.set_features:
            data = replace_properties(data, node_feat=onehot)
        return data
    
class AtomwiseLinear(torch.nn.Module):
    def __init__(
        self,
        irreps_in: Optional[o3.Irreps]=None,
        irreps_out: Optional[o3.Irreps]=None,
        use_cuequivariance: bool = False,
    ):
        super().__init__()
        self.irreps_in: Optional[o3.Irreps] = irreps_in
        if irreps_out is None:
            irreps_out = irreps_in
        self.irreps_out = irreps_out
        self.use_cuequivariance = use_cuequivariance
        
        if self.use_cuequivariance:
            self.linear = cuet.Linear(
                irreps_in=cue.Irreps(cue.O3, self.irreps_in), 
                irreps_out=cue.Irreps(cue.O3, self.irreps_out),
                layout=cue.mul_ir
            )
        else:
            self.linear = o3.Linear(
                irreps_in=self.irreps_in, irreps_out=self.irreps_out
            )

    def forward(self, data: AtomsData):
        if data.node_feat is None:
            raise RuntimeError("node_feat must not be None")
        node_feat = data.node_feat
        assert isinstance(node_feat, torch.Tensor)
        node_feat = self.linear(node_feat)
        data = replace_properties(data, node_feat=node_feat)
        return data

class AtomwiseNonLinear(torch.nn.Module):
    def __init__(
        self,
        irreps_in: o3.Irreps,
        MLP_irreps: o3.Irreps,
        gate: Optional[Callable],
        irreps_out: o3.Irreps=o3.Irreps("1x0e"),
        use_cuequivariance: bool = False,
    ):
        super().__init__()
        self.MLP_irreps = MLP_irreps
        self.use_cuequivariance = use_cuequivariance
        if self.use_cuequivariance:
            self.linear_1 = cuet.Linear(
                irreps_in=cue.Irreps(cue.O3, irreps_in), 
                irreps_out=cue.Irreps(cue.O3, self.MLP_irreps),
                layout=cue.mul_ir
            )
            self.linear_2 = cuet.Linear(
                irreps_in=cue.Irreps(cue.O3, self.MLP_irreps), 
                irreps_out=cue.Irreps(cue.O3, irreps_out),
                layout=cue.mul_ir
            )
        else:
            self.linear_1 = o3.Linear(irreps_in=irreps_in, irreps_out=self.MLP_irreps)
            self.linear_2 = o3.Linear(
                irreps_in=self.MLP_irreps, irreps_out=irreps_out
            )
        self.non_linearity = Activation(irreps_in=self.MLP_irreps, acts=[gate])

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # [n_nodes, irreps]  # [..., ]
        x = self.non_linearity(self.linear_1(x))
        return self.linear_2(x)  # [n_nodes, 1]

class RadialBasis(torch.nn.Module, metaclass=abc.ABCMeta):
    @abc.abstractmethod
    def forward(self):
        pass

class BesselBasis(RadialBasis):
    def __init__(self, cutoff: float, num_basis: int=8, trainable: bool=True):
        r"""Radial Bessel Basis, as proposed in DimeNet: https://arxiv.org/abs/2003.03123


        Parameters
        ----------
        cutoff : float
            Cutoff radius

        num_basis : int
            Number of Bessel Basis functions

        trainable : bool
            Train the :math:`n \pi` part or not.
        """
        super(BesselBasis, self).__init__()

        self.trainable = trainable
        self.num_basis = num_basis

        self.cutoff = float(cutoff)
        self.prefactor = 2.0 / self.cutoff
        # output edge dist irreps
        self.irreps_out = o3.Irreps([(num_basis, o3.Irrep(0, 1))])

        bessel_weights = (
            torch.linspace(start=1.0, end=num_basis, steps=num_basis) * math.pi
        )
        if self.trainable:
            self.bessel_weights = nn.Parameter(bessel_weights)
        else:
            self.register_buffer("bessel_weights", bessel_weights)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Evaluate Bessel Basis for input x.

        Parameters
        ----------
        x : torch.Tensor
            Input
        """
        numerator = torch.sin(self.bessel_weights * x.unsqueeze(-1) / self.cutoff)

        return self.prefactor * (numerator / x.unsqueeze(-1))

def _poly_cutoff(x: torch.Tensor, factor: float, p: float = 6.0) -> torch.Tensor:
    x = x * factor

    out = 1.0
    out = out - (((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(x, p))
    out = out + (p * (p + 2.0) * torch.pow(x, p + 1.0))
    out = out - ((p * (p + 1.0) / 2) * torch.pow(x, p + 2.0))

    return out * (x < 1.0)

class CutoffFunction(torch.nn.Module, metaclass=abc.ABCMeta):
    @abc.abstractmethod
    def forward(self):
        pass

class PolynomialCutoff(CutoffFunction):
    def __init__(self, cutoff: float, power: float = 6):
        r"""Polynomial cutoff, as proposed in DimeNet: https://arxiv.org/abs/2003.03123


        Parameters
        ----------
        cutoff : float
            Cutoff radius

        power : int
            Power used in envelope function
        """
        super().__init__()
        assert power >= 2.0
        self.p = float(power)
        self._factor = 1.0 / float(cutoff)

    def forward(self, x):
        """
        Evaluate cutoff function.

        x: torch.Tensor, input distance
        """
        return _poly_cutoff(x, self._factor, p=self.p)
    
class RadialBasisEdgeEncoding(torch.nn.Module):
    def __init__(
        self,
        basis: RadialBasis,
        cutoff_fn: CutoffFunction,
    ):
        super().__init__()
        self.basis = basis
        self.cutoff_fn = cutoff_fn
        
        # output edge dist irreps
        self.irreps_out = self.basis.irreps_out

    def forward(self, data: AtomsData):
        edge_vectors = data.edge_vectors
        edge_dist = torch.linalg.norm(edge_vectors, dim=1)
        edge_dist_embedding = (
            self.basis(edge_dist) * self.cutoff_fn(edge_dist)[:, None]
        )
        data = replace_properties(data, edge_dist_embedding=edge_dist_embedding)

        return data

class SphericalHarmonicEdgeAttrs(torch.nn.Module):
    def __init__(
        self,
        edge_sh_irreps: o3.Irreps,
        edge_sh_normalization: str = "component",
        edge_sh_normalize: bool = True,
    ):
        super().__init__()
        
        self.edge_sh_irreps = edge_sh_irreps
        self.sh = o3.SphericalHarmonics(
            self.edge_sh_irreps, edge_sh_normalize, edge_sh_normalization
        )
        # output edge diff irreps
        self.irreps_out = edge_sh_irreps

    def forward(self, data: AtomsData):
        edge_diff_embedding = self.sh(
            data.edge_vectors
        )
        data = replace_properties(data, edge_diff_embedding=edge_diff_embedding)
        
        return data

# Based on mir-group/nequip
def tp_out_irreps_with_instructions(
    irreps1: o3.Irreps, irreps2: o3.Irreps, target_irreps: o3.Irreps
) -> Tuple[o3.Irreps, List]:
    trainable = True

    # Collect possible irreps and their instructions
    irreps_out_list: List[Tuple[int, o3.Irreps]] = []
    instructions = []
    for i, (mul, ir_in) in enumerate(irreps1):
        for j, (_, ir_edge) in enumerate(irreps2):
            for ir_out in ir_in * ir_edge:  # | l1 - l2 | <= l <= l1 + l2
                if ir_out in target_irreps:
                    k = len(irreps_out_list)  # instruction index
                    irreps_out_list.append((mul, ir_out))
                    instructions.append((i, j, k, "uvu", trainable))

    # We sort the output irreps of the tensor product so that we can simplify them
    # when they are provided to the second o3.Linear
    irreps_out = o3.Irreps(irreps_out_list)
    irreps_out, permut, _ = irreps_out.sort()

    # Permute the output indexes of the instructions to match the sorted irreps:
    instructions = [
        (i_in1, i_in2, permut[i_out], mode, train)
        for i_in1, i_in2, i_out, mode, train in instructions
    ]

    instructions = sorted(instructions, key=lambda x: x[2])

    return irreps_out, instructions

class reshape_irreps(torch.nn.Module):
    def __init__(self, irreps: o3.Irreps) -> None:
        super().__init__()
        self.irreps = o3.Irreps(irreps)
        self.dims = []
        self.muls = []
        for mul, ir in self.irreps:
            d = ir.dim
            self.dims.append(d)
            self.muls.append(mul)

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        ix = 0
        out = []
        batch, _ = tensor.shape
        for mul, d in zip(self.muls, self.dims):
            field = tensor[:, ix : ix + mul * d]  # [batch, sample, mul * repr]
            ix += mul * d
            field = field.reshape(batch, mul, d)
            out.append(field)
        return torch.cat(out, dim=-1)

def scatter_add(
    x: torch.Tensor, index: torch.Tensor, dim_size: int, dim: int = 0
) -> torch.Tensor:
    shape = list(x.shape)
    shape[dim] = dim_size
    tmp = torch.zeros(shape, dtype=x.dtype, device=x.device)
    y = tmp.index_add(dim, index, x)
    return y

class RealAgnosticResidualInteractionBlock(torch.nn.Module):
    def __init__(
        self,
        irreps_in, 
        target_irreps,
        hidden_irreps,
        avg_num_neighbors: Optional[float] = None,
        use_cuequivariance: bool = False,
    ) -> None:
        super().__init__()
        self.irreps_in = irreps_in
        self.target_irreps = target_irreps
        self.hidden_irreps = hidden_irreps
        self._initialized = True if avg_num_neighbors is not None else False
        self.use_cuequivariance = use_cuequivariance
        avg_num_neighbors = torch.ones((1,)) if avg_num_neighbors is None else torch.tensor([avg_num_neighbors])
        self.register_buffer("avg_num_neighbors", avg_num_neighbors) 

        # First linear
        if self.use_cuequivariance:
            self.linear_1 = cuet.Linear(
                irreps_in=cue.Irreps(cue.O3, self.irreps_in['node_feat']),
                irreps_out=cue.Irreps(cue.O3, self.irreps_in['node_feat']),
                layout=cue.mul_ir,
                internal_weights=True,
                shared_weights=True,
            )
        else:
            self.linear_1 = o3.Linear(
                self.irreps_in['node_feat'],
                self.irreps_in['node_feat'],
                internal_weights=True,
                shared_weights=True,
            )
        
        irreps_mid, instructions = tp_out_irreps_with_instructions(
            self.irreps_in['node_feat'],
            self.irreps_in['edge_diff_embedding'],
            self.target_irreps,
        )
        
        if self.use_cuequivariance:
            self.conv_tp = cuet.ChannelWiseTensorProduct(
                irreps_in1=cue.Irreps(cue.O3, self.irreps_in['node_feat']),
                irreps_in2=cue.Irreps(cue.O3, self.irreps_in['edge_diff_embedding']),
                layout=cue.mul_ir,
                filter_irreps_out=cue.Irreps(cue.O3, irreps_mid.simplify()),
                shared_weights=False,
                internal_weights=False,
            )
        else:
            self.conv_tp = o3.TensorProduct(
                self.irreps_in['node_feat'],
                self.irreps_in['edge_diff_embedding'],
                irreps_mid,
                instructions=instructions,
                shared_weights=False,
                internal_weights=False,
            )

        # Convolution weights
        input_dim = self.irreps_in['edge_dist_embedding'].num_irreps
        self.conv_tp_weights = FullyConnectedNet(
            [input_dim] + 3 * [64] + [self.conv_tp.weight_numel],
            torch.nn.functional.silu,
        )

        # Linear
        irreps_mid = irreps_mid.simplify()
        self.irreps_out = self.target_irreps
        if self.use_cuequivariance:
            self.linear_2 = cuet.Linear(
                irreps_in=cue.Irreps(cue.O3, irreps_mid), 
                irreps_out=cue.Irreps(cue.O3, self.irreps_out),
                layout=cue.mul_ir,
                internal_weights=True,
                shared_weights=True,
            )
        else:
            self.linear_2 = o3.Linear(
                irreps_mid, self.irreps_out, internal_weights=True, shared_weights=True
            )

        # Selector TensorProduct
        if self.use_cuequivariance:
            self.skip_tp = cuet.FullyConnectedTensorProduct(
                irreps_in1=cue.Irreps(cue.O3, self.irreps_in['node_feat']), 
                irreps_in2=cue.Irreps(cue.O3, self.irreps_in['node_attr']),
                irreps_out=cue.Irreps(cue.O3, self.hidden_irreps),
                layout=cue.mul_ir,
            )
        else:
            self.skip_tp = o3.FullyConnectedTensorProduct(
                self.irreps_in['node_feat'], 
                self.irreps_in['node_attr'],
                self.hidden_irreps,
            )
        self.reshape = reshape_irreps(self.irreps_out)

    def forward(
        self,
        node_feat, 
        node_attr,
        edge_idx, 
        edge_dist_embedding,
        edge_diff_embedding,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:      
        sc = self.skip_tp(node_feat, node_attr)
        node_feat = self.linear_1(node_feat)
        tp_weights = self.conv_tp_weights(edge_dist_embedding)
        edge_feat = self.conv_tp(
            node_feat[edge_idx[:, 0]],
            edge_diff_embedding,
            tp_weights,
        )
        node_feat = scatter_add(
            edge_feat, edge_idx[:, 0], dim_size=len(node_feat), dim=0
        )
        node_feat = self.linear_2(node_feat)
        node_feat = node_feat / self.avg_num_neighbors
        
        return (self.reshape(node_feat), sc)
    
    def datamodule(self, _datamodule):
        if not self._initialized:
            avg_num_neigh = _datamodule._get_avg_num_neighbors()
            if avg_num_neigh is not None:
                self.avg_num_neighbors = torch.tensor([avg_num_neigh])

_TP = collections.namedtuple("_TP", "op, args")
_INPUT = collections.namedtuple("_INPUT", "tensor, start, stop")

def _wigner_nj(
    irrepss: List[o3.Irreps],
    normalization: str = "component",
    filter_ir_mid=None,
    dtype=None,
):
    irrepss = [o3.Irreps(irreps) for irreps in irrepss]
    if filter_ir_mid is not None:
        filter_ir_mid = [o3.Irrep(ir) for ir in filter_ir_mid]

    if len(irrepss) == 1:
        (irreps,) = irrepss
        ret = []
        e = torch.eye(irreps.dim, dtype=dtype)
        i = 0
        for mul, ir in irreps:
            for _ in range(mul):
                sl = slice(i, i + ir.dim)
                ret += [(ir, _INPUT(0, sl.start, sl.stop), e[sl])]
                i += ir.dim
        return ret

    *irrepss_left, irreps_right = irrepss
    ret = []
    for ir_left, path_left, C_left in _wigner_nj(
        irrepss_left,
        normalization=normalization,
        filter_ir_mid=filter_ir_mid,
        dtype=dtype,
    ):
        i = 0
        for mul, ir in irreps_right:
            for ir_out in ir_left * ir:
                if filter_ir_mid is not None and ir_out not in filter_ir_mid:
                    continue

                C = o3.wigner_3j(ir_out.l, ir_left.l, ir.l, dtype=dtype)
                if normalization == "component":
                    C *= ir_out.dim**0.5
                if normalization == "norm":
                    C *= ir_left.dim**0.5 * ir.dim**0.5

                C = torch.einsum("jk,ijl->ikl", C_left.flatten(1), C)
                C = C.reshape(
                    ir_out.dim, *(irreps.dim for irreps in irrepss_left), ir.dim
                )
                for u in range(mul):
                    E = torch.zeros(
                        ir_out.dim,
                        *(irreps.dim for irreps in irrepss_left),
                        irreps_right.dim,
                        dtype=dtype,
                    )
                    sl = slice(i + u * ir.dim, i + (u + 1) * ir.dim)
                    E[..., sl] = C
                    ret += [
                        (
                            ir_out,
                            _TP(
                                op=(ir_left, ir, ir_out),
                                args=(
                                    path_left,
                                    _INPUT(len(irrepss_left), sl.start, sl.stop),
                                ),
                            ),
                            E,
                        )
                    ]
            i += mul * ir.dim
    return sorted(ret, key=lambda x: x[0])


def U_matrix_real(
    irreps_in: Union[str, o3.Irreps],
    irreps_out: Union[str, o3.Irreps],
    correlation: int,
    normalization: str = "component",
    filter_ir_mid=None,
    dtype=None,
):
    irreps_out = o3.Irreps(irreps_out)
    irrepss = [o3.Irreps(irreps_in)] * correlation
    if correlation == 4:
        filter_ir_mid = [
            o3.Irrep(0, 1),
            o3.Irrep(1, -1),
            o3.Irrep(2, 1),
            o3.Irrep(3, -1),
            o3.Irrep(4, 1),
            o3.Irrep(5, -1),
            o3.Irrep(6, 1),
            o3.Irrep(7, -1),
            o3.Irrep(8, 1),
            o3.Irrep(9, -1),
            o3.Irrep(10, 1),
            o3.Irrep(11, -1),
        ]
    wigners = _wigner_nj(irrepss, normalization, filter_ir_mid, dtype)
    current_ir = wigners[0][0]
    out = []
    stack = torch.tensor([])
    for ir, _, base_o3 in wigners:
        if ir in irreps_out and ir == current_ir:
            stack = torch.cat((stack, base_o3.squeeze().unsqueeze(-1)), dim=-1)
            last_ir = current_ir
        elif ir in irreps_out and ir != current_ir:
            if len(stack) != 0:
                out += [last_ir, stack]
            stack = base_o3.squeeze().unsqueeze(-1)
            current_ir, last_ir = ir, ir
        else:
            current_ir = ir
    out += [last_ir, stack]
    return out

BATCH_EXAMPLE = 10
ALPHABET = ["w", "x", "v", "n", "z", "r", "t", "y", "u", "o", "p", "s"]

class Contraction(torch.nn.Module):
    def __init__(
        self,
        irreps_in: o3.Irreps,
        irrep_out: o3.Irreps,
        correlation: int,
        internal_weights: bool = True,
        num_elements: Optional[int] = None,
        weights: Optional[torch.Tensor] = None,
    ) -> None:
        super().__init__()

        self.num_channels = irreps_in.count((0, 1))
        self.coupling_irreps = o3.Irreps([irrep.ir for irrep in irreps_in])
        self.correlation = correlation
        dtype = torch.get_default_dtype()
        for nu in range(1, correlation + 1):
            U_matrix = U_matrix_real(
                irreps_in=self.coupling_irreps,
                irreps_out=irrep_out,
                correlation=nu,
                dtype=dtype,
            )[-1]
            self.register_buffer(f"U_matrix_{nu}", U_matrix)

        # Tensor contraction equations
        self.contractions_weighting = torch.nn.ModuleList()
        self.contractions_features = torch.nn.ModuleList()

        # Create weight for product basis
        self.weights = torch.nn.ParameterList([])

        for i in range(correlation, 0, -1):
            # Shapes definying
            num_params = self.U_tensors(i).size()[-1]
            num_equivariance = 2 * irrep_out.lmax + 1
            num_ell = self.U_tensors(i).size()[-2]

            if i == correlation:
                parse_subscript_main = (
                    [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)]
                    + ["ik,ekc,bci,be -> bc"]
                    + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1) - 1)]
                )
                graph_module_main = torch.fx.symbolic_trace(
                    lambda x, y, w, z: torch.einsum(
                        "".join(parse_subscript_main), x, y, w, z
                    )
                )

                # Optimizing the contractions
                self.graph_opt_main = opt_einsum_fx.optimize_einsums_full(
                    model=graph_module_main,
                    example_inputs=(
                        torch.randn(
                            [num_equivariance] + [num_ell] * i + [num_params]
                        ).squeeze(0),
                        torch.randn((num_elements, num_params, self.num_channels)),
                        torch.randn((BATCH_EXAMPLE, self.num_channels, num_ell)),
                        torch.randn((BATCH_EXAMPLE, num_elements)),
                    ),
                )
                # Parameters for the product basis
                w = torch.nn.Parameter(
                    torch.randn((num_elements, num_params, self.num_channels))
                    / num_params
                )
                self.weights_max = w
            else:
                # Generate optimized contractions equations
                parse_subscript_weighting = (
                    [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))]
                    + ["k,ekc,be->bc"]
                    + [ALPHABET[j] for j in range(i + min(irrep_out.lmax, 1))]
                )
                parse_subscript_features = (
                    ["bc"]
                    + [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))]
                    + ["i,bci->bc"]
                    + [ALPHABET[j] for j in range(i - 1 + min(irrep_out.lmax, 1))]
                )

                # Symbolic tracing of contractions
                graph_module_weighting = torch.fx.symbolic_trace(
                    lambda x, y, z: torch.einsum(
                        "".join(parse_subscript_weighting), x, y, z
                    )
                )
                graph_module_features = torch.fx.symbolic_trace(
                    lambda x, y: torch.einsum("".join(parse_subscript_features), x, y)
                )

                # Optimizing the contractions
                graph_opt_weighting = opt_einsum_fx.optimize_einsums_full(
                    model=graph_module_weighting,
                    example_inputs=(
                        torch.randn(
                            [num_equivariance] + [num_ell] * i + [num_params]
                        ).squeeze(0),
                        torch.randn((num_elements, num_params, self.num_channels)),
                        torch.randn((BATCH_EXAMPLE, num_elements)),
                    ),
                )
                graph_opt_features = opt_einsum_fx.optimize_einsums_full(
                    model=graph_module_features,
                    example_inputs=(
                        torch.randn(
                            [BATCH_EXAMPLE, self.num_channels, num_equivariance]
                            + [num_ell] * i
                        ).squeeze(2),
                        torch.randn((BATCH_EXAMPLE, self.num_channels, num_ell)),
                    ),
                )
                self.contractions_weighting.append(graph_opt_weighting)
                self.contractions_features.append(graph_opt_features)
                # Parameters for the product basis
                w = torch.nn.Parameter(
                    torch.randn((num_elements, num_params, self.num_channels))
                    / num_params
                )
                self.weights.append(w)
        if not internal_weights:
            self.weights = weights[:-1]
            self.weights_max = weights[-1]

    def forward(self, x: torch.Tensor, y: torch.Tensor):
        out = self.graph_opt_main( # the first layer for A features
            self.U_tensors(self.correlation), 
            self.weights_max,
            x,
            y,
        )
        for i, (weight, contract_weights, contract_features) in enumerate(
            zip(self.weights, self.contractions_weighting, self.contractions_features)
        ):
            c_tensor = contract_weights(  # other layers for A features
                self.U_tensors(self.correlation - i - 1), 
                weight,
                y,
            )
            c_tensor = c_tensor + out 
            out = contract_features(c_tensor, x) # B features
        resize_shape = torch.prod(torch.tensor(out.shape[1:]))
        return out.view(out.shape[0], resize_shape)

    def U_tensors(self, nu: int):
        return dict(self.named_buffers())[f"U_matrix_{nu}"]

class SymmetricContraction(CodeGenMixin, torch.nn.Module):
    def __init__(
        self,
        irreps_in: o3.Irreps,
        irreps_out: o3.Irreps,
        correlation: Union[int, Dict[str, int]],
        irrep_normalization: str = "component",
        path_normalization: str = "element",
        internal_weights: Optional[bool] = None,
        shared_weights: Optional[bool] = None,
        num_elements: Optional[int] = None,
    ) -> None:
        super().__init__()

        if irrep_normalization is None:
            irrep_normalization = "component"

        if path_normalization is None:
            path_normalization = "element"

        assert irrep_normalization in ["component", "norm", "none"]
        assert path_normalization in ["element", "path", "none"]

        self.irreps_in = o3.Irreps(irreps_in)
        self.irreps_out = o3.Irreps(irreps_out)

        del irreps_in, irreps_out

        if not isinstance(correlation, tuple):
            corr = correlation
            correlation = {}
            for irrep_out in self.irreps_out:
                correlation[irrep_out] = corr

        assert shared_weights or not internal_weights

        if internal_weights is None:
            internal_weights = True

        self.internal_weights = internal_weights
        self.shared_weights = shared_weights

        del internal_weights, shared_weights

        self.contractions = torch.nn.ModuleList()
        for irrep_out in self.irreps_out:
            self.contractions.append(
                Contraction(
                    irreps_in=self.irreps_in,
                    irrep_out=o3.Irreps(str(irrep_out.ir)),
                    correlation=correlation[irrep_out],
                    internal_weights=self.internal_weights,
                    num_elements=num_elements,
                    weights=self.shared_weights,
                )
            )

    def forward(self, x: torch.Tensor, y: torch.Tensor):
        outs = [contraction(x, y) for contraction in self.contractions]
        return torch.cat(outs, dim=-1)


class EquivariantProductBasisBlock(torch.nn.Module):
    def __init__(
        self,
        node_feats_irreps: o3.Irreps,
        target_irreps: o3.Irreps,
        correlation: int,
        use_sc: bool = True,
        num_elements: Optional[int] = None,
        use_cuequivariance: bool = False,
    ) -> None:
        super().__init__()

        self.use_sc = use_sc
        self.use_cuequivariance = use_cuequivariance
        self.symmetric_contractions = SymmetricContraction(
            irreps_in=node_feats_irreps,
            irreps_out=target_irreps,
            correlation=correlation,
            num_elements=num_elements,
        )
        # Update linear
        if self.use_cuequivariance:
            self.linear = cuet.Linear(
                irreps_in=cue.Irreps(cue.O3, target_irreps),
                irreps_out=cue.Irreps(cue.O3, target_irreps),
                layout=cue.mul_ir,
                internal_weights=True,
                shared_weights=True,
            )
        else:
            self.linear = o3.Linear(
                target_irreps,
                target_irreps,
                internal_weights=True,
                shared_weights=True,
            )

    def forward(
        self,
        node_feats: torch.Tensor,
        sc: Optional[torch.Tensor],
        node_attrs: torch.Tensor,
    ) -> torch.Tensor:
        node_feats = self.symmetric_contractions(node_feats, node_attrs)
        if self.use_sc and sc is not None:  # Update layer
            return self.linear(node_feats) + sc

        return self.linear(node_feats)

[docs] class MACE(nn.Module): """ A class to set up the MACE model. """
[docs] def __init__( self, num_layers: int = 3, num_channels: int = 64, norm_data: bool = False, norm_per_atom: bool = False, data_stddev: float = 1.0, data_mean: float = 0.0, **kwargs, ) -> None: """ Initialize the MACE model. """ super().__init__() self.use_cuequivariance = resolve_cuequivariance(kwargs.get('use_cuequivariance', None)) self.cutoff: float = kwargs.get('cutoff', 5.5) self.hidden_irreps: Union[o3.Irreps, str, None] = kwargs.get('hidden_irreps', None) self.edge_sh_irreps: Union[o3.Irreps, str, None] = kwargs.get('edge_sh_irreps', None) self.node_irreps: Union[o3.Irreps, str, None] = kwargs.get('node_irreps', None) self.MLP_irreps: Union[o3.Irreps, str, None] = kwargs.get('MLP_irreps', None) self.avg_num_neighbors: Optional[float] = kwargs.get('avg_num_neighbors', None) self.lmax: int = kwargs.get('lmax', 2) self.parity: bool = kwargs.get('parity', True) self.num_basis: int = kwargs.get('num_basis', 8) self.power: int = kwargs.get('power', 6) self.gate: Union[str, Callable] = kwargs.get('gate', 'silu') self.correlation: Union[int, List[int]] = kwargs.get('correlation', 3) if isinstance(self.correlation, int): self.correlation = [self.correlation] * num_layers species: List[str] = kwargs.get('species', None) if bool(species): num_elements = len(species) else: num_elements = 119 # hidden feature irreps if self.hidden_irreps is not None: self.hidden_irreps = o3.Irreps(self.hidden_irreps) if isinstance(self.hidden_irreps, str) else self.hidden_irreps else: self.hidden_irreps = o3.Irreps( [ (num_channels, o3.Irrep(l, p)) for p in ((1, -1) if self.parity else (1,)) for l in range(self.lmax + 1) ] ) # MACE prohibits some irreps like 0e, 1e to be used forbidden_ir = ['0o', '1e', '2o', '3e', '4o'] self.hidden_irreps = o3.Irreps([irrep for irrep in self.hidden_irreps if str(irrep.ir) not in forbidden_ir]) self.num_channels = self.hidden_irreps.count(o3.Irrep(0, 1)) ## handling irreps # chemical embedding irreps if self.node_irreps is None: self.node_irreps = o3.Irreps([(self.num_channels, o3.Irrep(0, 1))]) elif isinstance(self.node_irreps, str): self.node_irreps = o3.Irreps(self.node_irreps) else: self.node_irreps = self.node_irreps # edge sphere harmonic irreps if self.edge_sh_irreps is None: self.edge_sh_irreps = o3.Irreps.spherical_harmonics(self.lmax, p=-1 if self.parity else 1) elif isinstance(self.edge_sh_irreps, str): self.edge_sh_irreps = o3.Irreps(self.edge_sh_irreps) else: self.edge_sh_irreps = self.edge_sh_irreps # MLP_irreps if self.MLP_irreps is None: self.MLP_irreps = o3.Irreps([(max(1, self.num_channels // 2), o3.Irrep(0, 1))]) elif isinstance(self.MLP_irreps, str): self.MLP_irreps = o3.Irreps(self.MLP_irreps) else: self.MLP_irreps = self.MLP_irreps self.embeddings = nn.ModuleDict() self.embeddings['onehot_embedding'] = OneHotAtomEncoding(num_elements=num_elements, species=species) self.embeddings['radial_basis'] = RadialBasisEdgeEncoding( basis=BesselBasis(cutoff=self.cutoff, num_basis=self.num_basis), cutoff_fn=PolynomialCutoff(cutoff=self.cutoff, power=self.power), ) self.embeddings['sphere_harmonics'] = SphericalHarmonicEdgeAttrs(edge_sh_irreps=self.edge_sh_irreps) self.irreps_in = { 'edge_diff_embedding': self.embeddings.sphere_harmonics.irreps_out, 'edge_dist_embedding': self.embeddings.radial_basis.irreps_out, } self.irreps_in.update(self.embeddings.onehot_embedding.irreps_out) self.embeddings['chemical_embedding'] = AtomwiseLinear( irreps_in=self.irreps_in['node_attr'], irreps_out=self.node_irreps, use_cuequivariance=self.use_cuequivariance, ) self.irreps_in['node_feat'] = self.embeddings.chemical_embedding.irreps_out interaction_irreps = (self.edge_sh_irreps * self.num_channels).sort()[0].simplify() self.interactions = torch.nn.ModuleList() self.products = torch.nn.ModuleList() self.readouts = torch.nn.ModuleList() gate_fn = activation_fn[self.gate] if isinstance(self.gate, str) else self.gate # interaction blocks for i in range(num_layers): hidden_irreps_out = str(self.hidden_irreps[0]) if i == num_layers - 1 else self.hidden_irreps if i > 0: self.irreps_in['node_feat'] = self.hidden_irreps inter = RealAgnosticResidualInteractionBlock( irreps_in=self.irreps_in, target_irreps=interaction_irreps, hidden_irreps=hidden_irreps_out, avg_num_neighbors=self.avg_num_neighbors, use_cuequivariance=self.use_cuequivariance, ) self.interactions.append(inter) prod = EquivariantProductBasisBlock( node_feats_irreps=inter.target_irreps if i == 0 else interaction_irreps, target_irreps=hidden_irreps_out, correlation=self.correlation[i], num_elements=num_elements, use_sc=True, use_cuequivariance=self.use_cuequivariance, ) self.products.append(prod) if i == num_layers - 1: readout = AtomwiseNonLinear( irreps_in=hidden_irreps_out, MLP_irreps=self.MLP_irreps, gate=gate_fn, use_cuequivariance=self.use_cuequivariance, ) else: if self.use_cuequivariance: readout = cuet.Linear( irreps_in=cue.Irreps(cue.O3, hidden_irreps_out), irreps_out=cue.Irreps(cue.O3, '1x0e'), layout=cue.mul_ir ) else: readout = o3.Linear(irreps_in=hidden_irreps_out, irreps_out=o3.Irreps('1x0e')) self.readouts.append(readout) self.atomwise_reduce = AtomwiseReduce(output_key='energy') # Normalisation constants self.norm_data = torch.nn.Parameter(torch.tensor(norm_data), requires_grad=False) self.norm_per_atom = torch.nn.Parameter(torch.tensor(norm_per_atom), requires_grad=False) self.data_stddev = torch.nn.Parameter(torch.tensor(data_stddev), requires_grad=False) self.data_mean = torch.nn.Parameter(torch.tensor(data_mean), requires_grad=False) self.compute_forces = False if 'compute_forces' in kwargs.keys(): if kwargs['compute_forces']: self.compute_forces = True self.gradient_output = GradientOutput(model_outputs=['forces'])
[docs] def forward(self, data: AtomsData): """ Parameters ---------- data : AtomsData Input data for the model. Returns ------- AtomsData Output data after applying the model. """ for m in self.embeddings.values(): data = m(data) node_es_list = [] node_feat = data.node_feat node_attr = data.node_attr edge_indices = data.edge_indices edge_dist_embedding = data.edge_dist_embedding edge_diff_embedding = data.edge_diff_embedding assert node_feat is not None assert node_attr is not None assert edge_dist_embedding is not None assert edge_diff_embedding is not None for interaction, product, readout in zip( self.interactions, self.products, self.readouts ): node_feat, sc = interaction( node_feat, node_attr, edge_indices, edge_dist_embedding, edge_diff_embedding, ) node_feat = product( node_feats=node_feat, sc=sc, node_attrs=node_attr, ) node_es_list.append(readout(node_feat).view(-1)) node_es = torch.sum(torch.stack(node_es_list, dim=0), dim=0) data = replace_properties(data, atomic_energy=node_es) data = self.atomwise_reduce(data) # de-normalization if self.norm_data: normalizer = self.data_stddev energy = normalizer * data.energy mean_shift = self.data_mean if self.norm_per_atom: mean_shift = len(data.edge_indices) * mean_shift energy = energy + mean_shift data = replace_properties(data, energy=energy) if self.compute_forces: data = self.gradient_output(data) return data
def get_optimization_info(self): """Get information about optimization status""" return { "cuequivariance_available": _HAS_CUEQUIVARIANCE, "optimization_enabled": self.use_cuequivariance, "performance_boost": "2-5x speedup" if self.use_cuequivariance else "No optimization" }
class AtomwiseReduce(nn.Module): def __init__( self, output_key: str = "energy", per_atom_output: bool = False, aggregation_mode: str = "sum", # should be sum or mean ) -> None: super().__init__() # self.model_outputs = [output_key] # if per_atom_output: # self.model_outputs.append(output_key + '_per_atom') self.aggregation_mode = aggregation_mode self.per_atom_output = per_atom_output def forward(self, data: AtomsData): y = torch.zeros_like( data.num_atoms, dtype=data.edge_vectors.dtype, device=data.edge_vectors.device ) atomic_energy = data.atomic_energy assert atomic_energy is not None image_indices = data.image_indices if image_indices is None: image_indices = torch.zeros_like(atomic_energy, dtype=torch.long) assert image_indices is not None y.index_add_(0, image_indices.reshape(-1), atomic_energy.reshape(-1)) if self.aggregation_mode == "mean": y = y / data.num_atoms data = replace_properties(data, energy=y) if self.per_atom_output: assert atomic_energy is not None data = replace_properties(data, atomic_energy=atomic_energy) return data class GradientOutput(torch.nn.Module): def __init__( self, grad_on_edge_diff: bool = True, grad_on_positions: bool = False, model_outputs: List[str] = ['forces'], update_callback: Optional[Callable] = None, # Add a callback parameter ) -> None: super().__init__() self.grad_on_edge_diff = grad_on_edge_diff self.grad_on_positions = grad_on_positions self.update_callback = update_callback self.model_outputs = model_outputs def update_model_outputs(self, outputs: Union[List[str], str]): if isinstance(outputs, str): self.model_outputs.append(outputs) else: self.model_outputs.extend(outputs) if self.update_callback: self.update_callback() def forward(self, data: AtomsData, training: bool=True,)->AtomsData: if self.grad_on_edge_diff: energy = data.energy edge_vectors = data.edge_vectors positions = data.positions edge_indices = data.edge_indices assert energy is not None if 'forces' in self.model_outputs: outputs_list = torch.jit.annotate(List[torch.Tensor], [energy]) inputs_list = torch.jit.annotate(List[torch.Tensor], [edge_vectors]) grad_outputs_list = torch.jit.annotate(Optional[List[Optional[torch.Tensor]]], [torch.ones_like(energy)]) dE_ddiff = torch.autograd.grad( outputs=outputs_list, inputs=inputs_list, grad_outputs=grad_outputs_list, retain_graph=training, create_graph=training, )[0] # Initialize forces with proper strides assert dE_ddiff is not None i_forces = torch.zeros(positions.shape[0], 3, device=positions.device, dtype=torch.float32) j_forces = torch.zeros(positions.shape[0], 3, device=positions.device, dtype=torch.float32) i_forces.index_add_(0, edge_indices[:, 0], dE_ddiff) j_forces.index_add_(0, edge_indices[:, 1], -dE_ddiff) forces = i_forces + j_forces data = replace_properties(data, forces=forces) return data