{ "info": { "author": "Uber AI Labs", "author_email": "npradhan@uber.com", "bugtrack_url": null, "classifiers": [ "Intended Audience :: Developers", "Intended Audience :: Education", "Intended Audience :: Science/Research", "License :: OSI Approved :: MIT License", "Operating System :: MacOS :: MacOS X", "Operating System :: POSIX :: Linux", "Programming Language :: Python :: 3.6" ], "description": "[![Build Status](https://travis-ci.com/pyro-ppl/numpyro.svg?branch=master)](https://travis-ci.com/pyro-ppl/numpyro)\n[![Documentation Status](https://readthedocs.org/projects/numpyro/badge/?version=latest)](https://numpyro.readthedocs.io/en/latest/?badge=latest)\n[![Latest Version](https://badge.fury.io/py/numpyro.svg)](https://pypi.python.org/pypi/numpyro)\n# NumPyro\n\nProbabilistic programming with NumPy powered by [JAX](https://github.com/google/jax) for autograd and JIT compilation to GPU/CPU.\n\n[Docs](https://numpyro.readthedocs.io/en/v0.1.0/) | [Examples](https://pyro.ai/numpyro/) | [Forum](https://forum.pyro.ai/)\n\n----------------------------------------------------------------------------------------------------\n\n## What is NumPyro?\n\nNumPyro is a small probabilistic programming library built on [JAX](https://github.com/google/jax). It essentially provides a NumPy backend for [Pyro](https://github.com/pyro-ppl/pyro), with some minor changes to the inference API and syntax. Since we use JAX, we get autograd and JIT compilation to GPU / CPU for free. This is an alpha release under active development, so beware of brittleness, bugs, and changes to the API as the design evolves.\n\nNumPyro is designed to be *lightweight* and focuses on providing a flexible substrate that users can build on:\n\n - **Pyro Primitives:** NumPyro programs can contain regular Python and NumPy code, in addition to [Pyro primitives](http://pyro.ai/examples/intro_part_i.html) like `sample` and `param`. The model code should look very similar to Pyro except for some minor differences between PyTorch and Numpy's API. See [Examples](https://github.com/pyro-ppl/numpyro/#Examples).\n - **Inference algorithms:** NumPyro currently supports Hamiltonian Monte Carlo, including an implementation of the No U-Turn Sampler. One of the motivations for NumPyro was to speed up Hamiltonian Monte Carlo by JIT compiling the verlet integration step that includes multiple gradient computations. With JAX, we can compose `jit` and `grad` to compile the entire integration step into an XLA optimized kernel. We also eliminate Python overhead by JIT compiling the entire tree building stage in NUTS (this is possible using [Iterative NUTS](https://github.com/pyro-ppl/numpyro/wiki/Iterative-NUTS)). There is also a basic Variational Inference implementation for reparameterized distributions.\n - **Distributions:** The [numpyro.distributions](https://numpyro.readthedocs.io/en/latest/distributions.html) module provides distribution classes, constraints and bijective transforms. The distribution classes wrap over samplers implemented to work with JAX's [functional pseudo-random number generator](https://github.com/google/jax#random-numbers-are-different). The design of the distributions module largely follows from [PyTorch](https://pytorch.org/docs/stable/distributions.html). A major subset of the API is implemented, and it contains most of the common distributions that exist in PyTorch. As a result, Pyro and PyTorch users can rely on the same API and batching semantics as in `torch.distributions`. In addition to distributions, `constraints` and `transforms` are very useful when operating on distribution classes with bounded support.\n - **Effect handlers:** Like Pyro, primitives like `sample` and `param` can be interpreted with side-effects using effect-handlers from the [numpyro.handlers](https://numpyro.readthedocs.io/en/latest/handlers.html) module, and these can be easily extended to implement custom inference algorithms and inference utilities.\n\n\n## Installation\n\n> **Limited Windows Support:** Note that NumPyro is untested on Windows, and will require building jaxlib from source. See this [JAX issue](https://github.com/google/jax/issues/438) for more details.\n\nTo install NumPyro with a CPU version of JAX, you can use pip:\n\n```\npip install numpyro\n```\n\nTo use NumPyro on the GPU, you will need to first [install](https://github.com/google/jax#installation) `jax` and `jaxlib` with CUDA support.\n\nYou can also install NumPyro from source:\n\n```\ngit clone https://github.com/pyro-ppl/numpyro.git\n# install jax/jaxlib first for CUDA support\npip install -e .[dev]\n```\n\n## Examples\n\n\nFor some examples on specifying models and doing inference in NumPyro:\n\n - [Bayesian Regression in NumPyro](https://nbviewer.jupyter.org/github/pyro-ppl/numpyro/blob/master/notebooks/source/bayesian_regression.ipynb) - Start here to get acquainted with writing a simple model in NumPyro, MCMC inference API, effect handlers and writing custom inference utilities.\n - [Time Series Forecasting](https://nbviewer.jupyter.org/github/pyro-ppl/numpyro/blob/master/notebooks/source/time_series_forecasting.ipynb) - Illustrates how to convert for loops in the model to JAX's `lax.scan` primitive for fast inference.\n - [Baseball example](https://github.com/pyro-ppl/numpyro/blob/master/examples/baseball.py) - Using NUTS for a simple hierarchical model. Compare this with the baseball example in [Pyro](https://github.com/pyro-ppl/pyro/blob/dev/examples/baseball.py).\n - [Hidden Markov Model](https://github.com/pyro-ppl/numpyro/blob/master/examples/hmm.py) in NumPyro as compared to [Stan](https://mc-stan.org/docs/2_19/stan-users-guide/hmms-section.html).\n - [Variational Autoencoder](https://github.com/pyro-ppl/numpyro/blob/master/examples/vae.py) - As a simple example that uses Variational Inference. [Pyro implementation](https://github.com/pyro-ppl/pyro/blob/dev/examples/vae/vae.py) for comparison.\n - Other model examples can be found in the [examples](https://github.com/pyro-ppl/numpyro/tree/master/examples) folder.\n\nUsers will note that the API for model specification is largely the same as Pyro including the distributions API, by design. The interface for inference algorithms and other utility functions might deviate from Pyro in favor of a more *functional* style that works better with JAX. e.g. there is no global parameter store or random state.\n\n## Future Work\n\nIn the near term, we plan to work on the following. Please open new issues for feature requests and enhancements:\n\n - Improving robustness of inference on different models, profiling and performance tuning.\n - More inference algorithms, particularly those that require second order derivaties or use HMC.\n - Integration with [Funsor](https://github.com/pyro-ppl/funsor) to support inference algorithms with delayed sampling.\n - Supporting more distributions, extending the distributions API, and adding more samplers to JAX.\n - Other areas motivated by Pyro's research goals and application focus, and interest from the community.\n\n\n", "description_content_type": "text/markdown", "docs_url": null, "download_url": "", "downloads": { "last_day": -1, "last_month": -1, "last_week": -1 }, "home_page": "https://github.com/pyro-ppl/numpyro", "keywords": "probabilistic machine learning bayesian statistics", "license": "", "maintainer": "", "maintainer_email": "", "name": "numpyro", "package_url": "https://pypi.org/project/numpyro/", "platform": "", "project_url": "https://pypi.org/project/numpyro/", "project_urls": { "Homepage": "https://github.com/pyro-ppl/numpyro" }, "release_url": "https://pypi.org/project/numpyro/0.2.0/", "requires_dist": [ "jax (==0.1.44)", "jaxlib (==0.1.27)", "tqdm", "ipython ; extra == 'dev'", "sphinx ; extra == 'doc'", "sphinx-rtd-theme ; extra == 'doc'", "matplotlib ; extra == 'examples'", "flake8 ; extra == 'test'", "pytest (>=4.1) ; extra == 'test'" ], "requires_python": "", "summary": "Pyro PPL on Numpy", "version": "0.2.0" }, "last_serial": 5798605, "releases": { "0.1.0": [ { "comment_text": "", "digests": { "md5": "a6cb308c7ccee6f6f3ef329a25888d7a", "sha256": "3df3e54abfa195a51ceadc6d177ff629f79c5fe9c9a8f89547b3b1a73bd49840" }, "downloads": -1, "filename": "numpyro-0.1.0.tar.gz", "has_sig": false, "md5_digest": "a6cb308c7ccee6f6f3ef329a25888d7a", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 73756, "upload_time": "2019-06-01T04:15:31", "url": "https://files.pythonhosted.org/packages/46/b7/e4e0624e9e948b1f1c4176ef01ac8dc3487fdc78f897494bdfe15932e51b/numpyro-0.1.0.tar.gz" } ], "0.2.0": [ { "comment_text": "", "digests": { "md5": "93d947398501804cd7e21c5dcccd92fa", "sha256": "4f7ed0d83cca73d67e2368132a5dbbb0f84094942f2468173dc1b6c3770a5769" }, "downloads": -1, "filename": "numpyro-0.2.0-py3-none-any.whl", "has_sig": false, "md5_digest": "93d947398501804cd7e21c5dcccd92fa", "packagetype": "bdist_wheel", "python_version": "py3", "requires_python": null, "size": 95162, "upload_time": "2019-09-08T07:22:40", "url": "https://files.pythonhosted.org/packages/17/7c/323ec59c52d6c49defcba944167843437976cbc7261cff058721e74fc77f/numpyro-0.2.0-py3-none-any.whl" }, { "comment_text": "", "digests": { "md5": "26e76fef203630d51dfb292ca94b3e29", "sha256": "32b2dc6e0dc1c94a0b6590bfba51f605446162072ab73ca69e272dad1007aaf8" }, "downloads": -1, "filename": "numpyro-0.2.0.tar.gz", "has_sig": false, "md5_digest": "26e76fef203630d51dfb292ca94b3e29", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 101660, "upload_time": "2019-09-08T07:22:41", "url": "https://files.pythonhosted.org/packages/5a/fa/574b880cb719cd2b2310a6851ff334a2e0cbec1929b3871c16a37988e30f/numpyro-0.2.0.tar.gz" } ] }, "urls": [ { "comment_text": "", "digests": { "md5": "93d947398501804cd7e21c5dcccd92fa", "sha256": "4f7ed0d83cca73d67e2368132a5dbbb0f84094942f2468173dc1b6c3770a5769" }, "downloads": -1, "filename": "numpyro-0.2.0-py3-none-any.whl", "has_sig": false, "md5_digest": "93d947398501804cd7e21c5dcccd92fa", "packagetype": "bdist_wheel", "python_version": "py3", "requires_python": null, "size": 95162, "upload_time": "2019-09-08T07:22:40", "url": "https://files.pythonhosted.org/packages/17/7c/323ec59c52d6c49defcba944167843437976cbc7261cff058721e74fc77f/numpyro-0.2.0-py3-none-any.whl" }, { "comment_text": "", "digests": { "md5": "26e76fef203630d51dfb292ca94b3e29", "sha256": "32b2dc6e0dc1c94a0b6590bfba51f605446162072ab73ca69e272dad1007aaf8" }, "downloads": -1, "filename": "numpyro-0.2.0.tar.gz", "has_sig": false, "md5_digest": "26e76fef203630d51dfb292ca94b3e29", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 101660, "upload_time": "2019-09-08T07:22:41", "url": "https://files.pythonhosted.org/packages/5a/fa/574b880cb719cd2b2310a6851ff334a2e0cbec1929b3871c16a37988e30f/numpyro-0.2.0.tar.gz" } ] }