Skip to content

Commit 42f1905

Browse files
authored
Fix _is_tensorflow_array. (matplotlib#30114)
The previous implementation was clearly wrong (the isinstance check would raise TypeError as the second argument would be a bool), but the tests didn't catch that because the bug led to _is_tensorflow_array returning False, then _unpack_to_numpy returning the original input, and then assert_array_equal implicitly converting `result` by calling `__array__` on it. Fix the test by explicitly checking that `result` is indeed a numpy array, and also fix _is_tensorflow_array with more restrictive exception catching (also applied to _is_torch_array, _is_jax_array, and _is_pandas_dataframe, while we're at it).
1 parent b18407b commit 42f1905

File tree

2 files changed

+43
-38
lines changed

2 files changed

+43
-38
lines changed

lib/matplotlib/cbook.py

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2331,42 +2331,56 @@ def _picklable_class_constructor(mixin_class, fmt, attr_name, base_class):
23312331

23322332

23332333
def _is_torch_array(x):
2334-
"""Check if 'x' is a PyTorch Tensor."""
2334+
"""Return whether *x* is a PyTorch Tensor."""
23352335
try:
2336-
# we're intentionally not attempting to import torch. If somebody
2337-
# has created a torch array, torch should already be in sys.modules
2338-
return isinstance(x, sys.modules['torch'].Tensor)
2339-
except Exception: # TypeError, KeyError, AttributeError, maybe others?
2340-
# we're attempting to access attributes on imported modules which
2341-
# may have arbitrary user code, so we deliberately catch all exceptions
2342-
return False
2336+
# We're intentionally not attempting to import torch. If somebody
2337+
# has created a torch array, torch should already be in sys.modules.
2338+
tp = sys.modules.get("torch").Tensor
2339+
except AttributeError:
2340+
return False # Module not imported or a nonstandard module with no Tensor attr.
2341+
return (isinstance(tp, type) # Just in case it's a very nonstandard module.
2342+
and isinstance(x, tp))
23432343

23442344

23452345
def _is_jax_array(x):
2346-
"""Check if 'x' is a JAX Array."""
2346+
"""Return whether *x* is a JAX Array."""
23472347
try:
2348-
# we're intentionally not attempting to import jax. If somebody
2349-
# has created a jax array, jax should already be in sys.modules
2350-
return isinstance(x, sys.modules['jax'].Array)
2351-
except Exception: # TypeError, KeyError, AttributeError, maybe others?
2352-
# we're attempting to access attributes on imported modules which
2353-
# may have arbitrary user code, so we deliberately catch all exceptions
2354-
return False
2348+
# We're intentionally not attempting to import jax. If somebody
2349+
# has created a jax array, jax should already be in sys.modules.
2350+
tp = sys.modules.get("jax").Array
2351+
except AttributeError:
2352+
return False # Module not imported or a nonstandard module with no Array attr.
2353+
return (isinstance(tp, type) # Just in case it's a very nonstandard module.
2354+
and isinstance(x, tp))
2355+
2356+
2357+
def _is_pandas_dataframe(x):
2358+
"""Check if *x* is a Pandas DataFrame."""
2359+
try:
2360+
# We're intentionally not attempting to import Pandas. If somebody
2361+
# has created a Pandas DataFrame, Pandas should already be in sys.modules.
2362+
tp = sys.modules.get("pandas").DataFrame
2363+
except AttributeError:
2364+
return False # Module not imported or a nonstandard module with no Array attr.
2365+
return (isinstance(tp, type) # Just in case it's a very nonstandard module.
2366+
and isinstance(x, tp))
23552367

23562368

23572369
def _is_tensorflow_array(x):
2358-
"""Check if 'x' is a TensorFlow Tensor or Variable."""
2370+
"""Return whether *x* is a TensorFlow Tensor or Variable."""
23592371
try:
2360-
# we're intentionally not attempting to import TensorFlow. If somebody
2361-
# has created a TensorFlow array, TensorFlow should already be in sys.modules
2362-
# we use `is_tensor` to not depend on the class structure of TensorFlow
2363-
# arrays, as `tf.Variables` are not instances of `tf.Tensor`
2364-
# (they both convert the same way)
2365-
return isinstance(x, sys.modules['tensorflow'].is_tensor(x))
2366-
except Exception: # TypeError, KeyError, AttributeError, maybe others?
2367-
# we're attempting to access attributes on imported modules which
2368-
# may have arbitrary user code, so we deliberately catch all exceptions
2372+
# We're intentionally not attempting to import TensorFlow. If somebody
2373+
# has created a TensorFlow array, TensorFlow should already be in
2374+
# sys.modules we use `is_tensor` to not depend on the class structure
2375+
# of TensorFlow arrays, as `tf.Variables` are not instances of
2376+
# `tf.Tensor` (they both convert the same way).
2377+
is_tensor = sys.modules.get("tensorflow").is_tensor
2378+
except AttributeError:
23692379
return False
2380+
try:
2381+
return is_tensor(x)
2382+
except Exception:
2383+
return False # Just in case it's a very nonstandard module.
23702384

23712385

23722386
def _unpack_to_numpy(x):
@@ -2421,15 +2435,3 @@ def _auto_format_str(fmt, value):
24212435
return fmt % (value,)
24222436
except (TypeError, ValueError):
24232437
return fmt.format(value)
2424-
2425-
2426-
def _is_pandas_dataframe(x):
2427-
"""Check if 'x' is a Pandas DataFrame."""
2428-
try:
2429-
# we're intentionally not attempting to import Pandas. If somebody
2430-
# has created a Pandas DataFrame, Pandas should already be in sys.modules
2431-
return isinstance(x, sys.modules['pandas'].DataFrame)
2432-
except Exception: # TypeError, KeyError, AttributeError, maybe others?
2433-
# we're attempting to access attributes on imported modules which
2434-
# may have arbitrary user code, so we deliberately catch all exceptions
2435-
return False

lib/matplotlib/tests/test_cbook.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,7 @@ def __array__(self):
10001000
torch_tensor = torch.Tensor(data)
10011001

10021002
result = cbook._unpack_to_numpy(torch_tensor)
1003+
assert isinstance(result, np.ndarray)
10031004
# compare results, do not check for identity: the latter would fail
10041005
# if not mocked, and the implementation does not guarantee it
10051006
# is the same Python object, just the same values.
@@ -1028,6 +1029,7 @@ def __array__(self):
10281029
jax_array = jax.Array(data)
10291030

10301031
result = cbook._unpack_to_numpy(jax_array)
1032+
assert isinstance(result, np.ndarray)
10311033
# compare results, do not check for identity: the latter would fail
10321034
# if not mocked, and the implementation does not guarantee it
10331035
# is the same Python object, just the same values.
@@ -1057,6 +1059,7 @@ def __array__(self):
10571059
tf_tensor = tensorflow.Tensor(data)
10581060

10591061
result = cbook._unpack_to_numpy(tf_tensor)
1062+
assert isinstance(result, np.ndarray)
10601063
# compare results, do not check for identity: the latter would fail
10611064
# if not mocked, and the implementation does not guarantee it
10621065
# is the same Python object, just the same values.

0 commit comments

Comments
 (0)