Skip to content

Commit 7ca4e3d

Browse files
committed
Add pick_output function
1 parent c3e61d4 commit 7ca4e3d

File tree

9 files changed

+199
-63
lines changed

9 files changed

+199
-63
lines changed

metatomic-torch/include/metatomic/torch/misc.hpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <vector>
88
#include <torch/torch.h>
99

10+
#include "metatomic/torch/model.hpp"
1011
#include "metatomic/torch/system.hpp"
1112

1213
#include <torch/types.h>
@@ -26,6 +27,14 @@ METATOMIC_TORCH_EXPORT std::string pick_device(
2627
torch::optional<std::string> desired_device = torch::nullopt
2728
);
2829

30+
/// Pick the output for the given ``requested_output`` from the availabilities of the
31+
/// model's ``outputs``, according to the optional ``desired_variant``.
32+
METATOMIC_TORCH_EXPORT std::string pick_output(
33+
std::string requested_output,
34+
torch::Dict<std::string, ModelOutput> outputs,
35+
torch::optional<std::string> desired_variant = torch::nullopt
36+
);
37+
2938
// ===== File-based =====
3039
void save(const std::string& path, const System& system);
3140
System load_system(const std::string& path);
@@ -46,7 +55,7 @@ inline System load_system_buffer(const torch::Tensor& data) {
4655
throw std::runtime_error("System pickle: expected 1D torch.uint8 buffer");
4756
}
4857
const uint8_t* ptr = t.data_ptr<uint8_t>();
49-
const size_t n = static_cast<size_t>(t.numel());
58+
const auto n = static_cast<size_t>(t.numel());
5059
return load_system_buffer(ptr, n);
5160
}
5261

metatomic-torch/include/metatomic/torch/model.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,19 @@ bool valid_quantity(const std::string& quantity);
3939
void validate_unit(const std::string& quantity, const std::string& unit);
4040

4141

42-
/// Metadata of one of the quantity a model can compute
42+
/// Information about one of the quantity a model can compute
4343
class METATOMIC_TORCH_EXPORT ModelOutputHolder: public torch::CustomClassHolder {
4444
public:
4545
ModelOutputHolder() = default;
4646

4747
/// Initialize `ModelOutput` with the given data
4848
ModelOutputHolder(
49-
std::string description_,
5049
std::string quantity,
5150
std::string unit,
5251
bool per_atom_,
53-
std::vector<std::string> explicit_gradients_
54-
):
52+
std::vector<std::string> explicit_gradients_,
53+
std::string description_
54+
):
5555
description(std::move(description_)),
5656
per_atom(per_atom_),
5757
explicit_gradients(std::move(explicit_gradients_))

metatomic-torch/src/misc.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
#include <torch/torch.h>
22

3+
#include "metatomic/torch/model.hpp"
34
#include "metatomic/torch/version.h"
45
#include "metatomic/torch/misc.hpp"
56

7+
#include <algorithm>
68
#include <stdexcept>
79
#include <string>
810
#include <vector>
@@ -75,6 +77,66 @@ std::string pick_device(
7577
return selected_device;
7678
}
7779

