Skip to content

Utilities Module

Utility functions for device management, I/O, and D3 correction.

Device Utilities

Device management utilities for MACE inference

get_device

get_device(device: DeviceType = 'auto') -> str

Get the appropriate device for computation.

Parameters:

Name Type Description Default
device DeviceType

Device specification ("auto", "cpu", or "cuda")

'auto'

Returns:

Type Description
str

Device string ("cpu" or "cuda")

Raises:

Type Description
ValueError

If CUDA is requested but not available

Examples:

>>> device = get_device("auto")  # Auto-detect
>>> device = get_device("cuda")  # Force CUDA
>>> device = get_device("cpu")   # Force CPU
Source code in src/mace_inference/utils/device.py
def get_device(device: DeviceType = "auto") -> str:
    """
    Get the appropriate device for computation.

    Args:
        device: Device specification ("auto", "cpu", or "cuda")

    Returns:
        Device string ("cpu" or "cuda")

    Raises:
        ValueError: If CUDA is requested but not available

    Examples:
        >>> device = get_device("auto")  # Auto-detect
        >>> device = get_device("cuda")  # Force CUDA
        >>> device = get_device("cpu")   # Force CPU
    """
    if device == "auto":
        return "cuda" if torch.cuda.is_available() else "cpu"
    elif device == "cuda":
        if not torch.cuda.is_available():
            raise ValueError(
                "CUDA requested but not available. "
                "Install GPU version of PyTorch or use device='auto' or device='cpu'"
            )
        return "cuda"
    elif device == "cpu":
        return "cpu"
    else:
        raise ValueError(f"Invalid device: {device}. Must be 'auto', 'cpu', or 'cuda'")

validate_device

validate_device(device: str) -> None

Validate that the specified device is available.

Parameters:

Name Type Description Default
device str

Device string to validate

required

Raises:

Type Description
ValueError

If device is invalid or unavailable

Source code in src/mace_inference/utils/device.py
def validate_device(device: str) -> None:
    """
    Validate that the specified device is available.

    Args:
        device: Device string to validate

    Raises:
        ValueError: If device is invalid or unavailable
    """
    if device not in ["cpu", "cuda"]:
        raise ValueError(f"Invalid device: {device}")

    if device == "cuda" and not torch.cuda.is_available():
        raise ValueError("CUDA device requested but not available")

get_device_info

get_device_info() -> dict

Get information about available compute devices.

Returns:

Type Description
dict

Dictionary with device information

Source code in src/mace_inference/utils/device.py
def get_device_info() -> dict:
    """
    Get information about available compute devices.

    Returns:
        Dictionary with device information
    """
    info = {
        "cuda_available": torch.cuda.is_available(),
        "cuda_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
        "pytorch_version": torch.__version__,
    }

    if torch.cuda.is_available():
        info["cuda_version"] = torch.version.cuda
        info["devices"] = [
            {
                "index": i,
                "name": torch.cuda.get_device_name(i),
                "memory": f"{torch.cuda.get_device_properties(i).total_memory / 1e9:.2f} GB"
            }
            for i in range(torch.cuda.device_count())
        ]

    return info

get_device

def get_device(device: str = "auto") -> str

Determine the best available device for computation.

Parameters:

  • device (str): Device specification
  • "auto": Automatically select best available
  • "cpu": Force CPU
  • "cuda": Use CUDA GPU
  • "cuda:0", "cuda:1": Specific GPU
  • "mps": Apple Silicon GPU

Returns: str - The device string to use

Example:

from mace_inference.utils import get_device

device = get_device("auto")
print(f"Using device: {device}")  # e.g., "cuda" or "cpu"

I/O Utilities

I/O utilities for structure handling

load_structure

load_structure(filepath: Union[str, Path], index: int = -1) -> Atoms

Load atomic structure from file.

Parameters:

Name Type Description Default
filepath Union[str, Path]

Path to structure file (CIF, POSCAR, XYZ, etc.)

required
index int

Frame index for trajectory files (-1 for last frame)

-1

Returns:

Type Description
Atoms

ASE Atoms object

Examples:

>>> atoms = load_structure("structure.cif")
>>> atoms = load_structure("trajectory.traj", index=0)  # First frame
Source code in src/mace_inference/utils/io.py
def load_structure(filepath: Union[str, Path], index: int = -1) -> Atoms:
    """
    Load atomic structure from file.

    Args:
        filepath: Path to structure file (CIF, POSCAR, XYZ, etc.)
        index: Frame index for trajectory files (-1 for last frame)

    Returns:
        ASE Atoms object

    Examples:
        >>> atoms = load_structure("structure.cif")
        >>> atoms = load_structure("trajectory.traj", index=0)  # First frame
    """
    filepath = Path(filepath)
    if not filepath.exists():
        raise FileNotFoundError(f"Structure file not found: {filepath}")

    return read(str(filepath), index=index)

save_structure

save_structure(atoms: Atoms, filepath: Union[str, Path], format: str = None, **kwargs) -> None

Save atomic structure to file.

Parameters:

Name Type Description Default
atoms Atoms

ASE Atoms object

required
filepath Union[str, Path]

Output file path

required
format str

File format (auto-detected from extension if None)

None
**kwargs

Additional arguments passed to ase.io.write

{}

Examples:

>>> save_structure(atoms, "output.cif")
>>> save_structure(atoms, "output.xyz", format="xyz")
Source code in src/mace_inference/utils/io.py
def save_structure(
    atoms: Atoms,
    filepath: Union[str, Path],
    format: str = None,
    **kwargs
) -> None:
    """
    Save atomic structure to file.

    Args:
        atoms: ASE Atoms object
        filepath: Output file path
        format: File format (auto-detected from extension if None)
        **kwargs: Additional arguments passed to ase.io.write

    Examples:
        >>> save_structure(atoms, "output.cif")
        >>> save_structure(atoms, "output.xyz", format="xyz")
    """
    filepath = Path(filepath)
    filepath.parent.mkdir(parents=True, exist_ok=True)

    write(str(filepath), atoms, format=format, **kwargs)

parse_structure_input

parse_structure_input(structure: Union[str, Path, Atoms, List[Atoms]]) -> Union[Atoms, List[Atoms]]

Parse various structure input formats.

Parameters:

Name Type Description Default
structure Union[str, Path, Atoms, List[Atoms]]

Structure file path, Atoms object, or list of Atoms

required

Returns:

Type Description
Union[Atoms, List[Atoms]]

Atoms object or list of Atoms objects

Examples:

>>> atoms = parse_structure_input("structure.cif")
>>> atoms = parse_structure_input(existing_atoms)
Source code in src/mace_inference/utils/io.py
def parse_structure_input(
    structure: Union[str, Path, Atoms, List[Atoms]]
) -> Union[Atoms, List[Atoms]]:
    """
    Parse various structure input formats.

    Args:
        structure: Structure file path, Atoms object, or list of Atoms

    Returns:
        Atoms object or list of Atoms objects

    Examples:
        >>> atoms = parse_structure_input("structure.cif")
        >>> atoms = parse_structure_input(existing_atoms)
    """
    if isinstance(structure, (str, Path)):
        return load_structure(structure)
    elif isinstance(structure, Atoms):
        return structure
    elif isinstance(structure, list):
        if not all(isinstance(a, Atoms) for a in structure):
            raise ValueError("All elements in list must be ASE Atoms objects")
        return structure
    else:
        raise TypeError(
            f"Invalid structure input type: {type(structure)}. "
            "Expected str, Path, Atoms, or List[Atoms]"
        )

create_supercell

create_supercell(atoms: Atoms, supercell_matrix: Union[List[int], ndarray, int]) -> Atoms

Create a supercell from the input structure.

Parameters:

Name Type Description Default
atoms Atoms

Input ASE Atoms object

required
supercell_matrix Union[List[int], ndarray, int]

Supercell size (e.g., [2, 2, 2] or 2 for isotropic)

required

Returns:

Type Description
Atoms

Supercell Atoms object

Examples:

