Skip to content

Commit a8891a2

Browse files
authored
Merge pull request #11 from hydrator/feature/spark2-compatible
Spark2 doesn't have the DataFrame class (as Java class) anymore
2 parents 5c24ce3 + 1098317 commit a8891a2

File tree

2 files changed

+41
-13
lines changed

2 files changed

+41
-13
lines changed

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
<groupId>co.cask.hydrator</groupId>
2424
<artifactId>dynamic-spark</artifactId>
25-
<version>2.0.2</version>
25+
<version>2.0.3-SNAPSHOT</version>
2626

2727
<properties>
2828
<!-- properties for script build step that creates the config files for the artifacts -->

src/main/java/co/cask/hydrator/plugin/spark/dynamic/ScalaSparkCompute.java

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@
3737
import org.apache.spark.sql.DataFrame;
3838
import org.apache.spark.sql.Row;
3939
import org.apache.spark.sql.SQLContext;
40+
import org.apache.spark.sql.types.DataType;
4041
import org.apache.spark.sql.types.StructType;
42+
import org.slf4j.Logger;
43+
import org.slf4j.LoggerFactory;
4144

4245
import java.io.IOException;
4346
import java.io.PrintWriter;
@@ -55,12 +58,15 @@
5558
@Description("Executes user-provided Spark code written in Scala that performs RDD to RDD transformation")
5659
public class ScalaSparkCompute extends SparkCompute<StructuredRecord, StructuredRecord> {
5760

61+
private static final Logger LOG = LoggerFactory.getLogger(ScalaSparkCompute.class);
62+
5863
private static final String CLASS_NAME_PREFIX = "co.cask.hydrator.plugin.spark.dynamic.generated.UserSparkCompute$";
64+
private static final Class<?> DATAFRAME_TYPE = getDataFrameType();
5965
private static final Class<?>[][] ACCEPTABLE_PARAMETER_TYPES = new Class<?>[][] {
6066
{ RDD.class, SparkExecutionPluginContext.class },
6167
{ RDD.class },
62-
{ DataFrame.class, SparkExecutionPluginContext.class},
63-
{ DataFrame.class }
68+
{ DATAFRAME_TYPE, SparkExecutionPluginContext.class},
69+
{ DATAFRAME_TYPE }
6470
};
6571

6672
private final ThreadLocal<SQLContext> sqlContextThreadLocal = new InheritableThreadLocal<>();
@@ -102,7 +108,7 @@ public void configurePipeline(PipelineConfigurer pipelineConfigurer) throws Ille
102108
Method method = getTransformMethod(interpreter.getClassLoader(), className);
103109

104110
// If the method takes DataFrame, make sure it has input schema
105-
if (method.getParameterTypes()[0].equals(DataFrame.class) && stageConfigurer.getInputSchema() == null) {
111+
if (method.getParameterTypes()[0].equals(DATAFRAME_TYPE) && stageConfigurer.getInputSchema() == null) {
106112
throw new IllegalArgumentException("Missing input schema for transformation using DataFrame");
107113
}
108114

@@ -119,7 +125,7 @@ public void initialize(SparkExecutionPluginContext context) throws Exception {
119125
interpreter = context.createSparkInterpreter();
120126
interpreter.compile(generateSourceClass(className));
121127
method = getTransformMethod(interpreter.getClassLoader(), className);
122-
isDataFrame = method.getParameterTypes()[0].equals(DataFrame.class);
128+
isDataFrame = method.getParameterTypes()[0].equals(DATAFRAME_TYPE);
123129
takeContext = method.getParameterTypes().length == 2;
124130

125131
// Input schema shouldn't be null
@@ -154,18 +160,18 @@ public JavaRDD<StructuredRecord> transform(SparkExecutionPluginContext context,
154160
StructType rowType = DataFrames.toDataType(inputSchema);
155161
JavaRDD<Row> rowRDD = javaRDD.map(new RecordToRow(rowType));
156162

157-
DataFrame dataFrame = sqlContext.createDataFrame(rowRDD, rowType);
158-
DataFrame result = (DataFrame) (takeContext ?
159-
method.invoke(null, dataFrame, context) : method.invoke(null, dataFrame));
163+
Object dataFrame = sqlContext.createDataFrame(rowRDD, rowType);
164+
Object result = takeContext ? method.invoke(null, dataFrame, context) : method.invoke(null, dataFrame);
160165

161166
// Convert the DataFrame back to RDD<StructureRecord>
162167
Schema outputSchema = context.getOutputSchema();
163168
if (outputSchema == null) {
164169
// If there is no output schema configured, derive it from the DataFrame
165170
// Otherwise, assume the DataFrame has the correct schema already
166-
outputSchema = DataFrames.toSchema(result.schema());
171+
outputSchema = DataFrames.toSchema((DataType) invokeDataFrameMethod(result, "schema"));
167172
}
168-
return result.toJavaRDD().map(new RowToRecord(outputSchema));
173+
//noinspection unchecked
174+
return ((JavaRDD<Row>) invokeDataFrameMethod(result, "toJavaRDD")).map(new RowToRecord(outputSchema));
169175
}
170176

171177
private String generateSourceClass(String className) {
@@ -251,7 +257,7 @@ private Method getTransformMethod(ClassLoader classLoader, String className) {
251257
Type[] parameterTypes = method.getGenericParameterTypes();
252258

253259
// The first parameter should be of type RDD[StructuredRecord] if it takes RDD
254-
if (!parameterTypes[0].equals(DataFrame.class)) {
260+
if (!parameterTypes[0].equals(DATAFRAME_TYPE)) {
255261
validateRDDType(parameterTypes[0],
256262
"The first parameter of the 'transform' method should have type as 'RDD[StructuredRecord]'");
257263
}
@@ -264,8 +270,8 @@ private Method getTransformMethod(ClassLoader classLoader, String className) {
264270

265271
// The return type of the method must be RDD[StructuredRecord] if it takes RDD
266272
// Or it must be DataFrame if it takes DataFrame
267-
if (parameterTypes[0].equals(DataFrame.class)) {
268-
if (!method.getReturnType().equals(DataFrame.class)) {
273+
if (parameterTypes[0].equals(DATAFRAME_TYPE)) {
274+
if (!method.getReturnType().equals(DATAFRAME_TYPE)) {
269275
throw new IllegalArgumentException("The return type of the 'transform' method should be 'DataFrame'");
270276
}
271277
} else {
@@ -388,4 +394,26 @@ public StructuredRecord call(Row row) throws Exception {
388394
return DataFrames.fromRow(row, schema);
389395
}
390396
}
397+
398+
@Nullable
399+
private static Class<?> getDataFrameType() {
400+
// For Spark1, it has the DataFrame class
401+
// For Spark2, there is no more DataFrame class, and it becomes Dataset<Row>
402+
try {
403+
return ScalaSparkCompute.class.getClassLoader().loadClass("org.apache.spark.sql.DataFrame");
404+
} catch (ClassNotFoundException e) {
405+
try {
406+
return ScalaSparkCompute.class.getClassLoader().loadClass("org.apache.spark.sql.Dataset");
407+
} catch (ClassNotFoundException e1) {
408+
LOG.warn("Failed to determine the type of Spark DataFrame. " +
409+
"DataFrame is not supported in the ScalaSparkCompute plugin.");
410+
return null;
411+
}
412+
}
413+
}
414+
415+
private static <T> T invokeDataFrameMethod(Object dataFrame, String methodName) throws Exception {
416+
//noinspection unchecked
417+
return (T) dataFrame.getClass().getMethod(methodName).invoke(dataFrame);
418+
}
391419
}

0 commit comments

Comments
 (0)