99from tensorlayer .layers .utils import (get_variable_with_initializer )
1010from tensorlayer import logging
1111
12- __all__ = ['Module' , 'SequentialLayer' ]
12+ __all__ = ['Module' , 'SequentialLayer' , 'LayerList' ]
1313
1414_global_layer_name_dict = {}
1515Parameter_ = tf .Variable
@@ -606,19 +606,19 @@ def __init__(self, *args):
606606 def __getitem__ (self , index ):
607607 if isinstance (index , slice ):
608608 return self .__class__ (OrderedDict (list (self ._layers .items ())[index ]))
609- index = self . _valid_index (len (self ), index )
609+ index = _valid_index (len (self ), index )
610610 return list (self ._layers .values ())[index ]
611611
612612 def __setitem__ (self , index , layer ):
613- if self . _valid_module (layer ):
614- index = self . _valid_index (len (self ), index )
613+ if _valid_module (layer ):
614+ index = _valid_index (len (self ), index )
615615 key = list (self ._layers .keys ())[index ]
616616 self ._layers [key ] = layer
617617 self .layer_list = list (self ._layers .values ())
618618
619619 def __delitem__ (self , index ):
620620 if isinstance (index , int ):
621- index = self . _valid_index (len (self ), index )
621+ index = _valid_index (len (self ), index )
622622 key = list (self ._layers .keys ())[index ]
623623 del self ._layers [key ]
624624 elif isinstance (index , slice ):
@@ -633,7 +633,7 @@ def __len__(self):
633633 return len (self ._layers )
634634
635635 def append (self , layer ):
636- if self . _valid_module (layer ):
636+ if _valid_module (layer ):
637637 self ._layers [str (len (self ))] = layer
638638 self .layer_list = list (self ._layers .values ())
639639 return self
@@ -646,16 +646,133 @@ def forward(self, input_data):
646646 input_data = layer (input_data )
647647 return input_data
648648
649- def _valid_index (self , layer_num , index ):
650- if not isinstance (index , int ):
651- raise TypeError ("Index {} is not int type" )
652- if not - layer_num <= index < layer_num :
653- raise IndexError (
654- "Index should be a number in range [{}, {}), but got {}" .format (- layer_num , layer_num , index )
655- )
656- return index % layer_num
657-
658- def _valid_module (self , layer ):
659- if issubclass (layer .__class__ , Module ):
660- return True
661- raise TypeError ('Module {} is not subclass of Module' .format (layer ))
649+
650+ class LayerList (Module ):
651+ """
652+ Holds Modules in a list.
653+
654+ LayerList can be used like a regular Python list, support
655+ '__getitem__', '__setitem__', '__delitem__', '__len__', '__iter__' and '__iadd__',
656+ but module it contains are properly registered, and will be visible by all Modules methods.
657+
658+ Parameters
659+ ----------
660+ args : list
661+ List of subclass of Module.
662+ Methods
663+ ---------
664+ __init__()
665+ Initializing the Layer.
666+ insert()
667+ Inserts a given layer before a given index in the list.
668+ extend()
669+ Appends layers from a Python iterable to the end of the list.
670+ append()
671+ Appends a given layer to the end of the list.
672+
673+ Examples
674+ ---------
675+ Args:
676+ args (list, optional): List of subclass of Module.
677+
678+ Examples:
679+
680+ """
681+ def __init__ (self , * args , ** kwargs ):
682+ super (LayerList , self ).__init__ ()
683+ if len (args ) == 1 :
684+ self .extend (args [0 ])
685+
686+ def __getitem__ (self , index ):
687+ if isinstance (index , slice ):
688+ return self .__class__ (list (self ._layers .values ())[index ])
689+ if isinstance (index , int ):
690+ index = _valid_index (len (self ), index )
691+ return self ._layers [str (index )]
692+ raise TypeError ('Index {} is not int type or slice type' .format (index ))
693+
694+ def __setitem__ (self , index , layer ):
695+ if not isinstance (index , int ) and _valid_module (layer ):
696+ raise TypeError ('Index {} is not int type' .format (index ))
697+ index = _valid_index (len (self ), index )
698+ self ._layers [str (index )] = layer
699+
700+ def __delitem__ (self , index ):
701+ if isinstance (index , int ):
702+ index = _valid_index (len (self ), index )
703+ del self ._layers [str (index )]
704+ elif isinstance (index , slice ):
705+ keys = list (self ._layers .keys ())[index ]
706+ for key in keys :
707+ del self ._layers [key ]
708+ else :
709+ raise TypeError ('Index {} is not int type or slice type' .format (index ))
710+ temp_dict = OrderedDict ()
711+ for idx , layer in enumerate (self ._layers .values ()):
712+ temp_dict [str (idx )] = layer
713+ self ._layers = temp_dict
714+
715+ def __len__ (self ):
716+ return len (self ._layers )
717+
718+ def __iter__ (self ):
719+ return iter (self ._layers .values ())
720+
721+ def __iadd__ (self , layers ):
722+ self .extend (layers )
723+ return self
724+
725+ def insert (self , index , layer ):
726+ """
727+ Inserts a given layer before a given index in the list.
728+
729+ """
730+
731+ idx = _valid_index (len (self ), index )
732+ _valid_module (layer )
733+ length = len (self )
734+ while length > idx :
735+ self ._layers [str (length )] = self ._layers [str (length - 1 )]
736+ length -= 1
737+ self ._layers [str (idx )] = layer
738+
739+ def extend (self , layers ):
740+ """
741+ Appends layers from a Python iterable to the end of the list.
742+
743+ """
744+
745+ if not isinstance (layers , list ):
746+ raise TypeError ('Modules {} should be list of sublayers' .format (layers ))
747+ for layer in layers :
748+ if _valid_module (layer ):
749+ self ._layers [str (len (self ))] = layer
750+ return self
751+
752+ def append (self , layer ):
753+ """
754+ Appends a given layer to the end of the list.
755+
756+ """
757+
758+ if _valid_module (layer ):
759+ self ._layers [str (len (self ))] = layer
760+
761+ def forward (self , * inputs ):
762+ raise NotImplementedError
763+
764+
765+ def _valid_index (layer_num , index ):
766+ if not isinstance (index , int ):
767+ raise TypeError ("Index {} is not int type" )
768+ if not - layer_num <= index < layer_num :
769+ raise IndexError (
770+ "Index should be a number in range [{}, {}), but got {}" .format (- layer_num , layer_num , index )
771+ )
772+ return index % layer_num
773+
774+
775+ def _valid_module (layer ):
776+ if issubclass (layer .__class__ , Module ):
777+ return True
778+ raise TypeError ('Module {} is not subclass of Module' .format (layer ))
0 commit comments