Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDNNError: CUDNN_STATUS_BAD_PARAM (code 3) while training lstm neural network on GPU #1360

Closed
VoSiLk opened this issue Oct 16, 2020 · 17 comments

Comments

@VoSiLk
Copy link

VoSiLk commented Oct 16, 2020

CUDNNError: CUDNN_STATUS_BAD_PARAM occurs during training with gpu. Single evaluation of the loss works and training with cpu works as well. X is a vector of Array{Float32,2}(6,50) and and Y is Array{Float32,2}(1,1).

Julia version and packages:
Julia v1.5.2
Flux v0.11.1
CUDA v1.3.3

function prepare_time_series_data_4_flux(X,Y)
	num_batches = length(X)
	x = Vector{Array{Float32}}(undef, num_batches)
	y = Vector{Array{Float32}}(undef, num_batches)
	for i in 1:length(X)
		x[i] = Float32.(reshape(X[i], size(X[i],2), size(X[i], 1)))
		y[i] = Float32.(reshape(Y[i], size(Y[i], 2), size(Y[i], 1)))
	end

	return x, y
end

gpu_or_cpu  = gpu

X, Y = prepare_time_series_data_4_flux(X_train_batched, Y_train_batched)
X = X |> gpu_or_cpu
Y = Y |> gpu_or_cpu
data = (X, vcat(Y...))

X, Y = prepare_time_series_data_4_flux(X_val_batched, Y_val_batched)
X = X |> gpu_or_cpu
Y = Y |> gpu_or_cpu
data_test = (X, vcat(Y...))

opt = ADAM(0.001, (0.9, 0.999))

function loss(X,Y)
    Flux.reset!(model)
    mse_val = sum(abs2.(Y.-vcat(model.(X)...)[:, end]))
    return mse_val
end

model = Chain(LSTM(6, 70), LSTM(70, 70), LSTM(70, 70), Dense(70, 1, relu)) |> gpu_or_cpu
ps = Flux.params(model)
Flux.reset!(model)
loss(data...)

@time Flux.train!(loss, ps, [data], opt)

grafik

FluxML/NNlib.jl#237

@jeremiedb
Copy link
Contributor

jeremiedb commented Oct 18, 2020

This seems related to same issue as here: #1114
And this PR was proposed to CuArrays: JuliaGPU/CuArrays.jl#706 but not merged.

The following PR to Adapt.jl JuliaGPU/Adapt.jl#24, was presumed to have fixed the problem, but at the time I had only performed check on vanilla RNN and GRU, not LSTM.

Maybe the patch proposed in JuliaGPU/CuArrays.jl#706 could still be brought to the following CUDA.jl:
https://github.com/JuliaGPU/CUDA.jl/blob/fc690e20a90a1211f91d561c3bfc010957381c12/lib/cudnn/rnn.jl#L189
https://github.com/JuliaGPU/CUDA.jl/blob/fc690e20a90a1211f91d561c3bfc010957381c12/lib/cudnn/rnn.jl#L200
However, as previsouly discussed, it was thought that the root cause was likely elsewhere (maybe a further fix to Adapt.jl?)

bors bot added a commit that referenced this issue Nov 7, 2020
1367: RNN update to drop CUDNN, fix LSTM bug and output type stability r=CarloLucibello a=jeremiedb

PR related to #1114 #1360 #1365 

Some experiment for RNN handling. 

Hidden state of each cell structure was dropped as they weren't needed (AFAIK, only needed for size inference for CUDNN, but bias size could be used as a substitute to cells' `h` there as well). 

Looked to drop dependence on CUDNN entirely, so it's a pure Flux/CUDA.jl. File `src/cuda/curnnjl` no longer used. No  modifications were made to the cell computations. Initial test seems to show decent performance, but yet to benchmark. 

Pending issue: despite having dropped completely the CUDNN dependency, there's still an instability issue that seems present when running on GPU. This is illustrated in the test at lines 1-50 of file `test\rnn-test-jdb.jl`. If that test runs on CPU, it goes well thorugh the 100 iterations. However, the same on GPU will thow NAs after couple dozens of iterations. 
My only hypothesis so far: when performing the iteration over the sequence through `m.(x)` or `map(rnn, x)`, is the order of the execution safe? Ie: is it possible that there isn't a `sync()` on the CUDA side between those seq steps, which may mess up the state?

