Source code for iann.trainer.trainer

from iann.data import AseDataset, collate_atomsdata
import numpy as np
import math, time
import json, os, toml, sys
import argparse
import logging
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from datetime import timedelta
import warnings
warnings.filterwarnings("ignore", message=".*weights_only=False.*", category=FutureWarning)


path = os.path.abspath(os.path.join(os.path.dirname(__file__)))

# Default configuration that can be overridden
DEFAULT_CONFIG = {
    # parameters for model
    "num_channels": 128, # number of channels in the model
    "num_layers": 3, # number of layers in the model
    "cutoff": 5.5, # cutoff radius
    # parameters for trainer
    "device": None,      # override device, e.g. 'cpu' or 'cuda'
    "val_ratio": 0.1, # validation ratio
    "batch_size": 12, # batch size
    "learning_rate": 0.0001, # initial learning rate
    "forces_weight": 0.9, # weight for forces
    "load_model": False, # load model from checkpoint
    "max_steps": 1000000, # maximum number of steps
    "max_epochs": None,  # None if setup max_steps, otherwise max_epochs
    "optimizer_type": "adam", # optimizer type: "adam", "sgd", "rmsprop", "adagrad", "adadelta", "adamax", "adamw"
    "max_grad_norm": None,    # gradient clipping norm
    "log_interval": 2000, # log interval
    "stop_patience": 200, # patience for early stopping
    "scheduler_type": "LambdaLR", # scheduler type: "ReduceLROnPlateau", "LambdaLR", "CosineAnnealingLR", "CosineAnnealingWarmRestarts", "StepLR", "MultiStepLR", "ExponentialLR"
    # parameters for data
    "random_seed": 666, # random seed for reproducibility
    "save_split": False, # save split file name
    "load_split": False, # load split file name
    "norm_data": False, # normalize data
    "norm_per_atom": False, # normalize data per atom
    # parameters for DDP (Parallelization)
    "dist_timeout": 600,  # timeout (seconds) for distributed operations
    "master_port": 12356, # port for distributed operations
    # parameters for output
    "output_dir": "output", # output directory
    "output_log": "output.log", # log file
    "output_model": "model.pt", # model file
    "log_input": False, # log input config
    "debug": False, # debug mode
}

# Logging filter to inject rank into log records
class RankFilter(logging.Filter):
    def __init__(self, rank):
        super().__init__()
        self.rank = rank
    def filter(self, record):
        record.rank = self.rank
        return True

