Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions bigframes/ml/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from __future__ import annotations

from typing import Dict, List, Literal, Optional
from typing import Dict, List, Literal, Optional, Union

import bigframes_vendored.sklearn.ensemble._forest
import bigframes_vendored.xgboost.sklearn
Expand Down Expand Up @@ -78,6 +78,7 @@ def __init__(
tol: float = 0.01,
enable_global_explain: bool = False,
xgboost_version: Literal["0.9", "1.1"] = "0.9",
**kwargs: Union[str, str | int | bool | float | List[str]],
):
self.n_estimators = n_estimators
self.booster = booster
Expand All @@ -99,6 +100,7 @@ def __init__(
self.xgboost_version = xgboost_version
self._bqml_model: Optional[core.BqmlModel] = None
self._bqml_model_factory = globals.bqml_model_factory()
self._extra_bqml_options = kwargs

@classmethod
def _from_bq(
Expand All @@ -117,7 +119,7 @@ def _from_bq(
@property
def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
"""The model options as they will be set for BQML"""
return {
options = {
"model_type": "BOOSTED_TREE_REGRESSOR",
"data_split_method": "NO_SPLIT",
"early_stop": True,
Expand All @@ -139,6 +141,8 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
"enable_global_explain": self.enable_global_explain,
"xgboost_version": self.xgboost_version,
}
options.update(self._extra_bqml_options)
Copy link
Contributor

@GarrettWu GarrettWu Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may override existing options, if user pass in both.

We use sklearn terms instead of BQML terms for existing parameters, e.g. learning_rate instead of LEARN_RATE. User may get confused looking at two sets of options at the same time.

It becomes hard for us to both stick with sklearn-like experience and support BQML other offerings.

Or maybe we can ask them to only use one set of parameters, and raise error if conflicts?

return options # type: ignore

def _fit(
self,
Expand Down Expand Up @@ -237,6 +241,7 @@ def __init__(
tol: float = 0.01,
enable_global_explain: bool = False,
xgboost_version: Literal["0.9", "1.1"] = "0.9",
**kwargs: Union[str, str | int | bool | float | List[str]],
):
self.n_estimators = n_estimators
self.booster = booster
Expand All @@ -258,6 +263,7 @@ def __init__(
self.xgboost_version = xgboost_version
self._bqml_model: Optional[core.BqmlModel] = None
self._bqml_model_factory = globals.bqml_model_factory()
self._extra_bqml_options = kwargs

@classmethod
def _from_bq(
Expand All @@ -276,7 +282,7 @@ def _from_bq(
@property
def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
"""The model options as they will be set for BQML"""
return {
options = {
"model_type": "BOOSTED_TREE_CLASSIFIER",
"data_split_method": "NO_SPLIT",
"early_stop": True,
Expand All @@ -298,6 +304,8 @@ def _bqml_options(self) -> Dict[str, str | int | bool | float | List[str]]:
"enable_global_explain": self.enable_global_explain,
"xgboost_version": self.xgboost_version,
}
options.update(self._extra_bqml_options)
return options # type: ignore

def _fit(
self,
Expand Down
83 changes: 82 additions & 1 deletion tests/unit/ml/test_golden_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap
from unittest import mock

from google.cloud import bigquery
import pandas as pd
import pytest

import bigframes
from bigframes.ml import core, decomposition, linear_model
from bigframes.ml import core, decomposition, ensemble, linear_model
import bigframes.ml.core
import bigframes.pandas as bpd

Expand Down Expand Up @@ -286,3 +287,83 @@ def test_decomposition_mf_score_with_x(mock_session, bqml_model, mock_X):
"SELECT * FROM ML.EVALUATE(MODEL `model_project`.`model_dataset`.`model_id`,\n (input_X_sql_property))",
allow_large_results=True,
)


def test_xgb_classifier_kwargs_params_fit(
bqml_model_factory, mock_session, mock_X, mock_y
):
model = ensemble.XGBClassifier(category_encoding_method="LABEL_ENCODING")
model._bqml_model_factory = bqml_model_factory
model.fit(mock_X, mock_y)

mock_session._start_query_ml_ddl.assert_called_once_with(
textwrap.dedent(
"""
CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`
OPTIONS(
model_type='BOOSTED_TREE_CLASSIFIER',
data_split_method='NO_SPLIT',
early_stop=True,
num_parallel_tree=1,
booster_type='gbtree',
tree_method='auto',
min_tree_child_weight=1,
colsample_bytree=1.0,
colsample_bylevel=1.0,
colsample_bynode=1.0,
min_split_loss=0.0,
max_tree_depth=6,
subsample=1.0,
l1_reg=0.0,
l2_reg=1.0,
learn_rate=0.3,
max_iterations=20,
min_rel_progress=0.01,
enable_global_explain=False,
xgboost_version='0.9',
category_encoding_method='LABEL_ENCODING',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do they want to use label encoding on features? This is actually wrong. https://stackoverflow.com/a/34346937

I talked with jiashangliu@, but BQML won't change the default.

INPUT_LABEL_COLS=['input_column_label'])
AS input_X_y_no_index_sql
"""
).strip()
)


def test_xgb_regressor_kwargs_params_fit(
bqml_model_factory, mock_session, mock_X, mock_y
):
model = ensemble.XGBRegressor(category_encoding_method="LABEL_ENCODING")
model._bqml_model_factory = bqml_model_factory
model.fit(mock_X, mock_y)

mock_session._start_query_ml_ddl.assert_called_once_with(
textwrap.dedent(
"""
CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`
OPTIONS(
model_type='BOOSTED_TREE_REGRESSOR',
data_split_method='NO_SPLIT',
early_stop=True,
num_parallel_tree=1,
booster_type='gbtree',
tree_method='auto',
min_tree_child_weight=1,
colsample_bytree=1.0,
colsample_bylevel=1.0,
colsample_bynode=1.0,
min_split_loss=0.0,
max_tree_depth=6,
subsample=1.0,
l1_reg=0.0,
l2_reg=1.0,
learn_rate=0.3,
max_iterations=20,
min_rel_progress=0.01,
enable_global_explain=False,
xgboost_version='0.9',
category_encoding_method='LABEL_ENCODING',
INPUT_LABEL_COLS=['input_column_label'])
AS input_X_y_no_index_sql
"""
).strip()
)
19 changes: 18 additions & 1 deletion third_party/bigframes_vendored/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,18 @@ class XGBRegressor(XGBModel, XGBRegressorBase):
tol (Optional[float]):
Minimum relative loss improvement necessary to continue training. Default to 0.01.
enable_global_explain (Optional[bool]):
Whether to compute global explanations using explainable AI to evaluate global feature importance to the model. Default to False.
Whether to compute global explanations using explainable AI to
evaluate global feature importance to the model. Default to False.
xgboost_version (Optional[str]):
Specifies the Xgboost version for model training. Default to "0.9". Possible values: "0.9", "1.1".
kwargs (dict):
Keyword arguments for the ``model_option_list`` of the boosted tree
BQML model. See
https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-boosted-tree

For example, to set ``CATEGORY_ENCODING_METHOD`` to
``LABEL_ENCODING``, pass in the keyword argument
`category_encoding_method='LABEL_ENCODING'`.
"""


Expand Down Expand Up @@ -148,4 +157,12 @@ class XGBClassifier(XGBModel, XGBClassifierMixIn, XGBClassifierBase):
Whether to compute global explanations using explainable AI to evaluate global feature importance to the model. Default to False.
xgboost_version (Optional[str]):
Specifies the Xgboost version for model training. Default to "0.9". Possible values: "0.9", "1.1".
kwargs (dict):
Keyword arguments for the ``model_option_list`` of the boosted tree
BQML model. See
https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-boosted-tree

For example, to set ``CATEGORY_ENCODING_METHOD`` to
``LABEL_ENCODING``, pass in the keyword argument
`category_encoding_method='LABEL_ENCODING'`.
"""