Skip to content

Commit

Permalink
Fix time derivative of SymTensorValues
Browse files Browse the repository at this point in the history
  • Loading branch information
Antoinemarteau committed Sep 19, 2024
1 parent 78fcbf1 commit 9f6eb95
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 6 deletions.
17 changes: 11 additions & 6 deletions src/ODEs/TimeDerivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,20 @@ function _time_derivative(T::Type{<:Real}, f, t, x)
ForwardDiff.derivative(partial, t)
end

function _time_derivative(T::Type{<:VectorValue}, f, t, x)
function _time_derivative(T::Type{<:MultiValue}, f, t, x)
partial(t) = get_array(f(t)(x))
VectorValue(ForwardDiff.derivative(partial, t))
T(ForwardDiff.derivative(partial, t))
end

function _time_derivative(T::Type{<:TensorValue}, f, t, x)
partial(t) = get_array(f(t)(x))
TensorValue(ForwardDiff.derivative(partial, t))
end
#function _time_derivative(T::Type{<:VectorValue}, f, t, x)
# partial(t) = get_array(f(t)(x))
# VectorValue(ForwardDiff.derivative(partial, t))
#end
#
#function _time_derivative(T::Type{<:TensorValue}, f, t, x)
# partial(t) = get_array(f(t)(x))
# TensorValue(ForwardDiff.derivative(partial, t))
#end

##########################################
# Specialisation for `TimeSpaceFunction` #
Expand Down
36 changes: 36 additions & 0 deletions test/ODEsTests/TimeDerivativesTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using ForwardDiff

using Gridap
using Gridap.ODEs
using Gridap.TensorValues

# First time derivative, scalar-valued
f1(t) = x -> 5 * x[1] * x[2] + x[2]^2 * t^3
Expand Down Expand Up @@ -73,6 +74,41 @@ for (f, ∂tf) in ((f1, ∂tf1),)
@test ∂t(F)(tv)(xv) ∂tf(tv)(xv)
end

# First time derivative, symmetric tensor-valued
f1(t) = x -> SymTensorValue(x[1] * t, x[1] * x[2], x[2] * t^2)
∂tf1(t) = x -> SymTensorValue(x[1], zero(x[1]), 2 * x[2] * t)

for (f, ∂tf) in ((f1, ∂tf1),)
dtf(t) = x -> SymTensorValue(ForwardDiff.derivative(t -> get_array(f(t)(x)), t))

tv = rand(Float64)
xv = Point(rand(Float64, 2)...)
@test ∂t(f)(tv)(xv) ∂tf(tv)(xv)
@test ∂t(f)(tv)(xv) dtf(tv)(xv)

F = TimeSpaceFunction(f)
@test F(tv)(xv) f(tv)(xv)
@test ∂t(F)(tv)(xv) ∂tf(tv)(xv)
end

# First time derivative, symmetric traceless tensor-valued
f1(t) = x -> SymTracelessTensorValue(x[1] * t, x[2] * t^2)
∂tf1(t) = x -> SymTracelessTensorValue(x[1], 2 * x[2] * t)

for (f, ∂tf) in ((f1, ∂tf1),)
dtf(t) = x -> SymTensorValue(ForwardDiff.derivative(t -> get_array(f(t)(x)), t))

tv = rand(Float64)
xv = Point(rand(Float64, 2)...)
@test ∂t(f)(tv)(xv) ∂tf(tv)(xv)
@test ∂t(f)(tv)(xv) dtf(tv)(xv)

F = TimeSpaceFunction(f)
@test F(tv)(xv) f(tv)(xv)
@test ∂t(F)(tv)(xv) ∂tf(tv)(xv)
end


# Spatial derivatives
ft(t) = x -> x[1]^2 * t + x[2]
f = TimeSpaceFunction(ft)
Expand Down

0 comments on commit 9f6eb95

Please sign in to comment.