### PR Checklist

- [x] Tests are added
- [ ] Entry in NEWS.md
- [ ] Documentation, if applicable
- [ ] Final review from `@dhairyagandhi96` (for API changes).


Co-authored-by: jeremiedb <[email protected]>
Co-authored-by: jeremie.db <[email protected]>
@jeremiedb
Copy link
Contributor

Could you test from master? It should now be fixed.

@VoSiLk
Copy link
Author

VoSiLk commented Nov 10, 2020

I trained one epoch and it worked. The loss decreased and now errror occured, but the training was quite slow

@jeremiedb
Copy link
Contributor

jeremiedb commented Nov 10, 2020

What do you mean by slow?

Maybe validate the data shapes. In the example below: 100 batches, batch size of 256, seq length of 20. And data features 6 as per you model definition. A single epoch takes 2.8 sec on a GTX 1660 GPU.

feat = 6
batch_size = 256
num_batches = 100
seq_len = 20

X = [[rand(Float32, feat, batch_size) for i in 1:seq_len] for batch in 1:num_batches];
Y = [rand(Float32, batch_size, seq_len) ./ 10  for batch in 1:num_batches];

X = X |> gpu;
Y = Y |> gpu;
data = zip(X, Y);

opt = ADAM(0.001, (0.9, 0.999))

function loss(X,Y)
    Flux.reset!(model)
    mse_val = sum(abs2.(Y .- Flux.stack(model.(X), 2)))
    return mse_val
end

model = Chain(LSTM(6, 70), LSTM(70, 70), LSTM(70, 70), Dense(70, 1, relu), x -> reshape(x, :)) |> gpu
ps = Flux.params(model)
Flux.reset!(model)

@time Flux.train!(loss, ps, data, opt)

@VoSiLk
Copy link
Author

VoSiLk commented Nov 11, 2020

Ok, with your code one epoch takes 5.792389 seconds for me on a NVIDIA Quadro P1000. It is also the first time that myGPU is between 8 and 9 times faster than my CPU. Even when I´ve trained with Keras my GPU was just slightly faster than my CPU. I´ve read, it would depend on the amount of the data and the size of the model. I do mainly regression and time series forecasting for embedded systems. Therefore the amount of data and the size of the model is relative small compared to models and data for a image classification task.

Can you tell me what exactly to you mean by "Maybe validate the data shapes"?

