Skip to content

Commit bbdf756

Browse files
committed
refine udf transport error messaging
1 parent e484cc8 commit bbdf756

File tree

4 files changed

+402
-6
lines changed

4 files changed

+402
-6
lines changed

src/query/expression/src/utils/udf_client.rs

Lines changed: 172 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use std::error::Error as StdError;
1516
use std::str::FromStr;
1617
use std::sync::Arc;
1718
use std::time::Duration;
@@ -20,8 +21,10 @@ use std::time::Instant;
2021
use arrow_array::RecordBatch;
2122
use arrow_flight::decode::FlightRecordBatchStream;
2223
use arrow_flight::encode::FlightDataEncoderBuilder;
24+
use arrow_flight::error::FlightError;
2325
use arrow_flight::flight_service_client::FlightServiceClient;
2426
use arrow_flight::FlightDescriptor;
27+
use arrow_schema::ArrowError;
2528
use arrow_select::concat::concat_batches;
2629
use databend_common_base::headers::HEADER_FUNCTION;
2730
use databend_common_base::headers::HEADER_FUNCTION_HANDLER;
@@ -49,6 +52,7 @@ use tonic::transport::channel::Channel;
4952
use tonic::transport::ClientTlsConfig;
5053
use tonic::transport::Endpoint;
5154
use tonic::Request;
55+
use tonic::Status;
5256

5357
use crate::types::DataType;
5458
use crate::variant_transform::contains_variant;
@@ -64,6 +68,21 @@ const UDF_KEEP_ALIVE_TIMEOUT_SEC: u64 = 20;
6468
// 4MB by default, we use 16G
6569
// max_encoding_message_size is usize::max by default
6670
const MAX_DECODING_MESSAGE_SIZE: usize = 16 * 1024 * 1024 * 1024;
71+
const TRANSPORT_ERROR_SNIPPETS: &[&str] = &[
72+
"h2 protocol error",
73+
"broken pipe",
74+
"connection reset",
75+
"error reading a body from connection",
76+
];
77+
78+
#[derive(Debug)]
79+
enum FlightDecodeIssue {
80+
TransportInterrupted,
81+
ServerStatus(String),
82+
SchemaMismatch,
83+
MalformedData,
84+
Other,
85+
}
6786

