Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replaced deprecated Zygote.ignore by ChainRulesCore.ignore #287

Merged
merged 1 commit into from
Nov 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.4.1"
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Cairo = "159f3aea-2a34-519c-b102-8c37f9878175"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compose = "a81c6b42-2e10-5240-aca2-a61377ecd94b"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down Expand Up @@ -36,6 +37,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
Cairo = "^1"
ChainRulesCore = "0.8 - 1.15.6"
Compose = "< 0.9.2"
CUDA = "2 - 3"
DataStructures = "0.15.0 - 0.18.13"
Expand Down
2 changes: 1 addition & 1 deletion src/RL/nn_structures/cpnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ function (nn::CPNN)(states::BatchedDefaultTrajectoryState)

# extract the feature(s) of the variable(s) we're working on
indices = nothing
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
indices = CartesianIndex.(zip(variableIdx, 1:batchSize))
end
variableFeatures = nodeFeatures[:, indices]
Expand Down
4 changes: 2 additions & 2 deletions src/RL/nn_structures/fullfeaturedcpnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ function (nn::FullFeaturedCPNN)(states::BatchedDefaultTrajectoryState)

# Extract the features corresponding to the variables
variableIndices = nothing
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
variableIndices = Flux.unsqueeze(CartesianIndex.(variableIdx, 1:batchSize), 1)
end
variableFeatures = nodeFeatures[:, variableIndices] # Fx1xB
variableFeatures = reshape(nn.nodeChain(RL.flatten_batch(variableFeatures)), :, 1, batchSize) # F'x1xB

