from ase.io import Trajectory
import torch
from typing import List, Optional, NamedTuple, Dict, Any, Union
import asap3
import numpy as np
from scipy.spatial import distance_matrix
from ase.geometry import cell_to_cellpar
[docs]
class AtomsData(NamedTuple):
"""
A NamedTuple of model inputs.
Parameters
----------
num_atoms : torch.Tensor
The number of atoms in the system.
atomic_numbers : torch.Tensor
The atomic numbers of the atoms in the system.
positions : torch.Tensor
The positions of the atoms in the system.
cell : torch.Tensor
The cell of the system.
edge_indices : torch.Tensor
The indices of the edges in the system.
edge_vectors : torch.Tensor
The vectors of the edges in the system.
num_edges : torch.Tensor
The number of edges in the system.
energy : Optional[torch.Tensor]
The energy of the system.
forces : Optional[torch.Tensor]
The forces of the system.
stress : Optional[torch.Tensor]
The stress tensor of the system (3x3 matrix).
image_indices : Optional[torch.Tensor]
The image indices of the atoms in the system.
atomic_energy : Optional[torch.Tensor]
The atomic energy of the system.
node_attr : Optional[torch.Tensor]
The node attributes of the atoms in the system.
node_feat : Optional[torch.Tensor]
The node features of the atoms in the system.
edge_dist_embedding : Optional[torch.Tensor]
The edge distance embedding of the atoms in the system.
edge_diff_embedding : Optional[torch.Tensor]
The edge difference embedding of the atoms in the system.
energy_variance : Optional[torch.Tensor]
The energy variance of the system.
forces_variance : Optional[torch.Tensor]
The forces variance of the system.
global_attr: Optional[torch.Tensor]
The global attributes of the system.
global_embedding: Optional[torch.Tensor]
The global embedding of the system.
"""
num_atoms: torch.Tensor
atomic_numbers: torch.Tensor
positions: torch.Tensor
cell: torch.Tensor
edge_indices: torch.Tensor
edge_vectors: torch.Tensor
num_edges: torch.Tensor
global_attr: torch.Tensor
energy: Optional[torch.Tensor] = None
forces: Optional[torch.Tensor] = None
stress: Optional[torch.Tensor] = None
image_indices: Optional[torch.Tensor] = None
atomic_energy: Optional[torch.Tensor] = None
# nequip and mace
node_attr: Optional[torch.Tensor] = None
node_feat: Optional[torch.Tensor] = None
edge_dist_embedding: Optional[torch.Tensor] = None
edge_diff_embedding: Optional[torch.Tensor] = None
# ensemble
energy_variance: Optional[torch.Tensor] = None
forces_variance: Optional[torch.Tensor] = None
# global
global_embedding: Optional[torch.Tensor] = None
[docs]
def to(self, device):
new_values = {}
for field in self._fields:
value = getattr(self, field)
if isinstance(value, torch.Tensor):
new_values[field] = value.to(device)
else:
new_values[field] = value
# Return a new instance with updated values
return self._replace(**new_values)
[docs]
def contiguous(self):
new_values = {}
for field in self._fields:
value = getattr(self, field)
if isinstance(value, torch.Tensor):
cont_value = value.contiguous()
if cont_value.dim() == 2:
tmp = torch.empty(cont_value.shape[1], cont_value.shape[0],
device=cont_value.device, dtype=cont_value.dtype)
tmp.copy_(cont_value.t())
cont_value = tmp.t()
new_values[field] = cont_value
else:
new_values[field] = value
return self._replace(**new_values)
[docs]
def keys(self):
return [field for field in self._fields if getattr(self, field) is not None]
def replace_properties(
data: AtomsData,
energy: Optional[torch.Tensor] = None,
forces: Optional[torch.Tensor] = None,
stress: Optional[torch.Tensor] = None,
image_indices: Optional[torch.Tensor] = None,
atomic_energy: Optional[torch.Tensor] = None,
node_attr: Optional[torch.Tensor] = None,
node_feat: Optional[torch.Tensor] = None,
edge_dist_embedding: Optional[torch.Tensor] = None,
edge_diff_embedding: Optional[torch.Tensor] = None,
energy_variance: Optional[torch.Tensor] = None,
forces_variance: Optional[torch.Tensor] = None,
global_embedding: Optional[torch.Tensor] = None,
) -> AtomsData:
"""
Replace the properties of the AtomsData object.
"""
return AtomsData(
num_atoms=data.num_atoms,
atomic_numbers=data.atomic_numbers,
positions=data.positions,
cell=data.cell,
edge_indices=data.edge_indices,
edge_vectors=data.edge_vectors,
num_edges=data.num_edges,
energy=energy if energy is not None else data.energy,
forces=forces if forces is not None else data.forces,
stress=stress if stress is not None else data.stress,
image_indices=image_indices if image_indices is not None else data.image_indices,
atomic_energy=atomic_energy if atomic_energy is not None else data.atomic_energy,
node_attr=node_attr if node_attr is not None else data.node_attr,
node_feat=node_feat if node_feat is not None else data.node_feat,
edge_dist_embedding=edge_dist_embedding if edge_dist_embedding is not None else data.edge_dist_embedding,
edge_diff_embedding=edge_diff_embedding if edge_diff_embedding is not None else data.edge_diff_embedding,
energy_variance=energy_variance if energy_variance is not None else data.energy_variance,
forces_variance=forces_variance if forces_variance is not None else data.forces_variance,
global_attr=data.global_attr,
global_embedding=global_embedding if global_embedding is not None else data.global_embedding,
)
class AseDataReader:
"""
A class to read the data from the ASE Atoms object.
Parameters
----------
cutoff : float
The cutoff radius for the neighbor list.
compute_forces : bool
Whether to compute the forces.
Returns
-------
atoms_data : AtomsData
The AtomsData object.
"""
def __init__(self, cutoff=5.0, compute_forces=False):
self.cutoff = cutoff
self.compute_forces = compute_forces
def __call__(self, atoms):
num_atoms = torch.tensor([atoms.get_global_number_of_atoms()])
atomic_numbers = torch.tensor(atoms.numbers, dtype=torch.long)
positions = torch.tensor(atoms.positions, dtype=torch.float32)
cell = torch.tensor(atoms.cell[:], dtype=torch.float32)
if atoms.pbc.any():
edge_indices, edge_vectors = self.get_neighborlist(atoms)
else:
edge_indices, edge_vectors = self.get_neighborlist_simple(atoms)
# Ensure correct shapes for empty edge lists (e.g., single atom case)
if edge_indices.shape[0] == 0:
edge_indices = edge_indices.reshape(0, 2)
edge_vectors = edge_vectors.reshape(0, 3)
edge_indices = torch.from_numpy(edge_indices)
edge_vectors = torch.from_numpy(edge_vectors).float()
if self.compute_forces:
edge_vectors.requires_grad_()
num_edges = torch.tensor([edge_indices.shape[0]])
global_attr = self.get_global_attr(atoms)
try:
energy = torch.tensor([atoms.get_potential_energy()], dtype=torch.float32)
except (AttributeError, RuntimeError):
energy = None
try:
forces = torch.tensor(atoms.get_forces(apply_constraint=False), dtype=torch.float32)
except (AttributeError, RuntimeError):
forces = None
# Return as AtomsData object for TorchScript compatibility
return AtomsData(
num_atoms=num_atoms,
atomic_numbers=atomic_numbers,
positions=positions,
cell=cell,
edge_indices=edge_indices,
edge_vectors=edge_vectors,
num_edges=num_edges,
energy=energy,
forces=forces,
image_indices=None,
global_attr=global_attr,
).contiguous()
def get_neighborlist(self, atoms):
nl = asap3.FullNeighborList(self.cutoff, atoms)
pair_i_idx = []
pair_j_idx = []
edge_vectors = []
for i in range(len(atoms)):
indices, diff, _ = nl.get_neighbors(i)
pair_i_idx += [i] * len(indices)
pair_j_idx.append(indices)
edge_vectors.append(diff)
# Handle case when there are no edges (e.g., single atom)
if len(pair_j_idx) == 0 or all(len(idx) == 0 for idx in pair_j_idx):
edge_indices = np.empty((0, 2), dtype=np.int64)
edge_vectors = np.empty((0, 3), dtype=np.float64)
else:
pair_j_idx = np.concatenate(pair_j_idx)
edge_indices = np.stack((pair_i_idx, pair_j_idx), axis=1)
edge_vectors = np.concatenate(edge_vectors)
return edge_indices, edge_vectors
def get_neighborlist_simple(self, atoms):
pos = atoms.get_positions()
dist_mat = distance_matrix(pos, pos)
mask = dist_mat < self.cutoff
np.fill_diagonal(mask, False)
edge_indices = np.argwhere(mask)
# Handle case when there are no edges (e.g., single atom)
if edge_indices.shape[0] == 0:
edge_vectors = np.empty((0, 3), dtype=np.float64)
else:
edge_vectors = pos[edge_indices[:, 1]] - pos[edge_indices[:, 0]]
return edge_indices, edge_vectors
def get_global_attr(self, atoms):
cellpar = cell_to_cellpar(atoms.cell)
a, b, c, alpha, beta, gamma = cellpar
volume = atoms.get_volume()
density = atoms.get_masses().sum() / volume
global_attr = torch.tensor(
[a, b, c, alpha, beta, gamma, volume, density],
dtype=torch.float32
)
# Standardization
global_attr = (global_attr - global_attr.mean()) / (global_attr.std() + 1e-8)
return global_attr
[docs]
class AseDataset(torch.utils.data.Dataset):
"""
A class to read the data from the ASE Atoms object.
Parameters
----------
db : str or ase.Atoms (or list of ase.Atoms)
The ASE Atoms object or the path to the ASE Atoms object.
cutoff : float
The cutoff radius for the neighbor list.
compute_forces : bool
Whether to compute the forces.
Returns
-------
atoms_data : AtomsData
The AtomsData object.
"""
def __init__(self, ase_db, cutoff=5.0, compute_forces=False, **kwargs):
if isinstance(ase_db, str):
self.db = Trajectory(ase_db)
else:
self.db = ase_db
self.cutoff = cutoff
self.atoms_reader = AseDataReader(cutoff, compute_forces)
def __len__(self):
return len(self.db)
def __getitem__(self, idx):
atoms = self.db[idx]
atoms_data = self.atoms_reader(atoms)
return atoms_data
def cat_tensors(tensors: List[torch.Tensor]):
if tensors[0].shape:
return torch.cat(tensors)
return torch.stack(tensors)
def collate_atomsdata(atoms_data: List[dict], pin_memory=True):
field_names = atoms_data[0].keys()
batched_atoms_data = AtomsData(**{
k: torch.cat([getattr(obj, k) for obj in atoms_data if getattr(obj, k) is not None])
for k in field_names
})
# Pin memory function
pin = (lambda x: x.pin_memory()) if pin_memory else (lambda x: x)
# create image index for each atom with proper memory layout
image_indices = torch.repeat_interleave(
torch.arange(len(atoms_data)), batched_atoms_data.num_atoms, dim=0
)
batched_atoms_data = batched_atoms_data._replace(image_indices=image_indices)
# shift index of edges (because of batching) with proper memory layout
if batched_atoms_data.edge_indices is not None:
edge_offset = torch.zeros_like(batched_atoms_data.num_atoms)
edge_offset[1:] = batched_atoms_data.num_atoms[:-1]
edge_offset = torch.cumsum(edge_offset, dim=0)
edge_offset = torch.repeat_interleave(edge_offset, batched_atoms_data.num_edges)
edge_indices = batched_atoms_data.edge_indices + edge_offset.unsqueeze(-1)
batched_atoms_data = batched_atoms_data._replace(edge_indices=edge_indices)
return batched_atoms_data