diff --git a/doc/source/whatsnew/v2.3.4.rst b/doc/source/whatsnew/v2.3.4.rst index 6e729c4bf2e2a..897cbacb03170 100644 --- a/doc/source/whatsnew/v2.3.4.rst +++ b/doc/source/whatsnew/v2.3.4.rst @@ -14,6 +14,7 @@ Bug fixes ^^^^^^^^^ - Bug in :meth:`DataFrame.__getitem__` returning modified columns when called with ``slice`` in Python 3.12 (:issue:`57500`) - Bug in :meth:`Series.str.replace` raising an error on valid group references (``\1``, ``\2``, etc.) on series converted to PyArrow backend dtype (:issue:`62653`) +- Bug in :meth:`~DataFrame.groupby` with ``None`` values with filter (:issue:`62501`) .. --------------------------------------------------------------------------- .. _whatsnew_234.contributors: diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 2c8ec599a19ef..4b8b7717ad7ee 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -637,7 +637,7 @@ def indices(self) -> dict[Hashable, npt.NDArray[np.intp]]: return self._grouper.indices @final - def _get_indices(self, names): + def _get_index(self, name): """ Safe get multiple indices, translate keys for datelike to underlying repr. @@ -653,23 +653,24 @@ def get_converter(s): else: return lambda key: key - if len(names) == 0: - return [] + if isna(name): + return self.indices.get(np.nan, []) + if isinstance(name, tuple): + name = tuple(np.nan if isna(comp) else comp for comp in name) if len(self.indices) > 0: index_sample = next(iter(self.indices)) else: index_sample = None # Dummy sample - name_sample = names[0] if isinstance(index_sample, tuple): - if not isinstance(name_sample, tuple): + if not isinstance(name, tuple): msg = "must supply a tuple to get_group with multiple grouping keys" raise ValueError(msg) - if not len(name_sample) == len(index_sample): + if not len(name) == len(index_sample): try: # If the original grouper was a tuple - return [self.indices[name] for name in names] + return self.indices[name] except KeyError as err: # turns out it wasn't a tuple msg = ( @@ -679,23 +680,12 @@ def get_converter(s): raise ValueError(msg) from err converters = (get_converter(s) for s in index_sample) - names = ( - tuple(f(n) for f, n in zip(converters, name, strict=True)) - for name in names - ) - + name = tuple(f(n) for f, n in zip(converters, name, strict=True)) else: converter = get_converter(index_sample) - names = (converter(name) for name in names) - - return [self.indices.get(name, []) for name in names] + name = converter(name) - @final - def _get_index(self, name): - """ - Safe get index, translate keys for datelike to underlying repr. - """ - return self._get_indices([name])[0] + return self.indices.get(name, []) @final @cache_readonly diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index d86264cb95dc5..f6600f39bbc57 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -652,9 +652,25 @@ def indices(self) -> dict[Hashable, npt.NDArray[np.intp]]: """dict {group name -> group indices}""" if len(self.groupings) == 1 and isinstance(self.result_index, CategoricalIndex): # This shows unused categories in indices GH#38642 - return self.groupings[0].indices - codes_list = [ping.codes for ping in self.groupings] - return get_indexer_dict(codes_list, self.levels) + result = self.groupings[0].indices + else: + codes_list = [ping.codes for ping in self.groupings] + result = get_indexer_dict(codes_list, self.levels) + if not self.dropna: + has_mi = isinstance(self.result_index, MultiIndex) + if not has_mi and self.result_index.hasnans: + result = { + np.nan if isna(key) else key: value for key, value in result.items() + } + elif has_mi: + # MultiIndex has no efficient way to tell if there are NAs + result = { + # error: "Hashable" has no attribute "__iter__" (not iterable) + tuple(np.nan if isna(comp) else comp for comp in key): value # type: ignore[attr-defined] + for key, value in result.items() + } + + return result @final @cache_readonly diff --git a/pandas/tests/groupby/test_filters.py b/pandas/tests/groupby/test_filters.py index 4fe3aac629513..c20fc9e3d62e7 100644 --- a/pandas/tests/groupby/test_filters.py +++ b/pandas/tests/groupby/test_filters.py @@ -606,3 +606,33 @@ def test_filter_consistent_result_before_after_agg_func(): grouper.sum() result = grouper.filter(lambda x: True) tm.assert_frame_equal(result, expected) + + +def test_filter_with_non_values(): + # GH 62501 + df = DataFrame( + [ + [1], + [None], + ], + columns=["a"], + ) + + result = df.groupby("a", dropna=False).filter(lambda x: True) + tm.assert_frame_equal(result, df) + + +def test_filter_with_non_values_multi_index(): + # GH 62501 + df = DataFrame( + [ + [1, 2], + [3, None], + [None, 4], + [None, None], + ], + columns=["a", "b"], + ) + + result = df.groupby(["a", "b"], dropna=False).filter(lambda x: True) + tm.assert_frame_equal(result, df)