-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] By-pass Evaluation workers when doing OPE #30135
[RLlib] By-pass Evaluation workers when doing OPE #30135
Conversation
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
…e_on_dataset Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
at a high level, I would expect to see some updates to the resource planning util here
https://github.com/gjoliver/ray/blob/master/rllib/offline/dataset_reader.py#L23
we reserve a bit of CPU resources for train dataset here.
maybe we should reserve some for the eval dataset as well, if we want to do things this way.
other than that, this seems like a much better approach than our existing single threaded eval logic.
nice!
rllib/algorithms/algorithm.py
Outdated
# worker anymore | ||
logger.info("Creating evaluation dataset ...") | ||
self.evaluation_dataset, _ = get_dataset_and_shards( | ||
self.evaluation_config, num_workers=0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment that num_workers=0
is quite important. this will return un-sharded dataset as-is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
already did
rllib/algorithms/algorithm.py
Outdated
@@ -616,7 +616,8 @@ def setup(self, config: AlgorithmConfig) -> None: | |||
|
|||
# Evaluation WorkerSet setup. | |||
# User would like to setup a separate evaluation worker set. | |||
if self.config.evaluation_num_workers > 0 or self.config.evaluation_interval: | |||
# Note: We skipp workerset creation if we need to do offline evaluation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: skip
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dumb question: Is there any chance someone might want to do both? offline estimation on a dataset (one run through the complete dataset per iteration) AND have n eval rollout workers take some shot at an actual env? Just for research/debugging reasons.
What I'm saying is that we might want to consider separating the concepts of RolloutWorker/WorkerSet entirely from the concept of a separate dataset reader for offline data intake. Just some things to think about for the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There could be a use-case like that. Right now we don't support it. If we want to support it in the future we should create a separate project and re-define the evaluation config for completely isolating the two (or even more) eval flows.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we even support accepting both today? afaik the input config is like the following:
# for offline
{
"input": "dataset",
"input_config": {"format": xxx, "path": [path1, path2]}
}
# for online
{
"input": "sampler",
"input_config": {"env": "CartPole-v0"}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, let's consolidate the entire input config API. I feel like there should be a single config method: input_setup
or whatever name, which covers both our current environment()
method as well as rollouts()
(and offline_data()
😵💫 !).
To answer your question: Yeah, we have the mixed input, but it's not well explored/tested/configurable:
>>> MixedInput({
... "sampler": 0.4,
... "/tmp/experiences/*.json": 0.4,
... "s3://bucket/expert.json": 0.2,
... }, ioctx)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah this will fall back to the old json reader I think which we want to get rid of anyway.
rllib/algorithms/algorithm.py
Outdated
@@ -616,7 +616,8 @@ def setup(self, config: AlgorithmConfig) -> None: | |||
|
|||
# Evaluation WorkerSet setup. | |||
# User would like to setup a separate evaluation worker set. | |||
if self.config.evaluation_num_workers > 0 or self.config.evaluation_interval: | |||
# Note: We skipp workerset creation if we need to do offline evaluation | |||
if not self.config.get("off_policy_estimation_methods") and (self.config.evaluation_num_workers > 0 or self.config.evaluation_interval) : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.config.off_policy_estimation_methods
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
rllib/algorithms/algorithm.py
Outdated
@@ -2734,6 +2746,12 @@ def _run_one_evaluation( | |||
Returns: | |||
The results dict from the evaluation call. | |||
""" | |||
|
|||
if self.config["off_policy_estimation_methods"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.config.off_policy_estimation_methods
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On my point above:
We should validate in the config, whether the user is trying to kind of do both: off_policy_setimation PLUS configuring some evaluation_duration stuff. This way we can error out and say: For now, you can only do either one of these types of evaluation: off-policy w/ dataset OR via environment rollouts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If off_policy_estimation is specified and the input is not a dataset we should raise an error.
rllib/algorithms/algorithm.py
Outdated
@@ -2817,6 +2835,19 @@ def _run_one_training_iteration_and_evaluation_in_parallel( | |||
|
|||
return results, train_iter_ctx | |||
|
|||
|
|||
def _run_offline_evaluation(self, train_future = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
@@ -37,7 +38,7 @@ def __init__( | |||
Args: | |||
policy: Policy to evaluate. | |||
gamma: Discount factor of the environment. | |||
model: The ModelConfigDict for self.q_model, defaults to: | |||
model_config: The ModelConfigDict for self.q_model, defaults to: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice cleanup.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack.
rllib/offline/offline_evaluator.py
Outdated
n_parallelism: int = os.cpu_count(), | ||
) -> Dict[str, Any]: | ||
|
||
"""Calculates the estmiate of the metrics based on the given offline dataset. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Describe more here how this happens: We run once(?) through the entire dataset?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
Args: | ||
batch: The batch to remove the time dimension from. | ||
Returns: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add empty line
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I understand: Do we assume [B x T x d] or [T x B x d], just [T x d], correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
each row in the dataset it assumed to be [T, d], there are N rows defining N episodes. When you put them in batches you'll get [B, T, d]. If it's a bandit setting you want [B, d]. I'll add more docstring around this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added more docstrings.
def remove_time_dim(batch: pd.DataFrame) -> pd.DataFrame: | ||
"""Removes the time dimension from the given sub-batch of the dataset. | ||
|
||
RLlib assumes each record in the dataset is a single episode. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This information shouldn't go here, b/c it has nothing to do with the utility function per-se here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it. no information leakage from caller :)
SampleBatch.NEXT_OBS, | ||
SampleBatch.DONES, | ||
]: | ||
batch[k] = batch[k].apply(lambda x: x[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if the time dim > 1? Then we lose data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. This should be used only in bandits setting. I have to think about non-bandit setting to see what is the best way to consolidate both.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The solution is to apply this only in the bandit settings for now and fall back to the old slow setup in case of non-bandit. User can specify a bandit setting via setting ope_split_batch_by_episode=False.
Returns: | ||
The batch with the time dimension removed. | ||
""" | ||
for k in [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know it's probably fine for now, but what about possibly other fields? Nested structures?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think nested structures are supported in our offline RL stack anyway. So the assumption right now is that you have one key obs
that is the flattened version of your input, i.e. if you need one-hot encoding, etc. you should do it on your dataset beforehand.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I resolved some issue with the keys that don't exist in the batch. But I think that's a separate problem from what u were concerned with.
policy_state, | ||
estimator_class, | ||
): | ||
"""Computes importance sampling weights for the given batch of samples.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
proper dosctring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
from typing import Any, Dict, Type, TYPE_CHECKING | ||
import numpy as np | ||
|
||
from ray.rllib.utils.numpy import convert_to_numpy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sort imports
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
@@ -56,6 +65,7 @@ def __init__( | |||
epsilon_greedy: The probability by which we act acording to a fully random | |||
policy during deployment. With 1-epsilon_greedy we act | |||
according the target policy. | |||
normalize_weights: Whether to normalize the importance sampling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More description of what this arg does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
|
||
dsize = dataset.count() | ||
batch_size = max(dsize // n_parallelism, 1) | ||
# step 1: clean the dataset and remove the time dimension for bandits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we really auto-infer here that this is used for Bandits only? I feel like the time-dim removal doesn't belong here and should be handled by the client of this class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah it doesn't belong here. It now happens in algorithm.py only if the problem is offline bandits
|
||
dsize = dataset.count() | ||
batch_size = max(dsize // n_parallelism, 1) | ||
# step 1: clean the dataset and remove the time dimension for bandits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see my comment below on why we shouldn't do the time-dim removal here.
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
…sense anymore Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
…ay into remove-evalworkers-in-ope
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
@@ -270,12 +270,6 @@ def validate(self) -> None: | |||
f"Try setting config.rollouts(rollout_fragment_length={self.n_step})." | |||
) | |||
|
|||
if self.model["custom_model"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apparently we never called this validate() function method before when custom_model was provided.
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
…ith dataset + tune Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
are those red warnings related? I can't see anything man. |
They could be related. But I don't understand the error message. I updated those docstrings tbh but not sure what the proper format is and what the complain means. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nothing major.
# Dataset should be in form of one episode per row. in case of bandits each | ||
# row is just one time step. To make the computation more efficient later | ||
# we remove the time dimension here. | ||
parallelism = self.evaluation_config.evaluation_num_workers or 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we do have a parallelism parameter in evaluation_config. we don't want to use it in place of evaluation_num_workers?
it reads a little weird because on 1 hand, we say we are not gonna create evaluation workers, on the other hand, evaluation_num_workers actually play a critical role here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current users (you know who :)) get parallelism by defining the number of evaluation workers. I would love to actually change that at some point. To avoid further confusion I am getting rid of parallelism so it is less confusing. If you look in validate I am overriding the parallelism in evaluation_config.
self.evaluation_dataset = None | ||
if ( | ||
self.evaluation_config.off_policy_estimation_methods | ||
and not self.evaluation_config.ope_split_batch_by_episode |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have lost tracked a little bit. what if there are off_policy_estimation_methods and ope_split_batch_by_episode is True? how is that mode handled?
kinda of feel like we need a single functoin def _setup_evaluation(self)
, and in there, a simple list of different online/offline evaluation types that we may have, like:
def _setup_evaluation(self):
if (
self.evaluation_config.off_policy_estimation_methods and
not self.evaluation_config.ope_split_batch_by_episode
):
self._setup_offline_bandit_eval() # DS based offline bandit data eval.
elif self.config.evaluation_num_workers > 0:
assert self.config.evaluation_config.env, ...
self._setup_evaluation_workers() # Online eval.
else:
# What do we do here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if ope is specified but not split_batch_by_episode we fall back to the old behavior. i.e. use evaluation workers that each have a shard of the dataset. It is not clear how I could use ray dataset map_batches to do processing on an episode level. In the future we can do that consolidation but till then let's just go back to the old behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok. but how do I tell this is the behavior here?
like we are gonna get to this if statement of off_policy_estimation_methods and not ope_split_batch_by_episode
, if it doesn't fit, where do we fall to? how do I know the evaluation workers are guaranteed to be there?
to be clear, I am not saying the behavior is bad, I am just saying maybe we can write the code in a way that makes this flow clearer. wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep. Let me explain what happens:
if this condition is true, you'll set self.evaluation_dataset to a dataset otherwise it's None. In self.evaluate()
if self.evaluation_dataset
is None you fall back on to the old behavior. which is using evaluation_workers. Now the question is have we actually created evaluation_workers when they where necessary? The answer is yes. It's in the setup() function where you call self._should_create_evaluation_rollout_workers(self.evaluation_config)
. How do you think this should be re-written so that it is more clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, fine for now.
personally I will feel a lot safer, if this whole thing is a single if...else..., then I don't all cases are covered.
right now creation of self.evalution_workers
is conditioned on self._should_create_evaluation_rollout_workers()
, and creation of evaluation_dataset
is conditioned on off_policy_estimation_methods and not ope_split_batch_by_episode
.
it's just hard to tell if these 2 are mutually exclusive or not. you get the idea.
rllib/algorithms/algorithm.py
Outdated
@@ -803,6 +825,11 @@ def evaluate( | |||
# Call the `_before_evaluate` hook. | |||
self._before_evaluate() | |||
|
|||
if self.evaluation_dataset is not None: | |||
eval_results = {} | |||
eval_results["evaluation"] = self._run_offline_evaluation() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this basically:
return {
"evaluation": self._run_offline_evaluation()
}
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. fixed.
rllib/algorithms/algorithm.py
Outdated
# training on the offline dataset, since rollout workers have already | ||
# claimed it. | ||
# Another Note (Kourosh): dataset reader will not use placement groups so | ||
# what ever we specify here won't matter because dataset won't even use it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whatever
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
# evaluation_num_workers for backward compatibility and num_cpus gets | ||
# set to num_cpus_per_worker from rollout worker. User only needs to | ||
# set evaluation_num_workers. | ||
self.input_config["parallelism"] = self.evaluation_num_workers or 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, so we should simply use "parallelism" in our code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in code yes. from user it's evaluation_num_workers.
check(is_ste, ope_results["is"]["v_gain_ste"]) | ||
|
||
def test_dr_on_estimate_on_dataset(self): | ||
# TODO (Kourosh): How can we unittest this without querying into the model? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you mean? we can do hacky stuff in a unit test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess my point was that when u do this you don't have a ground truth to compare against. We need to write this test in a way where the ground truth is independent of the implementation details. I don't know how to do that yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh
cpus_per_task = input_config.get( | ||
"num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK | ||
) | ||
return [{"CPU": cpus_per_task} for _ in range(parallelism)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
rllib/offline/feature_importance.py
Outdated
|
||
For each feature in the dataset, the importance is computed by applying | ||
perturbations to each feature and computing the difference between the | ||
perturbed prediction and the reference prediction. The importance is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: is at the end is not necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
|
||
Args: | ||
dataset: The dataset to use for the computation. The dataset should have `obs` | ||
and `actions` columns. Each record should be flat d-dimensional array. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mention that we only perturb / study features in the last dimension?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we compute it on all dims. I don't understand this comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
line 253 n_features = ds.take(1)[0][SampleBatch.OBS].shape[-1]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh that's just getting the number of features.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but that's how many features we study right? when we perturb, we only shuffle along the last dim?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all of them. for each feature dim, you need to shuffle that entire column in the dataset. and do batch inference.
The modified dataset that contains a `delta` column which is the absolute | ||
difference between the expected output and the output due to the perturbation. | ||
""" | ||
perturbed_ds = dataset.map_batches( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
going over a same ds in parallel? this actually works? amazing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually good point. I have to verify this. I forgot to update the unittest for feature_importance. added a todo for now. We should get this merged and I'll add the unittest in a separate PR after the release timeline.
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good now. I feel like the way we do offline evaluation now is finally right :)
Signed-off-by: Kourosh Hakhamaneshi <[email protected]> Signed-off-by: Weichen Xu <[email protected]>
Why are these changes needed?
This PR covers two things:
evaluation_duration_unit = "episode"
, andevaluation_duration=int(validation_len * num_workers // batch_size)
anymore. As long as you give us a dataset, the evaluation will be done on that entire dataset after each training iteration is done (even if the dataset size is not integer multiple of the batch size // number of workers)Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.