When I compare your data shape of X with mine, then I see that you have another dimension, but I don´t understand it completely. My data shape is a vector of batches with shape of 6 x 50. The first dimension 6 is including the features and the second dimension 50 is including the past time steps. The loss is just the mse of model output after the 50 time steps compared to the target value of time step 50 of each batch (Y[ct_batch * 50 + 50], where . For my data shape the speed of the CPU and GPU are similiar.

Furthermore is it better to use data = zip(X, Y); then just making a tuple data = (X,Y) ?

@jeremiedb
Copy link
Contributor

For the data shapes, I was referring to ensuring the training data is organized in a proper iterator way. You can refer to the docs about that: https://fluxml.ai/Flux.jl/stable/training/training/.
What additional dimension isn't clear? For each batch, features are a vector of seq_length (50 in your case), each containing a matrix of 6 x batch_size. Plus, your dataset is made out of a number of minibatches, resulting in a vector of batches.
If you have 1000 sequences of length 50, and batch size is 10, then num_batches = 100.

The data provided to train! is the entire dataset. The train! function will iterate over each element of that iterator, and each of those element will be passed to the loss function. I'd recommend taking a toy example such a X=1:3, Y=11:13 to consider how (X,Y) differs from zip(X,Y) to get a hands on feel of their respective behavior. Zip will return a x,y pairs at each iteration, which is what you want. Similarily, data = [(x1,y1), (x1,y2), (x3,y3)] will also return a pair of batch and target for each of the 3 batches within the data.

@jeremiedb
Copy link
Contributor

As regard to the the original issue, CUDNN_STATUS_BAD_PARAM, I guess the issue can be closed?

@VoSiLk
Copy link
Author

VoSiLk commented Nov 12, 2020

Yes it is fine, that you have closed issue. Thanks for your explanation of zip. I will try it.

What additional dimension isn't clear? For each batch, features are a vector of seq_length (50 in your case), each containing a >matrix of 6 x batch_size. Plus, your dataset is made out of a number of minibatches, resulting in a vector of batches.
If you have 1000 sequences of length 50, and batch size is 10, then num_batches = 100.

I think the batch size is confusing me. In my scenario I would like to evaluate the model for the last 50 time steps to bring the state of the model and the model output to my target value after the 50 time steps. So the last 50 time steps or my sequence length and the features are defining my batch (shape). Then I would to it for all elements of my target starting at 51 => num_batches = N-50. If I would transfer it to your code, I think my batch size is 1. In my view of the data the matrix would be defined by the sequence length and the features, because then the matrix would be independent of the other matrices. When I have understood your example right, I have to evaluated all 20 matrices of the data shape before evaluating the loss. What does the batch size in your example define? I thought the batch size is a hyperparameter that defines the number of samples to work through before updating the internal model parameters. When it is like this the model.(X) would first take the first column of all matrices of the 20 elements in the vector before taking the next column until the batch size is reached. Is it like this?

@jeremiedb
Copy link
Contributor

jeremiedb commented Nov 15, 2020

The Flux's RNN design seems to have been the cause of frequent confusion, so I'll try to provide some clarification that will hopefully be reusable.

Starting back from the basic, the classic RNN depiction from Colah's blog:
image

Here, x0 to x4 represents a single sequence of 5 items, and h0 to h4 their respective outputs.

In Flux, we represent such model with, for example:

m = Chain(LSTM(6, 70), Dense(70, 1), x -> reshape(x, :))

We can apply a single step from a given sequence comprising 6 features with:

x = rand(6)
julia> m(x)
1-element Array{Float32,1}:
 0.028398542

The m(x) operation would be represented by the x0 -> A -> h0 in Colah's diagram.
If we perform this operation a second time, it will be equivalent to x1 -> A -> h1 since the model m has stored the state resulting from the x0 step:

x = rand(6)
julia> m(x)
1-element Array{Float32,1}:
 0.07381232

Now, instead of getting a single timestep at a time, we can get the the full h0 to h4 sequence in a single step, by broadcasting the model on a sequence of data. This is what is presented in the docs here: https://fluxml.ai/Flux.jl/stable/models/recurrence/#Sequences-1

seq = [rand(6) for i = 1:5]
julia> m.(seq)
5-element Array{Array{Float32,1},1}:
 [-0.17945863]
 [-0.20863166]
 [-0.20693761]
 [-0.21654114]
 [-0.18843849]

If for some reason one wants to exclude the first 3 timesteps of the chain for the computation of the loss, that can be handled through:

function loss(seq,y)
  sum((Flux.stack(m.(seq)[4:end],1) .- y) .^ 2)
end

y=rand(2)
julia> loss(seq, y)
1.7021208968648693

Such model would mean that only h3 and h4 are used to compute the loss, hence the target y being of length 2.

Note that in your use case, if the 50 first timesteps are to be ignored, there's nothing stopping your to apply the model over 60 timesteps for example so that the gradients is calculated over 10 data points (51 to 60).

Alternatively, the "warmup" of the sequence could be performed once, followed with a regular training where all step of the sequence that can be considered for the gradient update:

function loss(seq,y)
  sum((Flux.stack(m.(seq),1) .- y) .^ 2)
end

seq_init = [rand(6) for i = 1:3]
seq_1 = [rand(6) for i = 1:5]
seq_2 = [rand(6) for i = 1:5]

y1 = rand(5)
y2 = rand(5)

X = [seq_1, seq_2]
Y = [y1, y2]
dt = zip(X,Y)

Flux.reset!(m)
m.(seq_init)

ps = params(m)
opt= ADAM(1e-3)
Flux.train!(loss, ps, dt, opt)

In this previous example, there's a warmup period of length 3 that has been applied. Model's state is first reset, then the warmup sequence goes into the model, resulting in a warmup state. Then, the model can be trained for 1 epoch, where 2 batches are provided (seq_1 and seq_2) and where all the the timesteps are considered (we no longer use a subset of m.(seq) in the loss function).

In this scenario, it is important to note that a single continuous sequence is considered. Since the model state is not reset between the 2 batches, the state of the model is maintained, which only makes sense in the context where seq_1 is the continuation of seq_init and so on.

Batch size would be 1 here as there's only a single sequence within each batch. If the model was to be trained on multiple independent sequences, then these sequences could be added to the input data as a second dimension. For example, in a language model, each batch would contain multiple independent sentences. In such scenario, a single batch would be of the shape:

seq = [rand(6, 4) for i = 1:5]

Which would mean that we have 4 sentences (or samples), each with 6 features (let's say a very small embedding) and each with a length of 5 (5 words per sentence). Doing m(seq[1]), would still represent x0 -> h0 in the Colah's diagram, that is, the first word output for each of the 4 independent sentences (vector of length 4) since we have now introduced a batch size of 4.
Also, in a language model, since each batch aren't typically the continuation of previous batch sentences, we need to reset the mode state for each batch, which would result in the following loss function:

function loss(seq,y)
  Flux.reset!(m)
  sum((Flux.stack(m.(seq),1) .- y) .^ 2)
end

Hope these bring some clarifications. I think an aspect of ambiguity with RNN is that data is 3 dimensional: features, seq length and samples, though in flux, those 3 dimensions provided through a vector of seq length containing matrix of [features X samples]. I think a language model with multiple sentences being trained simultaneously best provide an example of the relevance of having both multiple time steps and multiple samples. So, in a given mini-batch training, the gradients are updated on both multiple time steps and multiple samples.

@VoSiLk
Copy link
Author

VoSiLk commented Nov 19, 2020

Thank you very much for this comprehensive answer and example. This helps me a lot.

@luboshanus
Copy link

Hi, The last example is nice and illustrative. I tried to go from the and in the post (below). @jeremiedb uses in the loss function Flux.stack() and when going through the loss I just don't get the dimension here. The loss works, I understand that size of Y[1] is batch_size x seq_len, however the Flux.stack still does not reduce dimension of the model's output. And then the substraction .- creates much bigger array.

julia> Y[1]
256×20 Array{Float32,2}:
 0.0884711   0.00465298  0.0696806     0.0348138   0.0749678   0.0218575
...

julia> Flux.stack(model.(X[1]),2)
1×20×256 Array{Float32,3}:
[:, :, 1] =
 0.0674857  0.0597369  0.0714865    0.0679729  0.057861  0.0686841
...

julia> Y[1] .- Flux.stack(model.(X[1]), 2)
256×20×256 Array{Float32,3}:
[:, :, 1] =
  0.0209854   -0.0550839   -0.00180598      0.0171068   -0.0468266

Here I don't see it proper 256x20x256. It might be also related to the speed, big array...

So my questions are:
Q1) Is this really a way to do loss for such types of sequences? (I wish to have batch size greater than 1)
Q2) (maybe silly one) If I have a time series and its 6 features I don't understand the dimension of seq_len here. My only intuition tells me that full-time series length equals to seq_lenxbatch_size?

