diff --git a/tests/unit/vertexai/genai/replays/test_assess_multimodal_dataset.py b/tests/unit/vertexai/genai/replays/test_assess_multimodal_dataset.py index 673a3b5b68..37f1aeee8a 100644 --- a/tests/unit/vertexai/genai/replays/test_assess_multimodal_dataset.py +++ b/tests/unit/vertexai/genai/replays/test_assess_multimodal_dataset.py @@ -66,6 +66,28 @@ def test_assess_tuning_resources(client): assert isinstance(response, types.TuningResourceUsageAssessmentResult) +def test_assess_batch_prediction_resources(client): + response = client.datasets.assess_batch_prediction_resources( + dataset_name=DATASET, + model_name="gemini-2.5-flash-001", + template_config=types.GeminiTemplateConfig( + gemini_example=types.GeminiExample( + contents=[ + { + "role": "user", + "parts": [{"text": "What is the capital of {name}?"}], + }, + { + "role": "model", + "parts": [{"text": "{capital}"}], + }, + ], + ), + ), + ) + assert isinstance(response, types.BatchPredictionResourceUsageAssessmentResult) + + pytestmark = pytest_helper.setup( file=__file__, globals_for_file=globals(), @@ -114,3 +136,26 @@ async def test_assess_tuning_resources_async(client): ), ) assert isinstance(response, types.TuningResourceUsageAssessmentResult) + + +@pytest.mark.asyncio +async def test_assess_batch_prediction_resources_async(client): + response = await client.aio.datasets.assess_batch_prediction_resources( + dataset_name=DATASET, + model_name="gemini-2.5-flash-001", + template_config=types.GeminiTemplateConfig( + gemini_example=types.GeminiExample( + contents=[ + { + "role": "user", + "parts": [{"text": "What is the capital of {name}?"}], + }, + { + "role": "model", + "parts": [{"text": "{capital}"}], + }, + ], + ), + ), + ) + assert isinstance(response, types.BatchPredictionResourceUsageAssessmentResult) diff --git a/vertexai/_genai/datasets.py b/vertexai/_genai/datasets.py index e92070704f..7262995c5d 100644 --- a/vertexai/_genai/datasets.py +++ b/vertexai/_genai/datasets.py @@ -1054,6 +1054,70 @@ def assess_tuning_resources( response["tuningResourceUsageAssessmentResult"], ) + def assess_batch_prediction_resources( + self, + *, + dataset_name: str, + model_name: str, + template_config: Optional[types.GeminiTemplateConfigOrDict] = None, + config: Optional[types.AssessDatasetConfigOrDict] = None, + ) -> types.BatchPredictionResourceUsageAssessmentResult: + """Assess the batch prediction resources required for a given model. + + Args: + dataset_name: + Required. The name of the dataset to assess the batch prediction + resources. + model_name: + Required. The name of the model to assess the batch prediction + resources. + template_config: + Optional. The template config used to assemble the dataset + before assessing the batch prediction resources. If not provided, + the template config attached to the dataset will be used. Required + if no template config is attached to the dataset. + template_config: + Optional. The template config used to assemble the dataset + before assessing the batch prediction resources. If not provided, the + template config attached to the dataset will be used. Required + if no template config is attached to the dataset. + config: + Optional. A configuration for assessing the batch prediction + resources. If not provided, the default configuration will be + used. + + Returns: + A types.BatchPredictionResourceUsageAssessmentResult object + representing the batch prediction resource usage assessment result. + It contains the following keys: + - token_count: The number of tokens in the dataset. + - audio_token_count: The number of audio tokens in the dataset. + + """ + if isinstance(config, dict): + config = types.AssessDatasetConfig(**config) + elif not config: + config = types.AssessDatasetConfig() + + operation = self._assess_multimodal_dataset( + name=dataset_name, + batch_prediction_resource_usage_assessment_config=types.BatchPredictionResourceUsageAssessmentConfig( + model_name=model_name, + ), + gemini_request_read_config=types.GeminiRequestReadConfig( + template_config=template_config, + ), + config=config, + ) + response = self._wait_for_operation( + operation=operation, + timeout_seconds=config.timeout, + ) + result = response["batchPredictionResourceUsageAssessmentResult"] + return _datasets_utils.create_from_response( + types.BatchPredictionResourceUsageAssessmentResult, result + ) + class AsyncDatasets(_api_module.BaseModule): @@ -1875,3 +1939,67 @@ async def assess_tuning_resources( types.TuningResourceUsageAssessmentResult, response["tuningResourceUsageAssessmentResult"], ) + + async def assess_batch_prediction_resources( + self, + *, + dataset_name: str, + model_name: str, + template_config: Optional[types.GeminiTemplateConfigOrDict] = None, + config: Optional[types.AssessDatasetConfigOrDict] = None, + ) -> types.BatchPredictionResourceUsageAssessmentResult: + """Assess the batch prediction resources required for a given model. + + Args: + dataset_name: + Required. The name of the dataset to assess the batch prediction + resources. + model_name: + Required. The name of the model to assess the batch prediction + resources. + template_config: + Optional. The template config used to assemble the dataset + before assessing the batch prediction resources. If not provided, + the template config attached to the dataset will be used. Required + if no template config is attached to the dataset. + template_config: + Optional. The template config used to assemble the dataset + before assessing the batch prediction resources. If not provided, the + template config attached to the dataset will be used. Required + if no template config is attached to the dataset. + config: + Optional. A configuration for assessing the batch prediction + resources. If not provided, the default configuration will be + used. + + Returns: + A types.BatchPredictionResourceUsageAssessmentResult object + representing the batch prediction resource usage assessment result. + It contains the following keys: + - token_count: The number of tokens in the dataset. + - audio_token_count: The number of audio tokens in the dataset. + + """ + if isinstance(config, dict): + config = types.AssessDatasetConfig(**config) + elif not config: + config = types.AssessDatasetConfig() + + operation = await self._assess_multimodal_dataset( + name=dataset_name, + batch_prediction_resource_usage_assessment_config=types.BatchPredictionResourceUsageAssessmentConfig( + model_name=model_name, + ), + gemini_request_read_config=types.GeminiRequestReadConfig( + template_config=template_config, + ), + config=config, + ) + response = await self._wait_for_operation( + operation=operation, + timeout_seconds=config.timeout, + ) + result = response["batchPredictionResourceUsageAssessmentResult"] + return _datasets_utils.create_from_response( + types.BatchPredictionResourceUsageAssessmentResult, result + )