Skip to content

Commit 2c64342

Browse files
committed
Fixes to the tests for the count function after review
1 parent f388cdc commit 2c64342

File tree

1 file changed

+84
-28
lines changed
  • core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api

1 file changed

+84
-28
lines changed

core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api/count.kt

Lines changed: 84 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,73 @@
11
package org.jetbrains.kotlinx.dataframe.api
22

33
import io.kotest.matchers.shouldBe
4+
import org.jetbrains.kotlinx.dataframe.DataColumn
5+
import org.jetbrains.kotlinx.dataframe.DataFrame
6+
import org.jetbrains.kotlinx.dataframe.DataRow
47
import org.jetbrains.kotlinx.dataframe.nrow
8+
import org.junit.BeforeClass
59
import org.junit.Test
610

11+
/**
12+
* Tests the behavior of the `count` function across different DataFrame structures:
13+
*
14+
* - **[DataColumn]**: counting all elements or elements matching a predicate,
15+
* including behavior on empty columns and columns with `null` values.
16+
*
17+
* - **[DataRow]**: counting elements or elements matching a predicate,
18+
* including rows containing `null` values.
19+
*
20+
* - **[DataFrame]**: counting all rows or rows matching a predicate in this [DataFrame],
21+
* including behavior on empty data frames and data frames with `null` values.
22+
*
23+
* - **[GroupBy]**: counting rows per group in the [GroupBy], with and without a predicate,
24+
* including behavior on grouped empty DataFrame and groups with `null` values.
25+
*
26+
* - **[Pivot]**: counting rows in each group of the [Pivot],
27+
* including handling of `null` values and predicates.
28+
*
29+
* - **[PivotGroupBy]**: counting rows in each combined [pivot] + [groupBy] group,
30+
* with and without predicates, including handling of empty and `null` groups.
31+
*/
732
class CountTests {
833

9-
// Test data
10-
11-
val df = dataFrameOf(
12-
"name" to columnOf("Alice", "Bob", "Charlie"),
13-
"age" to columnOf(15, 20, 25),
14-
"group" to columnOf(1, 1, 2),
15-
)
16-
val age = df["age"].cast<Int>()
17-
val name = df["name"].cast<String>()
18-
val grouped = df.groupBy("group")
19-
val pivot = df.pivot("group")
20-
21-
val emptyDf = df.drop(df.nrow)
34+
// region Test data
35+
36+
companion object {
37+
lateinit var df: DataFrame<*>
38+
lateinit var age: DataColumn<Int>
39+
lateinit var name: DataColumn<String>
40+
lateinit var grouped: GroupBy<*, *>
41+
lateinit var pivoted: Pivot<*>
42+
lateinit var emptyDf: DataFrame<*>
43+
lateinit var dfWithNulls: DataFrame<*>
44+
lateinit var ageWithNulls: DataColumn<Int?>
45+
lateinit var groupedWithNulls: GroupBy<*, *>
46+
lateinit var pivotWithNulls: Pivot<*>
47+
48+
@BeforeClass
49+
@JvmStatic
50+
fun setupTestData() {
51+
df = dataFrameOf(
52+
"name" to columnOf("Alice", "Bob", "Charlie"),
53+
"age" to columnOf(15, 20, 25),
54+
"group" to columnOf(1, 1, 2),
55+
)
56+
age = df["age"].cast()
57+
name = df["name"].cast()
58+
grouped = df.groupBy("group")
59+
pivoted = df.pivot("group")
60+
emptyDf = df.drop(df.nrow)
61+
dfWithNulls = df.append("Martin", null, null)
62+
ageWithNulls = dfWithNulls["age"].cast()
63+
groupedWithNulls = dfWithNulls.groupBy("group")
64+
pivotWithNulls = dfWithNulls.pivot("group")
65+
}
66+
}
2267

23-
val dfWithNulls = df.append("Martin", null, null)
24-
val ageWithNulls = dfWithNulls["age"].cast<Int?>()
25-
val groupedWithNulls = dfWithNulls.groupBy("group")
26-
val pivotWithNulls = dfWithNulls.pivot("group")
68+
// endregion
2769

28-
// DataColumn
70+
// region DataColumn
2971

3072
@Test
3173
fun `count on DataColumn`() {
@@ -46,13 +88,15 @@ class CountTests {
4688
ageWithNulls.count { it == null } shouldBe 1
4789
}
4890

49-
// DataRow
91+
// endregion
92+
93+
// region DataRow
5094

5195
@Test
5296
fun `count on DataRow`() {
5397
val row = df[0]
5498
row.count() shouldBe 3
55-
(row.count { it is Number }) shouldBe 2
99+
row.count { it is Number } shouldBe 2
56100
}
57101

58102
@Test
@@ -62,7 +106,9 @@ class CountTests {
62106
row.count { it == null } shouldBe 2
63107
}
64108

65-
// DataFrame
109+
// endregion
110+
111+
// region DataFrame
66112

67113
@Test
68114
fun `count on DataFrame`() {
@@ -82,7 +128,9 @@ class CountTests {
82128
dfWithNulls.count { it["age"] != null } shouldBe 3
83129
}
84130

85-
// GroupBy
131+
// endregion
132+
133+
// region GroupBy
86134

87135
@Test
88136
fun `count on grouped DataFrame`() {
@@ -104,6 +152,8 @@ class CountTests {
104152
groupedCount shouldBe expected
105153
}
106154

155+
// `emptyDf.groupBy("group").count()` results in a dataframe without the column `count`
156+
// Issue #1531
107157
@Test
108158
fun `count on empty grouped DataFrame`() {
109159
emptyDf.groupBy("group").count().count() shouldBe 0
@@ -129,11 +179,13 @@ class CountTests {
129179
groupedWithNullsCount shouldBe expected
130180
}
131181

132-
// Pivot
182+
// endregion
183+
184+
// region Pivot
133185

134186
@Test
135187
fun `count on Pivot`() {
136-
val counted = pivot.count()
188+
val counted = pivoted.count()
137189
val expected = dataFrameOf(
138190
"1" to columnOf(2),
139191
"2" to columnOf(1),
@@ -143,7 +195,7 @@ class CountTests {
143195

144196
@Test
145197
fun `count on Pivot with predicate`() {
146-
val counted = pivot.count { "group"<Int>() != 1 }
198+
val counted = pivoted.count { "group"<Int>() != 1 }
147199
val expected = dataFrameOf(
148200
"1" to columnOf(0),
149201
"2" to columnOf(1),
@@ -173,11 +225,13 @@ class CountTests {
173225
counted shouldBe expected
174226
}
175227

176-
// PivotGroupBy
228+
// endregion
229+
230+
// region PivotGroupBy
177231

178232
@Test
179233
fun `count on PivotGroupBy`() {
180-
val pivotGrouped = pivot.groupBy("age")
234+
val pivotGrouped = pivoted.groupBy("age")
181235
val counted = pivotGrouped.count()
182236
val expected = dataFrameOf(
183237
"age" to columnOf(15, 20, 25),
@@ -191,7 +245,7 @@ class CountTests {
191245

192246
@Test
193247
fun `count on PivotGroupBy with predicate`() {
194-
val pivotGrouped = pivot.groupBy("age")
248+
val pivotGrouped = pivoted.groupBy("age")
195249
val counted = pivotGrouped.count { "name"<String>() == "Alice" }
196250
val expected = dataFrameOf(
197251
"age" to columnOf(15, 20, 25),
@@ -232,4 +286,6 @@ class CountTests {
232286
)
233287
counted shouldBe expected
234288
}
289+
290+
// endregion
235291
}

0 commit comments

Comments
 (0)