Also, if I remove the dimension of seq_len and just number of batches times batch_size gives me full sample of a time series. Is this correct?

Thanks a lot for the clarification.

What do you mean by slow?

Maybe validate the data shapes. In the example below: 100 batches, batch size of 256, seq length of 20. And data features 6 as per you model definition. A single epoch takes 2.8 sec on a GTX 1660 GPU.

feat = 6
batch_size = 256
num_batches = 100
seq_len = 20

X = [[rand(Float32, feat, batch_size) for i in 1:seq_len] for batch in 1:num_batches];
Y = [rand(Float32, batch_size, seq_len) ./ 10  for batch in 1:num_batches];

X = X |> gpu;
Y = Y |> gpu;
data = zip(X, Y);

opt = ADAM(0.001, (0.9, 0.999))

function loss(X,Y)
    Flux.reset!(model)
    mse_val = sum(abs2.(Y .- Flux.stack(model.(X), 2)))
    return mse_val
end
.....

@jeremiedb
Copy link
Contributor

In the first model, there was a missing ingredient: x -> reshape(x, :), so it should read:

model = Chain(LSTM(6, 70), LSTM(70, 70), LSTM(70, 70), Dense(70, 1, relu), x -> reshape(x, :)) |> gpu

You should then get a matching 256x20 dimensions, which matches Y:

