Skip to content

Commit 007abae

Browse files
author
Ian Schweer
committed
Add single carry test
1 parent 893cd96 commit 007abae

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

tests/link/pytorch/test_basic.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,30 @@ def test_ScalarLoop_while():
412412
np.testing.assert_allclose(res[1], np.array(expected[1]))
413413

414414

415-
def test_ScalarLoop_Elemwise():
415+
def test_ScalarLoop_Elemwise_single_carries():
416+
n_steps = int64("n_steps")
417+
x0 = float64("x0")
418+
x = x0 * 2
419+
until = x >= 10
420+
421+
scalarop = ScalarLoop(init=[x0], update=[x], until=until)
422+
op = Elemwise(scalarop)
423+
424+
n_steps = pt.scalar("n_steps", dtype="int32")
425+
x0 = pt.vector("x0", dtype="float32")
426+
state, done = op(n_steps, x0)
427+
428+
f = FunctionGraph([n_steps, x0], [state, done])
429+
args = [
430+
np.array(10).astype("int32"),
431+
np.arange(0, 5).astype("float32"),
432+
]
433+
compare_pytorch_and_py(
434+
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
435+
)
436+
437+
438+
def test_ScalarLoop_Elemwise_multi_carries():
416439
n_steps = int64("n_steps")
417440
x0 = float64("x0")
418441
x1 = float64("x1")

0 commit comments

Comments
 (0)