Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running vid2vid on PyTorch 1.0.0 and CUDA 10 #82

Open
clarle opened this issue Dec 11, 2018 · 15 comments
Open

Running vid2vid on PyTorch 1.0.0 and CUDA 10 #82

clarle opened this issue Dec 11, 2018 · 15 comments

Comments

@clarle
Copy link

clarle commented Dec 11, 2018

I thought this might help a few people, though feel free to remove this if you think it doesn't belong in the issues tracker.

I was trying to run vid2vid on PyTorch 1.0.0 and CUDA 10, now that 1.0.0 was stable. I was able to successfully do it by making a few changes to the downloaded flownet2_pytorch snapshot.

Instructions

  • First, download all of the dependencies needed for vid2vid, substituting CUDA 10 and PyTorch 1.0.0 instead of CUDA 9.x and PyTorch 0.4.x.

  • Clone the vid2vid repository.

  • Continue to download the datasets with python scripts/download_datasets.py.

  • At this point, there are two options:

    • git clone the flownet2-pytorch repository directly into the models folder in vid2vid to flownet2_pytorch (recommended)
  • Alternatively:

    • Modify the download_flownet2.py file from if torch.__version__ == '0.4.1': to if torch.__version__ == '1.0.0':
    • Then, run python scripts/download_flownet2.py. You'll get some compilation errors and warnings as the current version of flownet2 is still designed for the earlier version of Torch.
  • Go to models/flownet2_pytorch and make the changes in this pull request:

  • Rebuild flownet2 after making those changes with bash install.sh inside the models/flownet2_pytorch folder. It should now compile successfully with PyTorch 1.0.0.

You can now run vid2vid following the rest of the instructions. I was able to get this working on CUDA 10, PyTorch 1.0.0, on a RTX 2080 TI.

If there's any interest, I can make a fork of vid2vid / flownet2 that does these things automatically as well or a Docker image.

Thanks @jiapei100 for the original PR to flownet2 - saved me a lot of time from trying to dig through the new PyTorch API changes there.

@francisr
Copy link

I'm trying to use vid2vid with pytorch 1.0.
I've made the fixes of the pull request you mention, but when I try to run a script I get an undefined symbol error:

-------------- End ----------------
CustomDatasetDataLoader
dataset [PoseDataset] was created
vid2vid
Traceback (most recent call last):
  File "test.py", line 25, in <module>
    model = create_model(opt)
  File "/home/remi/git/vid2vid/models/models.py", line 7, in create_model
    from .vid2vid_model_G import Vid2VidModelG
  File "/home/remi/git/vid2vid/models/vid2vid_model_G.py", line 13, in <module>
    from . import networks
  File "/home/remi/git/vid2vid/models/networks.py", line 12, in <module>
    from .flownet2_pytorch.networks.resample2d_package.resample2d import Resample2d
  File "/home/remi/git/vid2vid/models/flownet2_pytorch/networks/resample2d_package/resample2d.py", line 3, in <module>
    import resample2d_cuda
ImportError: /home/remi/git/vid2vid/venv/lib/python3.5/site-packages/resample2d_cuda-0.0.0-py3.5-linux-x86_64.egg/resample2d_cuda.cpython-35m-x86_64-linux-gnu.so: undefined symbol: _ZN2at19UndefinedTensorImpl10_singletonE

@clarle
Copy link
Author

clarle commented Dec 14, 2018

@francisr I previously got that issue when I had both CUDA 9 and CUDA 10 installed.

Make sure the symlink for /usr/local/cuda is pointing to the correct CUDA version, and check to see which version nvcc is tied to. I found it easier to start off with a fresh image and only installing CUDA 10.

@francisr
Copy link

I just have CUDA 9 installed on my system. But I gave up and used Pytorch 0.4.1 instead...

@anotherTK
Copy link

anotherTK commented Mar 1, 2019

Thanks for the guide. One more thing, if you git clone * flownet2-pytorch repository to flownet2_pytorch and want to train vid2vid, you need add the following codes in flownet2_pytorch/models.py.

FlowNet2.init(self, args=None, batchNorm=False, div_flow = 20.)

if args is None:
    args = MyDict()
    args.rgb_max = 1
    args.fp16 = False
    args.grads = {}
class MyDict(dict):
    pass

@andrewkchan
Copy link

andrewkchan commented May 5, 2019

One more thing: I did the git clone method and also had to change the imports in flownet2_pytorch/models.py back to relative imports, e.g. I had to undo the changes made in NVIDIA/flownet2-pytorch@44c8693