julia> Flux.stack(model.(X[1]), 2)
256×20 CUDA.CuArray{Float32,2}:
  1. The abovev is a correct way to perform a loss AFAIK.
  2. If you have 6 features, then data should be organized as in the discussed example: X = [[rand(Float32, feat, batch_size) for i in 1:seq_len] for batch in 1:num_batches];. If you have 6 features, where would they fit in [seq, batch_size]? seq_length is the number of time steps. At each timestep, you measure your 6 features. 20 timesteps, 6 features = 20 X 6. And if you have 256 sensors capturing those 6 features during 20 timesteps, then you have this length Vector of length 20 containing 6 X 256 matrices.

@CarloLucibello
Copy link
Member

@jeremiedb you should consider adding your comment of ten days ago somewhere in the docs, it's a very nice explanation

@luboshanus
Copy link

@jeremiedb Thanks for the explanation about the loss and the data shape. I got it.

I would like to ask about the batches you are describing. Let's assume I have only one realization of a time series (one sensor). My time series has a length 256, features that explain it are 6 and then I want to have batches (mini-batches ?) here.

feat = 6, seq_len = 256, batch_size = 32 (mini-batch), num_batches = 8 (= seq_len/batch_size) so, in this case, I do not need the parameter seq_len to generate my data set? Is it correct?

Is it just [rand(feat, batch_size) for i in 1:num_batches]? It covers the whole sample period.

I'm just trying to figure out how to properly prepare time-series data, such as macroeconomic data, estimation similar to vector autoregression. I want to capture the time-effect using LSTM, but what is batch size to see the propagation/dependence from the past. Thanks a lot.

@jeremiedb
Copy link
Contributor

Is it just [rand(feat, batch_size) for i in 1:num_batches]?

I think you skipped a key notion in the explanation about the data format: consecutive timesteps are handled through a Vector of length seq_len, so: [rand(feat, 1) for i in 1:seq_len]. The batch_size refers to independent observations.

If the the data shape is feat x 32, then this will be treated as 32 independent signal streams. Just like if you use a vanilla NN with Dense(6, 1). If you feed a data of 6 X 32, the network treat independently each of the 32 observations/samples. It's the very same for an RNN model, if you call your RNN model m on a 6 x 32 data, you get 32 independent output for a single timestep.
The recurrence if obtained through the successive call to the RNN m model, hence the usage of the broadcasting: m.(x) to roll over all the timesteps of a given batch.

What happens is essentially equivalent to:

for i in 1:seq_len
 m(X[i])
end

At each step of the iteration, m stores the updated state of the RNN following a single timestep update. If you had a batch_size = 32, that state would track those 32 independent states of the independent observations. In you have 1 stream, then batch size = 1. Best to play with some toy model to get a sense of the behavior.

@KronosTheLate
Copy link
Contributor

I am also getting caused by: CUDNNError: CUDNN_STATUS_BAD_PARAM (code 3) with some frequency when training a large number of networks, all consisting of first a block of convolutional layers, and secondly a block of dense layers. So no RNN's. I am very uncertain however on what causes the problem - the networks are all validated to be valid configurations by running a single prediction.

@ToucheSir
Copy link
Member

different cuDNN calls are used on the backwards pass, so validating with inference only is not enough unfortunately. Just to clarify, are these consistent or sporadic errors? Could you put together a MWE and open an issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants