diff --git a/docs/src/engines/ase.rst b/docs/src/engines/ase.rst index da8bce2d..8de8b324 100644 --- a/docs/src/engines/ase.rst +++ b/docs/src/engines/ase.rst @@ -23,6 +23,10 @@ Supported model outputs :py:meth:`ase.Atoms.get_forces`, …); - arbitrary outputs can be computed for any :py:class:`ase.Atoms` using :py:meth:`MetatomicCalculator.run_model`; +- for non-equivariant architectures like + `PET `_, + rotatonally-averaged energies, forces, and stresses can be computed using + :py:class:`metatomic.torch.ase_calculator.SymmetrizedCalculator`. How to install the code ^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/src/torch/reference/ase.rst b/docs/src/torch/reference/ase.rst index f217a3b9..ddb1f49a 100644 --- a/docs/src/torch/reference/ase.rst +++ b/docs/src/torch/reference/ase.rst @@ -17,3 +17,7 @@ not just the energy, through the .. autoclass:: metatomic.torch.ase_calculator.MetatomicCalculator :show-inheritance: :members: + +.. autoclass:: metatomic.torch.ase_calculator.SymmetrizedCalculator + :show-inheritance: + :members: diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index b5b2e4dd..0b917f69 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -2,7 +2,7 @@ import os import pathlib import warnings -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import metatensor.torch import numpy as np @@ -32,7 +32,6 @@ all_properties as ALL_ASE_PROPERTIES, ) - FilePath = Union[str, bytes, pathlib.PurePath] LOGGER = logging.getLogger(__name__) @@ -850,3 +849,450 @@ def _full_3x3_to_voigt_6_stress(stress): (stress[0, 1] + stress[1, 0]) / 2.0, ] ) + + +class SymmetrizedCalculator(ase.calculators.calculator.Calculator): + r""" + Take a MetatomicCalculator and average its predictions to make it (approximately) + equivariant. + + The default is to average over a quadrature of the orthogonal group O(3) composed + this way: + + - Lebedev quadrature of the unit sphere (S^2) + - Equispaced sampling of the unit circle (S^1) + - Both proper and improper rotations are taken into account by including the + inversion operation (if ``include_inversion=True``) + + :param base_calculator: the MetatomicCalculator to be symmetrized + :param l_max: the maximum spherical harmonic degree that the model is expected to + be able to represent. This is used to choose the quadrature order. If ``0``, + no rotational averaging will be performed (it can be useful to average only over + the space group, see ``apply_group_symmetry``). + :param batch_size: number of rotated systems to evaluate at once. If ``None``, all + systems will be evaluated at once (this can lead to high memory usage). + :param include_inversion: if ``True``, the inversion operation will be included in + the averaging. This is required to average over the full orthogonal group O(3). + :param apply_space_group_symmetry: if ``True``, the results will be averaged over + discrete space group of rotations for the input system. The group operations are + computed with spglib, and the average is performed after the O(3) averaging + (if any). This has no effect for non-periodic systems. + :param store_rotational_std: if ``True``, the results will contain the standard + deviation over the different rotations for each property (e.g., ``energy_std``). + :param \*\*kwargs: additional arguments passed to the ASE Calculator constructor + """ + + implemented_properties = ["energy", "forces", "stress"] + + def __init__( + self, + base_calculator: MetatomicCalculator, + *, + l_max: int = 3, + batch_size: Optional[int] = None, + include_inversion: bool = True, + apply_space_group_symmetry: bool = False, + store_rotational_std: bool = False, + **kwargs: Any, + ) -> None: + try: + from scipy.integrate import lebedev_rule # noqa: F401 + except ImportError as e: + raise ImportError( + "scipy is required to use the SO3AveragedCalculator, please install " + "it with `pip install scipy` or `conda install scipy`" + ) from e + + super().__init__(**kwargs) + + self.base_calculator = base_calculator + if l_max > 131: + raise ValueError( + f"l_max={l_max} is too large, the maximum supported value is 131" + ) + self.l_max = l_max + self.include_inversion = include_inversion + + if l_max > 0: + lebedev_order, n_inplane_rotations = _choose_quadrature(l_max) + self.quadrature_rotations, self.quadrature_weights = _get_quadrature( + lebedev_order, n_inplane_rotations, include_inversion + ) + else: + # no quadrature + self.quadrature_rotations = np.array([np.eye(3)]) + self.quadrature_weights = np.array([1.0]) + + self.batch_size = ( + batch_size if batch_size is not None else len(self.quadrature_rotations) + ) + + self.store_rotational_std = store_rotational_std + self.apply_space_group_symmetry = apply_space_group_symmetry + + def calculate( + self, atoms: ase.Atoms, properties: List[str], system_changes: List[str] + ) -> None: + """ + Perform the calculation for the given atoms and properties. + + :param atoms: the :py:class:`ase.Atoms` on which to perform the calculation + :param properties: list of properties to compute, among ``energy``, ``forces``, + and ``stress`` + :param system_changes: list of changes to the system since the last call to + ``calculate`` + """ + super().calculate(atoms, properties, system_changes) + + compute_forces_and_stresses = "forces" in properties or "stress" in properties + + if len(self.quadrature_rotations) > 0: + rotated_atoms_list = _rotate_atoms(atoms, self.quadrature_rotations) + batches = [ + rotated_atoms_list[i : i + self.batch_size] + for i in range(0, len(rotated_atoms_list), self.batch_size) + ] + results: Dict[str, np.ndarray] = {} + for batch in batches: + try: + batch_results = self.base_calculator.compute_energy( + batch, compute_forces_and_stresses + ) + for key, value in batch_results.items(): + results.setdefault(key, []) + results[key].extend( + [value] if isinstance(value, float) else value + ) + except torch.cuda.OutOfMemoryError as e: + raise RuntimeError( + "Out of memory error encountered during rotational averaging. " + "Please reduce the batch size or use lower rotational " + "averaging parameters. This can be done by setting the " + "`batch_size` and `l_max` parameters while initializing the " + "calculator." + ) from e + + self.results.update( + _compute_rotational_average( + results, + self.quadrature_rotations, + self.quadrature_weights, + self.store_rotational_std, + ) + ) + + if self.apply_space_group_symmetry: + # Apply the discrete space group of the system a posteriori + Q_list, P_list = _get_group_operations(atoms) + self.results.update(_average_over_group(self.results, Q_list, P_list)) + + +def _choose_quadrature(L_max: int) -> Tuple[int, int]: + """ + Choose a Lebedev quadrature order and number of in-plane rotations to integrate + spherical harmonics up to degree ``L_max``. + + :param L_max: maximum spherical harmonic degree + :return: (lebedev_order, n_inplane_rotations) + """ + available = [ + 3, + 5, + 7, + 9, + 11, + 13, + 15, + 17, + 19, + 21, + 23, + 25, + 27, + 29, + 31, + 35, + 41, + 47, + 53, + 59, + 65, + 71, + 77, + 83, + 89, + 95, + 101, + 107, + 113, + 119, + 125, + 131, + ] + # pick smallest order >= L_max + n = min(o for o in available if o >= L_max) + # minimal gamma count + K = 2 * L_max + 1 + return n, K + + +def _rotate_atoms(atoms: ase.Atoms, rotations: List[np.ndarray]) -> List[ase.Atoms]: + """ + Create a list of copies of ``atoms``, rotated by each of the given ``rotations``. + + :param atoms: the :py:class:`ase.Atoms` to be rotated + :param rotations: (N, 3, 3) array of orthogonal matrices + :return: list of N :py:class:`ase.Atoms`, each rotated by the corresponding matrix + """ + rotated_atoms_list = [] + has_cell = atoms.cell is not None and atoms.cell.rank > 0 + for rot in rotations: + new_atoms = atoms.copy() + new_atoms.positions = new_atoms.positions @ rot.T + if has_cell: + new_atoms.set_cell( + new_atoms.cell.array @ rot.T, scale_atoms=False, apply_constraint=False + ) + new_atoms.wrap() + rotated_atoms_list.append(new_atoms) + return rotated_atoms_list + + +def _get_quadrature(lebedev_order: int, n_rotations: int, include_inversion: bool): + """ + Lebedev(S^2) x uniform angle quadrature on SO(3). + If include_inversion=True, extend to O(3) by adding inversion * R. + + :param lebedev_order: order of the Lebedev quadrature on the unit sphere + :param n_rotations: number of in-plane rotations per Lebedev node + :param include_inversion: if ``True``, include the inversion operation in the + quadrature + :return: (N, 3, 3) array of orthogonal matrices, and (N,) array of weights + associated to each matrix + """ + from scipy.integrate import lebedev_rule + + # Lebedev nodes (X: (3, M)) + X, w = lebedev_rule(lebedev_order) # w sums to 4*pi + x, y, z = X + alpha = np.arctan2(y, x) # (M,) + beta = np.arccos(z) # (M,) + # beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) + + K = int(n_rotations) + gamma = np.linspace(0.0, 2 * np.pi, K, endpoint=False) # (K,) + + Rot = _rotations_from_angles(alpha, beta, gamma) + R_so3 = Rot.as_matrix() # (N, 3, 3) + + # SO(3) Haar–probability weights: w_i/(4*pi*K), repeated over gamma + w_so3 = np.repeat(w / (4 * np.pi * K), repeats=gamma.size) # (N,) + + if not include_inversion: + return R_so3, w_so3 + + # Extend to O(3) by appending inversion * R + P = -np.eye(3) + R_o3 = np.concatenate([R_so3, P @ R_so3], axis=0) # (2N, 3, 3) + w_o3 = np.concatenate([0.5 * w_so3, 0.5 * w_so3], axis=0) + + return R_o3, w_o3 + + +def _rotations_from_angles(alpha, beta, gamma): + from scipy.spatial.transform import Rotation + + # Build all combinations (alpha_i, beta_i, gamma_j) + A = np.repeat(alpha, gamma.size) # (N,) + B = np.repeat(beta, gamma.size) # (N,) + G = np.tile(gamma, alpha.size) # (N,) + + # Compose ZYZ rotations in SO(3) + Rot = ( + Rotation.from_euler("z", A) + * Rotation.from_euler("y", B) + * Rotation.from_euler("z", G) + ) + + return Rot + + +def _compute_rotational_average(results, rotations, weights, store_std): + R = rotations + B = R.shape[0] + w = weights + w = w / w.sum() + + def _wreshape(x): + return w.reshape((B,) + (1,) * (x.ndim - 1)) + + def _wmean(x): + return np.sum(_wreshape(x) * x, axis=0) + + def _wstd(x): + mu = _wmean(x) + return np.sqrt(np.sum(_wreshape(x) * (x - mu) ** 2, axis=0)) + + out = {} + + # Energy (B,) + if "energy" in results: + E = np.asarray(results["energy"], dtype=float) + out["energy"] = _wmean(E) + if store_std: + out["energy_rot_std"] = _wstd(E) + + # Forces (B,N,3) from rotated structures: back-rotate with F' R + if "forces" in results: + F = np.asarray(results["forces"], dtype=float) # (B,N,3) + F_back = F @ R # F' R + out["forces"] = _wmean(F_back) # (N,3) + if store_std: + out["forces_rot_std"] = _wstd(F_back) # (N,3) + + # Stress (B,3,3) from rotated structures: back-rotate with R^T S' R + if "stress" in results: + S = np.asarray(results["stress"], dtype=float) # (B,3,3) + RT = np.swapaxes(R, 1, 2) + S_back = RT @ S @ R # R^T S' R + out["stress"] = _wmean(S_back) # (3,3) + if store_std: + out["stress_rot_std"] = _wstd(S_back) # (3,3) + + return out + + +def _get_group_operations( + atoms: ase.Atoms, symprec: float = 1e-6, angle_tolerance: float = -1.0 +) -> Tuple[List[np.ndarray], List[np.ndarray]]: + """ + Extract point-group rotations Q_g (Cartesian, 3x3) and the corresponding + atom-index permutations P_g (N x N) induced by the space-group operations. + Returns Q_list, Cartesian rotation matrices of the point group, + and P_list, permutation matrices mapping original indexing -> indexing after (R,t), + + :param atoms: input structure + :param symprec: tolerance for symmetry finding + :param angle_tolerance: tolerance for symmetry finding (in degrees). If less than 0, + a value depending on ``symprec`` will be chosen automatically by spglib. + :return: List of rotation matrices and permutation matrices. + + """ + try: + import spglib + except ImportError as e: + raise ImportError( + "spglib is required to use the SymmetrizedCalculator with " + "`apply_group_symmetry=True`. Please install it with " + "`pip install spglib` or `conda install -c conda-forge spglib`" + ) from e + + # Lattice with column vectors a1,a2,a3 (spglib expects (cell, frac, Z)) + A = atoms.cell.array.T # (3,3) + frac = atoms.get_scaled_positions() # (N,3) in [0,1) + numbers = atoms.numbers + N = len(atoms) + + data = spglib.get_symmetry_dataset( + (atoms.cell.array, frac, numbers), + symprec=symprec, + angle_tolerance=angle_tolerance, + ) + + if data is None: + # No symmetry found + return [], [] + R_frac = data.rotations # (n_ops, 3,3), integer + t_frac = data.translations # (n_ops, 3) + Z = numbers + + # Match fractional coords modulo 1 within a tolerance, respecting chemical species + def _match_index(x_new, frac_ref, Z_ref, Z_i, tol=1e-6): + d = np.abs(frac_ref - x_new) # (N,3) + d = np.minimum(d, 1.0 - d) # periodic distance + # Mask by identical species + mask = Z_ref == Z_i + if not np.any(mask): + raise RuntimeError("No matching species found while building permutation.") + # Choose argmin over max-norm within species + idx = np.where(mask)[0] + j = idx[np.argmin(np.max(d[idx], axis=1))] + + # Sanity check + if np.max(d[j]) > tol: + pass + return j + + Q_list, P_list = [], [] + seen = set() + Ainv = np.linalg.inv(A) + + for Rf, tf in zip(R_frac, t_frac, strict=False): + # Cartesian rotation: Q = A Rf A^{-1} + Q = A @ Rf @ Ainv + # Deduplicate rotations (point group) by rounding + key = tuple(np.round(Q.flatten(), 12)) + if key in seen: + continue + seen.add(key) + + # Build the permutation P from i to j + P = np.zeros((N, N), dtype=int) + new_frac = (frac @ Rf.T + tf) % 1.0 # images after (Rf,tf) + for i in range(N): + j = _match_index(new_frac[i], frac, Z, Z[i]) + P[j, i] = 1 # column i maps to row j + + Q_list.append(Q.astype(float)) + P_list.append(P) + + return Q_list, P_list + + +def _average_over_group( + results: dict, Q_list: List[np.ndarray], P_list: List[np.ndarray] +) -> dict: + """ + Apply the point-group projector in output space. + + :param results: Must contain 'energy' (scalar), and/or 'forces' (N,3), and/or + 'stress' (3,3). These are predictions for the current structure in the reference + frame. + :param Q_list: Rotation matrices of the point group, from + :py:func:`_get_group_operations` + :param P_list: Permutation matrices of the point group, from + :py:func:`_get_group_operations` + :return out: Projected quantities. + """ + m = len(Q_list) + if m == 0: + return results # nothing to do + + out = {} + # Energy: unchanged by the projector (scalar) + if "energy" in results: + out["energy"] = float(results["energy"]) + + # Forces: (N,3) row-vectors; projector: (1/|G|) \sum_g P_g^T F Q_g + if "forces" in results: + F = np.asarray(results["forces"], float) + if F.ndim != 2 or F.shape[1] != 3: + raise ValueError(f"'forces' must be (N,3), got {F.shape}") + acc = np.zeros_like(F) + for Q, P in zip(Q_list, P_list, strict=False): + acc += P.T @ (F @ Q) + out["forces"] = acc / m + + # Stress: (3,3); projector: (1/|G|) \sum_g Q_g^T S Q_g + if "stress" in results: + S = np.asarray(results["stress"], float) + if S.shape != (3, 3): + raise ValueError(f"'stress' must be (3,3), got {S.shape}") + # S = 0.5 * (S + S.T) # symmetrize just in case + acc = np.zeros_like(S) + for Q in Q_list: + acc += Q.T @ S @ Q + S_pg = acc / m + out["stress"] = S_pg + + return out diff --git a/python/metatomic_torch/tests/symmetrized_ase_calculator.py b/python/metatomic_torch/tests/symmetrized_ase_calculator.py new file mode 100644 index 00000000..f3e4884d --- /dev/null +++ b/python/metatomic_torch/tests/symmetrized_ase_calculator.py @@ -0,0 +1,396 @@ +import numpy as np +import pytest +from ase import Atoms +from ase.build import bulk, molecule + +from metatomic.torch.ase_calculator import SymmetrizedCalculator, _get_quadrature + + +def _body_axis_from_atoms(atoms: Atoms) -> np.ndarray: + """ + Return the normalized vector connecting the two farthest atoms. + + :param atoms: Atomic configuration. + :return: Normalized 3D vector defining the body axis. + """ + pos = atoms.get_positions() + if len(pos) < 2: + return np.array([0.0, 0.0, 1.0]) + d2 = np.sum((pos[:, None, :] - pos[None, :, :]) ** 2, axis=-1) + i, j = np.unravel_index(np.argmax(d2), d2.shape) + b = pos[j] - pos[i] + nrm = np.linalg.norm(b) + return b / nrm if nrm > 0 else np.array([0.0, 0.0, 1.0]) + + +def _legendre_0_1_2_3(c: float) -> tuple[float, float, float, float]: + """ + Compute Legendre polynomials P0..P3(c). + + :param c: Cosine between the body axis and the lab z-axis. + :return: Tuple (P0, P1, P2, P3). + """ + P0 = 1.0 + P1 = c + P2 = 0.5 * (3 * c * c - 1.0) + P3 = 0.5 * (5 * c * c * c - 3 * c) + return P0, P1, P2, P3 + + +class MockAnisoCalculator: + """ + Deterministic, rotation-dependent mock for testing SymmetrizedCalculator. + + Components: + - Energy: E_true + a1*P1 + a2*P2 + a3*P3 + - Forces: F_true + (b1*P1 + b2*P2 + b3*P3)*ẑ + optional tensor L=2 term + - Stress: p_iso*I + (c2*P2 + c3*P3)*D + + :param a: Coefficients for Legendre P0..P3 in the energy. + :param b: Coefficients for P1..P3 in the forces (spurious vector parts). + :param c: Coefficients for P2,P3 in the stress (spurious deviators). + :param p_iso: Isotropic (true) part of the stress tensor. + :param tensor_forces: If True, add L=2 tensor-coupled force term. + :param tensor_amp: Amplitude of the tensor-coupled force component. + """ + + def __init__( + self, + a: tuple[float, float, float, float] = (0.0, 0.0, 0.0, 0.0), + b: tuple[float, float, float] = (0.0, 0.0, 0.0), + c: tuple[float, float] = (0.0, 0.0), + p_iso: float = 1.0, + tensor_forces: bool = False, + tensor_amp: float = 0.5, + ) -> None: + self.a0, self.a1, self.a2, self.a3 = a + self.b1, self.b2, self.b3 = b + self.c2, self.c3 = c + self.p_iso = p_iso + self.tensor_forces = tensor_forces + self.tensor_amp = tensor_amp + + def compute_energy( + self, + batch: list[Atoms], + compute_forces_and_stresses: bool = False, + ) -> dict[str, list[np.ndarray | float]]: + """ + Compute deterministic, rotation-dependent properties for each batch entry. + + :param batch: List of atomic configurations. + :param compute_forces_and_stresses: Unused flag for API compatibility. + :return: Dictionary with lists of energies, forces, and stresses. + """ + out: dict[str, list[np.ndarray | float]] = { + "energy": [], + "forces": [], + "stress": [], + } + zhat = np.array([0.0, 0.0, 1.0]) + D = np.diag([1.0, -1.0, 0.0]) + + for atoms in batch: + pos = atoms.get_positions() + b = _body_axis_from_atoms(atoms) + c = float(np.dot(b, zhat)) + P0, P1, P2, P3 = _legendre_0_1_2_3(c) + + # Energy + E_true = float(np.sum(pos**2)) + E = E_true + self.a0 * P0 + self.a1 * P1 + self.a2 * P2 + self.a3 * P3 + + # Forces + F_true = pos.copy() + F_spur = (self.b1 * P1 + self.b2 * P2 + self.b3 * P3) * zhat[None, :] + F = F_true + F_spur + + if self.tensor_forces: + # Build rotation R such that R ẑ = b + v = np.cross(zhat, b) + s = np.linalg.norm(v) + cth = np.dot(zhat, b) + if s < 1e-15: + R = np.eye(3) if cth > 0 else -np.eye(3) + else: + vx = np.array( + [[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]] + ) + R = np.eye(3) + vx + vx @ vx * ((1 - cth) / (s**2)) + T = R @ D @ R.T + F_tensor = self.tensor_amp * (T @ zhat) + F = F + F_tensor[None, :] + + # Stress + S = self.p_iso * np.eye(3) + (self.c2 * P2 + self.c3 * P3) * D + + out["energy"].append(E) + out["forces"].append(F) + out["stress"].append(S) + return out + + +@pytest.fixture +def dimer() -> Atoms: + """ + Create a small asymmetric geometry with a well-defined body axis. + + :return: ASE Atoms object with the H2 molecule. + """ + return Atoms("H2", positions=[[0, 0, 0], [0.3, 0.2, 1.0]]) + + +def test_quadrature_normalization() -> None: + """Verify normalization and determinant signs of the quadrature.""" + R, w = _get_quadrature(lebedev_order=11, n_rotations=5, include_inversion=True) + assert np.isclose(np.sum(w), 1.0) + dets = np.linalg.det(R) + assert np.all(np.isin(np.round(dets).astype(int), [-1, 1])) + + +@pytest.mark.parametrize("Lmax, expect_removed", [(0, False), (3, True)]) +def test_energy_L_components_removed( + dimer: Atoms, Lmax: int, expect_removed: bool +) -> None: + """ + Verify that spurious energy components vanish once rotational averaging is applied. + For Lmax>0, all use the same minimal Lebedev rule (order=3). + """ + a = (1.0, 1.0, 1.0, 1.0) + base = MockAnisoCalculator(a=a) + calc = SymmetrizedCalculator(base, l_max=Lmax) + dimer.calc = calc + e = dimer.get_potential_energy() + E_true = float(np.sum(dimer.positions**2)) + if expect_removed: + assert np.isclose(e, E_true + a[0], atol=1e-10) + else: + assert not np.isclose(e, E_true + a[0], atol=1e-10) + + +def test_force_backrotation_exact(dimer: Atoms) -> None: + """ + Check that forces are back-rotated exactly when no spurious terms are present. + + :param dimer: Test atomic structure. + """ + base = MockAnisoCalculator(b=(0, 0, 0)) + calc = SymmetrizedCalculator(base, l_max=3) + dimer.calc = calc + F = dimer.get_forces() + assert np.allclose(F, dimer.positions, atol=1e-12) + + +def test_tensorial_L2_force_cancellation(dimer: Atoms) -> None: + """ + Tensor-coupled (L=2) force components must vanish under O(3) averaging. + + Since the minimal Lebedev order used internally is 3, all quadratures + integrate L=2 components exactly; we only check for correct cancellation. + """ + base = MockAnisoCalculator(tensor_forces=True, tensor_amp=1.0) + + for Lmax in [1, 2, 3]: + calc = SymmetrizedCalculator(base, l_max=Lmax) + dimer.calc = calc + F = dimer.get_forces() + assert np.allclose(F, dimer.positions, atol=1e-10) + + +def test_stress_isotropization(dimer: Atoms) -> None: + """ + Check that stress deviatoric parts (L=2,3) vanish under full O(3) averaging. + + :param dimer: Test atomic structure. + """ + base = MockAnisoCalculator(c=(1.0, 1.0), p_iso=5.0) + calc = SymmetrizedCalculator(base, l_max=3, include_inversion=True) + dimer.calc = calc + S = dimer.get_stress(voigt=False) + iso = np.trace(S) / 3.0 + assert np.allclose(S, np.eye(3) * iso, atol=1e-10) + assert np.isclose(iso, 5.0, atol=1e-10) + + +def test_cancellation_vs_Lmax(dimer: Atoms) -> None: + """ + Residual anisotropy must vanish once rotational averaging is applied. + All quadratures with Lmax>0 are equivalent (Lebedev order=3). + """ + a = (0.0, 0.0, 1.0, 1.0) + base = MockAnisoCalculator(a=a) + E_true = float(np.sum(dimer.positions**2)) + + # No averaging + calc0 = SymmetrizedCalculator(base, l_max=0) + dimer.calc = calc0 + e0 = dimer.get_potential_energy() + + # Averaged + calc3 = SymmetrizedCalculator(base, l_max=3) + dimer.calc = calc3 + e3 = dimer.get_potential_energy() + + assert not np.isclose(e0, E_true, atol=1e-10) + assert np.isclose(e3, E_true, atol=1e-10) + + +def test_joint_energy_force_consistency(dimer: Atoms) -> None: + """ + Combined test: both energy and forces are consistent and invariant. + + :param dimer: Test atomic structure. + """ + base = MockAnisoCalculator(a=(1, 1, 1, 1), b=(0, 0, 0)) + calc = SymmetrizedCalculator(base, l_max=3) + dimer.calc = calc + e = dimer.get_potential_energy() + f = dimer.get_forces() + assert np.isclose(e, np.sum(dimer.positions**2) + 1.0, atol=1e-10) + assert np.allclose(f, dimer.positions, atol=1e-12) + + +def test_rotate_atoms_preserves_geometry(tmp_path): + """Check that _rotate_atoms applies rotations correctly and preserves distances.""" + from scipy.spatial.transform import Rotation + + from metatomic.torch.ase_calculator import _rotate_atoms + + # Build simple cubic cell with 2 atoms along x + atoms = Atoms("H2", positions=[[0, 0, 0], [1, 0, 0]], cell=np.eye(3)) + R = Rotation.from_euler("z", 90, degrees=True).as_matrix()[None, ...] # 90° about z + + rotated = _rotate_atoms(atoms, R)[0] + # Positions should now align along y + assert np.allclose( + rotated.positions[1] - rotated.positions[0], [0, 1, 0], atol=1e-12 + ) + # Cell rotated + assert np.allclose(rotated.cell[0], [0, 1, 0], atol=1e-12) + # Distances preserved + d0 = atoms.get_distance(0, 1) + d1 = rotated.get_distance(0, 1) + assert np.isclose(d0, d1, atol=1e-12) + + +def test_choose_quadrature_rules(): + """Check that _choose_quadrature selects appropriate rules.""" + from metatomic.torch.ase_calculator import _choose_quadrature + + for L in [0, 5, 17, 50]: + lebedev_order, n_gamma = _choose_quadrature(L) + assert lebedev_order >= L + assert n_gamma == 2 * L + 1 + + +def test_get_quadrature_properties(): + """Check properties of the quadrature returned by _get_quadrature.""" + from metatomic.torch.ase_calculator import _get_quadrature + + R, w = _get_quadrature(lebedev_order=11, n_rotations=5, include_inversion=False) + assert np.isclose(np.sum(w), 1.0) + assert np.allclose([np.dot(r.T, r) for r in R], np.eye(3), atol=1e-12) + assert np.allclose(np.linalg.det(R), 1.0, atol=1e-12) + + R_inv, w_inv = _get_quadrature( + lebedev_order=11, n_rotations=5, include_inversion=True + ) + assert len(R_inv) == 2 * len(R) + dets = np.linalg.det(R_inv) + assert np.all(np.isin(np.sign(dets).astype(int), [-1, 1])) + assert np.isclose(np.sum(w_inv), 1.0) + + +def test_compute_rotational_average_identity(): + """Check that _compute_rotational_average produces correct averages.""" + from metatomic.torch.ase_calculator import _compute_rotational_average + + R = np.repeat(np.eye(3)[None, :, :], 3, axis=0) + w = np.ones(3) / 3 + results = { + "energy": np.array([1.0, 2.0, 3.0]), + "forces": np.array([[[1, 0, 0]], [[0, 1, 0]], [[0, 0, 1]]]), + "stress": np.array([np.eye(3), 2 * np.eye(3), 3 * np.eye(3)]), + } + out = _compute_rotational_average(results, R, w, False) + assert np.isclose(out["energy"], np.mean(results["energy"])) + assert np.allclose(out["forces"], np.mean(results["forces"], axis=0)) + assert np.allclose(out["stress"], np.mean(results["stress"], axis=0)) + + out = _compute_rotational_average(results, R, w, True) + assert "energy_rot_std" in out + assert "forces_rot_std" in out + assert "stress_rot_std" in out + + +def test_average_over_fcc_group(): + """ + Check that averaging over the space group of an FCC crystal + produces an isotropic (scalar) stress tensor. + """ + from metatomic.torch.ase_calculator import ( + _average_over_group, + _get_group_operations, + ) + + # FCC conventional cubic cell (4 atoms) + atoms = bulk("Cu", "fcc", cubic=True) + + energy = 0.0 + forces = np.random.normal(0, 1, (4, 3)) + forces -= np.mean(forces, axis=0) # Ensure zero net force + + # Create an intentionally anisotropic stress + stress = np.array([[10.0, 1.0, 0.0], [1.0, 5.0, 0.0], [0.0, 0.0, 1.0]]) + results = {"energy": energy, "forces": forces, "stress": stress} + + Q_list, P_list = _get_group_operations(atoms) + out = _average_over_group(results, Q_list, P_list) + + # Energy must be unchanged + assert np.isclose(out["energy"], energy) + + # Forces must average to zero by symmetry + F_pg = out["forces"] + assert np.allclose(F_pg, np.zeros_like(F_pg)) + + S_pg = out["stress"] + + # The averaged stress must be isotropic: S_pg = (trace/3)*I + iso = np.trace(S_pg) / 3.0 + assert np.allclose(S_pg, np.eye(3) * iso, atol=1e-8) + + +def test_space_group_average_non_periodic(): + """ + Check that averaging over the space group of a non-periodic system leaves the + results unchanged. + """ + from metatomic.torch.ase_calculator import ( + _average_over_group, + _get_group_operations, + ) + + # Methane molecule (Td symmetry) + atoms = molecule("CH4") + + energy = 0.0 + forces = np.random.normal(0, 1, (4, 3)) + forces -= np.mean(forces, axis=0) # Ensure zero net force + + results = {"energy": energy, "forces": forces} + + Q_list, P_list = _get_group_operations(atoms) + + # Check that the operation lists are empty + assert len(Q_list) == 0 + assert len(P_list) == 0 + + out = _average_over_group(results, Q_list, P_list) + + # Energy must be unchanged + assert np.isclose(out["energy"], energy) + + # Forces must be unchanged + F_pg = out["forces"] + assert np.allclose(F_pg, forces) diff --git a/tox.ini b/tox.ini index 86a3b3e8..795382c7 100644 --- a/tox.ini +++ b/tox.ini @@ -150,6 +150,9 @@ deps = # for metatensor-lj-test setuptools-scm cmake + # for symmetrized calculator + scipy + spglib changedir = python/metatomic_torch commands =