80+
std::string pick_output(
81+
std::string requested_output,
82+
torch::Dict<std::string, ModelOutput> outputs,
83+
torch::optional<std::string> desired_variant
84+
) {
85+
std::vector<std::string> matching_keys;
86+
bool has_exact = false;
87+
88+
for (const auto& output: outputs) {
89+
const auto& key = output.key();
90+
91+
// match either exact name or "name/variant"
92+
if (key == requested_output
93+
|| (key.size() > requested_output.size()
94+
&& key.compare(0, requested_output.size(), requested_output) == 0
95+
&& key[requested_output.size()] == '/')) {
96+
matching_keys.emplace_back(key);
97+
98+
if (key == requested_output) {
99+
has_exact = true;
100+
}
101+
}
102+
}
103+
104+
if (matching_keys.empty()) {
105+
C10_THROW_ERROR(ValueError,
106+
"output '" + requested_output + "' not found in outputs"
107+
);
108+
}
109+
110+
if (desired_variant != torch::nullopt) {
111+
const auto& output = requested_output + "/" + desired_variant.value();
112+
auto it = std::find(matching_keys.begin(), matching_keys.end(), output);
113+
if (it != matching_keys.end()) {
114+
return *it;
115+
}
116+
C10_THROW_ERROR(ValueError,
117+
"variant '" + desired_variant.value() + "' for output '" + requested_output +
118+
"' not found in outputs"
119+
);
120+
} else if (has_exact) {
121+
return requested_output;
122+
} else {
123+
std::ostringstream oss;
124+
oss << "output '" << requested_output << "' has no default variant and no `desired_variant` was given. Available variants are:";
125+
126+
size_t maxlen = 0;
127+
for (const auto& key: matching_keys) {
128+
maxlen = std::max(key.size(), maxlen);
129+
}
130+
131+
for (const auto& key: matching_keys) {
132+
auto description = outputs.at(key)->description;
133+
std::string padding(maxlen - key.size(), ' ');
134+
oss << "\n - '" << key << "'" << padding << ": " << description;
135+
}
136+
C10_THROW_ERROR(ValueError, oss.str());
137+
}
138+
}
139+
78140