def setup_seed(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True

def get_arguments(arg_list=None):
    parser = argparse.ArgumentParser(
        description="IANN arguments", fromfile_prefix_chars="+"
    )     
    parser.add_argument(
        "--cfg",
        type=str,
        help="Path to a toml config file'"
    )
    parser.add_argument(
        "--model_type",
        type=str,
        choices=["painn", "nequip", "mace", "equiformer2", "fastpot", "demo"],
        help="Type of model to use"
    )
    parser.add_argument(
        "--dataset",
        type=str,
        help="Path to a dataset file"
    )
    
    return parser.parse_args(arg_list)

def forces_criterion(predicted, target, reduction="mean"):
    diff = predicted - target
    total_squared_norm = torch.linalg.norm(diff, dim=1)  # bs
    if reduction == "mean":
        scalar = torch.mean(total_squared_norm)
    elif reduction == "sum":
        scalar = torch.sum(total_squared_norm)
    else:
        raise ValueError("Reduction must be 'mean' or 'sum'")
    return scalar

def get_norm_data(dataset, per_atom=True):
    x_sum = torch.zeros(1, dtype=torch.float32)
    x_2 = torch.zeros(1, dtype=torch.float32)
    num_objects = 0
    for i, sample in enumerate(dataset):
        if i == 0:
            if per_atom:
                bias = sample.energy / sample.num_atoms
            else:
                bias = sample.energy
        x = sample.energy
        if per_atom:
            x = x / sample.num_atoms
        x -= bias
        x_sum += x
        x_2 += x ** 2.0
        num_objects += 1
    x_mean = x_sum / num_objects
    x_var = x_2 / num_objects - x_mean ** 2.0
    x_mean = x_mean + bias
    default_type = torch.get_default_dtype()

    return x_mean.type(default_type), torch.sqrt(x_var).type(default_type)

class EarlyStopping():
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.early_stop = False

    def __call__(self, val_loss, best_loss):
        if val_loss - best_loss > self.min_delta:
            self.counter +=1
            if self.counter >= self.patience:  
                self.early_stop = True
                
        return self.early_stop

def split_data(dataset, config, rank):
    # Load existed split file
    if config["load_split"]:
        with open(f"{config['load_split']}.json", "r") as fp:
            splits = json.load(fp)
            if rank == 0:
                logging.info(f"Loaded Existed Json Split File (load_split): {config['load_split']}")
    else:
        # Split the dataset into training and validation sets
        datalen = len(dataset)
        num_validation = int(math.ceil(datalen * config["val_ratio"]))
        indices = np.random.permutation(len(dataset))
        splits = {
            "train": indices[num_validation:].tolist(),
            "validation": indices[:num_validation].tolist(),
        }
    # Save split file
    if config["save_split"]:
        output_dir = config["output_dir"]
        os.makedirs(output_dir, exist_ok=True)
        with open(os.path.join(output_dir, f"{config['save_split']}.json"), "w") as f:
            json.dump(splits, f)
        if rank == 0:
            logging.info(f"Saved Json Split File (save_split): {config['save_split']}")

    # Split the dataset
    datasplits = {}
    for key, indices in splits.items():
        datasplits[key] = torch.utils.data.Subset(dataset, indices)
    return datasplits

[docs] class Trainer: """Trainer class for training interatomic neural network models"""
[docs] def __init__(self, model="painn", config=None, distributed=True, rank=None, world_size=None): """ Initialize the trainer with a model type and optional config Args: model (str): Model type ("painn", "nequip", "mace", or "equiformer2") config (dict, optional): Configuration overrides distributed (bool, optional): Whether to use distributed training rank (int, optional): Rank of the current process world_size (int, optional): Total number of processes """ # Initialize configuration with defaults self.config = DEFAULT_CONFIG.copy() # Update with user config if provided if config: self.config.update(config) self.input_config = config # Set model type self.model_type = model.lower() if self.model_type not in ["painn", "nequip", "mace", "equiformerv2", "fastpot", "demo"]: raise ValueError(f"Unknown model type: {self.model_type}") # Auto-detect SLURM to avoid spawning when SLURM is already managing processes self._under_slurm = 'SLURM_JOB_ID' in os.environ # Track whether rank was provided explicitly self.distributed = distributed self._explicit_rank = rank is not None # Auto-detect SLURM rank/world_size if not provided if self.distributed: # Rank if self._explicit_rank: self.rank = rank else: self.rank = int(os.environ.get('SLURM_PROCID', 0)) # World size if world_size is not None: self.world_size = world_size else: self.world_size = int(os.environ.get('SLURM_NTASKS', 1)) else: self.rank = 0 self.world_size = 1 # Initialize device from config override (e.g., 'cpu' or 'cuda:0'), if provided self.device = torch.device(self.config["device"]) if self.config.get("device") else None self.current_node = None # Initialize other attributes self.model = None self.optimizer = None self.scheduler = None self.criterion = None self.early_stop = None self.dataset = None self.datasplits = None self.train_loader = None self.val_loader = None self.train_sampler = None self.val_sampler = None self.data_mean = None self.data_stddev = None self.optimizer_type = self.config["optimizer_type"] self.scheduler_type = self.config["scheduler_type"] # Create output directory os.makedirs(self.config["output_dir"], exist_ok=True) # Logging is configured here, now that rank is set self._setup_logging() # Set random seed setup_seed(self.config["random_seed"])
def _setup_distributed(self): """Initialize distributed training environment""" if 'SLURM_JOB_NODELIST' in os.environ: # Get the first node name from the SLURM node list node_list = os.environ['SLURM_JOB_NODELIST'] # Handle different node list formats if '[' in node_list and ']' in node_list: base_name = node_list.split('[')[0] node_range = node_list.split('[')[1].split(']')[0] # Handle range format like "034-035" if '-' in node_range: start, end = node_range.split('-') # Preserve leading zeros by using the width of the original strings width = len(start) start_num = int(start) end_num = int(end) node_numbers = [f"{i:0{width}d}" for i in range(start_num, end_num + 1)] elif ',' in node_range: node_numbers = node_range.split(',') else: node_numbers = [node_range] node_numbers = sorted(node_numbers) master_addr = f"{base_name}{node_numbers[0]}" # Use first node as master node_index = self.rank % len(node_numbers) # Use modulo to wrap around if rank > num_nodes current_node = f"{base_name}{node_numbers[node_index]}" else: master_addr = node_list.split(',')[0] current_node = master_addr else: master_addr = os.environ.get('MASTER_ADDR', 'localhost') current_node = master_addr # Use fixed port for simplicity master_port = self.config["master_port"] self.master_addr = master_addr self.master_port = master_port os.environ['MASTER_ADDR'] = master_addr os.environ['MASTER_PORT'] = str(master_port) if self.rank == 0: logging.info(f"PyTorch version: {torch.__version__}") logging.info(f"Node List: {os.environ.get('SLURM_JOB_NODELIST', 'N/A')}") if torch.cuda.is_available(): logging.info(f"World Size (number of GPUs): {self.world_size}") else: logging.info(f"World Size (number of CPUs): {self.world_size}") logging.info(f"Master Address: {master_addr}") logging.info(f"Master Port: {master_port}") time.sleep(self.rank * 0.1 + 0.1) # Set device and self.device if torch.cuda.is_available() and self.device.type == 'cuda': if 'SLURM_LOCALID' in os.environ: local_rank = int(os.environ['SLURM_LOCALID']) else: local_rank = self.rank % torch.cuda.device_count() torch.cuda.set_device(local_rank) self.device = torch.device(f"cuda:{local_rank}") logging.info(f"Process {self.rank} using device {self.device} on {current_node}. GPU architecture: {torch.cuda.get_device_name()}") else: self.device = torch.device("cpu") import cpuinfo info = cpuinfo.get_cpu_info() logging.info(f"Process {self.rank} using device {self.device} on {current_node}, CPU architecture: {info['brand_raw']}") # Choose backend based on device: NCCL for GPUs, Gloo for CPU backend = "nccl" if self.device is not None and self.device.type.startswith("cuda") else "gloo" dist.init_process_group( backend, rank=self.rank, world_size=self.world_size, timeout=timedelta(seconds=self.config["dist_timeout"]) ) # Wait for all ranks to finish logging time.sleep((self.world_size - self.rank) * 0.1 + 0.1) # Each rank waits for others self.current_node = current_node return current_node def _cleanup_distributed(self): """Clean up distributed environment""" if self.distributed: dist.destroy_process_group() sys.exit(0) # Clean and Pythonic exit def _setup_data(self, dataset_path): """Setup dataset and dataloaders""" if self.rank == 0: logging.info(f"Loading data from {os.path.abspath(dataset_path)}") # Load dataset self.dataset = AseDataset( ase_db=dataset_path, cutoff=self.config["cutoff"], compute_forces=bool(self.config["forces_weight"]), ) # Split data self.datasplits = split_data(self.dataset, self.config, self.rank) if self.distributed: # Setup distributed samplers self.train_sampler = torch.utils.data.distributed.DistributedSampler( self.datasplits["train"], num_replicas=self.world_size, rank=self.rank, shuffle=True ) self.val_sampler = torch.utils.data.distributed.DistributedSampler( self.datasplits["validation"], num_replicas=self.world_size, rank=self.rank, shuffle=False ) else: # Use regular samplers self.train_sampler = torch.utils.data.RandomSampler(self.datasplits["train"]) self.val_sampler = torch.utils.data.SequentialSampler(self.datasplits["validation"]) # Setup dataloaders self.train_loader = torch.utils.data.DataLoader( self.datasplits["train"], self.config["batch_size"], sampler=self.train_sampler, collate_fn=collate_atomsdata, num_workers=0, pin_memory=True, ) self.val_loader = torch.utils.data.DataLoader( self.datasplits["validation"], self.config["batch_size"], sampler=self.val_sampler, collate_fn=collate_atomsdata, num_workers=0, pin_memory=True, ) if self.rank == 0: logging.info('Dataset size: {}, training set size: {}, validation set size: {}'.format( len(self.dataset), len(self.datasplits["train"]), len(self.datasplits["validation"]), )) # Compute data normalization if needed if self.config["norm_data"] or self.config["norm_per_atom"]: if self.rank == 0: logging.info("Computing energy mean and variance of the dataset: ") self.data_mean, self.data_stddev = get_norm_data( self.datasplits["train"], per_atom=self.config["norm_per_atom"], ) if self.rank == 0: logging.info(f"Mean of energy: {self.data_mean.item():.4f}, standard deviation of energy: {self.data_stddev.item():.4f}") else: self.data_mean = torch.tensor([0.0]) self.data_stddev = torch.tensor([1.0]) if self.rank == 0: if bool(self.config['forces_weight']): logging.info("Compute forces: True") else: logging.info("Compute forces: False") def _create_model(self): """Create model based on model_type""" # Other model parameters model_params = { "compute_forces": bool(self.config["forces_weight"]), } # update all params in model_params model_params.update(self.config) model_params.pop("num_layers") model_params.pop("num_channels") model_params.pop("norm_data") model_params.pop("norm_per_atom") model_params.pop("device") # Set model if self.model_type == "painn": try: from iann.models.painn import PaiNN except ImportError: raise ImportError("PaiNN is not available") model = PaiNN( num_layers=self.config["num_layers"], num_channels=self.config["num_channels"], norm_data=self.config["norm_data"], data_mean=self.data_mean.tolist() if self.config["norm_data"] else [0.0], data_stddev=self.data_stddev.tolist() if self.config["norm_data"] else [1.0], norm_per_atom=self.config["norm_per_atom"], **model_params ) elif self.model_type == "nequip": try: from iann.models.nequip import NequIP except ImportError: raise ImportError("NequIP is not available") model = NequIP( num_layers=self.config["num_layers"], num_channels=self.config["num_channels"], norm_data=self.config["norm_data"], data_mean=self.data_mean.tolist() if self.config["norm_data"] else [0.0], data_stddev=self.data_stddev.tolist() if self.config["norm_data"] else [1.0], norm_per_atom=self.config["norm_per_atom"], **model_params ) elif self.model_type == "mace": try: from iann.models.mace import MACE except ImportError: raise ImportError("MACE is not available") model = MACE( num_layers=self.config["num_layers"], num_channels=self.config["num_channels"], norm_data=self.config["norm_data"], data_mean=self.data_mean.tolist() if self.config["norm_data"] else [0.0], data_stddev=self.data_stddev.tolist() if self.config["norm_data"] else [1.0], norm_per_atom=self.config["norm_per_atom"], **model_params ) elif self.model_type == "equiformerv2": try: from iann.models.equiformerV2 import EquiformerV2 except ImportError: raise ImportError("EquiformerV2 is not available") model = EquiformerV2( num_layers=self.config["num_layers"], num_channels=self.config["num_channels"], device=self.device, norm_data=self.config["norm_data"], data_mean=self.data_mean.tolist() if self.config["norm_data"] else [0.0], data_stddev=self.data_stddev.tolist() if self.config["norm_data"] else [1.0], norm_per_atom=self.config["norm_per_atom"], **model_params ) elif self.model_type == "fastpot": try: from iann.models.fastpot import FastPot except ImportError: raise ImportError("FastPot is not available") model = FastPot( num_layers=self.config["num_layers"], num_channels=self.config["num_channels"], norm_data=self.config["norm_data"], data_mean=self.data_mean.tolist() if self.config["norm_data"] else [0.0], data_stddev=self.data_stddev.tolist() if self.config["norm_data"] else [1.0], norm_per_atom=self.config["norm_per_atom"], **model_params ) elif self.model_type == "demo": try: from iann.models.demo import Demo except ImportError: raise ImportError("Demo is not available") model = Demo( num_layers=self.config["num_layers"], num_channels=self.config["num_channels"], norm_data=self.config["norm_data"], data_mean=self.data_mean.tolist() if self.config["norm_data"] else [0.0], data_stddev=self.data_stddev.tolist() if self.config["norm_data"] else [1.0], norm_per_atom=self.config["norm_per_atom"], **model_params ) else: raise ValueError(f"Unknown model type: {self.model_type}") return model def _setup_model(self): """Setup model, optimizer, and scheduler""" # Create model self.model = self._create_model() # Move model to device self.model = self.model.to(self.device) # Wrap with DDP if distributed if self.distributed: if self.device.type == 'cpu': self.model = DDP(self.model) else: if 'SLURM_LOCALID' in os.environ: local_rank = int(os.environ['SLURM_LOCALID']) if self.model_type == "equiformerv2" and bool(self.config["forces_weight"]): find_unused_parameters = True else: find_unused_parameters = False self.model = DDP( self.model, device_ids=[local_rank], gradient_as_bucket_view=True, # Memory efficiency broadcast_buffers=False, # Broadcast buffers may cause stuck for MACE or NequiIP static_graph=False, # Dynamic graph is False by default find_unused_parameters=find_unused_parameters, ) else: self.model = DDP(self.model, device_ids=[self.rank]) # Log model info total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad) if self.rank == 0: logging.info(f"Total trainable parameters: {total_params}") total_memory = sum(p.element_size() * p.numel() for p in self.model.parameters()) if self.rank == 0: logging.info(f"Total memory of the model: {total_memory / 1024**2:.2f} MB") # Setup optimizer if self.optimizer_type == "adam": self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.config["learning_rate"]) elif self.optimizer_type == "adamw": self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.config["learning_rate"], weight_decay=self.config.get("weight_decay", 1e-2)) elif self.optimizer_type == "sgd": self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.config["learning_rate"], momentum=self.config.get("momentum", 0.9)) else: raise ValueError(f"Unknown optimizer type: {self.optimizer_type}") self.criterion = torch.nn.MSELoss() # Add gradient clipping for stability if self.config.get("max_grad_norm") is not None: self.max_grad_norm = self.config.get("max_grad_norm") # Setup scheduler if self.scheduler_type=="ReduceLROnPlateau": self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min', factor=0.5, patience=10) elif self.scheduler_type=="LambdaLR": scheduler_fn = lambda step: 0.96 ** (step / 100000) self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, scheduler_fn) elif self.scheduler_type=="CosineAnnealingLR": self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.config["max_steps"]) elif self.scheduler_type=="CosineAnnealingWarmRestarts": self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, T_0=10, T_mult=2) elif self.scheduler_type=="StepLR": self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=10, gamma=0.5) elif self.scheduler_type=="MultiStepLR": self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[10, 20, 30], gamma=0.5) elif self.scheduler_type=="ExponentialLR": self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.95) else: raise ValueError(f"Unknown scheduler type: {self.scheduler_type}") # Setup early stopping self.early_stop = EarlyStopping(patience=self.config["stop_patience"]) # Initialize steps self.init_steps = 0 # Load model if needed if self.config["load_model"]: self._load_model() def _load_model(self): """Load model from checkpoint""" if self.config['load_model'].endswith('.pt'): best_model = self.config['load_model'] else: best_model = os.path.join(self.config["output_dir"], self.config["output_model"]) if os.path.exists(best_model): if self.rank == 0: logging.info(f"Loading model from {best_model}") if self.distributed: if torch.cuda.is_available() and self.device.type == 'cuda': if 'SLURM_LOCALID' in os.environ: local_rank = int(os.environ['SLURM_LOCALID']) else: local_rank = self.rank % torch.cuda.device_count() map_location = torch.device(f"cuda:{local_rank}") else: # CPU local_rank = self.rank map_location = torch.device("cpu") state_dict = torch.load(best_model, map_location=map_location) else: state_dict = torch.load(best_model, map_location=self.device) if self.distributed: self.model.module.load_state_dict(state_dict["model"]) else: self.model.load_state_dict(state_dict["model"]) if state_dict["step"] > 0: self.init_steps = state_dict["step"] self.scheduler.load_state_dict(state_dict["scheduler"]) else: if self.rank == 0: logging.info(f"No model found at {best_model}") self.config["load_model"] = False def _save_model(self, filename, total_steps, best_val_loss): """Save model checkpoint""" # Check for NaN values and exit if found if math.isnan(best_val_loss): if self.rank == 0: logging.error("NaN values detected in best_val_loss. Exiting training.") if self.distributed: self._cleanup_distributed() sys.exit(1) model_state = self.model.module.state_dict() if self.distributed else self.model.state_dict() torch.save( { "model_type": self.model_type, "model": model_state, "optimizer": self.optimizer.state_dict(), "scheduler": self.scheduler.state_dict(), "step": total_steps, "best_val_loss": best_val_loss, "num_channels": self.config["num_channels"], "num_layers": self.config["num_layers"], "cutoff": self.config["cutoff"], "compute_forces": self.config["forces_weight"], }, os.path.join(self.config["output_dir"], filename), )
[docs] def eval_model(self): """Evaluate model on validation set""" model = self.model was_training = model.training # Set model to evaluation mode model.eval() energy_running_ae = 0.0 energy_running_se = 0.0 forces_running_c_ae = 0.0 forces_running_c_se = 0.0 running_loss = 0.0 count = 0 forces_count = 0 for batch in self.val_loader: device_batch = batch.to(self.device) out = model(device_batch) count += device_batch.energy.shape[0] energy_loss = self.criterion(out.energy, device_batch.energy).detach().cpu().numpy() if bool(self.config["forces_weight"]): forces_count += device_batch.forces.shape[0] forces_loss = forces_criterion(out.forces, device_batch.forces).detach().cpu().numpy() else: forces_loss = 0.0 # use mean square loss here total_loss = self.config["forces_weight"] * forces_loss + (1 - self.config["forces_weight"]) * energy_loss running_loss += total_loss * device_batch.energy.shape[0] # energy errors energy_targets = device_batch.energy.detach().cpu().numpy() energy_outputs = out.energy.detach().cpu().numpy() energy_running_ae += np.sum(np.abs(energy_targets - energy_outputs), axis=0) energy_running_se += np.sum(np.square(energy_targets - energy_outputs), axis=0) # force errors if bool(self.config["forces_weight"]): forces_targets = device_batch.forces.detach().cpu().numpy() forces_outputs = out.forces.detach().cpu().numpy() forces_diff = forces_targets - forces_outputs forces_running_c_ae += np.sum(np.abs(forces_diff)) forces_running_c_se += np.sum(np.square(forces_diff)) else: forces_running_c_ae = 0.0 forces_running_c_se = 0.0 # Restore the model's original training state if was_training: model.train() energy_mae = energy_running_ae / count energy_rmse = np.sqrt(energy_running_se / count) if bool(self.config["forces_weight"]): forces_mae = forces_running_c_ae / (forces_count * 3) forces_rmse = np.sqrt(forces_running_c_se / (forces_count * 3)) else: forces_mae = 0 forces_rmse = 0 total_loss = running_loss / count evaluation = { "energy_mae": energy_mae, "energy_rmse": energy_rmse, "forces_mae": forces_mae, "forces_rmse": forces_rmse, "sqrt(total_loss)": np.sqrt(total_loss), } return evaluation
def _setup_logging(self): """Setup logging with dynamic rank injection""" # Clear existing handlers root = logging.getLogger() for h in root.handlers[:]: try: h.flush() except Exception: pass root.removeHandler(h) root.setLevel(logging.DEBUG) # Format: include injected %(rank)s fmt = "%(asctime)s [RANK%(rank)s] [%(levelname)-5.5s] %(message)s" datefmt = "%Y-%m-%d %H:%M:%S" formatter = logging.Formatter(fmt, datefmt) # Create and configure file handler log_file = os.path.join(self.config["output_dir"], self.config["output_log"]) file_handler = logging.FileHandler(log_file, mode="a") file_handler.setFormatter(formatter) file_handler.addFilter(RankFilter(self.rank)) # Create and configure stream handler stream_handler = logging.StreamHandler() stream_handler.setFormatter(formatter) stream_handler.addFilter(RankFilter(self.rank)) # Attach handlers root.addHandler(file_handler) root.addHandler(stream_handler)
[docs] def train(self, dataset_path=None): """ Train the model on a dataset Args: dataset_path (str): Path to the dataset file """ # Configure logging first, using the detected or provided rank self._setup_logging() # Determine dataset path if dataset_path is None: dataset_path = self.config.get("dataset") if not dataset_path: raise ValueError("Dataset path must be provided via train() argument or in config['dataset']") # If using local multi-GPU (not under SLURM) and rank not explicitly set, spawn worker processes if (self.distributed and self.world_size > 1 and not self._explicit_rank and self.rank == 0 and not self._under_slurm): mp.spawn( process_function, args=(self.world_size, self.model_type, self.config, dataset_path), nprocs=self.world_size, join=True ) return # Return after spawning processes - only the spawned children continue # Setup distributed training if enabled if self.distributed: self._setup_distributed() else: # Single-GPU or CPU mode try: node_name = __import__('platform').node() except: node_name = "unknown" self.device = torch.device("cuda:0" if torch.cuda.is_available() and self.device.type == 'cuda' else "cpu") logging.info(f"PyTorch version: {torch.__version__}") logging.info(f"Running in single-{'GPU' if torch.cuda.is_available() and self.device.type == 'cuda' else 'CPU'} Node {node_name}") if torch.cuda.is_available(): logging.info(f"Hardware architecture: {torch.cuda.get_device_name()}") else: import cpuinfo info = cpuinfo.get_cpu_info() logging.info(f"Hardware architecture: {info['brand_raw']}") # Setup dataset and dataloaders self._setup_data(dataset_path) # Setup model, optimizer, and scheduler self._setup_model() # determine max_steps or max_epochs if self.config["max_epochs"]: self.config["max_steps"] = None max_epochs = self.config["max_epochs"] else: max_epochs = int(self.config["max_steps"]) # Log detailed model configuration and setup if self.rank == 0: logging.info("---------------- Configuration Settings ----------------") # Log for model logging.info(f"Model Type (model): {self.model_type}") logging.info(f"Number of Channels (num_channels): {self.config['num_channels']}") logging.info(f"Number of Layers (num_layers): {self.config['num_layers']}") logging.info(f"Cutoff Radius (cutoff): {self.config['cutoff']}") # Log for trainer logging.info(f"Validation Ratio (val_ratio): {self.config['val_ratio']}") logging.info(f"Batch Size (batch_size): {self.config['batch_size']}") logging.info(f"Learning Rate (learning_rate): {self.config['learning_rate']}") logging.info(f"Forces Weight (forces_weight): {self.config['forces_weight']}") if self.config["max_epochs"]: logging.info(f"Max Epochs (max_epochs): {max_epochs}") else: logging.info(f"Max Steps (max_steps): {self.config['max_steps']}") logging.info(f"Optimizer Type (optimizer_type): {self.optimizer_type}") if self.config['max_grad_norm']: logging.info(f"Gradient Clipping Norm (max_grad_norm): {self.config['max_grad_norm']}") logging.info(f"Log Interval (log_interval): {self.config['log_interval']}") logging.info(f"Early Stopping Patience (stop_patience): {self.config['stop_patience']}") logging.info(f"Scheduler Type (scheduler_type): {self.config['scheduler_type']}") # Log for data logging.info(f"Random Seed (random_seed): {self.config['random_seed']}") if self.config['save_split']: logging.info(f"Save Split File Name (save_split): {self.config['save_split']}") if self.config['load_split']: logging.info(f"Load Split File Name (load_split): {self.config['load_split']}") if self.config['norm_data'] and not self.config['norm_per_atom']: logging.info(f"Data Normalization (norm_data): {self.config['norm_data']}") if self.config['norm_per_atom']: logging.info(f"Data Normalization per Atom (norm_per_atom): {self.config['norm_per_atom']}") if not self.config['norm_data'] and not self.config['norm_per_atom']: logging.info("Data Normalization (norm_data): False") # Log for DDP logging.info(f"Distributed Training (distributed): {self.distributed}") if self.distributed: logging.info(f"Master Port (master_port): {self.config['master_port']}") logging.info(f"Distributed Timeout (dist_timeout) (s): {self.config['dist_timeout']}") # Log for output logging.info(f"Output Directory (output_dir): {self.config['output_dir']}") logging.info(f"Output Log File (output_log): {self.config['output_log']}") logging.info(f"Output Model File (output_model): {self.config['output_model']}") # Log input config after logging is set up if self.config["log_input"]: logging.info(f"Input config: {self.input_config}") # To do: log all default model parameters if self.config['debug']: logging.debug(f"Debug Mode (debug): {self.config['debug']}") logging.debug(f"All parameters: {self.config}") # log all default parameters except for other model default parameters # Initialize counters local_steps = 0 total_steps = 0 running_loss = 0.0 running_loss_count = 0 training_time = 0.0 prev_loss = None best_val_loss = np.inf if self.rank == 0: logging.info("---------------------- Training ------------------------") # Training loop for epoch in range(max_epochs): if epoch >= max_epochs - 1: logging.info(f"Reached maximum epochs ({max_epochs}), stopping training!") if self.distributed: self._cleanup_distributed() return if hasattr(self.train_sampler, "set_epoch"): # For distributed sampler, set epoch for shuffling self.train_sampler.set_epoch(epoch) self.model.train() for batch in self.train_loader: train_start_time = time.time() device_batch = batch.to(self.device) # Forward pass and backward pass self.optimizer.zero_grad() out = self.model(device_batch) # Calculate losses energy_loss = self.criterion(out.energy, device_batch.energy) if bool(self.config["forces_weight"]): forces_loss = forces_criterion(out.forces, device_batch.forces) else: forces_loss = 0.0 # Total loss total_loss = self.config["forces_weight"] * forces_loss + (1 - self.config["forces_weight"]) * energy_loss total_loss.backward() # Apply gradient clipping if self.config.get("max_grad_norm") is not None: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.get("max_grad_norm")) self.optimizer.step() # Update running loss running_loss += total_loss.detach().cpu().numpy() * device_batch.energy.shape[0] running_loss_count += device_batch.energy.shape[0] training_time += time.time() - train_start_time if self.distributed: total_steps = local_steps * self.world_size + self.rank + self.init_steps else: total_steps = local_steps + self.init_steps # Log training progress if (total_steps % self.config["log_interval"] == 0) or ((total_steps + 1) == self.config["max_steps"]): eval_start = time.time() train_loss = running_loss / running_loss_count running_loss = 0.0 running_loss_count = 0 # Evaluate model eval_dict = self.eval_model() eval_formatted = ", ".join( ["{}={:.3f}".format(k, v) for (k, v) in eval_dict.items()] ) eval_loss = np.square(eval_dict["sqrt(total_loss)"]) smooth_loss = eval_loss if prev_loss is None else 0.9 * eval_loss + 0.1 * prev_loss prev_loss = smooth_loss eval_time = (time.time() - eval_start) / 60 if self.rank == 0: log_msg = ( f"step={total_steps}, {eval_formatted}, " f"sqrt(train_loss)={math.sqrt(train_loss):.3f}, " f"patience={self.early_stop.counter:3d}, " f"training time={training_time/60:.3f} min, " f"eval time={eval_time:.3f} min" ) logging.info(log_msg) training_time = 0 if self.scheduler_type=="ReduceLROnPlateau": self.scheduler.step(smooth_loss) # Check early stopping and save best model if not self.early_stop(math.sqrt(smooth_loss), best_val_loss): best_val_loss = math.sqrt(smooth_loss) self._save_model(self.config["output_model"], total_steps, best_val_loss) else: logging.info(f"Early stopping, training complete") if self.distributed: self._cleanup_distributed() return # Count steps local_steps += 1 if self.scheduler_type!="ReduceLROnPlateau": self.scheduler.step() # update learning rate based on scheduler type if bool(self.config["max_steps"]) and total_steps >= self.config["max_steps"]: logging.info(f"Maximum steps {self.config['max_steps']} reached, training complete") self._save_model("exit_model.pth", total_steps, best_val_loss) if self.distributed: self._cleanup_distributed() return self._save_model("final_model.pth", total_steps, best_val_loss) # Clean up distributed training if self.distributed: self._cleanup_distributed()
def main(): """Command-line entry point""" args = get_arguments() config = {} if bool(args.cfg): with open(args.cfg, 'r') as f: config = toml.load(f) if bool(args.model_type): args.model_type = args.model_type else: args.model_type = "painn" # Check if we're running under SLURM slurm_distributed = 'SLURM_JOB_NUM_NODES' in os.environ or 'SLURM_NTASKS' in os.environ if slurm_distributed: # Get SLURM environment variables for distributed training if 'SLURM_NTASKS' in os.environ: world_size = int(os.environ['SLURM_NTASKS']) else: world_size = 1 if 'SLURM_PROCID' in os.environ: rank = int(os.environ['SLURM_PROCID']) else: rank = 0 if 'SLURM_LOCALID' in os.environ: local_rank = int(os.environ['SLURM_LOCALID']) # Set device based on local rank if torch.cuda.is_available(): torch.cuda.set_device(local_rank) # Create trainer with distributed settings trainer = Trainer( model=args.model_type, config=config, distributed=True, rank=rank, world_size=world_size ) # Train the model trainer.train(args.dataset) else: # Detect if we have multiple GPUs num_gpus = torch.cuda.device_count() if num_gpus > 1: # Multi-GPU training on a single node mp.spawn( process_function, args=(num_gpus, args.model_type, config, args.dataset), nprocs=num_gpus, join=True ) else: # Single-GPU or CPU training trainer = Trainer(model=args.model_type, config=config, distributed=False) trainer.train(args.dataset) def process_function(rank, world_size, model_type, config, dataset_path): """Function to be spawned in each process for multi-GPU training""" # Set device for this process torch.cuda.set_device(rank) # Initialize environment variables os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '29500' trainer = Trainer(model=model_type, config=config, distributed=True, rank=rank, world_size=world_size) trainer.train(dataset_path) if __name__ == "__main__": main()