Closed
Description
Hello,
I am getting a vectorized typing error from numba.
Any ideas how to fix this?
import pymc as pm
import nutpie
import numpy as np
test_data = np.array([642.29826899, 667.29826899, 692.29826899])
perturbations = np.array(
[288, 200, 288, 200, 200, 200, 200, 200, 1, 1, 1, 1, 1, 2, 200, 2, 200]
)
partial_derivatives = np.array(
[
[-5.16130147e-02],
[2.52964940e-01],
[1.75868011e-01],
[1.67508144e-01],
[1.22967884e-01],
[8.50826581e-02],
[4.51845806e-02],
[1.52296378e-02],
[8.54567700e-03],
[6.86547627e-03],
[4.94782294e-03],
[2.71454319e-03],
[9.46222115e-04],
[-1.27348193e00],
[1.27300075e-01],
[-4.70858227e00],
[1.72071494e-01],
]
)
init_temp = np.array([717.29826899])
with pm.Model() as model:
# Assuming uniform priors for BCs
params = [
pm.Uniform(f"params{idx}", -param, param, shape=1)
for idx, param in enumerate(perturbations)
]
params_arr = pm.math.concatenate(params, axis=0)
simulated_model = pm.Deterministic(
f"thermalmodel",
pm.math.dot(params_arr, pm.math.constant(partial_derivatives))
+ pm.math.constant(init_temp),
)
obs_sigma = pm.HalfCauchy(f"obs_sigma", beta=2) + pm.math.constant(6.3)
observed = pm.StudentT(
f"observed",
nu=len(test_data) - 1,
mu=simulated_model,
sigma=obs_sigma,
observed=test_data,
)
compiled_model = nutpie.compile_pymc_model(model)
trace = nutpie.sample(
compiled_model, draws=3000, tune=1000, chains=10, save_warmup=False
)
----------------
TypingError Traceback (most recent call last)
HOMEDIR\Dev\python\trame-htbctool\nutpie_bug_example.py in line 53
44 obs_sigma = pm.HalfCauchy(f"obs_sigma", beta=2) + pm.math.constant(6.3)
45 observed = pm.StudentT(
46 f"observed",
47 nu=len(test_data) - 1,
(...)
50 observed=test_data,
51 )
---> 53 compiled_model = nutpie.compile_pymc_model(model)
54 trace = nutpie.sample(
55 compiled_model, draws=3000, tune=1000, chains=10, save_warmup=False
56 )
File HOMEDIR\envs\nutpie_debug\lib\site-packages\nutpie\compile_pymc.py:121, in compile_pymc_model(model, **kwargs)
116 user_data = make_user_data(logp_fn_pt, shared_data)
118 logp_numba_raw, c_sig = _make_c_logp_func(
119 n_dim, logp_fn, user_data, shared_logp, shared_data
120 )
--> 121 logp_numba = numba.cfunc(c_sig, **kwargs)(logp_numba_raw)
123 def expand_draw(x, seed, chain, draw, *, shared_data):
124 return expand_fn(x, **{name: shared_data[name] for name in shared_expand})[0]
File HOMEDIR\envs\nutpie_debug\lib\site-packages\numba\core\decorators.py:282, in cfunc.<locals>.wrapper(func)
...
File "..\..\..\envs\nutpie_debug\lib\site-packages\nutpie\compile_pymc.py", line 265:
def extract_shared(x, user_data_):
return inner(x)
Environment:
name: nutpie_debug
channels:
- conda-forge
dependencies:
- appdirs=1.4.4=pyh9f0ad1d_0
- arviz=0.14.0=pyhd8ed1ab_0
- asttokens=2.2.1=pyhd8ed1ab_0
- backcall=0.2.0=pyh9f0ad1d_0
- backports=1.0=pyhd8ed1ab_3
- backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
- blas=2.0=netlib
- brotli=1.0.9=hcfcfb64_8
- brotli-bin=1.0.9=hcfcfb64_8
- brotlipy=0.7.0=py310h8d17308_1005
- bzip2=1.0.8=h8ffe710_4
- ca-certificates=2022.12.7=h5b45459_0
- cachetools=5.3.0=pyhd8ed1ab_0
- certifi=2022.12.7=pyhd8ed1ab_0
- cffi=1.15.1=py310h628cb3f_3
- cftime=1.6.2=py310h9b08ddd_1
- charset-normalizer=2.1.1=pyhd8ed1ab_0
- cloudpickle=2.2.1=pyhd8ed1ab_0
- colorama=0.4.6=pyhd8ed1ab_0
- comm=0.1.2=pyhd8ed1ab_0
- cons=0.4.5=pyhd8ed1ab_0
- contourpy=1.0.7=py310h232114e_0
- cryptography=39.0.1=py310h6e82f81_0
- curl=7.87.0=h68f0423_0
- cycler=0.11.0=pyhd8ed1ab_0
- debugpy=1.6.6=py310h00ffb61_0
- decorator=5.1.1=pyhd8ed1ab_0
- etuples=0.3.8=pyhd8ed1ab_0
- executing=1.2.0=pyhd8ed1ab_0
- fastprogress=1.0.3=pyhd8ed1ab_0
- filelock=3.9.0=pyhd8ed1ab_0
- fonttools=4.38.0=py310h8d17308_1
- freetype=2.12.1=h546665d_1
- hdf4=4.2.15=h1b1b6ef_5
- hdf5=1.12.2=nompi_h57737ce_101
- idna=3.4=pyhd8ed1ab_0
- importlib-metadata=6.0.0=pyha770c72_0
- importlib_metadata=6.0.0=hd8ed1ab_0
- intel-openmp=2023.0.0=h57928b3_25922
- ipykernel=6.21.1=pyh025b116_0
- ipython=8.9.0=pyh08f2357_0
- jedi=0.18.2=pyhd8ed1ab_0
- jpeg=9e=h8ffe710_2
- jupyter_client=8.0.2=pyhd8ed1ab_0
- jupyter_core=5.2.0=py310h5588dad_0
- kiwisolver=1.4.4=py310h232114e_1
- krb5=1.20.1=heb0366b_0
- lcms2=2.14=ha5c8aab_1
- lerc=4.0.0=h63175ca_0
- libaec=1.0.6=h63175ca_1
- libblas=3.9.0=0_h8933c1f_netlib
- libbrotlicommon=1.0.9=hcfcfb64_8
- libbrotlidec=1.0.9=hcfcfb64_8
- libbrotlienc=1.0.9=hcfcfb64_8
- libcblas=3.9.0=0_h8933c1f_netlib
- libcurl=7.87.0=h68f0423_0
- libdeflate=1.17=hcfcfb64_0
- libffi=3.4.2=h8ffe710_5
- libhwloc=2.8.0=h039e092_1
- libiconv=1.17=h8ffe710_0
- liblapack=3.9.0=0_h8933c1f_netlib
- liblapacke=3.9.0=0_h8933c1f_netlib
- libnetcdf=4.8.1=nompi_h8c042bf_106
- libpng=1.6.39=h19919ed_0
- libpython=2.2=py310h5588dad_2
- libsodium=1.0.18=h8d14728_1
- libsqlite=3.40.0=hcfcfb64_0
- libssh2=1.10.0=h9a1e1f7_3
- libtiff=4.5.0=hf8721a0_2
- libwebp-base=1.2.4=h8ffe710_0
- libxcb=1.13=hcd874cb_1004
- libxml2=2.10.3=hc3477c8_0
- libzip=1.9.2=h519de47_1
- libzlib=1.2.13=hcfcfb64_4
- llvmlite=0.39.1=py310hb84602e_1
- logical-unification=0.4.5=pyhd8ed1ab_0
- m2w64-binutils=2.25.1=5
- m2w64-bzip2=1.0.6=6
- m2w64-crt-git=5.0.0.4636.2595836=2
- m2w64-gcc=5.3.0=6
- m2w64-gcc-ada=5.3.0=6
- m2w64-gcc-fortran=5.3.0=6
- m2w64-gcc-libgfortran=5.3.0=6
- m2w64-gcc-libs=5.3.0=7
- m2w64-gcc-libs-core=5.3.0=7
- m2w64-gcc-objc=5.3.0=6
- m2w64-gmp=6.1.0=2
- m2w64-headers-git=5.0.0.4636.c0ad18a=2
- m2w64-isl=0.16.1=2
- m2w64-libiconv=1.14=6
- m2w64-libmangle-git=5.0.0.4509.2e5a9a2=2
- m2w64-libwinpthread-git=5.0.0.4634.697f757=2
- m2w64-make=4.1.2351.a80a8b8=2
- m2w64-mpc=1.0.3=3
- m2w64-mpfr=3.1.4=4
- m2w64-pkg-config=0.29.1=2
- m2w64-toolchain=5.3.0=7
- m2w64-toolchain_win-64=2.4.0=0
- m2w64-tools-git=5.0.0.4592.90b8472=2
- m2w64-windows-default-manifest=6.4=3
- m2w64-winpthreads-git=5.0.0.4634.697f757=2
- m2w64-zlib=1.2.8=10
- matplotlib-base=3.6.3=py310h51140c5_0
- matplotlib-inline=0.1.6=pyhd8ed1ab_0
- minikanren=1.0.3=pyhd8ed1ab_0
- mkl=2022.2.1=h6a75c08_19751
- mkl-service=2.4.0=py310h84a9c25_0
- msys2-conda-epoch=20160418=1
- multipledispatch=0.6.0=py_0
- munkres=1.1.4=pyh9f0ad1d_0
- nest-asyncio=1.5.6=pyhd8ed1ab_0
- netcdf4=1.6.2=nompi_py310h459bb5f_100
- numba=0.56.4=py310h19bcfe9_0
- numpy=1.23.5=py310h4a8f9c9_0
- nutpie=0.5.1=py310h96eb580_0
- openjpeg=2.5.0=ha2aaf27_2
- openssl=3.0.8=hcfcfb64_0
- packaging=23.0=pyhd8ed1ab_0
- pandas=1.5.3=py310h1c4a608_0
- parso=0.8.3=pyhd8ed1ab_0
- pickleshare=0.7.5=py_1003
- pillow=9.4.0=py310hdbb7713_1
- pip=23.0=pyhd8ed1ab_0
- platformdirs=3.0.0=pyhd8ed1ab_0
- pooch=1.6.0=pyhd8ed1ab_0
- prompt-toolkit=3.0.36=pyha770c72_0
- psutil=5.9.4=py310h8d17308_0
- pthread-stubs=0.4=hcd874cb_1001
- pthreads-win32=2.9.1=hfa6e2cd_3
- pure_eval=0.2.2=pyhd8ed1ab_0
- pycparser=2.21=pyhd8ed1ab_0
- pygments=2.14.0=pyhd8ed1ab_0
- pymc=5.0.2=hd8ed1ab_0
- pymc-base=5.0.2=pyhd8ed1ab_0
- pyopenssl=23.0.0=pyhd8ed1ab_0
- pyparsing=3.0.9=pyhd8ed1ab_0
- pysocks=1.7.1=pyh0701188_6
- pytensor=2.9.1=py310h53af72e_0
- pytensor-base=2.9.1=py310h00ffb61_0
- python=3.10.9=h4de0772_0_cpython
- python-dateutil=2.8.2=pyhd8ed1ab_0
- python_abi=3.10=3_cp310
- pytz=2022.7.1=pyhd8ed1ab_0
- pywin32=304=py310h00ffb61_2
- pyzmq=25.0.0=py310hcd737a0_0
- requests=2.28.2=pyhd8ed1ab_0
- scipy=1.10.0=py310h578b7cb_2
- setuptools=67.1.0=pyhd8ed1ab_0
- six=1.16.0=pyh6c4a22f_0
- stack_data=0.6.2=pyhd8ed1ab_0
- tbb=2021.7.0=h91493d7_1
- tk=8.6.12=h8ffe710_0
- toolz=0.12.0=pyhd8ed1ab_0
- tornado=6.2=py310h8d17308_1
- traitlets=5.9.0=pyhd8ed1ab_0
- typing-extensions=4.4.0=hd8ed1ab_0
- typing_extensions=4.4.0=pyha770c72_0
- tzdata=2022g=h191b570_0
- ucrt=10.0.22621.0=h57928b3_0
- unicodedata2=15.0.0=py310h8d17308_0
- urllib3=1.26.14=pyhd8ed1ab_0
- vc=14.3=hb6edc58_10
- vs2015_runtime=14.34.31931=h4c5c07a_10
- wcwidth=0.2.6=pyhd8ed1ab_0
- wheel=0.38.4=pyhd8ed1ab_0
- win_inet_pton=1.1.0=pyhd8ed1ab_6
- xarray=2023.2.0=pyhd8ed1ab_0
- xarray-einstats=0.5.1=pyhd8ed1ab_0
- xorg-libxau=1.0.9=hcd874cb_0
- xorg-libxdmcp=1.1.3=hcd874cb_0
- xz=5.2.6=h8d14728_0
- zeromq=4.3.4=h0e60522_1
- zipp=3.13.0=pyhd8ed1ab_0
- zlib=1.2.13=hcfcfb64_4
- zstd=1.5.2=h12be248_6
prefix: C:\envs\nutpie_debug
Metadata
Metadata
Assignees
Labels
No labels