Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,28 @@ def test_assess_tuning_resources(client):
assert isinstance(response, types.TuningResourceUsageAssessmentResult)


def test_assess_batch_prediction_resources(client):
response = client.datasets.assess_batch_prediction_resources(
dataset_name=DATASET,
model_name="gemini-2.5-flash-001",
template_config=types.GeminiTemplateConfig(
gemini_example=types.GeminiExample(
contents=[
{
"role": "user",
"parts": [{"text": "What is the capital of {name}?"}],
},
{
"role": "model",
"parts": [{"text": "{capital}"}],
},
],
),
),
)
assert isinstance(response, types.BatchPredictionResourceUsageAssessmentResult)


pytestmark = pytest_helper.setup(
file=__file__,
globals_for_file=globals(),
Expand Down Expand Up @@ -114,3 +136,26 @@ async def test_assess_tuning_resources_async(client):
),
)
assert isinstance(response, types.TuningResourceUsageAssessmentResult)


@pytest.mark.asyncio
async def test_assess_batch_prediction_resources_async(client):
response = await client.aio.datasets.assess_batch_prediction_resources(
dataset_name=DATASET,
model_name="gemini-2.5-flash-001",
template_config=types.GeminiTemplateConfig(
gemini_example=types.GeminiExample(
contents=[
{
"role": "user",
"parts": [{"text": "What is the capital of {name}?"}],
},
{
"role": "model",
"parts": [{"text": "{capital}"}],
},
],
),
),
)
assert isinstance(response, types.BatchPredictionResourceUsageAssessmentResult)
128 changes: 128 additions & 0 deletions vertexai/_genai/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,70 @@ def assess_tuning_resources(
response["tuningResourceUsageAssessmentResult"],
)

def assess_batch_prediction_resources(
self,
*,
dataset_name: str,
model_name: str,
template_config: Optional[types.GeminiTemplateConfigOrDict] = None,
config: Optional[types.AssessDatasetConfigOrDict] = None,
) -> types.BatchPredictionResourceUsageAssessmentResult:
"""Assess the batch prediction resources required for a given model.

Args:
dataset_name:
Required. The name of the dataset to assess the batch prediction
resources.
model_name:
Required. The name of the model to assess the batch prediction
resources.
template_config:
Optional. The template config used to assemble the dataset
before assessing the batch prediction resources. If not provided,
the template config attached to the dataset will be used. Required
if no template config is attached to the dataset.
template_config:
Optional. The template config used to assemble the dataset
before assessing the batch prediction resources. If not provided, the
template config attached to the dataset will be used. Required
if no template config is attached to the dataset.
config:
Optional. A configuration for assessing the batch prediction
resources. If not provided, the default configuration will be
used.

Returns:
A types.BatchPredictionResourceUsageAssessmentResult object
representing the batch prediction resource usage assessment result.
It contains the following keys:
- token_count: The number of tokens in the dataset.
- audio_token_count: The number of audio tokens in the dataset.

"""
if isinstance(config, dict):
config = types.AssessDatasetConfig(**config)
elif not config:
config = types.AssessDatasetConfig()

operation = self._assess_multimodal_dataset(
name=dataset_name,
batch_prediction_resource_usage_assessment_config=types.BatchPredictionResourceUsageAssessmentConfig(
model_name=model_name,
),
gemini_request_read_config=types.GeminiRequestReadConfig(
template_config=template_config,
),
config=config,
)
response = self._wait_for_operation(
operation=operation,
timeout_seconds=config.timeout,
)
result = response["batchPredictionResourceUsageAssessmentResult"]
return _datasets_utils.create_from_response(
types.BatchPredictionResourceUsageAssessmentResult, result
)


class AsyncDatasets(_api_module.BaseModule):

Expand Down Expand Up @@ -1875,3 +1939,67 @@ async def assess_tuning_resources(
types.TuningResourceUsageAssessmentResult,
response["tuningResourceUsageAssessmentResult"],
)

async def assess_batch_prediction_resources(
self,
*,
dataset_name: str,
model_name: str,
template_config: Optional[types.GeminiTemplateConfigOrDict] = None,
config: Optional[types.AssessDatasetConfigOrDict] = None,
) -> types.BatchPredictionResourceUsageAssessmentResult:
"""Assess the batch prediction resources required for a given model.

Args:
dataset_name:
Required. The name of the dataset to assess the batch prediction
resources.
model_name:
Required. The name of the model to assess the batch prediction
resources.
template_config:
Optional. The template config used to assemble the dataset
before assessing the batch prediction resources. If not provided,
the template config attached to the dataset will be used. Required
if no template config is attached to the dataset.
template_config:
Optional. The template config used to assemble the dataset
before assessing the batch prediction resources. If not provided, the
template config attached to the dataset will be used. Required
if no template config is attached to the dataset.
config:
Optional. A configuration for assessing the batch prediction
resources. If not provided, the default configuration will be
used.

Returns:
A types.BatchPredictionResourceUsageAssessmentResult object
representing the batch prediction resource usage assessment result.
It contains the following keys:
- token_count: The number of tokens in the dataset.
- audio_token_count: The number of audio tokens in the dataset.

"""
if isinstance(config, dict):
config = types.AssessDatasetConfig(**config)
elif not config:
config = types.AssessDatasetConfig()

operation = await self._assess_multimodal_dataset(
name=dataset_name,
batch_prediction_resource_usage_assessment_config=types.BatchPredictionResourceUsageAssessmentConfig(
model_name=model_name,
),
gemini_request_read_config=types.GeminiRequestReadConfig(
template_config=template_config,
),
config=config,
)
response = await self._wait_for_operation(
operation=operation,
timeout_seconds=config.timeout,
)
result = response["batchPredictionResourceUsageAssessmentResult"]
return _datasets_utils.create_from_response(
types.BatchPredictionResourceUsageAssessmentResult, result
)
Loading