Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] Update TensorflowPredictor to new API #26215

Merged
merged 5 commits into from
Jul 8, 2022

Conversation

amogkam
Copy link
Contributor

@amogkam amogkam commented Jun 29, 2022

Updates TensorflowPredictor to use the new _predict_pandas API.

Also as agreed upon offline, removes the extra configurations from TensorflowPredictor (column selection, concatenation) in favor of having this be done via a Preprocessor.

Why are these changes needed?

Related issue number

Checks

  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Copy link
Contributor

@krfricke krfricke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally LGTM! Couple of questions

python/ray/train/_internal/dl_predictor.py Outdated Show resolved Hide resolved
python/ray/train/tensorflow/tensorflow_predictor.py Outdated Show resolved Hide resolved
python/ray/train/_internal/dl_predictor.py Outdated Show resolved Hide resolved
@@ -121,36 +153,7 @@ def build_model(self):

predictions = predictor.predict(data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also comment on what is the return type of predictions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the docstring!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, just for my own understanding...
so say the input is a pd frame with column "age", "salary", it will be converted into multiple tensors and the model should be expecting that?

Another question I have is if the input is image, what does the predict function take and how's everything converted?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

regardless of image or tabular etc.

if it is is a single column dataframe, then it will be converted to a single tensor before being inputted to the model.
if it is a multi column dataframe, then it will be converted to a dict of tensors before being inputted to the model.

@@ -520,6 +340,14 @@
" df = df.drop([\"trip_start_timestamp\"], axis=1)\n",
" return df\n",
"\n",
" def concat_for_tensor(dataframe):\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to be written by user?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Contributor

@krfricke krfricke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@amogkam amogkam merged commit cc43bcc into ray-project:master Jul 8, 2022
@amogkam amogkam deleted the update-tensorflow-predictor-2 branch July 8, 2022 20:04
truelegion47 pushed a commit to truelegion47/ray that referenced this pull request Jul 9, 2022
* master: (42 commits)
  [dashboard][2/2] Add endpoints to dashboard and dashboard_agent for liveness check of raylet and gcs (ray-project#26408)
  [Doc] Fix docs feedback button (ray-project#26402)
  [core][1/2] Improve liveness check in GCS  (ray-project#26405)
  [RLlib] Checkpoint and restore connectors. (ray-project#26253)
  [Workflow] Minor refactoring of workflow exceptions (ray-project#26398)
  [workflow] Workflow queue (ray-project#24697)
  [RLlib] Minor simplification of code. (ray-project#26312)
  [AIR] Update TensorflowPredictor to new API (ray-project#26215)
  [RLlib] Make Dataset reader default reader and enable CRR to use dataset (ray-project#26304)
  [runtime_env] [doc] Remove outdated info about "isolated" environment (ray-project#26314)
  [Doc] Fix rate-the-docs plugin (ray-project#26384)
  [Docs] [Serve] Has a consistent landing page style (ray-project#26029)
  [dashboard] Add `RAY_CLUSTER_ACTIVITY_HOOK` to `/api/component_activities` (ray-project#26297)
  [tune] Use `Checkpoint.to_bytes()` for store_to_object (ray-project#25805)
  [tune] Fix `SyncerCallback` having a size limit (ray-project#26371)
  [air] Serialize additional files in dict checkpoints turned dir checkpoints (ray-project#26351)
  [Docs] Add "rate the docs" plugin for feedback on docs (ray-project#26330)
  [Doc] Fix actor example (ray-project#26381)
  Set RAY_USAGE_STATS_EXTRA_TAGS for release tests (ray-project#26366)
  [Datasets] Update docs for drop_columns and fix typos (ray-project#26317)
  ...
Stefan-1313 pushed a commit to Stefan-1313/ray_mod that referenced this pull request Aug 18, 2022
Updates TensorflowPredictor to use the new _predict_pandas API.

Also as agreed upon offline, removes the extra configurations from TensorflowPredictor (column selection, concatenation) in favor of having this be done via a Preprocessor.

Signed-off-by: Stefan van der Kleij <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants