Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
QuantuMope committed Sep 24, 2024
1 parent fadd4ef commit edb31e1
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 19 deletions.
18 changes: 13 additions & 5 deletions alf/config_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def adjust_config_by_multi_process_divider(ddp_rank: int,
config1(
tag,
math.ceil(num_parallel_environments / multi_process_divider),
raise_if_used=False)
raise_if_used=False,
override_sole_init=True)

# Adjust the mini_batch_size. If the original configured value is 64 and
# there are 4 processes, it should mean that "jointly the 4 processes have
Expand All @@ -165,7 +166,8 @@ def adjust_config_by_multi_process_divider(ddp_rank: int,
config1(
tag,
math.ceil(mini_batch_size / multi_process_divider),
raise_if_used=False)
raise_if_used=False,
override_sole_init=True)

# If the termination condition is num_env_steps instead of num_iterations,
# we need to adjust it as well since each process only sees env steps taking
Expand All @@ -176,20 +178,26 @@ def adjust_config_by_multi_process_divider(ddp_rank: int,
config1(
tag,
math.ceil(num_env_steps / multi_process_divider),
raise_if_used=False)
raise_if_used=False,
override_sole_init=True)

tag = 'TrainerConfig.initial_collect_steps'
init_collect_steps = get_config_value(tag)
config1(
tag,
math.ceil(init_collect_steps / multi_process_divider),
raise_if_used=False)
raise_if_used=False,
override_sole_init=True)

# Only allow process with rank 0 to have evaluate. Enabling evaluation for
# other parallel processes is a waste as such evaluation does not offer more
# information.
if ddp_rank > 0:
config1('TrainerConfig.evaluate', False, raise_if_used=False)
config1(
'TrainerConfig.evaluate',
False,
raise_if_used=False,
override_sole_init=True)


def parse_config(conf_file, conf_params, create_env=True):
Expand Down
4 changes: 2 additions & 2 deletions alf/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
'get_operative_configs',
'import_config',
'load_config',
'override_config',
'override_sole_config',
'pre_config',
'reset_configs',
'validate_pre_configs',
Expand Down Expand Up @@ -147,7 +147,7 @@ def func(self, a, b):
override_sole_init)


def override_config(prefix_or_dict, **kwargs):
def override_sole_config(prefix_or_dict, **kwargs):
"""Wrapper function for configuring a config with override_sole_init=True.
This call allows a user to attempt to overwrite a config's value even
Expand Down
26 changes: 14 additions & 12 deletions alf/config_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,26 +172,26 @@ def sole_init_test_env(x):
alf.config("sole_init_test_env", x=0)
os.environ["ALF_SOLE_CONFIG"] = "0"

# Testing alf.override_config
alf.override_config("sole_init_test_prior", x=1)
alf.override_config("sole_init_test_after", x=1)
alf.override_config("sole_init_test_twice", x=1)
alf.override_config("sole_init_test_env", x=1)
# Testing alf.override_sole_config
alf.override_sole_config("sole_init_test_prior", x=1)
alf.override_sole_config("sole_init_test_after", x=1)
alf.override_sole_config("sole_init_test_twice", x=1)
alf.override_sole_config("sole_init_test_env", x=1)
self.assertEqual(alf.get_config_value("sole_init_test_prior.x"), 1)
self.assertEqual(alf.get_config_value("sole_init_test_after.x"), 1)
self.assertEqual(alf.get_config_value("sole_init_test_twice.x"), 1)
self.assertEqual(alf.get_config_value("sole_init_test_env.x"), 1)

# Test override_config doesn't overwrite for immutable values.
# Test override_sole_config doesn't overwrite for immutable values.
@alf.configurable
def override_on_immutable(x):
pass

alf.config("override_on_immutable", x=0, mutable=False)
alf.override_config("override_on_immutable", x=1)
alf.override_sole_config("override_on_immutable", x=1)
self.assertEqual(alf.get_config_value("override_on_immutable.x"), 0)

# Test override_config doesn't doesn't overwrite for immutable values.
# Test override_sole_config doesn't doesn't overwrite for immutable values.
@alf.configurable
def override_on_immutable_and_sole_init(x):
pass
Expand All @@ -201,7 +201,7 @@ def override_on_immutable_and_sole_init(x):
x=0,
sole_init=True,
mutable=False)
alf.override_config("override_on_immutable_and_sole_init", x=1)
alf.override_sole_config("override_on_immutable_and_sole_init", x=1)
self.assertEqual(
alf.get_config_value("override_on_immutable_and_sole_init.x"), 0)

Expand All @@ -212,7 +212,7 @@ def pre_config_before(x):

alf.pre_config({"pre_config_before.x": 0})
alf.config("pre_config_before", x=1)
alf.override_config("pre_config_before", x=1)
alf.override_sole_config("pre_config_before", x=1)
# sole_init starts to take effect for all calls AFTER the first call.
alf.config("pre_config_before", x=1, sole_init=True)
with self.assertRaises(RuntimeError) as context:
Expand All @@ -229,15 +229,17 @@ def pre_config_after(x):
with self.assertRaises(RuntimeError) as context:
alf.pre_config({"pre_config_after.x": 0})

# Test that calling override_config doesn't affect previous sole_init calls.
# Test that calling override_sole_config doesn't affect previous sole_init calls.
@alf.configurable
def override_no_affect_sole_init(x):
pass

alf.config("override_no_affect_sole_init", x=1, sole_init=True)
alf.override_config("override_no_affect_sole_init", x=2)
alf.override_sole_config("override_no_affect_sole_init", x=2)
with self.assertRaises(RuntimeError) as context:
alf.config("override_no_affect_sole_init", x=3)
self.assertEqual(
alf.get_config_value("override_no_affect_sole_init.x"), 2)

def test_repr_wrapper(self):
a = MyClass(1, 2)
Expand Down
24 changes: 24 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) 2024 Horizon Robotics and ALF Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import alf


@alf.configurable
def override_on_immutable(x):
pass


alf.config("override_on_immutable", x=0, mutable=False)
alf.override_sole_config("override_on_immutable", x=1)

0 comments on commit edb31e1

Please sign in to comment.