Skip to content

Commit b091dc6

Browse files
committed
fix: struct writer for lance v2.1+
1 parent 6db8f08 commit b091dc6

File tree

4 files changed

+135
-2
lines changed

4 files changed

+135
-2
lines changed

lance-spark-base_2.12/src/main/scala/com/lancedb/lance/spark/arrow/LanceArrowFieldWriter.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,15 @@ abstract private[arrow] class LanceArrowFieldWriter {
4242

4343
def write(input: SpecializedGetters, ordinal: Int): Unit = {
4444
if (input.isNullAt(ordinal)) {
45-
setNull()
45+
writeNull()
4646
} else {
4747
setValue(input, ordinal)
48+
count += 1
4849
}
50+
}
51+
52+
private[arrow] def writeNull(): Unit = {
53+
setNull()
4954
count += 1
5055
}
5156

lance-spark-base_2.12/src/main/scala/com/lancedb/lance/spark/arrow/LanceArrowWriter.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,18 @@ private[arrow] class MapWriter(
388388
private[arrow] class StructWriter(
389389
val valueVector: StructVector,
390390
val children: Array[LanceArrowFieldWriter]) extends LanceArrowFieldWriter {
391-
override def setNull(): Unit = {}
391+
override def setNull(): Unit = {
392+
// mark the parent struct as null for this row
393+
valueVector.setNull(count)
394+
var i = 0
395+
while (i < children.length) {
396+
children(i).writeNull()
397+
i += 1
398+
}
399+
}
392400
override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
401+
// mark the parent struct as defined (not null) before writing children
402+
valueVector.setIndexDefined(count)
393403
val struct = input.getStruct(ordinal, children.length)
394404
var i = 0
395405
while (i < children.length) {

lance-spark-base_2.12/src/test/java/com/lancedb/lance/spark/write/LanceArrowWriterTest.java

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515

1616
import org.apache.arrow.memory.BufferAllocator;
1717
import org.apache.arrow.memory.RootAllocator;
18+
import org.apache.arrow.vector.IntVector;
1819
import org.apache.arrow.vector.VectorSchemaRoot;
1920
import org.apache.arrow.vector.VectorUnloader;
21+
import org.apache.arrow.vector.complex.StructVector;
2022
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
23+
import org.apache.arrow.vector.types.pojo.ArrowType;
2124
import org.apache.arrow.vector.types.pojo.Field;
2225
import org.apache.arrow.vector.types.pojo.FieldType;
2326
import org.apache.arrow.vector.types.pojo.Schema;
@@ -33,6 +36,7 @@
3336
import java.util.concurrent.atomic.AtomicLong;
3437

3538
import static org.junit.jupiter.api.Assertions.assertEquals;
39+
import static org.junit.jupiter.api.Assertions.assertTrue;
3640

3741
public class LanceArrowWriterTest {
3842
@Test
@@ -107,4 +111,54 @@ public void test() throws Exception {
107111
arrowWriter.close();
108112
}
109113
}
114+
115+
@Test
116+
public void propagatesStructNullsToChildren() {
117+
try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) {
118+
Field childField =
119+
new Field(
120+
"child", FieldType.nullable(new ArrowType.Int(32, true)), Collections.emptyList());
121+
Field structField =
122+
new Field(
123+
"struct_col",
124+
FieldType.nullable(ArrowType.Struct.INSTANCE),
125+
Collections.singletonList(childField));
126+
Schema arrowSchema = new Schema(Collections.singletonList(structField));
127+
128+
StructType childType =
129+
new StructType(
130+
new StructField[] {
131+
DataTypes.createStructField("child", DataTypes.IntegerType, true)
132+
});
133+
StructType sparkSchema =
134+
new StructType(
135+
new StructField[] {DataTypes.createStructField("struct_col", childType, true)});
136+
137+
try (VectorSchemaRoot root = VectorSchemaRoot.create(arrowSchema, allocator)) {
138+
com.lancedb.lance.spark.arrow.LanceArrowWriter structWriter =
139+
com.lancedb.lance.spark.arrow.LanceArrowWriter$.MODULE$.create(root, sparkSchema);
140+
141+
InternalRow[] rows =
142+
new InternalRow[] {
143+
new GenericInternalRow(new Object[] {new GenericInternalRow(new Object[] {1})}),
144+
new GenericInternalRow(new Object[] {null}),
145+
new GenericInternalRow(new Object[] {new GenericInternalRow(new Object[] {3})})
146+
};
147+
148+
for (InternalRow row : rows) {
149+
structWriter.write(row);
150+
}
151+
structWriter.finish();
152+
153+
StructVector structVector = (StructVector) root.getVector("struct_col");
154+
IntVector childVector = (IntVector) structVector.getChild("child");
155+
156+
assertEquals(rows.length, structVector.getValueCount());
157+
assertEquals(rows.length, childVector.getValueCount());
158+
assertTrue(structVector.isNull(1));
159+
assertEquals(1, childVector.get(0));
160+
assertEquals(3, childVector.get(2));
161+
}
162+
}
163+
}
110164
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.lancedb.lance.spark.arrow
15+
16+
import org.apache.arrow.memory.RootAllocator
17+
import org.apache.arrow.vector.{IntVector, VectorSchemaRoot}
18+
import org.apache.arrow.vector.complex.StructVector
19+
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
20+
import org.apache.spark.sql.types.{DataTypes, StructType}
21+
import org.apache.spark.sql.util.LanceArrowUtils
22+
import org.scalatest.funsuite.AnyFunSuite
23+
24+
import java.time.ZoneId
25+
26+
class StructWriterSuite extends AnyFunSuite {
27+
28+
test("struct null rows advance child writers") {
29+
val childType = new StructType().add("child", DataTypes.IntegerType, nullable = true)
30+
val schema = new StructType().add("struct_col", childType, nullable = true)
31+
32+
val allocator = new RootAllocator(Long.MaxValue)
33+
try {
34+
val arrowSchema = LanceArrowUtils.toArrowSchema(
35+
schema,
36+
ZoneId.systemDefault().getId,
37+
errorOnDuplicatedFieldNames = true)
38+
val root = VectorSchemaRoot.create(arrowSchema, allocator)
39+
try {
40+
val writer = LanceArrowWriter.create(root, schema)
41+
val rows = Seq(
42+
new GenericInternalRow(Array[Any](new GenericInternalRow(Array[Any](1)))),
43+
new GenericInternalRow(Array[Any](null)),
44+
new GenericInternalRow(Array[Any](new GenericInternalRow(Array[Any](3)))))
45+
46+
rows.foreach(writer.write)
47+
writer.finish()
48+
49+
val structVector = root.getVector("struct_col").asInstanceOf[StructVector]
50+
val childVector = structVector.getChild("child").asInstanceOf[IntVector]
51+
52+
assert(structVector.getValueCount === rows.length)
53+
assert(childVector.getValueCount === rows.length)
54+
assert(structVector.isNull(1))
55+
assert(childVector.get(0) === 1)
56+
assert(childVector.get(2) === 3)
57+
} finally {
58+
root.close()
59+
}
60+
} finally {
61+
allocator.close()
62+
}
63+
}
64+
}

0 commit comments

Comments
 (0)