Skip to content

Commit d43d960

Browse files
committed
refine udf transport error messaging
1 parent e484cc8 commit d43d960

File tree

4 files changed

+468
-6
lines changed

4 files changed

+468
-6
lines changed

src/query/expression/src/utils/udf_client.rs

Lines changed: 171 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use std::error::Error as StdError;
1516
use std::str::FromStr;
1617
use std::sync::Arc;
1718
use std::time::Duration;
@@ -20,8 +21,10 @@ use std::time::Instant;
2021
use arrow_array::RecordBatch;
2122
use arrow_flight::decode::FlightRecordBatchStream;
2223
use arrow_flight::encode::FlightDataEncoderBuilder;
24+
use arrow_flight::error::FlightError;
2325
use arrow_flight::flight_service_client::FlightServiceClient;
2426
use arrow_flight::FlightDescriptor;
27+
use arrow_schema::ArrowError;
2528
use arrow_select::concat::concat_batches;
2629
use databend_common_base::headers::HEADER_FUNCTION;
2730
use databend_common_base::headers::HEADER_FUNCTION_HANDLER;
@@ -49,6 +52,7 @@ use tonic::transport::channel::Channel;
4952
use tonic::transport::ClientTlsConfig;
5053
use tonic::transport::Endpoint;
5154
use tonic::Request;
55+
use tonic::Status;
5256

5357
use crate::types::DataType;
5458
use crate::variant_transform::contains_variant;
@@ -64,6 +68,21 @@ const UDF_KEEP_ALIVE_TIMEOUT_SEC: u64 = 20;
6468
// 4MB by default, we use 16G
6569
// max_encoding_message_size is usize::max by default
6670
const MAX_DECODING_MESSAGE_SIZE: usize = 16 * 1024 * 1024 * 1024;
71+
const TRANSPORT_ERROR_SNIPPETS: &[&str] = &[
72+
"h2 protocol error",
73+
"broken pipe",
74+
"connection reset",
75+
"error reading a body from connection",
76+
];
77+
78+
#[derive(Debug)]
79+
enum FlightDecodeIssue {
80+
TransportInterrupted,
81+
ServerStatus(String),
82+
SchemaMismatch,
83+
MalformedData,
84+
Other,
85+
}
6786