BTW this also works great for PyTorch 1.0 and CUDA 9, was finally able to get vid2vid running on V100 GPUs this way.

@ghost
Copy link

ghost commented Jun 11, 2019

I thought this might help a few people, though feel free to remove this if you think it doesn't belong in the issues tracker.

I was trying to run vid2vid on PyTorch 1.0.0 and CUDA 10, now that 1.0.0 was stable. I was able to successfully do it by making a few changes to the downloaded flownet2_pytorch snapshot.

Instructions

* First, download all of the dependencies needed for vid2vid, substituting CUDA 10 and PyTorch 1.0.0 instead of CUDA 9.x and PyTorch 0.4.x.

* Clone the vid2vid repository.

* Continue to download the datasets with `python scripts/download_datasets.py`.

* At this point, there are two options:
  
  * `git clone` the [flownet2-pytorch](https://github.com/NVIDIA/flownet2-pytorch) repository directly into the `models` folder in vid2vid to `flownet2_pytorch` (recommended)

* Alternatively:
  
  * Modify the `download_flownet2.py` file from `if torch.__version__ == '0.4.1':` to `if torch.__version__ == '1.0.0':`
  * Then, run `python scripts/download_flownet2.py`.  You'll get some compilation errors and warnings as the current version of flownet2 is still designed for the earlier version of Torch.

* Go to `models/flownet2_pytorch` and make the changes in this pull request:
  
  * [NVIDIA/flownet2-pytorch#98](https://github.com/NVIDIA/flownet2-pytorch/pull/98)
  * You don't need to do all of them, only really adding `#include <ATen/cuda/CUDAContext.h>` and then replacing `at::globalContext().getCurrentCUDAStream()` with `at::cuda::getCurrentCUDAStream()` in each of the three sub-packages - `channelnorm-cuda`, `correlation-cuda`, and `resample2d-cuda`.

* Rebuild flownet2 after making those changes with `bash install.sh` inside the `models/flownet2_pytorch` folder.  It should now compile successfully with PyTorch 1.0.0.

You can now run vid2vid following the rest of the instructions. I was able to get this working on CUDA 10, PyTorch 1.0.0, on a RTX 2080 TI.

If there's any interest, I can make a fork of vid2vid / flownet2 that does these things automatically as well or a Docker image.

Thanks @jiapei100 for the original PR to flownet2 - saved me a lot of time from trying to dig through the new PyTorch API changes there.

Thanks, can you please make a fork of vid2vid / flownet2 that does these things automatically or a Docker image? I really appreciate that.

@ghost
Copy link

ghost commented Jun 12, 2019

Hey @clarle, thanks for the instructions.
I tried to install PyTorch 1.0.0 but apparently 'torchvision 0.3.0 has requirement torch>=1.1.0' I wonder how you managed to work with pytorch 1.0.0? So due to this incompatibility, I had to install pytorch 1.1.0 and CUDA 10 and went through the instructions you gave and did all the modifications everything worked fine and no errors . But now that I run the test.py I get "RuntimeError: cuda runtime error (11) : invalid argument at /pytorch/aten/src/THC/THCGeneral.cpp:383" error which I believe is due to pytorch version incompatibility or s.th. Any ideas how I can fix it? I really appreciate any help. thanks

@zhuhaozh
Copy link

Segmentation fault when execute
"flow1 = self.flowNet(data1)" in models.flownet

Anyone meets this problem?

@sheiun
Copy link

sheiun commented Aug 12, 2020

Thanks for the guide. One more thing, if you git clone * flownet2-pytorch repository to flownet2_pytorch and want to train vid2vid, you need add the following codes in flownet2_pytorch/models.py.

FlowNet2.init(self, args=None, batchNorm=False, div_flow = 20.)

if args is None:
    args = MyDict()
    args.rgb_max = 1
    args.fp16 = False
    args.grads = {}
class MyDict(dict):
    pass

Hi I have a question with parameter rgb_max = 1 why not 255.

And here is a part of code in flownet-pytorch parser.add_argument("--rgb_max", type=float, default=255.)

@karims
Copy link

karims commented Oct 3, 2020

I still get this error:

                 from channelnorm_cuda.cc:1:
/usr/local/lib/python3.5/dist-packages/torch/include/c10/util/Optional.h: In instantiation of 'c10::optional<T>& c10::optional<T>::operator=(c10::optional<T>&&) [with T = torch::ExpandingArray<2ul, double>]':
/usr/local/lib/python3.5/dist-packages/torch/include/torch/csrc/api/include/torch/nn/options/pooling.h:428:8:   required from 'void torch::nn::Cloneable<Derived>::clone_(torch::nn::Module&, const c10::optional<c10::Device>&) [with Derived = torch::nn::FractionalMaxPool2dImpl]'
/usr/local/lib/python3.5/dist-packages/torch/include/torch/csrc/api/include/torch/optim/sgd.h:48:17:   required from here
/usr/local/lib/python3.5/dist-packages/torch/include/c10/util/Optional.h:396:23: error: passing 'const torch::ExpandingArray<2ul, double>' as 'this' argument discards qualifiers [-fpermissive]
       contained_val() = std::move(*rhs);
                       ^
In file included from /usr/local/lib/python3.5/dist-packages/torch/include/torch/csrc/api/include/torch/nn/options/conv.h:6:0,
                 from /usr/local/lib/python3.5/dist-packages/torch/include/torch/csrc/api/include/torch/nn/functional/conv.h:3,
                 from /usr/local/lib/python3.5/dist-packages/torch/include/torch/csrc/api/include/torch/nn/functional.h:4,
                 from /usr/local/lib/python3.5/dist-packages/torch/include/torch/csrc/api/include/torch/nn.h:4,
                 from /usr/local/lib/python3.5/dist-packages/torch/include/torch/csrc/api/include/torch/all.h:7,
                 from /usr/local/lib/python3.5/dist-packages/torch/include/torch/csrc/api/include/torch/torch.h:3,
                 from channelnorm_cuda.cc:1:
/usr/local/lib/python3.5/dist-packages/torch/include/torch/csrc/api/include/torch/expanding_array.h:23:7: note:   in call to 'torch::ExpandingArray<2ul, double>& torch::ExpandingArray<2ul, double>::operator=(const torch::ExpandingArray<2ul, double>&)'
 class ExpandingArray {
       ^
error: command 'x86_64-linux-gnu-gcc' failed with exit status 1

@xungeer29
Copy link

Segmentation fault when execute
"flow1 = self.flowNet(data1)" in models.flownet

Anyone meets this problem?

I meet the problem, too. Have you resolved it?

@tuanfeng
Copy link

tuanfeng commented May 2, 2021

here is my script to make it run:

nvidia-smi:
NVIDIA-SMI 460.32.03 Driver Version: 460.32.03 CUDA Version: 11.2

nvcc --version:
Cuda compilation tools, release 10.1, V10.1.105

conda create -n vid2vid python=3.6
conda activate vid2vid
git clone https://github.com/NVIDIA/vid2vid.git
cd vid2vid

add 'edge = edge.bool()' to models/base_model.py:ln 148

python scripts/download_datasets.py
pip install dominate requests
pip install numpy
add-apt-repository ppa:ubuntu-toolchain-r/test
apt update
apt install gcc-6 g++-6
update-alternatives --remove-all gcc
update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-6 80 --slave /usr/bin/g++ g++ /usr/bin/g++-6
update-alternatives --config gcc
pip install torch==1.4.0
python scripts/download_flownet2.py

check in python: import torch, import resample2d_cuda, see if there's any problem

python scripts/street/download_models.py
pip install Pillow opencv-python scipy
pip install torchvision==0.5.0
apt install libgl1-mesa-glx
python test.py --name label2city_2048 --label_nc 35 --loadSize 2048 --n_scales_spatial 3 --use_instance --fg --use_single_G
python scripts/download_models_flownet2.py
python train.py --name label2city_512 --label_nc 35 --gpu_ids 0,1,2,3,4,5,6,7 --n_gpus_gen 6 --n_frames_total 6 --use_instance --fg

@ry85
Copy link

ry85 commented Sep 6, 2021

Below setup worked for me:

conda remove --name vid2vid --all
conda create -n vid2vid python=3.7
conda activate vid2vid

conda install pytorch==1.4.0 torchvision==0.5.0 cudatoolkit=10.1 -c pytorch

pip install dominate requests
pip install dlib
pip install scikit-image
pip install opencv-python
!pip install numpy==1.16.1

Training (need approx. 15GB GPU)

python train.py --name edge2face_512 --dataroot datasets/face/ --dataset_mode face --input_nc 15 --loadSize 512 --num_D 3 --gpu_ids 0 --n_gpus_gen 1 --n_frames_total 12

Testing

python test.py --name edge2face_512 --dataroot datasets/face/ --dataset_mode face --input_nc 15 --loadSize 512 --use_single_G

@moulimatsa
Copy link

I have tried this solution but didn't work for me.
My command for single gpu is:
python train.py --name label2city_256_g2 --label_nc 35 --loadSize 256 --use_instance --fg --n_downsample_G 2 --num_D 1 --max_frames_per_gpu 6 --n_frames_total 6

Error:

------------ Options -------------
TTUR: False
add_face_disc: False
basic_point_only: False
batchSize: 1
beta1: 0.5
checkpoints_dir: ./checkpoints
continue_train: False
dataroot: datasets/Cityscapes/
dataset_mode: temporal
debug: False
densepose_only: False
display_freq: 100
display_id: 0
display_winsize: 512
feat_num: 3
fg: True
fg_labels: [26]
fineSize: 512
fp16: False
gan_mode: ls
gpu_ids: [0]
input_nc: 3
isTrain: True
label_feat: False
label_nc: 35
lambda_F: 10.0
lambda_T: 10.0
lambda_feat: 10.0
loadSize: 256
load_features: False
load_pretrain:
local_rank: 0
lr: 0.0002
max_dataset_size: inf
max_frames_backpropagate: 1
max_frames_per_gpu: 6
max_t_step: 1
model: vid2vid
nThreads: 2
n_blocks: 9
n_blocks_local: 3
n_downsample_E: 3
n_downsample_G: 2
n_frames_D: 3
n_frames_G: 3
n_frames_total: 6
n_gpus_gen: 1
n_layers_D: 3
n_local_enhancers: 1
n_scales_spatial: 1
n_scales_temporal: 2
name: label2city_256_g2
ndf: 64
nef: 32
netE: simple
netG: composite
ngf: 128
niter: 10
niter_decay: 10
niter_fix_global: 0
niter_step: 5
no_canny_edge: False
no_dist_map: False
no_first_img: False
no_flip: False
no_flow: False
no_ganFeat: False
no_html: False
no_vgg: False
norm: batch
num_D: 1
openpose_only: False
output_nc: 3
phase: train
pool_size: 1
print_freq: 100
random_drop_prob: 0.05
random_scale_points: False
remove_face_labels: False
resize_or_crop: scaleWidth
save_epoch_freq: 1
save_latest_freq: 1000
serial_batches: False
sparse_D: False
tf_log: False
use_instance: True
use_single_G: False
which_epoch: latest
-------------- End ----------------
CustomDatasetDataLoader
dataset [TemporalDataset] was created
#training videos = 6
vid2vid
---------- Networks initialized -------------

---------- Networks initialized -------------

create web directory ./checkpoints/label2city_256_g2/web...
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:4004: UserWarning: Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details.
"Default grid_sample and affine_grid behavior has changed "
Traceback (most recent call last):
File "train.py", line 148, in
train()
File "train.py", line 60, in train
flow_ref, conf_ref = flowNet(real_B, real_B_prev) # reference flows and confidences
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/content/drive/MyDrive/Colab Notebooks/vid2vid/models/flownet.py", line 38, in forward
flow, conf = self.compute_flow_and_conf(input_A, input_B)
File "/content/drive/MyDrive/Colab Notebooks/vid2vid/models/flownet.py", line 54, in compute_flow_and_conf
flow1 = self.flowNet(data1)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/content/drive/MyDrive/Colab Notebooks/vid2vid/models/flownet2_pytorch/models.py", line 105, in forward
flownetc_flow2 = self.flownetc(x)[0]
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/content/drive/MyDrive/Colab Notebooks/vid2vid/models/flownet2_pytorch/networks/FlowNetC.py", line 89, in forward
out_corr = self.corr(out_conv3a, out_conv3b) # False
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/content/drive/MyDrive/Colab Notebooks/vid2vid/models/flownet2_pytorch/networks/correlation_package/correlation.py", line 59, in forward
result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)(input1, input2)
File "/usr/local/lib/python3.7/dist-packages/torch/autograd/function.py", line 262, in call
"Legacy autograd function with non-static forward method is deprecated. "
RuntimeError: Legacy autograd function with non-static forward method is deprecated. Please use new-style autograd function with static forward method. (Example: https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function)

@digvijayad
Copy link

@moulimatsa Were you able to solve this? I'm getting the same error?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests