Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion metatomic-torch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@ a changelog](https://keepachangelog.com/en/1.1.0/) format. This project follows

### Removed

Dropped support for deprecated Python 3.9, now requires 3.10 as minimum version
- We dropped support for Python 3.9, and now requires at least Python 3.10

### Added

- `ModelOutput` now has a `description` field, to carry more information
about a given output.
- the `pick_output` function that can be used by simulation engines to pick the
correct output based on what's available inside a model and which variant (if
any) the user requested.

## [Version 0.1.5](https://github.com/metatensor/metatomic/releases/tag/metatomic-torch-v0.1.5) - 2025-10-06

Expand Down
11 changes: 10 additions & 1 deletion metatomic-torch/include/metatomic/torch/misc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <vector>
#include <torch/torch.h>

#include "metatomic/torch/model.hpp"
#include "metatomic/torch/system.hpp"

#include <torch/types.h>
Expand All @@ -26,6 +27,14 @@ METATOMIC_TORCH_EXPORT std::string pick_device(
torch::optional<std::string> desired_device = torch::nullopt
);

/// Pick the output for the given `requested_output` from the availabilities of the
/// model's `outputs`, according to the optional `desired_variant`.
METATOMIC_TORCH_EXPORT std::string pick_output(
std::string requested_output,
torch::Dict<std::string, ModelOutput> outputs,
torch::optional<std::string> desired_variant = torch::nullopt
);

// ===== File-based =====
void save(const std::string& path, const System& system);
System load_system(const std::string& path);
Expand All @@ -46,7 +55,7 @@ inline System load_system_buffer(const torch::Tensor& data) {
throw std::runtime_error("System pickle: expected 1D torch.uint8 buffer");
}
const uint8_t* ptr = t.data_ptr<uint8_t>();
const size_t n = static_cast<size_t>(t.numel());
const auto n = static_cast<size_t>(t.numel());
return load_system_buffer(ptr, n);
}

Expand Down
9 changes: 7 additions & 2 deletions metatomic-torch/include/metatomic/torch/model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ bool valid_quantity(const std::string& quantity);
void validate_unit(const std::string& quantity, const std::string& unit);


