diff --git a/src/linear.jl b/src/linear.jl index 98e2b59..b9a5aea 100644 --- a/src/linear.jl +++ b/src/linear.jl @@ -47,28 +47,34 @@ end LinearLayer(in_dim::Int, out_dim::Int; feature_first = false, use_cache = true) = LinearLayer{feature_first}(in_dim, out_dim, use_cache, _make_reqfields()...) +_valtype(l::LinearLayer, x::AbstractArray, ps) = + promote_type(eltype(x), eltype(ps.W)) -(l::LinearLayer)(x::AbstractVector, ps, st) = begin - out = acquire!(st.pool, :A, (l.out_dim, ), eltype(x)); +function (l::LinearLayer)(x::AbstractVector, ps, st) + out = acquire!(st.pool, :A, (l.out_dim, ), _valtype(l, x, ps)) mul!(unwrap(out), ps.W, unwrap(x)); release!(x); return out, st end -(l::LinearLayer{true})(x::AbstractMatrix, ps, st) = begin - out = acquire!(st.pool, :bA, (l.out_dim, size(x, 2)), eltype(x)); +function (l::LinearLayer{true})(x::AbstractMatrix, ps, st) + out = acquire!(st.pool, :bA, (l.out_dim, size(x, 2)), _valtype(l, x, ps)); mul!(unwrap(out), ps.W, unwrap(x)); release!(x); return out, st end (l::LinearLayer{false})(x::AbstractMatrix, ps, st) = begin - out = acquire!(st.pool, :bA, (size(x, 1), l.out_dim), eltype(x)); + out = acquire!(st.pool, :bA, (size(x, 1), l.out_dim), _valtype(l, x, ps)); mul!(unwrap(out), unwrap(x), transpose(PtrArray(ps.W))); release!(x); return out, st end # Jerry: Maybe we should use Glorot Uniform if we have no idea about what we should use? -LuxCore.initialparameters(rng::AbstractRNG, l::LinearLayer) = ( W = randn(rng, l.out_dim, l.in_dim), ) -LuxCore.initialstates(rng::AbstractRNG, l::LinearLayer) = ( l.use_cache ? (pool = ArrayPool(FlexArrayCache), ) : (pool = ArrayPool(FlexArray), )) +LuxCore.initialparameters(rng::AbstractRNG, l::LinearLayer) = + ( W = randn(rng, l.out_dim, l.in_dim), ) + +LuxCore.initialstates(rng::AbstractRNG, l::LinearLayer) = + ( l.use_cache ? (pool = ArrayPool(FlexArrayCache), ) + : (pool = ArrayPool(FlexArray), )) # TODO: check whether we can do this without multiple dispatch on vec/mat without loss of performance function rrule(::typeof(LuxCore.apply), l::LinearLayer, x::AbstractVector, ps, st)