Skip to content

Add sigil for tensors #498

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Oct 3, 2021
129 changes: 129 additions & 0 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -8451,6 +8451,135 @@ defmodule Nx do
)
end

## Sigils

@doc """
A convenient `~M` sigil for building matrices (two-dimensional tensors).

## Examples

Before using sigils, you must first import them:

import Nx, only: [sigil_M: 2]

If you are using Elixir v1.13+, then you can write instead:

import Nx, only: :sigils

Then you use the sigil to create matrices. The sigil:

~M<
-1 0 0 1
0 2 0 0
0 0 3 0
0 0 0 4
>

Is equivalent to:

Nx.tensor([
[-1, 0, 0, 1],
[0, 2, 0, 0],
[0, 0, 3, 0],
[0, 0, 0, 4]
])

If the tensor has any float type, it defaults to f32.
Otherwise, it is s64. If you are using Elixir 1.13+,
you can specify the tensor type as a sigil modifier:

iex> import Nx
iex> ~M[0.1 0.2 0.3 0.4]f16
#Nx.Tensor<
f16[1][4]
[
[0.0999755859375, 0.199951171875, 0.300048828125, 0.39990234375]
]
>

"""
defmacro sigil_M({:<<>>, _meta, [string]}, modifiers) do
string
|> binary_to_numbers
|> numbers_to_tensor(modifiers)
end

@doc """
A convenient `~V` sigil for building vectors (one-dimensional tensors).

## Examples

Before using sigils, you must first import them:

import Nx, only: [sigil_V: 2]

Then you use the sigil to create vectors. The sigil:

~V[-1 0 0 1]

Is equivalent to:

Nx.tensor([-1, 0, 0, 1])

If the tensor has any float type, it defaults to f32.
Otherwise, it is s64. If you are using Elixir 1.13+,
you can specify the tensor type as a sigil modifier:

iex> import Nx
iex> ~V[0.1 0.2 0.3 0.4]f16
#Nx.Tensor<
f16[4]
[0.0999755859375, 0.199951171875, 0.300048828125, 0.39990234375]
>

"""
defmacro sigil_V({:<<>>, _meta, [string]}, modifiers) do
case binary_to_numbers(string) do
[numbers] ->
numbers_to_tensor(numbers, modifiers)

_ ->
raise ArgumentError, "must be one-dimensional"
end
end

defp numbers_to_tensor(numbers, modifiers) do
type =
case modifiers do
[unit | size] ->
Nx.Type.normalize!({List.to_atom([unit]), List.to_integer(size)})

[] ->
Nx.Type.infer(numbers)
end

{shape, binary} = flatten(numbers, type)

quote do
unquote(binary)
|> Nx.from_binary(unquote(type))
|> Nx.reshape(unquote(Macro.escape(shape)))
end
end

defp binary_to_numbers(string) do
for row <- String.split(string, "\n", trim: true) do
row
|> String.split(" ", trim: true)
|> Enum.map(fn str ->
module = if String.contains?(str, "."), do: Float, else: Integer

case module.parse(str) do
{number, ""} ->
number

_ ->
raise ArgumentError, "expected a numerical value for tensor, got #{str}"
end
end)
end
end

## Helpers

defp backend!(backend) when is_atom(backend),
Expand Down
64 changes: 63 additions & 1 deletion nx/test/nx_test.exs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defmodule NxTest do
use ExUnit.Case, async: true

doctest Nx
doctest Nx, except: [sigil_M: 2, sigil_V: 2]

defp commute(a, b, fun) do
fun.(a, b)
Expand Down Expand Up @@ -1628,4 +1628,66 @@ defmodule NxTest do
)
end
end

describe "sigils" do
test "evaluates to tensor" do
import Nx

assert ~M[-1 2 3 4] == Nx.tensor([[-1, 2, 3, 4]])
assert ~M[1
2
3
4] == Nx.tensor([[1], [2], [3], [4]])
assert ~M[1.0 2 3
11 12 13] == Nx.tensor([[1.0, 2, 3], [11, 12, 13]])

assert ~V[4 3 2 1] == Nx.tensor([4, 3, 2, 1])
end

test "raises when vector has more than one dimension" do
assert_raise(
ArgumentError,
"must be one-dimensional",
fn ->
eval(~S[~V<0 0 0 1
1 0 0 0>])
end
)
end

if Version.match?(System.version(), ">= 1.13.0-dev") do
test "evaluates with proper type" do
assert eval("~M[1 2 3 4]f32") == Nx.tensor([[1, 2, 3, 4]], type: {:f, 32})
assert eval("~M[4 3 2 1]u8") == Nx.tensor([[4, 3, 2, 1]], type: {:u, 8})

assert eval("~V[0 1 0 1]u8") == Nx.tensor([0, 1, 0, 1], type: {:u, 8})
end

test "raises on invalid type" do
assert_raise(
ArgumentError,
"invalid numerical type: {:f, 8} (see Nx.Type docs for all supported types)",
fn ->
eval("~M[1 2 3 4]f8")
end
)
end

test "raises on non-numerical values" do
assert_raise(
ArgumentError,
"expected a numerical value for tensor, got x",
fn ->
eval("~V[1 2 x 4]u8")
end
)
end
end

defp eval(expresion) do
"import Nx; #{expresion}"
|> Code.eval_string()
|> elem(0)
end
end
end
5 changes: 4 additions & 1 deletion torchx/test/torchx/nx_doctest_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ defmodule Torchx.NxDoctestTest do
# to_batched_list - Shape mismatch due to unsupported options in some tests
to_batched_list: 3,
# window_mean - depends on window_sum which is not implemented
window_mean: 3
window_mean: 3,
# require Elixir 1.13+
sigil_M: 2,
sigil_V: 2
]

@rounding_error_doctests [
Expand Down