/// Description of one of the quantity a model can compute
/// Information about one of the quantity a model can compute
class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder {
public:
ModelOutputHolder() = default;
Expand All @@ -49,8 +49,10 @@ class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder
std::string quantity,
std::string unit,
bool per_atom_,
std::vector<std::string> explicit_gradients_
std::vector<std::string> explicit_gradients_,
std::string description_
):
description(std::move(description_)),
per_atom(per_atom_),
explicit_gradients(std::move(explicit_gradients_))
{
Expand All @@ -60,6 +62,9 @@ class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder

~ModelOutputHolder() override = default;

/// description of this output, defaults to empty string of not set by the user
std::string description;

/// quantity of the output (e.g. energy, dipole, …). If this is an empty
/// string, no unit conversion will be performed.
const std::string& quantity() const {
Expand Down
62 changes: 62 additions & 0 deletions metatomic-torch/src/misc.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#include <torch/torch.h>

#include "metatomic/torch/model.hpp"
#include "metatomic/torch/version.h"
#include "metatomic/torch/misc.hpp"

#include <algorithm>
#include <stdexcept>
#include <string>
#include <vector>
Expand Down Expand Up @@ -75,6 +77,66 @@ std::string pick_device(
return selected_device;
}

std::string pick_output(
std::string requested_output,
torch::Dict<std::string, ModelOutput> outputs,
torch::optional<std::string> desired_variant
) {
std::vector<std::string> matching_keys;
bool has_exact = false;

for (const auto& output: outputs) {
const auto& key = output.key();

// match either exact `requested_output` or `requested_output/<variant>`
if (key == requested_output
|| (key.size() > requested_output.size()
&& key.compare(0, requested_output.size(), requested_output) == 0
&& key[requested_output.size()] == '/')) {
matching_keys.emplace_back(key);

if (key == requested_output) {
has_exact = true;
}
}
}

if (matching_keys.empty()) {
C10_THROW_ERROR(ValueError,
"output '" + requested_output + "' not found in outputs"
);
}

if (desired_variant != torch::nullopt) {
const auto& output = requested_output + "/" + desired_variant.value();
auto it = std::find(matching_keys.begin(), matching_keys.end(), output);
if (it != matching_keys.end()) {
return *it;
}
C10_THROW_ERROR(ValueError,
"variant '" + desired_variant.value() + "' for output '" + requested_output +
"' not found in outputs"
);
} else if (has_exact) {
return requested_output;
} else {
std::ostringstream oss;
oss << "output '" << requested_output << "' has no default variant and no `desired_variant` was given. Available variants are:";

size_t maxlen = 0;
for (const auto& key: matching_keys) {
maxlen = std::max(key.size(), maxlen);
}

for (const auto& key: matching_keys) {
auto description = outputs.at(key)->description;
std::string padding(maxlen - key.size(), ' ');
oss << "\n - '" << key << "'" << padding << ": " << description;
}
C10_THROW_ERROR(ValueError, oss.str());
}
}


static bool ends_with(const std::string& s, const std::string& suff) {
return s.size() >= suff.size() && s.compare(s.size() - suff.size(), suff.size(), suff) == 0;
Expand Down
36 changes: 28 additions & 8 deletions metatomic-torch/src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ static nlohmann::json model_output_to_json(const ModelOutputHolder& self) {
result["unit"] = self.unit();
result["per_atom"] = self.per_atom;
result["explicit_gradients"] = self.explicit_gradients;
result["description"] = self.description;

return result;
}
Expand All @@ -96,6 +97,7 @@ static ModelOutput model_output_from_json(const nlohmann::json& data) {
}

auto result = torch::make_intrusive<ModelOutputHolder>();

if (data.contains("quantity")) {
if (!data["quantity"].is_string()) {
throw std::runtime_error("'quantity' in JSON for ModelOutput must be a string");
Expand Down Expand Up @@ -125,6 +127,16 @@ static ModelOutput model_output_from_json(const nlohmann::json& data) {
);
}

if (data.contains("description")) {
if (!data["description"].is_string()) {
throw std::runtime_error("'description' in JSON for ModelOutput must be a string");
}
result->description = data["description"];
} else {
// backward compatibility
result->description = "";
}

return result;
}

Expand All @@ -147,12 +159,14 @@ std::unordered_set<std::string> KNOWN_OUTPUTS = {
};

void ModelCapabilitiesHolder::set_outputs(torch::Dict<std::string, ModelOutput> outputs) {
std::unordered_map<std::string, std::unordered_set<std::string>> variants;

std::unordered_map<std::string, std::vector<std::string>> variants;

for (const auto& it: outputs) {
const auto& name = it.key();
if (KNOWN_OUTPUTS.find(name) != KNOWN_OUTPUTS.end()) {
// known output, nothing to do
variants[name].push_back(name);
continue;
}

Expand Down Expand Up @@ -180,7 +194,7 @@ void ModelCapabilitiesHolder::set_outputs(torch::Dict<std::string, ModelOutput>
);
}

variants[base].insert(variant);
variants[base].push_back(name);
continue;
}

Expand All @@ -204,14 +218,20 @@ void ModelCapabilitiesHolder::set_outputs(torch::Dict<std::string, ModelOutput>
);
}

// ensure each variant has a defined default base output
// check descriptions for each variant group
for (const auto& kv : variants) {
const auto& base = kv.first;
if (outputs.find(base) == outputs.end()) {
C10_THROW_ERROR(ValueError,
"Output variants for '" + base + "' were defined (e.g., '" +
base + "/" + *kv.second.begin() + "') but no default '" + base + "' was provided."
);
const auto& all_names = kv.second;

if (all_names.size() > 1) {
for (const auto& name : all_names) {
if (outputs.at(name)->description.empty()) {
TORCH_WARN(
"'", base, "' defines ", all_names.size(), " output variants and '", name, "' has an empty description. ",
"Consider adding meaningful descriptions helping users to distinguish between them."
);
}
}
}
}

Expand Down
13 changes: 11 additions & 2 deletions metatomic-torch/src/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,22 @@ TORCH_LIBRARY(metatomic, m) {

m.class_<ModelOutputHolder>("ModelOutput")
.def(
torch::init<std::string, std::string, bool, std::vector<std::string>>(),
torch::init<
std::string,
std::string,
bool,
std::vector<std::string>,
std::string
>(),
DOCSTRING, {
torch::arg("quantity") = "",
torch::arg("unit") = "",
torch::arg("per_atom") = false,
torch::arg("explicit_gradients") = std::vector<std::string>()
torch::arg("explicit_gradients") = std::vector<std::string>(),
torch::arg("description") = "",
}
)
.def_readwrite("description", &ModelOutputHolder::description)
.def_property("quantity", &ModelOutputHolder::quantity, &ModelOutputHolder::set_quantity)
.def_property("unit", &ModelOutputHolder::unit, &ModelOutputHolder::set_unit)
.def_readwrite("per_atom", &ModelOutputHolder::per_atom)
Expand Down Expand Up @@ -212,6 +220,7 @@ TORCH_LIBRARY(metatomic, m) {
// standalone functions
m.def("version() -> str", version);
m.def("pick_device(str[] model_devices, str? requested_device = None) -> str", pick_device);
m.def("pick_output(str requested_output, Dict(str, __torch__.torch.classes.metatomic.ModelOutput) outputs, str? desired_variant = None) -> str", pick_output);

m.def("read_model_metadata(str path) -> __torch__.torch.classes.metatomic.ModelMetadata", read_model_metadata);
m.def("unit_conversion_factor(str quantity, str from_unit, str to_unit) -> float", unit_conversion_factor);
Expand Down
33 changes: 33 additions & 0 deletions metatomic-torch/tests/misc.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <c10/util/intrusive_ptr.h>
#include <torch/torch.h>

#include "metatomic/torch.hpp"
Expand Down Expand Up @@ -61,3 +62,35 @@ TEST_CASE("Pick device") {
std::vector<std::string> supported_devices_cpu = {"cpu"};
CHECK_THROWS_WITH(metatomic_torch::pick_device(supported_devices_cpu, "cuda"), StartsWith("failed to find requested device"));
}


TEST_CASE("Pick variant") {
auto output_base = torch::make_intrusive<metatomic_torch::ModelOutputHolder>();
output_base->description = "my awesome energy";
output_base->set_quantity("energy");

auto variantA = torch::make_intrusive<metatomic_torch::ModelOutputHolder>();
variantA->set_quantity("energy");
variantA->description = "Variant A of the output";

auto variantfoo = torch::make_intrusive<metatomic_torch::ModelOutputHolder>();
variantfoo->set_quantity("energy");
variantfoo->description = "Variant foo of the output";

auto outputs = torch::Dict<std::string, metatomic_torch::ModelOutput>();
outputs.insert("energy", output_base);
outputs.insert("energy/A", variantA);
outputs.insert("energy/foo", variantfoo);

CHECK(metatomic_torch::pick_output("energy", outputs) == "energy");
CHECK(metatomic_torch::pick_output("energy", outputs, "A") == "energy/A");
CHECK_THROWS_WITH(metatomic_torch::pick_output("foo", outputs), StartsWith("output 'foo' not found in outputs"));
CHECK_THROWS_WITH(metatomic_torch::pick_output("energy", outputs, "C"), StartsWith("variant 'C' for output 'energy' not found in outputs"));

(void)outputs.erase("energy");
const auto *err = "output 'energy' has no default variant and no `desired_variant` was given. "
"Available variants are:\n"
" - 'energy/A' : Variant A of the output\n"
" - 'energy/foo': Variant foo of the output";
CHECK_THROWS_WITH(metatomic_torch::pick_output("energy", outputs), StartsWith(err));
}
Loading
Loading