diff --git a/Project.toml b/Project.toml index 22b3bad..02d961f 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "1.5.3" [deps] LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" @@ -35,10 +34,9 @@ GenericLinearAlgebra = "0.3.11" LaTeXStrings = "1" Latexify = "0.15, 0.16" LinearAlgebra = "1" -Pkg = "1.11.0" PrecompileTools = "1.1.1" Random = "1" Requires = "1" -StaticArrays = "1" +StaticArrays = "1.8.1" Symbolics = "0.1, 1, 2, 3, 4, 5, 6" julia = "1.6" diff --git a/ext/QuaternionicChainRulesCoreExt.jl b/ext/QuaternionicChainRulesCoreExt.jl index ade868e..57b7356 100644 --- a/ext/QuaternionicChainRulesCoreExt.jl +++ b/ext/QuaternionicChainRulesCoreExt.jl @@ -1,6 +1,5 @@ module QuaternionicChainRulesCoreExt -using Pkg using Quaternionic import Quaternionic: _sincu, _cossu using StaticArrays @@ -8,40 +7,6 @@ isdefined(Base, :get_extension) ? (using ChainRulesCore; import ChainRulesCore: rrule, rrule_via_ad, RuleConfig, ProjectTo) : (using ..ChainRulesCore; import ...ChainRulesCore: rrule, rrule_via_ad, RuleConfig, ProjectTo) - -## StaticArrays -# It's likely that StaticArrays will have its own ChainRulesCore extension someday, so we -# need to check if there is already a ProjectTo defined for SArray. If so, we'll use that. -# If not, we'll define one here. -staticarrays_info = Pkg.dependencies()[Base.UUID("90137ffa-7385-5640-81b9-e52037218182")] -if staticarrays_info.version < v"1.8.1" - # These are ripped from https://github.com/JuliaArrays/StaticArrays.jl/pull/1068 - function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::SArray) - dy = reshape(dx, axes(project.elements)) # allows for dx::OffsetArray - dz = ntuple(i -> project.elements[i](dy[i]), length(project.elements)) - return ChainRulesCore.project_type(project)(dz...) - end - function ProjectTo(x::SArray{S,T}) where {S, T} - return ProjectTo{SArray}(; - element=ChainRulesCore._eltype_projectto(T), - axes=axes(x), size=StaticArrays.Size(x) - ) - end - @inline _sarray_from_array(::Size{T}, dx::AbstractArray) where {T} = SArray{Tuple{T...}}(dx) - (project::ProjectTo{SArray})(dx::AbstractArray) = _sarray_from_array(project.size, dx) - function rrule(::Type{T}, x::Tuple) where {T <: SArray} - project_x = ProjectTo(x) - ∇Array(∂y) = (NoTangent(), project_x(∂y)) - return T(x), ∇Array - end - function rrule(::Type{T}, xs::Number...) where {T <: SVector} - project_x = ProjectTo(xs) - ∇Array(∂y) = (NoTangent(), project_x(∂y)...) - return T(xs...), ∇Array - end -end - - function rrule(::Type{QT}, arg::AbstractVector) where {QT<:AbstractQuaternion} AbstractQuaternion_pullback(Δquat) = (NoTangent(), components(unthunk(Δquat))) return QT(arg), AbstractQuaternion_pullback