Skip to content
This repository has been archived by the owner on May 26, 2024. It is now read-only.

Commit

Permalink
Prepare for v0.2.2
Browse files Browse the repository at this point in the history
  • Loading branch information
poypoyan committed Feb 23, 2023
1 parent d221069 commit f8a7ee2
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 13 deletions.
2 changes: 1 addition & 1 deletion edhsmm/hsmm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def fit(self, X, lengths=None, left_censor=0, right_censor=1):
self._dur_mstep(new_dur) # new durations
self._emission_mstep(X, emission_var) # new emissions
print(f"FIT{ self._print_name }: reestimation complete for loop { itera + 1 }.")
# return fitted edhsmm for joblib
# return fitted model for joblib
return self


Expand Down
117 changes: 105 additions & 12 deletions notebooks/EDHSMM (Import test).ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,20 @@
"FIT: reestimation complete for loop 2.\n",
"FIT: converged at loop 3.\n"
]
},
{
"data": {
"text/plain": [
"<edhsmm.hsmm_base.GaussianHSMM at 0x1f53dc7fd30>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# initialize HSMM and EM algorithm\n",
"# initialize HSMM\n",
"R = GaussianHSMM(n_states = 3, n_durations = 4)\n",
"my_init(R)\n",
"# sample observations (from hsmmlearn)\n",
Expand All @@ -74,6 +84,7 @@
" 10.49171453, -0.72812025, 0.57309517, 0.3420868, -1.35338431, 4.12587557,\n",
" 6.907117, 5.41243634])\n",
"obs = obs[:, None] # shape should be (n_samples, n_dim)\n",
"# EM algorithm\n",
"R.fit(obs)"
]
},
Expand Down Expand Up @@ -156,12 +167,23 @@
"FIT: reestimation complete for loop 2.\n",
"FIT: converged at loop 3.\n"
]
},
{
"data": {
"text/plain": [
"<edhsmm.hsmm_base.GaussianHSMM at 0x1f53ddb0460>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# initialize HSMM and EM algorithm\n",
"# initialize HSMM\n",
"S = GaussianHSMM(n_states = 3, n_durations = 4)\n",
"my_init(S)\n",
"# EM algorithm\n",
"S.fit(multi_obs, lengths=multi_len)"
]
},
Expand Down Expand Up @@ -244,6 +266,66 @@
"print(\"Covariance Matrices: [R]\\n\", R.covmat, \"\\nCovariance Matrices: [S]\\n\", S.covmat)"
]
},
{
"attachments": {
"fit_print.png": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAa8AAAB6CAYAAAAI79wpAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAAJcEhZcwAAEnQAABJ0Ad5mH3gAACihSURBVHhe7Z2/cty4sofhG583UFkOHHhfYWoUWAq1uQLJVUpt5RPYVROpyhtMLilVla1AuRVaG1g1r7AKTnDs0nPsxa8BkCBIgk0S5HCs/m7Nud4ZiGx2N/41QPSL//znP/8qQRAEQdgi/s/+f0EQBEHYGqTzEgRBELYO6bwEQRCErUM6L0EQBGHr2IrO68Wff6m///5L/fnihf1mWP7862/1919/2v/iYWTUf+c+Lf++DS9e/Kn+0vf4689x9LGt/HH2dXBbtGFq8kyFMevOEGR2tZ9N1Mtt12EXqPMqPXiFAqhBj5QJDRh+vp79QeXa8uKPM/VlMVNPt9fq279mYyS++4rrfj1Tf3gdmpNxE87z77eP6u3bt2p//726fbJf9iDW0P377zd1rW8yW9w/yw7M2Z/77E+//mv/NQxTk2fbSFl32tqiL2g7r46Uun2/T8+Az8dv42/gTt3+jEnXQZ0381qr1X5uAPp8/GZ/U+rbR/Pd/vtb9aT/LzOWLfPPxUn2d6Uy+nNy8Q+VawNmGJ+1Z+ysV9V/vzNXB2/MP1F2b2b+vc24yrdUN1EnhL7fUwf2ebQZ6bbhfLKL7w3B1OQR+nOIRufpQX1/tF8IbLhtXR2TDhu++XCsZuhUP93ZbyyvX6qdp1t1u95RL1/b7w731O7tautGHT6ms56rB93pv7tsHp0/Xt5o7czU4vOh/UYQhNH5+T/1j40KCTzatnVVTLbzwsOdHu0UwoUh33/opnvPNNyHe7vq4btRwu4rOx3TlMKZFVPTMGy6qJjBjRFTRjjw49sTdfEPryK48KFWQu3syz1/57AtnpvCs2adzT1/GJbh6KdTGW+t0z3LPWbj+r8RNq0q13Sf7JlcOfzbjgLDtdWY/6SSxxG714sXf6izr8aOhXJB6LwNxaWAr+rsj+J1ovIwdJha5pgeubYAXHukIKZDkOmxoX6lokkeELV7Ipu2beuqmO7MS8+kZupJd0iR+fjdD7VGw60rzfFueeoO5V4d/czCofv7K11+UTSGdp77xW4hxLla2x8tpoyKXmdTPH5/0FqaKduHD8POkboyCqDnD8OVHP2wymg7Yn1zbe9jPh+zwYsLu5mwtKotx4r/45mOf6n3uNbOXC2XL9UNXTfXJWReqvPsHk5mNxBIKQ/HV8HO0ZW6enmTl9HPsfyQD9Y4YGCIhnKxe6veZ/c7V+ogdyKWPAwdApJ5/kD3Ih0o/XctowVN/sO1RYq6jGu4RpsGuvj7ikaca9Om+pUKjjyt/LCnTfvidV4ztbj3RyvdR+spePNqV//vT/W/ir7L/Abu9MxjV+2dvlS6l1N+UTdzW68+5Y2I7u0/oWeaHdMoE6OID8fYDHJeOwLIy+QzQM6MZzQev6sHLYo/2/RJtc6CxsAtRJsOc1fhlhz9sHWIcLC+8lj7GdY3l9Zn9Fj94VrdPf5Pe1wOOh1fb5BZT/bVTharTgPHVzOedIdjw+hd5cnC8e8us3DXv//+oy4uzPp1G3madEhAZnsv3Oc7HLZF3UlVB1Ndxw1G8KGB7nqV/ffbkwt6zlY21dTVr1Tw2sOWftjDpimIbtjY+MLy0y/V1I7B0LszhAwfSYn/0zWHKrOduZUawv/+0t/atbI3B2qu69zPqh4y47V6qctgpOF37FfayFMidYNaZK1+eMuO//5zoU6yKT9HP0wd3l3rmcmOOrrCIKocxkpL0TfqfCDcZVsVUu4Nx1cdCdZXXsMY6x9ZA1WCLQ9Ph6HMj6ikrRrnVHVwxLrcxqbR+pUIjjw9/LC9Tfsz6Q0bukXW7hZnEENXUAxDuE8ejtg0m95+zdFPUxka/Z8gVIHw2lidWD3ouBaz4qCORtpbDGYfWeBiy0hVB6delwUek+28evfklSMcDYWm7EjHhjj8kBumzsUt9/9Vv2hGPG48lw1r9jgkHP2002GpEzsdX/eZH8RmKKng+GoiXHQiGuIZWJ7YkkA1qergiHV5RJuy4MjTQ+b2Nu3PdGde2IyhVTZ3L3J5UNijIXyCGdkNQrX+xgLdIH32XnhGPJfWDOYHtMhKv98v9NQ5J4/nLkZ78bENbw7mUcfCAixCI0OtX3L0012HJsxTmlXaQcewjZBp6PTIJluANzMx+meRnvJwfDUld9dmU4X/igVmZGdnZlF+SHmwKWd5FN9FHNLKfyK2GLMuj23TJljtYUeZu9i0LVn43ts4Qvm88h04+UJdSG3FxWKl9zIzoF1jtIf/Xa9wXh62KcpF36vyfQH9hl1UdjExlPvp9n1hLY8Wcb9cKRf2Rkjh+tUXs6PLu77ZhRPExr1nr/ydwMvafD3U6rniOpnsP6t1AZxc4XNz4fgGaNIP6KLDOrmNj5lt0YbcT5pscfn6s34m7DDV/1Yf6Do/td0/3R3S4EVfiBbPw3tAlnO1LPkG6COPs2nMV+tsHasLMahRKgzUyv4VlYf8Iq7DT3dvCnXLgTrmn0LB1U+T/zhitgDc63Bo0n9Mh4Bbv5pI4WOOqN2D9tIR2pRDeJ+cosyFe3p2mnQyyqyCdXSs3x3jsNjW2s/xBWEIOIMrYbuYkk0nvWGDtmnivQ09zd/ktv0pghGbHlzqEYp0XIIgPD+mvdtQgzjsu9Va7RydZnHY546LQ69XxTCPIAjCc2HSYUNBEARBqGLyMy9BEARBCJHOSxAEQdg6pPMSBEEQtg7pvARBEIStYys6L2wLD3PyDAlenmuTIgEYGfPDPtv+fRuw2xApLaZ44seUwHtwQ9uiDVOTZyqMWXeGILOr/WyiXm67DrtAnVfpwSsUQA16pExowPDT9T0telt+UTyeBN9R0rsgAZqTcRPOw8oh1YJYQ4f335DGAQn3nmMH5uzPffahDy2emjzbRsq609YWfUHbad63zA/7bXvSRApStz9jEfYrbezmzbyKp2fTx3uD+ttH851J9objO2xZW8bljaosoz/hESQc6H0mHPOyXlX//c5cuaMPUbZ4oO524irfUt1EnRD6Hipp3e9CqlxmqZiaPEJ/DtHoPJUT4QrN0PFQOFrL9hH7q3WrAfmkw4ZZ0jybfC8Dpxw/3arbtXcC8uGe2r1dbdWoI8R01jgTcl+9u2wenT9e3mjtFA9YFQRhZBLkWHuO0ITIP2KK8vlVH6pcxWQ7LzTkyOoZO6n4+w/ddNsHPdxDQkrT4PspTkrhzIowXBg2rToscoyYMsKBH1vkJnPhQ62E2tmXe/7OYVs8N4VnzTqbe/5wdMTRT6cy3lqne5Z7e+gqRmlV5Zrukz2TK4d/u1B0sLYa859U8jhi98KZcmdfjR0L5YLQeRuKIZty3rSoPAwdppY5pkeuLQDXHimI6RBkemyoX6lokgdE7T6AH3ZlujMvm9UTGZJrQdoUNNy60hzvlqfuUK45uNaEL/f3V7r8omgM7TzmdOw8xBkmHDRlcPJz/XU2hUkZPlODZgfZOVJX9qRwPH8YruToh1VG2xHrm8VkgXmSQBd2M2Fpc5J1VTlW/B/PdPxLvce1duZquXypbui6uS4h81KdZ/dwMruBQEp5OL4KkAWYTrV3ZfRzLD+0S3rnNvxQ9oXsfudKHeROxJKHoUNAMs8f6F6kA6X/rmW0oMl/uLZIUZdxDddo00AXf1/RiHNt2lS/UsGRp5Uf9rRpmZoUSDV4nddMLe790Ur30XoKYsnNzG/gTs88dtXe6UuleznlF3Uzt/UqP7iWDvpFzzQ7plEmnZB8jM0g57WznbxMPgPkzHhG4/G7Qooif7bpk2qdxU95YDpMkyiUox+2DinpXUUa8oFY31xan9E15uFa3dlcUA50Or7eIDPlfytl6+sHx1cznnSHY8PoXeXJwvE2bRBArquLCxPCaSNPkw4JyGzvlefU4tedVHUw1XXcYAQfGugi64X977cnF/ScrWyqqatfqeC1hy39sIdNqzj8jBQ9DRMWj+iGjY0vLD/9Uk3tGAy9O0PI8JGUiCyxVJntzK3UEPrZQllZiM1oACMNv2Ovzp2zOVI3qEWKyS5xWPJJFt7k6IepQ4p5j5X+v+gbdT4Q7oaqCin3huOrjgTrK5TMNZYhmi0PT4ehzO2zpKeqgyPW5TY2jdavRHDk6eGHfTPf0+YN7H2JTCRCJr1hQ7fI2t3iDGLoCophCPfJwxGbZtPbrzn6aSpDo38//f8onVg9pkIVB3U00t5iMPvIAhdbRqo6OPW6/NxAqNJ0XOXEmDEm23n17cmrRzgaCk3ZkY4NcfghN0ydi1vuTTr4YVPO94A1exwSjn7a6bDUiZ2Or/vMD2IzlFRwfDURLjoRDfEMLE9sSaCaVHVwxLo8ok1ZcOTpIXN7mxqwfkgz37rXoSJMd+aFzRhaZXP3IpcHhT0awieYkd0gVOtvLNANEvJguZg34rm0ZjA/oEVW+v3eT41uKruJ5y5Ge/GxDW8O5lHHwqgGoZGh1i85+umuQxPmKc0q7aBj2EbINHR6ZJMtwLvQRome8nB8NSV312ZThf+KBWZkZ2dmUX5IebApZ6kbqzbXaeU/EVuMWZfHtmkTrPawo8xdbArM5hldoRiZ8rPwvbdxhPJ55Ttw8oW6kNqKW3Fj2jVG7yv1S5aYh22KctH3eLmt4oHpN+yisouJodzh1JQWcZHW2oa9EVK4fvXF7Ojyrm924dhCDu/ZK38n8LI2Xw+1eq64TiZ7JCW3k6vtlNzB8Q3QpB/QRYd1chsfM9uiDbmfNNni8vVn/UzYYar/rT7QdX5qu3+6O6TBi74QLZ6H94As52pZ8g3QRx5n05iv1tk6VhdiUKNUGKiV/SsqD/lFXIef7t4U6pYDdcw/hYKrnyb/ccRsAbjX4dCk/5gOAbd+NZHCxxxRuwftpSO0aRN113H41yuU9ew06WSUWQXr6Fi/O8Zhsa21n+MLwhBwBlfCdjElm056wwZt08R7G3qav8lt+1MEIzY9uNSjKum4BEF4fkx7t6EGcdh3q7XaOTrN4rDPHReHXq+KYR5BEITnwqTDhoIgCIJQxeRnXoIgCIIQIp2XIAiCsHVI5yUIgiBsHdJ5CYIgCFvHVnRe2BYe5uQZEryk1yZFAjAy5od9tv37NmC3IVJaTPHEjymB9+CGtkUbpibPVBiz7gxBZlf72US93HYddoE6r9KDVyiAGvRImdCA4afre1r0tvyieDwJvqOkd0ECNCfjJpyHlUOKSajr8Hnw/hvSOLRJmf074ezPffahDy2emjzbRsq609YWfUHbad63zA/7bXPSRCpS6nBsug7qvJlX8fRs+nhvUFPKZiiHkr3hyBFb1pZxeaMqy+hPeAQJB3qfCce8rGsObdyZK3f0IcoWD9TdTuhoFhw343S5Wld2UtD3UEnrfhdS5TJLxdTkEfpziEbnqZwIV2jGDTSW6qZThzvpsGGWNM8m38vAKcdPt+p27Z2AfLindm9XWzfqCKFBgn/sCuW4qj5o9PHyRmuneMCqIAgjkyDH2nPDTExw/u2+enfZLRIx2c4LD4esnrGTir//0E23bdQP95CQ0ijBT3FSCmdWTE3DsKl/KKVjijFlFz7USqidfbnn7xy2xXNTeNass7nnD2eCHP10KuOtdbpnubeHrmJGWlWu6T7ZM7ly+LcLRQdrqzH/SSWPI3YvnCl39tXYsVAuCJ23oRieLudNi8rD0GFqmWN65NoCcO2RgpgOQabHhvqViiZ5QNTuiWyKtutjzzyM05152aye0ZTQSJuChltXmuPd8tQdyjUH15rw5f7+SpdfFI2hncecjp2HOMOEg6YMTn6uv85w1KQFsZiU4TM1aHaQnSN1ZU8Kx/OH4UqOflhltB2xvllMFpgnCXRhNxOWNidPV5Vjxf/xTMe/1Htca2eulsuX6oaum+sSMi/VeXYPJ7MbCKSUh+OrAFmA6VR7V0Y/x/JDu6R3bsMPZV/I7neu1EHuRCx5GDoEJPP8ge5FOlD671pGC5r8h2uLFHUZ13CNNg108fcVjTjXpk31KxUceVr5YU+b9sXrvGZqce+PVrqP1lMQS25mfgN3euaxq/ZOXyrdyym/qJu5rVf5wbV00C96ptkxjTLphORjbAapTz2dl8lngJwZTyoOPyNtRaQTf/yukKLIn236pFpn8VMUmA7TJArl6IetQ0p6V5GGfCDWN5fWZ/To4OFa3dlcUA50Or7eIDPlfytl6+sHx1cznnSHY8PoXeXJwvE2bRBArquLCxOubiNPkw4JyGzvlefU4tedVHUw1XXcYAQfGugi64X977cnF/ScrWyqqatfqeC1hy39sIdNUxDdsLHxheWnX6qpHYOhd2cIGT6SEpElliqznbmVGkI/WygrC7GZ+WCk4Xfs1blz0kKbN7AeHOlcHakb1CLFZJc4LPkkm/Jz9MPUIa3vjZX+v+gbdT4Q7vysCin3huOrjgTrK5TMNZYhmi0PT4ehzO2zpKeqgyPW5TY2jdavRHDk6eGHvTPfd2DSGzZ0i6zdLc4ghq6gGIZwnzwckRpM303HVU4WV8Wmt19z9NNUhkb/fvr/UTqxeszgoTioo5H2FoPZRxa42DJS1cGx67IwDJPtvHr35JUjHA2FpuxIx4Y4/JAbps7FLfcmHfywKeeLIKZOo8G6VwR8WLPHIeHop50OS53Y6bixdJD5QWyGkgqOrybCRSeiIZ6B5YktCVSTqg6OWJdHtCkLjjw9ZG5v0/5Md+aFzRhaZXP3IpcHhT0awieYkd0gVOtvLNANEvJguZg34rm0ZjA/oEVW+v3eT41uKruJ5y5GefHRLChrCZjZo98czKOOhRkcQiNDrV9y9NNdhybMU5pV2kHHsI2Qaej0yCZbgHdh3BI95eH4akrurs2mCv8VC8zIzs7MovyQ8mBTzlIPzNpcp5X/RGwxZl0e26ZNsNrDjjJ3sWlbsvC9t3GE8nnlO3DyhbqQ2opb0cjSrjHaw98vWWIetinKRd/jRd6Kxp1+wy4qu5gYyh2G4WgRF2mtbdgbIYXrV1/Mji7v+mYXji3k8J698ncCL2vz9BDKEgLZ/Lf3s/KRlNxOLm74MYTjG6BJP6CLDuvkNj5mtkUbcj9pssXl68/6mbDDVP9bfaDr/NS6/XR3SIMXfSHSc3gPyHKuliXfAH3kcb4R89U6W8fqQgxqlAoDtbKfRuUhv4jr8NPdm0p/Dv2Yq58m/3HEbAG41+HQpP/G9odZv5pI4WOOqN1r2qjQphzC++QUZS7c07PTpJNRZhWso2P97hiHxbbWfo4vCEPAGVwJ28WUbDrpDRu0TRPvbehp/ia37U8RjNj04FKPUKTjEgTh+THt3YYaxGHfrdZq5+g0i8M+d1wcer0qhnkEQRCeC5MOGwqCIAhCFZOfeQmCIAhCiHRegiAIwtYhnZcgCIKwdUjnJQiCIGwdW9F5YVt4mJNnSPDyXJsUCcDImB/22fbv24DdhkhpMcaJH0J7uvjPmIzpq4IwFNR5lZy5wqmpQkbK4IXZyt/tp+t7WvS2/KJ4PAm+o6R3QQI0J+MmGnVWDqkWZPqsaFjw/hvSOCDhnnRgzwvn+33sntpXx6DcRm3uwGZhGngzr+Lp2fTx3qCm9PRweEr2huM7bFlbxuWNqiyjP+ERJBzofSYc87KuOaB2Z67c0YcoWzxQdztxjdNS3UQbFuh7qKR1gjA1/Bxa+OgmRh1dfZEO7Bkz6bBhljTPJt/LwCnHT7fqdu2dgHy4p3ZvV1szkqzCdNY4E3JfvbtsTnHyeHmjtVM8YFUQngPG9ytOQBeeDZPtvNCQI6tn7KTi7z90021Pjz7cQ0JK0+D7KU5K4cyKMFwYkqg6LHKMdQKEAz+2yE3mwodaCbWzL/f8fY/XKoaNyyGbmJ5xHtrZVyNDoZwX9jX6LV/XzUR9+ZtsQb/Ttc3aoCsXhtrqQuFt7gXCMtWHjTYT06H77d4eOIuQcV52mPXgmDyOvnbvzJtXakvTkgmJmO7My2b1rE1/D5A2BQ23buCOdx9UWBQVxhxca8KX+/srXX5RrGC64TGnY+chzjDhoCmDk5/rr7MpTMrwmRoqO4jbHEIn9WfPf67UQX5Djp4BMtjSieyuzM6RWn6wAw2bAifM3eVSvtxcGuOybaGvfWVPN0e5MMQKmf1Ek/vW6DgdOztBm3Evjv9wwHWW6jy7hruX60hdWN6E5I2cruwQiRQ5Nm1l9/kD+Q+tsyltm57RgsNT3Yk/3arrsfNiCZPB67xmanHvj+b6j9b7EEtuZn4Dd3rmsav2Tl8q3cspv6ibua1X+cG1dNAvWpbZMY3w6YTkY2wGqU+zn5fJZ4CcGc9oPH5XSFHkzzZ9XKPXZc0RZKFbm2IGIC/SxYVZ6+ToOUM3Nu9tCBhlKJeajftU6RS6P0CmTZsQsq0t/DQNppM3yU3D6xJ31xRydnrk3CsvU+8/XLCm49so1M+Y8OpOS7tb/8lzarWvO+jgXdu0mD2p2/PcJ4XnR3TDRtcGLxlPv1TTyg8apd0ZQoaPVDGQJZYqvJ25lbLj+9lCWVmITUJEjB79jr06d87mGKqRo8SfsWzCHD07GhKIlmaRZB/dSGXD6za2KCboxAHPJzYkW9mAHp5SvqDcFxj3SpzFOgxjdg0/9oZj0x5275ol3d+0sf/+Qc2v7jc6wBY2y6Q3bOgWWTchcfxGaUiKYRr3SR+u6Uop23ACMLPIJrkjkGVytb0XhQyfyuHgdLbIow3IXu3P1Bxj2R0dlx/GxKdL+PG54HzFZUEXnh+T7by6js4yqkaAADsV3ajcpgz3Q24IhxS33Jt08MOmnO9B4tG/j5vJRkM8HD234A6xMrqfCe2tb/zQUBpbuLBh2DEVOy7GvVj+00z2N7EZ7phwbNrD7rElAUHgMt2Zl13An7sXuTwolNUQgspG8d4iPRoJ5MFy6xjZuoIdvdHvyNxMpQ15iGkxyReC3YaGusYCi+qYWXQNr9xdY4NAcTs+Gv+zM7Moz9FzK8juM7X3AZ1y8blS2yLWMXHuxfEfHqaj1L1gNoswMzH6ZxHbYQ45mGLVnY52x+7R5VF8FzEHrH9BP8XBjSELv4abeITfCsrnle+qqs/KW1uZKlL0w0G/0PtK/ZIl5qGUolz0vapOQ02/YWecXSAO5X66fV9Yy6NFd6S1tksZGI1fv/pidsV51zc7q7z1DuA9e+XvBF7W5uuhVs8V18lkj6TkdnKFz92GcqNcliWm5zo56+zYJHOTLTj+bMpUKDrw58Z7Mf2nCVNnzDZ4gGc/V8vK64Rl9V2jzxrC9dWmugNYdi+prxyebaJcL6qfuXDPirZJ+H2YdDLKrNEUJ6zENELYqsxvuIS88dctdaExdt//7NC4CmU4gytB6MqkN2zQ1lu81zLL33cRDJg5mPZXOq5UmBBsxQ46QRAmx6RnXg5OGOg54WakWiEyQ+hIddiwXYhXiCMzL2FItqLzEgRBEASfab/nJQiCIAgVSOclCIIgbB3SeQmCIAhbh3RegiAIwtaxFZ0XdoYNlbOoCrwQ2fbtfCNjfqjqkG/3Y7ch0pRM8cSPKYH34KZ00sLU5JkKY9adIcjsaj+bqJfbrsMuUOdVevAKBVCDHikTGjD8dH1Pi14cXRSPnMF3SFAYJrVzMm7CedyJ15SvCEf99CDUdfg8eP8NqTmQkPA5dmDO/txnH+LQYp+pybNtpKw7bW3RF7Sd5n3L/JzMTby+klKHY9LU1sXwZl7llCj+uxnfPprvTDI8vA9jy9oyLm9UZRn9qTrmpwl6nwnH4KxX1X+/M1fu6EOUbXsg6hSBMenIJKfL1bqyk4K+wwSLQpG+ucxSMzV5hP4cotGpyHwgNMNt6+qYdNgwS4RoExhm4OTqp1t1u/ZOtT7cU7u3q60adVRBgwT/hU6bJLHqINbHyxs6xNY/NFcQhJFpOCRcqKZNW1fFZDsvzKSQqTV2+vT3H7rptg96uIeElCYc46eoKIUzK2LBYdi0dPCCZooxZRc+1EqonX255+8ctsVzU3jWrLO55w9HRxz9dCrjrXW6Z7nHbFz/N0ZpVeWa7pM9kyuHf7tQdLC2GvOfVPI4YvfCaRVnX40dC+WC0HkbiiGbr8Xsx5qoPAwdppY5pkeuLQDXHimI6RBkemyoX6lokgdE7T6AH3ZlujMvm6kVGZJrQfoMNNy60hzvlqfuUK45uNaEL/f3V7r8omgM7Tz3i91CiDNMAmjK4DSm+usMh8noW7dOUso+PAQ7R+rKHkeF5w/DlRz9sMpoO2J9s5hnK0/86MJuJixtTievKseK/+OZjn+p97jWzlwtly/VDV031yVkXqrz7B5OZjcQSCkPx1cBMjvTSfOujH6O5Ydy2qAYbsMPZV/I7neu1EHuRCx5GDoEJPP8ge5FOlD671pGC5r8h2uLFHUZ13CNNg108fcVjTjXpk31KxUceVr5YU+blom3dSFe55VnlXWfrqP1FMQS1pnfwJ2eeeyqvdOXSvdyyi/qZm7rVX4eIh30i55pdkyjTDp77RibQc5rz7PLy+QzQM6MJxWHn5GKJNKJP35XSDvlzzZ9Uq2z+GksTIdpEoVy9MPWISUyHO9gXOSCMlrVNebhWt3ZXFkOdDq+3iAz5e8qZWDsB8dXM550h2PD6F3lycLxNm0QQP6yiwsTwmkjT5MOCchs75XnSePXnVR1MNV13GAEHxroIuuF/e+3Jxf0nK1sqqmrX6ngtYct/bCHTatobOsCohs2Nr6w/PRLNbVjMPTuDCHDR1IiMv9SZbYzt1JD6GeAZWUhNqMBjDT8jr06H1JaaEET68GRztWRukEtEiSF/OdCnbw9sTJx9MPUIcW8d9TRFQZR5TBWWoq+UecD4W6oqpBybzi+6kiwvkLJXGNZm9ny8HQYytw+S3qqOjhiXW5j02j9SgRHnh5+2DfzfZu2zjHpDRu6RdbuFmcQQ1dQDEO4Tx6OSA2m78aY1QkZQza9/Zqjn6YyNPo/QagC4bWxOrF6TIUqDupopL3FYPaRBS62jFR1cOy6LMRp29Y5Jtt59e3Jq0c4GgpN2ZGODXH4ITdMnYtb7k2K9iHTrocgpk6jwbpXBHxYs8ch4einnQ5LndjpeLp3ZH4Qm6GkguOriXDRiWiIZ2B5YksC1aSqgyPW5RFtyoIjTw+Z29vU0KqtC5juzAubMbTK5u5FLg8KezSETzAju0Go1t9YoBukz7qLdzFvxHNpzWB+QIus9Pu9n+7eVHYTz10MtgPIxywoawmY2aNNAsV6x8KoBqGRodYvOfrprkMT5inNKu2gY9hGyDR0emSTLcC70EaJnvJwfDUld9dmU4X/igVmZGdnZlF+SHmwKWepG6s212nlPxFbjFmXx7ZpE6z2sKPMXWwK2rR1Wfje2zhC+bzMRbDZJV+oC6mtuBU3pl1jV3P10DOxXx62KcpF3+PltooHpt+wi8ouJoZyh1NTWsRFwjzdSAKEFK5ffTE7urzrm104tpDDe/bK3wl+gsNQlhDI5r+9n5WPJPtzcrWdkjs4vgGa9AO66LBObuNjZlu0IfeTJltcvv6snwk7TPW/1Ycs9f+nu8NCks/wHpDlXC1LvgH6yON8I+ardbaO1YUY1CgVBmplP43KQ34R1+GnuzeV/hz6MVc/Tf7jiNkCcK/DoUn/je0Ps341kcLHHFG717RRoU2baNPWFcp6dpp0MsqsgnV0rN8d47DY1trP8QVhCDiDK2G7mJJNJ71hg7Zp4r0NPc3f5Lb9KYIRmx5c6lGVdFyCIDw/pr3bUIM47LvVWu0cnWZx2OeOi0OvV7xwpCAIwu/GpMOGgiAIglDF5GdegiAIghAinZcgCIKwdUjnJQiCIGwd0nkJgiAIW8dWdF7YFh7m5BkSvKTXJkUCMDLmh322/fs2YLchUlqMceKH0J4u/jMmY/qqIAwFdV4lZ65waqqQkTJ4Ybbyd/vp+p4WvS2/KB5Pgu8o6V2QAM3JuIlGnZVDikmo6/B58P4b0ji0SZkt/B443+9j95S+OhblNmpzBzYL08CbeRVPz6aP9wY1pWyGw1OyNxw5YsvaMi5vVGUZ/QmPIOFA7zPhmJd1zaGNO3Pljj5E2eKButsJOi46bsbpcrWu7KSg76GS1gnC1PBzaOGjmxh1dPVFOrBnzKTDhlnSPJt8LwOnHD/dqtu1dwLy4Z7avV1tzUiyDhok+MeuUI6r6oNGHy9vtHaKB6wKwnPA+H7FCejCs2GynRdmUsjqGTup+PsP3XTbRv1wDwkpzenjfoqTUjizIr4fhiT8QykdU1wncOFDrYTa2Zd7/r7HaxVDmeWQTUzPOA/t7KuRoVDOC/sa/Zav68JkvvxNtqDf6dpmbdCVC2evdaHwNvcCYZkq/+EQ06H77d4eOIvZeF52mPXgmDyOvnbvzJtXakvTkgmJmO7My2b1jKaERtoUNNy6gTvefVBhUVQYc3CtCV/u7690+UWxgumGx5yOnYc4w4SDpgxOfq6/znDUpAWxmJThMzVUdhC3OYRO6s+e/1ypg/yGHD0DZLClE9ldmZ0jtfxgBxo2BU6Yu8ulfLm5NMZl20Jf+8qebo5yYYgVMvuJJhGeBTjNOjtBm3Evjv9wwHWW6jy7hruX60hdWN6E5I2cruwQiRQ5Nm1l9/kD+Q+tsyltm57RgsNT3Yk/3arrsfNiCZPB67xmanHvj+b6j9b7EEtuZn4Dd3rmsav2Tl8q3cspv6ibua1X+cG1dNAvWpbZMY3w6YTkY2wGqU89nZfJZ4CcGU8qDj8jbUWkE3/8rpCiyJ9t+rhGr8uaI8hCtzbFDEBepIsLE9rk6DlDNzbvbQgYZSiXmo37VOkUuj9Apk2bELKtLfy0CqaTN8lNw+sSNjzr9Mi5V16Gn7q8Dqzp+DYK9TMmvLrT0u7Wf/KcWu3rDjp41zYtZk/q9jz3SeH5Ed2w0bXBS8bTL9WU3B6N0u4MIcNHqhjIEksV3s7cShMWP1soKwuxmflg9Oh37NW5c9JCmzdmqPvNjeNQjRwl/oxlE+bo2dGQQLQ0iyT76EYqG163sUUxQScOeD55e0J6rGxAD08pX1DuC4x7Jc5iHYYxu4Yfe8OxaQ+7d82S7m/a2H//oOZX9xsdYAubZdIbNnSLrJuQOH6jNCTFMI37pA/XOExYCx1XOVlcFXVhxT5gZpFNckcgy+Rqey8KGT6Vw8HpbJFHG5DR1Z+pOcayuxmoFAeQXcKPzwXnKy4LuvD8mGzn1XV0llE1AgTYqehG5TZluB9yQzikuOXepIPvmuK9CwiP0Ai/7hUBn8Sjfx83k42GeDh6bsEdYmV0PxPaW9/4oaE0tnBhw7BjKnZcjHux/KeZ7G9iM9wx4di0h91jSwKCwGW6My+7gD93L3J5UCirIQSVjeK9RXo0EsiD5dYxsnUFO3qj35G5mUob8hDTYpQXgs0GAC0BM3u029BQ11hgBoeZRdfwyt01NggUt+Oj8T87M4vyHD23guw+U3sf0CkXnyu1LWIdE+deHP/hYTpK3QtmswgXMi5hO8whB1OsutPR7tg9utQDs06+4YF6QtWkMLgxZOHXcBOP8FtB+bxMg4ldVfVZeWsrU0UjS6diXM3Vw/t+yRLzUEpRLvoeL/JWNO70G3bG2QXiUO4wDEeL7khrbZcyMBq/fvXF7Irzrm92VtlCDu/ZK38n8LI2Tw+hLCFhWCsrH0nJ7eTihh+rKDfK5WeK6blOzjo7NsncZAuOP2eDhJDAnxvvxfSfJkydMdvgAZ79XC0rrxOW1XeNPmsI11eb6g5g2b2kvnJ4tonwPnXPXLhnRdsk/D5MOhll1miKE1ZiGiFsVeY3XELe+OuWutAYu+9/dmhchTKcwZUgdGXSGzZo6y3ea5nl77sIBswcTPsrHVcqTAi2YgedIAiTY9IzLwcnDPSccDNSrRCZIXSkOmzID/EKzcjMSxiSrei8BEEQBMFn2u95CYIgCEIJpf4fSNZgJv2To+YAAAAASUVORK5CYIJiOjUcbed3hEI2mOieP9e52+9/teishSKcyZogbJpJHKCho9p4r2qevW8lGLCyMf25DISxMC7nkhOSgiC8WKZxmlSDvYn3q7WaHZ6kexPbjt2bWa+6uaG3lcym/NcmMLkQmQrCNjEJN6kgCIIg9MlkVoaCIAiC0BcyGAqCIAhbjlL/D3x62ORKCzCWAAAAAElFTkSuQmCC"
}
},
"cell_type": "markdown",
"metadata": {},
"source": [
"#### NEW! Support for parallelism with Joblib\n",
"* All methods now return something. For `fit()`, it returns `self`, the fitted model.\n",
"* There is a new attribute called \"name\". This is used in displayed messages for monitoring (see image below). This is helpful when models are run in parallel.\n",
"![fit_print.png](attachment:fit_print.png)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# initialize HSMM\n",
"R = GaussianHSMM(n_states = 3, n_durations = 4, name = \"Model 1\")\n",
"my_init(R)\n",
"S = GaussianHSMM(n_states = 3, n_durations = 4, name = \"Model 2\")\n",
"my_init(S)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.00000000e+00 5.00000000e-01 5.00000000e-01]\n",
" [2.56937871e-19 0.00000000e+00 1.00000000e+00]\n",
" [6.66666667e-01 3.33333333e-01 0.00000000e+00]]\n",
"[[0.00000000e+00 5.00000000e-01 5.00000000e-01]\n",
" [2.56937871e-19 0.00000000e+00 1.00000000e+00]\n",
" [6.66666667e-01 3.33333333e-01 0.00000000e+00]]\n"
]
}
],
"source": [
"from joblib import Parallel, delayed\n",
"\n",
"models = [R, S]\n",
"data = [obs, obs]\n",
"\n",
"# EM algorithm\n",
"[R, S] = Parallel(n_jobs=-1)(delayed(i.fit)(j) for i, j in zip(models, data))\n",
"\n",
"# check if models are updated\n",
"print(R.tmat)\n",
"print(S.tmat)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -264,7 +346,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand All @@ -284,7 +366,7 @@
" array([2, 2, 0, 0, 0, 0, 1, 1, 1, 2]))"
]
},
"execution_count": 9,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -295,7 +377,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 12,
"metadata": {},
"outputs": [
{
Expand All @@ -314,7 +396,7 @@
" array([2, 2, 0, 0, 0, 0, 1, 1, 1]))"
]
},
"execution_count": 10,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -332,7 +414,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -341,7 +423,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -367,7 +449,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 15,
"metadata": {},
"outputs": [
{
Expand All @@ -381,21 +463,32 @@
"FIT: reestimation complete for loop 5.\n",
"FIT: converged at loop 6.\n"
]
},
{
"data": {
"text/plain": [
"<edhsmm.hsmm_multinom.MultinomialHSMM at 0x1f53de0d810>"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# initialize HSMM and EM algorithm\n",
"# initialize HSMM\n",
"T = MultinomialHSMM(n_states = 3, n_durations = 4)\n",
"my_init_2(T)\n",
"# sample observations (made up by me)\n",
"obs = np.array([2, 2, 2, 2, 1, 1, 1, 0, 0, 2, 1, 1, 0, 0, 0])\n",
"obs = obs[:, None] # shape should be (n_samples, 1)\n",
"# EM algorithm\n",
"T.fit(obs)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 16,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -455,7 +548,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.2"
"version": "3.10.9"
}
},
"nbformat": 4,
Expand Down

0 comments on commit f8a7ee2

Please sign in to comment.