Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor tactics to track hypothesis provenance #557

Merged
merged 25 commits into from
Oct 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 32 additions & 20 deletions plugins/tactics/src/Ide/Plugin/Tactic.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ import Control.Monad.Trans
import Control.Monad.Trans.Maybe
import Data.Aeson
import Data.Coerce
import Data.Functor ((<&>))
import Data.Generics.Aliases (mkQ)
import Data.Generics.Schemes (everything)
import Data.List
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe
import Data.Monoid
Expand Down Expand Up @@ -214,7 +216,7 @@ filterBindingType
filterBindingType p tp dflags plId uri range jdg =
let hy = jHypothesis jdg
g = jGoal jdg
in fmap join $ for (M.toList hy) $ \(occ, CType ty) ->
in fmap join $ for (M.toList hy) $ \(occ, hi_type -> CType ty) ->
case p (unCType g) ty of
True -> tp occ ty dflags plId uri range jdg
False -> pure []
Expand Down Expand Up @@ -264,23 +266,28 @@ judgementForHole state nfp range = do
(mapMaybe (sequenceA . (occName *** coerce))
$ getDefiningBindings binds rss)
tcg
hyps = hypothesisFromBindings rss binds
ambient = M.fromList $ contextMethodHypothesis ctx
top_provs = getRhsPosVals rss tcs
local_hy = spliceProvenance top_provs
$ hypothesisFromBindings rss binds
cls_hy = contextMethodHypothesis ctx
pure ( resulting_range
, mkFirstJudgement
hyps
ambient
(local_hy <> cls_hy)
(isRhsHole rss tcs)
(maybe
mempty
(uncurry M.singleton . fmap pure)
$ getRhsPosVals rss tcs)
goal
, ctx
, dflags
)


spliceProvenance
:: Map OccName Provenance
-> Map OccName (HyInfo a)
-> Map OccName (HyInfo a)
spliceProvenance provs =
M.mapWithKey $ \name hi ->
overProvenance (maybe id const $ M.lookup name provs) hi


