From 951b15febcb152426216622bfed10f10837c29a5 Mon Sep 17 00:00:00 2001 From: Zilin Xiao Date: Tue, 6 Sep 2022 15:47:39 +0000 Subject: [PATCH] docstring add #1 (Need discussion) and (TODO) flags added for further discussion --- ivy/stateful/module.py | 88 ++++++++++++++++++++++++++++-------------- 1 file changed, 58 insertions(+), 30 deletions(-) diff --git a/ivy/stateful/module.py b/ivy/stateful/module.py index d3feaddf88f5c..a2fd8fe02cc1f 100644 --- a/ivy/stateful/module.py +++ b/ivy/stateful/module.py @@ -19,6 +19,9 @@ class Module(abc.ABC): + """ + Module is a base class for deriving trainable modules. + """ def __init__( self, device=None, @@ -134,7 +137,7 @@ def new_fn(*a, with_grads=None, **kw): def _top_v_fn(self, depth=None, flatten_key_chains=False): """ - The method helps in visualising the top view of a nested network upto + Help in visualising the top view of a nested network upto a certain depth Parameters @@ -147,7 +150,9 @@ def _top_v_fn(self, depth=None, flatten_key_chains=False): Returns ------- - + ret + + """ if ivy.exists(self.top_v): if ivy.exists(depth): @@ -162,16 +167,18 @@ def _top_v_fn(self, depth=None, flatten_key_chains=False): def _top_mod_fn(self, depth=None): """ - + Find the top module at specific depth. Parameters ---------- depth - + The number of modules we want to trace back. Returns ------- - + ret + The module we want to track down. Return self if no top + module exists. """ if ivy.exists(self.top_mod): if ivy.exists(depth): @@ -183,11 +190,13 @@ def _top_mod_fn(self, depth=None): def track_submod_rets(self): """ Tracks the returns of the submodules if track_submod_returns - argument is set to True during call + argument is set to True during call (Need discussion) Returns ------- - + ret + Bool flag indicating whether the current module gets + tracked in the computation graph. """ if not ivy.exists(self.top_mod): return False @@ -207,12 +216,13 @@ def track_submod_rets(self): def check_submod_rets(self): """ - compares the submodule returns with the expected submodule - returns passed during call + Compares the submodule returns with the expected submodule + returns passed during call (Need discussion) Returns ------- - + ret + True if the top module has expected_submod_rets. """ if not ivy.exists(self.top_mod): return False @@ -224,10 +234,12 @@ def check_submod_rets(self): def track_submod_call_order(self): """ Tracks the order in which the submodules are called. - + (Need discussion) + Returns ------- - + ret + True if the current module allows call order tracking. """ if not ivy.exists(self.top_mod): return False @@ -247,11 +259,12 @@ def track_submod_call_order(self): def mod_depth(self): """ - + Return the depth of the current module. Returns ------- - the depth of the module in the network. + ret + The depth of the module in the network. Return 0 for root module. """ depth = 0 mod_above = self @@ -265,26 +278,29 @@ def mod_depth(self): def mod_height(self): """ - + Return the height of the current module. Returns ------- - The height of the network + ret + The height of the network. Return 0 for leaf module. """ return self.sub_mods().max_depth - 1 def _find_variables(self, obj=None): """ - + Find all interval varibles in obj. Return empty Container if obj is None. Parameters ---------- obj - The submodule whose internal variables are to be returned. + The submodule whose internal variables are to be returned. Default + is None. Returns ------- - The internal variables of the submodule passed in the argument. + ret + The internal variables of the submodule passed in the argument. """ vs = Container() # ToDo: add support for finding local variables, if/when JAX supports @@ -318,7 +334,7 @@ def _find_variables(self, obj=None): return vs @staticmethod - def _extract_v(v, keychain_mappings, orig_key_chain): + def _extract_v(v, keychain_mappings: dict, orig_key_chain): """ @@ -346,7 +362,7 @@ def _extract_v(v, keychain_mappings, orig_key_chain): def _wrap_call_methods(self, keychain_mappings, key="", obj=None): """ - + (TODO) Parameters ---------- @@ -389,17 +405,21 @@ def _wrap_call_methods(self, keychain_mappings, key="", obj=None): @staticmethod def _remove_duplicate_variables(vs, created): """ - + Remove duplicate variables in `vs` referring to `created`. Parameters ---------- vs + The container that needs to be pruned. created - + The container as the duplication reference. Returns ------- - + vs + The container after removing duplicate variables. + keychain_mappings + Dict storing those keys and ids being removed. """ created_ids = created.map(lambda x, kc: id(x)) vs_ids = vs.map(lambda x, kc: id(x)) @@ -441,7 +461,8 @@ def _create_variables(self, device, dtype): Returns ------- - An empty set + ret + An empty set. """ return {} @@ -451,9 +472,10 @@ def _build(self, *args, **kwargs) -> bool: Returns ------- - False or empty Container if the build only partially completed (i.e. some - child Modules have "on_call" build mode). Alternatviely, return True or a - container of the built variables if the module is built. + ret + False or empty Container if the build only partially completed (i.e. some + child Modules have "on_call" build mode). Alternatviely, return True or a + container of the built variables if the module is built. """ return True @@ -473,11 +495,13 @@ def _forward(self, *args, **kwargs): def _forward_with_tracking(self, *args, **kwargs): """ - Forward pass while optionally tracking submodule returns and call order + Forward pass while optionally tracking submodule returns + and call order. Returns ------- ret + Result of the forward pass of the layer. """ if self.track_submod_call_order(): self._add_submod_enter() @@ -498,11 +522,15 @@ def _call(self, *args, v=None, with_grads=None, **kwargs): Parameters ---------- v + Replace `v` of current layer when forwarding. Restore + after the forward finished. with_grads + Whether to forward with gradients. Returns ------- - + ret + Result of the forward pass of the layer. """ with_grads = ivy.with_grads(with_grads=with_grads) if not self._built: