Skip to content

Commit

Permalink
feat: Updated black formatter to latest version and modified files …
Browse files Browse the repository at this point in the history
…accordingly (#28282)
  • Loading branch information
Sai-Suraj-27 authored Feb 14, 2024
1 parent d161925 commit d006c3d
Show file tree
Hide file tree
Showing 35 changed files with 743 additions and 596 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repos:
- id: ruff
args: [ --fix ]
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.1.1
rev: 24.2.0
hooks:
- id: black
language_version: python3
Expand Down
80 changes: 41 additions & 39 deletions ivy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,45 +930,47 @@ def __deepcopy__(self, memo):


# defines ivy.globals attribute
globals_vars = GlobalsDict({
"backend_stack": backend_stack,
"default_device_stack": device.default_device_stack,
"valid_dtypes": valid_dtypes,
"valid_numeric_dtypes": valid_numeric_dtypes,
"valid_int_dtypes": valid_int_dtypes,
"valid_uint_dtypes": valid_uint_dtypes,
"valid_complex_dtypes": valid_complex_dtypes,
"valid_devices": valid_devices,
"invalid_dtypes": invalid_dtypes,
"invalid_numeric_dtypes": invalid_numeric_dtypes,
"invalid_int_dtypes": invalid_int_dtypes,
"invalid_float_dtypes": invalid_float_dtypes,
"invalid_uint_dtypes": invalid_uint_dtypes,
"invalid_complex_dtypes": invalid_complex_dtypes,
"invalid_devices": invalid_devices,
"array_significant_figures_stack": array_significant_figures_stack,
"array_decimal_values_stack": array_decimal_values_stack,
"warning_level_stack": warning_level_stack,
"queue_timeout_stack": general.queue_timeout_stack,
"array_mode_stack": general.array_mode_stack,
"inplace_mode_stack": general.inplace_mode_stack,
"soft_device_mode_stack": device.soft_device_mode_stack,
"shape_array_mode_stack": general.shape_array_mode_stack,
"show_func_wrapper_trace_mode_stack": general.show_func_wrapper_trace_mode_stack,
"min_denominator_stack": general.min_denominator_stack,
"min_base_stack": general.min_base_stack,
"tmp_dir_stack": general.tmp_dir_stack,
"precise_mode_stack": general.precise_mode_stack,
"nestable_mode_stack": general.nestable_mode_stack,
"exception_trace_mode_stack": general.exception_trace_mode_stack,
"default_dtype_stack": data_type.default_dtype_stack,
"default_float_dtype_stack": data_type.default_float_dtype_stack,
"default_int_dtype_stack": data_type.default_int_dtype_stack,
"default_uint_dtype_stack": data_type.default_uint_dtype_stack,
"nan_policy_stack": nan_policy_stack,
"dynamic_backend_stack": dynamic_backend_stack,
"cython_wrappers_stack": cython_wrappers_stack,
})
globals_vars = GlobalsDict(
{
"backend_stack": backend_stack,
"default_device_stack": device.default_device_stack,
"valid_dtypes": valid_dtypes,
"valid_numeric_dtypes": valid_numeric_dtypes,
"valid_int_dtypes": valid_int_dtypes,
"valid_uint_dtypes": valid_uint_dtypes,
"valid_complex_dtypes": valid_complex_dtypes,
"valid_devices": valid_devices,
"invalid_dtypes": invalid_dtypes,
"invalid_numeric_dtypes": invalid_numeric_dtypes,
"invalid_int_dtypes": invalid_int_dtypes,
"invalid_float_dtypes": invalid_float_dtypes,
"invalid_uint_dtypes": invalid_uint_dtypes,
"invalid_complex_dtypes": invalid_complex_dtypes,
"invalid_devices": invalid_devices,
"array_significant_figures_stack": array_significant_figures_stack,
"array_decimal_values_stack": array_decimal_values_stack,
"warning_level_stack": warning_level_stack,
"queue_timeout_stack": general.queue_timeout_stack,
"array_mode_stack": general.array_mode_stack,
"inplace_mode_stack": general.inplace_mode_stack,
"soft_device_mode_stack": device.soft_device_mode_stack,
"shape_array_mode_stack": general.shape_array_mode_stack,
"show_func_wrapper_trace_mode_stack": general.show_func_wrapper_trace_mode_stack,
"min_denominator_stack": general.min_denominator_stack,
"min_base_stack": general.min_base_stack,
"tmp_dir_stack": general.tmp_dir_stack,
"precise_mode_stack": general.precise_mode_stack,
"nestable_mode_stack": general.nestable_mode_stack,
"exception_trace_mode_stack": general.exception_trace_mode_stack,
"default_dtype_stack": data_type.default_dtype_stack,
"default_float_dtype_stack": data_type.default_float_dtype_stack,
"default_int_dtype_stack": data_type.default_int_dtype_stack,
"default_uint_dtype_stack": data_type.default_uint_dtype_stack,
"nan_policy_stack": nan_policy_stack,
"dynamic_backend_stack": dynamic_backend_stack,
"cython_wrappers_stack": cython_wrappers_stack,
}
)

_default_globals = copy.deepcopy(globals_vars)

Expand Down
183 changes: 102 additions & 81 deletions ivy/data_classes/container/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,15 +1367,17 @@ def cont_flatten_key_chain(
if below_depth and num_keys > below_depth:
pre_keys = flat_keys[0:below_depth]
del flat_keys[0:below_depth]
return "/".join([
k
for k in [
"/".join(pre_keys),
replacement.join(flat_keys),
"/".join(post_keys),
return "/".join(
[
k
for k in [
"/".join(pre_keys),
replacement.join(flat_keys),
"/".join(post_keys),
]
if k
]
if k
])
)

@staticmethod
def cont_trim_key(key, max_length):
Expand Down Expand Up @@ -1697,16 +1699,18 @@ def cont_all_true(
Boolean, whether all entries are boolean True.
"""
return bool(
np.prod([
v
for k, v in self.cont_as_bools(
assert_is_bool,
key_chains,
to_apply,
prune_unapplied,
map_sequences,
).cont_to_iterator()
])
np.prod(
[
v
for k, v in self.cont_as_bools(
assert_is_bool,
key_chains,
to_apply,
prune_unapplied,
map_sequences,
).cont_to_iterator()
]
)
)

def cont_all_false(
Expand Down Expand Up @@ -1742,16 +1746,18 @@ def cont_all_false(
Boolean, whether all entries are boolean False.
"""
return not bool(
np.sum([
v
for k, v in self.cont_as_bools(
assert_is_bool,
key_chains,
to_apply,
prune_unapplied,
map_sequences,
).cont_to_iterator()
])
np.sum(
[
v
for k, v in self.cont_as_bools(
assert_is_bool,
key_chains,
to_apply,
prune_unapplied,
map_sequences,
).cont_to_iterator()
]
)
)

def cont_slice_via_key(self, slice_key):
Expand Down Expand Up @@ -1843,11 +1849,15 @@ def cont_unstack_conts(self, axis, keepdims=False, dim_size=None):
if keepdims:
# noinspection PyTypeChecker
return [
self[(
slice(i, i + 1, 1)
if axis == 0
else tuple([slice(None, None, None)] * axis + [slice(i, i + 1, 1)])
)]
self[
(
slice(i, i + 1, 1)
if axis == 0
else tuple(
[slice(None, None, None)] * axis + [slice(i, i + 1, 1)]
)
)
]
for i in range(dim_size)
]
# noinspection PyTypeChecker
Expand Down Expand Up @@ -3690,10 +3700,12 @@ def _pre_pad_alpha_line(str_in):
padded = True
return "\\n" + indent_str + indented_key_str + str_in

leading_str_to_keep = ", ".join([
_pre_pad_alpha_line(s) if s[0].isalpha() and i != 0 else s
for i, s in enumerate(leading_str_to_keep.split(", "))
])
leading_str_to_keep = ", ".join(
[
_pre_pad_alpha_line(s) if s[0].isalpha() and i != 0 else s
for i, s in enumerate(leading_str_to_keep.split(", "))
]
)
local_indent_str = "" if padded else indent_str
leading_str = leading_str_to_keep.split("\\n")[-1].replace('"', "")
remaining_str = array_str_in_split[1]
Expand All @@ -3710,23 +3722,25 @@ def _pre_pad_alpha_line(str_in):
uniform_indent_wo_overflow_list = list(
filter(None, uniform_indent_wo_overflow.split("\\n"))
)
uniform_indent = "\n".join([
(
local_indent_str + extra_indent + " " + s
if (
s[0].isnumeric()
or s[0] == "-"
or s[0:3] == "..."
or max(ss in s[0:6] for ss in ["nan, ", "inf, "])
)
else (
indent_str + indented_key_str + s
if (not s[0].isspace() and s[0] != '"')
else s
uniform_indent = "\n".join(
[
(
local_indent_str + extra_indent + " " + s
if (
s[0].isnumeric()
or s[0] == "-"
or s[0:3] == "..."
or max(ss in s[0:6] for ss in ["nan, ", "inf, "])
)
else (
indent_str + indented_key_str + s
if (not s[0].isspace() and s[0] != '"')
else s
)
)
)
for s in uniform_indent_wo_overflow_list
])
for s in uniform_indent_wo_overflow_list
]
)
indented = uniform_indent
# 10 dimensions is a sensible upper bound for the number in a single array
for i in range(2, 10):
Expand Down Expand Up @@ -3832,14 +3846,16 @@ def _align_arrays(str_in):
def _add_newline(str_in):
str_in_split = str_in.split("\n")
str_split_size = len(str_in_split)
return "\n".join([
(
("\n" * self._print_line_spacing + ss)
if i == (str_split_size - 1)
else ss
)
for i, ss in enumerate(str_in_split)
])
return "\n".join(
[
(
("\n" * self._print_line_spacing + ss)
if i == (str_split_size - 1)
else ss
)
for i, ss in enumerate(str_in_split)
]
)

json_dumped_str = '":'.join(
[_add_newline(s) for s in json_dumped_str.split('":')]
Expand All @@ -3850,9 +3866,12 @@ def _add_newline(str_in):
json_dumped_str = (
json_dumped_str_split[0]
+ ", "
+ ", ".join([
"'".join(ss.split("'")[1:]) for ss in json_dumped_str_split[1:]
])
+ ", ".join(
[
"'".join(ss.split("'")[1:])
for ss in json_dumped_str_split[1:]
]
)
)
json_dumped_str = (
json_dumped_str.replace(":shape", ", shape")
Expand All @@ -3863,24 +3882,26 @@ def _add_newline(str_in):
# color keys
json_dumped_str_split = json_dumped_str.split('":')
split_size = len(json_dumped_str_split)
json_dumped_str = '":'.join([
(
' "'.join(
sub_str.split(' "')[:-1]
+ [
termcolor.colored(
ivy.Container.cont_trim_key(
sub_str.split(' "')[-1], self._key_length_limit
),
self._default_key_color,
)
]
json_dumped_str = '":'.join(
[
(
' "'.join(
sub_str.split(' "')[:-1]
+ [
termcolor.colored(
ivy.Container.cont_trim_key(
sub_str.split(' "')[-1], self._key_length_limit
),
self._default_key_color,
)
]
)
if i < split_size - 1
else sub_str
)
if i < split_size - 1
else sub_str
)
for i, sub_str in enumerate(json_dumped_str_split)
])
for i, sub_str in enumerate(json_dumped_str_split)
]
)
# remove quotation marks, shape tuple, and color other elements of the dict
ret = (
json_dumped_str.replace('"', "")
Expand Down
10 changes: 6 additions & 4 deletions ivy/data_classes/factorized_tensor/cp_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ def to_unfolded(self, mode):
return ivy.CPTensor.cp_to_unfolded(self, mode)

def cp_copy(self):
return CPTensor((
ivy.copy_array(self.weights),
[ivy.copy_array(self.factors[i]) for i in range(len(self.factors))],
))
return CPTensor(
(
ivy.copy_array(self.weights),
[ivy.copy_array(self.factors[i]) for i in range(len(self.factors))],
)
)

def mode_dot(self, matrix_or_vector, mode, keep_dim=False, copy=True):
"""N-mode product of a CP tensor and a matrix or vector at the
Expand Down
10 changes: 6 additions & 4 deletions ivy/data_classes/factorized_tensor/tucker_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,12 @@ def to_unfolded(self, mode):
return TuckerTensor.tucker_to_unfolded(self, mode)

def tucker_copy(self):
return TuckerTensor((
deepcopy(self.core),
[deepcopy(self.factors[i]) for i in range(len(self.factors))],
))
return TuckerTensor(
(
deepcopy(self.core),
[deepcopy(self.factors[i]) for i in range(len(self.factors))],
)
)

def to_vec(self):
return TuckerTensor.tucker_to_vec(self)
Expand Down
Loading

0 comments on commit d006c3d

Please sign in to comment.