Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] Predictor call_model API for unsupported output types #26845

Merged
merged 10 commits into from
Jul 27, 2022

Conversation

amogkam
Copy link
Contributor

@amogkam amogkam commented Jul 21, 2022

Signed-off-by: Amog Kamsetty [email protected]

Adds a Public call_model API to TorchPredictor and TensorflowPredictor.

We also remove default support for List outputs, as it's ambiguous if the list represents the batch, or different output columns.

Supported output types are tensor and Dict[str, tensor].

To handle models with unsupported output types, users are expected to override call_model.

Example:

import torch

class SSDPredictor(TorchPredictor):
    def call_model(self, tensor):
        output = super().call_model(tensor)
        output = {k: torch.stack([d[k] for d in output]) for k in output[0]}
        return output

Documentation will be added if this API looks good.

Why are these changes needed?

Related issue number

Checks

  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: Amog Kamsetty <[email protected]>
@amogkam amogkam marked this pull request as draft July 21, 2022 18:55
Signed-off-by: Amog Kamsetty <[email protected]>
Signed-off-by: Amog Kamsetty <[email protected]>
@ericl ericl added the @author-action-required The PR author is responsible for the next step. Remove tag to send back to the reviewer. label Jul 21, 2022
Signed-off-by: Amog Kamsetty <[email protected]>
…ictor-developer-prototype

Signed-off-by: Amog Kamsetty <[email protected]>
Signed-off-by: Amog Kamsetty <[email protected]>
Signed-off-by: Amog Kamsetty <[email protected]>
@amogkam amogkam removed the @author-action-required The PR author is responsible for the next step. Remove tag to send back to the reviewer. label Jul 27, 2022
@amogkam amogkam marked this pull request as ready for review July 27, 2022 01:40
@amogkam amogkam changed the title [WIP] Predictor Developer API prototype [AIR] Predictor call_model API for unsupported output types Jul 27, 2022
Signed-off-by: Amog Kamsetty <[email protected]>
Signed-off-by: Amog Kamsetty <[email protected]>
Copy link
Contributor

@clarkzinzow clarkzinzow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, nice examples in the call_model() docstrings!

@ericl ericl merged commit 13d185c into ray-project:master Jul 27, 2022
Rohan138 pushed a commit to Rohan138/ray that referenced this pull request Jul 28, 2022
Stefan-1313 pushed a commit to Stefan-1313/ray_mod that referenced this pull request Aug 18, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants