Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonchinn178 committed Jun 6, 2024
1 parent f244b99 commit 0396deb
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,12 @@ a = do
[ c
| c <- d
]

trans =
[ x
| x <- xs,
then
reverse,
then
reverse
]
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,10 @@ barbaz x y z w = [
a = do
[ c
| c <- d ]

trans =
[ x
| x <- xs
, then reverse
, then reverse
]
109 changes: 73 additions & 36 deletions src/Ormolu/Printer/Meat/Declaration/Value.hs
Original file line number Diff line number Diff line change
Expand Up @@ -450,14 +450,14 @@ p_stmt' placer render = \case
space
sitcc $ p_hsLocalBinds binds
ParStmt {} ->
-- 'ParStmt' should always be eliminated in 'flattenStmts' already, such
-- 'ParStmt' should always be eliminated in 'gatherStmts' already, such
-- that it never occurs in 'p_stmt''. Consequently, handling it here
-- would be redundant.
notImplemented "ParStmt"
TransStmt {..} ->
-- 'TransStmt' only needs to account for render printing itself, since
-- pretty printing of relevant statements (e.g., in 'trS_stmts') is
-- handled through 'flattenStmts'.
-- handled through 'gatherStmts'.
case (trS_form, trS_by) of
(ThenForm, Nothing) -> do
txt "then"
Expand Down Expand Up @@ -756,7 +756,7 @@ p_hsExpr' isApp s = \case
(located' (sitcc . p_stmt))
stmts = init xs
yield = last xs
lists = flattenStmts stmts
lists = gatherStmts stmts
located yield p_stmt
breakpoint
txt "|"
Expand Down Expand Up @@ -867,33 +867,79 @@ p_hsExpr' isApp s = \case
space
located hswc_body p_hsType

-- | Flatten the set of statements in a list comprehension.
-- | Gather the set of statements in a list comprehension.
--
-- Concretely, expands all ParStmt constructors and extract out prefix
-- statements in TransStmt.
flattenStmts :: [ExprLStmt GhcPs] -> [[ExprLStmt GhcPs]]
flattenStmts = collect . map gatherStmt
-- For example, this code:
--
-- @
-- [ a + b + c + d
-- | a <- as, let b = a + a
-- | c <- cs
-- | d <- ds, then sort by f
-- ]
-- @
--
-- is parsed as roughly:
--
-- @
-- [ ParStmt
-- [ ParStmtBlock
-- [ BindStmt [| a <- as |]
-- , LetStmt [| let b = a + a |]
-- ]
-- , ParStmtBlock
-- [ BindStmt [| c <- cs |]
-- ]
-- , ParStmtBlock
-- [ TransStmt
-- [ BindStmt [| d <- ds |]
-- ]
-- [| then sort by f |]
-- ]
-- ]
-- , LastStmt [| a + b + c + d |]
-- ]
-- @
--
-- The final expression is parsed out in p_body, and the rest is passed
-- to this function. This function takes the above tree as input and
-- normalizes it into:
--
-- @
-- [ [ BindStmt [| a <- as |]
-- , LetStmt [| let b = a + a |]
-- ]
-- , [ BindStmt [| c <- cs |]
-- ]
-- , [ BindStmt [| d <- ds |]
-- , TransStmt [] [| then sortWith by f |]
-- ]
-- ]
-- @
--
-- Notes:
-- * The number of elements in the outer list is the number of pipes in
-- the comprehension; i.e. 1 unless -XParallelListComp is enabled
gatherStmts :: [ExprLStmt GhcPs] -> [[ExprLStmt GhcPs]]
gatherStmts = \case
-- When -XParallelListComp is enabled + list comprehension has
-- multiple pipes, input will have exactly 1 element, and it
-- will be ParStmt.
[L _ (ParStmt _ blocks _ _)] ->
[ concatMap collectNonParStmts stmts
| ParStmtBlock _ stmts _ _ <- blocks
]
-- Otherwise, list will not contain any ParStmt
stmts ->
[ concatMap collectNonParStmts stmts
]
where
-- Note: not exactly sure what this does...
--
-- From experimenting, it does this:
--
-- >>> collect [[[stmt1]], [[stmt2, stmt3]], [[stmt4], [stmt5]], [[stmt6]]]
-- [[stmt1, stmt2, stmt3, stmt4, stmt6], [stmt5]]
--
-- But it's not clear what the purpose is.
collect :: [[[ExprLStmt GhcPs]]] -> [[ExprLStmt GhcPs]]
collect = foldr (zipPrefixWith (<>)) []
collectNonParStmts = \case
L _ ParStmt {} -> unexpected "ParStmt"
stmt@(L _ TransStmt {trS_stmts}) -> concatMap collectNonParStmts trS_stmts ++ [stmt]
stmt -> [stmt]

gatherStmt :: ExprLStmt GhcPs -> [[ExprLStmt GhcPs]]
gatherStmt (L _ (ParStmt _ block _ _)) =
concatMap gatherStmtBlock block
gatherStmt (L s stmt@TransStmt {..}) =
collect $ map gatherStmt trS_stmts <> [[[L s stmt]]]
gatherStmt stmt = [[stmt]]

gatherStmtBlock :: ParStmtBlock GhcPs GhcPs -> [[ExprLStmt GhcPs]]
gatherStmtBlock (ParStmtBlock _ stmts _ _) = flattenStmts stmts
unexpected label = error $ "Unexpected " <> label <> "! Please file a bug."

p_patSynBind :: PatSynBind GhcPs GhcPs -> R ()
p_patSynBind PSB {..} = do
Expand Down Expand Up @@ -1298,15 +1344,6 @@ layoutToBraces = \case
SingleLine -> useBraces
MultiLine -> id

-- | Same as 'zipWith', except only works on lists of the same type, and
-- leaves extra elements at the end of the list.
zipPrefixWith :: (a -> a -> a) -> [a] -> [a] -> [a]
zipPrefixWith f = go
where
go [] ys = ys
go xs [] = xs
go (x : xs) (y : ys) = f x y : go xs ys

getGRHSSpan :: GRHS GhcPs (LocatedA body) -> SrcSpan
getGRHSSpan (GRHS _ guards body) =
combineSrcSpans' $ getLocA body :| map getLocA guards
Expand Down

0 comments on commit 0396deb

Please sign in to comment.