Skip to content

Commit 44ac66a

Browse files
committed
Add quote patterns regression tests
1 parent 8aae979 commit 44ac66a

File tree

6 files changed

+130
-0
lines changed

6 files changed

+130
-0
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
Original: ls.filter(((x: scala.Int) => x.<(3))).filter(((`x₂`: scala.Int) => `x₂`.>(1)))
2+
Optimized: ls.filter(((x: scala.Int) => ((`x₂`: scala.Int) => `x₂`.<(3)).apply(x).&&(((`x₃`: scala.Int) => `x₃`.>(1)).apply(x))))
3+
Result: List(2)
4+
5+
Original: ls2.filter(((x: scala.Char) => x.<('c'))).filter(((`x₂`: scala.Char) => `x₂`.>('a')))
6+
Optimized: ls2.filter(((x: scala.Char) => ((`x₂`: scala.Char) => `x₂`.<('c')).apply(x).&&(((`x₃`: scala.Char) => `x₃`.>('a')).apply(x))))
7+
Result: List(b)
8+
9+
Original: ls.filter(((x: scala.Int) => x.<(3))).filter(((`x₂`: scala.Int) => `x₂`.>(1))).filter(((`x₃`: scala.Int) => `x₃`.==(2)))
10+
Optimized: ls.filter(((x: scala.Int) => ((`x₂`: scala.Int) => `x₂`.<(3)).apply(x).&&(((`x₃`: scala.Int) => ((`x₄`: scala.Int) => `x₄`.>(1)).apply(`x₃`).&&(((`x₅`: scala.Int) => `x₅`.==(2)).apply(`x₃`))).apply(x))))
11+
Result: List(2)
12+
13+
Original: ls.map[scala.Long](((a: scala.Int) => a.toLong)).map[java.lang.String](((b: scala.Long) => b.toString()))
14+
Optimized: ls.map[java.lang.String](((x: scala.Int) => ((b: scala.Long) => b.toString()).apply(((a: scala.Int) => a.toLong).apply(x))))
15+
Result: List(1, 2, 3)
16+
17+
Original: ls.map[scala.Char](((a: scala.Int) => a.toChar)).map[java.lang.String](((b: scala.Char) => b.toString()))
18+
Optimized: ls.map[java.lang.String](((x: scala.Int) => ((b: scala.Char) => b.toString()).apply(((a: scala.Int) => a.toChar).apply(x))))
19+
Result: List(, , )
20+
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import scala.quoted._
2+
3+
object Macro {
4+
5+
inline def optimize[T](inline x: List[T]): List[T] = ${ Macro.impl[T]('x) }
6+
7+
def impl[T: Type](x: Expr[List[T]])(using Quotes): Expr[List[T]] = {
8+
val res = optimize(x)
9+
'{
10+
val result = $res
11+
val originalCode = ${Expr(x.show)}
12+
val optimizeCode = ${Expr(res.show)}
13+
println("Original: " + originalCode)
14+
println("Optimized: " + optimizeCode)
15+
println("Result: " + result)
16+
println()
17+
result
18+
}
19+
}
20+
21+
def optimize[T: Type](x: Expr[List[T]])(using Quotes): Expr[List[T]] = x match {
22+
case '{ ($ls: List[T]).filter($f).filter($g) } =>
23+
optimize('{ $ls.filter(x => $f(x) && $g(x)) })
24+
25+
case '{ type u; type v; ($ls: List[`u`]).map($f: `u` => `v`).map($g: `v` => T) } =>
26+
optimize('{ $ls.map(x => $g($f(x))) })
27+
28+
case _ => x
29+
}
30+
}
31+
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
object Test {
2+
import Macro._
3+
4+
def main(args: Array[String]): Unit = {
5+
val ls = List(1, 2, 3)
6+
val ls2 = List('a', 'b', 'c')
7+
optimize(ls.filter(x => x < 3).filter(x => x > 1))
8+
optimize(ls2.filter(x => x < 'c').filter(x => x > 'a'))
9+
optimize(ls.filter(x => x < 3).filter(x => x > 1).filter(x => x == 2))
10+
optimize(ls.map(a => a.toLong).map(b => b.toString))
11+
optimize(ls.map(a => a.toChar).map(b => b.toString))
12+
}
13+
14+
}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
Original: ls.filter(((x: scala.Int) => x.<(3))).filter(((`x₂`: scala.Int) => `x₂`.>(1)))
2+
Optimized: ls.filter(((x: scala.Int) => ((`x₂`: scala.Int) => `x₂`.<(3)).apply(x).&&(((`x₃`: scala.Int) => `x₃`.>(1)).apply(x))))
3+
Result: List(2)
4+
5+
Original: ls2.filter(((x: scala.Char) => x.<('c'))).filter(((`x₂`: scala.Char) => `x₂`.>('a')))
6+
Optimized: ls2.filter(((x: scala.Char) => ((`x₂`: scala.Char) => `x₂`.<('c')).apply(x).&&(((`x₃`: scala.Char) => `x₃`.>('a')).apply(x))))
7+
Result: List(b)
8+
9+
Original: ls.filter(((x: scala.Int) => x.<(3))).filter(((`x₂`: scala.Int) => `x₂`.>(1))).filter(((`x₃`: scala.Int) => `x₃`.==(2)))
10+
Optimized: ls.filter(((x: scala.Int) => ((`x₂`: scala.Int) => `x₂`.<(3)).apply(x).&&(((`x₃`: scala.Int) => ((`x₄`: scala.Int) => `x₄`.>(1)).apply(`x₃`).&&(((`x₅`: scala.Int) => `x₅`.==(2)).apply(`x₃`))).apply(x))))
11+
Result: List(2)
12+
13+
Original: ls.map[scala.Long](((a: scala.Int) => a.toLong)).map[java.lang.String](((b: scala.Long) => b.toString()))
14+
Optimized: ls.map[java.lang.String](((x: scala.Int) => ((b: scala.Long) => b.toString()).apply(((a: scala.Int) => a.toLong).apply(x))))
15+
Result: List(1, 2, 3)
16+
17+
Original: ls.map[scala.Char](((a: scala.Int) => a.toChar)).map[java.lang.String](((b: scala.Char) => b.toString()))
18+
Optimized: ls.map[java.lang.String](((x: scala.Int) => ((b: scala.Char) => b.toString()).apply(((a: scala.Int) => a.toChar).apply(x))))
19+
Result: List(, , )
20+
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import scala.quoted._
2+
3+
object Macro {
4+
5+
inline def optimize[T](inline x: List[T]): List[T] = ${ Macro.impl[T]('x) }
6+
7+
def impl[T: Type](x: Expr[List[T]])(using Quotes): Expr[List[T]] = {
8+
val res = optimize(x)
9+
'{
10+
val result = $res
11+
val originalCode = ${Expr(x.show)}
12+
val optimizeCode = ${Expr(res.show)}
13+
println("Original: " + originalCode)
14+
println("Optimized: " + optimizeCode)
15+
println("Result: " + result)
16+
println()
17+
result
18+
}
19+
}
20+
21+
def optimize[T: Type](x: Expr[List[T]])(using Quotes): Expr[List[T]] = x match {
22+
case '{ ($ls: List[T]).filter($f).filter($g) } =>
23+
optimize('{ $ls.filter(x => $f(x) && $g(x)) })
24+
25+
case '{ ($ls: List[u]).map[v]($f).map[T]($g) } =>
26+
optimize('{ $ls.map(x => $g($f(x))) })
27+
28+
case _ => x
29+
}
30+
}
31+
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
object Test {
2+
import Macro._
3+
4+
def main(args: Array[String]): Unit = {
5+
val ls = List(1, 2, 3)
6+
val ls2 = List('a', 'b', 'c')
7+
optimize(ls.filter(x => x < 3).filter(x => x > 1))
8+
optimize(ls2.filter(x => x < 'c').filter(x => x > 'a'))
9+
optimize(ls.filter(x => x < 3).filter(x => x > 1).filter(x => x == 2))
10+
optimize(ls.map(a => a.toLong).map(b => b.toString))
11+
optimize(ls.map(a => a.toChar).map(b => b.toString))
12+
}
13+
14+
}

0 commit comments

Comments
 (0)