Skip to content

Vectorized Typing error from Numba #37

Closed
pymc-devs/pytensor
#218
@giiyms

Description

@giiyms

Hello,

I am getting a vectorized typing error from numba.

Any ideas how to fix this?

import pymc as pm
import nutpie
import numpy as np


test_data = np.array([642.29826899, 667.29826899, 692.29826899])
perturbations = np.array(
    [288, 200, 288, 200, 200, 200, 200, 200, 1, 1, 1, 1, 1, 2, 200, 2, 200]
)

partial_derivatives = np.array(
    [
        [-5.16130147e-02],
        [2.52964940e-01],
        [1.75868011e-01],
        [1.67508144e-01],
        [1.22967884e-01],
        [8.50826581e-02],
        [4.51845806e-02],
        [1.52296378e-02],
        [8.54567700e-03],
        [6.86547627e-03],
        [4.94782294e-03],
        [2.71454319e-03],
        [9.46222115e-04],
        [-1.27348193e00],
        [1.27300075e-01],
        [-4.70858227e00],
        [1.72071494e-01],
    ]
)

init_temp = np.array([717.29826899])

with pm.Model() as model:
    # Assuming uniform priors for BCs
    params = [
        pm.Uniform(f"params{idx}", -param, param, shape=1)
        for idx, param in enumerate(perturbations)
    ]
    params_arr = pm.math.concatenate(params, axis=0)
    simulated_model = pm.Deterministic(
        f"thermalmodel",
        pm.math.dot(params_arr, pm.math.constant(partial_derivatives))
        + pm.math.constant(init_temp),
    )
    obs_sigma = pm.HalfCauchy(f"obs_sigma", beta=2) + pm.math.constant(6.3)
    observed = pm.StudentT(
        f"observed",
        nu=len(test_data) - 1,
        mu=simulated_model,
        sigma=obs_sigma,
        observed=test_data,
    )


compiled_model = nutpie.compile_pymc_model(model)

trace = nutpie.sample(
    compiled_model, draws=3000, tune=1000, chains=10, save_warmup=False
)
----------------
TypingError                               Traceback (most recent call last)
HOMEDIR\Dev\python\trame-htbctool\nutpie_bug_example.py in line 53
     44     obs_sigma = pm.HalfCauchy(f"obs_sigma", beta=2) + pm.math.constant(6.3)
     45     observed = pm.StudentT(
     46         f"observed",
     47         nu=len(test_data) - 1,
   (...)
     50         observed=test_data,
     51     )
---> 53 compiled_model = nutpie.compile_pymc_model(model)
     54 trace = nutpie.sample(
     55     compiled_model, draws=3000, tune=1000, chains=10, save_warmup=False
     56 )

File HOMEDIR\envs\nutpie_debug\lib\site-packages\nutpie\compile_pymc.py:121, in compile_pymc_model(model, **kwargs)
    116 user_data = make_user_data(logp_fn_pt, shared_data)
    118 logp_numba_raw, c_sig = _make_c_logp_func(
    119     n_dim, logp_fn, user_data, shared_logp, shared_data
    120 )
--> 121 logp_numba = numba.cfunc(c_sig, **kwargs)(logp_numba_raw)
    123 def expand_draw(x, seed, chain, draw, *, shared_data):
    124     return expand_fn(x, **{name: shared_data[name] for name in shared_expand})[0]

File HOMEDIR\envs\nutpie_debug\lib\site-packages\numba\core\decorators.py:282, in cfunc.<locals>.wrapper(func)
...
File "..\..\..\envs\nutpie_debug\lib\site-packages\nutpie\compile_pymc.py", line 265:
        def extract_shared(x, user_data_):
            return inner(x)

Environment:

name: nutpie_debug
channels:
  - conda-forge
