@@ -63,6 +63,35 @@ def sklearn_model(train_data):
6363 return model
6464
6565
66+ @pytest .fixture
67+ def sklearn_pipeline (train_data ):
68+ from sklearn .pipeline import Pipeline
69+ from sklearn .ensemble import GradientBoostingClassifier
70+ from sklearn .preprocessing import StandardScaler
71+ from sklearn .impute import SimpleImputer
72+ from sklearn .compose import ColumnTransformer
73+
74+ X , y = train_data
75+
76+ numeric_transformer = Pipeline ([
77+ ('imputer' , SimpleImputer (strategy = 'median' )),
78+ ('scaler' , StandardScaler ())
79+ ])
80+
81+ preprocessor = ColumnTransformer ([
82+ ('num' , numeric_transformer , X .columns )
83+ ])
84+
85+ pipe = Pipeline ([
86+ ('preprocess' , preprocessor ),
87+ ('classifier' , GradientBoostingClassifier ())
88+ ])
89+
90+ pipe .fit (X , y )
91+
92+ return pipe
93+
94+
6695@pytest .fixture
6796def pickle_file (tmpdir_factory , sklearn_model ):
6897 """Returns the path to a file containing a pickled Scikit-Learn model """
@@ -215,6 +244,17 @@ def test_from_python_file(python_file):
215244 assert isinstance (p , PyMAS )
216245
217246
247+ def test_with_sklearn_pipeline (train_data , sklearn_pipeline ):
248+ from sasctl .utils .pymas import PyMAS , from_pickle
249+
250+ X , y = train_data
251+ p = from_pickle (pickle .dumps (sklearn_pipeline ),
252+ func_name = 'predict' ,
253+ input_types = X )
254+
255+ assert isinstance (p , PyMAS )
256+ assert len (p .variables ) > 4 # 4 input features in Iris data set
257+
218258@pytest .mark .usefixtures ('session' )
219259def test_publish_and_execute (tmpdir ):
220260 import pickle
0 commit comments