Skip to content

Commit 1eb26cd

Browse files
author
jngaravitoc
committed
updating tests
1 parent 3396cc9 commit 1eb26cd

17 files changed

Lines changed: 781 additions & 290 deletions

EXPtools/basis/__init__.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,19 @@
1-
from .basis_utils import *
2-
from .makemodel import *
1+
from .basis_utils import (
2+
load_basis,
3+
check_basis_params,
4+
write_config,
5+
make_basis,
6+
)
7+
from .makemodel import (
8+
write_table,
9+
make_model,
10+
)
11+
12+
__all__ = [
13+
"load_basis",
14+
"check_basis_params",
15+
"write_config",
16+
"make_basis",
17+
"write_table",
18+
"make_model",
19+
]

EXPtools/basis/basis_utils.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,11 @@ def check_basis_params(basis_params):
9999
raise KeyError(f"Missing mandatory keyword arguments missing: {missing}")
100100
return True
101101
elif basis_params['basis_id'] == 'cylinder':
102-
mandatory_keys = ['acyl', 'hcyl', 'nmaxfid', 'lmaxfid',
103-
'mmax', 'nmax', 'ncylodd', 'ncylnx',
104-
'ncylny', 'rnum', 'pnum', 'tnum', 'vflag', 'logr', 'cachename']
102+
mandatory_keys = [
103+
'acyl', 'hcyl', 'nmaxfid', 'lmaxfid', 'mmax', 'nmax',
104+
'ncylodd', 'ncylnx', 'ncylny', 'rnum', 'pnum', 'tnum',
105+
'vflag', 'logr', 'cachename',
106+
]
105107
missing = [key for key in mandatory_keys if key not in basis_params]
106108
if missing:
107109
raise KeyError(f"Missing mandatory keyword arguments missing: {missing}")
@@ -110,26 +112,37 @@ def check_basis_params(basis_params):
110112
raise AttributeError(f"basis id {basis_params['basis_id']} not found. Please chose between sphereSL or cylinder")
111113

112114

113-
def write_config(basis_params):
115+
def write_config(
116+
basis_params,
117+
write_yaml=False,
118+
filename="basis_config.yaml",
119+
):
120+
114121
"""
115122
Create a YAML configuration file string for building a basis model.
116123
117124
Parameters
118125
----------
119-
basis_id : str
120-
Identifier of the basis model. Must be either 'sphereSL' or 'cylinder'.
121-
122-
**params : dict
123-
Additional keyword arguments required depending on the basis type:
124-
126+
basis_params : dict
127+
Dictionary containing basis configuration parameters. Must include
128+
the key ``'basis_id'`` and all required parameters for the chosen basis.
129+
125130
- For ``sphereSL``:
126131
['lmax', 'nmax', 'rmapping', 'modelname', 'cachename']
127132
128133
- For ``cylinder``:
129134
['acyl', 'hcyl', 'nmaxfid', 'lmaxfid', 'mmax', 'nmax',
130135
'ncylodd', 'ncylnx', 'ncylny', 'rnum', 'pnum', 'tnum',
131136
'vflag', 'logr', 'cachename']
137+
138+
write_config : bool, optional
139+
If True, write the YAML configuration to disk. Default is True.
132140
141+
filename : str, optional
142+
Output filename for the YAML configuration. Only used if
143+
``write_config=True``. Default is ``'basis_config.yaml'``.
144+
145+
133146
Returns
134147
-------
135148
str
@@ -144,10 +157,8 @@ def write_config(basis_params):
144157
ValueError
145158
If the model file does not contain valid radius data.
146159
"""
147-
print(basis_params)
148160
check_basis_params(basis_params)
149161

150-
151162
if basis_params['basis_id'] == "sphereSL":
152163
modelname = basis_params["modelname"]
153164
try:
@@ -168,13 +179,17 @@ def write_config(basis_params):
168179
"id": basis_id,
169180
"parameters": basis_params
170181
}
171-
172-
print(config_dict)
173-
174-
175-
return yaml.dump(config_dict, sort_keys=False)
182+
print('OK')
183+
yaml_str = yaml.dump(config_dict, sort_keys=False)
184+
print('here')
185+
if write_yaml:
186+
with open(filename, "w") as f:
187+
f.write(yaml_str)
188+
print('----')
189+
return yaml_str
176190

177-
def make_basis(R, D, Mtotal, basis_params, physical_units=True):
191+
192+
def make_basis(R, D, Mtotal, basis_params, physical_units=True, write_yaml=False):
178193
"""
179194
Construct a basis from a given radial density profile.
180195
@@ -188,8 +203,9 @@ def make_basis(R, D, Mtotal, basis_params, physical_units=True):
188203
Total mass normalization (default is 1.0).
189204
basis_params : dict
190205
basis parameters e.g., basis_id, nmax, lmax
191-
For the descriptions of the basis_params please see the EXP description:
192-
https://github.com/EXP-code/EXP-docs/blob/93207da758d34cf10092650a840cdb23180b859a/topics/yamlconfig.rst#L229
206+
For the descriptions of the basis_params please see the EXP docs:
207+
https://github.com/EXP-code/EXP-docs/blob/93207da758d34cf10092650a84
208+
0cdb23180b859a/topics/yamlconfig.rst#L229
193209
194210
195211
Returns
@@ -204,10 +220,6 @@ def make_basis(R, D, Mtotal, basis_params, physical_units=True):
204220
- It then builds a basis either spherical (`sphereSL`) or cylindrical using `EXPtools.make_config`
205221
and returns the corresponding `pyEXP` basis object.
206222
207-
TODO:
208-
-----
209-
- check cache values are consitent with basis_params
210-
- informce float for rmapping
211223
"""
212224

213225
if "modelname" not in basis_params.keys():
@@ -221,8 +233,8 @@ def make_basis(R, D, Mtotal, basis_params, physical_units=True):
221233
output_filename=basis_params['modelname'],
222234
physical_units=physical_units
223235
)
224-
225-
config = write_config(basis_params)
236+
print('Done making model')
237+
config = write_config(basis_params, write_yaml)
226238

227239
basis = pyEXP.basis.Basis.factory(config)
228240
return basis

EXPtools/coefficients/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from . import coefficients
1+
from .coefficients import compute_exp_coefs, compute_exp_coefs_parallel
22

3-
__all__ = ["compute_exp_coefs", "compute_exp_coeds_parallel"]
3+
__all__ = ["compute_exp_coefs", "compute_exp_coefs_parallel"]

0 commit comments

Comments
 (0)