1212// See the License for the specific language governing permissions and
1313// limitations under the License.
1414
15+ use std:: error:: Error as StdError ;
1516use std:: str:: FromStr ;
1617use std:: sync:: Arc ;
1718use std:: time:: Duration ;
@@ -20,8 +21,10 @@ use std::time::Instant;
2021use arrow_array:: RecordBatch ;
2122use arrow_flight:: decode:: FlightRecordBatchStream ;
2223use arrow_flight:: encode:: FlightDataEncoderBuilder ;
24+ use arrow_flight:: error:: FlightError ;
2325use arrow_flight:: flight_service_client:: FlightServiceClient ;
2426use arrow_flight:: FlightDescriptor ;
27+ use arrow_schema:: ArrowError ;
2528use arrow_select:: concat:: concat_batches;
2629use databend_common_base:: headers:: HEADER_FUNCTION ;
2730use databend_common_base:: headers:: HEADER_FUNCTION_HANDLER ;
@@ -49,6 +52,7 @@ use tonic::transport::channel::Channel;
4952use tonic:: transport:: ClientTlsConfig ;
5053use tonic:: transport:: Endpoint ;
5154use tonic:: Request ;
55+ use tonic:: Status ;
5256
5357use crate :: types:: DataType ;
5458use crate :: variant_transform:: contains_variant;
@@ -64,6 +68,21 @@ const UDF_KEEP_ALIVE_TIMEOUT_SEC: u64 = 20;
6468// 4MB by default, we use 16G
6569// max_encoding_message_size is usize::max by default
6670const MAX_DECODING_MESSAGE_SIZE : usize = 16 * 1024 * 1024 * 1024 ;
71+ const TRANSPORT_ERROR_SNIPPETS : & [ & str ] = & [
72+ "h2 protocol error" ,
73+ "broken pipe" ,
74+ "connection reset" ,
75+ "error reading a body from connection" ,
76+ ] ;
77+
78+ #[ derive( Debug ) ]
79+ enum FlightDecodeIssue {
80+ TransportInterrupted ,
81+ ServerStatus ( String ) ,
82+ SchemaMismatch ,
83+ MalformedData ,
84+ Other ,
85+ }
6786
6887#[ derive( Debug , Clone ) ]
6988pub struct UDFFlightClient {
@@ -380,11 +399,7 @@ impl UDFFlightClient {
380399 let record_batch_stream = FlightRecordBatchStream :: new_from_flight_data (
381400 flight_data_stream. map_err ( |err| err. into ( ) ) ,
382401 )
383- . map_err ( |err| {
384- ErrorCode :: UDFDataError ( format ! (
385- "Decode record batch failed on UDF function {func_name}: {err}"
386- ) )
387- } ) ;
402+ . map_err ( |err| handle_flight_decode_error ( func_name, err) ) ;
388403
389404 let batches: Vec < RecordBatch > = record_batch_stream. try_collect ( ) . await ?;
390405 if batches. is_empty ( ) {
@@ -399,6 +414,90 @@ impl UDFFlightClient {
399414 }
400415}
401416
417+ fn handle_flight_decode_error ( func_name : & str , err : FlightError ) -> ErrorCode {
418+ let issue = classify_flight_error ( & err) ;
419+ let err_text = err. to_string ( ) ;
420+
421+ match issue {
422+ FlightDecodeIssue :: TransportInterrupted => ErrorCode :: UDFDataError ( format ! (
423+ "The user-defined function \" {func_name}\" stopped responding before it finished. Retry the query; if it keeps failing, ensure the UDF server is running or review its logs. (details: {err_text})"
424+ ) ) ,
425+ FlightDecodeIssue :: ServerStatus ( status) => ErrorCode :: UDFDataError ( format ! (
426+ "The user-defined function \" {func_name}\" reported an error: {status}. Review the UDF server logs."
427+ ) ) ,
428+ FlightDecodeIssue :: SchemaMismatch => ErrorCode :: UDFDataError ( format ! (
429+ "The user-defined function \" {func_name}\" returned an unexpected schema. Ensure the UDF definition matches the server output. (details: {err_text})"
430+ ) ) ,
431+ FlightDecodeIssue :: MalformedData => ErrorCode :: UDFDataError ( format ! (
432+ "The user-defined function \" {func_name}\" returned data that Databend could not parse. Check the UDF implementation or its logs. (details: {err_text})"
433+ ) ) ,
434+ FlightDecodeIssue :: Other => ErrorCode :: UDFDataError ( format ! (
435+ "Decode record batch failed on UDF function {func_name}: {err_text}"
436+ ) ) ,
437+ }
438+ }
439+
440+ fn classify_flight_error ( err : & FlightError ) -> FlightDecodeIssue {
441+ match err {
442+ FlightError :: Arrow ( arrow_err) => classify_arrow_error ( arrow_err) ,
443+ FlightError :: Tonic ( status) => classify_status ( status) ,
444+ FlightError :: ExternalError ( source) => classify_external_error ( source. as_ref ( ) ) ,
445+ FlightError :: ProtocolError ( _) | FlightError :: DecodeError ( _) => {
446+ FlightDecodeIssue :: MalformedData
447+ }
448+ FlightError :: NotYetImplemented ( _) => FlightDecodeIssue :: Other ,
449+ }
450+ }
451+
452+ fn classify_arrow_error ( err : & ArrowError ) -> FlightDecodeIssue {
453+ match err {
454+ ArrowError :: SchemaError ( _) => FlightDecodeIssue :: SchemaMismatch ,
455+ ArrowError :: ExternalError ( source) => classify_external_error ( source. as_ref ( ) ) ,
456+ ArrowError :: IoError ( message, _) => classify_error_message ( message) ,
457+ ArrowError :: ParseError ( _)
458+ | ArrowError :: InvalidArgumentError ( _)
459+ | ArrowError :: ComputeError ( _)
460+ | ArrowError :: JsonError ( _)
461+ | ArrowError :: CsvError ( _)
462+ | ArrowError :: IpcError ( _)
463+ | ArrowError :: CDataInterface ( _)
464+ | ArrowError :: ParquetError ( _) => FlightDecodeIssue :: MalformedData ,
465+ _ => FlightDecodeIssue :: Other ,
466+ }
467+ }
468+
469+ fn classify_status ( status : & Status ) -> FlightDecodeIssue {
470+ classify_error_message ( status. message ( ) )
471+ }
472+
473+ fn classify_external_error ( error : & ( dyn StdError + Send + Sync + ' static ) ) -> FlightDecodeIssue {
474+ if let Some ( flight_err) = error. downcast_ref :: < FlightError > ( ) {
475+ classify_flight_error ( flight_err)
476+ } else if let Some ( arrow_err) = error. downcast_ref :: < ArrowError > ( ) {
477+ classify_arrow_error ( arrow_err)
478+ } else if let Some ( status) = error. downcast_ref :: < Status > ( ) {
479+ classify_status ( status)
480+ } else if let Some ( io_error) = error. downcast_ref :: < std:: io:: Error > ( ) {
481+ classify_error_message ( & io_error. to_string ( ) )
482+ } else {
483+ classify_error_message ( & error. to_string ( ) )
484+ }
485+ }
486+
487+ fn classify_error_message ( message : & str ) -> FlightDecodeIssue {
488+ if is_transport_error_message ( message) {
489+ FlightDecodeIssue :: TransportInterrupted
490+ } else {
491+ FlightDecodeIssue :: ServerStatus ( message. to_string ( ) )
492+ }
493+ }
494+
495+ pub fn is_transport_error_message ( message : & str ) -> bool {
496+ let lower = message. to_ascii_lowercase ( ) ;
497+ TRANSPORT_ERROR_SNIPPETS
498+ . iter ( )
499+ . any ( |snippet| lower. contains ( snippet) )
500+ }
402501pub fn error_kind ( message : & str ) -> & str {
403502 let message = message. to_ascii_lowercase ( ) ;
404503 if message. contains ( "timeout" ) || message. contains ( "timedout" ) {
@@ -418,3 +517,70 @@ pub fn error_kind(message: &str) -> &str {
418517 "Other"
419518 }
420519}
520+
521+ #[ cfg( test) ]
522+ mod tests {
523+ use tonic:: Code ;
524+
525+ use super :: * ;
526+
527+ #[ test]
528+ fn transport_error_returns_interrupt_hint ( ) {
529+ let err = handle_flight_decode_error (
530+ "test_udf" ,
531+ FlightError :: Tonic ( Box :: new ( Status :: new (
532+ Code :: Internal ,
533+ "h2 protocol error: error reading a body from connection" ,
534+ ) ) ) ,
535+ ) ;
536+ let message = err. message ( ) ;
537+ assert ! (
538+ message. contains( "stopped responding before it finished" ) ,
539+ "unexpected transport hint: {message}"
540+ ) ;
541+ }
542+
543+ #[ test]
544+ fn server_status_is_preserved ( ) {
545+ let err = handle_flight_decode_error (
546+ "test_udf" ,
547+ FlightError :: Tonic ( Box :: new ( Status :: new (
548+ Code :: Internal ,
549+ "remote handler returned validation error" ,
550+ ) ) ) ,
551+ ) ;
552+ let message = err. message ( ) ;
553+ assert ! (
554+ message. contains( "reported an error: remote handler returned validation error" ) ,
555+ "unexpected server status message: {message}"
556+ ) ;
557+ }
558+
559+ #[ test]
560+ fn schema_mismatch_detected ( ) {
561+ let err = handle_flight_decode_error (
562+ "test_udf" ,
563+ FlightError :: Arrow ( ArrowError :: SchemaError (
564+ "expected Int32, got Utf8" . to_string ( ) ,
565+ ) ) ,
566+ ) ;
567+ let message = err. message ( ) ;
568+ assert ! (
569+ message. contains( "returned an unexpected schema" ) ,
570+ "schema mismatch hint missing: {message}"
571+ ) ;
572+ }
573+
574+ #[ test]
575+ fn malformed_data_reported ( ) {
576+ let err = handle_flight_decode_error (
577+ "test_udf" ,
578+ FlightError :: Arrow ( ArrowError :: ParseError ( "bad payload" . to_string ( ) ) ) ,
579+ ) ;
580+ let message = err. message ( ) ;
581+ assert ! (
582+ message. contains( "could not parse" ) ,
583+ "malformed data hint missing: {message}"
584+ ) ;
585+ }
586+ }
0 commit comments