Skip to content

Commit 43118de

Browse files
pc: Implement code action 'convert to named lambda parameters' (#22799)
Associated Metals PR: scalameta/metals#6669 Add a code action that converts a wildcard lambda to a lambda with named parameters. It supports Scala 3 only at the moment. e.g. ![converttonamedlambdaparameters](https://github.com/user-attachments/assets/93561630-b626-4cff-8248-806e4c32744a) ![converttonamedlambdaparameters1](https://github.com/user-attachments/assets/292ad14f-be8b-4535-b4e4-45819f14da9e)
1 parent 15f40a9 commit 43118de

File tree

4 files changed

+422
-21
lines changed

4 files changed

+422
-21
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
package dotty.tools.pc
2+
3+
import java.nio.file.Paths
4+
import java.util as ju
5+
6+
import scala.jdk.CollectionConverters.*
7+
import scala.meta.pc.OffsetParams
8+
9+
import dotty.tools.dotc.ast.tpd
10+
import dotty.tools.dotc.core.Contexts.Context
11+
import dotty.tools.dotc.core.Flags
12+
import dotty.tools.dotc.interactive.Interactive
13+
import dotty.tools.dotc.interactive.InteractiveDriver
14+
import dotty.tools.dotc.util.SourceFile
15+
import dotty.tools.dotc.util.SourcePosition
16+
import org.eclipse.lsp4j as l
17+
import dotty.tools.pc.utils.InteractiveEnrichments.*
18+
import dotty.tools.pc.utils.TermNameInference.*
19+
20+
/**
21+
* Facilitates the code action that converts a wildcard lambda to a lambda with named parameters
22+
* e.g.
23+
*
24+
* List(1, 2).map(<<_>> + 1) => List(1, 2).map(i => i + 1)
25+
*/
26+
final class PcConvertToNamedLambdaParameters(
27+
driver: InteractiveDriver,
28+
params: OffsetParams
29+
):
30+
import PcConvertToNamedLambdaParameters._
31+
32+
def convertToNamedLambdaParameters: ju.List[l.TextEdit] = {
33+
val uri = params.uri
34+
val filePath = Paths.get(uri)
35+
driver.run(
36+
uri,
37+
SourceFile.virtual(filePath.toString, params.text),
38+
)
39+
given newctx: Context = driver.localContext(params)
40+
val pos = driver.sourcePosition(params)
41+
val trees = driver.openedTrees(uri)
42+
val treeList = Interactive.pathTo(trees, pos)
43+
// Extractor for a lambda function (needs context, so has to be defined here)
44+
val LambdaExtractor = Lambda(using newctx)
45+
// select the most inner wildcard lambda
46+
val firstLambda = treeList.collectFirst {
47+
case LambdaExtractor(params, rhsFn) if params.forall(isWildcardParam) =>
48+
params -> rhsFn
49+
}
50+
51+
firstLambda match {
52+
case Some((params, lambda)) =>
53+
// avoid names that are either defined or referenced in the lambda
54+
val namesToAvoid = allDefAndRefNamesInTree(lambda)
55+
// compute parameter names based on the type of the parameter
56+
val computedParamNames: List[String] =
57+
params.foldLeft(List.empty[String]) { (acc, param) =>
58+
val name = singleLetterNameStream(param.tpe.typeSymbol.name.toString())
59+
.find(n => !namesToAvoid.contains(n) && !acc.contains(n))
60+
acc ++ name.toList
61+
}
62+
if computedParamNames.size == params.size then
63+
val paramReferenceEdits = params.zip(computedParamNames).flatMap { (param, paramName) =>
64+
val paramReferencePosition = findParamReferencePosition(param, lambda)
65+
paramReferencePosition.toList.map { pos =>
66+
val position = pos.toLsp
67+
val range = new l.Range(
68+
position.getStart(),
69+
position.getEnd()
70+
)
71+
new l.TextEdit(range, paramName)
72+
}
73+
}
74+
val paramNamesStr = computedParamNames.mkString(", ")
75+
val paramDefsStr =
76+
if params.size == 1 then paramNamesStr
77+
else s"($paramNamesStr)"
78+
val defRange = new l.Range(
79+
lambda.sourcePos.toLsp.getStart(),
80+
lambda.sourcePos.toLsp.getStart()
81+
)
82+
val paramDefinitionEdits = List(
83+
new l.TextEdit(defRange, s"$paramDefsStr => ")
84+
)
85+
(paramDefinitionEdits ++ paramReferenceEdits).asJava
86+
else
87+
List.empty.asJava
88+
case _ =>
89+
List.empty.asJava
90+
}
91+
}
92+
93+
end PcConvertToNamedLambdaParameters
94+
95+
object PcConvertToNamedLambdaParameters:
96+
val codeActionId = "ConvertToNamedLambdaParameters"
97+
98+
class Lambda(using Context):
99+
def unapply(tree: tpd.Block): Option[(List[tpd.ValDef], tpd.Tree)] = tree match {
100+
case tpd.Block((ddef @ tpd.DefDef(_, tpd.ValDefs(params) :: Nil, _, body: tpd.Tree)) :: Nil, tpd.Closure(_, meth, _))
101+
if ddef.symbol == meth.symbol =>
102+
params match {
103+
case List(param) =>
104+
// lambdas with multiple wildcard parameters are represented as a single parameter function and a block with wildcard valdefs
105+
Some(multipleUnderscoresFromBody(param, body))
106+
case _ => Some(params -> body)
107+
}
108+
case _ => None
109+
}
110+
end Lambda
111+
112+
private def multipleUnderscoresFromBody(param: tpd.ValDef, body: tpd.Tree)(using Context): (List[tpd.ValDef], tpd.Tree) = body match {
113+
case tpd.Block(defs, expr) if param.symbol.is(Flags.Synthetic) =>
114+
val wildcardParamDefs = defs.collect {
115+
case valdef: tpd.ValDef if isWildcardParam(valdef) => valdef
116+
}
117+
if wildcardParamDefs.size == defs.size then wildcardParamDefs -> expr
118+
else List(param) -> body
119+
case _ => List(param) -> body
120+
}
121+
122+
def isWildcardParam(param: tpd.ValDef)(using Context): Boolean =
123+
param.name.toString.startsWith("_$") && param.symbol.is(Flags.Synthetic)
124+
125+
def findParamReferencePosition(param: tpd.ValDef, lambda: tpd.Tree)(using Context): Option[SourcePosition] =
126+
var pos: Option[SourcePosition] = None
127+
object FindParamReference extends tpd.TreeTraverser:
128+
override def traverse(tree: tpd.Tree)(using Context): Unit =
129+
tree match
130+
case ident @ tpd.Ident(_) if ident.symbol == param.symbol =>
131+
pos = Some(tree.sourcePos)
132+
case _ =>
133+
traverseChildren(tree)
134+
FindParamReference.traverse(lambda)
135+
pos
136+
end findParamReferencePosition
137+
138+
def allDefAndRefNamesInTree(tree: tpd.Tree)(using Context): List[String] =
139+
object FindDefinitionsAndRefs extends tpd.TreeAccumulator[List[String]]:
140+
override def apply(x: List[String], tree: tpd.Tree)(using Context): List[String] =
141+
tree match
142+
case tpd.DefDef(name, _, _, _) =>
143+
super.foldOver(x :+ name.toString, tree)
144+
case tpd.ValDef(name, _, _) =>
145+
super.foldOver(x :+ name.toString, tree)
146+
case tpd.Ident(name) =>
147+
super.foldOver(x :+ name.toString, tree)
148+
case _ =>
149+
super.foldOver(x, tree)
150+
FindDefinitionsAndRefs.foldOver(Nil, tree)
151+
end allDefAndRefNamesInTree
152+
153+
end PcConvertToNamedLambdaParameters

presentation-compiler/src/main/dotty/tools/pc/ScalaPresentationCompiler.scala

+26-21
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ case class ScalaPresentationCompiler(
6161
CodeActionId.ImplementAbstractMembers,
6262
CodeActionId.ExtractMethod,
6363
CodeActionId.InlineValue,
64-
CodeActionId.InsertInferredType
64+
CodeActionId.InsertInferredType,
65+
PcConvertToNamedLambdaParameters.codeActionId
6566
).asJava
6667

6768
def this() = this("", None, Nil, Nil)
@@ -82,26 +83,30 @@ case class ScalaPresentationCompiler(
8283
codeActionPayload: Optional[T]
8384
): CompletableFuture[ju.List[TextEdit]] =
8485
(codeActionId, codeActionPayload.asScala) match
85-
case (
86-
CodeActionId.ConvertToNamedArguments,
87-
Some(argIndices: ju.List[_])
88-
) =>
89-
val payload =
90-
argIndices.asScala.collect { case i: Integer => i.toInt }.toSet
91-
convertToNamedArguments(params, payload)
92-
case (CodeActionId.ImplementAbstractMembers, _) =>
93-
implementAbstractMembers(params)
94-
case (CodeActionId.InsertInferredType, _) =>
95-
insertInferredType(params)
96-
case (CodeActionId.InlineValue, _) =>
97-
inlineValue(params)
98-
case (CodeActionId.ExtractMethod, Some(extractionPos: OffsetParams)) =>
99-
params match {
100-
case range: RangeParams =>
101-
extractMethod(range, extractionPos)
102-
case _ => failedFuture(new IllegalArgumentException(s"Expected range parameters"))
103-
}
104-
case (id, _) => failedFuture(new IllegalArgumentException(s"Unsupported action id $id"))
86+
case (
87+
CodeActionId.ConvertToNamedArguments,
88+
Some(argIndices: ju.List[_])
89+
) =>
90+
val payload =
91+
argIndices.asScala.collect { case i: Integer => i.toInt }.toSet
92+
convertToNamedArguments(params, payload)
93+
case (CodeActionId.ImplementAbstractMembers, _) =>
94+
implementAbstractMembers(params)
95+
case (CodeActionId.InsertInferredType, _) =>
96+
insertInferredType(params)
97+
case (CodeActionId.InlineValue, _) =>
98+
inlineValue(params)
99+
case (CodeActionId.ExtractMethod, Some(extractionPos: OffsetParams)) =>
100+
params match {
101+
case range: RangeParams =>
102+
extractMethod(range, extractionPos)
103+
case _ => failedFuture(new IllegalArgumentException(s"Expected range parameters"))
104+
}
105+
case (PcConvertToNamedLambdaParameters.codeActionId, _) =>
106+
compilerAccess.withNonInterruptableCompiler(List.empty[l.TextEdit].asJava, params.token) {
107+
access => PcConvertToNamedLambdaParameters(access.compiler(), params).convertToNamedLambdaParameters
108+
}(params.toQueryContext)
109+
case (id, _) => failedFuture(new IllegalArgumentException(s"Unsupported action id $id"))
105110

106111
private def failedFuture[T](e: Throwable): CompletableFuture[T] =
107112
val f = new CompletableFuture[T]()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package dotty.tools.pc.utils
2+
3+
/**
4+
* Helpers for generating variable names based on the desired types.
5+
*/
6+
object TermNameInference {
7+
8+
/** Single character names for types. (`Int` => `i`, `i1`, `i2`, ...) */
9+
def singleLetterNameStream(typeName: String): LazyList[String] = {
10+
sanitizeInput(typeName).fold(saneNamesStream) { typeName1 =>
11+
val firstCharStr = typeName1.headOption.getOrElse('x').toLower.toString
12+
numberedStreamFromName(firstCharStr)
13+
}
14+
}
15+
16+
/** Names only from upper case letters (`OnDemandSymbolIndex` => `odsi`, `odsi1`, `odsi2`, ...) */
17+
def shortNameStream(typeName: String): LazyList[String] = {
18+
sanitizeInput(typeName).fold(saneNamesStream) { typeName1 =>
19+
val upperCases = typeName1.filter(_.isUpper).map(_.toLower)
20+
val name = if (upperCases.isEmpty) typeName1 else upperCases
21+
numberedStreamFromName(name)
22+
}
23+
}
24+
25+
/** Names from lower case letters (`OnDemandSymbolIndex` => `onDemandSymbolIndex`, `onDemandSymbolIndex1`, ...) */
26+
def fullNameStream(typeName: String): LazyList[String] = {
27+
sanitizeInput(typeName).fold(saneNamesStream) { typeName1 =>
28+
val withFirstLower =
29+
typeName1.headOption.map(_.toLower).getOrElse('x').toString + typeName1.drop(1)
30+
numberedStreamFromName(withFirstLower)
31+
}
32+
}
33+
34+
/** A lazy list of names: a, b, ..., z, aa, ab, ..., az, ba, bb, ... */
35+
def saneNamesStream: LazyList[String] = {
36+
val letters = ('a' to 'z').map(_.toString)
37+
def computeNext(acc: String): String = {
38+
if (acc.last == 'z')
39+
computeNext(acc.init) + letters.head
40+
else
41+
acc.init + letters(letters.indexOf(acc.last) + 1)
42+
}
43+
def loop(acc: String): LazyList[String] =
44+
acc #:: loop(computeNext(acc))
45+
loop("a")
46+
}
47+
48+
private def sanitizeInput(typeName: String): Option[String] =
49+
val typeName1 = typeName.filter(_.isLetterOrDigit)
50+
Option.when(typeName1.nonEmpty)(typeName1)
51+
52+
private def numberedStreamFromName(name: String): LazyList[String] = {
53+
val rest = LazyList.from(1).map(name + _)
54+
name #:: rest
55+
}
56+
}

0 commit comments

Comments
 (0)