Skip to content

Commit 9f16092

Browse files
committed
update action test
1 parent f33473e commit 9f16092

File tree

2 files changed

+110
-10
lines changed

2 files changed

+110
-10
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBasicLicenseIT.java

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,6 @@ public void testUpdateModel_RestrictedWithBasicLicense() throws Exception {
5757
sendRestrictedRequest("PUT", endpoint, requestBody);
5858
}
5959

60-
public void testPerformInference_RestrictedWithBasicLicense() throws Exception {
61-
var endpoint = Strings.format("_inference/%s/%s?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
62-
var requestBody = """
63-
{
64-
"input": ["washing", "machine"]
65-
}
66-
""";
67-
sendRestrictedRequest("POST", endpoint, requestBody);
68-
}
69-
7060
public void testGetServices_NonRestrictedWithBasicLicense() throws Exception {
7161
var endpoint = "_inference/_services";
7262
sendNonRestrictedRequest("GET", endpoint, null, 200, false);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.action;
9+
10+
import org.elasticsearch.ElasticsearchSecurityException;
11+
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.action.support.ActionFilters;
13+
import org.elasticsearch.action.support.PlainActionFuture;
14+
import org.elasticsearch.client.internal.Client;
15+
import org.elasticsearch.cluster.ClusterState;
16+
import org.elasticsearch.cluster.project.TestProjectResolvers;
17+
import org.elasticsearch.cluster.service.ClusterService;
18+
import org.elasticsearch.common.bytes.BytesArray;
19+
import org.elasticsearch.core.TimeValue;
20+
import org.elasticsearch.inference.InferenceService;
21+
import org.elasticsearch.inference.InferenceServiceRegistry;
22+
import org.elasticsearch.inference.TaskType;
23+
import org.elasticsearch.inference.UnparsedModel;
24+
import org.elasticsearch.license.MockLicenseState;
25+
import org.elasticsearch.tasks.Task;
26+
import org.elasticsearch.test.ESTestCase;
27+
import org.elasticsearch.threadpool.ThreadPool;
28+
import org.elasticsearch.transport.TransportService;
29+
import org.elasticsearch.xcontent.XContentType;
30+
import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction;
31+
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
32+
import org.junit.Before;
33+
34+
import java.util.Map;
35+
import java.util.Optional;
36+
37+
import static org.hamcrest.Matchers.is;
38+
import static org.mockito.ArgumentMatchers.any;
39+
import static org.mockito.ArgumentMatchers.anyString;
40+
import static org.mockito.Mockito.doAnswer;
41+
import static org.mockito.Mockito.mock;
42+
import static org.mockito.Mockito.when;
43+
44+
public class TransportUpdateInferenceModelActionTests extends ESTestCase {
45+
46+
private MockLicenseState licenseState;
47+
private TransportUpdateInferenceModelAction action;
48+
private ThreadPool threadPool;
49+
private ModelRegistry mockModelRegistry;
50+
private InferenceServiceRegistry mockInferenceServiceRegistry;
51+
52+
@Before
53+
public void createAction() throws Exception {
54+
super.setUp();
55+
threadPool = mock(ThreadPool.class);
56+
mockModelRegistry = mock(ModelRegistry.class);
57+
mockInferenceServiceRegistry = mock(InferenceServiceRegistry.class);
58+
licenseState = MockLicenseState.createMock();
59+
action = new TransportUpdateInferenceModelAction(
60+
mock(TransportService.class),
61+
mock(ClusterService.class),
62+
threadPool,
63+
mock(ActionFilters.class),
64+
licenseState,
65+
mockModelRegistry,
66+
mockInferenceServiceRegistry,
67+
mock(Client.class),
68+
TestProjectResolvers.DEFAULT_PROJECT_ONLY
69+
);
70+
71+
}
72+
73+
public void testLicenseCheck_NotAllowed() {
74+
mocks("enterprise_licensed_service", false);
75+
76+
var listener = new PlainActionFuture<UpdateInferenceModelAction.Response>();
77+
78+
String requestBody = "{\"service_settings\": {\"api_key\": \"<API_KEY>\"}}";
79+
80+
action.masterOperation(
81+
mock(Task.class),
82+
new UpdateInferenceModelAction.Request(
83+
"model-id",
84+
new BytesArray(requestBody),
85+
XContentType.JSON,
86+
TaskType.TEXT_EMBEDDING,
87+
TimeValue.timeValueSeconds(1)
88+
),
89+
ClusterState.EMPTY_STATE,
90+
listener
91+
);
92+
93+
var exception = expectThrows(ElasticsearchSecurityException.class, () -> listener.actionGet(TimeValue.timeValueSeconds(5)));
94+
assertThat(exception.getMessage(), is("current license is non-compliant for [inference]"));
95+
}
96+
97+
private void mocks(String serviceName, boolean isAllowed) {
98+
doAnswer(invocationOnMock -> {
99+
ActionListener<UnparsedModel> listener = invocationOnMock.getArgument(1);
100+
listener.onResponse(new UnparsedModel("model_id", TaskType.COMPLETION, serviceName, Map.of(), Map.of()));
101+
return Void.TYPE;
102+
}).when(mockModelRegistry).getModelWithSecrets(anyString(), any());
103+
104+
var mockService = mock(InferenceService.class);
105+
when(mockService.name()).thenReturn(serviceName);
106+
when(mockInferenceServiceRegistry.getService(anyString())).thenReturn(Optional.of(mockService));
107+
108+
when(licenseState.isAllowed(any())).thenReturn(isAllowed);
109+
}
110+
}

0 commit comments

Comments
 (0)