Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lib/ramble/ramble/expander.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,9 @@ def expand_var(
if var is None or var == "None":
return None if typed else "None"

if isinstance(var, (int, float, bool)):
return var if typed else str(var)

passthrough_setting = allow_passthrough

# If disable_passthrough is set, override allow_passthrough from caller
Expand Down
12 changes: 11 additions & 1 deletion lib/ramble/ramble/schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,19 @@ def _deprecated_properties(validator, deprecated, instance, schema):

yield jsonschema.ValidationError(msg)

return jsonschema.validators.extend(
ValidatorClass = jsonschema.validators.extend(
jsonschema.Draft4Validator, {"deprecatedProperties": _deprecated_properties}
)

import spack.util.spack_yaml as syaml

# Add syaml_bool to the accepted boolean types for validation
boolean_types = ValidatorClass.DEFAULT_TYPES.get("boolean", bool)
if not isinstance(boolean_types, tuple):
boolean_types = (boolean_types,)
ValidatorClass.DEFAULT_TYPES["boolean"] = boolean_types + (syaml.syaml_bool,)

return ValidatorClass


Validator = llnl.util.lang.Singleton(_make_validator)
40 changes: 40 additions & 0 deletions lib/ramble/ramble/test/end_to_end/env_var_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,43 @@ def test_auto_env_vars(make_workspace_from_config, mock_applications, mock_modif
assert "MY_AUTO_ENV_VAR_WL_DEFAULTS" not in data
assert "MY_AUTO_ENV_VAR_WG" not in data
assert "OBJ_AUTO_ENV_VAR" not in data


def test_env_var_bool_case(mock_applications, make_workspace_from_config):
test_config = """
ramble:
config:
shell: bash
variables:
mpi_command: 'mpirun -n {n_ranks} -ppn {processes_per_node}'
batch_submit: 'batch_submit {execute_experiment}'
partition: 'part1'
processes_per_node: '16'
n_threads: '1'
applications:
interleved-env-vars:
workloads:
test_wl:
experiments:
simple_test:
variables:
n_nodes: 1
env_vars:
set:
MY_VAR: TRUE
software:
packages: {}
environments: {}
"""
ws, ws_name = make_workspace_from_config(test_config)

workspace("setup", "--dry-run", global_args=["-w", ws_name])

experiment_root = ws.experiment_dir
exp_dir = os.path.join(experiment_root, "interleved-env-vars", "test_wl", "simple_test")
exp_script = os.path.join(exp_dir, "execute_experiment")

with open(exp_script, encoding="utf-8") as f:
data = f.read()

assert "MY_VAR=TRUE" in data
9 changes: 6 additions & 3 deletions lib/ramble/ramble/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,10 @@ def experiment_variant(self, name: str, value: Any):
default_var = self.default_variants[name]

# If the default value is a boolean, convert the experiment value to a boolean
if default_var and isinstance(default_var.default, bool):
if default_var and (
isinstance(default_var.default, bool)
or type(default_var.default).__name__ == "syaml_bool"
):
Comment on lines +192 to +195

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using type(default_var.default).__name__ == "syaml_bool" is a fragile string-based type check. It is more robust and idiomatic to import syaml_bool and use isinstance to check the type, which also correctly handles subclasses.

        from spack.util.spack_yaml import syaml_bool
        if default_var and isinstance(default_var.default, (bool, syaml_bool)):

if isinstance(value, str):
value = value.lower() == "true"

Expand Down Expand Up @@ -403,7 +406,7 @@ def __init__(
self._definitions = self._build_definitions()

def _build_definitions(self) -> tuple:
if isinstance(self.default, bool):
if isinstance(self.default, bool) or type(self.default).__name__ == "syaml_bool":

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using type(self.default).__name__ == "syaml_bool" is a fragile string-based type check. It is more robust and idiomatic to import syaml_bool and use isinstance to check the type.

Suggested change
if isinstance(self.default, bool) or type(self.default).__name__ == "syaml_bool":
from spack.util.spack_yaml import syaml_bool
if isinstance(self.default, (bool, syaml_bool)):

val_str = str(self.default)
return (
self._definition,
Expand All @@ -419,7 +422,7 @@ def copy(self):

def format_value(self, value: Any) -> str:
"""Format a value for this variant into Spack-like syntax"""
if isinstance(self.default, bool):
if isinstance(self.default, bool) or type(self.default).__name__ == "syaml_bool":

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using type(self.default).__name__ == "syaml_bool" is a fragile string-based type check. It is more robust and idiomatic to import syaml_bool and use isinstance to check the type.

Suggested change
if isinstance(self.default, bool) or type(self.default).__name__ == "syaml_bool":
from spack.util.spack_yaml import syaml_bool
if isinstance(self.default, (bool, syaml_bool)):

prefix = "+" if value else "~"
return f"{prefix}{self.name}"
else:
Expand Down
42 changes: 42 additions & 0 deletions lib/ramble/spack/util/spack_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,25 @@ class syaml_int(int):
__repr__ = int.__repr__


class syaml_bool(int):
def __new__(cls, val, string_val=None):
obj = super(syaml_bool, cls).__new__(cls, val)
obj.string_val = string_val
return obj

def __repr__(self):
return self.string_val if self.string_val else ('True' if self else 'False')

def __str__(self):
return self.string_val if self.string_val else ('True' if self else 'False')



#: mapping from syaml type -> primitive type
syaml_types = {
syaml_str: str,
syaml_int: int,
syaml_bool: bool,
syaml_dict: dict,
syaml_list: list,
}
Expand Down Expand Up @@ -121,6 +136,13 @@ class OrderedLineLoader(RoundTripLoader):
# and fill in with mappings later. We preserve this behavior.
#


def construct_yaml_bool(self, node):
value = super(OrderedLineLoader, self).construct_yaml_bool(node)
b = syaml_bool(value, string_val=node.value)
mark(b, node)
return b

def construct_yaml_str(self, node):
value = super(OrderedLineLoader, self).construct_yaml_str(node)
# There is no specific marker to indicate that we are parsing a key,
Expand Down Expand Up @@ -156,6 +178,8 @@ def construct_yaml_map(self, node):


# register above new constructors
OrderedLineLoader.add_constructor(
'tag:yaml.org,2002:bool', OrderedLineLoader.construct_yaml_bool)
OrderedLineLoader.add_constructor(
'tag:yaml.org,2002:map', OrderedLineLoader.construct_yaml_map)
OrderedLineLoader.add_constructor(
Expand All @@ -178,12 +202,23 @@ def ignore_aliases(self, _data):
"""Make the dumper NEVER print YAML aliases."""
return True

def represent_bool(self, data):
if hasattr(data, 'string_val') and data.string_val:
return self.represent_scalar('tag:yaml.org,2002:bool', data.string_val)
return super(SafeDumper, self).represent_bool(bool(data))

Comment on lines +205 to +209

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a duplicate definition of represent_bool in OrderedLineDumper. Additionally, this first definition incorrectly uses super(SafeDumper, self) instead of super(OrderedLineDumper, self), which would raise a TypeError at runtime if it were ever called. Since it is shadowed by the correct definition further down, this duplicate should be removed entirely to avoid confusion and potential bugs.

def represent_data(self, data):
result = super(OrderedLineDumper, self).represent_data(data)
if data is None:
result.value = syaml_str("null")
return result


def represent_bool(self, data):
if hasattr(data, 'string_val') and data.string_val:
return self.represent_scalar('tag:yaml.org,2002:bool', data.string_val)
return super(OrderedLineDumper, self).represent_bool(bool(data))

def represent_str(self, data):
if hasattr(data, 'override') and data.override:
data = data + ':'
Expand All @@ -198,12 +233,19 @@ def ignore_aliases(self, _data):
"""Make the dumper NEVER print YAML aliases."""
return True

def represent_bool(self, data):
if hasattr(data, 'string_val') and data.string_val:
return self.represent_scalar('tag:yaml.org,2002:bool', data.string_val)
return super(SafeDumper, self).represent_bool(bool(data))


# Make our special objects look like normal YAML ones.
RoundTripDumper.add_representer(syaml_dict, RoundTripDumper.represent_dict)
RoundTripDumper.add_representer(syaml_list, RoundTripDumper.represent_list)
RoundTripDumper.add_representer(syaml_int, RoundTripDumper.represent_int)
RoundTripDumper.add_representer(syaml_bool, RoundTripDumper.represent_bool)
RoundTripDumper.add_representer(syaml_str, RoundTripDumper.represent_str)
OrderedLineDumper.add_representer(syaml_bool, OrderedLineDumper.represent_bool)
OrderedLineDumper.add_representer(syaml_str, OrderedLineDumper.represent_str)


Expand Down
Loading