>>> supercell = create_supercell(atoms, [2, 2, 2])
>>> supercell = create_supercell(atoms, 2)  # Same as [2, 2, 2]
Source code in src/mace_inference/utils/io.py
def create_supercell(
    atoms: Atoms,
    supercell_matrix: Union[List[int], np.ndarray, int]
) -> Atoms:
    """
    Create a supercell from the input structure.

    Args:
        atoms: Input ASE Atoms object
        supercell_matrix: Supercell size (e.g., [2, 2, 2] or 2 for isotropic)

    Returns:
        Supercell Atoms object

    Examples:
        >>> supercell = create_supercell(atoms, [2, 2, 2])
        >>> supercell = create_supercell(atoms, 2)  # Same as [2, 2, 2]
    """
    if isinstance(supercell_matrix, int):
        supercell_matrix = [supercell_matrix] * 3

    if len(supercell_matrix) != 3:
        raise ValueError(f"Supercell matrix must have 3 elements, got {len(supercell_matrix)}")

    return atoms * tuple(supercell_matrix)

atoms_to_dict

atoms_to_dict(atoms: Atoms) -> dict

Convert Atoms object to dictionary (for JSON serialization).

Parameters:

Name Type Description Default
atoms Atoms

ASE Atoms object

required

Returns:

Type Description
dict

Dictionary representation

Source code in src/mace_inference/utils/io.py
def atoms_to_dict(atoms: Atoms) -> dict:
    """
    Convert Atoms object to dictionary (for JSON serialization).

    Args:
        atoms: ASE Atoms object

    Returns:
        Dictionary representation
    """
    return {
        "symbols": atoms.get_chemical_symbols(),
        "positions": atoms.get_positions().tolist(),
        "cell": atoms.get_cell().tolist() if atoms.pbc.any() else None,
        "pbc": atoms.pbc.tolist(),
        "numbers": atoms.numbers.tolist(),
    }

dict_to_atoms

dict_to_atoms(data: dict) -> Atoms

Convert dictionary to Atoms object.

Parameters:

Name Type Description Default
data dict

Dictionary with structure data

required

Returns:

Type Description
Atoms

ASE Atoms object

Source code in src/mace_inference/utils/io.py
def dict_to_atoms(data: dict) -> Atoms:
    """
    Convert dictionary to Atoms object.

    Args:
        data: Dictionary with structure data

    Returns:
        ASE Atoms object
    """
    return Atoms(
        symbols=data["symbols"],
        positions=data["positions"],
        cell=data.get("cell"),
        pbc=data.get("pbc", [False, False, False])
    )

load_structure

def load_structure(path: str) -> Atoms

Load atomic structure from file.

Parameters:

  • path (str): Path to structure file (CIF, POSCAR, XYZ, etc.)

Returns: Atoms - ASE Atoms object

save_structure

def save_structure(atoms: Atoms, path: str, format: Optional[str] = None)

Save atomic structure to file.

Parameters:

  • atoms (Atoms): Structure to save
  • path (str): Output file path
  • format (str): File format (auto-detected if None)

D3 Correction

D3 dispersion correction utilities

create_d3_calculator

create_d3_calculator(device: str = 'cpu', damping: str = 'bj', xc: str = 'pbe')

Create a DFT-D3 dispersion correction calculator.

Parameters:

Name Type Description Default
device str

Compute device ("cpu" or "cuda")

'cpu'
damping str

Damping function ("zero", "bj", "zerom", "bjm")

'bj'
xc str

Exchange-correlation functional (e.g., "pbe", "b3lyp")

'pbe'

Returns:

Type Description

TorchDFTD3Calculator instance

Raises:

Type Description
ImportError

If torch-dftd is not installed

Examples:

>>> d3_calc = create_d3_calculator(device="cuda", damping="bj", xc="pbe")
Source code in src/mace_inference/utils/d3_correction.py
def create_d3_calculator(device: str = "cpu", damping: str = "bj", xc: str = "pbe"):
    """
    Create a DFT-D3 dispersion correction calculator.

    Args:
        device: Compute device ("cpu" or "cuda")
        damping: Damping function ("zero", "bj", "zerom", "bjm")
        xc: Exchange-correlation functional (e.g., "pbe", "b3lyp")

    Returns:
        TorchDFTD3Calculator instance

    Raises:
        ImportError: If torch-dftd is not installed

    Examples:
        >>> d3_calc = create_d3_calculator(device="cuda", damping="bj", xc="pbe")
    """
    try:
        from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
    except ImportError:
        raise ImportError(
            "torch-dftd is required for D3 dispersion correction. "
            "Install with: pip install torch-dftd or pip install mace-inference[d3]"
        )

    return TorchDFTD3Calculator(device=device, damping=damping, xc=xc)

