Skip to content

Commit f6304a8

Browse files
committed
Revert "Simplify Command API"
This reverts commit ab5ff9f.
1 parent be8467f commit f6304a8

11 files changed

+63
-62
lines changed

compiler/src/dotty/tools/dotc/ast/MainProxies.scala

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,10 @@ object MainProxies {
159159
* .withDocumentation("all my params y")
160160
* )
161161
*
162-
* val args0: Option[S] = cmd.parseArg[S](0, None)
163-
* val args1: Option[Seq[T]] = cmd.parseVararg[T]
162+
* val args0: () => S = cmd.argGetter[S](0, None)
163+
* val args1: () => Seq[T] = cmd.varargGetter[T]
164164
*
165-
* cmd.run(f(args0.get, args1.get*))
165+
* cmd.run(f(args0(), args1()*))
166166
* }
167167
* }
168168
*/
@@ -258,7 +258,7 @@ object MainProxies {
258258
/**
259259
* Creates a list of references and definitions of arguments.
260260
* The goal is to create the
261-
* `val args0: Option[S] = cmd.parseArg[S](0, None)`
261+
* `val args0: () => S = cmd.argGetter[S](0, None)`
262262
* part of the code.
263263
*/
264264
def argValDefs(mt: MethodType): List[ValDef] =
@@ -267,28 +267,28 @@ object MainProxies {
267267
val isRepeated = formal.isRepeatedParam
268268
val formalType = if isRepeated then formal.argTypes.head else formal
269269
val getterSym =
270-
if isRepeated then defn.MainAnnotationCommand_parseVararg
271-
else defn.MainAnnotationCommand_parseArg
270+
if isRepeated then defn.MainAnnotationCommand_varargGetter
271+
else defn.MainAnnotationCommand_argGetter
272272
val defaultValueGetterOpt = defaultValueSymbols.get(idx) match
273273
case None => ref(defn.NoneModule.termRef)
274274
case Some(dvSym) =>
275275
val value = unitToValue(ref(dvSym.termRef))
276276
Apply(ref(defn.SomeClass.companionModule.termRef), value)
277-
val parseArg0 = TypeApply(Select(Ident(nme.cmd), getterSym.name), TypeTree(formalType) :: Nil)
278-
val parseArg =
279-
if isRepeated then parseArg0
280-
else Apply(parseArg0, List(Literal(Constant(idx)), defaultValueGetterOpt))
277+
val argGetter0 = TypeApply(Select(Ident(nme.cmd), getterSym.name), TypeTree(formalType) :: Nil)
278+
val argGetter =
279+
if isRepeated then argGetter0
280+
else Apply(argGetter0, List(Literal(Constant(idx)), defaultValueGetterOpt))
281281

282-
ValDef(argName, TypeTree(), parseArg)
282+
ValDef(argName, TypeTree(), argGetter)
283283
end argValDefs
284284

285285

286286
/** Create a list of argument references that will be passed as argument to the main method.
287-
* `args0.get`, ...`argn.get*`
287+
* `args0`, ...`argn*`
288288
*/
289289
def argRefs(mt: MethodType): List[Tree] =
290290
for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
291-
val argRef = Select(Ident(nme.args ++ idx.toString), nme.get)
291+
val argRef = Apply(Ident(nme.args ++ idx.toString), Nil)
292292
if formal.isRepeatedParam then repeated(argRef) else argRef
293293
end argRefs
294294

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -856,8 +856,8 @@ class Definitions {
856856
@tu lazy val MainAnnotationParameterInfo_withAnnotations: Symbol = MainAnnotationParameterInfo.requiredMethod("withAnnotations")
857857
@tu lazy val MainAnnotationParameterAnnotation: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.ParameterAnnotation")
858858
@tu lazy val MainAnnotationCommand: ClassSymbol = requiredClass("scala.annotation.MainAnnotation.Command")
859-
@tu lazy val MainAnnotationCommand_parseArg: Symbol = MainAnnotationCommand.requiredMethod("parseArg")
860-
@tu lazy val MainAnnotationCommand_parseVararg: Symbol = MainAnnotationCommand.requiredMethod("parseVararg")
859+
@tu lazy val MainAnnotationCommand_argGetter: Symbol = MainAnnotationCommand.requiredMethod("argGetter")
860+
@tu lazy val MainAnnotationCommand_varargGetter: Symbol = MainAnnotationCommand.requiredMethod("varargGetter")
861861
@tu lazy val MainAnnotationCommand_run: Symbol = MainAnnotationCommand.requiredMethod("run")
862862

863863
@tu lazy val CommandLineParserModule: Symbol = requiredModule("scala.util.CommandLineParser")

library/src/scala/annotation/MainAnnotation.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ package scala.annotation
66
* The protocol of calls from compiler-main is as follows:
77
*
88
* - create a `command` with the command line arguments,
9-
* - for each parameter of user-main, a call to `command.parseArg`,
9+
* - for each parameter of user-main, a call to `command.argGetter`,
1010
* or `command.argsGetter` if is a final varargs parameter,
1111
* - a call to `command.run` with the closure of user-main applied to all arguments.
1212
*/
@@ -72,10 +72,10 @@ object MainAnnotation:
7272
trait Command[Parser[_], Result]:
7373

7474
/** The getter for the `idx`th argument of type `T` */
75-
def parseArg[T](idx: Int, defaultArgument: Option[() => T])(using Parser[T]): Option[T]
75+
def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using Parser[T]): () => T
7676

7777
/** The getter for a final varargs argument of type `T*` */
78-
def parseVararg[T](using Parser[T]): Option[Seq[T]]
78+
def varargGetter[T](using Parser[T]): () => Seq[T]
7979

8080
/** Run `program` if all arguments are valid,
8181
* or print usage information and/or error messages.

tests/run/main-annotation-homemade-annot-1.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ class mainAwait(timeout: Int = 2) extends MainAnnotation:
3737
// This is a toy example, it only works with positional args
3838
override def command(args: Array[String], commandName: String, docComment: String, parameterInfos: ParameterInfo*) =
3939
new Command[Parser, Result]:
40-
override def parseArg[T](idx: Int, defaultArgument: Option[() => T])(using p: Parser[T]): Option[T] =
41-
Some(p.fromString(args(idx)))
40+
override def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using p: Parser[T]): () => T =
41+
() => p.fromString(args(idx))
4242

43-
override def parseVararg[T](using p: Parser[T]): Option[Seq[T]] =
44-
Some(for i <- ((parameterInfos.length-1) until args.length) yield p.fromString(args(i)))
43+
override def varargGetter[T](using p: Parser[T]): () => Seq[T] =
44+
() => for i <- ((parameterInfos.length-1) until args.length) yield p.fromString(args(i))
4545

4646
override def run(f: => Result): Unit = println(Await.result(f, Duration(timeout, SECONDS)))
4747
end mainAwait

tests/run/main-annotation-homemade-annot-2.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ class myMain(runs: Int = 3)(after: String*) extends MainAnnotation:
3737
override def command(args: Array[String], commandName: String, docComment: String, parameterInfos: ParameterInfo*) =
3838
new Command[Parser, Result]:
3939

40-
override def parseArg[T](idx: Int, defaultArgument: Option[() => T])(using p: Parser[T]): Option[T] =
41-
Some(p.fromString(args(idx)))
40+
override def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using p: Parser[T]): () => T =
41+
() => p.fromString(args(idx))
4242

43-
override def parseVararg[T](using p: Parser[T]): Option[Seq[T]] =
44-
Some(for i <- (parameterInfos.length until args.length) yield p.fromString(args(i)))
43+
override def varargGetter[T](using p: Parser[T]): () => Seq[T] =
44+
() => for i <- (parameterInfos.length until args.length) yield p.fromString(args(i))
4545

4646
override def run(f: => Result): Unit =
4747
for (_ <- 1 to runs)

tests/run/main-annotation-homemade-annot-3.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ class mainNoArgs extends MainAnnotation:
1818

1919
override def command(args: Array[String], commandName: String, docComment: String, parameterInfos: ParameterInfo*) =
2020
new Command[Parser, Result]:
21-
override def parseArg[T](idx: Int, defaultArgument: Option[() => T])(using p: Parser[T]): Option[T] = None
21+
override def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using p: Parser[T]): () => T = ???
2222

23-
override def parseVararg[T](using p: Parser[T]): Option[Seq[T]] = None
23+
override def varargGetter[T](using p: Parser[T]): () => Seq[T] = ???
2424

2525
override def run(f: => Result): Unit = f
2626
end command

tests/run/main-annotation-homemade-annot-4.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ class mainManyArgs(i1: Int, s2: String, i3: Int) extends MainAnnotation:
1818

1919
override def command(args: Array[String], commandName: String, docComment: String, parameterInfos: ParameterInfo*) =
2020
new Command[Parser, Result]:
21-
override def parseArg[T](idx: Int, optDefaultGetter: Option[() => T])(using p: Parser[T]): Option[T] = None
21+
override def argGetter[T](idx: Int, optDefaultGetter: Option[() => T])(using p: Parser[T]): () => T = ???
2222

23-
override def parseVararg[T](using p: Parser[T]): Option[Seq[T]] = None
23+
override def varargGetter[T](using p: Parser[T]): () => Seq[T] = ???
2424

2525
override def run(f: => Result): Unit = f
2626
end command

tests/run/main-annotation-homemade-annot-5.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ class mainManyArgs(o: Option[Int]) extends MainAnnotation:
2020

2121
override def command(args: Array[String], commandName: String, docComment: String, parameterInfos: ParameterInfo*) =
2222
new Command[Parser, Result]:
23-
override def parseArg[T](idx: Int, defaultArgument: Option[() => T])(using p: Parser[T]): Option[T] = None
23+
override def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using p: Parser[T]): () => T = ???
2424

25-
override def parseVararg[T](using p: Parser[T]): Option[Seq[T]] = None
25+
override def varargGetter[T](using p: Parser[T]): () => Seq[T] = ???
2626

2727
override def run(f: => Result): Unit = f
2828
end command

tests/run/main-annotation-homemade-annot-6.check

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ command(
77
ParameterInfo(name="j", typeName="java.lang.String", hasDefault=true, isVarargs=false, documentation="", annotations=List())
88
)*
99
)
10-
parseArg(0, None)
11-
parseArg(1, Some(2))
10+
argGetter(0, None)
11+
argGetter(1, Some(2))
1212
run()
1313
foo(42, abc)
1414

@@ -21,8 +21,8 @@ command(
2121
ParameterInfo(name="rest", typeName="scala.Int", hasDefault=false, isVarargs=true, documentation="", annotations=List())
2222
)*
2323
)
24-
parseArg(0, None)
25-
parseVararg()
24+
argGetter(0, None)
25+
varargGetter()
2626
run()
2727
bar(List(42), 42, 42)
2828

tests/run/main-annotation-homemade-annot-6.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ class myMain extends MainAnnotation:
3535
| ${parameterInfos.map(paramInfoString).mkString("Seq(\n", ",\n", "\n )*")}
3636
|)""".stripMargin)
3737
new Command[Parser, Result]:
38-
override def parseArg[T](idx: Int, defaultArgument: Option[() => T])(using p: Parser[T]): Option[T] =
39-
println(s"parseArg($idx, ${defaultArgument.map(_())})")
40-
Some(p.make)
38+
override def argGetter[T](idx: Int, defaultArgument: Option[() => T])(using p: Parser[T]): () => T =
39+
println(s"argGetter($idx, ${defaultArgument.map(_())})")
40+
() => p.make
4141

42-
override def parseVararg[T](using p: Parser[T]): Option[Seq[T]] =
43-
println("parseVararg()")
44-
Some(Seq(p.make, p.make))
42+
override def varargGetter[T](using p: Parser[T]): () => Seq[T] =
43+
println("varargGetter()")
44+
() => Seq(p.make, p.make)
4545

4646
override def run(f: => Result): Unit =
4747
println("run()")

tests/run/main-annotation-newMain.scala

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ final class newMain extends MainAnnotation:
121121
private val errors = new mutable.ArrayBuffer[String]
122122

123123
/** Issue an error, and return an uncallable getter */
124-
private def error(msg: String): None.type =
124+
private def error(msg: String): () => Nothing =
125125
errors += msg
126-
None
126+
() => throw new AssertionError("trying to get invalid argument")
127127

128128
private inline def nameIsValid(name: String): Boolean =
129129
name.length > 1 // TODO add more checks for illegal characters
@@ -140,8 +140,10 @@ final class newMain extends MainAnnotation:
140140
case s => argMarker + s
141141
}
142142

143-
private def convert[T](argName: String, arg: String, p: Parser[T]): Option[T] =
144-
p.fromStringOption(arg).orElse(error(s"invalid argument for $argName: $arg"))
143+
private def convert[T](argName: String, arg: String, p: Parser[T]): () => T =
144+
p.fromStringOption(arg) match
145+
case Some(t) => () => t
146+
case None => error(s"invalid argument for $argName: $arg")
145147

146148
private def usage(): Unit =
147149
def argsUsage: Seq[String] =
@@ -228,27 +230,26 @@ final class newMain extends MainAnnotation:
228230
argDoc.append("\n").append(shiftedDoc)
229231
}
230232

