File tree 1 file changed +24
-1
lines changed
1 file changed +24
-1
lines changed Original file line number Diff line number Diff line change @@ -412,7 +412,30 @@ def test_ScalarLoop_while():
412
412
np .testing .assert_allclose (res [1 ], np .array (expected [1 ]))
413
413
414
414
415
- def test_ScalarLoop_Elemwise ():
415
+ def test_ScalarLoop_Elemwise_single_carries ():
416
+ n_steps = int64 ("n_steps" )
417
+ x0 = float64 ("x0" )
418
+ x = x0 * 2
419
+ until = x >= 10
420
+
421
+ scalarop = ScalarLoop (init = [x0 ], update = [x ], until = until )
422
+ op = Elemwise (scalarop )
423
+
424
+ n_steps = pt .scalar ("n_steps" , dtype = "int32" )
425
+ x0 = pt .vector ("x0" , dtype = "float32" )
426
+ state , done = op (n_steps , x0 )
427
+
428
+ f = FunctionGraph ([n_steps , x0 ], [state , done ])
429
+ args = [
430
+ np .array (10 ).astype ("int32" ),
431
+ np .arange (0 , 5 ).astype ("float32" ),
432
+ ]
433
+ compare_pytorch_and_py (
434
+ f , args , assert_fn = partial (np .testing .assert_allclose , rtol = 1e-6 )
435
+ )
436
+
437
+
438
+ def test_ScalarLoop_Elemwise_multi_carries ():
416
439
n_steps = int64 ("n_steps" )
417
440
x0 = float64 ("x0" )
418
441
x1 = float64 ("x1" )
You can’t perform that action at this time.
0 commit comments