Source code for iann.plugins.converter

#!/usr/bin/env python3
"""
Script to export trained IANN models to TorchScript format for LAMMPS integration.

Converts trained PyTorch models to TorchScript format
Creates a wrapper that adapts model inputs/outputs for LAMMPS
Supports all four model types with proper error handling
"""

from iann.data import AseDataReader, AtomsData
import argparse
import torch
from pathlib import Path
from typing import Dict
import warnings
warnings.filterwarnings("ignore", message=".*weights_only=False.*", category=FutureWarning)

[docs] class LAMMPSModelWrapper(torch.nn.Module): """ A wrapper that adapts model inputs/outputs for LAMMPS. """
[docs] def __init__(self, model, compute_forces=True): super().__init__() self.model = model self.compute_forces = compute_forces
[docs] def forward(self, num_atoms: torch.Tensor, atomic_numbers: torch.Tensor, positions: torch.Tensor, cell: torch.Tensor, edge_indices: torch.Tensor, edge_vectors: torch.Tensor, num_edges: torch.Tensor) -> Dict[str, torch.Tensor]: """Forward input tensors matching the C++ plugin call. Parameters ---------- num_atoms : torch.Tensor Number of atoms. atomic_numbers : torch.Tensor Atomic numbers. positions : torch.Tensor Positions. cell : torch.Tensor Cell. edge_indices : torch.Tensor Edge indices. edge_vectors : torch.Tensor Edge vectors. num_edges : torch.Tensor Returns ------- Dict[str, torch.Tensor] Dictionary with keys 'energy', 'forces', 'atomic_energy'. """ # Reconstruct the NamedTuple internally model_inputs = AtomsData( num_atoms=num_atoms, atomic_numbers=atomic_numbers, positions=positions, cell=cell, edge_indices=edge_indices, edge_vectors=edge_vectors, num_edges=num_edges, energy=None, forces=None, image_indices=None, ) if self.compute_forces: model_inputs.edge_vectors.requires_grad_() data = self.model(model_inputs) if not hasattr(self.model, 'compute_forces') or not self.model.compute_forces: raise RuntimeError("Model did not return forces. Make sure compute_forces=True") energy = data.energy forces = data.forces atomic_energy = data.atomic_energy assert energy is not None assert forces is not None assert atomic_energy is not None results = {'energy': energy, 'forces': forces, 'atomic_energy': atomic_energy} return results
[docs] class EnsembleLAMMPSModelWrapper(torch.nn.Module): """ A wrapper that adapts ensemble model inputs/outputs for LAMMPS. """
[docs] def __init__(self, models, compute_forces=True): super().__init__() self.models = models self.compute_forces = compute_forces
[docs] def forward(self, num_atoms: torch.Tensor, atomic_numbers: torch.Tensor, positions: torch.Tensor, cell: torch.Tensor, edge_indices: torch.Tensor, edge_vectors: torch.Tensor, num_edges: torch.Tensor) -> Dict[str, torch.Tensor]: """Forward pass that computes ensemble averages and variances. Parameters ---------- num_atoms : torch.Tensor Number of atoms. atomic_numbers : torch.Tensor Atomic numbers. positions : torch.Tensor Positions. cell : torch.Tensor Cell. edge_indices : torch.Tensor Edge indices. edge_vectors : torch.Tensor Edge vectors. num_edges : torch.Tensor Returns ------- Dict[str, torch.Tensor] Dictionary with keys 'energy', 'forces', 'energy_variance', 'forces_variance', 'atomic_energy_variance'. """ # Reconstruct the NamedTuple internally model_inputs = AtomsData( num_atoms=num_atoms, atomic_numbers=atomic_numbers, positions=positions, cell=cell, edge_indices=edge_indices, edge_vectors=edge_vectors, num_edges=num_edges, energy=None, forces=None, image_indices=None, ) if self.compute_forces: model_inputs.edge_vectors.requires_grad_() # Collect predictions from all models all_energies = [] all_forces = [] all_atomic_energies = [] for model in self.models: data = model(model_inputs) if not hasattr(model, 'compute_forces') or not model.compute_forces: raise RuntimeError("Model did not return forces. Make sure compute_forces=True") energy = data.energy forces = data.forces atomic_energies = data.atomic_energy assert energy is not None assert forces is not None assert atomic_energies is not None all_energies.append(energy) all_forces.append(forces) all_atomic_energies.append(atomic_energies) # Calculate ensemble averages and variances avg_energy = torch.mean(torch.stack(all_energies), dim=0) # shape: (1,) avg_forces = torch.mean(torch.stack(all_forces), dim=0) # shape: (N, 3) avg_atomic_energy = torch.mean(torch.stack(all_atomic_energies), dim=0) # shape: (N,) # Calculate variances energy_var = torch.var(torch.stack(all_energies), dim=0) # shape: (1,) forces_var = torch.var(torch.stack(all_forces), dim=0) # shape: (N, 3) atomic_energy_var = torch.var(torch.stack(all_atomic_energies), dim=0) # shape: (N,) results = {'energy': avg_energy, 'forces': avg_forces, 'atomic_energy': avg_atomic_energy, 'energy_variance': energy_var, 'forces_variance': forces_var, 'atomic_energy_variance': atomic_energy_var} return results
[docs] def convert_model_for_lammps(model_path, model_type=None, output_path=None, debug=False, atoms=None, **kwargs): """Wrap a trained model in a TorchScript-compatible wrapper for LAMMPS. Args: model_path (str): Path to the trained model checkpoint model_type (str): Type of model (painn, nequip, mace, equiformer2) output_path (str, optional): Path to save the exported model Returns: str: Path to the exported TorchScript model """ print(f"Loading model from {model_path}") # Load the model checkpoint device = torch.device('cpu') state_dict = torch.load(model_path, map_location=device) if model_type is None: # Determine model type from state dict if "model_type" in state_dict: model_type = state_dict["model_type"] else: # Try to determine from model architecture if "num_layers" in state_dict: model_type = "painn" elif "irreps" in state_dict: model_type = "nequip" elif "correlation" in state_dict: model_type = "mace" elif "transformer" in state_dict: model_type = "equiformerv2" else: raise ValueError("Could not determine model type, please provide model type explicitly!") # Create appropriate model wrapper based on type if model_type.lower() == "painn": from iann.models.painn import PaiNN num_channels = state_dict.get("num_channels", 128) num_layers = state_dict.get("num_layers", 3) cutoff = state_dict.get("cutoff", 5.5) raw_model = PaiNN( num_layers=num_layers, num_channels=num_channels, cutoff=cutoff, compute_forces=kwargs.get("compute_forces", True), **kwargs, ) raw_model.load_state_dict(state_dict["model"]) elif model_type.lower() == "nequip": from iann.models.nequip import NequIP num_layers = state_dict.get("num_layers", 3) num_channels = state_dict.get("num_channels", 128) cutoff = state_dict.get("cutoff", 5.5) raw_model = NequIP( num_layers=num_layers, num_channels=num_channels, cutoff=cutoff, compute_forces=kwargs.get("compute_forces", True), **kwargs, ) raw_model.load_state_dict(state_dict["model"]) elif model_type.lower() == "mace": from iann.models.mace import MACE num_layers = state_dict.get("num_layers", 3) num_channels = state_dict.get("num_channels", 128) cutoff = state_dict.get("cutoff", 5.5) raw_model = MACE( num_layers=num_layers, num_channels=num_channels, cutoff=cutoff, compute_forces=kwargs.get("compute_forces", True), **kwargs, ) raw_model.load_state_dict(state_dict["model"]) elif model_type.lower() == "equiformerv2": from iann.models.equiformerV2 import EquiformerV2 num_layers = state_dict.get("num_layers", 3) num_channels = state_dict.get("num_channels", 128) cutoff = state_dict.get("cutoff", 5.5) raw_model = EquiformerV2( num_layers=num_layers, num_channels=num_channels, cutoff=cutoff, compute_forces=kwargs.get("compute_forces", True), **kwargs, ) raw_model.load_state_dict(state_dict["model"]) else: raise ValueError(f"Unknown model type: {model_type}") raw_model.eval() wrapped_model = LAMMPSModelWrapper(raw_model, compute_forces=kwargs.get("compute_forces", True)) wrapped_model.eval() # Example test: verify the wrapper with dummy ASE atoms if debug: print(f"Debug mode enabled. Running example test...") if not atoms: from ase.build import fcc100 atoms = fcc100('Pt', size=(4,4,3), a=5.5, vacuum=15.0) model_inputs = AseDataReader(raw_model.cutoff, compute_forces=kwargs.get("compute_forces", True))(atoms) example_out = wrapped_model( model_inputs.num_atoms, model_inputs.atomic_numbers, model_inputs.positions, model_inputs.cell, model_inputs.edge_indices, model_inputs.edge_vectors, model_inputs.num_edges, ) print(f"Example test passed: Energy={example_out['energy']}, Forces shape={example_out['forces'].shape}") scripted_model = torch.jit.script(wrapped_model) if output_path is None: output_path = f"{Path(model_path).stem}_{model_type}_lammps.pt" torch.jit.save(scripted_model, output_path) print(f"Model exported to {output_path}") return output_path
[docs] def convert_models_for_lammps(model_paths, model_type, output_path=None, debug=False, atoms=None, **kwargs): """Convert multiple models to a single TorchScript model for LAMMPS with ensemble statistics. Args: model_paths (list): List of paths to trained model checkpoints model_type (str): Type of model (painn, nequip, mace, equiformer2) output_path (str, optional): Path to save the exported model Returns: str: Path to the exported TorchScript model """ print(f"Loading {len(model_paths)} models for ensemble") # Load all models device = torch.device('cpu') models = torch.nn.ModuleList() for model_path in model_paths: print(f"Loading model from {model_path}") state_dict = torch.load(model_path, map_location=device) if model_type is None: # Determine model type from state dict if "model_type" in state_dict: model_type = state_dict["model_type"] else: if "num_layers" in state_dict: model_type = "painn" elif "irreps" in state_dict: model_type = "nequip" elif "correlation" in state_dict: model_type = "mace" elif "transformer" in state_dict: model_type = "equiformerv2" else: raise ValueError("Could not determine model type, please provide model type explicitly!") # Create appropriate model based on type if model_type.lower() == "painn": from iann.models.painn import PaiNN num_channels = state_dict.get("num_channels", 128) num_layers = state_dict.get("num_layers", 3) cutoff = state_dict.get("cutoff", 5.5) raw_model = PaiNN( num_channels=num_channels, num_layers=num_layers, cutoff=cutoff, compute_forces=kwargs.get("compute_forces", True), **kwargs, ) elif model_type.lower() == "nequip": from iann.models.nequip import NequIP num_layers = state_dict.get("num_layers", 3) num_channels = state_dict.get("num_channels", 128) cutoff = state_dict.get("cutoff", 5.5) raw_model = NequIP( num_layers=num_layers, num_channels=num_channels, cutoff=cutoff, compute_forces=kwargs.get("compute_forces", True), **kwargs, ) elif model_type.lower() == "mace": from iann.models.mace import MACE num_layers = state_dict.get("num_layers", 3) num_channels = state_dict.get("num_channels", 128) cutoff = state_dict.get("cutoff", 5.5) raw_model = MACE( num_layers=num_layers, num_channels=num_channels, cutoff=cutoff, compute_forces=kwargs.get("compute_forces", True), **kwargs, ) elif model_type.lower() == "equiformerv2": from iann.models.equiformerV2 import EquiformerV2 num_layers = state_dict.get("num_layers", 3) num_channels = state_dict.get("num_channels", 128) cutoff = state_dict.get("cutoff", 5.5) raw_model = EquiformerV2( num_layers=num_layers, num_channels=num_channels, cutoff=cutoff, compute_forces=kwargs.get("compute_forces", True), **kwargs, ) else: raise ValueError(f"Unknown model type: {model_type}") raw_model.load_state_dict(state_dict["model"]) raw_model.eval() models.append(raw_model) # Create ensemble wrapper wrapped_model = EnsembleLAMMPSModelWrapper(models, compute_forces=kwargs.get("compute_forces", True)) wrapped_model.eval() # Test the wrapper with dummy ASE atoms if debug: print(f"Debug mode enabled. Running example test...") if not atoms: from ase.build import fcc100 atoms = fcc100('Pt', size=(4,4,3), a=5.5, vacuum=15.0) model_inputs = AseDataReader(models[0].cutoff, compute_forces=kwargs.get("compute_forces", True))(atoms) example_out = wrapped_model( model_inputs.num_atoms, model_inputs.atomic_numbers, model_inputs.positions, model_inputs.cell, model_inputs.edge_indices, model_inputs.edge_vectors, model_inputs.num_edges, ) print(f"Example test passed: Energy={example_out['energy']}, Forces shape={example_out['forces'].shape},\ Energy variance={example_out['energy_variance']}, Forces variance={example_out['forces_variance']},\ Atomic energy variance={example_out['atomic_energy_variance']}") # Script the model scripted_model = torch.jit.script(wrapped_model) if output_path is None: output_path = f"ensemble_{model_type}_lammps.pt" torch.jit.save(scripted_model, output_path) print(f"Ensemble model exported to {output_path}") return output_path
def main(): parser = argparse.ArgumentParser(description="Export IANN models to TorchScript for LAMMPS") parser.add_argument("--model_path", "-m", required=True, help="Path to the trained model checkpoint") parser.add_argument("--model_type", "-t", choices=["painn", "nequip", "mace", "equiformer2"], required=True, help="Type of model to export") parser.add_argument("--output", "-o", help="Output path for exported model") args = parser.parse_args() convert_model_for_lammps(args.model_path, args.model_type, args.output) if __name__ == "__main__": main()