create_combined_calculator

create_combined_calculator(mace_calculator, enable_d3: bool = False, d3_device: Optional[str] = None, d3_damping: str = 'bj', d3_xc: str = 'pbe')

Create a combined MACE + D3 calculator using ASE's SumCalculator.

Parameters:

Name Type Description Default
mace_calculator

MACE calculator instance

required
enable_d3 bool

Whether to enable D3 correction

False
d3_device Optional[str]

Device for D3 calculator (defaults to MACE device)

None
d3_damping str

D3 damping function

'bj'
d3_xc str

D3 exchange-correlation functional

'pbe'

Returns:

Type Description

Combined calculator (or just MACE calculator if D3 disabled)

Examples:

>>> from mace.calculators import mace_mp
>>> mace_calc = mace_mp(model="medium", device="cuda")
>>> combined = create_combined_calculator(mace_calc, enable_d3=True)
Source code in src/mace_inference/utils/d3_correction.py
def create_combined_calculator(
    mace_calculator,
    enable_d3: bool = False,
    d3_device: Optional[str] = None,
    d3_damping: str = "bj",
    d3_xc: str = "pbe"
):
    """
    Create a combined MACE + D3 calculator using ASE's SumCalculator.

    Args:
        mace_calculator: MACE calculator instance
        enable_d3: Whether to enable D3 correction
        d3_device: Device for D3 calculator (defaults to MACE device)
        d3_damping: D3 damping function
        d3_xc: D3 exchange-correlation functional

    Returns:
        Combined calculator (or just MACE calculator if D3 disabled)

    Examples:
        >>> from mace.calculators import mace_mp
        >>> mace_calc = mace_mp(model="medium", device="cuda")
        >>> combined = create_combined_calculator(mace_calc, enable_d3=True)
    """
    if not enable_d3:
        return mace_calculator

    try:
        from ase.calculators.mixing import SumCalculator
    except ImportError:
        warnings.warn("ASE SumCalculator not available, D3 correction disabled")
        return mace_calculator

    # Determine D3 device from MACE calculator if not specified
    if d3_device is None:
        # Try to extract device from MACE calculator
        if hasattr(mace_calculator, 'device'):
            d3_device = mace_calculator.device
        else:
            d3_device = "cpu"

    d3_calc = create_d3_calculator(device=d3_device, damping=d3_damping, xc=d3_xc)

    return SumCalculator([mace_calculator, d3_calc])

check_d3_available

check_d3_available() -> bool

Check if torch-dftd is installed.

Source code in src/mace_inference/utils/d3_correction.py
def check_d3_available() -> bool:
    """Check if torch-dftd is installed."""
    try:
        import torch_dftd  # noqa: F401
        return True
    except ImportError:
        return False

create_d3_calculator

def create_d3_calculator(
    base_calculator: Calculator,
    xc: str = "pbe",
    damping: str = "bj"
) -> Calculator

Create a combined MACE+D3 calculator.

Parameters:

  • base_calculator (Calculator): The MACE calculator
  • xc (str): Exchange-correlation functional
  • damping (str): Damping function ("bj" or "zero")

Returns: Calculator - Sum calculator with MACE + D3

Example:

from mace_inference.utils.d3_correction import create_d3_calculator
from mace.calculators import MACECalculator

mace_calc = MACECalculator(model_path="medium", device="cpu")
combined_calc = create_d3_calculator(mace_calc, xc="pbe", damping="bj")

D3_AVAILABLE

D3_AVAILABLE: bool

Boolean indicating if torch-dftd is installed.

from mace_inference.utils.d3_correction import D3_AVAILABLE

if D3_AVAILABLE:
    print("D3 correction is available")
else:
    print("Install torch-dftd for D3 support")