|
| 1 | +package dotty.tools.pc |
| 2 | + |
| 3 | +import java.nio.file.Paths |
| 4 | +import java.util as ju |
| 5 | + |
| 6 | +import scala.jdk.CollectionConverters.* |
| 7 | +import scala.meta.pc.OffsetParams |
| 8 | + |
| 9 | +import dotty.tools.dotc.ast.tpd |
| 10 | +import dotty.tools.dotc.core.Contexts.Context |
| 11 | +import dotty.tools.dotc.core.Flags |
| 12 | +import dotty.tools.dotc.interactive.Interactive |
| 13 | +import dotty.tools.dotc.interactive.InteractiveDriver |
| 14 | +import dotty.tools.dotc.util.SourceFile |
| 15 | +import dotty.tools.dotc.util.SourcePosition |
| 16 | +import org.eclipse.lsp4j as l |
| 17 | +import dotty.tools.pc.utils.InteractiveEnrichments.* |
| 18 | +import dotty.tools.pc.utils.TermNameInference.* |
| 19 | + |
| 20 | +/** |
| 21 | + * Facilitates the code action that converts a wildcard lambda to a lambda with named parameters |
| 22 | + * e.g. |
| 23 | + * |
| 24 | + * List(1, 2).map(<<_>> + 1) => List(1, 2).map(i => i + 1) |
| 25 | + */ |
| 26 | +final class PcConvertToNamedLambdaParameters( |
| 27 | + driver: InteractiveDriver, |
| 28 | + params: OffsetParams |
| 29 | +): |
| 30 | + import PcConvertToNamedLambdaParameters._ |
| 31 | + |
| 32 | + def convertToNamedLambdaParameters: ju.List[l.TextEdit] = { |
| 33 | + val uri = params.uri |
| 34 | + val filePath = Paths.get(uri) |
| 35 | + driver.run( |
| 36 | + uri, |
| 37 | + SourceFile.virtual(filePath.toString, params.text), |
| 38 | + ) |
| 39 | + given newctx: Context = driver.localContext(params) |
| 40 | + val pos = driver.sourcePosition(params) |
| 41 | + val trees = driver.openedTrees(uri) |
| 42 | + val treeList = Interactive.pathTo(trees, pos) |
| 43 | + // Extractor for a lambda function (needs context, so has to be defined here) |
| 44 | + val LambdaExtractor = Lambda(using newctx) |
| 45 | + // select the most inner wildcard lambda |
| 46 | + val firstLambda = treeList.collectFirst { |
| 47 | + case LambdaExtractor(params, rhsFn) if params.forall(isWildcardParam) => |
| 48 | + params -> rhsFn |
| 49 | + } |
| 50 | + |
| 51 | + firstLambda match { |
| 52 | + case Some((params, lambda)) => |
| 53 | + // avoid names that are either defined or referenced in the lambda |
| 54 | + val namesToAvoid = allDefAndRefNamesInTree(lambda) |
| 55 | + // compute parameter names based on the type of the parameter |
| 56 | + val computedParamNames: List[String] = |
| 57 | + params.foldLeft(List.empty[String]) { (acc, param) => |
| 58 | + val name = singleLetterNameStream(param.tpe.typeSymbol.name.toString()) |
| 59 | + .find(n => !namesToAvoid.contains(n) && !acc.contains(n)) |
| 60 | + acc ++ name.toList |
| 61 | + } |
| 62 | + if computedParamNames.size == params.size then |
| 63 | + val paramReferenceEdits = params.zip(computedParamNames).flatMap { (param, paramName) => |
| 64 | + val paramReferencePosition = findParamReferencePosition(param, lambda) |
| 65 | + paramReferencePosition.toList.map { pos => |
| 66 | + val position = pos.toLsp |
| 67 | + val range = new l.Range( |
| 68 | + position.getStart(), |
| 69 | + position.getEnd() |
| 70 | + ) |
| 71 | + new l.TextEdit(range, paramName) |
| 72 | + } |
| 73 | + } |
| 74 | + val paramNamesStr = computedParamNames.mkString(", ") |
| 75 | + val paramDefsStr = |
| 76 | + if params.size == 1 then paramNamesStr |
| 77 | + else s"($paramNamesStr)" |
| 78 | + val defRange = new l.Range( |
| 79 | + lambda.sourcePos.toLsp.getStart(), |
| 80 | + lambda.sourcePos.toLsp.getStart() |
| 81 | + ) |
| 82 | + val paramDefinitionEdits = List( |
| 83 | + new l.TextEdit(defRange, s"$paramDefsStr => ") |
| 84 | + ) |
| 85 | + (paramDefinitionEdits ++ paramReferenceEdits).asJava |
| 86 | + else |
| 87 | + List.empty.asJava |
| 88 | + case _ => |
| 89 | + List.empty.asJava |
| 90 | + } |
| 91 | + } |
| 92 | + |
| 93 | +end PcConvertToNamedLambdaParameters |
| 94 | + |
| 95 | +object PcConvertToNamedLambdaParameters: |
| 96 | + val codeActionId = "ConvertToNamedLambdaParameters" |
| 97 | + |
| 98 | + class Lambda(using Context): |
| 99 | + def unapply(tree: tpd.Block): Option[(List[tpd.ValDef], tpd.Tree)] = tree match { |
| 100 | + case tpd.Block((ddef @ tpd.DefDef(_, tpd.ValDefs(params) :: Nil, _, body: tpd.Tree)) :: Nil, tpd.Closure(_, meth, _)) |
| 101 | + if ddef.symbol == meth.symbol => |
| 102 | + params match { |
| 103 | + case List(param) => |
| 104 | + // lambdas with multiple wildcard parameters are represented as a single parameter function and a block with wildcard valdefs |
| 105 | + Some(multipleUnderscoresFromBody(param, body)) |
| 106 | + case _ => Some(params -> body) |
| 107 | + } |
| 108 | + case _ => None |
| 109 | + } |
| 110 | + end Lambda |
| 111 | + |
| 112 | + private def multipleUnderscoresFromBody(param: tpd.ValDef, body: tpd.Tree)(using Context): (List[tpd.ValDef], tpd.Tree) = body match { |
| 113 | + case tpd.Block(defs, expr) if param.symbol.is(Flags.Synthetic) => |
| 114 | + val wildcardParamDefs = defs.collect { |
| 115 | + case valdef: tpd.ValDef if isWildcardParam(valdef) => valdef |
| 116 | + } |
| 117 | + if wildcardParamDefs.size == defs.size then wildcardParamDefs -> expr |
| 118 | + else List(param) -> body |
| 119 | + case _ => List(param) -> body |
| 120 | + } |
| 121 | + |
| 122 | + def isWildcardParam(param: tpd.ValDef)(using Context): Boolean = |
| 123 | + param.name.toString.startsWith("_$") && param.symbol.is(Flags.Synthetic) |
| 124 | + |
| 125 | + def findParamReferencePosition(param: tpd.ValDef, lambda: tpd.Tree)(using Context): Option[SourcePosition] = |
| 126 | + var pos: Option[SourcePosition] = None |
| 127 | + object FindParamReference extends tpd.TreeTraverser: |
| 128 | + override def traverse(tree: tpd.Tree)(using Context): Unit = |
| 129 | + tree match |
| 130 | + case ident @ tpd.Ident(_) if ident.symbol == param.symbol => |
| 131 | + pos = Some(tree.sourcePos) |
| 132 | + case _ => |
| 133 | + traverseChildren(tree) |
| 134 | + FindParamReference.traverse(lambda) |
| 135 | + pos |
| 136 | + end findParamReferencePosition |
| 137 | + |
| 138 | + def allDefAndRefNamesInTree(tree: tpd.Tree)(using Context): List[String] = |
| 139 | + object FindDefinitionsAndRefs extends tpd.TreeAccumulator[List[String]]: |
| 140 | + override def apply(x: List[String], tree: tpd.Tree)(using Context): List[String] = |
| 141 | + tree match |
| 142 | + case tpd.DefDef(name, _, _, _) => |
| 143 | + super.foldOver(x :+ name.toString, tree) |
| 144 | + case tpd.ValDef(name, _, _) => |
| 145 | + super.foldOver(x :+ name.toString, tree) |
| 146 | + case tpd.Ident(name) => |
| 147 | + super.foldOver(x :+ name.toString, tree) |
| 148 | + case _ => |
| 149 | + super.foldOver(x, tree) |
| 150 | + FindDefinitionsAndRefs.foldOver(Nil, tree) |
| 151 | + end allDefAndRefNamesInTree |
| 152 | + |
| 153 | +end PcConvertToNamedLambdaParameters |
0 commit comments