79141
static bool ends_with(const std::string& s, const std::string& suff) {
80142
return s.size() >= suff.size() && s.compare(s.size() - suff.size(), suff.size(), suff) == 0;

metatomic-torch/src/register.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,18 +124,18 @@ TORCH_LIBRARY(metatomic, m) {
124124
m.class_<ModelOutputHolder>("ModelOutput")
125125
.def(
126126
torch::init<
127-
std::string,
128127
std::string,
129128
std::string,
130129
bool,
131-
std::vector<std::string>
130+
std::vector<std::string>,
131+
std::string
132132
>(),
133133
DOCSTRING, {
134-
torch::arg("description") = "",
135134
torch::arg("quantity") = "",
136135
torch::arg("unit") = "",
137136
torch::arg("per_atom") = false,
138-
torch::arg("explicit_gradients") = std::vector<std::string>()
137+
torch::arg("explicit_gradients") = std::vector<std::string>(),
138+
torch::arg("description") = "",
139139
}
140140
)
141141
.def_readwrite("description", &ModelOutputHolder::description)
@@ -220,6 +220,7 @@ TORCH_LIBRARY(metatomic, m) {
220220
// standalone functions
221221
m.def("version() -> str", version);
222222
m.def("pick_device(str[] model_devices, str? requested_device = None) -> str", pick_device);
223+
m.def("pick_output(str requested_output, Dict(str, __torch__.torch.classes.metatomic.ModelOutput) outputs, str? desired_variant = None) -> str", pick_output);
223224

224225
m.def("read_model_metadata(str path) -> __torch__.torch.classes.metatomic.ModelMetadata", read_model_metadata);
225226
m.def("unit_conversion_factor(str quantity, str from_unit, str to_unit) -> float", unit_conversion_factor);

metatomic-torch/tests/misc.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <c10/util/intrusive_ptr.h>
12
#include <torch/torch.h>
23

34
#include "metatomic/torch.hpp"
@@ -61,3 +62,35 @@ TEST_CASE("Pick device") {
6162
std::vector<std::string> supported_devices_cpu = {"cpu"};
6263
CHECK_THROWS_WITH(metatomic_torch::pick_device(supported_devices_cpu, "cuda"), StartsWith("failed to find requested device"));
6364
}
65+
66+
67+
TEST_CASE("Pick variant") {
68+
auto output_base = torch::make_intrusive<metatomic_torch::ModelOutputHolder>();
69+
output_base->description = "my awesome energy";
70+
output_base->set_quantity("energy");
71+
72+
auto variantA = torch::make_intrusive<metatomic_torch::ModelOutputHolder>();
73+
variantA->set_quantity("energy");
74+
variantA->description = "Variant A of the output";
75+
76+
auto variantfoo = torch::make_intrusive<metatomic_torch::ModelOutputHolder>();
77+
variantfoo->set_quantity("energy");
78+
variantfoo->description = "Variant foo of the output";
79+
80+
auto outputs = torch::Dict<std::string, metatomic_torch::ModelOutput>();
81+
outputs.insert("energy", output_base);
82+
outputs.insert("energy/A", variantA);
83+
outputs.insert("energy/foo", variantfoo);
84+
85+
CHECK(metatomic_torch::pick_output("energy", outputs) == "energy");
86+
CHECK(metatomic_torch::pick_output("energy", outputs, "A") == "energy/A");
87+
CHECK_THROWS_WITH(metatomic_torch::pick_output("foo", outputs), StartsWith("output 'foo' not found in outputs"));
88+
CHECK_THROWS_WITH(metatomic_torch::pick_output("energy", outputs, "C"), StartsWith("variant 'C' for output 'energy' not found in outputs"));
89+
90+
(void)outputs.erase("energy");
91+
const auto *err = "output 'energy' has no default variant and no `desired_variant` was given. "
92+
"Available variants are:\n"
93+
" - 'energy/A' : Variant A of the output\n"
94+
" - 'energy/foo': Variant foo of the output";
95+
CHECK_THROWS_WITH(metatomic_torch::pick_output("energy", outputs), StartsWith(err));
96+
}

python/metatomic_torch/metatomic/torch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
check_atomistic_model,
1818
load_model_extensions,
1919
pick_device,
20+
pick_output,
2021
read_model_metadata,
2122
register_autograd_neighbors,
2223
unit_conversion_factor,
@@ -40,6 +41,7 @@
4041
register_autograd_neighbors = torch.ops.metatomic.register_autograd_neighbors
4142
unit_conversion_factor = torch.ops.metatomic.unit_conversion_factor
4243
pick_device = torch.ops.metatomic.pick_device
44+
pick_output = torch.ops.metatomic.pick_output
4345

4446
from .model import ( # noqa: F401
4547
AtomisticModel,

python/metatomic_torch/metatomic/torch/ase_calculator.py

Lines changed: 52 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
System,
2020
load_atomistic_model,
2121
pick_device,
22+
pick_output,
2223
register_autograd_neighbors,
2324
)
2425

@@ -166,50 +167,58 @@ def __init__(
166167
f"found unexpected dtype in model capabilities: {capabilities.dtype}"
167168
)
168169

169-
self._energy_key = "energy"
170-
self._energy_uq_key = "energy_uncertainty"
171-
self._nc_forces_key = "non_conservative_forces"
172-
self._nc_stress_key = "non_conservative_stress"
173-
174-
if variants:
175-
if "energy" in variants:
176-
self._energy_key += f"/{variants['energy']}"
177-
self._energy_uq_key += f"/{variants['energy']}"
178-
self._nc_forces_key += f"/{variants['energy']}"
179-
self._nc_stress_key += f"/{variants['energy']}"
180-
181-
if "energy_uncertainty" in variants:
182-
if variants["energy_uncertainty"] is None:
183-
self._energy_uq_key = "energy_uncertainty"
184-
else:
185-
self._energy_uq_key += f"/{variants['energy_uncertainty']}"
186-
187-
if non_conservative:
188-
if (
189-
"non_conservative_stress" in variants
190-
and "non_conservative_forces" in variants
191-
and (
192-
(variants["non_conservative_stress"] is None)
193-
!= (variants["non_conservative_forces"] is None)
194-
)
195-
):
196-
raise ValueError(
197-
"if both 'non_conservative_stress' and "
198-
"'non_conservative_forces' are present in `variants`, they "
199-
"must either be both `None` or both not `None`."
200-
)
170+
# resolve the output keys to use based on the requested variants
171+
variants = variants or {}
172+
default_variant = variants.get("energy")
173+
174+
resolved_variants = {
175+
key: variants.get(key, default_variant)
176+
for key in [
177+
"energy",
178+
"energy_uncertainty",
179+
"non_conservative_forces",
180+
"non_conservative_stress",
181+
]
182+
}
183+
184+
outputs = capabilities.outputs
185+
self._energy_key = pick_output("energy", outputs, resolved_variants["energy"])
201186

202-
if "non_conservative_forces" in variants:
203-
if variants["non_conservative_forces"] is None:
204-
self._nc_forces_key = "non_conservative_forces"
205-
else:
206-
self._nc_forces_key += f"/{variants['non_conservative_forces']}"
207-
208-
if "non_conservative_stress" in variants:
209-
if variants["non_conservative_stress"] is None:
210-
self._nc_stress_key = "non_conservative_stress"
211-
else:
212-
self._nc_stress_key += f"/{variants['non_conservative_stress']}"
187+
if uncertainty_threshold is not None:
188+
self._energy_uq_key = pick_output(
189+
"energy_uncertainty", outputs, resolved_variants["energy_uncertainty"]
190+
)
191+
else:
192+
self._energy_uq_key = "energy_uncertainty"
193+
194+
if non_conservative:
195+
if (
196+
"non_conservative_stress" in variants
197+
and "non_conservative_forces" in variants
198+
and (
199+
(variants["non_conservative_stress"] is None)
200+
!= (variants["non_conservative_forces"] is None)
201+
)
202+
):
203+
raise ValueError(
204+
"if both 'non_conservative_stress' and "
205+
"'non_conservative_forces' are present in `variants`, they "
206+
"must either be both `None` or both not `None`."
207+
)
208+
209+
self._nc_forces_key = pick_output(
210+
"non_conservative_forces",
211+
outputs,
212+
resolved_variants["non_conservative_forces"],
213+
)
214+
self._nc_stress_key = pick_output(
215+
"non_conservative_stress",
216+
outputs,
217+
resolved_variants["non_conservative_stress"],
218+
)
219+
else:
220+
self._nc_forces_key = "non_conservative_forces"
221+
self._nc_stress_key = "non_conservative_stress"
213222

214223
if additional_outputs is None:
215224
self._additional_output_requests = {}

python/metatomic_torch/metatomic/torch/documentation.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -247,24 +247,17 @@ def __ne__(self, other: "NeighborListOptions") -> bool:
247247

248248

249249
class ModelOutput:
250-
"""Description of one of the quantity a model can compute."""
250+
"""Information about one of the quantity a model can compute."""
251251

252252
def __init__(
253253
self,
254-
description: str = "",
255254
quantity: str = "",
256255
unit: str = "",
257256
per_atom: bool = False,
258257
explicit_gradients: List[str] = [], # noqa B006
258+
description: str = "",
259259
):
260260
pass
261-
262-
@property
263-
def description(self) -> str:
264-
"""
265-
A description of this output. Especially recommended for non-standard outputs
266-
and variants of the one unit.
267-
"""
268261

269262
@property
270263
def quantity(self) -> str:
@@ -294,6 +287,13 @@ def unit(self) -> str:
294287
:py:class:`TensorMap`.
295288
"""
296289

290+
@property
291+
def description(self) -> str:
292+
"""
293+
A description of this output. Especially recommended for non-standard outputs
294+
and variants of the one unit.
295+
"""
296+
297297

298298
class ModelCapabilities:
299299
"""Description of a model capabilities, i.e. everything a model can do."""
@@ -544,3 +544,18 @@ def pick_device(model_devices: List[str], desired_device: Optional[str]) -> str:
544544
:param desired_device: user-provided desired device. If ``None`` or not available,
545545
the first available device from ``model_devices`` will be picked.
546546
"""
547+
548+
549+
def pick_output(
550+
requested_output: str,
551+
outputs: Dict[str, ModelOutput],
552+
desired_variant: Optional[str] = None,
553+
) -> str:
554+
"""
555+
Pick the output for the given ``requested_output`` from the availabilities of the
556+
model's ``outputs``, according to the optional ``desired_variant``.
557+
558+
:param requested_output: name of the output to pick a variant for
559+
:param outputs: all available outputs from the model
560+
:param desired_variant: if provided, try to pick this specific variant
561+
"""

0 commit comments

Comments
 (0)