231-
232233
println(argDoc)
233234
}
234235
end explain
235236

236-
private def getAliases(paramInfo: ParameterInfo): Seq[String] =
237-
paramInfo.annotations.collect{ case a: Alias => a }.flatMap(_.aliases)
237+
private def getAliases(paramInfos: ParameterInfo): Seq[String] =
238+
paramInfos.annotations.collect{ case a: Alias => a }.flatMap(_.aliases)
238239

239-
private def getAlternativeNames(paramInfo: ParameterInfo): Seq[String] =
240-
getAliases(paramInfo).filter(nameIsValid(_))
240+
private def getAlternativeNames(paramInfos: ParameterInfo): Seq[String] =
241+
getAliases(paramInfos).filter(nameIsValid(_))
241242

242-
private def getShortNames(paramInfo: ParameterInfo): Seq[Char] =
243-
getAliases(paramInfo).filter(shortNameIsValid(_)).map(_(0))
243+
private def getShortNames(paramInfos: ParameterInfo): Seq[Char] =
244+
getAliases(paramInfos).filter(shortNameIsValid(_)).map(_(0))
244245

245-
private def getInvalidNames(paramInfo: ParameterInfo): Seq[String | Char] =
246-
getAliases(paramInfo).filter(name => !nameIsValid(name) && !shortNameIsValid(name))
246+
private def getInvalidNames(paramInfos: ParameterInfo): Seq[String | Char] =
247+
getAliases(paramInfos).filter(name => !nameIsValid(name) && !shortNameIsValid(name))
247248

