55# SPDX-License-Identifier: Apache-2.0
66
77import copy
8- from unittest import mock
98
109import pytest
10+ from six .moves import mock
1111from sasctl import current_session
1212from sasctl .services import model_repository as mr
1313
@@ -22,56 +22,67 @@ def test_create_model():
2222 with mock .patch ('sasctl.core.requests.Session.request' ):
2323 current_session ('example.com' , USER , 'password' )
2424
25- TARGET = {'name' : MODEL_NAME ,
26- 'projectId' : PROJECT_ID ,
27- 'modeler' : USER ,
28- 'description' : 'model description' ,
29- 'function' : 'Classification' ,
30- 'algorithm' : 'Dummy Algorithm' ,
31- 'tool' : 'pytest' ,
32- 'champion' : True ,
33- 'role' : 'champion' ,
34- 'immutable' : True ,
35- 'retrainable' : True ,
36- 'scoreCodeType' : None ,
37- 'targetVariable' : None ,
38- 'trainTable' : None ,
39- 'classificationEventProbabilityVariableName' : None ,
40- 'classificationTargetEventValue' : None ,
41- 'location' : None ,
42- 'properties' : [{'name' : 'custom1' , 'value' : 123 },
43- {'name' : 'custom2' , 'value' : 'somevalue' }],
44- 'inputVariables' : [],
45- 'outputVariables' : [],
46- 'version' : '2' }
25+ TARGET = {
26+ 'name' : MODEL_NAME ,
27+ 'projectId' : PROJECT_ID ,
28+ 'modeler' : USER ,
29+ 'description' : 'model description' ,
30+ 'function' : 'Classification' ,
31+ 'algorithm' : 'Dummy Algorithm' ,
32+ 'tool' : 'pytest' ,
33+ 'champion' : True ,
34+ 'role' : 'champion' ,
35+ 'immutable' : True ,
36+ 'retrainable' : True ,
37+ 'scoreCodeType' : None ,
38+ 'targetVariable' : None ,
39+ 'trainTable' : None ,
40+ 'classificationEventProbabilityVariableName' : None ,
41+ 'classificationTargetEventValue' : None ,
42+ 'location' : None ,
43+ 'properties' : [
44+ {'name' : 'custom1' , 'value' : 123 },
45+ {'name' : 'custom2' , 'value' : 'somevalue' },
46+ ],
47+ 'inputVariables' : [],
48+ 'outputVariables' : [],
49+ 'version' : '2' ,
50+ }
4751
4852 # Passed params should be set correctly
4953 target = copy .deepcopy (TARGET )
50- with mock .patch ('sasctl._services.model_repository.ModelRepository.get_project' ) as get_project :
51- with mock .patch ('sasctl._services.model_repository.ModelRepository' '.get_model' ) as get_model :
52- with mock .patch ('sasctl._services.model_repository.ModelRepository.post' ) as post :
54+ with mock .patch (
55+ 'sasctl._services.model_repository.ModelRepository.get_project'
56+ ) as get_project :
57+ with mock .patch (
58+ 'sasctl._services.model_repository.ModelRepository' '.get_model'
59+ ) as get_model :
60+ with mock .patch (
61+ 'sasctl._services.model_repository.ModelRepository.post'
62+ ) as post :
5363 get_project .return_value = {'id' : PROJECT_ID }
5464 get_model .return_value = None
55- _ = mr .create_model (MODEL_NAME ,
56- PROJECT_NAME ,
57- description = target ['description' ],
58- function = target ['function' ],
59- algorithm = target ['algorithm' ],
60- tool = target ['tool' ],
61- is_champion = True ,
62- is_immutable = True ,
63- is_retrainable = True ,
64- properties = dict (custom1 = 123 , custom2 = 'somevalue' ))
65+ _ = mr .create_model (
66+ MODEL_NAME ,
67+ PROJECT_NAME ,
68+ description = target ['description' ],
69+ function = target ['function' ],
70+ algorithm = target ['algorithm' ],
71+ tool = target ['tool' ],
72+ is_champion = True ,
73+ is_immutable = True ,
74+ is_retrainable = True ,
75+ properties = dict (custom1 = 123 , custom2 = 'somevalue' ),
76+ )
6577 assert post .call_count == 1
6678 url , data = post .call_args
6779
6880 # dict isn't guaranteed to preserve order
6981 # so k/v pairs of properties=dict() may be
7082 # returned in a different order
71- assert sorted (target ['properties' ],
72- key = lambda d : d ['name' ]) \
73- == sorted (data ['json' ]['properties' ],
74- key = lambda d : d ['name' ])
83+ assert sorted (target ['properties' ], key = lambda d : d ['name' ]) == sorted (
84+ data ['json' ]['properties' ], key = lambda d : d ['name' ]
85+ )
7586
7687 target .pop ('properties' )
7788 data ['json' ].pop ('properties' )
@@ -80,12 +91,20 @@ def test_create_model():
8091 # Model dict w/ parameters already specified should be allowed
8192 # Explicit overrides should be respected.
8293 target = copy .deepcopy (TARGET )
83- with mock .patch ('sasctl._services.model_repository.ModelRepository.get_project' ) as get_project :
84- with mock .patch ('sasctl._services.model_repository.ModelRepository' '.get_model' ) as get_model :
85- with mock .patch ('sasctl._services.model_repository.ModelRepository.post' ) as post :
94+ with mock .patch (
95+ 'sasctl._services.model_repository.ModelRepository.get_project'
96+ ) as get_project :
97+ with mock .patch (
98+ 'sasctl._services.model_repository.ModelRepository' '.get_model'
99+ ) as get_model :
100+ with mock .patch (
101+ 'sasctl._services.model_repository.ModelRepository.post'
102+ ) as post :
86103 get_project .return_value = {'id' : PROJECT_ID }
87104 get_model .return_value = None
88- _ = mr .create_model (copy .deepcopy (target ), PROJECT_NAME , description = 'Updated Model' )
105+ _ = mr .create_model (
106+ copy .deepcopy (target ), PROJECT_NAME , description = 'Updated Model'
107+ )
89108 target ['description' ] = 'Updated Model'
90109 assert post .call_count == 1
91110 url , data = post .call_args
@@ -104,10 +123,12 @@ def test_copy_analytic_store():
104123
105124 MODEL_ID = 12345
106125 # Intercept calls to lookup the model & call the "copyAnalyticStore" link
107- with mock .patch ('sasctl._services.model_repository.ModelRepository'
108- '.get_model' ) as get_model :
109- with mock .patch ('sasctl._services.model_repository.ModelRepository'
110- '.request_link' ) as request_link :
126+ with mock .patch (
127+ 'sasctl._services.model_repository.ModelRepository' '.get_model'
128+ ) as get_model :
129+ with mock .patch (
130+ 'sasctl._services.model_repository.ModelRepository' '.request_link'
131+ ) as request_link :
111132
112133 # Return a dummy Model with a static id
113134 get_model .return_value = {'id' : MODEL_ID }
@@ -138,19 +159,17 @@ def test_get_model_by_name():
138159
139160 mock_responses = [
140161 # First response is for list_items/list_models
141- [
142- {'id' : 12345 , 'name' : MODEL_NAME },
143- {'id' : 67890 , 'name' : MODEL_NAME }
144- ],
145-
162+ [{'id' : 12345 , 'name' : MODEL_NAME }, {'id' : 67890 , 'name' : MODEL_NAME }],
146163 # Second response is mock GET for model details
147- {'id' : 12345 , 'name' : MODEL_NAME }
164+ {'id' : 12345 , 'name' : MODEL_NAME },
148165 ]
149166
150- with mock .patch ('sasctl._services.model_repository.ModelRepository.request' ) as request :
167+ with mock .patch (
168+ 'sasctl._services.model_repository.ModelRepository.request'
169+ ) as request :
151170 request .side_effect = mock_responses
152171
153172 with pytest .warns (Warning ):
154173 result = mr .get_model (MODEL_NAME )
155- assert result ['id' ]== 12345
156- assert result ['name' ] == MODEL_NAME
174+ assert result ['id' ] == 12345
175+ assert result ['name' ] == MODEL_NAME
0 commit comments