diff --git a/ivy/stateful/module.py b/ivy/stateful/module.py index a2fd8fe02cc1f..59b6d69e40117 100644 --- a/ivy/stateful/module.py +++ b/ivy/stateful/module.py @@ -22,6 +22,7 @@ class Module(abc.ABC): """ Module is a base class for deriving trainable modules. """ + def __init__( self, device=None, @@ -138,21 +139,21 @@ def new_fn(*a, with_grads=None, **kw): def _top_v_fn(self, depth=None, flatten_key_chains=False): """ Help in visualising the top view of a nested network upto - a certain depth - + a certain depth (Need Discussion, I don't think `v` stands for visualization. Any clues?) + Parameters ---------- depth depth upto which we want to visualise flatten_key_chains - If True, returns a flattened view of the structure. Default is - False - + If set True, will return return a flat (depth-1) container, + which all nested key-chains flattened. Default is False. + Returns ------- ret - - + + """ if ivy.exists(self.top_v): if ivy.exists(depth): @@ -168,16 +169,16 @@ def _top_v_fn(self, depth=None, flatten_key_chains=False): def _top_mod_fn(self, depth=None): """ Find the top module at specific depth. - + Parameters ---------- depth The number of modules we want to trace back. - + Returns ------- ret - The module we want to track down. Return self if no top + The module we want to track down. Return current layer if no top module exists. """ if ivy.exists(self.top_mod): @@ -191,12 +192,12 @@ def track_submod_rets(self): """ Tracks the returns of the submodules if track_submod_returns argument is set to True during call (Need discussion) - + Returns ------- ret - Bool flag indicating whether the current module gets - tracked in the computation graph. + True if the current module gets tracked in the computation + graph. """ if not ivy.exists(self.top_mod): return False @@ -218,7 +219,7 @@ def check_submod_rets(self): """ Compares the submodule returns with the expected submodule returns passed during call (Need discussion) - + Returns ------- ret @@ -260,7 +261,7 @@ def track_submod_call_order(self): def mod_depth(self): """ Return the depth of the current module. - + Returns ------- ret @@ -279,7 +280,7 @@ def mod_depth(self): def mod_height(self): """ Return the height of the current module. - + Returns ------- ret @@ -290,7 +291,7 @@ def mod_height(self): def _find_variables(self, obj=None): """ Find all interval varibles in obj. Return empty Container if obj is None. - + Parameters ---------- obj @@ -336,15 +337,15 @@ def _find_variables(self, obj=None): @staticmethod def _extract_v(v, keychain_mappings: dict, orig_key_chain): """ - - + + Parameters ---------- v keychain_mappings orig_key_chain - - + + Returns ------- ret_cont @@ -363,14 +364,14 @@ def _extract_v(v, keychain_mappings: dict, orig_key_chain): def _wrap_call_methods(self, keychain_mappings, key="", obj=None): """ (TODO) - + Parameters ---------- keychain_mappings key obj - - + + Returns ------- None @@ -406,14 +407,14 @@ def _wrap_call_methods(self, keychain_mappings, key="", obj=None): def _remove_duplicate_variables(vs, created): """ Remove duplicate variables in `vs` referring to `created`. - + Parameters ---------- vs The container that needs to be pruned. created The container as the duplication reference. - + Returns ------- vs @@ -458,7 +459,7 @@ def _create_variables(self, device, dtype): ---------- device The device string, specifying the device on which to create the variables. - + Returns ------- ret @@ -469,7 +470,7 @@ def _create_variables(self, device, dtype): def _build(self, *args, **kwargs) -> bool: """ Build the internal layers and variables for this module. Overridable. - + Returns ------- ret @@ -486,7 +487,7 @@ def _forward(self, *args, **kwargs): """ Forward pass of the layer, called after handling the optional input variables. - + Raises ------ NotImplementedError @@ -497,7 +498,7 @@ def _forward_with_tracking(self, *args, **kwargs): """ Forward pass while optionally tracking submodule returns and call order. - + Returns ------- ret @@ -518,7 +519,7 @@ def _call(self, *args, v=None, with_grads=None, **kwargs): """ The forward pass of the layer, treating layer instance as callable function. - + Parameters ---------- v @@ -526,7 +527,7 @@ def _call(self, *args, v=None, with_grads=None, **kwargs): after the forward finished. with_grads Whether to forward with gradients. - + Returns ------- ret @@ -565,18 +566,23 @@ def _call(self, *args, v=None, with_grads=None, **kwargs): def sub_mods(self, show_v=True, depth=None, flatten_key_chains=False): """ - - + Return a container composing of all submodules. + Parameters ---------- show_v + If set True, will return values of all submodule variables. + Default is True. depth + How many layers we step in before beginning enumerating submodules. + None for current layer. Default is None. flatten_key_chains - - + If set True, will return return a flat (depth-1) container, + which all nested key-chains flattened. Default is False. Returns ------- - + ret + A container composing of all submodules. """ if self._sub_mods: if ivy.exists(depth): @@ -604,13 +610,15 @@ def sub_mods(self, show_v=True, depth=None, flatten_key_chains=False): def show_v_in_top_v(self, depth=None): """ - - + Show sub containers from the perspective of value of top layer. + Will give prompt if either of `v` and `top_v` is not initialized. + Parameters ---------- depth - - + The number of modules we want to step in. None for the value of + current module. Default is None. + Returns ------- None @@ -626,16 +634,22 @@ def show_v_in_top_v(self, depth=None): def v_with_top_v_key_chains(self, depth=None, flatten_key_chains=False): """ - - + Show current layer from the perspective of value of top layer. + Will give prompt if either of `v` and `top_v` is not initialized. + Parameters ---------- depth + The number of modules we want to step in. None for the value of + current module. Default is None. flatten_key_chains - + If set True, will return return a flat (depth-1) container, + which all nested key-chains flattened. Default is False. + Returns ------- ret + """ if ivy.exists(self.top_v) and ivy.exists(self.v): kc = self.top_v(depth).find_sub_container(self.v) @@ -655,16 +669,19 @@ def v_with_top_v_key_chains(self, depth=None, flatten_key_chains=False): def mod_with_top_mod_key_chain(self, depth=None, flatten_key_chain=False): """ - - + (TODO) + Parameters ---------- depth - flatten_key_chain - + + flatten_key_chains + If set True, will return return a flat (depth-1) container, + which all nested key-chains flattened. Default is False. + Returns ------- - + """ if not ivy.exists(self.top_mod) or depth == 0: return self.__repr__() @@ -690,14 +707,22 @@ def show_mod_in_top_mod( self, upper_depth=None, lower_depth=None, flatten_key_chains=False ): """ - - + Show lower submodules in the top module. `uppper_depth` and `lower_depth` + are for controlling the coverage of upper and lower modules. + Will give prompt if no top module found. + Parameters ---------- upper_depth + How many modules it tracks up as upper module. None for current module. + Default is None. Will be truncated to mod_depth. lower_depth - flatten_key_chain - + How many modules it tracks down. None for current module. + Default is None. Will be truncated to mod_height. + flatten_key_chains + If set True, will return return a flat (depth-1) container, + which all nested key-chains flattened. Default is False. + Returns ------- None @@ -727,16 +752,22 @@ def _set_submod_flags( expected_submod_rets, ): """ - - + Set flags of the submodule. + Parameters ---------- track_submod_rets + If True, will track the returns of submodules. submod_depth + The depth of tracked submodules. submods_to_track + If given, will only tracks submodules in `submods_to_track`. track_submod_call_order + If True, will tracks the call order of submodules. expected_submod_rets - + If given, will raise exception if submodule returns are + different from expected returns. + Returns ------- None @@ -753,8 +784,8 @@ def _set_submod_flags( def _unset_submod_flags(self): """ - - + Unset flags of the submodule. + Returns ------- None @@ -767,15 +798,17 @@ def _unset_submod_flags(self): def get_mod_key(self, top_mod=None): """ - - + Get the key of current module. + Parameters ---------- top_mod - + Explicit indicate the top module. None for the top + module of current module. Default is None. + Returns ------- - + A string of current module key. """ if top_mod is None: top_mod = self.top_mod() @@ -792,12 +825,13 @@ def get_mod_key(self, top_mod=None): def _add_submod_ret(self, ret): """ - - + Add returns in the submodule return of the top module. + Parameters ---------- ret - + The return you want to add. + Returns ------- None @@ -813,8 +847,9 @@ def _add_submod_ret(self, ret): def _check_submod_ret(self): """ - - + Check submodule returns with expected submodule returns. + Raise AssertError if returns are not close enough. + Returns ------- None @@ -859,11 +894,12 @@ def _check_submod_ret(self): # noinspection PyProtectedMember def _is_submod_leaf(self): """ - checks if the submodule is the leaf node of the network. - + Checks if the submodule is the leaf node of the network. + Returns ------- - + ret + True if the submodule is the leaf node of the network. """ submod_depth = self.top_mod()._submod_depth submods_to_track = self.top_mod()._submods_to_track @@ -875,8 +911,8 @@ def _is_submod_leaf(self): def _add_submod_enter(self): """ - - + (TODO) + Returns ------- None @@ -934,6 +970,7 @@ def __call__( *args, v=None, with_grads=None, + # consider remove unused parameters? stateful=None, arg_stateful_idxs=None, kwarg_stateful_idxs=None, @@ -945,21 +982,34 @@ def __call__( **kwargs ): """ - - + Forward an input through current module. + Parameters ---------- v + If given, use this container as internal varibles temporarily. + Default is None. with_grads + If True, forward this pass with gradients. + + ### (TODO) Unused Parameters Below ### stateful arg_stateful_idxs kwarg_stateful_idxs + ### (TODO) Unused Parameters Above ### + track_submod_rets + If True, will track the returns of submodules. submod_depth + The depth of tracked submodules. submods_to_track + If given, will only tracks submodules in `submods_to_track`. track_submod_call_order - expected_submo_rets - + If True, will tracks the call order of submodules. + expected_submod_rets + If given, will raise exception if submodule returns are + different from expected returns. + Returns ------- ret @@ -991,7 +1041,7 @@ def save_weights(self, weights_path): ---------- weights_path The hdf5 file for saving the weights. - + Returns ------- None @@ -1002,16 +1052,21 @@ def save_weights(self, weights_path): def build(self, *args, from_call=False, device=None, dtype=None, **kwargs): """ Build the internal layers and variables for this module. - + Parameters ---------- from_call + If True, denote that this build is triggered by calling. Otherwise + triggered by initializing the module. Default is False. device + The device we want to build module on. None for default device. + Default is None. dtype - + The data type for building the module. Default is None. Returns ------- - + ret + True for successfully built a module. """ self._dev = ivy.default(device, self._dev) # return False if not from_call but build_mode is on_call @@ -1094,10 +1149,11 @@ def build(self, *args, from_call=False, device=None, dtype=None, **kwargs): def show_structure(self): """ Prints the structure of the layer network. - + Returns ------- this_repr + String of the stucture of the module. """ this_repr = termcolor.colored(object.__repr__(self), "green") sub_mod_repr = self.sub_mods(False).__repr__()