diff --git a/src/midst_toolkit/models/clavaddpm/clustering.py b/src/midst_toolkit/models/clavaddpm/clustering.py index e9a8212d..09fd4963 100644 --- a/src/midst_toolkit/models/clavaddpm/clustering.py +++ b/src/midst_toolkit/models/clavaddpm/clustering.py @@ -2,7 +2,7 @@ import os import pickle -from collections import OrderedDict, defaultdict +from collections import defaultdict from logging import INFO, WARNING from pathlib import Path from typing import Any @@ -233,7 +233,7 @@ def _pair_clustering( cluster_labels = _get_cluster_labels(cluster_data, clustering_method, num_clusters) - child_group_data = _get_group_data(sorted_child_data, foreign_key_index) + child_group_data = group_data_by_id(sorted_child_data, foreign_key_index, sort_by_column_value=True) child_group_lengths = np.array([len(group) for group in child_group_data], dtype=int) if clustering_method == ClusteringMethod.VARIATIONAL: @@ -313,12 +313,12 @@ def _merge_parent_data_with_child_data( child_data: Numpy array of the child data. Should be sorted by the foreign key. parent_data: Numpy array of the parent data. Should be sorted by the parent primary key. parent_primary_key_index: Index of the parent primary key. - foreign_key_index: Index of the foreign key to the child data. + foreign_key_index: Index of the foreign key in the child data. Returns: Numpy array of the parent data merged for each group of the child group data. """ - child_group_data_dict = _group_data_by_group_id(child_data, foreign_key_index) + child_group_data_dict = group_data_by_group_id_as_dict(child_data, foreign_key_index) group_lengths = [] unique_group_ids = parent_data[:, parent_primary_key_index] @@ -669,47 +669,55 @@ def _get_categorical_and_numerical_columns( return numerical_columns, categorical_columns -def _group_data_by_group_id( - np_data: np.ndarray, - group_id_index: int, +def group_data_by_group_id_as_dict( + data_to_be_grouped: np.ndarray, column_index_to_group_by: int ) -> dict[int, list[np.ndarray]]: """ - Collects the data in each group by group id and returns it as a dictionary. + Group rows in a numpy array by their values in the column specified by ``column_index_to_group_by`` into a + dictionary. Returns a dict where keys are values from the column to group by and values are lists of + corresponding rows (groups). Args: - np_data: Numpy array of the data. - group_id_index: The index of the data that contains the group id. + data_to_be_grouped: Numpy array of the data to be grouped. + column_index_to_group_by: Column index by which the data should be grouped. Returns: - Dictionary of group data by group id. + Dictionary of group data where the keys are values from the column to group by and the values + are a list of full ROWS from the ``data_to_be_grouped`` where the specified column value is shared. """ - group_data_by_group_id = OrderedDict[int, list[np.ndarray]]() - - for i in range(len(np_data)): - group_id = _parse_numpy_number_as_int(np_data[i, group_id_index]) - - if group_id not in group_data_by_group_id: - group_data_by_group_id[group_id] = [] + grouped_data_dict: defaultdict[int, list[np.ndarray]] = defaultdict(list) + num_rows = len(data_to_be_grouped) + for row in range(num_rows): + row_id = _parse_numpy_number_as_int(data_to_be_grouped[row, column_index_to_group_by]) + grouped_data_dict[row_id].append(data_to_be_grouped[row]) - group_data_by_group_id[group_id].append(np_data[i]) + return grouped_data_dict - return group_data_by_group_id - -def _get_group_data(np_data: np.ndarray, group_id_index: int) -> np.ndarray: +def group_data_by_id( + data_to_be_grouped: np.ndarray, column_index_to_group_by: int, sort_by_column_value: bool = False +) -> np.ndarray: """ - Collects the data in each group by group id and returns it as a numpy array. + Group rows in a numpy array that share values in the column specified by ``column_index_to_group_by``. + Returns an array of arrays where each sub-array contains full rows sharing identical values in the grouping column. Args: - np_data: Numpy array of the data. - group_id_index: The index of the data that contains the group id. + data_to_be_grouped: Numpy array of the data to be grouped. + column_index_to_group_by: Column index by which the data should be grouped. + sort_by_column_value: Whether or not the returned groups are sorted by the values in the column the index + ``column_index_to_group_by``. Defaults to False. Returns: - Numpy array of the data ordered by group id. + Numpy array of the data grouped by values in the column with index ``column_index_to_group_by``. The returned + array has dtype=object since groups may have different lengths. """ - group_data_by_group_id = _group_data_by_group_id(np_data, group_id_index) - group_data_list = [np.array(group_data) for group_data in group_data_by_group_id.values()] - return np.array(group_data_list, dtype=object) + grouped_data_by_group_id = group_data_by_group_id_as_dict(data_to_be_grouped, column_index_to_group_by) + if sort_by_column_value: + grouped_data = [(key, np.array(group_data)) for key, group_data in grouped_data_by_group_id.items()] + grouped_data_list = [data for _, data in sorted(grouped_data)] + else: + grouped_data_list = [np.array(group_data) for group_data in grouped_data_by_group_id.values()] + return np.array(grouped_data_list, dtype=object) def _parse_numpy_number_as_int(number: np.number) -> int: diff --git a/tests/unit/models/clavaddpm/test_clustering.py b/tests/unit/models/clavaddpm/test_clustering.py index f88555dd..8d684796 100644 --- a/tests/unit/models/clavaddpm/test_clustering.py +++ b/tests/unit/models/clavaddpm/test_clustering.py @@ -5,6 +5,8 @@ _min_max_normalize_sklearn, _quantile_normalize_sklearn, get_normalized_numerical_columns, + group_data_by_group_id_as_dict, + group_data_by_id, ) from midst_toolkit.models.clavaddpm.enumerations import DataAndKeyNormalizationType @@ -81,3 +83,97 @@ def test_get_normalized_numerical_columns() -> None: ) unset_all_random_seeds() + + +def test_group_data_by_id() -> None: + set_all_random_seeds(42) + data_array_with_one_foreign_keys = np.hstack( + (np.random.randn(10, 3), np.random.randint(0, 3, (10, 1)).astype(float), np.random.randn(10, 1)) + ) + data_array_with_foreign_key_in_front = np.hstack( + (np.random.randint(0, 2, (10, 1)).astype(float), np.random.randn(10, 3), np.random.randn(10, 1)) + ) + + grouped_data = group_data_by_id(data_array_with_one_foreign_keys, 3) + assert len(grouped_data) == 3 + assert len(grouped_data[0]) == 4 + assert len(grouped_data[1]) == 2 + assert len(grouped_data[2]) == 4 + assert np.allclose( + grouped_data[0], + np.array( + [ + [0.49671415, -0.1382643, 0.64768854, 2.0, 2.77831304], + [1.52302986, -0.23415337, -0.23413696, 2.0, 1.19363972], + [0.54256004, -0.46341769, -0.46572975, 2.0, 0.88176104], + [0.24196227, -1.91328024, -1.72491783, 2.0, -1.00908534], + ] + ), + atol=1e-6, + ) + assert np.allclose( + grouped_data[1], + np.array( + [ + [1.57921282, 0.76743473, -0.46947439, 0.0, 0.21863832], + [-0.90802408, -1.4123037, 1.46564877, 0.0, 0.77370042], + ], + ), + atol=1e-6, + ) + + grouped_data = group_data_by_id(data_array_with_foreign_key_in_front, 0, sort_by_column_value=True) + # Because the first column is non-unique, we get proper groups. + assert len(grouped_data) == 2 + assert len(grouped_data[0]) == 9 + assert len(grouped_data[1]) == 1 + assert np.allclose( + grouped_data[1], + np.array([[1.0, -0.676922, 0.61167629, 1.03099952, 1.47789404]]), + atol=1e-6, + ) + assert np.allclose( + grouped_data[0], + np.array( + [ + [0.0, 0.93128012, -0.83921752, -0.30921238, -0.51827022], + [0.0, 0.33126343, 0.97554513, -0.47917424, -0.8084936], + [0.0, -0.18565898, -1.10633497, -1.19620662, -0.50175704], + [0.0, 0.81252582, 1.35624003, -0.07201012, 0.91540212], + [0.0, 1.0035329, 0.36163603, -0.64511975, 0.32875111], + [0.0, 0.36139561, 1.53803657, -0.03582604, -0.5297602], + [0.0, 1.56464366, -2.6197451, 0.8219025, 0.51326743], + [0.0, 0.08704707, -0.29900735, 0.09176078, 0.09707755], + [0.0, -1.98756891, -0.21967189, 0.35711257, 0.96864499], + ] + ), + atol=1e-6, + ) + unset_all_random_seeds() + + +def test_group_data_by_group_id_as_dict() -> None: + set_all_random_seeds(42) + data_array_with_one_foreign_keys = np.hstack( + (np.random.randn(10, 3), np.random.randint(0, 3, (10, 1)).astype(float), np.random.randn(10, 1)) + ) + data_array_with_foreign_key_in_front = np.hstack( + (np.random.randint(0, 2, (10, 1)).astype(float), np.random.randn(10, 3), np.random.randn(10, 1)) + ) + + grouped_data = group_data_by_group_id_as_dict(data_array_with_one_foreign_keys, 3) + assert len(grouped_data) == 3 + assert len(grouped_data[2]) == 4 + assert len(grouped_data[0]) == 2 + assert np.allclose(grouped_data[0][0], np.array([1.57921282, 0.76743473, -0.46947439, 0.0, 0.21863832]), atol=1e-6) + assert np.allclose(grouped_data[0][1], np.array([-0.90802408, -1.4123037, 1.46564877, 0.0, 0.77370042]), atol=1e-6) + assert np.allclose( + grouped_data[2][1], np.array([1.52302986, -0.23415337, -0.23413696, 2.0, 1.19363972]), atol=1e-6 + ) + + grouped_data = group_data_by_group_id_as_dict(data_array_with_foreign_key_in_front, 0) + # Because the first column is non-unique, we get proper groups. + assert len(grouped_data) == 2 + assert len(grouped_data[0]) == 9 + assert len(grouped_data[1]) == 1 + unset_all_random_seeds()