66import  neural_tangents 
77from  neural_tangents  import  stax 
88
9- from  pkg_resources  import  parse_version 
10- if  parse_version (neural_tangents .__version__ ) >=  parse_version ('0.5.0' ):
11-   from  neural_tangents ._src .utils  import  utils , dataclasses 
12-   from  neural_tangents ._src .stax .linear  import  _pool_kernel , Padding 
13-   from  neural_tangents ._src .stax .linear  import  _Pooling  as  Pooling 
14- else :
15-   from  neural_tangents .utils  import  utils , dataclasses 
16-   from  neural_tangents .stax  import  _pool_kernel , Padding , Pooling 
17- 
18- from  sketching  import  TensorSRHT2 , PolyTensorSRHT 
9+ # from pkg_resources import parse_version 
10+ # if parse_version(neural_tangents.__version__) >= parse_version('0.5.0'): 
11+ #   from neural_tangents._src.utils import utils, dataclasses 
12+ #   from neural_tangents._src.stax.linear import _pool_kernel, Padding 
13+ #   from neural_tangents._src.stax.linear import _Pooling as Pooling 
14+ # else: 
15+ #   from neural_tangents.utils import utils, dataclasses 
16+ #   from neural_tangents.stax import _pool_kernel, Padding, Pooling 
17+ from  neural_tangents ._src .utils  import  dataclasses 
18+ # from neural_tangents._src.utils.typing import Optional 
19+ from  typing  import  Optional 
20+ from  neural_tangents ._src .stax .linear  import  _pool_kernel , Padding 
21+ from  neural_tangents ._src .stax .linear  import  _Pooling  as  Pooling 
22+ 
23+ from  experimental .sketching  import  TensorSRHT2 , PolyTensorSRHT 
1924""" Implementation for NTK Sketching and Random Features """ 
2025
2126
@@ -50,13 +55,13 @@ def kappa1(x):
5055
5156@dataclasses .dataclass  
5257class  Features :
53-   nngp_feat : np .ndarray 
54-   ntk_feat : np .ndarray 
58+   nngp_feat : Optional [ np .ndarray ]  =   None 
59+   ntk_feat : Optional [ np .ndarray ]  =   None 
5560
5661  batch_axis : int  =  dataclasses .field (pytree_node = False )
5762  channel_axis : int  =  dataclasses .field (pytree_node = False )
5863
59-   replace  =  ...   # type: Callable[..., 'Features'] 
64+   replace  =  ... 
6065
6166
6267def  _inputs_to_features (x : np .ndarray ,
@@ -69,7 +74,7 @@ def _inputs_to_features(x: np.ndarray,
6974  nngp_feat  =  x  /  x .shape [channel_axis ]** 0.5 
7075  ntk_feat  =  np .empty ((), dtype = nngp_feat .dtype )
7176
72-   return  Features (nngp_feat = nngp_feat ,
77+   return  Features . replace (nngp_feat = nngp_feat ,
7378                  ntk_feat = ntk_feat ,
7479                  batch_axis = batch_axis ,
7580                  channel_axis = channel_axis )
0 commit comments