|
19 | 19 | System, |
20 | 20 | load_atomistic_model, |
21 | 21 | pick_device, |
| 22 | + pick_output, |
22 | 23 | register_autograd_neighbors, |
23 | 24 | ) |
24 | 25 |
|
@@ -166,50 +167,58 @@ def __init__( |
166 | 167 | f"found unexpected dtype in model capabilities: {capabilities.dtype}" |
167 | 168 | ) |
168 | 169 |
|
169 | | - self._energy_key = "energy" |
170 | | - self._energy_uq_key = "energy_uncertainty" |
171 | | - self._nc_forces_key = "non_conservative_forces" |
172 | | - self._nc_stress_key = "non_conservative_stress" |
173 | | - |
174 | | - if variants: |
175 | | - if "energy" in variants: |
176 | | - self._energy_key += f"/{variants['energy']}" |
177 | | - self._energy_uq_key += f"/{variants['energy']}" |
178 | | - self._nc_forces_key += f"/{variants['energy']}" |
179 | | - self._nc_stress_key += f"/{variants['energy']}" |
180 | | - |
181 | | - if "energy_uncertainty" in variants: |
182 | | - if variants["energy_uncertainty"] is None: |
183 | | - self._energy_uq_key = "energy_uncertainty" |
184 | | - else: |
185 | | - self._energy_uq_key += f"/{variants['energy_uncertainty']}" |
186 | | - |
187 | | - if non_conservative: |
188 | | - if ( |
189 | | - "non_conservative_stress" in variants |
190 | | - and "non_conservative_forces" in variants |
191 | | - and ( |
192 | | - (variants["non_conservative_stress"] is None) |
193 | | - != (variants["non_conservative_forces"] is None) |
194 | | - ) |
195 | | - ): |
196 | | - raise ValueError( |
197 | | - "if both 'non_conservative_stress' and " |
198 | | - "'non_conservative_forces' are present in `variants`, they " |
199 | | - "must either be both `None` or both not `None`." |
200 | | - ) |
| 170 | + # resolve the output keys to use based on the requested variants |
| 171 | + variants = variants or {} |
| 172 | + default_variant = variants.get("energy") |
| 173 | + |
| 174 | + resolved_variants = { |
| 175 | + key: variants.get(key, default_variant) |
| 176 | + for key in [ |
| 177 | + "energy", |
| 178 | + "energy_uncertainty", |
| 179 | + "non_conservative_forces", |
| 180 | + "non_conservative_stress", |
| 181 | + ] |
| 182 | + } |
| 183 | + |
| 184 | + outputs = capabilities.outputs |
| 185 | + self._energy_key = pick_output("energy", outputs, resolved_variants["energy"]) |
201 | 186 |
|
202 | | - if "non_conservative_forces" in variants: |
203 | | - if variants["non_conservative_forces"] is None: |
204 | | - self._nc_forces_key = "non_conservative_forces" |
205 | | - else: |
206 | | - self._nc_forces_key += f"/{variants['non_conservative_forces']}" |
207 | | - |
208 | | - if "non_conservative_stress" in variants: |
209 | | - if variants["non_conservative_stress"] is None: |
210 | | - self._nc_stress_key = "non_conservative_stress" |
211 | | - else: |
212 | | - self._nc_stress_key += f"/{variants['non_conservative_stress']}" |
| 187 | + if uncertainty_threshold is not None: |
| 188 | + self._energy_uq_key = pick_output( |
| 189 | + "energy_uncertainty", outputs, resolved_variants["energy_uncertainty"] |
| 190 | + ) |
| 191 | + else: |
| 192 | + self._energy_uq_key = "energy_uncertainty" |
| 193 | + |
| 194 | + if non_conservative: |
| 195 | + if ( |
| 196 | + "non_conservative_stress" in variants |
| 197 | + and "non_conservative_forces" in variants |
| 198 | + and ( |
| 199 | + (variants["non_conservative_stress"] is None) |
| 200 | + != (variants["non_conservative_forces"] is None) |
| 201 | + ) |
| 202 | + ): |
| 203 | + raise ValueError( |
| 204 | + "if both 'non_conservative_stress' and " |
| 205 | + "'non_conservative_forces' are present in `variants`, they " |
| 206 | + "must either be both `None` or both not `None`." |
| 207 | + ) |
| 208 | + |
| 209 | + self._nc_forces_key = pick_output( |
| 210 | + "non_conservative_forces", |
| 211 | + outputs, |
| 212 | + resolved_variants["non_conservative_forces"], |
| 213 | + ) |
| 214 | + self._nc_stress_key = pick_output( |
| 215 | + "non_conservative_stress", |
| 216 | + outputs, |
| 217 | + resolved_variants["non_conservative_stress"], |
| 218 | + ) |
| 219 | + else: |
| 220 | + self._nc_forces_key = "non_conservative_forces" |
| 221 | + self._nc_stress_key = "non_conservative_stress" |
213 | 222 |
|
214 | 223 | if additional_outputs is None: |
215 | 224 | self._additional_output_requests = {} |
|
0 commit comments