Skip to content

Add dtype keyword to to_device #647

Closed
@asmeurer

Description

@asmeurer

This is somewhat tangentially related to the discussions at #645. Right now to_device() doesn't have a dtype paramter, but it could be useful for it to have one. The torch.to function does have one. I'm not completely clear about cupy so @leofang would have to comment.

One reason is that certain devices might not support the existing array dtype. We also need to specify what happens in this case (likely should be an error).

This code in scikit-learn is also an example of where this would be useful https://github.com/scikit-learn/scikit-learn/pull/26315/files/42524bd42900d8ea5f4a334780387a72c6f9580d#diff-86c94a3ca33490c6190f488f5d40b01bf0fd29be36da0b4497ef0da1fda4148a. That code simultaneously converts an array to float32. It would presumably be more efficient to do this in one step instead of two.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions