Skip to content

RFC: add API to enforce a specified dimensionality #494

Open
@arogozhnikov

Description

@arogozhnikov

Previous discussion: #97
Resolution was to ignore the case of data-dependent (unknown) dimensionalities.

I have a suggestion how to implement it by adding a new function to the API with a specific interface.

This allows to support variable dimensionality in a large number of practical cases.
Additionally, it allows to partially control dimension sizes that are currently unknown.

Example 1:

x, [d1, d2, d3] = xp.enforce_shape(x, [None, None, None])

enforces dimensionality to be 3, will fail for inputs of other dimensionalities.
If dimensionality is not yet known - will fail in runtime at this operation.
Elements d1, d2 and d3 are integers or symbols, not Nones.

Returned x is either the same tensor, or tensor with attached information about shape.
Main reason for returning a tensor is that operation could not be cut out of the graph if resulting shape is not used to obtain result.

Example 2:

x, [d1, d2, d3] = xp.enforce_shape(x, [d1_input, None, 4])

expects input to have 3 dimensions, first axis matches to (maybe) symbol d1_input, last axis of length 4.

Example 3:

x, [d1, d2, (axes, n_elements), d3] = xp.enforce_shape(x, [1, None, ..., 3])

expects input to have 3 or more dimensions. First axis is of length 1, second of any length, last of length 3.

Note here, that for ellipsis a tuple with two objects is returned: axes corresponds to x.shape[2:-1], and n_elements is a product of elements in axes.

axes is either a tuple of Union[int, Symbol] (if dimensionality is known)
or a Symbol for a tuple of unknown size (if dimensionality is unknown), same type as used to represent unknown shape.

n_elements is either int (if size of all axes is known) or symbol for axis.

Signature

Tensor, List[Union[int, None, AxisSymbol, ellipsis]] -> Tuple[Tensor, List[Union[int, AxisSymbol, EllipsisTuple]]] 
EllipsisTuple = Tuple[TupleSymbol, AxisSymbol] 
              | Tuple[Tuple[SizeOrSymbol, ...], SizeOrSymbol] 
SizeOrSymbol = Union[int, AxisSymbol]

I see this as a good intermediate point: sufficient to deal with multiple practical cases while requiring framework developers to maintain very little.

Downstream package dev perspective

  • there are cases when a dimensionality should be enforced (e.g. 1d/2d). This is covered by proposed function
  • there are cases when axes can be split into operation-important and non-important.
    There is a generic path to deal with these cases by first reducing to a fixed dimensionality, and converting back afterwards.
    Example: last dimension should be rgb, should convert to hsl. Assuming we have a function that implements conversion for 2d array.
x_rgb, [(axes, n_elements), last_axis] = xp.enforce_shape(x_rgb, [..., None])
x_rgb_2d = xp.reshape(x_rgb, (n_elements, last_axis))
x_hsv_2d = convert_rgb_to_hsv(x_rgb_2d) # or any other function that works with 'fixed' dimensionality
result = xp.reshape(x_hsv_2d, axes + (last_axis,))
  • if there is a need to e.g. specify last or first dimension, now it is possible with graph-based frameworks too. Standard way is to add validations like tf.assert(...) and bind them in the graph, which can't be done in a cross-framework manner

Frameworks perspective

While the function is novel, it does not introduce any new entities: symbols for axes and symbols for shapes (like TensorShape) should exist anyway.
Function behavior overlaps with and extends tf.set_shape by supporting ellipsis.

For frameworks without data-dependent shapes, this is straightforward to implement based on already exposed shape property.

Metadata

Metadata

Assignees

No one assigned

    Labels

    API extensionAdds new functions or objects to the API.Needs DiscussionNeeds further discussion.RFCRequest for comments. Feature requests and proposed changes.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions