{ "info": { "author": "", "author_email": "", "bugtrack_url": null, "classifiers": [ "Intended Audience :: Developers", "Intended Audience :: Education", "Intended Audience :: Science/Research", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", "Topic :: Scientific/Engineering :: Artificial Intelligence" ], "description": "# JAXnet [![Build Status](https://travis-ci.org/JuliusKunze/jaxnet.svg?branch=master)](https://travis-ci.org/JuliusKunze/jaxnet) [![PyPI](https://img.shields.io/pypi/v/jaxnet.svg)](https://pypi.python.org/pypi/jaxnet/#history)\n\nJAXnet is a deep learning library based on [JAX](https://github.com/google/jax).\nJAXnet's functional API provides unique benefits over TensorFlow2, Keras and PyTorch,\nwhile maintaining user-friendliness, modularity and scalability:\n- More robustness through immutable weights, no global compute graph.\n- GPU-compiled `numpy` code for networks, training loops, pre- and postprocessing.\n- Regularization and reparametrization of any module or whole networks in one line.\n- No global random state, flexible random key control.\n\nIf you already know stax, read [this](STAX.md).\n\n### Modularity\n\n```python\nfrom jaxnet import *\n\nnet = Sequential(Dense(1024), relu, Dense(1024), relu, Dense(4), logsoftmax)\n```\ncreates a neural net model from predefined modules.\n\n### Extensibility\n\nDefine your own modules using `@parametrized` functions. You can reuse other modules:\n\n```python\nfrom jax import numpy as np\n\n@parametrized\ndef loss(inputs, targets):\n return -np.mean(net(inputs) * targets)\n```\n\nAll modules are composed in this way.\n[`jax.numpy`](https://github.com/google/jax#whats-supported) is mirroring `numpy`,\nmeaning that if you know how to use `numpy`, you know most of JAXnet.\nCompare this to TensorFlow2/Keras:\n\n```python\nimport tensorflow as tf\nfrom tensorflow.keras import Sequential\nfrom tensorflow.keras.layers import Dense, Lambda\n\nnet = Sequential([Dense(1024, 'relu'), Dense(1024, 'relu'), Dense(4), Lambda(tf.nn.log_softmax)])\n\ndef loss(inputs, targets):\n return -tf.reduce_mean(net(inputs) * targets)\n```\n\nNotice how `Lambda` layers are not needed in JAXnet.\n`relu` and `logsoftmax` are plain Python functions.\n\n### Immutable weights\n\nDifferent from TensorFlow2/Keras, JAXnet has no global compute graph.\nModules like `net` and `loss` do not contain mutable weights.\nInstead, weights are contained in separate, immutable objects.\nThey are initialized with `init_parameters`, provided a random key and example inputs:\n\n```python\nfrom jax.random import PRNGKey\n\ndef next_batch(): return np.zeros((3, 784)), np.zeros((3, 4))\n\nparams = loss.init_parameters(PRNGKey(0), *next_batch())\n\nprint(params.sequential.dense2.bias) # [0.00376661 0.01038619 0.00920947 0.00792002]\n```\n\nInstead of mutating weights inline, optimizers return updated versions of weights.\nThey are returned as part of a new optimizer state, and can be retrieved via `get_parameters`:\n\n```python\nopt = optimizers.Adam()\nstate = opt.init(params)\nfor _ in range(10):\n state = opt.update(loss.apply, state, *next_batch()) # accelerate with jit=True\n\ntrained_params = opt.get_parameters(state)\n```\n\n`apply` evaluates a network:\n\n```python\ntest_loss = loss.apply(trained_params, *test_batch) # accelerate with jit=True\n```\n\n### GPU support and compilation\n\nJAX allows functional `numpy`/`scipy` code to be accelerated.\nMake it run on GPU by replacing your `numpy` import with `jax.numpy`.\nCompile a function by decorating it with [`jit`](https://github.com/google/jax#compilation-with-jit).\nThis will free your function from slow Python interpretation, parallelize operations where possible and optimize your compute graph.\nIt provides speed and scalability at the level of TensorFlow2 or PyTorch.\n\nDue to immutable weights, whole training loops can be compiled / run on GPU ([demo](examples/mnist_vae.py#L96)).\n`jit` will make your training as fast as mutating weights inline, and weights will not leave the GPU.\nYou can write functional code without worrying about performance.\n\nYou can easily accelerate `numpy`/`scipy` pre-/postprocessing code in the same way ([demo](examples/mnist_vae.py#L61)).\n\n### Regularization and reparametrization\n\nIn JAXnet, regularizing a model can be done in one line ([demo](examples/wavenet.py#L167)):\n\n```python\nloss = L2Regularized(loss, scale=.1)\n```\n\n`loss` is now just another module that can be used as above.\nReparametrized layers are one-liners, too (see [API](API.md#regularization-and-reparametrization)).\nJAXnet allows regularizing or reparametrizing any module or subnetwork without changing its code.\nThis is possible because modules do not instantiate any variables.\nInstead each module provides a function (`apply`) with parameters as an argument.\nThis function can be wrapped to build layers like `L2Regularized`.\n\nIn contrast, TensorFlow2/Keras/PyTorch have mutable variables baked into their model API. They therefore require:\n- Regularization arguments on layer level, with separate code necessary for each layer.\n- Reparametrization arguments on layer level, and separate implementations for [each](https://www.tensorflow.org/probability/api_docs/python/tfp/layers/DenseReparameterization) [layer](https://www.tensorflow.org/probability/api_docs/python/tfp/layers/Convolution1DReparameterization).\n\n### Random key control\nJAXnet does not have global random state.\nRandom keys are provided explicitly, making code deterministic and independent of previously executed code by default.\nThis can help debugging and is more flexible ([demo](examples/mnist_vae.py#L81)).\nRead more on random numbers in JAX [here](https://github.com/google/jax#random-numbers-are-different).\n\n### Step-by-step debugging\n\nJAXnet allows step-by-step debugging with concrete values like any plain Python function\n(when [`jit`](https://github.com/google/jax#compilation-with-jit) compilation is not used).\n\n## API and demos\nFind more details on the API [here](API.md).\n\nSee JAXnet in action in your browser:\n[Mnist Classifier](https://colab.research.google.com/drive/18kICTUbjqnfg5Lk3xFVQtUj6ahct9Vmv),\n[Mnist VAE](https://colab.research.google.com/drive/19web5SnmIFglLcnpXE34phiTY03v39-g),\n[OCR with RNNs (to be fixed)](https://colab.research.google.com/drive/1YuI6GUtMgnMiWtqoaPznwAiSCe9hMR1E),\n[ResNet](https://colab.research.google.com/drive/1q6yoK_Zscv-57ZzPM4qNy3LgjeFzJ5xN) and\n[WaveNet](https://colab.research.google.com/drive/111cKRfwYX4YFuPH3FF4V46XLfsPG1icZ).\n\n## Installation\n**This is a preview. Expect breaking changes!** Install with\n\n```\npip3 install jaxnet\n```\n\nTo use GPU, first install the [right version of jaxlib](https://github.com/google/jax#installation).\n\n## Questions\n\nPlease feel free to create an issue on GitHub.\n\n", "description_content_type": "text/markdown", "docs_url": null, "download_url": "", "downloads": { "last_day": -1, "last_month": -1, "last_week": -1 }, "home_page": "http://github.com/JuliusKunze/jaxnet", "keywords": "", "license": "", "maintainer": "", "maintainer_email": "", "name": "jaxnet", "package_url": "https://pypi.org/project/jaxnet/", "platform": "", "project_url": "https://pypi.org/project/jaxnet/", "project_urls": { "Homepage": "http://github.com/JuliusKunze/jaxnet" }, "release_url": "https://pypi.org/project/jaxnet/0.2.2/", "requires_dist": [ "jax (>=0.1.42)", "dill" ], "requires_python": "", "summary": "Neural Nets for JAX", "version": "0.2.2" }, "last_serial": 5774169, "releases": { "0.1": [ { "comment_text": "", "digests": { "md5": "59cb5843e1efe23f1bef7106a692a7cb", "sha256": "e5d9608fe4ef54c5c41e83e19996dcbc567a53531c43ea72b29addfe73fa2ee9" }, "downloads": -1, "filename": "jaxnet-0.1-py3.6.egg", "has_sig": false, "md5_digest": "59cb5843e1efe23f1bef7106a692a7cb", "packagetype": "bdist_egg", "python_version": "3.6", "requires_python": null, "size": 13920, "upload_time": "2019-08-11T12:58:02", "url": "https://files.pythonhosted.org/packages/90/79/618cd047e1d561449a443ae44a97b72e38db5e58400bc2bdee4a677efb07/jaxnet-0.1-py3.6.egg" }, { "comment_text": "", "digests": { "md5": "bb9d9db44da933fd0bf70e5b3216918c", "sha256": "e77f3b8a58538ee17186e99de54bc8596e9dedd612a7f0064f777beb94911bc6" }, "downloads": -1, "filename": "jaxnet-0.1-py3-none-any.whl", "has_sig": false, "md5_digest": "bb9d9db44da933fd0bf70e5b3216918c", "packagetype": "bdist_wheel", "python_version": "py3", "requires_python": null, "size": 11398, "upload_time": "2019-08-11T12:57:59", "url": "https://files.pythonhosted.org/packages/ae/84/9388ab22690f0fddc85fde89c8f3c3661490266974c579d89e1f9618453f/jaxnet-0.1-py3-none-any.whl" }, { "comment_text": "", "digests": { "md5": "3a7e492f3de266d1ebd185e57a585858", "sha256": "80ee897d1eae792d1b90f0eb6ca5d0a2b1cb336a4cb22338ca0acf60e30480e7" }, "downloads": -1, "filename": "jaxnet-0.1.tar.gz", "has_sig": false, "md5_digest": "3a7e492f3de266d1ebd185e57a585858", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 7365, "upload_time": "2019-08-11T12:58:04", "url": "https://files.pythonhosted.org/packages/c5/0b/d59603c742e3dfd879b5f545b040cb73c1d9f6849285b27d934c91573ef4/jaxnet-0.1.tar.gz" } ], "0.1.1": [ { "comment_text": "", "digests": { "md5": "8db1784149f25548669eebbb231b4472", "sha256": "50b7e1355f7c2dd9eaf48184979dd7b3170fe7613705b252c81f634222059a79" }, "downloads": -1, "filename": "jaxnet-0.1.1-py3-none-any.whl", "has_sig": false, "md5_digest": "8db1784149f25548669eebbb231b4472", "packagetype": "bdist_wheel", "python_version": "py3", "requires_python": null, "size": 11365, "upload_time": "2019-08-12T20:03:39", "url": "https://files.pythonhosted.org/packages/f7/32/55e76e8a4e64f510b62c282e85e64c25cd8b14d199eafa339fc1ab5000a8/jaxnet-0.1.1-py3-none-any.whl" }, { "comment_text": "", "digests": { "md5": "4da08703b7674d45d7e1b3e1adc6d5cf", "sha256": "da3fdbfa6660ad00711d15b40e737be02258564812f086283092f41e32608cb8" }, "downloads": -1, "filename": "jaxnet-0.1.1.tar.gz", "has_sig": false, "md5_digest": "4da08703b7674d45d7e1b3e1adc6d5cf", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 6867, "upload_time": "2019-08-12T20:03:41", "url": "https://files.pythonhosted.org/packages/49/4f/9e2568e107b5585fbe6c5f3e6cb372b0f0056ea357162a9bcab07c836ff9/jaxnet-0.1.1.tar.gz" } ], "0.1.2": [ { "comment_text": "", "digests": { "md5": "39ec922823e07515a42c9c1e5c782063", "sha256": "26e1233f5be479cc8adddf4eb0a0d488c1fb2b854f3034a2bd7bb594db322b2b" }, "downloads": -1, "filename": "jaxnet-0.1.2-py3-none-any.whl", "has_sig": false, "md5_digest": "39ec922823e07515a42c9c1e5c782063", "packagetype": "bdist_wheel", "python_version": "py3", "requires_python": null, "size": 12411, "upload_time": "2019-08-14T00:18:18", "url": "https://files.pythonhosted.org/packages/8a/12/6c0f67e5ce0b24d61126699d2b89f5a6dee2c42e07893189bec0ba1aed8c/jaxnet-0.1.2-py3-none-any.whl" }, { "comment_text": "", "digests": { "md5": "ee627744f2b1f91e185a7a664c00ea4e", "sha256": "6fc8ad35c7b4bdf318a649ed1af5df1cddd7ac1f2771199ef6409cae462de497" }, "downloads": -1, "filename": "jaxnet-0.1.2.tar.gz", "has_sig": false, "md5_digest": "ee627744f2b1f91e185a7a664c00ea4e", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 8540, "upload_time": "2019-08-14T00:18:20", "url": "https://files.pythonhosted.org/packages/8a/80/177d480fa583be531757292bddf53b131d4b514271363fd858e0c54b3002/jaxnet-0.1.2.tar.gz" } ], "0.1.3": [ { "comment_text": "", "digests": { "md5": "1fd10cb316da4c3b80d3764780b9ea46", "sha256": "e3a8252ffe5df1a4b769ec0e4fc45607b97526b088fd62fd2835836b8103bf8d" }, "downloads": -1, "filename": "jaxnet-0.1.3-py3-none-any.whl", "has_sig": false, "md5_digest": "1fd10cb316da4c3b80d3764780b9ea46", "packagetype": "bdist_wheel", "python_version": "py3", "requires_python": null, "size": 8820, "upload_time": "2019-08-14T22:39:18", "url": "https://files.pythonhosted.org/packages/b8/14/16b140c1fbc443ebe85ff2355542a1e10e85addd6bedb26f90bdd869a06c/jaxnet-0.1.3-py3-none-any.whl" }, { "comment_text": "", "digests": { "md5": "5eb15bb6729bd70689f35416bbf47fd4", "sha256": "0dce9b79f8a3dcf5431f672b0a3097ef42fd5f4083c9c1e77f66582ec00bdf94" }, "downloads": -1, "filename": "jaxnet-0.1.3.tar.gz", "has_sig": false, "md5_digest": "5eb15bb6729bd70689f35416bbf47fd4", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 8496, "upload_time": "2019-08-14T22:39:20", "url": "https://files.pythonhosted.org/packages/8e/f6/369eee1788c7e7fde3145ccb32a3b21426007a0b7c7382c35d43780f77ab/jaxnet-0.1.3.tar.gz" } ], "0.1.4": [ { "comment_text": "", "digests": { "md5": "e288c293175081ca1cce28e974e042a0", "sha256": "5cba3b80e616aed610e9b64af02067b0aa5ae9f30dda36581ea7bee274e2d845" }, "downloads": -1, "filename": "jaxnet-0.1.4-py3-none-any.whl", "has_sig": false, "md5_digest": "e288c293175081ca1cce28e974e042a0", "packagetype": "bdist_wheel", "python_version": "py3", "requires_python": null, "size": 9046, "upload_time": "2019-08-16T19:48:44", "url": "https://files.pythonhosted.org/packages/79/65/efbc58d301b1d17c55d02140a7f01801a09ca65aa017dab1b1697b321f3e/jaxnet-0.1.4-py3-none-any.whl" }, { "comment_text": "", "digests": { "md5": "8f7eb80e575fa32ad553de83f82cefc4", "sha256": "c5bffeda1889d110c406ea5d15a017bd567a6907c3bad17bf76d93e86d6e2562" }, "downloads": -1, "filename": "jaxnet-0.1.4.tar.gz", "has_sig": false, "md5_digest": "8f7eb80e575fa32ad553de83f82cefc4", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 8735, "upload_time": "2019-08-16T19:48:46", "url": "https://files.pythonhosted.org/packages/0f/27/881e3f24f97b5ade3cd0bb40eb93b5597633e3ac88bbd9109ddd9e02cf98/jaxnet-0.1.4.tar.gz" } ], "0.2.0": [ { "comment_text": "", "digests": { "md5": "70604e2ddeb0080f0dd38dfcfd5c8959", "sha256": "0c6b7ab9d4f45034884db7d85e2ddd0b62672ae11ce5db38aaa9a64c29410ffd" }, "downloads": -1, "filename": "jaxnet-0.2.0-py3-none-any.whl", "has_sig": false, "md5_digest": "70604e2ddeb0080f0dd38dfcfd5c8959", "packagetype": "bdist_wheel", "python_version": "py3", "requires_python": null, "size": 12620, "upload_time": "2019-08-24T22:41:12", "url": "https://files.pythonhosted.org/packages/ae/f5/22208df7a3eabf9afe972b0a92241cab1cc2fc24737372f6a178acab167e/jaxnet-0.2.0-py3-none-any.whl" }, { "comment_text": "", "digests": { "md5": "dfaddad78dd6720f9d9877a3e6f6ee87", "sha256": "36d0bed7101011164b7b63c5fb5442849fedbb150910195495a3a27b6d7c063f" }, "downloads": -1, "filename": "jaxnet-0.2.0.tar.gz", "has_sig": false, "md5_digest": "dfaddad78dd6720f9d9877a3e6f6ee87", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 14149, "upload_time": "2019-08-24T22:41:13", "url": "https://files.pythonhosted.org/packages/a9/54/50c5a49e9254da35b8333b56158f9d762b09e8f546a19c018b498efded01/jaxnet-0.2.0.tar.gz" } ], "0.2.1": [ { "comment_text": "", "digests": { "md5": "93d68bb2a7156209dfadeefbf2eaff47", "sha256": "6547cc79a3729d486168cb000d42532cf882c850468858495ec9e28d26b69f10" }, "downloads": -1, "filename": "jaxnet-0.2.1-py3-none-any.whl", "has_sig": false, "md5_digest": "93d68bb2a7156209dfadeefbf2eaff47", "packagetype": "bdist_wheel", "python_version": "py3", "requires_python": null, "size": 13710, "upload_time": "2019-08-27T09:53:12", "url": "https://files.pythonhosted.org/packages/4a/2d/baa736ceed15faf2eab45e973f50af5ecb239be050cee03459d0f28dd812/jaxnet-0.2.1-py3-none-any.whl" }, { "comment_text": "", "digests": { "md5": "f0d879bb24ec261011b2b0fd08b85be9", "sha256": "cf6d5988ab4a34d11da11da45405188b79a8616b1f84e6a27e43c944eb128f02" }, "downloads": -1, "filename": "jaxnet-0.2.1.tar.gz", "has_sig": false, "md5_digest": "f0d879bb24ec261011b2b0fd08b85be9", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 15174, "upload_time": "2019-08-27T09:53:14", "url": "https://files.pythonhosted.org/packages/e4/23/826c3d12f069a3a244ded2d25f62ef20703e64392313af709cc8dbdb4cee/jaxnet-0.2.1.tar.gz" } ], "0.2.2": [ { "comment_text": "", "digests": { "md5": "67c24e3a6273af6eae68f227bd771660", "sha256": "4d6b986a9f1e3eaa377ca1fe16d7d3f16610f03029570adea9f0403626d95219" }, "downloads": -1, "filename": "jaxnet-0.2.2-py3-none-any.whl", "has_sig": false, "md5_digest": "67c24e3a6273af6eae68f227bd771660", "packagetype": "bdist_wheel", "python_version": "py3", "requires_python": null, "size": 17931, "upload_time": "2019-09-03T06:24:40", "url": "https://files.pythonhosted.org/packages/b5/50/89d95e5c9cf3020c5b7518737ed5c513840ac7f4d32be6a55787ea075155/jaxnet-0.2.2-py3-none-any.whl" }, { "comment_text": "", "digests": { "md5": "28d2193a8645712db97956e4075b5f38", "sha256": "5bb38384c2eb43b476b5c67357c00fc5dcef48e93e497eaf1199f6ce77f6e766" }, "downloads": -1, "filename": "jaxnet-0.2.2.tar.gz", "has_sig": false, "md5_digest": "28d2193a8645712db97956e4075b5f38", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 15712, "upload_time": "2019-09-03T06:24:42", "url": "https://files.pythonhosted.org/packages/0d/27/77e409e467ac3a055c25f26533e668b0819e65878ca8ffe1bed35ab61df6/jaxnet-0.2.2.tar.gz" } ] }, "urls": [ { "comment_text": "", "digests": { "md5": "67c24e3a6273af6eae68f227bd771660", "sha256": "4d6b986a9f1e3eaa377ca1fe16d7d3f16610f03029570adea9f0403626d95219" }, "downloads": -1, "filename": "jaxnet-0.2.2-py3-none-any.whl", "has_sig": false, "md5_digest": "67c24e3a6273af6eae68f227bd771660", "packagetype": "bdist_wheel", "python_version": "py3", "requires_python": null, "size": 17931, "upload_time": "2019-09-03T06:24:40", "url": "https://files.pythonhosted.org/packages/b5/50/89d95e5c9cf3020c5b7518737ed5c513840ac7f4d32be6a55787ea075155/jaxnet-0.2.2-py3-none-any.whl" }, { "comment_text": "", "digests": { "md5": "28d2193a8645712db97956e4075b5f38", "sha256": "5bb38384c2eb43b476b5c67357c00fc5dcef48e93e497eaf1199f6ce77f6e766" }, "downloads": -1, "filename": "jaxnet-0.2.2.tar.gz", "has_sig": false, "md5_digest": "28d2193a8645712db97956e4075b5f38", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 15712, "upload_time": "2019-09-03T06:24:42", "url": "https://files.pythonhosted.org/packages/0d/27/77e409e467ac3a055c25f26533e668b0819e65878ca8ffe1bed35ab61df6/jaxnet-0.2.2.tar.gz" } ] }