248-
def parseArg[T](idx: Int, optDefaultGetter: Option[() => T])(using p: Parser[T]): Option[T] =
249+
override def argGetter[T](idx: Int, optDefaultGetter: Option[() => T])(using p: Parser[T]): () => T =
249250
val name = parameterInfos(idx).name
250-
251251
argKinds += (if optDefaultGetter.nonEmpty then ArgumentKind.OptionalArgument else ArgumentKind.SimpleArgument)
252+
val parameterInfo = nameToParameterInfo(name)
252253

253254
byNameArgs.get(name) match {
254255
case Some(Nil) =>
@@ -264,20 +265,20 @@ final class newMain extends MainAnnotation:
264265
if positionalArgs.length > 0 then
265266
convert(name, positionalArgs.dequeue, p)
266267
else if optDefaultGetter.nonEmpty then
267-
optDefaultGetter.map(_())
268+
optDefaultGetter.get
268269
else
269270
error(s"missing argument for $name")
270271
}
271-
end parseArg
272+
end argGetter
272273

273-
def parseVararg[T](using p: Parser[T]): Option[Seq[T]] =
274-
argKinds += ArgumentKind.VarArgument
274+
override def varargGetter[T](using p: Parser[T]): () => Seq[T] =
275275
val name = parameterInfos.last.name
276+
argKinds += ArgumentKind.VarArgument
276277

277278
val byNameGetters = byNameArgs.getOrElse(name, Seq()).map(arg => convert(name, arg, p))
278279
val positionalGetters = positionalArgs.removeAll.map(arg => convert(name, arg, p))
279280
// First take arguments passed by name, then those passed by position
280-
Some(byNameGetters.flatten ++ positionalGetters.flatten)
281+
() => (byNameGetters ++ positionalGetters).map(_())
281282

282283
override def run(f: => Result): Unit =
283284
// Check aliases unicity

0 commit comments

Comments
 (0)