Skip to content

Commit 2533574

Browse files
isovectorTOTBWF
andauthored
Allow hole filling to deal with recursion (#472)
This PR enhances the "attempt to fill hole" code action, allowing it to implement self-recursive functions. The generated code ensures recursion occurs only on structurally-smaller values, and preserves the positional ordering of homomorphically destructed arguments. It's clever enough to implement foldr and nontrivial functor instances. Co-authored-by: TOTBWF <[email protected]>
1 parent cdf50a6 commit 2533574

35 files changed

+785
-244
lines changed

cabal.project

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,6 @@ package ghcide
2525

2626
write-ghc-environment-files: never
2727

28-
index-state: 2020-10-08T12:51:21Z
28+
index-state: 2020-10-16T04:00:00Z
2929

3030
allow-newer: data-tree-print:base

haskell-language-server.cabal

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,13 @@ executable haskell-language-server
9494
Ide.Plugin.Retrie
9595
Ide.Plugin.StylishHaskell
9696
Ide.Plugin.Tactic
97+
Ide.Plugin.Tactic.Auto
9798
Ide.Plugin.Tactic.CodeGen
9899
Ide.Plugin.Tactic.Context
99100
Ide.Plugin.Tactic.Debug
100101
Ide.Plugin.Tactic.GHC
101102
Ide.Plugin.Tactic.Judgements
103+
Ide.Plugin.Tactic.KnownStrategies
102104
Ide.Plugin.Tactic.Machinery
103105
Ide.Plugin.Tactic.Naming
104106
Ide.Plugin.Tactic.Range
@@ -156,9 +158,10 @@ executable haskell-language-server
156158
, transformers
157159
, unordered-containers
158160
, ghc-source-gen
159-
, refinery ^>=0.2
161+
, refinery ^>=0.3
160162
, ghc-exactprint
161163
, fingertree
164+
, generic-lens
162165

163166
if flag(agpl)
164167
build-depends: brittany

plugins/tactics/src/Ide/Plugin/Tactic.hs

Lines changed: 73 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
{-# LANGUAGE DeriveAnyClass #-}
22
{-# LANGUAGE DeriveGeneric #-}
33
{-# LANGUAGE LambdaCase #-}
4+
{-# LANGUAGE NumDecimals #-}
45
{-# LANGUAGE OverloadedStrings #-}
6+
{-# LANGUAGE PatternSynonyms #-}
57
{-# LANGUAGE ScopedTypeVariables #-}
68
{-# LANGUAGE TupleSections #-}
79
{-# LANGUAGE TypeApplications #-}
@@ -20,8 +22,12 @@ import Control.Monad.Trans
2022
import Control.Monad.Trans.Maybe
2123
import Data.Aeson
2224
import Data.Coerce
25+
import Data.Generics.Aliases (mkQ)
26+
import Data.Generics.Schemes (everything)
27+
import Data.List
2328
import qualified Data.Map as M
2429
import Data.Maybe
30+
import Data.Monoid
2531
import qualified Data.Set as S
2632
import qualified Data.Text as T
2733
import Data.Traversable
@@ -38,6 +44,7 @@ import qualified FastString
3844
import GHC.Generics (Generic)
3945
import GHC.LanguageExtensions.Type (Extension (LambdaCase))
4046
import Ide.Plugin (mkLspCommand)
47+
import Ide.Plugin.Tactic.Auto
4148
import Ide.Plugin.Tactic.Context
4249
import Ide.Plugin.Tactic.GHC
4350
import Ide.Plugin.Tactic.Judgements
@@ -50,6 +57,8 @@ import Ide.Types
5057
import Language.Haskell.LSP.Core (clientCapabilities)
5158
import Language.Haskell.LSP.Types
5259
import OccName
60+
import SrcLoc (containsSpan)
61+
import System.Timeout
5362

5463

5564
descriptor :: PluginId -> PluginDescriptor
@@ -250,12 +259,24 @@ judgementForHole state nfp range = do
250259
resulting_range <- liftMaybe $ toCurrentRange amapping $ realSrcSpanToRange rss
251260
(tcmod, _) <- MaybeT $ runIde state $ useWithStale TypeCheck nfp
252261
let tcg = fst $ tm_internals_ $ tmrModule tcmod
262+
tcs = tm_typechecked_source $ tmrModule tcmod
253263
ctx = mkContext
254264
(mapMaybe (sequenceA . (occName *** coerce))
255265
$ getDefiningBindings binds rss)
256266
tcg
257267
hyps = hypothesisFromBindings rss binds
258-
pure (resulting_range, mkFirstJudgement hyps goal, ctx, dflags)
268+
pure ( resulting_range
269+
, mkFirstJudgement
270+
hyps
271+
(isRhsHole rss tcs)
272+
(maybe
273+
mempty
274+
(uncurry M.singleton . fmap pure)
275+
$ getRhsPosVals rss tcs)
276+
goal
277+
, ctx
278+
, dflags
279+
)
259280

260281

261282

@@ -266,20 +287,26 @@ tacticCmd tac lf state (TacticParams uri range var_name)
266287
(range', jdg, ctx, dflags) <- judgementForHole state nfp range
267288
let span = rangeToRealSrcSpan (fromNormalizedFilePath nfp) range'
268289
pm <- MaybeT $ useAnnotatedSource "tacticsCmd" state nfp
269-
case runTactic ctx jdg
270-
$ tac
271-
$ mkVarOcc
272-
$ T.unpack var_name of
273-
Left err ->
274-
pure $ (, Nothing)
275-
$ Left
276-
$ ResponseError InvalidRequest (T.pack $ show err) Nothing
277-
Right res -> do
278-
let g = graft (RealSrcSpan span) res
279-
response = transform dflags (clientCapabilities lf) uri g pm
280-
pure $ case response of
281-
Right res -> (Right Null , Just (WorkspaceApplyEdit, ApplyWorkspaceEditParams res))
282-
Left err -> (Left $ ResponseError InternalError (T.pack err) Nothing, Nothing)
290+
x <- lift $ timeout 2e8 $
291+
case runTactic ctx jdg
292+
$ tac
293+
$ mkVarOcc
294+
$ T.unpack var_name of
295+
Left err ->
296+
pure $ (, Nothing)
297+
$ Left
298+
$ ResponseError InvalidRequest (T.pack $ show err) Nothing
299+
Right (_, ext) -> do
300+
let g = graft (RealSrcSpan span) ext
301+
response = transform dflags (clientCapabilities lf) uri g pm
302+
pure $ case response of
303+
Right res -> (Right Null , Just (WorkspaceApplyEdit, ApplyWorkspaceEditParams res))
304+
Left err -> (Left $ ResponseError InternalError (T.pack err) Nothing, Nothing)
305+
pure $ case x of
306+
Just y -> y
307+
Nothing -> (, Nothing)
308+
$ Left
309+
$ ResponseError InvalidRequest "timed out" Nothing
283310
tacticCmd _ _ _ _ =
284311
pure ( Left $ ResponseError InvalidRequest (T.pack "Bad URI") Nothing
285312
, Nothing
@@ -292,3 +319,34 @@ fromMaybeT def = fmap (fromMaybe def) . runMaybeT
292319
liftMaybe :: Monad m => Maybe a -> MaybeT m a
293320
liftMaybe a = MaybeT $ pure a
294321

322+
323+
------------------------------------------------------------------------------
324+
-- | Is this hole immediately to the right of an equals sign?
325+
isRhsHole :: RealSrcSpan -> TypecheckedSource -> Bool
326+
isRhsHole rss tcs = everything (||) (mkQ False $ \case
327+
TopLevelRHS _ _ (L (RealSrcSpan span) _) -> containsSpan rss span
328+
_ -> False
329+
) tcs
330+
331+
332+
------------------------------------------------------------------------------
333+
-- | Compute top-level position vals of a function
334+
getRhsPosVals :: RealSrcSpan -> TypecheckedSource -> Maybe (OccName, [OccName])
335+
getRhsPosVals rss tcs = getFirst $ everything (<>) (mkQ mempty $ \case
336+
TopLevelRHS name ps
337+
(L (RealSrcSpan span) -- body with no guards and a single defn
338+
(HsVar _ (L _ hole)))
339+
| containsSpan rss span -- which contains our span
340+
, isHole $ occName hole -- and the span is a hole
341+
-> First $ do
342+
patnames <- traverse getPatName ps
343+
pure (occName name, patnames)
344+
_ -> mempty
345+
) tcs
346+
347+
348+
349+
-- TODO(sandy): Make this more robust
350+
isHole :: OccName -> Bool
351+
isHole = isPrefixOf "_" . occNameString
352+
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
module Ide.Plugin.Tactic.Auto where
2+
3+
import Ide.Plugin.Tactic.Context
4+
import Ide.Plugin.Tactic.Judgements
5+
import Ide.Plugin.Tactic.KnownStrategies
6+
import Ide.Plugin.Tactic.Tactics
7+
import Ide.Plugin.Tactic.Types
8+
import Refinery.Tactic
9+
import Ide.Plugin.Tactic.Machinery (tracing)
10+
11+
12+
------------------------------------------------------------------------------
13+
-- | Automatically solve a goal.
14+
auto :: TacticsM ()
15+
auto = do
16+
jdg <- goal
17+
current <- getCurrentDefinitions
18+
traceMX "goal" jdg
19+
traceMX "ctx" current
20+
commit knownStrategies
21+
. tracing "auto"
22+
. localTactic (auto' 4)
23+
. disallowing
24+
$ fmap fst current
25+
Lines changed: 79 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,45 @@
1+
{-# LANGUAGE TupleSections #-}
12
{-# LANGUAGE FlexibleContexts #-}
23
module Ide.Plugin.Tactic.CodeGen where
34

4-
import Control.Monad.Except
5-
import Data.List
6-
import Data.Traversable
7-
import DataCon
8-
import Development.IDE.GHC.Compat
9-
import GHC.Exts
10-
import GHC.SourceGen.Binds
11-
import GHC.SourceGen.Expr
12-
import GHC.SourceGen.Overloaded
13-
import GHC.SourceGen.Pat
14-
import Ide.Plugin.Tactic.Judgements
15-
import Ide.Plugin.Tactic.Machinery
16-
import Ide.Plugin.Tactic.Naming
17-
import Ide.Plugin.Tactic.Types
18-
import Name
19-
import Type hiding (Var)
5+
import Control.Monad.Except
6+
import Control.Monad.State (MonadState)
7+
import Control.Monad.State.Class (modify)
8+
import Data.List
9+
import qualified Data.Map as M
10+
import qualified Data.Set as S
11+
import Data.Traversable
12+
import DataCon
13+
import Development.IDE.GHC.Compat
14+
import GHC.Exts
15+
import GHC.SourceGen.Binds
16+
import GHC.SourceGen.Expr
17+
import GHC.SourceGen.Overloaded
18+
import GHC.SourceGen.Pat
19+
import Ide.Plugin.Tactic.Judgements
20+
import Ide.Plugin.Tactic.Machinery
21+
import Ide.Plugin.Tactic.Naming
22+
import Ide.Plugin.Tactic.Types
23+
import Name
24+
import Type hiding (Var)
25+
26+
27+
useOccName :: MonadState TacticState m => Judgement -> OccName -> m ()
28+
useOccName jdg name =
29+
case M.lookup name $ jHypothesis jdg of
30+
Just{} -> modify $ withUsedVals $ S.insert name
31+
Nothing -> pure ()
2032

2133

2234
destructMatches
2335
:: (DataCon -> Judgement -> Rule)
2436
-- ^ How to construct each match
25-
-> (Judgement -> Judgement)
37+
-> ([(OccName, CType)] -> Judgement -> Judgement)
2638
-- ^ How to derive each match judgement
2739
-> CType
2840
-- ^ Type being destructed
2941
-> Judgement
30-
-> RuleM [RawMatch]
42+
-> RuleM (Trace, [RawMatch])
3143
destructMatches f f2 t jdg = do
3244
let hy = jHypothesis jdg
3345
g = jGoal jdg
@@ -37,18 +49,32 @@ destructMatches f f2 t jdg = do
3749
let dcs = tyConDataCons tc
3850
case dcs of
3951
[] -> throwError $ GoalMismatch "destruct" g
40-
_ -> for dcs $ \dc -> do
52+
_ -> fmap unzipTrace $ for dcs $ \dc -> do
4153
let args = dataConInstOrigArgTys' dc apps
4254
names <- mkManyGoodNames hy args
55+
let hy' = zip names $ coerce args
56+
dcon_name = nameOccName $ dataConName dc
4357

4458
let pat :: Pat GhcPs
45-
pat = conP (fromString $ occNameString $ nameOccName $ dataConName dc)
59+
pat = conP (fromString $ occNameString dcon_name)
4660
$ fmap bvar' names
47-
j = f2
48-
$ introducingPat (zip names $ coerce args)
61+
j = f2 hy'
62+
$ withPositionMapping dcon_name names
63+
$ introducingPat hy'
4964
$ withNewGoal g jdg
50-
sg <- f dc j
51-
pure $ match [pat] $ unLoc sg
65+
(tr, sg) <- f dc j
66+
modify $ withIntroducedVals $ mappend $ S.fromList names
67+
pure ( rose ("match " <> show dc <> " {" <>
68+
intercalate ", " (fmap show names) <> "}")
69+
$ pure tr
70+
, match [pat] $ unLoc sg
71+
)
72+
73+
74+
unzipTrace :: [(Trace, a)] -> (Trace, [a])
75+
unzipTrace l =
76+
let (trs, as) = unzip l
77+
in (rose mempty trs, as)
5278

5379

5480
-- | Essentially same as 'dataConInstOrigArgTys' in GHC,
@@ -66,24 +92,34 @@ dataConInstOrigArgTys' con ty =
6692

6793
destruct' :: (DataCon -> Judgement -> Rule) -> OccName -> Judgement -> Rule
6894
destruct' f term jdg = do
95+
when (isDestructBlacklisted jdg) $ throwError NoApplicableTactic
6996
let hy = jHypothesis jdg
7097
case find ((== term) . fst) $ toList hy of
7198
Nothing -> throwError $ UndefinedHypothesis term
72-
Just (_, t) ->
73-
fmap noLoc $ case' (var' term) <$>
74-
destructMatches f (destructing term) t jdg
99+
Just (_, t) -> do
100+
useOccName jdg term
101+
(tr, ms)
102+
<- destructMatches
103+
f
104+
(\cs -> setParents term (fmap fst cs) . destructing term)
105+
t
106+
jdg
107+
pure ( rose ("destruct " <> show term) $ pure tr
108+
, noLoc $ case' (var' term) ms
109+
)
75110

76111

77112
------------------------------------------------------------------------------
78113
-- | Combinator for performign case splitting, and running sub-rules on the
79114
-- resulting matches.
80115
destructLambdaCase' :: (DataCon -> Judgement -> Rule) -> Judgement -> Rule
81116
destructLambdaCase' f jdg = do
117+
when (isDestructBlacklisted jdg) $ throwError NoApplicableTactic
82118
let g = jGoal jdg
83119
case splitFunTy_maybe (unCType g) of
84120
Just (arg, _) | isAlgType arg ->
85-
fmap noLoc $ lambdaCase <$>
86-
destructMatches f id (CType arg) jdg
121+
fmap (fmap noLoc $ lambdaCase) <$>
122+
destructMatches f (const id) (CType arg) jdg
87123
_ -> throwError $ GoalMismatch "destructLambdaCase'" g
88124

89125

@@ -93,11 +129,21 @@ buildDataCon
93129
:: Judgement
94130
-> DataCon -- ^ The data con to build
95131
-> [Type] -- ^ Type arguments for the data con
96-
-> RuleM (LHsExpr GhcPs)
132+
-> RuleM (Trace, LHsExpr GhcPs)
97133
buildDataCon jdg dc apps = do
98134
let args = dataConInstOrigArgTys' dc apps
99-
sgs <- traverse (newSubgoal . flip withNewGoal jdg . CType) args
135+
dcon_name = nameOccName $ dataConName dc
136+
(tr, sgs)
137+
<- fmap unzipTrace
138+
$ traverse ( \(arg, n) ->
139+
newSubgoal
140+
. filterSameTypeFromOtherPositions dcon_name n
141+
. blacklistingDestruct
142+
. flip withNewGoal jdg
143+
$ CType arg
144+
) $ zip args [0..]
100145
pure
146+
. (rose (show dc) $ pure tr,)
101147
. noLoc
102148
. foldl' (@@)
103149
(HsVar noExtField $ noLoc $ Unqual $ nameOccName $ dataConName dc)
@@ -109,7 +155,9 @@ buildDataCon jdg dc apps = do
109155
var' :: Var a => OccName -> a
110156
var' = var . fromString . occNameString
111157

158+
112159
------------------------------------------------------------------------------
113160
-- | Like 'bvar', but works over standard GHC 'OccName's.
114161
bvar' :: BVar a => OccName -> a
115162
bvar' = bvar . fromString . occNameString
163+

plugins/tactics/src/Ide/Plugin/Tactic/Context.hs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ getFunBindId (AbsBinds _ _ _ abes _ _ _)
3434
getFunBindId _ = []
3535

3636

37-
getCurrentDefinitions :: MonadReader Context m => m [OccName]
38-
getCurrentDefinitions = asks $ fmap fst . ctxDefiningFuncs
37+
getCurrentDefinitions :: MonadReader Context m => m [(OccName, CType)]
38+
getCurrentDefinitions = asks $ ctxDefiningFuncs
3939

4040
getModuleHypothesis :: MonadReader Context m => m [(OccName, CType)]
4141
getModuleHypothesis = asks ctxModuleFuncs

0 commit comments

Comments
 (0)