4646from  ...utils .compat  import  (a2u , a2b , int_types , int32_types , int64_types , dict_types ,
4747                             float64_types , items_types , int32 , int64 , float64 )
4848from  ...utils .authinfo  import  query_authinfo 
49+ import  warnings 
4950
5051# pylint: disable=C0330 
5152
@@ -66,6 +67,7 @@ def _print_request(rtype, url, headers, data=None):
6667
6768def  _print_response (text ):
6869    ''' Print the response for debugging ''' 
70+     sys .stderr .write ("RESPONSE text: \n " )
6971    sys .stderr .write (a2u (text , 'utf-8' ))
7072    sys .stderr .write ('\n ' )
7173
@@ -206,7 +208,8 @@ class REST_CASConnection(object):
206208    username : string 
207209        The CAS username 
208210    password : string 
209-         The CAS password 
211+         The CAS password or an OAuth token 
212+         If an OAuth token is specified, do not specify username 
210213    soptions : string 
211214        The string containing connection options 
212215    error : REST_CASError 
@@ -291,13 +294,31 @@ def __init__(self, hostname, port, username, password, soptions, error):
291294        self ._error  =  error 
292295        self ._results  =  None 
293296
297+         allow_basic_auth  =  get_option ('cas.allow_basic_auth' )
298+         # add in the following when allow_basic_auth default is changed to False 
299+         # if allow_basic_auth: 
300+         #     logger.warning("allow_basic_auth has been deprecated and will be removed " 
301+         #                 "in a future release") 
302+         #     warnings.warn("allow_basic_auth has been deprecated and will be removed " 
303+         #                 "in a future release", category=FutureWarning) 
304+ 
294305        if  username  and  password :
295306            logger .debug ('Using Basic authentication' )
307+ #           if allow_basic_auth: 
308+ #               logger.warning("Basic authentication has been deprecated in servers that " 
309+ #                              "support OAuth tokens. It will be removed " 
310+ #                              "in a future release") 
311+ #               warnings.warn("Basic authentication has been deprecated in servers that " 
312+ #                             "support OAuth tokens. It will be removed " 
313+ #                             "in a future release.", category=FutureWarning) 
314+ 
296315            self ._auth  =  b'Basic '  +  base64 .b64encode (
297316                ('%s:%s'  %  (username , password )).encode ('utf-8' )).strip ()
317+             self ._isBearer  =  False 
298318        elif  password :
299319            logger .debug ('Using Bearer token authentication' )
300320            self ._auth  =  b'Bearer '  +  a2b (password ).strip ()
321+             self ._isBearer  =  True 
301322        else :
302323            raise  SWATError ('Either username and password, or OAuth token in the ' 
303324                            'password parameter must be specified.' )
@@ -310,9 +331,15 @@ def __init__(self, hostname, port, username, password, soptions, error):
310331            'Accept' : 'application/json' ,
311332            'Content-Type' : 'application/json' ,
312333            'Content-Length' : '0' ,
313-             'Authorization' : self ._auth ,
314334        })
315335
336+         if  allow_basic_auth  or  self ._isBearer :
337+             # user is using an OAuth token, or the allow_basic_auth option is set to 
338+             # allow Basic Auth 
339+             self ._req_sess .headers .update ({
340+                 'Authorization' : self ._auth ,
341+             })
342+ 
316343        self ._connect (session = session , locale = locale , wait_until_idle = False )
317344
318345    def  _connect (self , session = None , locale = None , wait_until_idle = True ):
@@ -358,6 +385,11 @@ def _connect(self, session=None, locale=None, wait_until_idle=True):
358385                            time .sleep (connection_retry_interval )
359386                            get_retries  +=  1 
360387                            continue 
388+ 
389+                         if  res .status_code  ==  401 :
390+                             if  self ._check_authorization_method (res ):
391+                                 continue 
392+ 
361393                        break 
362394
363395                else :
@@ -376,6 +408,11 @@ def _connect(self, session=None, locale=None, wait_until_idle=True):
376408                            time .sleep (connection_retry_interval )
377409                            put_retries  +=  1 
378410                            continue 
411+ 
412+                         if  res .status_code  ==  401 :
413+                             if  self ._check_authorization_method (res ):
414+                                 continue 
415+ 
379416                        break 
380417
381418                if  'tkhttp-id'  in  res .cookies :
@@ -420,15 +457,69 @@ def _connect(self, session=None, locale=None, wait_until_idle=True):
420457            except  KeyError :
421458                raise  SWATError (str (out ))
422459
423-             except  Exception  as  exc :
424-                 raise  SWATError (str (exc ))
425- 
426460            except  SWATError :
427461                raise 
428462
463+             except  Exception  as  exc :
464+                 raise  SWATError (str (exc ))
465+ 
429466        if  wait_until_idle :
430467            self ._wait_until_idle ()
431468
469+     def  _check_authorization_method (self , res ):
470+         ''' 
471+         Check whether Bearer auth is supported 
472+ 
473+         Notes 
474+         ----- 
475+         This method may modify the request session headers. 
476+ 
477+         Parameters 
478+         ---------- 
479+         res : requests.models.Response 
480+             The http response from the /cas/sessions request 
481+ 
482+         Returns 
483+         ------- 
484+         boolean 
485+             Whether or not the /cas/sessions should be 
486+             retried with the new Authorization header 
487+         ''' 
488+         logger .debug ('HTTP 401 error : checking Authorization header' )
489+         retry_auth  =  False 
490+ 
491+         if  'Authorization'  in  self ._req_sess .headers :
492+             # we've already sent an Authorization header and 
493+             # the server didn't like it, bail. 
494+             logger .debug ("Authorization header previously sent" )
495+         else :
496+             # < Viya 4 does not return correct info in www-authenticate 
497+             # Hack check : Is this a viya 4 server: 
498+             if  'tkhttp-id'  in  res .cookies :
499+                 # Does the server support Bearer auth ? 
500+                 wwwauth  =  res .headers .get ('WWW-Authenticate' )
501+                 if  wwwauth :
502+                     supported_auths  =  wwwauth .split (', ' )
503+                     for  supported_auth  in  supported_auths :
504+                         if  supported_auth .lower ().startswith ('bearer' ):
505+                             # OAuth is supported and the user 
506+                             # is using userid/password. 
507+                             raise  SWATError ('You must use an OAuth token ' 
508+                                             'to connect to this server.' )
509+             else :
510+                 logger .debug ("Server does not appear to be a viya4 server, " 
511+                              "allowing basic auth" )
512+ 
513+             # There was no www-authenticate, or OAuth is not 
514+             # supported on this server, or server is < viya4. 
515+             # Add the basic auth header and retry with basic auth 
516+             self ._req_sess .headers .update ({
517+                 'Authorization' : self ._auth ,
518+             })
519+             retry_auth  =  True 
520+ 
521+         return  retry_auth 
522+ 
432523    def  _set_next_connection (self ):
433524        ''' Iterate to the next available controller ''' 
434525        self ._host_index  +=  1 
0 commit comments