@@ -636,7 +636,11 @@ ModelInstanceState::Run(
636636 " error setting the binding dimension" );
637637
638638 TRITONSERVER_DataType datatype = batch_input.DataType ();
639- size_t total_byte_size = GetByteSize (datatype, shape);
639+ int64_t total_byte_size = 0 ;
640+ FAIL_ALL_AND_RETURN_IF_ERROR (
641+ payload_->requests_ , payload_->request_count_ , payload_->responses_ ,
642+ GetByteSize (datatype, shape, &total_byte_size),
643+ " error getting the batch input byte size" );
640644
641645 const char * dst_buffer;
642646 size_t dst_buffer_byte_size;
@@ -690,7 +694,12 @@ ModelInstanceState::Run(
690694 " '" )
691695 .c_str ());
692696
693- ragged_shape[0 ] += backend::GetElementCount (shape, dims_count);
697+ int64_t element_cnt = 0 ;
698+ FAIL_ALL_AND_RETURN_IF_ERROR (
699+ payload_->requests_ , payload_->request_count_ , payload_->responses_ ,
700+ backend::GetElementCount (shape, dims_count, &element_cnt),
701+ " error getting the input element count" );
702+ ragged_shape[0 ] += element_cnt;
694703 if (req_idx == 0 ) {
695704 datatype = temp_dt;
696705 }
@@ -702,7 +711,11 @@ ModelInstanceState::Run(
702711 name, ragged_shape, citr->second , io_index, &input_dims),
703712 " error setting the binding dimension" );
704713
705- size_t total_byte_size = GetByteSize (datatype, ragged_shape);
714+ int64_t total_byte_size = 0 ;
715+ FAIL_ALL_AND_RETURN_IF_ERROR (
716+ payload_->requests_ , payload_->request_count_ , payload_->responses_ ,
717+ GetByteSize (datatype, ragged_shape, &total_byte_size),
718+ " error getting the input byte size" );
706719
707720 payload_->collector_ ->ProcessTensor (
708721 name.c_str (), static_cast <char *>(io_binding_info.GetBuffer ()),
@@ -758,17 +771,23 @@ ModelInstanceState::Run(
758771 " error setting the binding dimension" );
759772 }
760773
761- size_t total_byte_size = 0 ;
774+ int64_t total_byte_size = 0 ;
762775 if (io_binding_info.GetFormat ().is_linear_format_ ) {
763- total_byte_size = GetByteSize (datatype, batchn_shape);
776+ FAIL_ALL_AND_RETURN_IF_ERROR (
777+ payload_->requests_ , payload_->request_count_ , payload_->responses_ ,
778+ GetByteSize (datatype, batchn_shape, &total_byte_size),
779+ " error getting the batch input byte size" );
764780 // For input tensors with a linear IO format, the request has already
765781 // verified the byte size, so no further validation is needed here.
766782 } else {
767783 batchn_shape[io_binding_info.GetFormat ().vectorized_dim_ ] +=
768784 (io_binding_info.GetFormat ().components_per_element_ -
769785 (batchn_shape[io_binding_info.GetFormat ().vectorized_dim_ ] %
770786 io_binding_info.GetFormat ().components_per_element_ ));
771- total_byte_size = GetByteSize (datatype, batchn_shape);
787+ FAIL_ALL_AND_RETURN_IF_ERROR (
788+ payload_->requests_ , payload_->request_count_ , payload_->responses_ ,
789+ GetByteSize (datatype, batchn_shape, &total_byte_size),
790+ " error getting the batch input byte size" );
772791
773792 // Ensure the request data byte size matches the expected byte size for
774793 // non-linear IO format tensors
@@ -823,8 +842,13 @@ ModelInstanceState::Run(
823842 // Initialize additional entries in batch input
824843 if (io_binding_info.GetBatchInput () != nullptr ) {
825844 const auto & batch_input = io_binding_info.GetBatchInput ()->first ;
826- const size_t total_byte_size = GetByteSize (
827- batch_input.DataType (), cuda_graph->input_dims_ [input_idx]);
845+ int64_t total_byte_size = 0 ;
846+ FAIL_ALL_AND_RETURN_IF_ERROR (
847+ payload_->requests_ , payload_->request_count_ , payload_->responses_ ,
848+ GetByteSize (
849+ batch_input.DataType (), cuda_graph->input_dims_ [input_idx],
850+ &total_byte_size),
851+ " error getting the batch input byte size" );
828852
829853 auto & allocated_memory = io_binding_info.GetBatchInput ()->second ;
830854 TRITONSERVER_MemoryType mem_type = allocated_memory->MemoryType ();
@@ -841,7 +865,7 @@ ModelInstanceState::Run(
841865 batch_input, input_buffer, total_byte_size,
842866 {{mem_type, mem_type_id}}, &dst_buffer, &dst_buffer_byte_size,
843867 &dst_memory_type, &dst_memory_type_id),
844- " error setting the bath input value" );
868+ " error setting the batch input value" );
845869
846870 if ((batch_input.BatchInputKind () !=
847871 BatchInput::Kind::BATCH_MAX_ELEMENT_COUNT_AS_SHAPE) &&
@@ -1067,8 +1091,10 @@ ModelInstanceState::Run(
10671091 batchn_shape[0 ] = shape[0 ];
10681092 }
10691093
1070- const size_t tensor_element_cnt =
1071- backend::GetElementCount (batchn_shape);
1094+ int64_t tensor_element_cnt = 0 ;
1095+ RESPOND_AND_SET_NULL_IF_ERROR (
1096+ &response,
1097+ backend::GetElementCount (batchn_shape, &tensor_element_cnt));
10721098
10731099 TRITONSERVER_DataType dt = ConvertTrtTypeToDataType (
10741100 engine_->getTensorDataType (name.c_str ()));
@@ -1112,7 +1138,11 @@ ModelInstanceState::Run(
11121138 // FIXME process reformat-free output, need to update output
11131139 // process code to accept batch1_byte_size and request batch
11141140 // size to break down output buffer properly.
1115- size_t batch1_byte_size = GetByteSize (dt, batchn_shape);
1141+ int64_t batch1_byte_size = 0 ;
1142+ FAIL_ALL_AND_RETURN_IF_ERROR (
1143+ payload_->requests_ , payload_->request_count_ , payload_->responses_ ,
1144+ GetByteSize (dt, batchn_shape, &batch1_byte_size),
1145+ " error getting the batch byte size" );
11161146 if (support_batching_) {
11171147 batch1_byte_size /= payload_->total_batch_size_ ;
11181148 }
@@ -1371,7 +1401,9 @@ ModelInstanceState::GetRequestShapeValues(
13711401 .c_str ());
13721402 }
13731403
1374- int64_t element_cnt = backend::GetElementCount (shape, dims_count);
1404+ int64_t element_cnt = 0 ;
1405+ RETURN_IF_ERROR (
1406+ backend::GetElementCount (shape, dims_count, &element_cnt));
13751407 if (support_batching_) {
13761408 element_cnt /= shape[0 ];
13771409 }
@@ -1481,7 +1513,10 @@ ModelInstanceState::EvaluateTensorRTContext(
14811513 RETURN_IF_ERROR (TRITONBACKEND_InputProperties (
14821514 repr_input, nullptr , nullptr , &shape, &dims_count, nullptr ,
14831515 nullptr ));
1484- shape_vec[0 ] += backend::GetElementCount (shape, dims_count);
1516+ int64_t element_cnt = 0 ;
1517+ RETURN_IF_ERROR (
1518+ backend::GetElementCount (shape, dims_count, &element_cnt));
1519+ shape_vec[0 ] += element_cnt;
14851520 }
14861521 auto err = ValidateDimension (
14871522 shape_vec, citr->second .min_dims_ [io_index],
@@ -2462,7 +2497,8 @@ ModelInstanceState::InitializeConfigShapeOutputBindings(
24622497 context.context_ ->getTensorShape (io_name.c_str ());
24632498 std::vector<int64_t > dim_vec;
24642499 DimsToDimVec (output_dim, &dim_vec);
2465- int64_t byte_size = GetByteSize (dt, dim_vec);
2500+ int64_t byte_size = 0 ;
2501+ RETURN_IF_ERROR (GetByteSize (dt, dim_vec, &byte_size));
24662502
24672503 max_byte_size = std::max (max_byte_size, byte_size);
24682504 }
@@ -2691,13 +2727,13 @@ ModelInstanceState::InitializeExecuteInputBinding(
26912727
26922728 int64_t byte_size = 0 ;
26932729 if (io_binding_info.GetFormat ().is_linear_format_ ) {
2694- byte_size = GetByteSize (dt, maximum_dims);
2730+ RETURN_IF_ERROR ( GetByteSize (dt, maximum_dims, &byte_size) );
26952731 } else {
26962732 maximum_dims[io_binding_info.GetFormat ().vectorized_dim_ ] +=
26972733 (io_binding_info.GetFormat ().components_per_element_ -
26982734 (maximum_dims[io_binding_info.GetFormat ().vectorized_dim_ ] %
26992735 io_binding_info.GetFormat ().components_per_element_ ));
2700- byte_size = GetByteSize (dt, maximum_dims);
2736+ RETURN_IF_ERROR ( GetByteSize (dt, maximum_dims, &byte_size) );
27012737 }
27022738
27032739 if (byte_size == -1 ) {
@@ -3097,7 +3133,7 @@ ModelInstanceState::InitializeShapeInputBinding(
30973133 std::vector<int64_t > dim_vec;
30983134 DimsToDimVec (
30993135 context.context_ ->getTensorShape (input_name.c_str ()), &dim_vec);
3100- byte_size = GetByteSize (dt, dim_vec);
3136+ RETURN_IF_ERROR ( GetByteSize (dt, dim_vec, &byte_size) );
31013137 } else {
31023138 auto component_count = GetElementCount (
31033139 context.context_ ->getTensorStrides (input_name.c_str ()));
0 commit comments