diff --git a/metatomic-torch/CHANGELOG.md b/metatomic-torch/CHANGELOG.md index 108a0575..b0ea893f 100644 --- a/metatomic-torch/CHANGELOG.md +++ b/metatomic-torch/CHANGELOG.md @@ -19,7 +19,15 @@ a changelog](https://keepachangelog.com/en/1.1.0/) format. This project follows ### Removed -Dropped support for deprecated Python 3.9, now requires 3.10 as minimum version +- We dropped support for Python 3.9, and now requires at least Python 3.10 + +### Added + +- `ModelOutput` now has a `description` field, to carry more information + about a given output. +- the `pick_output` function that can be used by simulation engines to pick the + correct output based on what's available inside a model and which variant (if + any) the user requested. ## [Version 0.1.5](https://github.com/metatensor/metatomic/releases/tag/metatomic-torch-v0.1.5) - 2025-10-06 diff --git a/metatomic-torch/include/metatomic/torch/misc.hpp b/metatomic-torch/include/metatomic/torch/misc.hpp index cd889823..ebeff887 100644 --- a/metatomic-torch/include/metatomic/torch/misc.hpp +++ b/metatomic-torch/include/metatomic/torch/misc.hpp @@ -7,6 +7,7 @@ #include #include +#include "metatomic/torch/model.hpp" #include "metatomic/torch/system.hpp" #include @@ -26,6 +27,14 @@ METATOMIC_TORCH_EXPORT std::string pick_device( torch::optional desired_device = torch::nullopt ); +/// Pick the output for the given `requested_output` from the availabilities of the +/// model's `outputs`, according to the optional `desired_variant`. +METATOMIC_TORCH_EXPORT std::string pick_output( + std::string requested_output, + torch::Dict outputs, + torch::optional desired_variant = torch::nullopt +); + // ===== File-based ===== void save(const std::string& path, const System& system); System load_system(const std::string& path); @@ -46,7 +55,7 @@ inline System load_system_buffer(const torch::Tensor& data) { throw std::runtime_error("System pickle: expected 1D torch.uint8 buffer"); } const uint8_t* ptr = t.data_ptr(); - const size_t n = static_cast(t.numel()); + const auto n = static_cast(t.numel()); return load_system_buffer(ptr, n); } diff --git a/metatomic-torch/include/metatomic/torch/model.hpp b/metatomic-torch/include/metatomic/torch/model.hpp index 5896ad8b..f2f9bdeb 100644 --- a/metatomic-torch/include/metatomic/torch/model.hpp +++ b/metatomic-torch/include/metatomic/torch/model.hpp @@ -39,7 +39,7 @@ bool valid_quantity(const std::string& quantity); void validate_unit(const std::string& quantity, const std::string& unit); -/// Description of one of the quantity a model can compute +/// Information about one of the quantity a model can compute class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder { public: ModelOutputHolder() = default; @@ -49,8 +49,10 @@ class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder std::string quantity, std::string unit, bool per_atom_, - std::vector explicit_gradients_ + std::vector explicit_gradients_, + std::string description_ ): + description(std::move(description_)), per_atom(per_atom_), explicit_gradients(std::move(explicit_gradients_)) { @@ -60,6 +62,9 @@ class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder ~ModelOutputHolder() override = default; + /// description of this output, defaults to empty string of not set by the user + std::string description; + /// quantity of the output (e.g. energy, dipole, …). If this is an empty /// string, no unit conversion will be performed. const std::string& quantity() const { diff --git a/metatomic-torch/src/misc.cpp b/metatomic-torch/src/misc.cpp index 006fc78a..8c3b3e79 100644 --- a/metatomic-torch/src/misc.cpp +++ b/metatomic-torch/src/misc.cpp @@ -1,8 +1,10 @@ #include +#include "metatomic/torch/model.hpp" #include "metatomic/torch/version.h" #include "metatomic/torch/misc.hpp" +#include #include #include #include @@ -75,6 +77,66 @@ std::string pick_device( return selected_device; } +std::string pick_output( + std::string requested_output, + torch::Dict outputs, + torch::optional desired_variant +) { + std::vector matching_keys; + bool has_exact = false; + + for (const auto& output: outputs) { + const auto& key = output.key(); + + // match either exact `requested_output` or `requested_output/` + if (key == requested_output + || (key.size() > requested_output.size() + && key.compare(0, requested_output.size(), requested_output) == 0 + && key[requested_output.size()] == '/')) { + matching_keys.emplace_back(key); + + if (key == requested_output) { + has_exact = true; + } + } + } + + if (matching_keys.empty()) { + C10_THROW_ERROR(ValueError, + "output '" + requested_output + "' not found in outputs" + ); + } + + if (desired_variant != torch::nullopt) { + const auto& output = requested_output + "/" + desired_variant.value(); + auto it = std::find(matching_keys.begin(), matching_keys.end(), output); + if (it != matching_keys.end()) { + return *it; + } + C10_THROW_ERROR(ValueError, + "variant '" + desired_variant.value() + "' for output '" + requested_output + + "' not found in outputs" + ); + } else if (has_exact) { + return requested_output; + } else { + std::ostringstream oss; + oss << "output '" << requested_output << "' has no default variant and no `desired_variant` was given. Available variants are:"; + + size_t maxlen = 0; + for (const auto& key: matching_keys) { + maxlen = std::max(key.size(), maxlen); + } + + for (const auto& key: matching_keys) { + auto description = outputs.at(key)->description; + std::string padding(maxlen - key.size(), ' '); + oss << "\n - '" << key << "'" << padding << ": " << description; + } + C10_THROW_ERROR(ValueError, oss.str()); + } +} + static bool ends_with(const std::string& s, const std::string& suff) { return s.size() >= suff.size() && s.compare(s.size() - suff.size(), suff.size(), suff) == 0; diff --git a/metatomic-torch/src/model.cpp b/metatomic-torch/src/model.cpp index 8be6ca35..e9452c2c 100644 --- a/metatomic-torch/src/model.cpp +++ b/metatomic-torch/src/model.cpp @@ -74,6 +74,7 @@ static nlohmann::json model_output_to_json(const ModelOutputHolder& self) { result["unit"] = self.unit(); result["per_atom"] = self.per_atom; result["explicit_gradients"] = self.explicit_gradients; + result["description"] = self.description; return result; } @@ -96,6 +97,7 @@ static ModelOutput model_output_from_json(const nlohmann::json& data) { } auto result = torch::make_intrusive(); + if (data.contains("quantity")) { if (!data["quantity"].is_string()) { throw std::runtime_error("'quantity' in JSON for ModelOutput must be a string"); @@ -125,6 +127,16 @@ static ModelOutput model_output_from_json(const nlohmann::json& data) { ); } + if (data.contains("description")) { + if (!data["description"].is_string()) { + throw std::runtime_error("'description' in JSON for ModelOutput must be a string"); + } + result->description = data["description"]; + } else { + // backward compatibility + result->description = ""; + } + return result; } @@ -147,12 +159,14 @@ std::unordered_set KNOWN_OUTPUTS = { }; void ModelCapabilitiesHolder::set_outputs(torch::Dict outputs) { - std::unordered_map> variants; + + std::unordered_map> variants; for (const auto& it: outputs) { const auto& name = it.key(); if (KNOWN_OUTPUTS.find(name) != KNOWN_OUTPUTS.end()) { // known output, nothing to do + variants[name].push_back(name); continue; } @@ -180,7 +194,7 @@ void ModelCapabilitiesHolder::set_outputs(torch::Dict ); } - variants[base].insert(variant); + variants[base].push_back(name); continue; } @@ -204,14 +218,20 @@ void ModelCapabilitiesHolder::set_outputs(torch::Dict ); } - // ensure each variant has a defined default base output + // check descriptions for each variant group for (const auto& kv : variants) { const auto& base = kv.first; - if (outputs.find(base) == outputs.end()) { - C10_THROW_ERROR(ValueError, - "Output variants for '" + base + "' were defined (e.g., '" + - base + "/" + *kv.second.begin() + "') but no default '" + base + "' was provided." - ); + const auto& all_names = kv.second; + + if (all_names.size() > 1) { + for (const auto& name : all_names) { + if (outputs.at(name)->description.empty()) { + TORCH_WARN( + "'", base, "' defines ", all_names.size(), " output variants and '", name, "' has an empty description. ", + "Consider adding meaningful descriptions helping users to distinguish between them." + ); + } + } } } diff --git a/metatomic-torch/src/register.cpp b/metatomic-torch/src/register.cpp index 3dbf5234..90916250 100644 --- a/metatomic-torch/src/register.cpp +++ b/metatomic-torch/src/register.cpp @@ -123,14 +123,22 @@ TORCH_LIBRARY(metatomic, m) { m.class_("ModelOutput") .def( - torch::init>(), + torch::init< + std::string, + std::string, + bool, + std::vector, + std::string + >(), DOCSTRING, { torch::arg("quantity") = "", torch::arg("unit") = "", torch::arg("per_atom") = false, - torch::arg("explicit_gradients") = std::vector() + torch::arg("explicit_gradients") = std::vector(), + torch::arg("description") = "", } ) + .def_readwrite("description", &ModelOutputHolder::description) .def_property("quantity", &ModelOutputHolder::quantity, &ModelOutputHolder::set_quantity) .def_property("unit", &ModelOutputHolder::unit, &ModelOutputHolder::set_unit) .def_readwrite("per_atom", &ModelOutputHolder::per_atom) @@ -212,6 +220,7 @@ TORCH_LIBRARY(metatomic, m) { // standalone functions m.def("version() -> str", version); m.def("pick_device(str[] model_devices, str? requested_device = None) -> str", pick_device); + m.def("pick_output(str requested_output, Dict(str, __torch__.torch.classes.metatomic.ModelOutput) outputs, str? desired_variant = None) -> str", pick_output); m.def("read_model_metadata(str path) -> __torch__.torch.classes.metatomic.ModelMetadata", read_model_metadata); m.def("unit_conversion_factor(str quantity, str from_unit, str to_unit) -> float", unit_conversion_factor); diff --git a/metatomic-torch/tests/misc.cpp b/metatomic-torch/tests/misc.cpp index 6301db7d..4725bb63 100644 --- a/metatomic-torch/tests/misc.cpp +++ b/metatomic-torch/tests/misc.cpp @@ -1,3 +1,4 @@ +#include #include #include "metatomic/torch.hpp" @@ -61,3 +62,35 @@ TEST_CASE("Pick device") { std::vector supported_devices_cpu = {"cpu"}; CHECK_THROWS_WITH(metatomic_torch::pick_device(supported_devices_cpu, "cuda"), StartsWith("failed to find requested device")); } + + +TEST_CASE("Pick variant") { + auto output_base = torch::make_intrusive(); + output_base->description = "my awesome energy"; + output_base->set_quantity("energy"); + + auto variantA = torch::make_intrusive(); + variantA->set_quantity("energy"); + variantA->description = "Variant A of the output"; + + auto variantfoo = torch::make_intrusive(); + variantfoo->set_quantity("energy"); + variantfoo->description = "Variant foo of the output"; + + auto outputs = torch::Dict(); + outputs.insert("energy", output_base); + outputs.insert("energy/A", variantA); + outputs.insert("energy/foo", variantfoo); + + CHECK(metatomic_torch::pick_output("energy", outputs) == "energy"); + CHECK(metatomic_torch::pick_output("energy", outputs, "A") == "energy/A"); + CHECK_THROWS_WITH(metatomic_torch::pick_output("foo", outputs), StartsWith("output 'foo' not found in outputs")); + CHECK_THROWS_WITH(metatomic_torch::pick_output("energy", outputs, "C"), StartsWith("variant 'C' for output 'energy' not found in outputs")); + + (void)outputs.erase("energy"); + const auto *err = "output 'energy' has no default variant and no `desired_variant` was given. " + "Available variants are:\n" + " - 'energy/A' : Variant A of the output\n" + " - 'energy/foo': Variant foo of the output"; + CHECK_THROWS_WITH(metatomic_torch::pick_output("energy", outputs), StartsWith(err)); +} diff --git a/metatomic-torch/tests/models.cpp b/metatomic-torch/tests/models.cpp index b51a9328..7a22ab1f 100644 --- a/metatomic-torch/tests/models.cpp +++ b/metatomic-torch/tests/models.cpp @@ -62,6 +62,7 @@ TEST_CASE("Models metadata") { SECTION("ModelOutput") { // save to JSON auto output = torch::make_intrusive(); + output->description = "my awesome energy"; output->set_quantity("energy"); output->set_unit("kJ / mol"); output->per_atom = false; @@ -69,6 +70,7 @@ TEST_CASE("Models metadata") { const auto* expected = R"({ "class": "ModelOutput", + "description": "my awesome energy", "explicit_gradients": [ "baz", "not.this-one_" @@ -104,8 +106,6 @@ TEST_CASE("Models metadata") { StartsWith("unknown unit 'unknown' for length") ); - #if TORCH_VERSION_MAJOR >= 2 && TORCH_VERSION_MINOR >= 0 - struct WarningHandler: public torch::WarningHandler { virtual ~WarningHandler() override = default; void process(const torch::Warning& warning) override { @@ -120,7 +120,6 @@ TEST_CASE("Models metadata") { output->set_quantity("unknown"), torch::WarningUtils::set_warning_handler(old_handler); - #endif } SECTION("ModelEvaluationOptions") { @@ -142,6 +141,7 @@ TEST_CASE("Models metadata") { "outputs": { "output_1": { "class": "ModelOutput", + "description": "", "explicit_gradients": [], "per_atom": false, "quantity": "", @@ -149,6 +149,7 @@ TEST_CASE("Models metadata") { }, "output_2": { "class": "ModelOutput", + "description": "", "explicit_gradients": [], "per_atom": true, "quantity": "something", @@ -235,6 +236,7 @@ TEST_CASE("Models metadata") { "outputs": { "tests::bar": { "class": "ModelOutput", + "description": "", "explicit_gradients": [ "\u00b5-\u03bb" ], @@ -295,6 +297,7 @@ TEST_CASE("Models metadata") { auto capabilities_variants = torch::make_intrusive(); auto output_variant = torch::make_intrusive(); output_variant->per_atom = true; + output_variant->description = "variant output"; auto outputs_variant = torch::Dict(); outputs_variant.insert("energy", output_variant); @@ -308,16 +311,6 @@ TEST_CASE("Models metadata") { CHECK(stored.find("energy") != stored.end()); CHECK(stored.find("energy/PBE0") != stored.end()); - auto capabilities_no_default = torch::make_intrusive(); - auto output_no_default = torch::make_intrusive(); - auto outputs_no_default = torch::Dict(); - outputs_no_default.insert("energy/PBE0", output_no_default); // missing "energy" - - CHECK_THROWS_WITH( - capabilities_no_default->set_outputs(outputs_no_default), - Contains("no default 'energy' was provided") - ); - auto capabilities_non_standard = torch::make_intrusive(); auto output_non_standard = torch::make_intrusive(); auto outputs_non_standard = torch::Dict(); @@ -382,6 +375,25 @@ TEST_CASE("Models metadata") { capabilities_non_standard->set_outputs(outputs_non_standard), Contains("Invalid name for model output") ); + + // check for variant description warning + struct WarningHandler: public torch::WarningHandler { + virtual ~WarningHandler() override = default; + void process(const torch::Warning& warning) override { + CHECK(warning.msg() == "'energy' defines 3 output variants and 'energy/foo' has an empty description. " + "Consider adding meaningful descriptions helping users to distinguish between them."); + } + }; + + auto* old_handler = torch::WarningUtils::get_warning_handler(); + auto check_expected_warning = WarningHandler(); + torch::WarningUtils::set_warning_handler(&check_expected_warning); + + auto output_variant_no_desc = torch::make_intrusive(); + outputs_variant.insert("energy/foo", output_variant_no_desc); + capabilities_variants->set_outputs(outputs_variant); + + torch::WarningUtils::set_warning_handler(old_handler); } SECTION("ModelMetadata") { diff --git a/python/metatomic_torch/metatomic/torch/__init__.py b/python/metatomic_torch/metatomic/torch/__init__.py index de2116ca..ce1f551b 100644 --- a/python/metatomic_torch/metatomic/torch/__init__.py +++ b/python/metatomic_torch/metatomic/torch/__init__.py @@ -17,6 +17,7 @@ check_atomistic_model, load_model_extensions, pick_device, + pick_output, read_model_metadata, register_autograd_neighbors, unit_conversion_factor, @@ -40,6 +41,7 @@ register_autograd_neighbors = torch.ops.metatomic.register_autograd_neighbors unit_conversion_factor = torch.ops.metatomic.unit_conversion_factor pick_device = torch.ops.metatomic.pick_device + pick_output = torch.ops.metatomic.pick_output from .model import ( # noqa: F401 AtomisticModel, diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index b5b2e4dd..346e8929 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -19,6 +19,7 @@ System, load_atomistic_model, pick_device, + pick_output, register_autograd_neighbors, ) @@ -166,50 +167,59 @@ def __init__( f"found unexpected dtype in model capabilities: {capabilities.dtype}" ) - self._energy_key = "energy" - self._energy_uq_key = "energy_uncertainty" - self._nc_forces_key = "non_conservative_forces" - self._nc_stress_key = "non_conservative_stress" - - if variants: - if "energy" in variants: - self._energy_key += f"/{variants['energy']}" - self._energy_uq_key += f"/{variants['energy']}" - self._nc_forces_key += f"/{variants['energy']}" - self._nc_stress_key += f"/{variants['energy']}" - - if "energy_uncertainty" in variants: - if variants["energy_uncertainty"] is None: - self._energy_uq_key = "energy_uncertainty" - else: - self._energy_uq_key += f"/{variants['energy_uncertainty']}" - - if non_conservative: - if ( - "non_conservative_stress" in variants - and "non_conservative_forces" in variants - and ( - (variants["non_conservative_stress"] is None) - != (variants["non_conservative_forces"] is None) - ) - ): - raise ValueError( - "if both 'non_conservative_stress' and " - "'non_conservative_forces' are present in `variants`, they " - "must either be both `None` or both not `None`." - ) + # resolve the output keys to use based on the requested variants + variants = variants or {} + default_variant = variants.get("energy") + + resolved_variants = { + key: variants.get(key, default_variant) + for key in [ + "energy", + "energy_uncertainty", + "non_conservative_forces", + "non_conservative_stress", + ] + } + + outputs = capabilities.outputs + self._energy_key = pick_output("energy", outputs, resolved_variants["energy"]) - if "non_conservative_forces" in variants: - if variants["non_conservative_forces"] is None: - self._nc_forces_key = "non_conservative_forces" - else: - self._nc_forces_key += f"/{variants['non_conservative_forces']}" - - if "non_conservative_stress" in variants: - if variants["non_conservative_stress"] is None: - self._nc_stress_key = "non_conservative_stress" - else: - self._nc_stress_key += f"/{variants['non_conservative_stress']}" + has_energy_uq = any("energy_uncertainty" in key for key in outputs.keys()) + if has_energy_uq and uncertainty_threshold is not None: + self._energy_uq_key = pick_output( + "energy_uncertainty", outputs, resolved_variants["energy_uncertainty"] + ) + else: + self._energy_uq_key = "energy_uncertainty" + + if non_conservative: + if ( + "non_conservative_stress" in variants + and "non_conservative_forces" in variants + and ( + (variants["non_conservative_stress"] is None) + != (variants["non_conservative_forces"] is None) + ) + ): + raise ValueError( + "if both 'non_conservative_stress' and " + "'non_conservative_forces' are present in `variants`, they " + "must either be both `None` or both not `None`." + ) + + self._nc_forces_key = pick_output( + "non_conservative_forces", + outputs, + resolved_variants["non_conservative_forces"], + ) + self._nc_stress_key = pick_output( + "non_conservative_stress", + outputs, + resolved_variants["non_conservative_stress"], + ) + else: + self._nc_forces_key = "non_conservative_forces" + self._nc_stress_key = "non_conservative_stress" if additional_outputs is None: self._additional_output_requests = {} diff --git a/python/metatomic_torch/metatomic/torch/documentation.py b/python/metatomic_torch/metatomic/torch/documentation.py index 6c99b750..83360c12 100644 --- a/python/metatomic_torch/metatomic/torch/documentation.py +++ b/python/metatomic_torch/metatomic/torch/documentation.py @@ -247,7 +247,7 @@ def __ne__(self, other: "NeighborListOptions") -> bool: class ModelOutput: - """Description of one of the quantity a model can compute.""" + """Information about one of the quantity a model can compute.""" def __init__( self, @@ -255,6 +255,7 @@ def __init__( unit: str = "", per_atom: bool = False, explicit_gradients: List[str] = [], # noqa B006 + description: str = "", ): pass @@ -286,6 +287,13 @@ def unit(self) -> str: :py:class:`TensorMap`. """ + @property + def description(self) -> str: + """ + A description of this output. Especially recommended for non-standard outputs + and variants of the one unit. + """ + class ModelCapabilities: """Description of a model capabilities, i.e. everything a model can do.""" @@ -536,3 +544,18 @@ def pick_device(model_devices: List[str], desired_device: Optional[str]) -> str: :param desired_device: user-provided desired device. If ``None`` or not available, the first available device from ``model_devices`` will be picked. """ + + +def pick_output( + requested_output: str, + outputs: Dict[str, ModelOutput], + desired_variant: Optional[str] = None, +) -> str: + """ + Pick the output for the given ``requested_output`` from the availabilities of the + model's ``outputs``, according to the optional ``desired_variant``. + + :param requested_output: name of the output to pick a variant for + :param outputs: all available outputs from the model + :param desired_variant: if provided, try to pick this specific variant + """ diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index 31b9e665..384669c3 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -579,18 +579,23 @@ def test_additional_outputs(atoms): ) model = AtomisticModel(MultipleOutputModel().eval(), ModelMetadata(), capabilities) - atoms.calc = MetatomicCalculator(model, check_consistency=True) + atoms.calc = MetatomicCalculator( + model, + check_consistency=True, + uncertainty_threshold=None, + ) assert atoms.get_potential_energy() == 0.0 assert atoms.calc.additional_outputs == {} atoms.calc = MetatomicCalculator( model, - check_consistency=True, additional_outputs={ "test::test": ModelOutput(per_atom=False), "another::one": ModelOutput(per_atom=False), }, + check_consistency=True, + uncertainty_threshold=None, ) assert atoms.get_potential_energy() == 0.0