dependencies:
  - appdirs=1.4.4=pyh9f0ad1d_0
  - arviz=0.14.0=pyhd8ed1ab_0
  - asttokens=2.2.1=pyhd8ed1ab_0
  - backcall=0.2.0=pyh9f0ad1d_0
  - backports=1.0=pyhd8ed1ab_3
  - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
  - blas=2.0=netlib
  - brotli=1.0.9=hcfcfb64_8
  - brotli-bin=1.0.9=hcfcfb64_8
  - brotlipy=0.7.0=py310h8d17308_1005
  - bzip2=1.0.8=h8ffe710_4
  - ca-certificates=2022.12.7=h5b45459_0
  - cachetools=5.3.0=pyhd8ed1ab_0
  - certifi=2022.12.7=pyhd8ed1ab_0
  - cffi=1.15.1=py310h628cb3f_3
  - cftime=1.6.2=py310h9b08ddd_1
  - charset-normalizer=2.1.1=pyhd8ed1ab_0
  - cloudpickle=2.2.1=pyhd8ed1ab_0
  - colorama=0.4.6=pyhd8ed1ab_0
  - comm=0.1.2=pyhd8ed1ab_0
  - cons=0.4.5=pyhd8ed1ab_0
  - contourpy=1.0.7=py310h232114e_0
  - cryptography=39.0.1=py310h6e82f81_0
  - curl=7.87.0=h68f0423_0
  - cycler=0.11.0=pyhd8ed1ab_0
  - debugpy=1.6.6=py310h00ffb61_0
  - decorator=5.1.1=pyhd8ed1ab_0
  - etuples=0.3.8=pyhd8ed1ab_0
  - executing=1.2.0=pyhd8ed1ab_0
  - fastprogress=1.0.3=pyhd8ed1ab_0
  - filelock=3.9.0=pyhd8ed1ab_0
  - fonttools=4.38.0=py310h8d17308_1
  - freetype=2.12.1=h546665d_1
  - hdf4=4.2.15=h1b1b6ef_5
  - hdf5=1.12.2=nompi_h57737ce_101
  - idna=3.4=pyhd8ed1ab_0
  - importlib-metadata=6.0.0=pyha770c72_0
  - importlib_metadata=6.0.0=hd8ed1ab_0
  - intel-openmp=2023.0.0=h57928b3_25922
  - ipykernel=6.21.1=pyh025b116_0
  - ipython=8.9.0=pyh08f2357_0
  - jedi=0.18.2=pyhd8ed1ab_0
  - jpeg=9e=h8ffe710_2
  - jupyter_client=8.0.2=pyhd8ed1ab_0
  - jupyter_core=5.2.0=py310h5588dad_0
  - kiwisolver=1.4.4=py310h232114e_1
  - krb5=1.20.1=heb0366b_0
  - lcms2=2.14=ha5c8aab_1
  - lerc=4.0.0=h63175ca_0
  - libaec=1.0.6=h63175ca_1
  - libblas=3.9.0=0_h8933c1f_netlib
  - libbrotlicommon=1.0.9=hcfcfb64_8
  - libbrotlidec=1.0.9=hcfcfb64_8
  - libbrotlienc=1.0.9=hcfcfb64_8
  - libcblas=3.9.0=0_h8933c1f_netlib
  - libcurl=7.87.0=h68f0423_0
  - libdeflate=1.17=hcfcfb64_0
  - libffi=3.4.2=h8ffe710_5
  - libhwloc=2.8.0=h039e092_1
  - libiconv=1.17=h8ffe710_0
  - liblapack=3.9.0=0_h8933c1f_netlib
  - liblapacke=3.9.0=0_h8933c1f_netlib
  - libnetcdf=4.8.1=nompi_h8c042bf_106
  - libpng=1.6.39=h19919ed_0
  - libpython=2.2=py310h5588dad_2
  - libsodium=1.0.18=h8d14728_1
  - libsqlite=3.40.0=hcfcfb64_0
  - libssh2=1.10.0=h9a1e1f7_3
  - libtiff=4.5.0=hf8721a0_2
  - libwebp-base=1.2.4=h8ffe710_0
  - libxcb=1.13=hcd874cb_1004
  - libxml2=2.10.3=hc3477c8_0
  - libzip=1.9.2=h519de47_1
  - libzlib=1.2.13=hcfcfb64_4
  - llvmlite=0.39.1=py310hb84602e_1
  - logical-unification=0.4.5=pyhd8ed1ab_0
  - m2w64-binutils=2.25.1=5
  - m2w64-bzip2=1.0.6=6
  - m2w64-crt-git=5.0.0.4636.2595836=2
  - m2w64-gcc=5.3.0=6
  - m2w64-gcc-ada=5.3.0=6
  - m2w64-gcc-fortran=5.3.0=6
  - m2w64-gcc-libgfortran=5.3.0=6
  - m2w64-gcc-libs=5.3.0=7
  - m2w64-gcc-libs-core=5.3.0=7
  - m2w64-gcc-objc=5.3.0=6
  - m2w64-gmp=6.1.0=2
  - m2w64-headers-git=5.0.0.4636.c0ad18a=2
  - m2w64-isl=0.16.1=2
  - m2w64-libiconv=1.14=6
  - m2w64-libmangle-git=5.0.0.4509.2e5a9a2=2
  - m2w64-libwinpthread-git=5.0.0.4634.697f757=2
  - m2w64-make=4.1.2351.a80a8b8=2
  - m2w64-mpc=1.0.3=3
  - m2w64-mpfr=3.1.4=4
  - m2w64-pkg-config=0.29.1=2
  - m2w64-toolchain=5.3.0=7
  - m2w64-toolchain_win-64=2.4.0=0
  - m2w64-tools-git=5.0.0.4592.90b8472=2
  - m2w64-windows-default-manifest=6.4=3
  - m2w64-winpthreads-git=5.0.0.4634.697f757=2
  - m2w64-zlib=1.2.8=10
  - matplotlib-base=3.6.3=py310h51140c5_0
  - matplotlib-inline=0.1.6=pyhd8ed1ab_0
  - minikanren=1.0.3=pyhd8ed1ab_0
  - mkl=2022.2.1=h6a75c08_19751
  - mkl-service=2.4.0=py310h84a9c25_0
  - msys2-conda-epoch=20160418=1
  - multipledispatch=0.6.0=py_0
  - munkres=1.1.4=pyh9f0ad1d_0
  - nest-asyncio=1.5.6=pyhd8ed1ab_0
  - netcdf4=1.6.2=nompi_py310h459bb5f_100
  - numba=0.56.4=py310h19bcfe9_0
  - numpy=1.23.5=py310h4a8f9c9_0
  - nutpie=0.5.1=py310h96eb580_0
  - openjpeg=2.5.0=ha2aaf27_2
  - openssl=3.0.8=hcfcfb64_0
  - packaging=23.0=pyhd8ed1ab_0
  - pandas=1.5.3=py310h1c4a608_0
  - parso=0.8.3=pyhd8ed1ab_0
  - pickleshare=0.7.5=py_1003
  - pillow=9.4.0=py310hdbb7713_1
  - pip=23.0=pyhd8ed1ab_0
  - platformdirs=3.0.0=pyhd8ed1ab_0
  - pooch=1.6.0=pyhd8ed1ab_0
  - prompt-toolkit=3.0.36=pyha770c72_0
  - psutil=5.9.4=py310h8d17308_0
  - pthread-stubs=0.4=hcd874cb_1001
  - pthreads-win32=2.9.1=hfa6e2cd_3
  - pure_eval=0.2.2=pyhd8ed1ab_0
  - pycparser=2.21=pyhd8ed1ab_0
  - pygments=2.14.0=pyhd8ed1ab_0
  - pymc=5.0.2=hd8ed1ab_0
  - pymc-base=5.0.2=pyhd8ed1ab_0
  - pyopenssl=23.0.0=pyhd8ed1ab_0
  - pyparsing=3.0.9=pyhd8ed1ab_0
  - pysocks=1.7.1=pyh0701188_6
  - pytensor=2.9.1=py310h53af72e_0
  - pytensor-base=2.9.1=py310h00ffb61_0
  - python=3.10.9=h4de0772_0_cpython
  - python-dateutil=2.8.2=pyhd8ed1ab_0
  - python_abi=3.10=3_cp310
  - pytz=2022.7.1=pyhd8ed1ab_0
  - pywin32=304=py310h00ffb61_2
  - pyzmq=25.0.0=py310hcd737a0_0
  - requests=2.28.2=pyhd8ed1ab_0
  - scipy=1.10.0=py310h578b7cb_2
  - setuptools=67.1.0=pyhd8ed1ab_0
  - six=1.16.0=pyh6c4a22f_0
  - stack_data=0.6.2=pyhd8ed1ab_0
  - tbb=2021.7.0=h91493d7_1
  - tk=8.6.12=h8ffe710_0
  - toolz=0.12.0=pyhd8ed1ab_0
  - tornado=6.2=py310h8d17308_1
  - traitlets=5.9.0=pyhd8ed1ab_0
  - typing-extensions=4.4.0=hd8ed1ab_0
  - typing_extensions=4.4.0=pyha770c72_0
  - tzdata=2022g=h191b570_0
  - ucrt=10.0.22621.0=h57928b3_0
  - unicodedata2=15.0.0=py310h8d17308_0
  - urllib3=1.26.14=pyhd8ed1ab_0
  - vc=14.3=hb6edc58_10
  - vs2015_runtime=14.34.31931=h4c5c07a_10
  - wcwidth=0.2.6=pyhd8ed1ab_0
  - wheel=0.38.4=pyhd8ed1ab_0
  - win_inet_pton=1.1.0=pyhd8ed1ab_6
  - xarray=2023.2.0=pyhd8ed1ab_0
  - xarray-einstats=0.5.1=pyhd8ed1ab_0
  - xorg-libxau=1.0.9=hcd874cb_0
  - xorg-libxdmcp=1.1.3=hcd874cb_0
  - xz=5.2.6=h8d14728_0
  - zeromq=4.3.4=h0e60522_1
  - zipp=3.13.0=pyhd8ed1ab_0
  - zlib=1.2.13=hcfcfb64_4
  - zstd=1.5.2=h12be248_6
prefix: C:\envs\nutpie_debug

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions