-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[training] Tensorflow interface for MultiNode SGD #5440
Conversation
Test FAILed. |
BTW, I don't think you need to fit the PyTorch API so closely. I think you should first get the Distributed TF example running in Ray, and then think about APIs afterwards. |
… train_example works with cpu
Test FAILed. |
Looks like there are problems with model save / loading.
|
Test FAILed. |
logger = logging.getLogger(__name__) | ||
|
||
|
||
class DistributedTensorFlowRunner(TensorFlowRunner): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the inheritance necessary here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class TensorFlowRunner's method get_state and set_state can be used to get current model. So I think inheritance is needed here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should just define get_state
and set_state
in this class, so I think we don't need a separate DistributedTensorFlowRunner clas - just a regular TensorFlowRunner.
Test FAILed. |
Test FAILed. |
from ray.experimental.sgd.tensorflow.tensorflow_trainer import ( | ||
TensorFlowTrainer, TensorFlowTrainable) | ||
|
||
from ray.experimental.sgd.tests.tf_helper import (get_model, get_dataset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this, I think we should not use tf_helper
and define get_model
and get_dataset
in this file. PyTorch should be updated in a separate PR.
- pip install -U tensorflow-gpu==2.0.0-beta1 | ||
|
||
file_mounts: { | ||
~/run/: /Users/jichanchung/OneDrive/FF/190812_tf2_tune_trainable/190814_multinode_mnist_ray |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove?
|
||
return stats | ||
|
||
def validate(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validate should only be called on the test_dataset
.
@@ -0,0 +1,98 @@ | |||
from __future__ import absolute_import |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should merge this with DistributedTensorflowRunner, as commented in that file
|
||
def set_state(self, state): | ||
self.epoch = state["epoch"] | ||
if self.model.optimizer.weights == []: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite understand this issue nor this comment - can you:
- provide a comment of exactly what the error is, and
- provide in the code a link to the stackoverflow or Tensorflow github issue link that you found which suggested this workaround?
@@ -0,0 +1,38 @@ | |||
from __future__ import absolute_import, division, print_function, unicode_literals |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be separate lines, and I don't think you need unicode_laterals
from __future__ import absolute_import, division, print_function, unicode_literals | ||
import tensorflow as tf | ||
|
||
NUM_TRAIN_SAMPLES = 60000 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we make this 512?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shuffle(NUM_TRAIN_SAMPLES) is used to shuffle whole data.
Are you meaning we should only consider 512 datapoints from mnist dataset?
@@ -0,0 +1,38 @@ | |||
from __future__ import absolute_import, division, print_function, unicode_literals |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you copy these functions to the example?
… to tf runner; removed trainloss decreasing check from example and test
Test FAILed. |
Test FAILed. |
Test FAILed. |
Test FAILed. |
Test FAILed. |
Test FAILed. |
Test FAILed. |
Test FAILed. |
Test FAILed. |
Test FAILed. |
Test FAILed. |
Test FAILed. |
Test PASSed. |
What do these changes do?
Creates Tensorflow interface for MultiNode SGD.
TODO:
Linter
scripts/format.sh
to lint the changes in this PR.