Skip to content

Commit

Permalink
ENH: recession_bands now supports subplots
Browse files Browse the repository at this point in the history
  • Loading branch information
sglyon committed May 1, 2017
1 parent 23ac07b commit 5330d71
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 16 deletions.
1 change: 1 addition & 0 deletions src/PlotlyJS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ end

# include the rest of the core parts of the package
include("display.jl")
include("util.jl")
include("json.jl")
include("subplots.jl")
include("api.jl")
Expand Down
114 changes: 98 additions & 16 deletions src/recession_bands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,20 @@ writedlm(out_path, [start_dates end_dates])
function _recession_band_shapes(p::Plot; kwargs...)
# data is 2 column matrix, first column is start date, second column
# is end_date
data = map(Date, readdlm(
recession_data = map(Date, readdlm(
joinpath(dirname(dirname(@__FILE__)), "recession_dates.tsv")
))::Matrix{Date}

# need to over the x attribute of each trace and get the most extreme
min_date = Date(500_000)
max_date = Date(-500_000)
dates_changed = false
for t in p.data
# keep track of which traces have dates for x
n_trace = length(p.data)
x_has_dates = fill(false, n_trace)

# need to over the x attribute of each trace and get the most extreme dates
# for each trace
min_date = fill(Date(500_000), n_trace)
max_date = fill(Date(-500_000), n_trace)
dates_changed = fill(false, n_trace)
for (i, t) in enumerate(p.data)
t_x = t[:x]
if !isempty(t_x)
maybe_dates = to_date(t_x)
Expand All @@ -96,23 +101,100 @@ function _recession_band_shapes(p::Plot; kwargs...)
else
dates = get(maybe_dates)
if typeof(dates) <: AbstractArray
min_date = min(min_date, minimum(dates))
max_date = max(max_date, maximum(dates))
dates_changed = true
min_date[i] = min(min_date[i], minimum(dates))
max_date[i] = max(max_date[i], maximum(dates))
dates_changed[i] = true
x_has_dates[i] = true
end
end
end
end
if !dates_changed
if !any(dates_changed)
return Nullable{Vector{Shape}}()
end
ix1 = searchsortedfirst(data[:, 1], min_date)
ix2 = searchsortedlast(data[:, 2], max_date)
Nullable{Vector{Shape}}(
rect(data[ix1:ix2, 1], data[ix1:ix2, 2], 0, 1;
fillcolor="#E6E6E6", opacity=0.4, line_width=0, layer="below",
xref="x", yref="paper", kwargs...)

# map traces into xaxis and yaxis
xmap = trace_map(p, :x)
ymap = trace_map(p, :y)

# WANT: 1 set of shapes for each xaxis with dates
n_xaxis = length(unique(xmap))

band_kwargs = attr(
fillcolor="#E6E6E6", opacity=0.4, line_width=0, layer="below",
yref="paper", kwargs...
)

# if we only have one xaxis, only add one set of shapes
if n_xaxis == 1
ix1 = minimum(searchsortedfirst(recession_data[:, 1], m) for m in min_date)
ix2 = maximum(searchsortedlast(recession_data[:, 2], m) for m in max_date)
if isempty(ix1:ix2)
return Nullable{Vector{Shape}}()
end

out = rect(
recession_data[ix1:ix2, 1],
recession_data[ix1:ix2, 2],
0,
1,
band_kwargs.fields;
xref="x"
)
return Nullable{Vector{Shape}}(out)
end

# otherwise we need to loop over the traces and add one set of shapes
# for each x axis
out = Vector{Shape}()

# first determine the min_date and max_date for each axis, based on that
# info for each trace and the xmap.
min_ax_date = fill(Date(500_000), n_xaxis)
max_ax_date = fill(Date(-500_000), n_xaxis)
for i in 1:n_trace
ix = xmap[i]
min_ax_date[ix] = minimum(min_ax_date[ix], min_date[ix])
max_ax_date[ix] = maximum(max_ax_date[ix], max_date[ix])
end

# now loop through the traces and add one set of shapes per axis that
# appears. We loop over traces so that we can be sure to get the correct
# yaxis domain for each trace
already_added = fill(false, n_xaxis)
for i in 1:n_trace
if x_has_dates[i]
# move on if we have bands on this xax
ix = xmap[i]
already_added[ix] && continue

# get starting and ending date
ix1 = searchsortedfirst(recession_data[:, 1], min_ax_date[ix])
ix2 = searchsortedlast(recession_data[:, 2], max_ax_date[ix])
isempty(ix1:ix2) && continue # range was bogus, move on

# otherwise extract the yaxis details
yax = p.layout[Symbol("yaxis", ymap[i])]

isempty(yax) && continue # don't have this axis somehow... move on

# don't have info on where this axis lives ... move on
!haskey(yax, :domain) && continue
ybounds = yax[:domain]

# we are in business, build the shapes and append them
append!(out,
rect(recession_data[ix1:ix2, 1],
recession_data[ix1:ix2, 2],
ybounds[1],
ybounds[2],
band_kwargs.fields;
xref=string("x", xmap[i]))
)
already_added[xmap[i]] = true
end
end
Nullable(out)
end

function add_recession_bands!(p::Plot; kwargs...)
Expand Down
2 changes: 2 additions & 0 deletions src/traces_layouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ function Base.merge!(hf1::HasFields, hf2::HasFields)
hf1
end

Base.haskey(hf::HasFields, k::Symbol) = haskey(hf.fields, k)

Base.merge{T<:HasFields}(hf1::T, hf2::T) =
merge!(deepcopy(hf1), hf2)

Expand Down
20 changes: 20 additions & 0 deletions src/util.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
trace_map(p::Plot, axis::Symbol=:x)
Return an array of `length(p.data)` that maps each element of `p.data` into an
integer for which number axis of kind `axis` that trace belogs to. `axis` can
either be `x` or `y`. If `x` is given, return the integer for which x-axis the
trace belongs to. Similar for `y`.
"""
function trace_map(p::Plot, axis=:x)
out = fill(1, length(p.data))
ax_key = axis == :x ? :xaxis : :yaxis
for (i, t) in enumerate(p.data)
if haskey(t, ax_key)
out[i] = parse(Int, t[ax_key][2:end])
end
end
out
end

trace_map(p::SyncPlot, axis=:x) = trace_map(p.plot, axis)

0 comments on commit 5330d71

Please sign in to comment.