Skip to content

Commit

Permalink
fix #207
Browse files Browse the repository at this point in the history
  • Loading branch information
domluna committed Mar 25, 2020
1 parent 9deeb5b commit cd92820
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 10 deletions.
22 changes: 13 additions & 9 deletions src/fst.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ function get_args(cst::CSTParser.EXPR)
cst.typ === CSTParser.Curly ||
cst.typ === CSTParser.Call
return get_args(cst.args[2:end])
elseif cst.typ === CSTParser.WhereOpCall
# get the arguments in B of `A where B`
return get_args(cst.args[3:end])
elseif cst.typ === CSTParser.Parameters ||
cst.typ === CSTParser.Braces ||
cst.typ === CSTParser.Vcat ||
Expand All @@ -124,22 +127,21 @@ end

function get_args(args::Vector{CSTParser.EXPR})
args0 = CSTParser.EXPR[]
for arg in args
CSTParser.ispunctuation(arg) && continue
if CSTParser.typof(arg) === CSTParser.Parameters
for j = 1:length(arg.args)
parg = arg[j]
for a in args
CSTParser.ispunctuation(a) && continue
if CSTParser.typof(a) === CSTParser.Parameters
for j = 1:length(a.args)
parg = a[j]
CSTParser.ispunctuation(parg) && continue
push!(args0, parg)
end
else
push!(args0, arg)
push!(args0, a)
end
end
args0
end

n_args(x) = length(get_args(x))
@inline n_args(x) = length(get_args(x))

function add_node!(t::FST, n::FST, s::State; join_lines = false, max_padding = -1)
if n.typ === SEMICOLON
Expand Down Expand Up @@ -213,7 +215,9 @@ function add_node!(t::FST, n::FST, s::State; join_lines = false, max_padding = -
add_node!(t, Semicolon(), s)
if length(n.nodes) > 0
multi_arg = n_args(t.ref[]) > 0
multi_arg ? add_node!(t, Placeholder(1), s) : add_node!(t, Whitespace(1), s)
nws = (t.typ === CSTParser.Curly || t.typ === CSTParser.WhereOpCall) && !s.opts.whitespace_typedefs ? 0 : 1
# @info "" t.typ nws s.opts.whitespace_typedefs
multi_arg ? add_node!(t, Placeholder(nws), s) : add_node!(t, Whitespace(nws), s)
end
end

Expand Down
1 change: 0 additions & 1 deletion src/pretty.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,6 @@ function p_whereopcall(ds::DefaultStyle, cst::CSTParser.EXPR, s::State)
# Used to mark where `B` starts.
add_node!(t, Placeholder(0), s)

nest = length(CSTParser.get_where_params(cst)) > 0
args = get_args(cst.args[3:end])
nest = length(args) > 0 && !(length(args) == 1 && unnestable_arg(args[1]))
add_braces =
Expand Down
47 changes: 47 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4697,4 +4697,51 @@ some_function(
@test fmt(str_, m = 92) == str
end

@testset "issue 207" begin
str_ = """
@traitfn function predict_ar(m::TGP, p::Int = 3, n::Int = 1; y_past = get_y(m)) where {T,TGP<:AbstractGP{T};IsMultiOutput{TGP}}
end"""

str = """
@traitfn function predict_ar(
m::TGP,
p::Int = 3,
n::Int = 1;
y_past = get_y(m),
) where {T,TGP<:AbstractGP{T};IsMultiOutput{TGP}} end"""
@test fmt(str_, m = 92) == str

str = """
@traitfn function predict_ar(
m::TGP,
p::Int = 3,
n::Int = 1;
y_past = get_y(m),
) where {T, TGP <: AbstractGP{T}; IsMultiOutput{TGP}} end"""
@test fmt(str_, m = 92, whitespace_typedefs = true) == str

str_ = """
@traitfn function predict_ar(m::TGP, p::Int = 3, n::Int = 1; y_past = get_y(m)) where C <: Union{T,TGP<:AbstractGP{T};IsMultiOutput{TGP}}
end"""

str = """
@traitfn function predict_ar(
m::TGP,
p::Int = 3,
n::Int = 1;
y_past = get_y(m),
) where {C<:Union{T,TGP<:AbstractGP{T};IsMultiOutput{TGP}}} end"""
@test fmt(str_, m = 92) == str

str = """
@traitfn function predict_ar(
m::TGP,
p::Int = 3,
n::Int = 1;
y_past = get_y(m),
) where {C <: Union{T, TGP <: AbstractGP{T}; IsMultiOutput{TGP}}} end"""
@test fmt(str_, m = 92, whitespace_typedefs = true) == str

end

end

0 comments on commit cd92820

Please sign in to comment.