Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ProjectTo(::AbstractArray) does not infer #407

Open
mzgubic opened this issue Jul 22, 2021 · 3 comments
Open

ProjectTo(::AbstractArray) does not infer #407

mzgubic opened this issue Jul 22, 2021 · 3 comments
Labels
ProjectTo related to the projection functionality

Comments

@mzgubic
Copy link
Member

mzgubic commented Jul 22, 2021

e.g. see

julia> @inferred ProjectTo(rand(3, 3))(Diagonal(rand(3)))
ERROR: return type Diagonal{Float64, Vector{Float64}} does not match inferred return type Union{Base.ReshapedArray{Float64, 2, Diagonal{Float64, Vector{Float64}}, Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}}, Diagonal{Float64, Vector{Float64}}}

the reason for the failure is

dy = if axes(dx) == project.axes
dx
else
for d in 1:max(M, length(project.axes))
if size(dx, d) != length(get(project.axes, d, 1))
throw(_projection_mismatch(project.axes, size(dx)))
end
end
reshape(dx, project.axes)
end

which decides whether to dy is dx or a reshape(dx, ...) based on the values of the axes.

We could instead do

    dy = if length(axes(dx)) == length(project.axes)
        axes(dx) == project.axes || throw(_projection_mismatch(project.axes, size(dx)))
        dx
    else
        for d in 1:max(M, length(project.axes))
            if size(dx, d) != length(get(project.axes, d, 1))
                throw(_projection_mismatch(project.axes, size(dx)))
            end
        end
        reshape(dx, project.axes)
    end

which does infer, but throws an error for the arrays that need to be reshaped. In practice, this only means that only

        poffv = ProjectTo(OffsetArray(rand(3), 0:2))
        @test axes(poffv([1, 2, 3])) == (0:2,)

test fails in ChainRulesCore
and all the rrule inference tests (but not the one where inplaceablethunk inference fails) in JuliaDiff/ChainRules.jl#459 (comment) are fixed.

do we want this tradeoff or not?

@mzgubic
Copy link
Member Author

mzgubic commented Jul 22, 2021

Lyndon says "It is a union of <=4 elements. It is fine."

My reply:
I found a “solution” with a small tradeoff.

It is a union of <=4 elements. It is fine.

I’ve heard small unions are fine, but never understood this really. Does this mean that a thing happens (JIT?) for each of the elements, and not all possible elements?
My concern (that may be unjustified) is that having a few of these small unions through the program means that the inferred union snowballs when functions are called

i.e. something like

julia> inner(x) = x > 0 ? 3.0 : 3
inner (generic function with 1 method)

julia> wrapper(x::Int) = x > 0 ? rand(2, 3) : Diagonal(rand(2))
wrapper (generic function with 1 method)

julia> wrapper(x::Float64) = x > 0 ? "hello" : :world
wrapper (generic function with 2 methods)

julia> together(x) = wrapper(inner(x))
together (generic function with 1 method)

julia> @code_warntype together(2.0)
Variables
  #self#::Core.Const(together)
  x::Float64
Body::Any
1%1 = Main.inner(x)::Union{Float64, Int64}%2 = Main.wrapper(%1)::Any
└──      return %2

where the result is actually Any rather than a Union of four types

@oxinabox
Copy link
Member

I wonder if we actually only want the reshape for if the dx is a Array.
That is the case that we know we will never get ReshapedArray out, but will just get another Array that reference the same memory.
The ReshapedArray is probably not a friend, it is still a view but idk how consistently it is used and how nice it plays with BLAS etc.
Maybe not having the reshape always work though would defeat utility of the feature

@oxinabox oxinabox added the ProjectTo related to the projection functionality label Jul 23, 2021
@mcabbott
Copy link
Member

Surely this can be made to infer.

The trick if length(axes(dx)) == length(project.axes) is clever, but what I think it will miss is that the reshape at present restores OffsetArrays, which tend to go missing e.g. hcat(OffsetArray(rand(3), 0:2)) isa Matrix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ProjectTo related to the projection functionality
Projects
None yet
Development

No branches or pull requests

3 participants