Skip to content

Commit

Permalink
fix typehints conv1d (ivy-llc#9880)
Browse files Browse the repository at this point in the history
Co-authored-by: WilliamHirst [email protected]
  • Loading branch information
xoiga123 authored Jan 21, 2023
1 parent 54e4795 commit fbdd3ed
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 33 deletions.
12 changes: 6 additions & 6 deletions ivy/array/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,12 +303,12 @@ def multi_head_attention(
def conv1d(
self: ivy.Array,
filters: Union[ivy.Array, ivy.NativeArray],
strides: int,
strides: Union[int, Tuple[int]],
padding: str,
/,
*,
data_format: str = "NWC",
dilations: int = 1,
dilations: Union[int, Tuple[int]] = 1,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Expand All @@ -318,17 +318,17 @@ def conv1d(
Parameters
----------
x
Input image *[batch_size,w,d_in]*.
self
Input image *[batch_size,w,d_in]* or *[batch_size,d_in,w]*.
filters
Convolution filters *[fw,d_in,d_out]*.
strides
The stride of the sliding window for each dimension of input.
padding
SAME" or "VALID" indicating the algorithm, or list indicating the
"SAME" or "VALID" indicating the algorithm, or list indicating the
per-dimension paddings.
data_format
NWC" or "NCW". Defaults to "NWC".
"NWC" or "NCW". Defaults to "NWC".
dilations
The dilation factor for each dimension of input. (Default value = 1)
out
Expand Down
44 changes: 32 additions & 12 deletions ivy/container/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ def static_scaled_dot_product_attention(
prune_unapplied: bool = False,
map_sequences: bool = False,
out: Optional[ivy.Container] = None,
) -> Union[ivy.Array, ivy.NativeArray, ivy.Container]:
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.scaled_dot_product_attention.
This method simply wraps the function, and so the docstring for
Expand All @@ -519,7 +519,7 @@ def static_scaled_dot_product_attention(
Parameters
----------
self
q
The queries input container. The shape of queries input array leaves should
be in *[batch_shape,num_queries,feat_dim]*. The queries input array leaves
should have the same size as keys and values.
Expand Down Expand Up @@ -592,7 +592,6 @@ def static_scaled_dot_product_attention(
[4.4, 5.6],
[4.4, 5.6]]])
}
"""
return ContainerBase.cont_multi_map_in_function(
"scaled_dot_product_attention",
Expand Down Expand Up @@ -621,9 +620,9 @@ def scaled_dot_product_attention(
prune_unapplied: bool = False,
map_sequences: bool = False,
out: Optional[ivy.Container] = None,
) -> Union[ivy.Array, ivy.NativeArray, ivy.Container]:
) -> ivy.Container:
"""
ivy.Container method variant of ivy.scaled_dot_product_attention.
ivy.Container instance method variant of ivy.scaled_dot_product_attention.
This method simply wraps the function, and so the docstring for
ivy.scaled_dot_product_attention also applies to this method with minimal
changes.
Expand Down Expand Up @@ -702,7 +701,6 @@ def scaled_dot_product_attention(
[4.4, 5.6],
[4.4, 5.6]]])
}
"""
return self.static_scaled_dot_product_attention(
self,
Expand Down Expand Up @@ -801,12 +799,12 @@ def multi_head_attention(
def static_conv1d(
x: Union[ivy.Array, ivy.NativeArray, ivy.Container],
filters: Union[ivy.Array, ivy.NativeArray, ivy.Container],
strides: int,
strides: Union[int, Tuple[int]],
padding: str,
/,
*,
data_format: str = "NWC",
dilations: int = 1,
dilations: Union[int, Tuple[int]] = 1,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
Expand All @@ -833,6 +831,17 @@ def static_conv1d(
"NWC" or "NCW". Defaults to "NWC".
dilations
The dilation factor for each dimension of input. (Default value = 1)
key_chains
The key-chains to apply or not apply the method to. Default is ``None``.
to_apply
If True, the method will be applied to key_chains, otherwise key_chains
will be skipped. Default is ``True``.
prune_unapplied
Whether to prune key_chains for which the function was not applied.
Default is ``False``.
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
Expand Down Expand Up @@ -873,13 +882,13 @@ def static_conv1d(

def conv1d(
self: ivy.Container,
filters: Union[ivy.Array, ivy.NativeArray],
strides: int,
filters: Union[ivy.Array, ivy.NativeArray, ivy.Container],
strides: Union[int, Tuple[int]],
padding: str,
/,
*,
data_format: str = "NWC",
dilations: int = 1,
dilations: Union[int, Tuple[int]] = 1,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
Expand All @@ -893,7 +902,7 @@ def conv1d(
Parameters
----------
x
self
Input image *[batch_size,w, d_in]*.
filters
Convolution filters *[fw,d_in, d_out]*. (d_in must be the same as d from x)
Expand All @@ -906,6 +915,17 @@ def conv1d(
"NWC" or "NCW". Defaults to "NWC".
dilations
The dilation factor for each dimension of input. (Default value = 1)
key_chains
The key-chains to apply or not apply the method to. Default is ``None``.
to_apply
If True, the method will be applied to key_chains, otherwise key_chains
will be skipped. Default is ``True``.
prune_unapplied
Whether to prune key_chains for which the function was not applied.
Default is ``False``.
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
out
optional output container, for writing the result to. It must have a shape
that the inputs broadcast to.
Expand Down
6 changes: 3 additions & 3 deletions ivy/functional/backends/jax/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ def _conv_transpose_padding(k, s, padding, dilation, diff=0):
def conv1d(
x: JaxArray,
filters: JaxArray,
strides: int,
strides: Union[int, Tuple[int]],
padding: str,
/,
*,
data_format: Optional[str] = "NWC",
dilations: Optional[int] = 1,
data_format: str = "NWC",
dilations: Union[int, Tuple[int]] = 1,
out: Optional[JaxArray] = None,
) -> JaxArray:
strides = (strides,) if isinstance(strides, int) else strides
Expand Down
6 changes: 3 additions & 3 deletions ivy/functional/backends/numpy/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ def _add_dilations(x, dilations, axis):
def conv1d(
x: np.ndarray,
filters: np.ndarray,
strides: int,
strides: Union[int, Tuple[int]],
padding: str,
/,
*,
data_format: Optional[str] = "NWC",
dilations: Optional[int] = 1,
data_format: str = "NWC",
dilations: Union[int, Tuple[int]] = 1,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
if isinstance(strides, (tuple, list)):
Expand Down
6 changes: 3 additions & 3 deletions ivy/functional/backends/tensorflow/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
def conv1d(
x: Union[tf.Tensor, tf.Variable],
filters: Union[tf.Tensor, tf.Variable],
strides: int,
strides: Union[int, Tuple[int]],
padding: str,
/,
*,
data_format: Optional[str] = "NWC",
dilations: Optional[int] = 1,
data_format: str = "NWC",
dilations: Union[int, Tuple[int]] = 1,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
if data_format == "NCW":
Expand Down
6 changes: 3 additions & 3 deletions ivy/functional/backends/torch/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ def _out_shape(x, strides, pad, dilations, filters):
def conv1d(
x: torch.Tensor,
filters: torch.Tensor,
strides: int,
strides: Union[int, Tuple[int]],
padding: str,
/,
*,
data_format: Optional[str] = "NWC",
dilations: Optional[int] = 1,
data_format: str = "NWC",
dilations: Union[int, Tuple[int]] = 1,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(strides, (tuple, list)):
Expand Down
10 changes: 7 additions & 3 deletions ivy/functional/ivy/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,12 +771,12 @@ def call_einops(t):
def conv1d(
x: Union[ivy.Array, ivy.NativeArray],
filters: Union[ivy.Array, ivy.NativeArray],
strides: int,
strides: Union[int, Tuple[int]],
padding: str,
/,
*,
data_format: Optional[str] = "NWC",
dilations: Optional[int] = 1,
data_format: str = "NWC",
dilations: Union[int, Tuple[int]] = 1,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""Computes a 1-D convolution given 3-D input x and filters arrays.
Expand Down Expand Up @@ -807,6 +807,10 @@ def conv1d(
ret
The result of the convolution operation.
Both the description and the type hints above assumes an array input for simplicity,
but this function is *nestable*, and therefore also accepts :class:`ivy.Container`
instances in place of any of the arguments.
Examples
--------
With :class:`ivy.Array` input:
Expand Down

0 comments on commit fbdd3ed

Please sign in to comment.