Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix AttributeError with None obs shape and tf #22428

Conversation

Fabien-Couthouis
Copy link
Contributor

Why are these changes needed?

When obs shape is None, calling v.value lead to AttributeError: 'NoneType' object has no attribute 'value' in ModelV2 with tensorflow. This fix adds a check for None value:

batch_dims = [
                v if isinstance(v, (int, type(None))) else v.value
                for v in obs.shape[:-1]
            ]

Related issue number

Closes #16286

Checks

  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

@Fabien-Couthouis
Copy link
Contributor Author

@GoingMyWay

rllib/models/modelv2.py Outdated Show resolved Hide resolved
@GoingMyWay
Copy link

@GoingMyWay

@Fabien-Couthouis Hi, I guess the same issue may happen for torch. I ran PyTorch variants of A3C and PPO and I got RuntimeError: input.size(-1) must be equal to input_size.. I am debugging the code. Hope I can fix this issue.

@GoingMyWay
Copy link

GoingMyWay commented Feb 17, 2022

@Fabien-Couthouis. I did not find the bug. After downgrading the ray from 1.10.0 to 1.4.1, the bug disappeared. 🤣

@Fabien-Couthouis
Copy link
Contributor Author

@GoingMyWay I am not sure the your issue is related to this one, you can open a new issue maybe

else:
return v.value

batch_dims = [get_value(v) for v in obs.shape[:-1]]
batch_dims = [-1 if v is None else v for v in batch_dims]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then you don't need this line anymore. please delete?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, v.value can still be None:

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

obs = tf.placeholder(tf.float32, shape=[None, None, 128, 3], name="X")
print(obs.shape) 
>>> TensorShape([Dimension(None), Dimension(None), Dimension(128), Dimension(3)])

def get_value(v):
    if v is None:
        return -1
    elif isinstance(v, int):
        return v
    else:
        return v.value 

batch_dims = [get_value(v) for v in obs.shape[:-1]]
print(batch_dims)
>>> [None, None, 128]
batch_dims = [-1 if v is None else v for v in batch_dims]
print(batch_dims)
>>> [-1, -1, 128]

But still, this logic can be moved in get_value: bd0e751

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. yeah, it probably should be

if v is None or v.value is None:
   return -1

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking v.value is None before isinstance(v, int) will lead to something like "ValueError, int type does not have a .value method" if v is an integer. Can change:

def get_value(v):
  if v is None:
      return -1
  elif isinstance(v, int):
      return v
  elif v.value is None:
      return -1
  else:
      return v.value

to:

def get_value(v):
  if isinstance(v, int):
      return v
  elif  v is None or v.value is None:
      return -1
  else:
      return v.value

But in general, I prefere to avoid unnecessary calls to .isinstance because this function can be quite slow.

Copy link
Member

@gjoliver gjoliver left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sven1977, can you help merge? thanks.

@sven1977 sven1977 merged commit e575ed3 into ray-project:master Mar 15, 2022
@Fabien-Couthouis Fabien-Couthouis deleted the fix/attributeError_with_None_obs_shape_and_tf branch March 16, 2022 14:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants