Skip to content
This repository was archived by the owner on Dec 20, 2018. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.io.{IOException, OutputStream}
import java.nio.ByteBuffer
import java.sql.{Date, Timestamp}
import java.util.HashMap
import java.util.concurrent.TimeUnit

import scala.collection.immutable.Map

Expand All @@ -43,9 +44,10 @@ private[avro] class AvroOutputWriter(
context: TaskAttemptContext,
schema: StructType,
recordName: String,
recordNamespace: String) extends OutputWriter {
recordNamespace: String,
timeUnit: TimeUnit) extends OutputWriter {

private lazy val converter = createConverterToAvro(schema, recordName, recordNamespace)
private lazy val converter = createConverterToAvro(schema, recordName, recordNamespace, timeUnit)
// copy of the old conversion logic after api change in SPARK-19085
private lazy val internalRowConverter =
CatalystTypeConverters.createToScalaConverter(schema).asInstanceOf[InternalRow => Row]
Expand Down Expand Up @@ -89,7 +91,8 @@ private[avro] class AvroOutputWriter(
private def createConverterToAvro(
dataType: DataType,
structName: String,
recordNamespace: String): (Any) => Any = {
recordNamespace: String,
timeUnit: TimeUnit): (Any) => Any = {
dataType match {
case BinaryType => (item: Any) => item match {
case null => null
Expand All @@ -99,14 +102,19 @@ private[avro] class AvroOutputWriter(
FloatType | DoubleType | StringType | BooleanType => identity
case _: DecimalType => (item: Any) => if (item == null) null else item.toString
case TimestampType => (item: Any) =>
if (item == null) null else item.asInstanceOf[Timestamp].getTime
if (item == null) null else {
timeUnit.convert(item.asInstanceOf[Timestamp].getTime, TimeUnit.MILLISECONDS)
}
case DateType => (item: Any) =>
if (item == null) null else item.asInstanceOf[Date].getTime
if (item == null) null else {
timeUnit.convert(item.asInstanceOf[Date].getTime, TimeUnit.MILLISECONDS)
}
case ArrayType(elementType, _) =>
val elementConverter = createConverterToAvro(
elementType,
structName,
SchemaConverters.getNewRecordNamespace(elementType, recordNamespace, structName))
SchemaConverters.getNewRecordNamespace(elementType, recordNamespace, structName),
timeUnit)
(item: Any) => {
if (item == null) {
null
Expand All @@ -126,7 +134,8 @@ private[avro] class AvroOutputWriter(
val valueConverter = createConverterToAvro(
valueType,
structName,
SchemaConverters.getNewRecordNamespace(valueType, recordNamespace, structName))
SchemaConverters.getNewRecordNamespace(valueType, recordNamespace, structName),
timeUnit)
(item: Any) => {
if (item == null) {
null
Expand All @@ -146,7 +155,8 @@ private[avro] class AvroOutputWriter(
createConverterToAvro(
field.dataType,
field.name,
SchemaConverters.getNewRecordNamespace(field.dataType, recordNamespace, field.name)))
SchemaConverters.getNewRecordNamespace(field.dataType, recordNamespace, field.name),
timeUnit))
(item: Any) => {
if (item == null) {
null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package com.databricks.spark.avro

import java.util.concurrent.TimeUnit

import org.apache.hadoop.mapreduce.TaskAttemptContext

import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory}
Expand All @@ -24,7 +26,8 @@ import org.apache.spark.sql.types.StructType
private[avro] class AvroOutputWriterFactory(
schema: StructType,
recordName: String,
recordNamespace: String) extends OutputWriterFactory {
recordNamespace: String,
timeUnit: TimeUnit) extends OutputWriterFactory {

override def getFileExtension(context: TaskAttemptContext): String = {
".avro"
Expand All @@ -34,6 +37,6 @@ private[avro] class AvroOutputWriterFactory(
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new AvroOutputWriter(path, context, schema, recordName, recordNamespace)
new AvroOutputWriter(path, context, schema, recordName, recordNamespace, timeUnit)
}
}
5 changes: 4 additions & 1 deletion src/main/scala/com/databricks/spark/avro/DefaultSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package com.databricks.spark.avro

import java.io._
import java.net.URI
import java.util.concurrent.TimeUnit
import java.util.zip.Deflater

import scala.util.control.NonFatal
Expand Down Expand Up @@ -114,6 +115,8 @@ private[avro] class DefaultSource extends FileFormat with DataSourceRegister {
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
val recordName = options.getOrElse("recordName", "topLevelRecord")
val timeUnitName = options.getOrElse("timeUnit", "milliseconds")
val timeUnit = TimeUnit.valueOf(timeUnitName.toUpperCase)
val recordNamespace = options.getOrElse("recordNamespace", "")
val build = SchemaBuilder.record(recordName).namespace(recordNamespace)
val outputAvroSchema = SchemaConverters.convertStructToAvro(dataSchema, build, recordNamespace)
Expand Down Expand Up @@ -145,7 +148,7 @@ private[avro] class DefaultSource extends FileFormat with DataSourceRegister {
log.error(s"unsupported compression codec $unknown")
}

new AvroOutputWriterFactory(dataSchema, recordName, recordNamespace)
new AvroOutputWriterFactory(dataSchema, recordName, recordNamespace, timeUnit)
}

override def buildReader(
Expand Down
20 changes: 20 additions & 0 deletions src/test/scala/com/databricks/spark/avro/AvroSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,26 @@ class AvroSuite extends FunSuite with BeforeAndAfterAll {
}
}

test("Timeunit conversion") {
TestUtils.withTempDir { dir =>
val schema = StructType(Seq(
StructField("float", FloatType, true),
StructField("timestamp", TimestampType, true)
))
TimeZone.setDefault(TimeZone.getTimeZone("UTC"))
val rdd = spark.sparkContext.parallelize(Seq(
Row(1f, null),
Row(2f, new Timestamp(1451948400000L)),
Row(3f, new Timestamp(1460066400500L))
))
val df = spark.createDataFrame(rdd, schema)
df.write.option("timeUnit", "microseconds").avro(dir.toString)
assert(spark.read.avro(dir.toString).count == rdd.count)
assert(spark.read.avro(dir.toString).select("timestamp").collect().map(_(0)).toSet ==
Array(null, 1451948400000000L, 1460066400500000L).toSet)
}
}

test("Array data types") {
TestUtils.withTempDir { dir =>
val testSchema = StructType(Seq(
Expand Down