Skip to content

Commit d433ce8

Browse files
committed
Address review comments.
Main changes: * Apply common protection against wrong number of arguments. * Exception in an operation is converted into a type error.
1 parent 32043aa commit d433ce8

File tree

1 file changed

+126
-114
lines changed

1 file changed

+126
-114
lines changed

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 126 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -4196,6 +4196,11 @@ object Types {
41964196
case tycon: TypeRef if defn.isCompiletimeAppliedType(tycon.symbol) =>
41974197
extension (tp : Type) def fixForEvaluation : Type =
41984198
tp.normalized.dealias match {
4199+
//enable operations for constant singleton terms. E.g.:
4200+
//```
4201+
//final val one = 1
4202+
//type Two = one.type + one.type
4203+
//```
41994204
case tp : TermRef => tp.underlying
42004205
case tp => tp
42014206
}
@@ -4234,163 +4239,170 @@ object Types {
42344239
case ConstantType(Constant(n: String)) => Some(n)
42354240
case _ => None
42364241
}
4237-
def isConst : Option[Type] = args.head.fixForEvaluation match {
4242+
4243+
def isConst(tp : Type) : Option[Type] = tp.fixForEvaluation match {
42384244
case ConstantType(_) => Some(ConstantType(Constant(true)))
42394245
case _ => Some(ConstantType(Constant(false)))
42404246
}
4247+
4248+
def expectArgsNum(expectedNum : Int) : Unit =
4249+
//We can use assert instead of a compiler type error because this error should not
4250+
//occur since the type signature of the operation enforces the proper number of args.
4251+
assert(args.length == expectedNum, s"Type operation expects $expectedNum arguments but found ${args.length}")
4252+
42414253
def natValue(tp: Type): Option[Int] = intValue(tp).filter(n => n >= 0 && n < Int.MaxValue)
42424254

4255+
//Runs the op and returns the result as a constant type.
4256+
//If the op throws an exception, then this exception is converted into a type error.
4257+
def runConstantOp(op : => Any): Type =
4258+
val result = try {
4259+
op
4260+
} catch {
4261+
case e : Throwable =>
4262+
throw new TypeError(e.getMessage)
4263+
}
4264+
ConstantType(Constant(result))
4265+
42434266
def constantFold1[T](extractor: Type => Option[T], op: T => Any): Option[Type] =
4244-
extractor(args.head).map(a => ConstantType(Constant(op(a))))
4267+
expectArgsNum(1)
4268+
extractor(args.head).map(a => runConstantOp(op(a)))
42454269

42464270
def constantFold2[T](extractor: Type => Option[T], op: (T, T) => Any): Option[Type] =
42474271
constantFold2AB(extractor, extractor, op)
42484272

42494273
def constantFold2AB[TA, TB](extractorA: Type => Option[TA], extractorB: Type => Option[TB], op: (TA, TB) => Any): Option[Type] =
4274+
expectArgsNum(2)
42504275
for {
4251-
a <- extractorA(args.head)
4252-
b <- extractorB(args.last)
4253-
} yield ConstantType(Constant(op(a, b)))
4276+
a <- extractorA(args(0))
4277+
b <- extractorB(args(1))
4278+
} yield runConstantOp(op(a, b))
42544279

42554280
def constantFold3[TA, TB, TC](
42564281
extractorA: Type => Option[TA],
42574282
extractorB: Type => Option[TB],
42584283
extractorC: Type => Option[TC],
42594284
op: (TA, TB, TC) => Any
42604285
): Option[Type] =
4286+
expectArgsNum(3)
42614287
for {
4262-
a <- extractorA(args.head)
4288+
a <- extractorA(args(0))
42634289
b <- extractorB(args(1))
4264-
c <- extractorC(args.last)
4265-
} yield ConstantType(Constant(op(a, b, c)))
4290+
c <- extractorC(args(2))
4291+
} yield runConstantOp(op(a, b, c))
42664292