# Extract the features corresponding to the values
valueIndices = nothing
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
valueIndices = CartesianIndex.(allValuesIdx, repeat(transpose(1:batchSize); outer=(actionSpaceSize, 1)))
end
valueFeatures = nodeFeatures[:, valueIndices] # FxAxB
Expand Down
2 changes: 1 addition & 1 deletion src/RL/nn_structures/heterogeneouscpnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function (nn::HeterogeneousCPNN)(states::BatchedHeterogeneousTrajectoryState)
globalFeatures = fg.gf
# extract the feature(s) of the variable(s) we're working on
indices = nothing
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
# Double check that we are extracting the right variable
indices = CartesianIndex.(zip(variableIdx, 1:batchSize))
end
Expand Down
6 changes: 3 additions & 3 deletions src/RL/nn_structures/heterogeneousfullfeaturedcpnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ function (nn::HeterogeneousFullFeaturedCPNN)(states::BatchedHeterogeneousTraject
batchSize = length(variableIdx)
actionSpaceSize = size(states.fg.valnf, 2)
Mask = nothing
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
Mask = device(states) != Val{:cpu}() ? CUDA.zeros(Float32, 1, size(states.fg.varnf,2), actionSpaceSize, batchSize) : zeros(Float32, 1, size(states.fg.varnf,2), actionSpaceSize, batchSize) # this Mask will replace `reapeat` using broadcasted `+`
end
# chain working on the graph(s) with the GNNs
Expand All @@ -55,7 +55,7 @@ function (nn::HeterogeneousFullFeaturedCPNN)(states::BatchedHeterogeneousTraject

# Extract the features corresponding to the varibales
variableIndices = nothing
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
variableIndices = Flux.unsqueeze(CartesianIndex.(variableIdx, 1:batchSize), 1)
end
branchingVariableFeatures = nn.varChain(variableFeatures) # Fx1xB
Expand Down Expand Up @@ -92,7 +92,7 @@ function (nn::HeterogeneousFullFeaturedCPNN)(states::BatchedHeterogeneousTraject
predictions = permutedims(predictions, [1,3,2,4])
#variableIndices = Flux.unsqueeze(CartesianIndex.(variableIdx, 1:batchSize), 1)|> gpu
variableIndices = nothing
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
variableIndices = device(states) != Val{:cpu}() ? CuArray(Flux.unsqueeze(CartesianIndex.(variableIdx, 1:batchSize), 1)) : Flux.unsqueeze(CartesianIndex.(variableIdx, 1:batchSize), 1)
end
output = dropdims(predictions[:,:, variableIndices], dims = tuple(findall(size(predictions[:,:, variableIndices]) .== 1)...))
Expand Down
18 changes: 9 additions & 9 deletions src/RL/nn_structures/heterogeneousvariableoutputcpnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,19 +57,19 @@ function (nn::HeterogeneousVariableOutputCPNN)(state::BatchedHeterogeneousTrajec

# extract the feature(s) of the variable(s) we're working on
indices = nothing
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
indices = CartesianIndex.(zip(variableIdx, 1:batchSize))
end
variableFeatures = nothing
numPadded = nothing
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
variableFeatures = reshape(fg.varnf[:, indices], (:,1,batchSize)) # Fx1xB
# extract the feature(s) of the variable(s) we're working on
numPadded = [maxActionSpaceSize - actionSpaceSizes[i] for i in 1:batchSize] #number of padding zeros needed fo each element of the batch
end

valueIndices = nothing
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
paddedPossibleValuesIdx = [append!(possibleValuesIdx[i], repeat([possibleValuesIdx[i][1]], numPadded[i])) for i in 1:batchSize]
paddedPossibleValuesIdx = mapreduce(identity, hcat, paddedPossibleValuesIdx) #convert from Vector to Matrix
#create a CartesianIndex matrix of size (maxActionSpaceSize x batch_size)
Expand All @@ -79,7 +79,7 @@ function (nn::HeterogeneousVariableOutputCPNN)(state::BatchedHeterogeneousTrajec
valueFeatures = fg.valnf[:, valueIndices] #FxAxB

f = size(valueFeatures, 1)
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
for i in 1:batchSize
for j in 1:numPadded[i]
valueFeatures[:,maxActionSpaceSize-j+1,i] = zeros(Float32, f)
Expand All @@ -92,21 +92,21 @@ function (nn::HeterogeneousVariableOutputCPNN)(state::BatchedHeterogeneousTrajec

variableOutput = nothing
valueOutput = nothing
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
variableOutput = reshape(chainOutput[:,1,:], (:,1,batchSize)) #F'xB
valueOutput = chainOutput[:,2:end,:] #F'xAxB
end

finalInput = nothing
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
finalInput = []
for i in 1:batchSize
singleFinalInput = vcat(repeat(variableOutput[:,:,i], 1, maxActionSpaceSize), valueOutput[:,:,i]) #one element of the batch
finalInput = isempty(finalInput) ? [singleFinalInput] : append!(finalInput, [singleFinalInput])
end
end
#finalInput: vector of matrices of size F'xA (total size BxF'xA)
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
f, a = size(finalInput[1])
finalInput = reshape(collect(Iterators.flatten(finalInput)), (f, maxActionSpaceSize, batchSize)) #!!TO TEST #convert vector of matrices into a 3-dimensional matrix

Expand All @@ -115,15 +115,15 @@ function (nn::HeterogeneousVariableOutputCPNN)(state::BatchedHeterogeneousTrajec
output = dropdims(nn.outputChain(finalInput); dims=1) #AxB

finalOutput = nothing
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
finalOutput = reshape(
Float32[-Inf32 for _ in 1:(size(fg.valnf,2)*size(fg.valnf,3))],
size(fg.valnf,2),
size(fg.valnf,3)
)
end

Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
for i in 1:batchSize
for j in 1:actionSpaceSizes[i]
#note that possibleValuesIdx is a Vector{Vector{Int64}} while output and finalOutput are Matrix{Int64}
Expand Down
18 changes: 9 additions & 9 deletions src/RL/nn_structures/variableoutputcpnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,19 @@ function (nn::VariableOutputCPNN)(state::BatchedDefaultTrajectoryState)

# extract the feature(s) of the variable(s) we're working on
indices = nothing
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
indices = CartesianIndex.(zip(variableIdx, 1:batchSize))
end
variableFeatures = nothing
numPadded = nothing
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
variableFeatures = reshape(nodeFeatures[:, indices], (:,1,batchSize)) # Fx1xB
# extract the feature(s) of the variable(s) we're working on
numPadded = [maxActionSpaceSize - actionSpaceSizes[i] for i in 1:batchSize] #number of padding zeros needed fo each element of the batch
end

valueIndices = nothing
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
paddedPossibleValuesIdx = [append!(possibleValuesIdx[i], repeat([possibleValuesIdx[i][1]], numPadded[i])) for i in 1:batchSize]
paddedPossibleValuesIdx = mapreduce(identity, hcat, paddedPossibleValuesIdx) #convert from Vector to Matrix
#create a CartesianIndex matrix of size (maxActionSpaceSize x batch_size)
Expand All @@ -56,7 +56,7 @@ function (nn::VariableOutputCPNN)(state::BatchedDefaultTrajectoryState)
valueFeatures = nodeFeatures[:, valueIndices] #FxAxB

f = size(valueFeatures, 1)
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
for i in 1:batchSize
for j in 1:numPadded[i]
valueFeatures[:,maxActionSpaceSize-j+1,i] = zeros(Float32, f)
Expand All @@ -69,35 +69,35 @@ function (nn::VariableOutputCPNN)(state::BatchedDefaultTrajectoryState)

variableOutput = nothing
valueOutput = nothing
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
variableOutput = reshape(chainOutput[:,1,:], (:,1,batchSize)) #F'xB
valueOutput = chainOutput[:,2:end,:] #F'xAxB
end

finalInput = nothing
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
finalInput = []
for i in 1:batchSize
singleFinalInput = vcat(repeat(variableOutput[:,:,i], 1, maxActionSpaceSize), valueOutput[:,:,i])
finalInput = isempty(finalInput) ? [singleFinalInput] : append!(finalInput, [singleFinalInput])
end
end
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
f, a = size(finalInput[1])
finalInput = reshape(collect(Iterators.flatten(finalInput)), (f, maxActionSpaceSize, batchSize)) #!!TO TEST #convert vector of matrices into a 3-dimensional matrix
end

output = dropdims(nn.outputChain(finalInput); dims=1) #AxB
finalOutput = nothing
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
finalOutput = reshape(
Float32[-Inf32 for _ in 1:(size(state.allValuesIdx,1)*size(state.allValuesIdx,2))],
size(state.allValuesIdx,1),
size(state.allValuesIdx,2)
)
end

Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
for i in 1:batchSize
for j in 1:actionSpaceSizes[i]
#the order of a vertex_id in allValuesIdx
Expand Down
2 changes: 1 addition & 1 deletion src/RL/representation/default/defaulttrajectorystate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ function Flux.functor(::Type{Vector{DefaultTrajectoryState}}, v)
allValuesIdx = zeros(Int, maxActions, batchSize)
end

Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
variableIdx = (state -> state.variableIdx).(v)
possibleValuesIdx = (state -> state.possibleValuesIdx).(v)
# TODO: this could probably be optimized
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ function Flux.functor(::Type{Vector{HeterogeneousTrajectoryState}}, v)
variableIdx = ones(Int, batchSize)
possibleValuesIdx = nothing

Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
variableIdx = (state -> state.variableIdx).(v)
possibleValuesIdx = (state -> state.possibleValuesIdx).(v)
end
Expand Down
8 changes: 4 additions & 4 deletions src/RL/utils/geometricflux/graphconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ end
function (g::GraphConv{<:AbstractMatrix,<:Any,meanPooling})(fgs::BatchedFeaturedGraph{Float32})
A, X = fgs.graph, fgs.nf
sumVal = nothing
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
sumVal = replace(reshape(mapslices(x -> sum(eachrow(x)), A, dims=[1, 2]), 1, :, size(A, 3)), 0=>1)
end
A = A ./ sumVal
Expand All @@ -58,7 +58,7 @@ function (g::GraphConv{<:AbstractMatrix,<:Any,meanPooling})(fg::FeaturedGraph)
A, X = fg.graph, fg.nf

sum = reshape(sum(eachrow(A)), 1, :)
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
replace(sum, 0=>1)
end
A = A ./ sum
Expand All @@ -79,7 +79,7 @@ This function operates the coordinate-wise max-Pooling technique along the neigh
function (g::GraphConv{<:AbstractMatrix,<:Any,maxPooling})(fgs::BatchedFeaturedGraph{Float32})
A, X = fgs.graph, fgs.nf

Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
B = repeat(collect(1:size(A,2)),1,size(A,2)).*A
B = collect.(zip(B,cat(repeat([1],size(A)[1:2]...),repeat([2], size(A)[1:2]...),dims= 3)))
filteredcol = mapslices( x -> map(z-> filter(y -> y[1]!=0,z),eachcol(x)), B, dims= [1,2])
Expand Down Expand Up @@ -114,7 +114,7 @@ This function operates the coordinate-wise max-Pooling technique along the neigh
"""
function (g::GraphConv{<:AbstractMatrix,<:Any, maxPooling})(fg::FeaturedGraph)
A, X = fg.graph, fg.nf
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
B = repeat(collect(1:size(A,2)),1,size(A,2)).*A
filteredcol = map(x-> filter(y -> y!=0,x),eachcol(B))
filteredemb = mapreduce(x->maximum(X[:,x], dims = 2), hcat,filteredcol)
Expand Down
12 changes: 6 additions & 6 deletions src/RL/utils/geometricflux/heterogeneousgraphconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ function (g::HeterogeneousGraphConv{<:AbstractMatrix,<:Any,meanPooling})(fgs::Ba
X1, X2, X3 = original_fgs.varnf, original_fgs.connf, original_fgs.valnf

MatVar, MatCon, MatVal = nothing,nothing,nothing
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
sumcontovar = sum(contovar, dims = 1)
sumvaltovar = sum(valtovar, dims = 1)
sumvartocon = sum(vartocon, dims = 1)
Expand All @@ -89,7 +89,7 @@ function (g::HeterogeneousGraphConv{<:AbstractMatrix,<:Any,meanPooling})(fgs::Ba
XX2 = g.σ.(g.weightscon ⊠ MatCon .+ g.biascon)
XX3 = g.σ.(g.weightsval ⊠ MatVal .+ g.biasval)

Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
return BatchedHeterogeneousFeaturedGraph{Float32}(
contovar,
valtovar,
Expand All @@ -109,7 +109,7 @@ function (g::HeterogeneousGraphConv{<:AbstractMatrix,<:Any,meanPooling})(fg::Het

MatVar, MatCon, MatVal = nothing,nothing,nothing
sumcontovar,sumvaltovar,sumvartocon,sumvartoval = nothing,nothing,nothing,nothing
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
sumcontovar = sum(contovar, dims =1)
sumvaltovar = sum(valtovar, dims =1)
sumvartocon = sum(vartocon, dims =1)
Expand All @@ -134,7 +134,7 @@ function (g::HeterogeneousGraphConv{<:AbstractMatrix,<:Any,meanPooling})(fg::Het
XX1 = g.σ.(g.weightsvar * MatVar .+ g.biasvar)
XX2 = g.σ.(g.weightscon * MatCon .+ g.biascon)
XX3 = g.σ.(g.weightsval * MatVal.+ g.biasval)
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
return HeterogeneousFeaturedGraph(
contovar,
valtovar,
Expand Down Expand Up @@ -164,7 +164,7 @@ function (g::HeterogeneousGraphConv{<:AbstractMatrix,<:Any, maxPooling})(fgs::Ba
filteredembvaltovar = nothing
filteredembvartoval = nothing

Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
contovarIdx = repeat(collect(1:size(contovar,1)),1,size(contovar,2)).*contovar
valtovarIdx = repeat(collect(1:size(valtovar,1)),1,size(valtovar,2)).*valtovar
vartoconIdx = repeat(collect(1:size(vartocon,1)),1,size(vartocon,2)).*vartocon
Expand Down Expand Up @@ -215,7 +215,7 @@ function (g::HeterogeneousGraphConv{<:AbstractMatrix,<:Any, maxPooling})(fg::Het
filteredembvaltovar = nothing
filteredembvartocon = nothing
filteredembvartoval = nothing
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do

contovarIdx = device(contovar) != Val{:cpu}() ? CuArray(repeat(collect(1:size(contovar,1)),1,size(contovar,2))) : repeat(collect(1:size(contovar,1)),1,size(contovar,2))
valtovarIdx = device(valtovar) != Val{:cpu}() ? CuArray(repeat(collect(1:size(valtovar,1)),1,size(valtovar,2))) : repeat(collect(1:size(valtovar,1)),1,size(valtovar,2))
Expand Down
10 changes: 5 additions & 5 deletions src/RL/utils/geometricflux/heterogeneousgraphtransformer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ function (g::HeterogeneousGraphTransformer)(fgs::BatchedHeterogeneousFeaturedGra
temp_attention_vartocon = cat([(vartocon[:,:,i] .+ zeros(nvar, ncon, g.heads)) .* ATT_head_vartocon[:,:,:,i] for i in 1:batch_size]..., dims=4)
temp_attention_valtovar = cat([(valtovar[:,:,i] .+ zeros(nval, nvar, g.heads)) .* ATT_head_valtovar[:,:,:,i] for i in 1:batch_size]..., dims=4)
temp_attention_vartoval = cat([(vartoval[:,:,i] .+ zeros(nvar, nval, g.heads)) .* ATT_head_vartoval[:,:,:,i] for i in 1:batch_size]..., dims=4)
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
temp_attention_contovar = replace(temp_attention_contovar, 0.0 => -Inf)
temp_attention_vartocon = replace(temp_attention_vartocon, 0.0 => -Inf)
temp_attention_valtovar = replace(temp_attention_valtovar, 0.0 => -Inf)
Expand All @@ -102,7 +102,7 @@ function (g::HeterogeneousGraphTransformer)(fgs::BatchedHeterogeneousFeaturedGra
attention_vartocon = softmax(temp_attention_vartocon; dims=2) # nvar x ncon x heads x batch
attention_valtovar = softmax(temp_attention_valtovar; dims=2) # nval x nvar x heads x batch
attention_vartoval = softmax(temp_attention_vartoval; dims=2) # nvar x nval x heads x batch
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
attention_contovar = replace(attention_contovar, NaN => 0)
attention_vartocon = replace(attention_vartocon, NaN => 0)
attention_valtovar = replace(attention_valtovar, NaN => 0)
Expand All @@ -128,7 +128,7 @@ function (g::HeterogeneousGraphTransformer)(fgs::BatchedHeterogeneousFeaturedGra
new_H1 = permutedims(σ.(H_tilde_1) ⊠ g.a_lin_var + H1,[2,1,3]) # n x nvar x batch
new_H2 = permutedims(σ.(H_tilde_2) ⊠ g.a_lin_con + H2,[2,1,3]) # n x ncon x batch
new_H3 = permutedims(σ.(H_tilde_3) ⊠ g.a_lin_val + H3,[2,1,3]) # n x nval x batch
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
return BatchedHeterogeneousFeaturedGraph{Float32}(
contovar,
valtovar,
Expand Down Expand Up @@ -173,7 +173,7 @@ function (g::HeterogeneousGraphTransformer)(fg::HeterogeneousFeaturedGraph)
temp_attention_vartocon = (vartocon .+ zeros(nvar, ncon, g.heads)) .* ATT_head_vartocon
temp_attention_valtovar = (valtovar .+ zeros(nval, nvar, g.heads)) .* ATT_head_valtovar
temp_attention_vartoval = (vartoval .+ zeros(nvar, nval, g.heads)) .* ATT_head_vartoval
Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
temp_attention_contovar = replace(temp_attention_contovar, 0.0 => -Inf)
temp_attention_vartocon = replace(temp_attention_vartocon, 0.0 => -Inf)
temp_attention_valtovar = replace(temp_attention_valtovar, 0.0 => -Inf)
Expand All @@ -184,7 +184,7 @@ function (g::HeterogeneousGraphTransformer)(fg::HeterogeneousFeaturedGraph)
attention_valtovar = softmax(temp_attention_valtovar; dims=2) # nval x nvar x heads
attention_vartoval = softmax(temp_attention_vartoval; dims=2) # nvar x nval x heads

Zygote.ignore() do
ChainRulesCore.ignore_derivatives() do
attention_contovar = replace(attention_contovar, NaN => 0)
attention_vartocon = replace(attention_vartocon, NaN => 0)
attention_valtovar = replace(attention_valtovar, NaN => 0)
Expand Down
4 changes: 2 additions & 2 deletions src/SeaPearl.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
module SeaPearl

using ChainRulesCore
using Graphs
using Random
using ReinforcementLearning
using Graphs
const RL = ReinforcementLearning

include("abstract_types.jl")

include("trailer.jl")
include("CP/CP.jl")
#include("MOI_wrapper/MOI_wrapper.jl")
Expand Down