diff --git a/docs/changelog/136732.yaml b/docs/changelog/136732.yaml new file mode 100644 index 0000000000000..34aa46d6e218c --- /dev/null +++ b/docs/changelog/136732.yaml @@ -0,0 +1,5 @@ +pr: 136732 +summary: Address `CompoundRetrieverBuilder` Failure Handling +area: Search +type: bug +issues: [] diff --git a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java index 8042be444292d..51968af61ee7e 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilder.java @@ -19,6 +19,7 @@ import org.elasticsearch.action.search.MultiSearchResponse; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.action.search.TransportMultiSearchAction; import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.query.QueryBuilder; @@ -174,9 +175,13 @@ public void onResponse(MultiSearchResponse items) { } } else { assert item.getResponse() != null; - var rankDocs = getRankDocs(item.getResponse()); - innerRetrievers.get(i).retriever().setRankDocs(rankDocs); - topDocs.add(rankDocs); + if (item.getResponse().getFailedShards() > 0) { + statusCode = handleShardFailures(item.getResponse(), statusCode, failures); + } else { + var rankDocs = getRankDocs(item.getResponse()); + innerRetrievers.get(i).retriever().setRankDocs(rankDocs); + topDocs.add(rankDocs); + } } } if (false == failures.isEmpty()) { @@ -212,6 +217,26 @@ public void onFailure(Exception e) { return rankDocsRetrieverBuilder; } + static int handleShardFailures(SearchResponse response, int statusCode, List failures) { + ShardSearchFailure[] shardFailures = response.getShardFailures(); + for (ShardSearchFailure shardFailure : shardFailures) { + if (shardFailure != null) { + int shardFailureStatusCode = ExceptionsHelper.status(shardFailure.getCause()).getStatus(); + failures.add( + new ElasticsearchStatusException( + "failed to retrieve data from shard [" + + shardFailure.shardId() + + "] with message: " + + shardFailure.getCause().getMessage(), + RestStatus.fromCode(shardFailureStatusCode) + ) + ); + statusCode = Math.max(shardFailureStatusCode, statusCode); + } + } + return statusCode; + } + @Override public final QueryBuilder topDocsQuery() { throw new IllegalStateException("Should not be called, missing a rewrite?"); diff --git a/server/src/test/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilderTests.java b/server/src/test/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilderTests.java new file mode 100644 index 0000000000000..6345581a3ac31 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/retriever/CompoundRetrieverBuilderTests.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.retriever; + +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.ShardSearchFailure; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.test.ESTestCase; +import org.mockito.Mockito; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.mockito.Mockito.when; + +public class CompoundRetrieverBuilderTests extends ESTestCase { + public void testShardFailureHandling() { + SearchResponse response = Mockito.mock(SearchResponse.class); + ShardSearchFailure[] shardFailures = new ShardSearchFailure[2]; + shardFailures[0] = new ShardSearchFailure( + new IOException("some shard failed"), // 500 + new SearchShardTarget("1", new ShardId("1", "1", 1), "foo") + ); + shardFailures[1] = new ShardSearchFailure( + new IOException("some second shard failed"), // 500 + new SearchShardTarget("2", new ShardId("2", "2", 2), "bar") + ); + when(response.getShardFailures()).thenReturn(shardFailures); + + int priorStatusCode = randomIntBetween(200, 600); + List failures = new ArrayList<>(); + int shardFailureStatusCode = CompoundRetrieverBuilder.handleShardFailures(response, priorStatusCode, failures); + + assertEquals(2, failures.size()); + assertEquals("failed to retrieve data from shard [1] with message: some shard failed", failures.get(0).getMessage()); + assertEquals("failed to retrieve data from shard [2] with message: some second shard failed", failures.get(1).getMessage()); + assertEquals(Math.max(priorStatusCode, 500), shardFailureStatusCode); + } +}