3737import org .apache .spark .sql .DataFrame ;
3838import org .apache .spark .sql .Row ;
3939import org .apache .spark .sql .SQLContext ;
40+ import org .apache .spark .sql .types .DataType ;
4041import org .apache .spark .sql .types .StructType ;
42+ import org .slf4j .Logger ;
43+ import org .slf4j .LoggerFactory ;
4144
4245import java .io .IOException ;
4346import java .io .PrintWriter ;
5558@ Description ("Executes user-provided Spark code written in Scala that performs RDD to RDD transformation" )
5659public class ScalaSparkCompute extends SparkCompute <StructuredRecord , StructuredRecord > {
5760
61+ private static final Logger LOG = LoggerFactory .getLogger (ScalaSparkCompute .class );
62+
5863 private static final String CLASS_NAME_PREFIX = "co.cask.hydrator.plugin.spark.dynamic.generated.UserSparkCompute$" ;
64+ private static final Class <?> DATAFRAME_TYPE = getDataFrameType ();
5965 private static final Class <?>[][] ACCEPTABLE_PARAMETER_TYPES = new Class <?>[][] {
6066 { RDD .class , SparkExecutionPluginContext .class },
6167 { RDD .class },
62- { DataFrame . class , SparkExecutionPluginContext .class },
63- { DataFrame . class }
68+ { DATAFRAME_TYPE , SparkExecutionPluginContext .class },
69+ { DATAFRAME_TYPE }
6470 };
6571
6672 private final ThreadLocal <SQLContext > sqlContextThreadLocal = new InheritableThreadLocal <>();
@@ -102,7 +108,7 @@ public void configurePipeline(PipelineConfigurer pipelineConfigurer) throws Ille
102108 Method method = getTransformMethod (interpreter .getClassLoader (), className );
103109
104110 // If the method takes DataFrame, make sure it has input schema
105- if (method .getParameterTypes ()[0 ].equals (DataFrame . class ) && stageConfigurer .getInputSchema () == null ) {
111+ if (method .getParameterTypes ()[0 ].equals (DATAFRAME_TYPE ) && stageConfigurer .getInputSchema () == null ) {
106112 throw new IllegalArgumentException ("Missing input schema for transformation using DataFrame" );
107113 }
108114
@@ -119,7 +125,7 @@ public void initialize(SparkExecutionPluginContext context) throws Exception {
119125 interpreter = context .createSparkInterpreter ();
120126 interpreter .compile (generateSourceClass (className ));
121127 method = getTransformMethod (interpreter .getClassLoader (), className );
122- isDataFrame = method .getParameterTypes ()[0 ].equals (DataFrame . class );
128+ isDataFrame = method .getParameterTypes ()[0 ].equals (DATAFRAME_TYPE );
123129 takeContext = method .getParameterTypes ().length == 2 ;
124130
125131 // Input schema shouldn't be null
@@ -154,18 +160,18 @@ public JavaRDD<StructuredRecord> transform(SparkExecutionPluginContext context,
154160 StructType rowType = DataFrames .toDataType (inputSchema );
155161 JavaRDD <Row > rowRDD = javaRDD .map (new RecordToRow (rowType ));
156162
157- DataFrame dataFrame = sqlContext .createDataFrame (rowRDD , rowType );
158- DataFrame result = (DataFrame ) (takeContext ?
159- method .invoke (null , dataFrame , context ) : method .invoke (null , dataFrame ));
163+ Object dataFrame = sqlContext .createDataFrame (rowRDD , rowType );
164+ Object result = takeContext ? method .invoke (null , dataFrame , context ) : method .invoke (null , dataFrame );
160165
161166 // Convert the DataFrame back to RDD<StructureRecord>
162167 Schema outputSchema = context .getOutputSchema ();
163168 if (outputSchema == null ) {
164169 // If there is no output schema configured, derive it from the DataFrame
165170 // Otherwise, assume the DataFrame has the correct schema already
166- outputSchema = DataFrames .toSchema (result . schema ( ));
171+ outputSchema = DataFrames .toSchema (( DataType ) invokeDataFrameMethod ( result , " schema" ));
167172 }
168- return result .toJavaRDD ().map (new RowToRecord (outputSchema ));
173+ //noinspection unchecked
174+ return ((JavaRDD <Row >) invokeDataFrameMethod (result , "toJavaRDD" )).map (new RowToRecord (outputSchema ));
169175 }
170176
171177 private String generateSourceClass (String className ) {
@@ -251,7 +257,7 @@ private Method getTransformMethod(ClassLoader classLoader, String className) {
251257 Type [] parameterTypes = method .getGenericParameterTypes ();
252258
253259 // The first parameter should be of type RDD[StructuredRecord] if it takes RDD
254- if (!parameterTypes [0 ].equals (DataFrame . class )) {
260+ if (!parameterTypes [0 ].equals (DATAFRAME_TYPE )) {
255261 validateRDDType (parameterTypes [0 ],
256262 "The first parameter of the 'transform' method should have type as 'RDD[StructuredRecord]'" );
257263 }
@@ -264,8 +270,8 @@ private Method getTransformMethod(ClassLoader classLoader, String className) {
264270
265271 // The return type of the method must be RDD[StructuredRecord] if it takes RDD
266272 // Or it must be DataFrame if it takes DataFrame
267- if (parameterTypes [0 ].equals (DataFrame . class )) {
268- if (!method .getReturnType ().equals (DataFrame . class )) {
273+ if (parameterTypes [0 ].equals (DATAFRAME_TYPE )) {
274+ if (!method .getReturnType ().equals (DATAFRAME_TYPE )) {
269275 throw new IllegalArgumentException ("The return type of the 'transform' method should be 'DataFrame'" );
270276 }
271277 } else {
@@ -388,4 +394,26 @@ public StructuredRecord call(Row row) throws Exception {
388394 return DataFrames .fromRow (row , schema );
389395 }
390396 }
397+
398+ @ Nullable
399+ private static Class <?> getDataFrameType () {
400+ // For Spark1, it has the DataFrame class
401+ // For Spark2, there is no more DataFrame class, and it becomes Dataset<Row>
402+ try {
403+ return ScalaSparkCompute .class .getClassLoader ().loadClass ("org.apache.spark.sql.DataFrame" );
404+ } catch (ClassNotFoundException e ) {
405+ try {
406+ return ScalaSparkCompute .class .getClassLoader ().loadClass ("org.apache.spark.sql.Dataset" );
407+ } catch (ClassNotFoundException e1 ) {
408+ LOG .warn ("Failed to determine the type of Spark DataFrame. " +
409+ "DataFrame is not supported in the ScalaSparkCompute plugin." );
410+ return null ;
411+ }
412+ }
413+ }
414+
415+ private static <T > T invokeDataFrameMethod (Object dataFrame , String methodName ) throws Exception {
416+ //noinspection unchecked
417+ return (T ) dataFrame .getClass ().getMethod (methodName ).invoke (dataFrame );
418+ }
391419}
0 commit comments