@@ -4196,6 +4196,11 @@ object Types {
4196
4196
case tycon : TypeRef if defn.isCompiletimeAppliedType(tycon.symbol) =>
4197
4197
extension (tp : Type ) def fixForEvaluation : Type =
4198
4198
tp.normalized.dealias match {
4199
+ // enable operations for constant singleton terms. E.g.:
4200
+ // ```
4201
+ // final val one = 1
4202
+ // type Two = one.type + one.type
4203
+ // ```
4199
4204
case tp : TermRef => tp.underlying
4200
4205
case tp => tp
4201
4206
}
@@ -4234,163 +4239,170 @@ object Types {
4234
4239
case ConstantType (Constant (n : String )) => Some (n)
4235
4240
case _ => None
4236
4241
}
4237
- def isConst : Option [Type ] = args.head.fixForEvaluation match {
4242
+
4243
+ def isConst (tp : Type ) : Option [Type ] = tp.fixForEvaluation match {
4238
4244
case ConstantType (_) => Some (ConstantType (Constant (true )))
4239
4245
case _ => Some (ConstantType (Constant (false )))
4240
4246
}
4247
+
4248
+ def expectArgsNum (expectedNum : Int ) : Unit =
4249
+ // We can use assert instead of a compiler type error because this error should not
4250
+ // occur since the type signature of the operation enforces the proper number of args.
4251
+ assert(args.length == expectedNum, s " Type operation expects $expectedNum arguments but found ${args.length}" )
4252
+
4241
4253
def natValue (tp : Type ): Option [Int ] = intValue(tp).filter(n => n >= 0 && n < Int .MaxValue )
4242
4254
4255
+ // Runs the op and returns the result as a constant type.
4256
+ // If the op throws an exception, then this exception is converted into a type error.
4257
+ def runConstantOp (op : => Any ): Type =
4258
+ val result = try {
4259
+ op
4260
+ } catch {
4261
+ case e : Throwable =>
4262
+ throw new TypeError (e.getMessage)
4263
+ }
4264
+ ConstantType (Constant (result))
4265
+
4243
4266
def constantFold1 [T ](extractor : Type => Option [T ], op : T => Any ): Option [Type ] =
4244
- extractor(args.head).map(a => ConstantType (Constant (op(a))))
4267
+ expectArgsNum(1 )
4268
+ extractor(args.head).map(a => runConstantOp(op(a)))
4245
4269
4246
4270
def constantFold2 [T ](extractor : Type => Option [T ], op : (T , T ) => Any ): Option [Type ] =
4247
4271
constantFold2AB(extractor, extractor, op)
4248
4272
4249
4273
def constantFold2AB [TA , TB ](extractorA : Type => Option [TA ], extractorB : Type => Option [TB ], op : (TA , TB ) => Any ): Option [Type ] =
4274
+ expectArgsNum(2 )
4250
4275
for {
4251
- a <- extractorA(args.head )
4252
- b <- extractorB(args.last )
4253
- } yield ConstantType ( Constant ( op(a, b) ))
4276
+ a <- extractorA(args( 0 ) )
4277
+ b <- extractorB(args( 1 ) )
4278
+ } yield runConstantOp( op(a, b))
4254
4279
4255
4280
def constantFold3 [TA , TB , TC ](
4256
4281
extractorA : Type => Option [TA ],
4257
4282
extractorB : Type => Option [TB ],
4258
4283
extractorC : Type => Option [TC ],
4259
4284
op : (TA , TB , TC ) => Any
4260
4285
): Option [Type ] =
4286
+ expectArgsNum(3 )
4261
4287
for {
4262
- a <- extractorA(args.head )
4288
+ a <- extractorA(args( 0 ) )
4263
4289
b <- extractorB(args(1 ))
4264
- c <- extractorC(args.last )
4265
- } yield ConstantType ( Constant ( op(a, b, c) ))
4290
+ c <- extractorC(args( 2 ) )
4291
+ } yield runConstantOp( op(a, b, c))
4266
4292
4267
4293
trace(i " compiletime constant fold $this" , typr, show = true ) {
4268
4294
val name = tycon.symbol.name
4269
4295
val owner = tycon.symbol.owner
4270
- val nArgs = args.length
4271
4296
val constantType =
4272
4297
if (defn.isCompiletime_S(tycon.symbol)) {
4273
- if (nArgs == 1 ) constantFold1(natValue, _ + 1 )
4274
- else None
4298
+ constantFold1(natValue, _ + 1 )
4275
4299
} else if (owner == defn.CompiletimeOpsAnyModuleClass ) name match {
4276
- case tpnme.Equals if nArgs == 2 => constantFold2(constValue, _ == _)
4277
- case tpnme.NotEquals if nArgs == 2 => constantFold2(constValue, _ != _)
4278
- case tpnme.ToString if nArgs == 1 => constantFold1(constValue, _.toString)
4279
- case tpnme.IsConst if nArgs == 1 => isConst
4300
+ case tpnme.Equals => constantFold2(constValue, _ == _)
4301
+ case tpnme.NotEquals => constantFold2(constValue, _ != _)
4302
+ case tpnme.ToString => constantFold1(constValue, _.toString)
4303
+ case tpnme.IsConst => isConst(args.head)
4280
4304
case _ => None
4281
4305
} else if (owner == defn.CompiletimeOpsIntModuleClass ) name match {
4282
- case tpnme.Abs if nArgs == 1 => constantFold1(intValue, _.abs)
4283
- case tpnme.Negate if nArgs == 1 => constantFold1(intValue, x => - x)
4306
+ case tpnme.Abs => constantFold1(intValue, _.abs)
4307
+ case tpnme.Negate => constantFold1(intValue, x => - x)
4284
4308
// ToString is deprecated for ops.int, and moved to ops.any
4285
- case tpnme.ToString if nArgs == 1 => constantFold1(intValue, _.toString)
4286
- case tpnme.Plus if nArgs == 2 => constantFold2(intValue, _ + _)
4287
- case tpnme.Minus if nArgs == 2 => constantFold2(intValue, _ - _)
4288
- case tpnme.Times if nArgs == 2 => constantFold2(intValue, _ * _)
4289
- case tpnme.Div if nArgs == 2 => constantFold2(intValue, {
4290
- case (_, 0 ) => throw new TypeError (" Division by 0" )
4291
- case (a, b) => a / b
4292
- })
4293
- case tpnme.Mod if nArgs == 2 => constantFold2(intValue, {
4294
- case (_, 0 ) => throw new TypeError (" Modulo by 0" )
4295
- case (a, b) => a % b
4296
- })
4297
- case tpnme.Lt if nArgs == 2 => constantFold2(intValue, _ < _)
4298
- case tpnme.Gt if nArgs == 2 => constantFold2(intValue, _ > _)
4299
- case tpnme.Ge if nArgs == 2 => constantFold2(intValue, _ >= _)
4300
- case tpnme.Le if nArgs == 2 => constantFold2(intValue, _ <= _)
4301
- case tpnme.Xor if nArgs == 2 => constantFold2(intValue, _ ^ _)
4302
- case tpnme.BitwiseAnd if nArgs == 2 => constantFold2(intValue, _ & _)
4303
- case tpnme.BitwiseOr if nArgs == 2 => constantFold2(intValue, _ | _)
4304
- case tpnme.ASR if nArgs == 2 => constantFold2(intValue, _ >> _)
4305
- case tpnme.LSL if nArgs == 2 => constantFold2(intValue, _ << _)
4306
- case tpnme.LSR if nArgs == 2 => constantFold2(intValue, _ >>> _)
4307
- case tpnme.Min if nArgs == 2 => constantFold2(intValue, _ min _)
4308
- case tpnme.Max if nArgs == 2 => constantFold2(intValue, _ max _)
4309
- case tpnme.NumberOfLeadingZeros if nArgs == 1 => constantFold1(intValue, Integer .numberOfLeadingZeros(_))
4310
- case tpnme.ToLong if nArgs == 1 => constantFold1(intValue, _.toLong)
4311
- case tpnme.ToFloat if nArgs == 1 => constantFold1(intValue, _.toFloat)
4312
- case tpnme.ToDouble if nArgs == 1 => constantFold1(intValue, _.toDouble)
4309
+ case tpnme.ToString => constantFold1(intValue, _.toString)
4310
+ case tpnme.Plus => constantFold2(intValue, _ + _)
4311
+ case tpnme.Minus => constantFold2(intValue, _ - _)
4312
+ case tpnme.Times => constantFold2(intValue, _ * _)
4313
+ case tpnme.Div => constantFold2(intValue, _ / _)
4314
+ case tpnme.Mod => constantFold2(intValue, _ % _)
4315
+ case tpnme.Lt => constantFold2(intValue, _ < _)
4316
+ case tpnme.Gt => constantFold2(intValue, _ > _)
4317
+ case tpnme.Ge => constantFold2(intValue, _ >= _)
4318
+ case tpnme.Le => constantFold2(intValue, _ <= _)
4319
+ case tpnme.Xor => constantFold2(intValue, _ ^ _)
4320
+ case tpnme.BitwiseAnd => constantFold2(intValue, _ & _)
4321
+ case tpnme.BitwiseOr => constantFold2(intValue, _ | _)
4322
+ case tpnme.ASR => constantFold2(intValue, _ >> _)
4323
+ case tpnme.LSL => constantFold2(intValue, _ << _)
4324
+ case tpnme.LSR => constantFold2(intValue, _ >>> _)
4325
+ case tpnme.Min => constantFold2(intValue, _ min _)
4326
+ case tpnme.Max => constantFold2(intValue, _ max _)
4327
+ case tpnme.NumberOfLeadingZeros => constantFold1(intValue, Integer .numberOfLeadingZeros(_))
4328
+ case tpnme.ToLong => constantFold1(intValue, _.toLong)
4329
+ case tpnme.ToFloat => constantFold1(intValue, _.toFloat)
4330
+ case tpnme.ToDouble => constantFold1(intValue, _.toDouble)
4313
4331
case _ => None
4314
4332
} else if (owner == defn.CompiletimeOpsLongModuleClass ) name match {
4315
- case tpnme.Abs if nArgs == 1 => constantFold1(longValue, _.abs)
4316
- case tpnme.Negate if nArgs == 1 => constantFold1(longValue, x => - x)
4317
- case tpnme.Plus if nArgs == 2 => constantFold2(longValue, _ + _)
4318
- case tpnme.Minus if nArgs == 2 => constantFold2(longValue, _ - _)
4319
- case tpnme.Times if nArgs == 2 => constantFold2(longValue, _ * _)
4320
- case tpnme.Div if nArgs == 2 => constantFold2(longValue, {
4321
- case (_, 0L ) => throw new TypeError (" Division by 0" )
4322
- case (a, b) => a / b
4323
- })
4324
- case tpnme.Mod if nArgs == 2 => constantFold2(longValue, {
4325
- case (_, 0L ) => throw new TypeError (" Modulo by 0" )
4326
- case (a, b) => a % b
4327
- })
4328
- case tpnme.Lt if nArgs == 2 => constantFold2(longValue, _ < _)
4329
- case tpnme.Gt if nArgs == 2 => constantFold2(longValue, _ > _)
4330
- case tpnme.Ge if nArgs == 2 => constantFold2(longValue, _ >= _)
4331
- case tpnme.Le if nArgs == 2 => constantFold2(longValue, _ <= _)
4332
- case tpnme.Xor if nArgs == 2 => constantFold2(longValue, _ ^ _)
4333
- case tpnme.BitwiseAnd if nArgs == 2 => constantFold2(longValue, _ & _)
4334
- case tpnme.BitwiseOr if nArgs == 2 => constantFold2(longValue, _ | _)
4335
- case tpnme.ASR if nArgs == 2 => constantFold2(longValue, _ >> _)
4336
- case tpnme.LSL if nArgs == 2 => constantFold2(longValue, _ << _)
4337
- case tpnme.LSR if nArgs == 2 => constantFold2(longValue, _ >>> _)
4338
- case tpnme.Min if nArgs == 2 => constantFold2(longValue, _ min _)
4339
- case tpnme.Max if nArgs == 2 => constantFold2(longValue, _ max _)
4340
- case tpnme.NumberOfLeadingZeros if nArgs == 1 =>
4333
+ case tpnme.Abs => constantFold1(longValue, _.abs)
4334
+ case tpnme.Negate => constantFold1(longValue, x => - x)
4335
+ case tpnme.Plus => constantFold2(longValue, _ + _)
4336
+ case tpnme.Minus => constantFold2(longValue, _ - _)
4337
+ case tpnme.Times => constantFold2(longValue, _ * _)
4338
+ case tpnme.Div => constantFold2(longValue, _ / _)
4339
+ case tpnme.Mod => constantFold2(longValue, _ % _)
4340
+ case tpnme.Lt => constantFold2(longValue, _ < _)
4341
+ case tpnme.Gt => constantFold2(longValue, _ > _)
4342
+ case tpnme.Ge => constantFold2(longValue, _ >= _)
4343
+ case tpnme.Le => constantFold2(longValue, _ <= _)
4344
+ case tpnme.Xor => constantFold2(longValue, _ ^ _)
4345
+ case tpnme.BitwiseAnd => constantFold2(longValue, _ & _)
4346
+ case tpnme.BitwiseOr => constantFold2(longValue, _ | _)
4347
+ case tpnme.ASR => constantFold2(longValue, _ >> _)
4348
+ case tpnme.LSL => constantFold2(longValue, _ << _)
4349
+ case tpnme.LSR => constantFold2(longValue, _ >>> _)
4350
+ case tpnme.Min => constantFold2(longValue, _ min _)
4351
+ case tpnme.Max => constantFold2(longValue, _ max _)
4352
+ case tpnme.NumberOfLeadingZeros =>
4341
4353
constantFold1(longValue, java.lang.Long .numberOfLeadingZeros(_))
4342
- case tpnme.ToInt if nArgs == 1 => constantFold1(longValue, _.toInt)
4343
- case tpnme.ToFloat if nArgs == 1 => constantFold1(longValue, _.toFloat)
4344
- case tpnme.ToDouble if nArgs == 1 => constantFold1(longValue, _.toDouble)
4354
+ case tpnme.ToInt => constantFold1(longValue, _.toInt)
4355
+ case tpnme.ToFloat => constantFold1(longValue, _.toFloat)
4356
+ case tpnme.ToDouble => constantFold1(longValue, _.toDouble)
4345
4357
case _ => None
4346
4358
} else if (owner == defn.CompiletimeOpsFloatModuleClass ) name match {
4347
- case tpnme.Abs if nArgs == 1 => constantFold1(floatValue, _.abs)
4348
- case tpnme.Negate if nArgs == 1 => constantFold1(floatValue, x => - x)
4349
- case tpnme.Plus if nArgs == 2 => constantFold2(floatValue, _ + _)
4350
- case tpnme.Minus if nArgs == 2 => constantFold2(floatValue, _ - _)
4351
- case tpnme.Times if nArgs == 2 => constantFold2(floatValue, _ * _)
4352
- case tpnme.Div if nArgs == 2 => constantFold2(floatValue, _ / _)
4353
- case tpnme.Mod if nArgs == 2 => constantFold2(floatValue, _ % _)
4354
- case tpnme.Lt if nArgs == 2 => constantFold2(floatValue, _ < _)
4355
- case tpnme.Gt if nArgs == 2 => constantFold2(floatValue, _ > _)
4356
- case tpnme.Ge if nArgs == 2 => constantFold2(floatValue, _ >= _)
4357
- case tpnme.Le if nArgs == 2 => constantFold2(floatValue, _ <= _)
4358
- case tpnme.Min if nArgs == 2 => constantFold2(floatValue, _ min _)
4359
- case tpnme.Max if nArgs == 2 => constantFold2(floatValue, _ max _)
4360
- case tpnme.ToInt if nArgs == 1 => constantFold1(floatValue, _.toInt)
4361
- case tpnme.ToLong if nArgs == 1 => constantFold1(floatValue, _.toLong)
4362
- case tpnme.ToDouble if nArgs == 1 => constantFold1(floatValue, _.toDouble)
4359
+ case tpnme.Abs => constantFold1(floatValue, _.abs)
4360
+ case tpnme.Negate => constantFold1(floatValue, x => - x)
4361
+ case tpnme.Plus => constantFold2(floatValue, _ + _)
4362
+ case tpnme.Minus => constantFold2(floatValue, _ - _)
4363
+ case tpnme.Times => constantFold2(floatValue, _ * _)
4364
+ case tpnme.Div => constantFold2(floatValue, _ / _)
4365
+ case tpnme.Mod => constantFold2(floatValue, _ % _)
4366
+ case tpnme.Lt => constantFold2(floatValue, _ < _)
4367
+ case tpnme.Gt => constantFold2(floatValue, _ > _)
4368
+ case tpnme.Ge => constantFold2(floatValue, _ >= _)
4369
+ case tpnme.Le => constantFold2(floatValue, _ <= _)
4370
+ case tpnme.Min => constantFold2(floatValue, _ min _)
4371
+ case tpnme.Max => constantFold2(floatValue, _ max _)
4372
+ case tpnme.ToInt => constantFold1(floatValue, _.toInt)
4373
+ case tpnme.ToLong => constantFold1(floatValue, _.toLong)
4374
+ case tpnme.ToDouble => constantFold1(floatValue, _.toDouble)
4363
4375
case _ => None
4364
4376
} else if (owner == defn.CompiletimeOpsDoubleModuleClass ) name match {
4365
- case tpnme.Abs if nArgs == 1 => constantFold1(doubleValue, _.abs)
4366
- case tpnme.Negate if nArgs == 1 => constantFold1(doubleValue, x => - x)
4367
- case tpnme.Plus if nArgs == 2 => constantFold2(doubleValue, _ + _)
4368
- case tpnme.Minus if nArgs == 2 => constantFold2(doubleValue, _ - _)
4369
- case tpnme.Times if nArgs == 2 => constantFold2(doubleValue, _ * _)
4370
- case tpnme.Div if nArgs == 2 => constantFold2(doubleValue, _ / _)
4371
- case tpnme.Mod if nArgs == 2 => constantFold2(doubleValue, _ % _)
4372
- case tpnme.Lt if nArgs == 2 => constantFold2(doubleValue, _ < _)
4373
- case tpnme.Gt if nArgs == 2 => constantFold2(doubleValue, _ > _)
4374
- case tpnme.Ge if nArgs == 2 => constantFold2(doubleValue, _ >= _)
4375
- case tpnme.Le if nArgs == 2 => constantFold2(doubleValue, _ <= _)
4376
- case tpnme.Min if nArgs == 2 => constantFold2(doubleValue, _ min _)
4377
- case tpnme.Max if nArgs == 2 => constantFold2(doubleValue, _ max _)
4378
- case tpnme.ToInt if nArgs == 1 => constantFold1(doubleValue, _.toInt)
4379
- case tpnme.ToLong if nArgs == 1 => constantFold1(doubleValue, _.toLong)
4380
- case tpnme.ToFloat if nArgs == 1 => constantFold1(doubleValue, _.toFloat)
4377
+ case tpnme.Abs => constantFold1(doubleValue, _.abs)
4378
+ case tpnme.Negate => constantFold1(doubleValue, x => - x)
4379
+ case tpnme.Plus => constantFold2(doubleValue, _ + _)
4380
+ case tpnme.Minus => constantFold2(doubleValue, _ - _)
4381
+ case tpnme.Times => constantFold2(doubleValue, _ * _)
4382
+ case tpnme.Div => constantFold2(doubleValue, _ / _)
4383
+ case tpnme.Mod => constantFold2(doubleValue, _ % _)
4384
+ case tpnme.Lt => constantFold2(doubleValue, _ < _)
4385
+ case tpnme.Gt => constantFold2(doubleValue, _ > _)
4386
+ case tpnme.Ge => constantFold2(doubleValue, _ >= _)
4387
+ case tpnme.Le => constantFold2(doubleValue, _ <= _)
4388
+ case tpnme.Min => constantFold2(doubleValue, _ min _)
4389
+ case tpnme.Max => constantFold2(doubleValue, _ max _)
4390
+ case tpnme.ToInt => constantFold1(doubleValue, _.toInt)
4391
+ case tpnme.ToLong => constantFold1(doubleValue, _.toLong)
4392
+ case tpnme.ToFloat => constantFold1(doubleValue, _.toFloat)
4381
4393
case _ => None
4382
4394
} else if (owner == defn.CompiletimeOpsStringModuleClass ) name match {
4383
- case tpnme.Plus if nArgs == 2 => constantFold2(stringValue, _ + _)
4384
- case tpnme.Length if nArgs == 1 => constantFold1(stringValue, _.length)
4385
- case tpnme.Matches if nArgs == 2 => constantFold2(stringValue, _ matches _)
4386
- case tpnme.Substring if nArgs == 3 =>
4395
+ case tpnme.Plus => constantFold2(stringValue, _ + _)
4396
+ case tpnme.Length => constantFold1(stringValue, _.length)
4397
+ case tpnme.Matches => constantFold2(stringValue, _ matches _)
4398
+ case tpnme.Substring =>
4387
4399
constantFold3(stringValue, intValue, intValue, (s, b, e) => s.substring(b, e))
4388
4400
case _ => None
4389
4401
} else if (owner == defn.CompiletimeOpsBooleanModuleClass ) name match {
4390
- case tpnme.Not if nArgs == 1 => constantFold1(boolValue, x => ! x)
4391
- case tpnme.And if nArgs == 2 => constantFold2(boolValue, _ && _)
4392
- case tpnme.Or if nArgs == 2 => constantFold2(boolValue, _ || _)
4393
- case tpnme.Xor if nArgs == 2 => constantFold2(boolValue, _ ^ _)
4402
+ case tpnme.Not => constantFold1(boolValue, x => ! x)
4403
+ case tpnme.And => constantFold2(boolValue, _ && _)
4404
+ case tpnme.Or => constantFold2(boolValue, _ || _)
4405
+ case tpnme.Xor => constantFold2(boolValue, _ ^ _)
4394
4406
case _ => None
4395
4407
} else None
4396
4408
0 commit comments