| 
19 | 19 |     System,  | 
20 | 20 |     load_atomistic_model,  | 
21 | 21 |     pick_device,  | 
 | 22 | +    pick_output,  | 
22 | 23 |     register_autograd_neighbors,  | 
23 | 24 | )  | 
24 | 25 | 
 
  | 
@@ -166,50 +167,58 @@ def __init__(  | 
166 | 167 |                 f"found unexpected dtype in model capabilities: {capabilities.dtype}"  | 
167 | 168 |             )  | 
168 | 169 | 
 
  | 
169 |  | -        self._energy_key = "energy"  | 
170 |  | -        self._energy_uq_key = "energy_uncertainty"  | 
171 |  | -        self._nc_forces_key = "non_conservative_forces"  | 
172 |  | -        self._nc_stress_key = "non_conservative_stress"  | 
173 |  | - | 
174 |  | -        if variants:  | 
175 |  | -            if "energy" in variants:  | 
176 |  | -                self._energy_key += f"/{variants['energy']}"  | 
177 |  | -                self._energy_uq_key += f"/{variants['energy']}"  | 
178 |  | -                self._nc_forces_key += f"/{variants['energy']}"  | 
179 |  | -                self._nc_stress_key += f"/{variants['energy']}"  | 
180 |  | - | 
181 |  | -            if "energy_uncertainty" in variants:  | 
182 |  | -                if variants["energy_uncertainty"] is None:  | 
183 |  | -                    self._energy_uq_key = "energy_uncertainty"  | 
184 |  | -                else:  | 
185 |  | -                    self._energy_uq_key += f"/{variants['energy_uncertainty']}"  | 
186 |  | - | 
187 |  | -            if non_conservative:  | 
188 |  | -                if (  | 
189 |  | -                    "non_conservative_stress" in variants  | 
190 |  | -                    and "non_conservative_forces" in variants  | 
191 |  | -                    and (  | 
192 |  | -                        (variants["non_conservative_stress"] is None)  | 
193 |  | -                        != (variants["non_conservative_forces"] is None)  | 
194 |  | -                    )  | 
195 |  | -                ):  | 
196 |  | -                    raise ValueError(  | 
197 |  | -                        "if both 'non_conservative_stress' and "  | 
198 |  | -                        "'non_conservative_forces' are present in `variants`, they "  | 
199 |  | -                        "must either be both `None` or both not `None`."  | 
200 |  | -                    )  | 
 | 170 | +        # resolve the output keys to use based on the requested variants  | 
 | 171 | +        variants = variants or {}  | 
 | 172 | +        default_variant = variants.get("energy")  | 
 | 173 | + | 
 | 174 | +        resolved_variants = {  | 
 | 175 | +            key: variants.get(key, default_variant)  | 
 | 176 | +            for key in [  | 
 | 177 | +                "energy",  | 
 | 178 | +                "energy_uncertainty",  | 
 | 179 | +                "non_conservative_forces",  | 
 | 180 | +                "non_conservative_stress",  | 
 | 181 | +            ]  | 
 | 182 | +        }  | 
 | 183 | + | 
 | 184 | +        outputs = capabilities.outputs  | 
 | 185 | +        self._energy_key = pick_output("energy", outputs, resolved_variants["energy"])  | 
201 | 186 | 
 
  | 
202 |  | -                if "non_conservative_forces" in variants:  | 
203 |  | -                    if variants["non_conservative_forces"] is None:  | 
204 |  | -                        self._nc_forces_key = "non_conservative_forces"  | 
205 |  | -                    else:  | 
206 |  | -                        self._nc_forces_key += f"/{variants['non_conservative_forces']}"  | 
207 |  | - | 
208 |  | -                if "non_conservative_stress" in variants:  | 
209 |  | -                    if variants["non_conservative_stress"] is None:  | 
210 |  | -                        self._nc_stress_key = "non_conservative_stress"  | 
211 |  | -                    else:  | 
212 |  | -                        self._nc_stress_key += f"/{variants['non_conservative_stress']}"  | 
 | 187 | +        if uncertainty_threshold is not None:  | 
 | 188 | +            self._energy_uq_key = pick_output(  | 
 | 189 | +                "energy_uncertainty", outputs, resolved_variants["energy_uncertainty"]  | 
 | 190 | +            )  | 
 | 191 | +        else:  | 
 | 192 | +            self._energy_uq_key = "energy_uncertainty"  | 
 | 193 | + | 
 | 194 | +        if non_conservative:  | 
 | 195 | +            if (  | 
 | 196 | +                "non_conservative_stress" in variants  | 
 | 197 | +                and "non_conservative_forces" in variants  | 
 | 198 | +                and (  | 
 | 199 | +                    (variants["non_conservative_stress"] is None)  | 
 | 200 | +                    != (variants["non_conservative_forces"] is None)  | 
 | 201 | +                )  | 
 | 202 | +            ):  | 
 | 203 | +                raise ValueError(  | 
 | 204 | +                    "if both 'non_conservative_stress' and "  | 
 | 205 | +                    "'non_conservative_forces' are present in `variants`, they "  | 
 | 206 | +                    "must either be both `None` or both not `None`."  | 
 | 207 | +                )  | 
 | 208 | + | 
 | 209 | +            self._nc_forces_key = pick_output(  | 
 | 210 | +                "non_conservative_forces",  | 
 | 211 | +                outputs,  | 
 | 212 | +                resolved_variants["non_conservative_forces"],  | 
 | 213 | +            )  | 
 | 214 | +            self._nc_stress_key = pick_output(  | 
 | 215 | +                "non_conservative_stress",  | 
 | 216 | +                outputs,  | 
 | 217 | +                resolved_variants["non_conservative_stress"],  | 
 | 218 | +            )  | 
 | 219 | +        else:  | 
 | 220 | +            self._nc_forces_key = "non_conservative_forces"  | 
 | 221 | +            self._nc_stress_key = "non_conservative_stress"  | 
213 | 222 | 
 
  | 
214 | 223 |         if additional_outputs is None:  | 
215 | 224 |             self._additional_output_requests = {}  | 
 | 
0 commit comments