6887
#[derive(Debug, Clone)]
6988
pub struct UDFFlightClient {
@@ -380,11 +399,7 @@ impl UDFFlightClient {
380399
let record_batch_stream = FlightRecordBatchStream::new_from_flight_data(
381400
flight_data_stream.map_err(|err| err.into()),
382401
)
383-
.map_err(|err| {
384-
ErrorCode::UDFDataError(format!(
385-
"Decode record batch failed on UDF function {func_name}: {err}"
386-
))
387-
});
402+
.map_err(|err| handle_flight_decode_error(func_name, err));
388403

389404
let batches: Vec<RecordBatch> = record_batch_stream.try_collect().await?;
390405
if batches.is_empty() {
@@ -399,6 +414,158 @@ impl UDFFlightClient {
399414
}
400415
}
401416

417+
fn handle_flight_decode_error(func_name: &str, err: FlightError) -> ErrorCode {
418+
let issue = classify_flight_error(&err);
419+
let err_text = err.to_string();
420+
421+
match issue {
422+
FlightDecodeIssue::TransportInterrupted => ErrorCode::UDFDataError(format!(
423+
"The user-defined function \"{func_name}\" stopped responding before it finished. Retry the query; if it keeps failing, ensure the UDF server is running or review its logs. (details: {err_text})"
424+
)),
425+
FlightDecodeIssue::ServerStatus(status) => ErrorCode::UDFDataError(format!(
426+
"The user-defined function \"{func_name}\" reported an error: {status}. Review the UDF server logs."
427+
)),
428+
FlightDecodeIssue::SchemaMismatch => ErrorCode::UDFDataError(format!(
429+
"The user-defined function \"{func_name}\" returned an unexpected schema. Ensure the UDF definition matches the server output. (details: {err_text})"
430+
)),
431+
FlightDecodeIssue::MalformedData => ErrorCode::UDFDataError(format!(
432+
"The user-defined function \"{func_name}\" returned data that Databend could not parse. Check the UDF implementation or its logs. (details: {err_text})"
433+
)),
434+
FlightDecodeIssue::Other => ErrorCode::UDFDataError(format!(
435+
"Decode record batch failed on UDF function {func_name}: {err_text}"
436+
)),
437+
}
438+
}
439+
440+
fn classify_flight_error(err: &FlightError) -> FlightDecodeIssue {
441+
match err {
442+
FlightError::Arrow(arrow_err) => classify_arrow_error(arrow_err),
443+
FlightError::Tonic(status) => classify_status(status),
444+
FlightError::ExternalError(source) => classify_external_error(source.as_ref()),
445+
FlightError::ProtocolError(_) | FlightError::DecodeError(_) => {
446+
FlightDecodeIssue::MalformedData
447+
}
448+
FlightError::NotYetImplemented(_) => FlightDecodeIssue::Other,
449+
}
450+
}
451+
452+
fn classify_arrow_error(err: &ArrowError) -> FlightDecodeIssue {
453+
match err {
454+
ArrowError::SchemaError(_) => FlightDecodeIssue::SchemaMismatch,
455+
ArrowError::ExternalError(source) => classify_external_error(source.as_ref()),
456+
ArrowError::IoError(message, _) => classify_error_message(message),
457+
ArrowError::ParseError(_)
458+
| ArrowError::InvalidArgumentError(_)
459+
| ArrowError::ComputeError(_)
460+
| ArrowError::JsonError(_)
461+
| ArrowError::CsvError(_)
462+
| ArrowError::IpcError(_)
463+
| ArrowError::CDataInterface(_)
464+
| ArrowError::ParquetError(_) => FlightDecodeIssue::MalformedData,
465+
_ => FlightDecodeIssue::Other,
466+
}
467+
}
468+
469+
fn classify_status(status: &Status) -> FlightDecodeIssue {
470+
classify_error_message(status.message())
471+
}
472+
473+
fn classify_external_error(error: &(dyn StdError + Send + Sync + 'static)) -> FlightDecodeIssue {
474+
if let Some(flight_err) = error.downcast_ref::<FlightError>() {
475+
classify_flight_error(flight_err)
476+
} else if let Some(arrow_err) = error.downcast_ref::<ArrowError>() {
477+
classify_arrow_error(arrow_err)
478+
} else if let Some(status) = error.downcast_ref::<Status>() {
479+
classify_status(status)
480+
} else if let Some(io_error) = error.downcast_ref::<std::io::Error>() {
481+
classify_error_message(&io_error.to_string())
482+
} else {
483+
classify_error_message(&error.to_string())
484+
}
485+
}
486+
487+
fn classify_error_message(message: &str) -> FlightDecodeIssue {
488+
if is_transport_error_message(message) {
489+
FlightDecodeIssue::TransportInterrupted
490+
} else {
491+
FlightDecodeIssue::ServerStatus(message.to_string())
492+
}
493+
}
494+
495+
pub fn is_transport_error_message(message: &str) -> bool {
496+
let lower = message.to_ascii_lowercase();
497+
TRANSPORT_ERROR_SNIPPETS
498+
.iter()
499+
.any(|snippet| lower.contains(snippet))
500+
}
501+
502+
#[cfg(test)]
503+
mod tests {
504+
use tonic::Code;
505+
506+
use super::*;
507+
508+
#[test]
509+
fn transport_error_returns_interrupt_hint() {
510+
let err = handle_flight_decode_error(
511+
"test_udf",
512+
FlightError::Tonic(Box::new(Status::new(
513+
Code::Internal,
514+
"h2 protocol error: error reading a body from connection",
515+
))),
516+
);
517+
let message = err.message();
518+
assert!(
519+
message.contains("stopped responding before it finished"),
520+
"unexpected transport hint: {message}"
521+
);
522+
}
523+
524+
#[test]
525+
fn server_status_is_preserved() {
526+
let err = handle_flight_decode_error(
527+
"test_udf",
528+
FlightError::Tonic(Box::new(Status::new(
529+
Code::Internal,
530+
"remote handler returned validation error",
531+
))),
532+
);
533+
let message = err.message();
534+
assert!(
535+
message.contains("reported an error: remote handler returned validation error"),
536+
"unexpected server status message: {message}"
537+
);
538+
}
539+
540+
#[test]
541+
fn schema_mismatch_detected() {
542+
let err = handle_flight_decode_error(
543+
"test_udf",
544+
FlightError::Arrow(ArrowError::SchemaError(
545+
"expected Int32, got Utf8".to_string(),
546+
)),
547+
);
548+
let message = err.message();
549+
assert!(
550+
message.contains("returned an unexpected schema"),
551+
"schema mismatch hint missing: {message}"
552+
);
553+
}
554+
555+
#[test]
556+
fn malformed_data_reported() {
557+
let err = handle_flight_decode_error(
558+
"test_udf",
559+
FlightError::Arrow(ArrowError::ParseError("bad payload".to_string())),
560+
);
561+
let message = err.message();
562+
assert!(
563+
message.contains("could not parse"),
564+
"malformed data hint missing: {message}"
565+
);
566+
}
567+
}
568+
402569
pub fn error_kind(message: &str) -> &str {
403570
let message = message.to_ascii_lowercase();
404571
if message.contains("timeout") || message.contains("timedout") {

src/query/service/src/pipelines/processors/transforms/transform_udf_server.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use databend_common_catalog::table_context::TableContext;
2424
use databend_common_exception::ErrorCode;
2525
use databend_common_exception::Result;
2626
use databend_common_expression::udf_client::error_kind;
27+
use databend_common_expression::udf_client::is_transport_error_message;
2728
use databend_common_expression::udf_client::UDFFlightClient;
2829
use databend_common_expression::BlockEntry;
2930
use databend_common_expression::ColumnBuilder;
@@ -156,7 +157,7 @@ fn retry_on(err: &databend_common_exception::ErrorCode) -> bool {
156157
if err.code() == ErrorCode::U_D_F_DATA_ERROR {
157158
let message = err.message();
158159
// this means the server can't handle the request in 60s
159-
if message.contains("h2 protocol error") {
160+
if is_transport_error_message(&message) {
160161
return false;
161162
}
162163
}

src/query/service/tests/it/pipelines/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515
mod executor;
1616
mod filter;
1717
mod transforms;
18+
mod udf_transport;

0 commit comments

Comments
 (0)