-
Notifications
You must be signed in to change notification settings - Fork 442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix tensor devices for DARTS Trial #2273
Conversation
@sifa1024 You need to sign to commit with the email used during sign the CLA. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the fix! I left a small comment.
# Check device use cuda or cpu | ||
use_cuda = list(range(torch.cuda.device_count())) | ||
if use_cuda: | ||
print("Using CUDA") | ||
device = torch.device("cuda" if use_cuda else "cpu") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We identify device here:
katib/examples/v1beta1/trial-images/darts-cnn-cifar10/run_trial.py
Lines 86 to 100 in 4d3ea0c
if len(all_gpus) > 0: | |
device = torch.device("cuda") | |
torch.cuda.set_device(all_gpus[0]) | |
np.random.seed(2) | |
torch.manual_seed(2) | |
torch.cuda.manual_seed_all(2) | |
torch.backends.cudnn.benchmark = True | |
print(">>> Use GPU for Training <<<") | |
print("Device ID: {}".format(torch.cuda.current_device())) | |
print("Device name: {}".format(torch.cuda.get_device_name(0))) | |
print("Device availability: {}\n".format(torch.cuda.is_available())) | |
else: | |
device = torch.device("cpu") | |
print(">>> Use CPU for Training <<<") | |
Can we just pass the device to the Architect class ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we can. But is it a good idea to send the device name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think, it's fine since we don't need to invoke torch
API again to understand if we have GPU available.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK I will change it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@andreyvelich Please check it and thank you for your help.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I restarted tests.
[email protected] Signed-off-by: Chen Pin-Han <72907153+sifa1024@users.noreply.github.com>
[email protected] Signed-off-by: Chen Pin-Han <72907153+sifa1024@users.noreply.github.com>
[email protected] Signed-off-by: Chen Pin-Han <72907153+sifa1024@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/lgtm
/approve
/hold
for restarting failed Go Test / Unit Test (1.26.1) (pull_request)
.
@kubeflow/wg-automl-leads Could you restart CI?
[APPROVALNOTIFIER] This PR is APPROVED This pull-request has been approved by: sifa1024, tenzen-y The full list of commands accepted by this bot can be found here. The pull request process is described here
Needs approval from an approver in each of these files:
Approvers can indicate their approval by writing |
@kubeflow/wg-automl-leads Could you approve CI, again? |
@sifa1024 |
@tenzen-y OK. I'm sorry. But I found a commit in my branch was not updated, so I just updated this. |
@kubeflow/wg-automl-leads Could you restart /lgtm |
Thank you for your contribution @sifa1024! |
What this PR does / why we need it:
If I use the original program, I will get this error when running darts-gpu,
Which issue(s) this PR fixes
None. I've create pull request directly.
Checklist: