Skip to content

Commit eed30ab

Browse files
authored
Add tests for the count function (#1530)
* Add tests for the `count` function * Fixes to the tests for the `count` function after review * Fixes to the KDoc of the tests for the `count` function
1 parent 63ecf8a commit eed30ab

File tree

1 file changed

+291
-0
lines changed
  • core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/api

1 file changed

+291
-0
lines changed
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
package org.jetbrains.kotlinx.dataframe.api
2+
3+
import io.kotest.matchers.shouldBe
4+
import org.jetbrains.kotlinx.dataframe.DataColumn
5+
import org.jetbrains.kotlinx.dataframe.DataFrame
6+
import org.jetbrains.kotlinx.dataframe.DataRow
7+
import org.jetbrains.kotlinx.dataframe.nrow
8+
import org.junit.BeforeClass
9+
import org.junit.Test
10+
11+
/**
12+
* Tests the behavior of the [count] function across different [DataFrame] structures:
13+
*
14+
* - [DataColumn]: counting all elements or elements matching a predicate,
15+
* including behavior on empty columns and columns with `null` values.
16+
*
17+
* - [DataRow]: counting elements or elements matching a predicate,
18+
* including rows containing `null` values.
19+
*
20+
* - [DataFrame]: counting all rows or rows matching a predicate in this [DataFrame],
21+
* including behavior on empty DataFrames and DataFrames with `null` values.
22+
*
23+
* - [GroupBy]: counting rows per group in the [GroupBy], with and without a predicate,
24+
* including behavior on grouped empty [DataFrame] and groups with `null` values.
25+
*
26+
* - [Pivot]: counting rows in each group of the [Pivot],
27+
* including handling of `null` values and predicates.
28+
*
29+
* - [PivotGroupBy]: counting rows in each combined [pivot] + [groupBy] group,
30+
* with and without predicates, including handling of `null` values and predicates.
31+
*/
32+
class CountTests {
33+
34+
// region Test data
35+
36+
companion object {
37+
lateinit var df: DataFrame<*>
38+
lateinit var age: DataColumn<Int>
39+
lateinit var name: DataColumn<String>
40+
lateinit var grouped: GroupBy<*, *>
41+
lateinit var pivoted: Pivot<*>
42+
lateinit var emptyDf: DataFrame<*>
43+
lateinit var dfWithNulls: DataFrame<*>
44+
lateinit var ageWithNulls: DataColumn<Int?>
45+
lateinit var groupedWithNulls: GroupBy<*, *>
46+
lateinit var pivotWithNulls: Pivot<*>
47+
48+
@BeforeClass
49+
@JvmStatic
50+
fun setupTestData() {
51+
df = dataFrameOf(
52+
"name" to columnOf("Alice", "Bob", "Charlie"),
53+
"age" to columnOf(15, 20, 25),
54+
"group" to columnOf(1, 1, 2),
55+
)
56+
age = df["age"].cast()
57+
name = df["name"].cast()
58+
grouped = df.groupBy("group")
59+
pivoted = df.pivot("group")
60+
emptyDf = df.drop(df.nrow)
61+
dfWithNulls = df.append("Martin", null, null)
62+
ageWithNulls = dfWithNulls["age"].cast()
63+
groupedWithNulls = dfWithNulls.groupBy("group")
64+
pivotWithNulls = dfWithNulls.pivot("group")
65+
}
66+
}
67+
68+
// endregion
69+
70+
// region DataColumn
71+
72+
@Test
73+
fun `count on DataColumn`() {
74+
age.count() shouldBe 3
75+
age.count { it > 18 } shouldBe 2
76+
name.count { it.startsWith("A") } shouldBe 1
77+
}
78+
79+
@Test
80+
fun `count on empty DataColumn`() {
81+
emptyDf["name"].count() shouldBe 0
82+
emptyDf["name"].count { it == "Alice" } shouldBe 0
83+
}
84+
85+
@Test
86+
fun `count on DataColumn with nulls`() {
87+
ageWithNulls.count() shouldBe 4
88+
ageWithNulls.count { it == null } shouldBe 1
89+
}
90+
91+
// endregion
92+
93+
// region DataRow
94+
95+
@Test
96+
fun `count on DataRow`() {
97+
val row = df[0]
98+
row.count() shouldBe 3
99+
row.count { it is Number } shouldBe 2
100+
}
101+
102+
@Test
103+
fun `count on DataRow with nulls`() {
104+
val row = dfWithNulls[3]
105+
row.count() shouldBe 3
106+
row.count { it == null } shouldBe 2
107+
}
108+
109+
// endregion
110+
111+
// region DataFrame
112+
113+
@Test
114+
fun `count on DataFrame`() {
115+
df.count() shouldBe 3
116+
df.count { age > 18 } shouldBe 2
117+
df.count { it["name"] == "Alice" } shouldBe 1
118+
}
119+
120+
@Test
121+
fun `count on empty DataFrame`() {
122+
emptyDf.count() shouldBe 0
123+
}
124+
125+
@Test
126+
fun `count on DataFrame with nulls`() {
127+
dfWithNulls.count() shouldBe 4
128+
dfWithNulls.count { it["age"] != null } shouldBe 3
129+
}
130+
131+
// endregion
132+
133+
// region GroupBy
134+
135+
@Test
136+
fun `count on grouped DataFrame`() {
137+
val groupedCount = grouped.count()
138+
val expected = dataFrameOf(
139+
"group" to columnOf(1, 2),
140+
"count" to columnOf(2, 1),
141+
)
142+
groupedCount shouldBe expected
143+
}
144+
145+
@Test
146+
fun `count on grouped DataFrame with predicate`() {
147+
val groupedCount = grouped.count { "age"<Int>() > 18 }
148+
val expected = dataFrameOf(
149+
"group" to columnOf(1, 2),
150+
"count" to columnOf(1, 1),
151+
)
152+
groupedCount shouldBe expected
153+
}
154+
155+
// `emptyDf.groupBy("group").count()` results in a dataframe without the column `count`
156+
// Issue #1531
157+
@Test
158+
fun `count on empty grouped DataFrame`() {
159+
emptyDf.groupBy("group").count().count() shouldBe 0
160+
}
161+
162+
@Test
163+
fun `count on grouped DataFrame with nulls`() {
164+
val groupedWithNullsCount = groupedWithNulls.count()
165+
val expected = dataFrameOf(
166+
"group" to columnOf(1, 2, null),
167+
"count" to columnOf(2, 1, 1),
168+
)
169+
groupedWithNullsCount shouldBe expected
170+
}
171+
172+
@Test
173+
fun `count on grouped DataFrame with nulls and predicate`() {
174+
val groupedWithNullsCount = groupedWithNulls.count { it["age"] != null }
175+
val expected = dataFrameOf(
176+
"group" to columnOf(1, 2, null),
177+
"count" to columnOf(2, 1, 0),
178+
)
179+
groupedWithNullsCount shouldBe expected
180+
}
181+
182+
// endregion
183+
184+
// region Pivot
185+
186+
@Test
187+
fun `count on Pivot`() {
188+
val counted = pivoted.count()
189+
val expected = dataFrameOf(
190+
"1" to columnOf(2),
191+
"2" to columnOf(1),
192+
)[0]
193+
counted shouldBe expected
194+
}
195+
196+
@Test
197+
fun `count on Pivot with predicate`() {
198+
val counted = pivoted.count { "group"<Int>() != 1 }
199+
val expected = dataFrameOf(
200+
"1" to columnOf(0),
201+
"2" to columnOf(1),
202+
)[0]
203+
counted shouldBe expected
204+
}
205+
206+
@Test
207+
fun `count on Pivot with nulls`() {
208+
val counted = pivotWithNulls.count()
209+
val expected = dataFrameOf(
210+
"1" to columnOf(2),
211+
"2" to columnOf(1),
212+
"null" to columnOf(1),
213+
)[0]
214+
counted shouldBe expected
215+
}
216+
217+
@Test
218+
fun `count on Pivot with nulls and predicate`() {
219+
val counted = pivotWithNulls.count { it["age"] != null }
220+
val expected = dataFrameOf(
221+
"1" to columnOf(2),
222+
"2" to columnOf(1),
223+
"null" to columnOf(0),
224+
)[0]
225+
counted shouldBe expected
226+
}
227+
228+
// endregion
229+
230+
// region PivotGroupBy
231+
232+
@Test
233+
fun `count on PivotGroupBy`() {
234+
val pivotGrouped = pivoted.groupBy("age")
235+
val counted = pivotGrouped.count()
236+
val expected = dataFrameOf(
237+
"age" to columnOf(15, 20, 25),
238+
"group" to columnOf(
239+
"1" to columnOf(1, 1, 0),
240+
"2" to columnOf(0, 0, 1),
241+
),
242+
)
243+
counted shouldBe expected
244+
}
245+
246+
@Test
247+
fun `count on PivotGroupBy with predicate`() {
248+
val pivotGrouped = pivoted.groupBy("age")
249+
val counted = pivotGrouped.count { "name"<String>() == "Alice" }
250+
val expected = dataFrameOf(
251+
"age" to columnOf(15, 20, 25),
252+
"group" to columnOf(
253+
"1" to columnOf(1, 0, 0),
254+
"2" to columnOf(0, 0, 0),
255+
),
256+
)
257+
counted shouldBe expected
258+
}
259+
260+
@Test
261+
fun `count on PivotGroupBy with nulls`() {
262+
val pivotGrouped = pivotWithNulls.groupBy("age")
263+
val counted = pivotGrouped.count()
264+
val expected = dataFrameOf(
265+
"age" to columnOf(15, 20, 25, null),
266+
"group" to columnOf(
267+
"1" to columnOf(1, 1, 0, 0),
268+
"2" to columnOf(0, 0, 1, 0),
269+
"null" to columnOf(0, 0, 0, 1),
270+
),
271+
)
272+
counted shouldBe expected
273+
}
274+
275+
@Test
276+
fun `count PivotGroupBy with nulls and predicate`() {
277+
val pivotGrouped = pivotWithNulls.groupBy("age")
278+
val counted = pivotGrouped.count { it["age"] != null && "age"<Int>() > 15 }
279+
val expected = dataFrameOf(
280+
"age" to columnOf(15, 20, 25, null),
281+
"group" to columnOf(
282+
"1" to columnOf(0, 1, 0, 0),
283+
"2" to columnOf(0, 0, 1, 0),
284+
"null" to columnOf(0, 0, 0, 0),
285+
),
286+
)
287+
counted shouldBe expected
288+
}
289+
290+
// endregion
291+
}

0 commit comments

Comments
 (0)