Skip to content

Rewrite scalar dot as multiplication #1205

Open
@ricardoV94

Description

@ricardoV94

Description

In #1178 we rewrite batched dots that are just multiplication away, but left core dots the same due to use of BLAS operations for those (whether they are worth it or not is a question on its own). But there is one case that is definitely not worth it: scalar multiplication.

The following graph should definitely be simplified:

import pytensor
import pytensor.tensor as pt

x = pt.tensor("x", shape=(1, 1))
y = pt.tensor("y", shape=(1, 1))
out = x @ y
pytensor.function([x, y], out).dprint()
CGer{non-destructive} [id A] 2
 ├─ [[0.]] [id B]
 ├─ 1.0 [id C]
 ├─ DropDims{axis=1} [id D] 1
 │  └─ x [id E]
 └─ DropDims{axis=0} [id F] 0
    └─ y [id G]

Or without BLAS stuff

pytensor.function([x, y], out, mode="FAST_COMPILE").dprint()
Dot22 [id A] 0
 ├─ x [id B]
 └─ y [id C]

Those should just be mul because that can be fused with other Elemwise operations (and calling BLAS for it is the silliest thing ever)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions