Open
Description
Description
In #1178 we rewrite batched dots that are just multiplication away, but left core dots the same due to use of BLAS operations for those (whether they are worth it or not is a question on its own). But there is one case that is definitely not worth it: scalar multiplication.
The following graph should definitely be simplified:
import pytensor
import pytensor.tensor as pt
x = pt.tensor("x", shape=(1, 1))
y = pt.tensor("y", shape=(1, 1))
out = x @ y
pytensor.function([x, y], out).dprint()
CGer{non-destructive} [id A] 2
├─ [[0.]] [id B]
├─ 1.0 [id C]
├─ DropDims{axis=1} [id D] 1
│ └─ x [id E]
└─ DropDims{axis=0} [id F] 0
└─ y [id G]
Or without BLAS stuff
pytensor.function([x, y], out, mode="FAST_COMPILE").dprint()
Dot22 [id A] 0
├─ x [id B]
└─ y [id C]
Those should just be mul because that can be fused with other Elemwise operations (and calling BLAS for it is the silliest thing ever)