6887
#[derive(Debug, Clone)]
6988
pub struct UDFFlightClient {
@@ -380,11 +399,7 @@ impl UDFFlightClient {
380399
let record_batch_stream = FlightRecordBatchStream::new_from_flight_data(
381400
flight_data_stream.map_err(|err| err.into()),
382401
)
383-
.map_err(|err| {
384-
ErrorCode::UDFDataError(format!(
385-
"Decode record batch failed on UDF function {func_name}: {err}"
386-
))
387-
});
402+
.map_err(|err| handle_flight_decode_error(func_name, err));
388403

389404
let batches: Vec<RecordBatch> = record_batch_stream.try_collect().await?;
390405
if batches.is_empty() {
@@ -399,6 +414,90 @@ impl UDFFlightClient {
399414
}
400415
}
401416

417+
fn handle_flight_decode_error(func_name: &str, err: FlightError) -> ErrorCode {
418+
let issue = classify_flight_error(&err);
419+
let err_text = err.to_string();
420+
421+
match issue {
422+
FlightDecodeIssue::TransportInterrupted => ErrorCode::UDFDataError(format!(
423+
"The user-defined function \"{func_name}\" stopped responding before it finished. Retry the query; if it keeps failing, ensure the UDF server is running or review its logs. (details: {err_text})"
424+
)),
425+
FlightDecodeIssue::ServerStatus(status) => ErrorCode::UDFDataError(format!(
426+
"The user-defined function \"{func_name}\" reported an error: {status}. Review the UDF server logs."
427+
)),
428+
FlightDecodeIssue::SchemaMismatch => ErrorCode::UDFDataError(format!(
429+
"The user-defined function \"{func_name}\" returned an unexpected schema. Ensure the UDF definition matches the server output. (details: {err_text})"
430+
)),
431+
FlightDecodeIssue::MalformedData => ErrorCode::UDFDataError(format!(
432+
"The user-defined function \"{func_name}\" returned data that Databend could not parse. Check the UDF implementation or its logs. (details: {err_text})"
433+
)),
434+
FlightDecodeIssue::Other => ErrorCode::UDFDataError(format!(
435+
"Decode record batch failed on UDF function {func_name}: {err_text}"
436+
)),
437+
}
438+
}
439+
440+
fn classify_flight_error(err: &FlightError) -> FlightDecodeIssue {
441+
match err {
442+
FlightError::Arrow(arrow_err) => classify_arrow_error(arrow_err),
443+
FlightError::Tonic(status) => classify_status(status),
444+
FlightError::ExternalError(source) => classify_external_error(source.as_ref()),
445+
FlightError::ProtocolError(_) | FlightError::DecodeError(_) => {
446+
FlightDecodeIssue::MalformedData
447+
}
448+
FlightError::NotYetImplemented(_) => FlightDecodeIssue::Other,
449+
}
450+
}
451+
452+
fn classify_arrow_error(err: &ArrowError) -> FlightDecodeIssue {
453+
match err {
454+
ArrowError::SchemaError(_) => FlightDecodeIssue::SchemaMismatch,
455+
ArrowError::ExternalError(source) => classify_external_error(source.as_ref()),
456+
ArrowError::IoError(message, _) => classify_error_message(message),
457+
ArrowError::ParseError(_)
458+
| ArrowError::InvalidArgumentError(_)
459+
| ArrowError::ComputeError(_)
460+
| ArrowError::JsonError(_)
461+
| ArrowError::CsvError(_)
462+
| ArrowError::IpcError(_)
463+
| ArrowError::CDataInterface(_)
464+
| ArrowError::ParquetError(_) => FlightDecodeIssue::MalformedData,
465+
_ => FlightDecodeIssue::Other,
466+
}
467+
}
468+
469+
fn classify_status(status: &Status) -> FlightDecodeIssue {
470+
classify_error_message(status.message())
471+
}
472+
473+
fn classify_external_error(error: &(dyn StdError + Send + Sync + 'static)) -> FlightDecodeIssue {
474+
if let Some(flight_err) = error.downcast_ref::<FlightError>() {
475+
classify_flight_error(flight_err)
476+
} else if let Some(arrow_err) = error.downcast_ref::<ArrowError>() {
477+
classify_arrow_error(arrow_err)
478+
} else if let Some(status) = error.downcast_ref::<Status>() {
479+
classify_status(status)
480+
} else if let Some(io_error) = error.downcast_ref::<std::io::Error>() {
481+
classify_error_message(&io_error.to_string())
482+
} else {
483+
classify_error_message(&error.to_string())
484+
}
485+
}
486+
487+
fn classify_error_message(message: &str) -> FlightDecodeIssue {
488+
if is_transport_error_message(message) {
489+
FlightDecodeIssue::TransportInterrupted
490+
} else {
491+
FlightDecodeIssue::ServerStatus(message.to_string())
492+
}
493+
}
494+
495+
pub fn is_transport_error_message(message: &str) -> bool {
496+
let lower = message.to_ascii_lowercase();
497+
TRANSPORT_ERROR_SNIPPETS
498+
.iter()
499+
.any(|snippet| lower.contains(snippet))
500+
}
402501
pub fn error_kind(message: &str) -> &str {
403502
let message = message.to_ascii_lowercase();
404503
if message.contains("timeout") || message.contains("timedout") {
@@ -418,3 +517,70 @@ pub fn error_kind(message: &str) -> &str {
418517
"Other"
419518
}
420519
}
520+
521+
#[cfg(test)]
522+
mod tests {
523+
use tonic::Code;
524+
525+
use super::*;
526+
527+
#[test]
528+
fn transport_error_returns_interrupt_hint() {
529+
let err = handle_flight_decode_error(
530+
"test_udf",
531+
FlightError::Tonic(Box::new(Status::new(
532+
Code::Internal,
533+
"h2 protocol error: error reading a body from connection",
534+
))),
535+
);
536+
let message = err.message();
537+
assert!(
538+
message.contains("stopped responding before it finished"),
539+
"unexpected transport hint: {message}"
540+
);
541+
}
542+
543+
#[test]
544+
fn server_status_is_preserved() {
545+
let err = handle_flight_decode_error(
546+
"test_udf",
547+
FlightError::Tonic(Box::new(Status::new(
548+
Code::Internal,
549+
"remote handler returned validation error",
550+
))),
551+
);
552+
let message = err.message();
553+
assert!(
554+
message.contains("reported an error: remote handler returned validation error"),
555+
"unexpected server status message: {message}"
556+
);
557+
}
558+
559+
#[test]
560+
fn schema_mismatch_detected() {
561+
let err = handle_flight_decode_error(
562+
"test_udf",
563+
FlightError::Arrow(ArrowError::SchemaError(
564+
"expected Int32, got Utf8".to_string(),
565+
)),
566+
);
567+
let message = err.message();
568+
assert!(
569+
message.contains("returned an unexpected schema"),
570+
"schema mismatch hint missing: {message}"
571+
);
572+
}
573+
574+
#[test]
575+
fn malformed_data_reported() {
576+
let err = handle_flight_decode_error(
577+
"test_udf",
578+
FlightError::Arrow(ArrowError::ParseError("bad payload".to_string())),
579+
);
580+
let message = err.message();
581+
assert!(
582+
message.contains("could not parse"),
583+
"malformed data hint missing: {message}"
584+
);
585+
}
586+
}

src/query/service/src/pipelines/processors/transforms/transform_udf_server.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use databend_common_catalog::table_context::TableContext;
2424
use databend_common_exception::ErrorCode;
2525
use databend_common_exception::Result;
2626
use databend_common_expression::udf_client::error_kind;
27+
use databend_common_expression::udf_client::is_transport_error_message;
2728
use databend_common_expression::udf_client::UDFFlightClient;
2829
use databend_common_expression::BlockEntry;
2930
use databend_common_expression::ColumnBuilder;
@@ -156,7 +157,7 @@ fn retry_on(err: &databend_common_exception::ErrorCode) -> bool {
156157
if err.code() == ErrorCode::U_D_F_DATA_ERROR {
157158
let message = err.message();
158159
// this means the server can't handle the request in 60s
159-
if message.contains("h2 protocol error") {
160+
if is_transport_error_message(&message) {
160161
return false;
161162
}
162163
}

src/query/service/tests/it/pipelines/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515
mod executor;
1616
mod filter;
1717
mod transforms;
18+
mod udf_transport;

0 commit comments

Comments
 (0)