tacticCmd :: (OccName -> TacticsM ()) -> CommandFunction TacticParams
tacticCmd tac lf state (TacticParams uri range var_name)
Expand Down Expand Up @@ -334,17 +341,22 @@ isRhsHole rss tcs = everything (||) (mkQ False $ \case

------------------------------------------------------------------------------
-- | Compute top-level position vals of a function
getRhsPosVals :: RealSrcSpan -> TypecheckedSource -> Maybe (OccName, [OccName])
getRhsPosVals rss tcs = getFirst $ everything (<>) (mkQ mempty $ \case
TopLevelRHS name ps
(L (RealSrcSpan span) -- body with no guards and a single defn
(HsVar _ (L _ hole)))
| containsSpan rss span -- which contains our span
, isHole $ occName hole -- and the span is a hole
-> First $ do
patnames <- traverse getPatName ps
pure (occName name, patnames)
_ -> mempty
getRhsPosVals :: RealSrcSpan -> TypecheckedSource -> Map OccName Provenance
getRhsPosVals rss tcs
= M.fromList
$ join
$ maybeToList
$ getFirst
$ everything (<>) (mkQ mempty $ \case
TopLevelRHS name ps
(L (RealSrcSpan span) -- body with no guards and a single defn
(HsVar _ (L _ hole)))
| containsSpan rss span -- which contains our span
, isHole $ occName hole -- and the span is a hole
-> First $ do
patnames <- traverse getPatName ps
pure $ zip patnames $ [0..] <&> TopLevelArgPrv name
_ -> mempty
) tcs


Expand Down
2 changes: 1 addition & 1 deletion plugins/tactics/src/Ide/Plugin/Tactic/Auto.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ auto = do
commit knownStrategies
. tracing "auto"
. localTactic (auto' 4)
. disallowing
. disallowing RecursiveCall
$ fmap fst current

29 changes: 12 additions & 17 deletions plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ useOccName jdg name =

------------------------------------------------------------------------------
-- | Doing recursion incurs a small penalty in the score.
penalizeRecursion :: MonadState TacticState m => m ()
penalizeRecursion = modify $ field @"ts_recursion_penality" +~ 1
countRecursiveCall :: TacticState -> TacticState
countRecursiveCall = field @"ts_recursion_count" +~ 1


------------------------------------------------------------------------------
Expand All @@ -57,14 +57,14 @@ addUnusedTopVals vals = modify $ field @"ts_unused_top_vals" <>~ vals
destructMatches
:: (DataCon -> Judgement -> Rule)
-- ^ How to construct each match
-> ([(OccName, CType)] -> Judgement -> Judgement)
-- ^ How to derive each match judgement
-> Maybe OccName
-- ^ Scrutinee
-> CType
-- ^ Type being destructed
-> Judgement
-> RuleM (Trace, [RawMatch])
destructMatches f f2 t jdg = do
let hy = jHypothesis jdg
destructMatches f scrut t jdg = do
let hy = jEntireHypothesis jdg
g = jGoal jdg
case splitTyConApp_maybe $ unCType t of
Nothing -> throwError $ GoalMismatch "destruct" g
Expand All @@ -76,11 +76,7 @@ destructMatches f f2 t jdg = do
let args = dataConInstOrigArgTys' dc apps
names <- mkManyGoodNames hy args
let hy' = zip names $ coerce args
dcon_name = nameOccName $ dataConName dc

let j = f2 hy'
$ withPositionMapping dcon_name names
$ introducingPat hy'
j = introducingPat scrut dc hy'
$ withNewGoal g jdg
(tr, sg) <- f dc j
modify $ withIntroducedVals $ mappend $ S.fromList names
Expand Down Expand Up @@ -142,14 +138,14 @@ destruct' f term jdg = do
let hy = jHypothesis jdg
case find ((== term) . fst) $ toList hy of
Nothing -> throwError $ UndefinedHypothesis term
Just (_, t) -> do
Just (_, hi_type -> t) -> do
useOccName jdg term
(tr, ms)
<- destructMatches
f
(\cs -> setParents term (fmap fst cs) . destructing term)
(Just term)
t
jdg
$ disallowing AlreadyDestructed [term] jdg
pure ( rose ("destruct " <> show term) $ pure tr
, noLoc $ case' (var' term) ms
)
Expand All @@ -165,7 +161,7 @@ destructLambdaCase' f jdg = do
case splitFunTy_maybe (unCType g) of
Just (arg, _) | isAlgType arg ->
fmap (fmap noLoc $ lambdaCase) <$>
destructMatches f (const id) (CType arg) jdg
destructMatches f Nothing (CType arg) jdg
_ -> throwError $ GoalMismatch "destructLambdaCase'" g


Expand All @@ -178,12 +174,11 @@ buildDataCon
-> RuleM (Trace, LHsExpr GhcPs)
buildDataCon jdg dc apps = do
let args = dataConInstOrigArgTys' dc apps
dcon_name = nameOccName $ dataConName dc
(tr, sgs)
<- fmap unzipTrace
$ traverse ( \(arg, n) ->
newSubgoal
. filterSameTypeFromOtherPositions dcon_name n
. filterSameTypeFromOtherPositions dc n
. blacklistingDestruct
. flip withNewGoal jdg
$ CType arg
Expand Down
9 changes: 6 additions & 3 deletions plugins/tactics/src/Ide/Plugin/Tactic/Context.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import Bag
import Control.Arrow
import Control.Monad.Reader
import Data.List
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe (mapMaybe)
import Data.Set (Set)
import qualified Data.Set as S
Expand All @@ -33,9 +35,10 @@ mkContext locals tcg = Context

------------------------------------------------------------------------------
-- | Find all of the class methods that exist from the givens in the context.
contextMethodHypothesis :: Context -> [(OccName, CType)]
contextMethodHypothesis :: Context -> Map OccName (HyInfo CType)
contextMethodHypothesis ctx
= excludeForbiddenMethods
= M.fromList
. excludeForbiddenMethods
. join
. concatMap
( mapMaybe methodHypothesis
Expand All @@ -51,7 +54,7 @@ contextMethodHypothesis ctx
-- | Many operations are defined in typeclasses for performance reasons, rather
-- than being a true part of the class. This function filters out those, in
-- order to keep our hypothesis space small.
excludeForbiddenMethods :: [(OccName, CType)] -> [(OccName, CType)]
excludeForbiddenMethods :: [(OccName, a)] -> [(OccName, a)]
excludeForbiddenMethods = filter (not . flip S.member forbiddenMethods . fst)
where
forbiddenMethods :: Set OccName
Expand Down
Loading