42674293
trace(i"compiletime constant fold $this", typr, show = true) {
42684294
val name = tycon.symbol.name
42694295
val owner = tycon.symbol.owner
4270-
val nArgs = args.length
42714296
val constantType =
42724297
if (defn.isCompiletime_S(tycon.symbol)) {
4273-
if (nArgs == 1) constantFold1(natValue, _ + 1)
4274-
else None
4298+
constantFold1(natValue, _ + 1)
42754299
} else if (owner == defn.CompiletimeOpsAnyModuleClass) name match {
4276-
case tpnme.Equals if nArgs == 2 => constantFold2(constValue, _ == _)
4277-
case tpnme.NotEquals if nArgs == 2 => constantFold2(constValue, _ != _)
4278-
case tpnme.ToString if nArgs == 1 => constantFold1(constValue, _.toString)
4279-
case tpnme.IsConst if nArgs == 1 => isConst
4300+
case tpnme.Equals => constantFold2(constValue, _ == _)
4301+
case tpnme.NotEquals => constantFold2(constValue, _ != _)
4302+
case tpnme.ToString => constantFold1(constValue, _.toString)
4303+
case tpnme.IsConst => isConst(args.head)
42804304
case _ => None
42814305
} else if (owner == defn.CompiletimeOpsIntModuleClass) name match {
4282-
case tpnme.Abs if nArgs == 1 => constantFold1(intValue, _.abs)
4283-
case tpnme.Negate if nArgs == 1 => constantFold1(intValue, x => -x)
4306+
case tpnme.Abs => constantFold1(intValue, _.abs)
4307+
case tpnme.Negate => constantFold1(intValue, x => -x)
42844308
//ToString is deprecated for ops.int, and moved to ops.any
4285-
case tpnme.ToString if nArgs == 1 => constantFold1(intValue, _.toString)
4286-
case tpnme.Plus if nArgs == 2 => constantFold2(intValue, _ + _)
4287-
case tpnme.Minus if nArgs == 2 => constantFold2(intValue, _ - _)
4288-
case tpnme.Times if nArgs == 2 => constantFold2(intValue, _ * _)
4289-
case tpnme.Div if nArgs == 2 => constantFold2(intValue, {
4290-
case (_, 0) => throw new TypeError("Division by 0")
4291-
case (a, b) => a / b
4292-
})
4293-
case tpnme.Mod if nArgs == 2 => constantFold2(intValue, {
4294-
case (_, 0) => throw new TypeError("Modulo by 0")
4295-
case (a, b) => a % b
4296-
})
4297-
case tpnme.Lt if nArgs == 2 => constantFold2(intValue, _ < _)
4298-
case tpnme.Gt if nArgs == 2 => constantFold2(intValue, _ > _)
4299-
case tpnme.Ge if nArgs == 2 => constantFold2(intValue, _ >= _)
4300-
case tpnme.Le if nArgs == 2 => constantFold2(intValue, _ <= _)
4301-
case tpnme.Xor if nArgs == 2 => constantFold2(intValue, _ ^ _)
4302-
case tpnme.BitwiseAnd if nArgs == 2 => constantFold2(intValue, _ & _)
4303-
case tpnme.BitwiseOr if nArgs == 2 => constantFold2(intValue, _ | _)
4304-
case tpnme.ASR if nArgs == 2 => constantFold2(intValue, _ >> _)
4305-
case tpnme.LSL if nArgs == 2 => constantFold2(intValue, _ << _)
4306-
case tpnme.LSR if nArgs == 2 => constantFold2(intValue, _ >>> _)
4307-
case tpnme.Min if nArgs == 2 => constantFold2(intValue, _ min _)
4308-
case tpnme.Max if nArgs == 2 => constantFold2(intValue, _ max _)
4309-
case tpnme.NumberOfLeadingZeros if nArgs == 1 => constantFold1(intValue, Integer.numberOfLeadingZeros(_))
4310-
case tpnme.ToLong if nArgs == 1 => constantFold1(intValue, _.toLong)
4311-
case tpnme.ToFloat if nArgs == 1 => constantFold1(intValue, _.toFloat)
4312-
case tpnme.ToDouble if nArgs == 1 => constantFold1(intValue, _.toDouble)
4309+
case tpnme.ToString => constantFold1(intValue, _.toString)
4310+
case tpnme.Plus => constantFold2(intValue, _ + _)
4311+
case tpnme.Minus => constantFold2(intValue, _ - _)
4312+
case tpnme.Times => constantFold2(intValue, _ * _)
4313+
case tpnme.Div => constantFold2(intValue, _ / _)
4314+
case tpnme.Mod => constantFold2(intValue, _ % _)
4315+
case tpnme.Lt => constantFold2(intValue, _ < _)
4316+
case tpnme.Gt => constantFold2(intValue, _ > _)
4317+
case tpnme.Ge => constantFold2(intValue, _ >= _)
4318+
case tpnme.Le => constantFold2(intValue, _ <= _)
4319+
case tpnme.Xor => constantFold2(intValue, _ ^ _)
4320+
case tpnme.BitwiseAnd => constantFold2(intValue, _ & _)
4321+
case tpnme.BitwiseOr => constantFold2(intValue, _ | _)
4322+
case tpnme.ASR => constantFold2(intValue, _ >> _)
4323+
case tpnme.LSL => constantFold2(intValue, _ << _)
4324+
case tpnme.LSR => constantFold2(intValue, _ >>> _)
4325+
case tpnme.Min => constantFold2(intValue, _ min _)
4326+
case tpnme.Max => constantFold2(intValue, _ max _)
4327+
case tpnme.NumberOfLeadingZeros => constantFold1(intValue, Integer.numberOfLeadingZeros(_))
4328+
case tpnme.ToLong => constantFold1(intValue, _.toLong)
4329+
case tpnme.ToFloat => constantFold1(intValue, _.toFloat)
4330+
case tpnme.ToDouble => constantFold1(intValue, _.toDouble)
43134331
case _ => None
43144332
} else if (owner == defn.CompiletimeOpsLongModuleClass) name match {
4315-
case tpnme.Abs if nArgs == 1 => constantFold1(longValue, _.abs)
4316-
case tpnme.Negate if nArgs == 1 => constantFold1(longValue, x => -x)
4317-
case tpnme.Plus if nArgs == 2 => constantFold2(longValue, _ + _)
4318-
case tpnme.Minus if nArgs == 2 => constantFold2(longValue, _ - _)
4319-
case tpnme.Times if nArgs == 2 => constantFold2(longValue, _ * _)
4320-
case tpnme.Div if nArgs == 2 => constantFold2(longValue, {
4321-
case (_, 0L) => throw new TypeError("Division by 0")
4322-
case (a, b) => a / b
4323-
})
4324-
case tpnme.Mod if nArgs == 2 => constantFold2(longValue, {
4325-
case (_, 0L) => throw new TypeError("Modulo by 0")
4326-
case (a, b) => a % b
4327-
})
4328-
case tpnme.Lt if nArgs == 2 => constantFold2(longValue, _ < _)
4329-
case tpnme.Gt if nArgs == 2 => constantFold2(longValue, _ > _)
4330-
case tpnme.Ge if nArgs == 2 => constantFold2(longValue, _ >= _)
4331-
case tpnme.Le if nArgs == 2 => constantFold2(longValue, _ <= _)
4332-
case tpnme.Xor if nArgs == 2 => constantFold2(longValue, _ ^ _)
4333-
case tpnme.BitwiseAnd if nArgs == 2 => constantFold2(longValue, _ & _)
4334-
case tpnme.BitwiseOr if nArgs == 2 => constantFold2(longValue, _ | _)
4335-
case tpnme.ASR if nArgs == 2 => constantFold2(longValue, _ >> _)
4336-
case tpnme.LSL if nArgs == 2 => constantFold2(longValue, _ << _)
4337-
case tpnme.LSR if nArgs == 2 => constantFold2(longValue, _ >>> _)
4338-
case tpnme.Min if nArgs == 2 => constantFold2(longValue, _ min _)
4339-
case tpnme.Max if nArgs == 2 => constantFold2(longValue, _ max _)
4340-
case tpnme.NumberOfLeadingZeros if nArgs == 1 =>
4333+
case tpnme.Abs => constantFold1(longValue, _.abs)
4334+
case tpnme.Negate => constantFold1(longValue, x => -x)
4335+
case tpnme.Plus => constantFold2(longValue, _ + _)
4336+
case tpnme.Minus => constantFold2(longValue, _ - _)
4337+
case tpnme.Times => constantFold2(longValue, _ * _)
4338+
case tpnme.Div => constantFold2(longValue, _ / _)
4339+
case tpnme.Mod => constantFold2(longValue, _ % _)
4340+
case tpnme.Lt => constantFold2(longValue, _ < _)
4341+
case tpnme.Gt => constantFold2(longValue, _ > _)
4342+
case tpnme.Ge => constantFold2(longValue, _ >= _)
4343+
case tpnme.Le => constantFold2(longValue, _ <= _)
4344+
case tpnme.Xor => constantFold2(longValue, _ ^ _)
4345+
case tpnme.BitwiseAnd => constantFold2(longValue, _ & _)
4346+
case tpnme.BitwiseOr => constantFold2(longValue, _ | _)
4347+
case tpnme.ASR => constantFold2(longValue, _ >> _)
4348+
case tpnme.LSL => constantFold2(longValue, _ << _)
4349+
case tpnme.LSR => constantFold2(longValue, _ >>> _)
4350+
case tpnme.Min => constantFold2(longValue, _ min _)
4351+
case tpnme.Max => constantFold2(longValue, _ max _)
4352+
case tpnme.NumberOfLeadingZeros =>
43414353
constantFold1(longValue, java.lang.Long.numberOfLeadingZeros(_))
4342-
case tpnme.ToInt if nArgs == 1 => constantFold1(longValue, _.toInt)
4343-
case tpnme.ToFloat if nArgs == 1 => constantFold1(longValue, _.toFloat)
4344-
case tpnme.ToDouble if nArgs == 1 => constantFold1(longValue, _.toDouble)
4354+
case tpnme.ToInt => constantFold1(longValue, _.toInt)
4355+
case tpnme.ToFloat => constantFold1(longValue, _.toFloat)
4356+
case tpnme.ToDouble => constantFold1(longValue, _.toDouble)
43454357
case _ => None
43464358
} else if (owner == defn.CompiletimeOpsFloatModuleClass) name match {
4347-
case tpnme.Abs if nArgs == 1 => constantFold1(floatValue, _.abs)
4348-
case tpnme.Negate if nArgs == 1 => constantFold1(floatValue, x => -x)
4349-
case tpnme.Plus if nArgs == 2 => constantFold2(floatValue, _ + _)
4350-
case tpnme.Minus if nArgs == 2 => constantFold2(floatValue, _ - _)
4351-
case tpnme.Times if nArgs == 2 => constantFold2(floatValue, _ * _)
4352-
case tpnme.Div if nArgs == 2 => constantFold2(floatValue, _ / _)
4353-
case tpnme.Mod if nArgs == 2 => constantFold2(floatValue, _ % _)
4354-
case tpnme.Lt if nArgs == 2 => constantFold2(floatValue, _ < _)
4355-
case tpnme.Gt if nArgs == 2 => constantFold2(floatValue, _ > _)
4356-
case tpnme.Ge if nArgs == 2 => constantFold2(floatValue, _ >= _)
4357-
case tpnme.Le if nArgs == 2 => constantFold2(floatValue, _ <= _)
4358-
case tpnme.Min if nArgs == 2 => constantFold2(floatValue, _ min _)
4359-
case tpnme.Max if nArgs == 2 => constantFold2(floatValue, _ max _)
4360-
case tpnme.ToInt if nArgs == 1 => constantFold1(floatValue, _.toInt)
4361-
case tpnme.ToLong if nArgs == 1 => constantFold1(floatValue, _.toLong)
4362-
case tpnme.ToDouble if nArgs == 1 => constantFold1(floatValue, _.toDouble)
4359+
case tpnme.Abs => constantFold1(floatValue, _.abs)
4360+
case tpnme.Negate => constantFold1(floatValue, x => -x)
4361+
case tpnme.Plus => constantFold2(floatValue, _ + _)
4362+
case tpnme.Minus => constantFold2(floatValue, _ - _)
4363+
case tpnme.Times => constantFold2(floatValue, _ * _)
4364+
case tpnme.Div => constantFold2(floatValue, _ / _)
4365+
case tpnme.Mod => constantFold2(floatValue, _ % _)
4366+
case tpnme.Lt => constantFold2(floatValue, _ < _)
4367+
case tpnme.Gt => constantFold2(floatValue, _ > _)
4368+
case tpnme.Ge => constantFold2(floatValue, _ >= _)
4369+
case tpnme.Le => constantFold2(floatValue, _ <= _)
4370+
case tpnme.Min => constantFold2(floatValue, _ min _)
4371+
case tpnme.Max => constantFold2(floatValue, _ max _)
4372+
case tpnme.ToInt => constantFold1(floatValue, _.toInt)
4373+
case tpnme.ToLong => constantFold1(floatValue, _.toLong)
4374+
case tpnme.ToDouble => constantFold1(floatValue, _.toDouble)
43634375
case _ => None
43644376
} else if (owner == defn.CompiletimeOpsDoubleModuleClass) name match {
4365-
case tpnme.Abs if nArgs == 1 => constantFold1(doubleValue, _.abs)
4366-
case tpnme.Negate if nArgs == 1 => constantFold1(doubleValue, x => -x)
4367-
case tpnme.Plus if nArgs == 2 => constantFold2(doubleValue, _ + _)
4368-
case tpnme.Minus if nArgs == 2 => constantFold2(doubleValue, _ - _)
4369-
case tpnme.Times if nArgs == 2 => constantFold2(doubleValue, _ * _)
4370-
case tpnme.Div if nArgs == 2 => constantFold2(doubleValue, _ / _)
4371-
case tpnme.Mod if nArgs == 2 => constantFold2(doubleValue, _ % _)
4372-
case tpnme.Lt if nArgs == 2 => constantFold2(doubleValue, _ < _)
4373-
case tpnme.Gt if nArgs == 2 => constantFold2(doubleValue, _ > _)
4374-
case tpnme.Ge if nArgs == 2 => constantFold2(doubleValue, _ >= _)
4375-
case tpnme.Le if nArgs == 2 => constantFold2(doubleValue, _ <= _)
4376-
case tpnme.Min if nArgs == 2 => constantFold2(doubleValue, _ min _)
4377-
case tpnme.Max if nArgs == 2 => constantFold2(doubleValue, _ max _)
4378-
case tpnme.ToInt if nArgs == 1 => constantFold1(doubleValue, _.toInt)
4379-
case tpnme.ToLong if nArgs == 1 => constantFold1(doubleValue, _.toLong)
4380-
case tpnme.ToFloat if nArgs == 1 => constantFold1(doubleValue, _.toFloat)
4377+
case tpnme.Abs => constantFold1(doubleValue, _.abs)
4378+
case tpnme.Negate => constantFold1(doubleValue, x => -x)
4379+
case tpnme.Plus => constantFold2(doubleValue, _ + _)
4380+
case tpnme.Minus => constantFold2(doubleValue, _ - _)
4381+
case tpnme.Times => constantFold2(doubleValue, _ * _)
4382+
case tpnme.Div => constantFold2(doubleValue, _ / _)
4383+
case tpnme.Mod => constantFold2(doubleValue, _ % _)
4384+
case tpnme.Lt => constantFold2(doubleValue, _ < _)
4385+
case tpnme.Gt => constantFold2(doubleValue, _ > _)
4386+
case tpnme.Ge => constantFold2(doubleValue, _ >= _)
4387+
case tpnme.Le => constantFold2(doubleValue, _ <= _)
4388+
case tpnme.Min => constantFold2(doubleValue, _ min _)
4389+
case tpnme.Max => constantFold2(doubleValue, _ max _)
4390+
case tpnme.ToInt => constantFold1(doubleValue, _.toInt)
4391+
case tpnme.ToLong => constantFold1(doubleValue, _.toLong)
4392+
case tpnme.ToFloat => constantFold1(doubleValue, _.toFloat)
43814393
case _ => None
43824394
} else if (owner == defn.CompiletimeOpsStringModuleClass) name match {
4383-
case tpnme.Plus if nArgs == 2 => constantFold2(stringValue, _ + _)
4384-
case tpnme.Length if nArgs == 1 => constantFold1(stringValue, _.length)
4385-
case tpnme.Matches if nArgs == 2 => constantFold2(stringValue, _ matches _)
4386-
case tpnme.Substring if nArgs == 3 =>
4395+
case tpnme.Plus => constantFold2(stringValue, _ + _)
4396+
case tpnme.Length => constantFold1(stringValue, _.length)
4397+
case tpnme.Matches => constantFold2(stringValue, _ matches _)
4398+
case tpnme.Substring =>
43874399
constantFold3(stringValue, intValue, intValue, (s, b, e) => s.substring(b, e))
43884400
case _ => None
43894401
} else if (owner == defn.CompiletimeOpsBooleanModuleClass) name match {
4390-
case tpnme.Not if nArgs == 1 => constantFold1(boolValue, x => !x)
4391-
case tpnme.And if nArgs == 2 => constantFold2(boolValue, _ && _)
4392-
case tpnme.Or if nArgs == 2 => constantFold2(boolValue, _ || _)
4393-
case tpnme.Xor if nArgs == 2 => constantFold2(boolValue, _ ^ _)
4402+
case tpnme.Not => constantFold1(boolValue, x => !x)
4403+
case tpnme.And => constantFold2(boolValue, _ && _)
4404+
case tpnme.Or => constantFold2(boolValue, _ || _)
4405+
case tpnme.Xor => constantFold2(boolValue, _ ^ _)
43944406
case _ => None
43954407
} else None
43964408

0 commit comments

Comments
 (0)