Skip to content

Implement infer_shape automatically from gufunc_signature #1257

Open
@ricardoV94

Description

@ricardoV94

Description

For Ops with a gufunc_signature, we can automate infer_shape implementation:

class Cholesky(Op):
# TODO: LAPACK wrapper with in-place behavior, for solve also
__props__ = ("lower", "check_finite", "on_error", "overwrite_a")
gufunc_signature = "(m,m)->(m,m)"
def __init__(
self,
*,
lower: bool = True,
check_finite: bool = True,
on_error: Literal["raise", "nan"] = "raise",
overwrite_a: bool = False,
):
self.lower = lower
self.check_finite = check_finite
if on_error not in ("raise", "nan"):
raise ValueError('on_error must be one of "raise" or ""nan"')
self.on_error = on_error
self.overwrite_a = overwrite_a
if self.overwrite_a:
self.destroy_map = {0: [0]}
def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]

We actually already do it for the Blockwise Wrapper:

# The output dim is the same as another input dim
if dim_name in core_dims:
core_out_shape.append(core_dims[dim_name])

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions