Skip to content

Commit

Permalink
docstring add #1
Browse files Browse the repository at this point in the history
(Need discussion) and (TODO) flags added for further discussion
  • Loading branch information
MrZilinXiao committed Sep 6, 2022
1 parent c2b4e25 commit 951b15f
Showing 1 changed file with 58 additions and 30 deletions.
88 changes: 58 additions & 30 deletions ivy/stateful/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@


class Module(abc.ABC):
"""
Module is a base class for deriving trainable modules.
"""
def __init__(
self,
device=None,
Expand Down Expand Up @@ -134,7 +137,7 @@ def new_fn(*a, with_grads=None, **kw):

def _top_v_fn(self, depth=None, flatten_key_chains=False):
"""
The method helps in visualising the top view of a nested network upto
Help in visualising the top view of a nested network upto
a certain depth
Parameters
Expand All @@ -147,7 +150,9 @@ def _top_v_fn(self, depth=None, flatten_key_chains=False):
Returns
-------
ret
"""
if ivy.exists(self.top_v):
if ivy.exists(depth):
Expand All @@ -162,16 +167,18 @@ def _top_v_fn(self, depth=None, flatten_key_chains=False):

def _top_mod_fn(self, depth=None):
"""
Find the top module at specific depth.
Parameters
----------
depth
The number of modules we want to trace back.
Returns
-------
ret
The module we want to track down. Return self if no top
module exists.
"""
if ivy.exists(self.top_mod):
if ivy.exists(depth):
Expand All @@ -183,11 +190,13 @@ def _top_mod_fn(self, depth=None):
def track_submod_rets(self):
"""
Tracks the returns of the submodules if track_submod_returns
argument is set to True during call
argument is set to True during call (Need discussion)
Returns
-------
ret
Bool flag indicating whether the current module gets
tracked in the computation graph.
"""
if not ivy.exists(self.top_mod):
return False
Expand All @@ -207,12 +216,13 @@ def track_submod_rets(self):

def check_submod_rets(self):
"""
compares the submodule returns with the expected submodule
returns passed during call
Compares the submodule returns with the expected submodule
returns passed during call (Need discussion)
Returns
-------
ret
True if the top module has expected_submod_rets.
"""
if not ivy.exists(self.top_mod):
return False
Expand All @@ -224,10 +234,12 @@ def check_submod_rets(self):
def track_submod_call_order(self):
"""
Tracks the order in which the submodules are called.
(Need discussion)
Returns
-------
ret
True if the current module allows call order tracking.
"""
if not ivy.exists(self.top_mod):
return False
Expand All @@ -247,11 +259,12 @@ def track_submod_call_order(self):

def mod_depth(self):
"""
Return the depth of the current module.
Returns
-------
the depth of the module in the network.
ret
The depth of the module in the network. Return 0 for root module.
"""
depth = 0
mod_above = self
Expand All @@ -265,26 +278,29 @@ def mod_depth(self):

def mod_height(self):
"""
Return the height of the current module.
Returns
-------
The height of the network
ret
The height of the network. Return 0 for leaf module.
"""
return self.sub_mods().max_depth - 1

def _find_variables(self, obj=None):
"""
Find all interval varibles in obj. Return empty Container if obj is None.
Parameters
----------
obj
The submodule whose internal variables are to be returned.
The submodule whose internal variables are to be returned. Default
is None.
Returns
-------
The internal variables of the submodule passed in the argument.
ret
The internal variables of the submodule passed in the argument.
"""
vs = Container()
# ToDo: add support for finding local variables, if/when JAX supports
Expand Down Expand Up @@ -318,7 +334,7 @@ def _find_variables(self, obj=None):
return vs

@staticmethod
def _extract_v(v, keychain_mappings, orig_key_chain):
def _extract_v(v, keychain_mappings: dict, orig_key_chain):
"""
Expand Down Expand Up @@ -346,7 +362,7 @@ def _extract_v(v, keychain_mappings, orig_key_chain):

def _wrap_call_methods(self, keychain_mappings, key="", obj=None):
"""
(TODO)
Parameters
----------
Expand Down Expand Up @@ -389,17 +405,21 @@ def _wrap_call_methods(self, keychain_mappings, key="", obj=None):
@staticmethod
def _remove_duplicate_variables(vs, created):
"""
Remove duplicate variables in `vs` referring to `created`.
Parameters
----------
vs
The container that needs to be pruned.
created
The container as the duplication reference.
Returns
-------
vs
The container after removing duplicate variables.
keychain_mappings
Dict storing those keys and ids being removed.
"""
created_ids = created.map(lambda x, kc: id(x))
vs_ids = vs.map(lambda x, kc: id(x))
Expand Down Expand Up @@ -441,7 +461,8 @@ def _create_variables(self, device, dtype):
Returns
-------
An empty set
ret
An empty set.
"""
return {}

Expand All @@ -451,9 +472,10 @@ def _build(self, *args, **kwargs) -> bool:
Returns
-------
False or empty Container if the build only partially completed (i.e. some
child Modules have "on_call" build mode). Alternatviely, return True or a
container of the built variables if the module is built.
ret
False or empty Container if the build only partially completed (i.e. some
child Modules have "on_call" build mode). Alternatviely, return True or a
container of the built variables if the module is built.
"""
return True

Expand All @@ -473,11 +495,13 @@ def _forward(self, *args, **kwargs):

def _forward_with_tracking(self, *args, **kwargs):
"""
Forward pass while optionally tracking submodule returns and call order
Forward pass while optionally tracking submodule returns
and call order.
Returns
-------
ret
Result of the forward pass of the layer.
"""
if self.track_submod_call_order():
self._add_submod_enter()
Expand All @@ -498,11 +522,15 @@ def _call(self, *args, v=None, with_grads=None, **kwargs):
Parameters
----------
v
Replace `v` of current layer when forwarding. Restore
after the forward finished.
with_grads
Whether to forward with gradients.
Returns
-------
ret
Result of the forward pass of the layer.
"""
with_grads = ivy.with_grads(with_grads=with_grads)
if not self._built:
Expand Down

0 comments on commit 951b15f

Please sign in to comment.