Skip to content

Commit

Permalink
Merge pull request #488 from JuliaHealth/fix-spinspan
Browse files Browse the repository at this point in the history
Fix bugs related with `SpinRange` and flow
  • Loading branch information
pvillacorta authored Sep 27, 2024
2 parents 356c48a + bfbf17c commit 48a5bbe
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 33 deletions.
4 changes: 2 additions & 2 deletions KomaMRIBase/src/motion/motionlist/Motion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,11 @@ Base.:(≈)(m1::Motion, m2::Motion) = (typeof(m1) == typeof(m2)) & reduce(&, [g
""" Motion sub-group """
function Base.getindex(m::Motion, p)
idx, spin_range = m.spins[p]
return Motion(m.action[idx], m.time, spin_range)
return idx !== nothing ? Motion(m.action[idx], m.time, spin_range) : nothing
end
function Base.view(m::Motion, p)
idx, spin_range = @view(m.spins[p])
return Motion(@view(m.action[idx]), m.time, spin_range)
return idx !== nothing ? Motion(@view(m.action[idx]), m.time, spin_range) : nothing
end

# Auxiliary functions
Expand Down
4 changes: 2 additions & 2 deletions KomaMRIBase/src/motion/motionlist/MotionList.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ MotionList(motions...) = length([motions]) > 0 ? MotionList([motions...]) : @err
function Base.getindex(mv::MotionList{T}, p) where {T<:Real}
motion_array_aux = Motion{T}[]
for m in mv.motions
push!(motion_array_aux, m[p])
m[p] !== nothing ? push!(motion_array_aux, m[p]) : nothing
end
return length(motion_array_aux) > 0 ? MotionList(motion_array_aux) : NoMotion{T}()
end
function Base.view(mv::MotionList{T}, p) where {T<:Real}
motion_array_aux = Motion{T}[]
for m in mv.motions
push!(motion_array_aux, @view(m[p]))
@view(m[p]) !== nothing ? push!(motion_array_aux, @view(m[p])) : nothing
end
return length(motion_array_aux) > 0 ? MotionList(motion_array_aux) : NoMotion{T}()
end
Expand Down
10 changes: 5 additions & 5 deletions KomaMRIBase/src/motion/motionlist/SpinSpan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ end
# Functions
function Base.getindex(spins::SpinRange, p)
idx = intersect_idx(spins.range, p)
return intersect_idx(p, spins.range), SpinRange(idx)
end
function Base.view(spins::SpinRange, p)
idx = intersect_idx(spins.range, p)
return intersect_idx(p, spins.range), SpinRange(idx)
l = length(idx)
intersect = l >= 1 ? intersect_idx(p, spins.range) : nothing
spin_range = l >= 2 ? SpinRange(idx) : (l == 1 ? SpinRange(idx[1]:idx[1]) : nothing)
return intersect, spin_range
end
Base.view(spins::SpinRange, p) = spins[p]
Base.:(==)(sr1::SpinRange, sr2::SpinRange) = sr1.range == sr2.range
Base.length(sr::SpinRange) = length(sr.range)
get_indexing_range(spins::SpinRange) = spins.range
Expand Down
5 changes: 2 additions & 3 deletions KomaMRIBase/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -555,14 +555,14 @@ end

simplemotion = MotionList(
Translate(0.05, 0.05, 0.0, Periodic(period=0.5, asymmetry=0.5)),
Rotate(0.0, 0.0, 90.0, TimeRange(t_start=0.05, t_end=0.5))
Rotate(0.0, 0.0, 90.0, TimeRange(t_start=0.05, t_end=0.5), SpinRange(1:3))
)

Ns = length(obj1)
Nt = 3
t_start = 0.0
t_end = 1.0
arbitrarymotion = MotionList(Path(0.01 .* rand(Ns, Nt), 0.01 .* rand(Ns, Nt), 0.01 .* rand(Ns, Nt), TimeRange(t_start, t_end)))
arbitrarymotion = MotionList(Path(0.01 .* rand(Ns, Nt), 0.01 .* rand(Ns, Nt), 0.01 .* rand(Ns, Nt), TimeRange(t_start, t_end), SpinRange(2:2:4)))

# Test phantom subset
obs1 = Phantom(
Expand Down Expand Up @@ -602,7 +602,6 @@ end
obs1.motion = arbitrarymotion
obs2.motion = arbitrarymotion[rng]
@test obs1[rng] == obs2
# @test @view(obs1[rng]) == obs2

# Test addition of phantoms
oba = Phantom(
Expand Down
9 changes: 9 additions & 0 deletions KomaMRICore/src/simulation/Flow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ function outflow_spin_reset!(
# Get spin state range affected by the spin span
idx = KomaMRIBase.get_indexing_range(spin_span)
spin_state_matrix = @view(spin_state_matrix[idx, :])
replace_by = replace_view(replace_by, idx)
# Obtain mask
mask = get_mask(action.spin_reset, ts)
# Modify spin state: reset and replace by initial value
Expand All @@ -53,6 +54,7 @@ function outflow_spin_reset!(
# Get spin state range affected by the spin span
idx = KomaMRIBase.get_indexing_range(spin_span)
M = @view(M[idx])
replace_by = replace_view(replace_by, idx)
# Obtain mask
mask = get_mask(action.spin_reset, ts)
mask = @view(mask[:, end])
Expand All @@ -72,6 +74,13 @@ function init_time(t, seq_t, add_t0)
return t
end

function replace_view(replace_by::AbstractArray, idx)
return @view(replace_by[idx])
end
function replace_view(replace_by, idx)
return replace_by
end

function get_mask(spin_reset, t::Real)
itp = KomaMRIBase.interpolate(spin_reset, KomaMRIBase.Gridded(KomaMRIBase.Constant{KomaMRIBase.Previous}()), Val(size(spin_reset, 1)), t)
return KomaMRIBase.resample(itp, t)
Expand Down
26 changes: 5 additions & 21 deletions KomaMRIPlots/src/ui/DisplayFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1062,16 +1062,8 @@ function plot_phantom_map(
end

function decimate_uniform_phantom(obj, num_points::Int)
dimx, dimy, dimz = KomaMRIBase.get_dims(obj)
ss = Int(ceil((length(obj) / num_points)^(1 / sum(KomaMRIBase.get_dims(obj)))))
ssx = dimx ? ss : 1
ssy = dimy ? ss : 1
ssz = dimz ? ss : 1
ix = sortperm(obj.x)[1:ssx:end]
iy = sortperm(obj.y)[1:ssy:end]
iz = sortperm(obj.z)[1:ssz:end]
idx = intersect(ix, iy, iz)
return obj[idx]
ss = Int(ceil(length(obj) / num_points))
return obj[1:ss:end]
end

if length(obj) > max_spins
Expand Down Expand Up @@ -1360,16 +1352,8 @@ function plot_phantom_map(
kwargs...,
)
function decimate_uniform_phantom(obj, num_points::Int)
dimx, dimy, dimz = KomaMRIBase.get_dims(obj)
ss = Int(ceil((length(obj) / num_points)^(1 / sum(KomaMRIBase.get_dims(obj)))))
ssx = dimx ? ss : 1
ssy = dimy ? ss : 1
ssz = dimz ? ss : 1
ix = sortperm(obj.x)[1:ssx:end]
iy = sortperm(obj.y)[1:ssy:end]
iz = sortperm(obj.z)[1:ssz:end]
idx = intersect(ix, iy, iz)
return obj[idx]
ss = Int(ceil(length(obj) / num_points))
return obj[1:ss:end]
end

if length(obj) > max_spins
Expand Down Expand Up @@ -1448,7 +1432,7 @@ function plot_phantom_map(
l.width = width
end
if view_2d
h = scatter(
h = scattergl(
x=obj.x*1e2,
y=obj.y*1e2,
mode="markers",
Expand Down
Binary file removed examples/2.phantoms/artery.phantom
Binary file not shown.

1 comment on commit 48a5bbe

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KomaMRI Benchmarks

Benchmark suite Current: 48a5bbe Previous: 356c48a Ratio
MRI Lab/Bloch/CPU/2 thread(s) 230260983 ns 231918438 ns 0.99
MRI Lab/Bloch/CPU/4 thread(s) 130246287 ns 171093596 ns 0.76
MRI Lab/Bloch/CPU/8 thread(s) 141633208 ns 87065540 ns 1.63
MRI Lab/Bloch/CPU/1 thread(s) 400677945.5 ns 404792264.5 ns 0.99
MRI Lab/Bloch/GPU/CUDA 55535980 ns 55451684 ns 1.00
MRI Lab/Bloch/GPU/oneAPI 495029548 ns 505814210 ns 0.98
MRI Lab/Bloch/GPU/Metal 541225208 ns 539627125 ns 1.00
MRI Lab/Bloch/GPU/AMDGPU 34691883.5 ns 34748766 ns 1.00
Slice Selection 3D/Bloch/CPU/2 thread(s) 1153307006.5 ns 1007165039.5 ns 1.15
Slice Selection 3D/Bloch/CPU/4 thread(s) 605676119.5 ns 576725538 ns 1.05
Slice Selection 3D/Bloch/CPU/8 thread(s) 387857453 ns 331123481 ns 1.17
Slice Selection 3D/Bloch/CPU/1 thread(s) 2232808553 ns 2261304965.5 ns 0.99
Slice Selection 3D/Bloch/GPU/CUDA 101291917 ns 100941879.5 ns 1.00
Slice Selection 3D/Bloch/GPU/oneAPI 632996422 ns 649245415 ns 0.97
Slice Selection 3D/Bloch/GPU/Metal 552117729 ns 552172687.5 ns 1.00
Slice Selection 3D/Bloch/GPU/AMDGPU 59154320.5 ns 59069032 ns 1.00

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.