{ "info": { "author": "Simon Larsson", "author_email": "larssonsimon0@gmail.com", "bugtrack_url": null, "classifiers": [ "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3 :: Only", "Topic :: Scientific/Engineering" ], "description": "# Keras SWA - Stochastic Weight Averaging\n\n[![PyPI version](https://badge.fury.io/py/keras-swa.svg)](https://pypi.python.org/pypi/keras-swa/) \n[![License](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/simon-larsson/keras-swa/blob/master/LICENSE)\n\nThis is an implemention of SWA for Keras and TF-Keras.\n\n## Introduction\nStochastic weight averaging (SWA) is build upon the same principle as [snapshot ensembling](https://arxiv.org/abs/1704.00109) and [fast geometric ensembling](https://arxiv.org/abs/1802.10026). The idea is that averaging select stages of training can lead to better models. Where as the two former methods average by sampling and ensembling models, SWA instead average weights. This has been shown to give comparable improvements confined into a single model.\n\n[![Illustration](https://raw.githubusercontent.com/simon-larsson/keras-swa/master/swa_illustration.png)](https://raw.githubusercontent.com/simon-larsson/keras-swa/master/swa_illustration.png)\n\n## Paper\n - Title: Averaging Weights Leads to Wider Optima and Better Generalization\n - Link: https://arxiv.org/abs/1803.05407\n - Authors: Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov, Andrew Gordon Wilson\n - Repo: https://github.com/timgaripov/swa (PyTorch)\n\n## Installation\n\n pip install keras-swa\n\n### SWA API\n\nKeras callback object for SWA. \n\n### Arguments\n**start_epoch** - Starting epoch for SWA.\n\n**lr_schedule** - Learning rate schedule. `'manual'` , `'constant'` or `'cyclic'`.\n\n**swa_lr** - Learning rate used when averaging weights.\n\n**swa_lr2** - Upper bound of learning rate for the cyclic schedule.\n\n**swa_freq** - Frequency of weight averagining. Used with cyclic schedules.\n\n**batch_size** - Batch size model is being trained with (only when using batch normalization).\n\n**verbose** - Verbosity mode, 0 or 1.\n\n### Batch Normalization\nLast epoch will be a forward pass, i.e. have learning rate set to zero, for models with batch normalization. This is due to the fact that batch normalization uses the running mean and variance of it's preceding layer to make a normalization. SWA will offset this normalization by suddenly changing the weights in the end of training. Therefore, it is necessary for the last epoch to be used to reset and recalculate batch normalization running mean and variance for the updated weights. Batch normalization gamma and beta values are preserved.\n\n**When using manual schedule:** The SWA callback will set learning rate to zero in the last epoch if batch normalization is used. This must not be undone by any external learning rate schedulers for SWA to work properly. \n\n### Learning Rate Schedules\nThe default schedule is `'manual'`, allowing the learning rate to be controlled by an external learning rate scheduler or the optimizer. Then SWA will only affect the final weights and the learning rate of the last epoch if batch normalization is used. The schedules for the two predefined, `'constant'` or `'cyclic'` can be observed below.\n\n[![lr_schedules](https://raw.githubusercontent.com/simon-larsson/keras-swa/master/lr_schedules.png)](https://raw.githubusercontent.com/simon-larsson/keras-swa/master/lr_schedules.png)\n\n\n#### Example\n\nFor Tensorflow Keras (with constant LR)\n```python\nfrom sklearn.datasets import make_blobs\nfrom tensorflow.keras.utils import to_categorical\nfrom tensorflow.keras.models import Sequential\nfrom tensorflow.keras.layers import Dense\nfrom tensorflow.keras.optimizers import SGD\n\nfrom swa.tfkeras import SWA\n \n# make dataset\nX, y = make_blobs(n_samples=1000, \n centers=3, \n n_features=2, \n cluster_std=2, \n random_state=2)\n\ny = to_categorical(y)\n\n# build model\nmodel = Sequential()\nmodel.add(Dense(50, input_dim=2, activation='relu'))\nmodel.add(Dense(3, activation='softmax'))\n\nmodel.compile(loss='categorical_crossentropy', \n optimizer=SGD(lr=0.1))\n\nepochs = 100\nstart_epoch = 75\n\n# define swa callback\nswa = SWA(start_epoch=start_epoch, \n lr_schedule='constant', \n swa_lr=0.01, \n verbose=1)\n\n# train\nmodel.fit(X, y, epochs=epochs, verbose=1, callbacks=[swa])\n```\n\nOr for Keras (with Cyclic LR)\n```python\nfrom sklearn.datasets import make_blobs\nfrom keras.utils import to_categorical\nfrom keras.models import Sequential\nfrom keras.layers import Dense, BatchNormalization\nfrom keras.optimizers import SGD\n\nfrom swa.keras import SWA\n\n# make dataset\nX, y = make_blobs(n_samples=1000, \n centers=3, \n n_features=2, \n cluster_std=2, \n random_state=2)\n\ny = to_categorical(y)\n\n# build model\nmodel = Sequential()\nmodel.add(Dense(50, input_dim=2, activation='relu'))\nmodel.add(BatchNormalization())\nmodel.add(Dense(3, activation='softmax'))\n\nmodel.compile(loss='categorical_crossentropy', \n optimizer=SGD(learning_rate=0.1))\n\nepochs = 100\nstart_epoch = 75\n\n# define swa callback\nswa = SWA(start_epoch=start_epoch, \n lr_schedule='cyclic', \n swa_lr=0.001,\n swa_lr2=0.003,\n swa_freq=3,\n batch_size=32, # needed when using batch norm\n verbose=1)\n\n# train\nmodel.fit(X, y, batch_size=32, epochs=epochs, verbose=1, callbacks=[swa])\n```\n\nOutput\n```\nModel uses batch normalization. SWA will require last epoch to be a forward pass and will run with no learning rate\nEpoch 1/100\n1000/1000 [==============================] - 1s 547us/sample - loss: 0.5529\nEpoch 2/100\n1000/1000 [==============================] - 0s 160us/sample - loss: 0.4720\n...\nEpoch 74/100\n1000/1000 [==============================] - 0s 160us/sample - loss: 0.4249\n\nEpoch 00075: starting stochastic weight averaging\nEpoch 75/100\n1000/1000 [==============================] - 0s 164us/sample - loss: 0.4357\nEpoch 76/100\n1000/1000 [==============================] - 0s 165us/sample - loss: 0.4209\n...\nEpoch 99/100\n1000/1000 [==============================] - 0s 167us/sample - loss: 0.4263\n\nEpoch 00100: final model weights set to stochastic weight average\n\nEpoch 00100: reinitializing batch normalization layers\n\nEpoch 00100: running forward pass to adjust batch normalization\nEpoch 100/100\n1000/1000 [==============================] - 0s 166us/sample - loss: 0.4408\n```\n\n### Collaborators\n\n- [Simon Larsson](https://github.com/simon-larsson \"Github\")\n- [Alex Stoken](https://github.com/alexstoken \"Github\")", "description_content_type": "text/markdown", "docs_url": null, "download_url": "", "downloads": { "last_day": -1, "last_month": -1, "last_week": -1 }, "home_page": "https://github.com/simon-larsson/keras-swa", "keywords": "", "license": "MIT", "maintainer": "", "maintainer_email": "", "name": "keras-swa", "package_url": "https://pypi.org/project/keras-swa/", "platform": "", "project_url": "https://pypi.org/project/keras-swa/", "project_urls": { "Homepage": "https://github.com/simon-larsson/keras-swa" }, "release_url": "https://pypi.org/project/keras-swa/0.1.7/", "requires_dist": null, "requires_python": "", "summary": "Simple stochastic weight averaging callback for Keras.", "version": "0.1.7", "yanked": false, "yanked_reason": null }, "last_serial": 11573731, "releases": { "0.0.1": [ { "comment_text": "", "digests": { "md5": "bd376de5ca8a6eb2bb22b1917d68a004", "sha256": "96ae71e4266523d6e9d3bc22d2ae21e0259cc57835e96e8de687269deb58f66b" }, "downloads": -1, "filename": "keras-swa-0.0.1.tar.gz", "has_sig": false, "md5_digest": "bd376de5ca8a6eb2bb22b1917d68a004", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 2763, "upload_time": "2019-10-02T12:33:23", "upload_time_iso_8601": "2019-10-02T12:33:23.722782Z", "url": "https://files.pythonhosted.org/packages/c1/9e/bda0c7d265358cae43d8e34abe3fdd280b6032e3ddf350de0b56aae61359/keras-swa-0.0.1.tar.gz", "yanked": false, "yanked_reason": null } ], "0.0.2": [ { "comment_text": "", "digests": { "md5": "7a62b0769f811edabeaaf160ec5a9e7a", "sha256": "b61916e8f539e42da8512aef8e08c2efebac1cac30aa4cf80c04a7d144742b7e" }, "downloads": -1, "filename": "keras-swa-0.0.2.tar.gz", "has_sig": false, "md5_digest": "7a62b0769f811edabeaaf160ec5a9e7a", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 2941, "upload_time": "2019-10-03T12:04:37", "upload_time_iso_8601": "2019-10-03T12:04:37.173497Z", "url": "https://files.pythonhosted.org/packages/9c/f2/30d13ba8fd678fda68f1a97392e4edc42705e5867d4e92577c8b69f489f2/keras-swa-0.0.2.tar.gz", "yanked": false, "yanked_reason": null } ], "0.0.3": [ { "comment_text": "", "digests": { "md5": "7372b5c3d75d717670d0b428c92ded91", "sha256": "b2ce77496bfcc98ffa97dbeeedd1f319effe8f9b5ff0d76fe9ec63f686699929" }, "downloads": -1, "filename": "keras-swa-0.0.3.tar.gz", "has_sig": false, "md5_digest": "7372b5c3d75d717670d0b428c92ded91", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 3058, "upload_time": "2019-10-07T11:14:50", "upload_time_iso_8601": "2019-10-07T11:14:50.143256Z", "url": "https://files.pythonhosted.org/packages/aa/2f/bd653405f3fc2a209157e2d4c62328a8ed71278dce260e653a69019144a1/keras-swa-0.0.3.tar.gz", "yanked": false, "yanked_reason": null } ], "0.0.4": [ { "comment_text": "", "digests": { "md5": "0165707d4ac6bf5dfc991b3d9a2c93cf", "sha256": "eff1cd79265961655f91bc48db373bda78e5ab551267770cf24db7531b61edc7" }, "downloads": -1, "filename": "keras-swa-0.0.4.tar.gz", "has_sig": false, "md5_digest": "0165707d4ac6bf5dfc991b3d9a2c93cf", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 3106, "upload_time": "2019-10-14T09:26:51", "upload_time_iso_8601": "2019-10-14T09:26:51.473589Z", "url": "https://files.pythonhosted.org/packages/34/64/1cd6879cb8d6ce597c7e54d1bd6a17c294ecacaaa45e37466962cf90ad59/keras-swa-0.0.4.tar.gz", "yanked": false, "yanked_reason": null } ], "0.0.5": [ { "comment_text": "", "digests": { "md5": "f120833b361be07a45dae524a42edfae", "sha256": "6a1710d9c97f9d9735f75ea51aed94c6966caa93a8386e77f7c1d4dd7df953b2" }, "downloads": -1, "filename": "keras-swa-0.0.5.tar.gz", "has_sig": false, "md5_digest": "f120833b361be07a45dae524a42edfae", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 3057, "upload_time": "2019-10-14T16:11:23", "upload_time_iso_8601": "2019-10-14T16:11:23.166780Z", "url": "https://files.pythonhosted.org/packages/79/57/43fcc5373473b582e5b40b54d34885372a1756dc672dbfade4f37ac043ac/keras-swa-0.0.5.tar.gz", "yanked": false, "yanked_reason": null } ], "0.0.6": [ { "comment_text": "", "digests": { "md5": "fabe0ad5b6f0a793815066b24d7941f3", "sha256": "b679df949073ca7e64c6cbea9660b92485a0b6e60b0262acf927116b297538b5" }, "downloads": -1, "filename": "keras-swa-0.0.6.tar.gz", "has_sig": false, "md5_digest": "fabe0ad5b6f0a793815066b24d7941f3", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 3098, "upload_time": "2019-10-18T18:32:55", "upload_time_iso_8601": "2019-10-18T18:32:55.752191Z", "url": "https://files.pythonhosted.org/packages/db/f4/2aabf7096c60c58f3ff6bb086ea359a474c6bd96c741dbaeb9866f477fcf/keras-swa-0.0.6.tar.gz", "yanked": false, "yanked_reason": null } ], "0.1.0": [ { "comment_text": "", "digests": { "md5": "317ba548ee669118004c06d429e0e18b", "sha256": "25388dab93a12c4378bb313eb91218743cfe2a6b7c7f0dcdfc21739704de34fd" }, "downloads": -1, "filename": "keras-swa-0.1.0.tar.gz", "has_sig": false, "md5_digest": "317ba548ee669118004c06d429e0e18b", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 3164, "upload_time": "2019-10-18T18:36:15", "upload_time_iso_8601": "2019-10-18T18:36:15.306771Z", "url": "https://files.pythonhosted.org/packages/b9/c5/00b8e545f2b8cf79529acef3aa676e07528c7be03d39354aedf73eb249ba/keras-swa-0.1.0.tar.gz", "yanked": false, "yanked_reason": null } ], "0.1.1": [ { "comment_text": "", "digests": { "md5": "1ffbe2494caa54e9975e700dbbc9d2a3", "sha256": "293d8573233f86b2ac7e25afe22f3005b1c02c34c5882511ceff874856a37cb4" }, "downloads": -1, "filename": "keras_swa-0.1.1-py3-none-any.whl", "has_sig": false, "md5_digest": "1ffbe2494caa54e9975e700dbbc9d2a3", "packagetype": "bdist_wheel", "python_version": "py3", "requires_python": null, "size": 8019, "upload_time": "2019-10-19T07:08:07", "upload_time_iso_8601": "2019-10-19T07:08:07.293348Z", "url": "https://files.pythonhosted.org/packages/9e/ac/1ffe8397b29bc4e6004677921fb14fb4c6b6a89b75caabd7a0a32cc2fb6a/keras_swa-0.1.1-py3-none-any.whl", "yanked": false, "yanked_reason": null }, { "comment_text": "", "digests": { "md5": "65c6bb4e86b470c7f14d73ff14efb57b", "sha256": "062871e86885bf28c90869d5e2dbdee0d850029fb23fa832c81df3ea8da23f0f" }, "downloads": -1, "filename": "keras-swa-0.1.1.tar.gz", "has_sig": false, "md5_digest": "65c6bb4e86b470c7f14d73ff14efb57b", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 5121, "upload_time": "2019-10-19T07:08:08", "upload_time_iso_8601": "2019-10-19T07:08:08.710782Z", "url": "https://files.pythonhosted.org/packages/b5/8a/df128a0b7a14ce8414c499f1bfe0cb37d9d7d4f0987311e7d1f98ac6bcbe/keras-swa-0.1.1.tar.gz", "yanked": false, "yanked_reason": null } ], "0.1.2": [ { "comment_text": "", "digests": { "md5": "12c420f16c502652c66dfc2a38041f33", "sha256": "75cb1bbee7adad2094bb70f72490066c938e2d0609f280b21f0eb6ce19a40084" }, "downloads": -1, "filename": "keras-swa-0.1.2.tar.gz", "has_sig": false, "md5_digest": "12c420f16c502652c66dfc2a38041f33", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 5116, "upload_time": "2019-10-22T09:23:10", "upload_time_iso_8601": "2019-10-22T09:23:10.590780Z", "url": "https://files.pythonhosted.org/packages/2d/e0/56befd04516d8b2a690ff4b1aad1ddba95166b66f4d4ca0c47c48787828a/keras-swa-0.1.2.tar.gz", "yanked": false, "yanked_reason": null } ], "0.1.3": [ { "comment_text": "", "digests": { "md5": "0a18fed65b78f867482d85b7437968fb", "sha256": "cf84a30744b0a26a31526321bc0a28fd5031728bcb14c4eb77e6868a0a05eee6" }, "downloads": -1, "filename": "keras-swa-0.1.3.tar.gz", "has_sig": false, "md5_digest": "0a18fed65b78f867482d85b7437968fb", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 5519, "upload_time": "2019-12-14T05:50:47", "upload_time_iso_8601": "2019-12-14T05:50:47.798783Z", "url": "https://files.pythonhosted.org/packages/ba/45/cc2f034abbeddc0259733fecfcaf7bedc6f38e967d7b20617f04d3e77f75/keras-swa-0.1.3.tar.gz", "yanked": false, "yanked_reason": null } ], "0.1.4": [ { "comment_text": "", "digests": { "md5": "c692c8dc5a55b17e99e2395b2cda63b6", "sha256": "048d06f13719dcfd6432b6adb930affc166402d25708ac7d157f1ff34e9a3fa7" }, "downloads": -1, "filename": "keras-swa-0.1.4.tar.gz", "has_sig": false, "md5_digest": "c692c8dc5a55b17e99e2395b2cda63b6", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 6114, "upload_time": "2020-01-16T13:08:33", "upload_time_iso_8601": "2020-01-16T13:08:33.212258Z", "url": "https://files.pythonhosted.org/packages/5f/08/1ed4384b15b291499bb33b27d18bf67cf8323bf7f445338c46ab78f5252f/keras-swa-0.1.4.tar.gz", "yanked": false, "yanked_reason": null } ], "0.1.5": [ { "comment_text": "", "digests": { "md5": "68bab397361914c4f9c29bbfb296792a", "sha256": "15d0f85fa45c70845f842874a84ce15b5c010f6f09626eb8c9ff2eb263c3fe53" }, "downloads": -1, "filename": "keras-swa-0.1.5.tar.gz", "has_sig": false, "md5_digest": "68bab397361914c4f9c29bbfb296792a", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 77355, "upload_time": "2020-02-24T14:38:43", "upload_time_iso_8601": "2020-02-24T14:38:43.257794Z", "url": "https://files.pythonhosted.org/packages/91/40/2796068a5d40c9b7f8d2b0d5a830a45c3ff9909194931f50a87ec19f45d5/keras-swa-0.1.5.tar.gz", "yanked": false, "yanked_reason": null } ], "0.1.6": [ { "comment_text": "", "digests": { "md5": "eed99f0a3f6e50c903591e8977b83370", "sha256": "435ed50adbd70b48cc6dd0a234a388d751a3d16c0adabb495258ec5701b973b8" }, "downloads": -1, "filename": "keras_swa-0.1.6-py3-none-any.whl", "has_sig": false, "md5_digest": "eed99f0a3f6e50c903591e8977b83370", "packagetype": "bdist_wheel", "python_version": "py3", "requires_python": null, "size": 7636, "upload_time": "2021-05-02T18:15:33", "upload_time_iso_8601": "2021-05-02T18:15:33.344912Z", "url": "https://files.pythonhosted.org/packages/07/73/0256256dae8206e239e031a15a61e09f67412e4c176eed8b74c3b2e9cbfe/keras_swa-0.1.6-py3-none-any.whl", "yanked": false, "yanked_reason": null }, { "comment_text": "", "digests": { "md5": "2ebd447ddaa7e09332ed1a0d19bddb4a", "sha256": "92fb36be8522a62da985d30ff21b3fbb16ebe1fcefe440bc39627aa33d011e36" }, "downloads": -1, "filename": "keras-swa-0.1.6.tar.gz", "has_sig": false, "md5_digest": "2ebd447ddaa7e09332ed1a0d19bddb4a", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 7154, "upload_time": "2021-05-02T18:15:35", "upload_time_iso_8601": "2021-05-02T18:15:35.098095Z", "url": "https://files.pythonhosted.org/packages/90/d3/d8a9dbdc3a6c71f896f1ab8c9cfc76a714acb461052cbead15a5a69836cc/keras-swa-0.1.6.tar.gz", "yanked": false, "yanked_reason": null } ], "0.1.7": [ { "comment_text": "", "digests": { "md5": "42a54964ba532105a4e708be09cebc2d", "sha256": "d143ef034d1f30f1fd0a5a929f8a06ff644c8406b11a01b786e36a9cadc7a069" }, "downloads": -1, "filename": "keras-swa-0.1.7.tar.gz", "has_sig": false, "md5_digest": "42a54964ba532105a4e708be09cebc2d", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 76124, "upload_time": "2021-09-28T17:36:03", "upload_time_iso_8601": "2021-09-28T17:36:03.479188Z", "url": "https://files.pythonhosted.org/packages/29/5d/487813b1983c777eeea282779ca6f9663570b721139034afa310e35de129/keras-swa-0.1.7.tar.gz", "yanked": false, "yanked_reason": null } ] }, "urls": [ { "comment_text": "", "digests": { "md5": "42a54964ba532105a4e708be09cebc2d", "sha256": "d143ef034d1f30f1fd0a5a929f8a06ff644c8406b11a01b786e36a9cadc7a069" }, "downloads": -1, "filename": "keras-swa-0.1.7.tar.gz", "has_sig": false, "md5_digest": "42a54964ba532105a4e708be09cebc2d", "packagetype": "sdist", "python_version": "source", "requires_python": null, "size": 76124, "upload_time": "2021-09-28T17:36:03", "upload_time_iso_8601": "2021-09-28T17:36:03.479188Z", "url": "https://files.pythonhosted.org/packages/29/5d/487813b1983c777eeea282779ca6f9663570b721139034afa310e35de129/keras-swa-0.1.7.tar.gz", "yanked": false, "yanked_reason": null } ], "vulnerabilities": [] }