From 2cb3c1b01ccea2e8e3fc528ed568836e3c655e81 Mon Sep 17 00:00:00 2001 From: Marc Suchard Date: Mon, 11 Dec 2023 13:59:07 -0800 Subject: [PATCH] initial infrastructure for BEAGLE-based BASTA --- lib/beagle.jar | Bin 26150 -> 26081 bytes src/beagle/basta/BastaFactory.java | 77 ++++++++++++++ src/beagle/basta/BastaJNIImpl.java | 42 ++++++++ src/beagle/basta/BastaJNIWrapper.java | 44 ++++++++ src/beagle/basta/BeagleBasta.java | 15 +++ .../coalescent/basta/BastaLikelihood.java | 4 +- .../basta/BastaLikelihoodDelegate.java | 4 +- .../basta/BeagleBastaLikelihoodDelegate.java | 100 ++++++++++++++++++ .../basta/GenericBastaLikelihoodDelegate.java | 35 +++++- .../StructuredCoalescentLikelihoodParser.java | 22 ++-- .../treedatalikelihood/BufferIndexHelper.java | 4 +- 11 files changed, 332 insertions(+), 15 deletions(-) create mode 100644 src/beagle/basta/BastaFactory.java create mode 100644 src/beagle/basta/BastaJNIImpl.java create mode 100644 src/beagle/basta/BastaJNIWrapper.java create mode 100644 src/beagle/basta/BeagleBasta.java create mode 100644 src/dr/evomodel/coalescent/basta/BeagleBastaLikelihoodDelegate.java diff --git a/lib/beagle.jar b/lib/beagle.jar index 7c311997b90cdee5b355ac989da4d2d4a90b5630..4af9ac97f83fdbda1320d259a81f5bcbd46d67e8 100644 GIT binary patch delta 6183 zcmZvAcQD-VzdegYNTT=NSzQpKMenS-(d#N9dM9~D4b~D?FVV{qorn@;wdhfzBt*2R z!K$m@&*%4>@4a(p?laG8&g+l!dS;$~UUSYV8^v82!F>$T#=|GV!6C%KNw$MSlOOW{ z@BhIw$;|ujaBy(){wd0TN>|lbQBYG)O;}e^Pg70R$XG~MO&2aoj20nPJAZ!Li-(k`&eJSr2MOF{3#Do2MHUViS2{IT3=+Mxna|BzrTZFm)w6D%x6V}-zps=` z(ZJI$3v~2js97ILrKxVZ#7$lm#wWCeLmR1ZBOX8a_ZQ8L1Wn?|0une;=pU=Q)7#yf zmC8WA?s4vaODb4pMYJeyzrl@y}m7Iy5I*u=3K9StfhWA(%y{Fjv)Y4`>U`L zKjKb}XU7;*+Z4iWpoa{WySIN{<}nSG^0b~FSpWG>x_BIO4w-n}P)YclHrEj9UjoRQ_vJ@`eoqGQuBe~v5 zYtA!c%8Yz1J(al|c)bvfNZmS%Mg+f$Mqq$M(7@9oVgkyElo#Oq+w~_fiGQt_o z*5C=ouvXc|nG4I-<7lI!!7#}xd-%7l>0~V9xSsjN8-54XLBncc4o0rJ0?w$i4fV~ZublOzfH&qoz zx_J$m-~{{Eg3sbOU)QCOW-Tz{iZru@`nYTtWe~>vTx5Ev@#MY=izbqT^)GHGJK6n; zYqI;T*V{Jb*JOO*2y@_P+zg5$8=kwh@WmPGd@%ltN#-Y2VTN$`dg++*_UcUiiH-(i z#JkFT0z3{0W56eRB zojiP%VJVR_Qi=+j`TM{;)H(=%(& z!jnPJ(ygXms19}cXg3Al;zzm`5Iops&ZwG8T|vh{x@E)CaPU~jnS`{}ARL<=$R1DY z;Biy3?xhJ(W$A7>sm-DhyNpgT zy*EBXG#hgf1}DiF147)-Ky+!W6^{(mDo6kB5}h0+9TWei0onZ-K(MNso4vI0 zdjH6Z7&30Dw!}|39mq1L3mf$Ck0R_F?j{mPNA+zrxK#kM>`i z!3E75AW2WRCGXsaxtW6`DF_Pq=ifb9&M*7=bMm=gP!H)}8fD}P5FL<8o9#Xl=Ux3s zT8kWs!6Ls??L~6ZhAvWy0=MQavGZOBAYk0XEpJd=9^L7ya*ykMku!tbGdZ~v+LOJ~ z!IDi|^(tp9ySa?{mzln2=C5a>WXx}7c$un)>rycniSbt|WI+^)l3}&^Z$3QbbyT)I zkV=VT#W>dF+Xbhh+=zz$NY@P|?>E278J?1}QEkJ8w(&vR&I6ms?=Hy}qM#Bq7~te6 z?mL7A<3iPRIeZjeX>3&ipKoA%_&krDnXD-A$&L97K}AderlOeo zK$#Hf&F7QaxkX$;Of*76Bf=t41SkF+GQjfS!ovJZ{M_-NYM6G>i(R}gU#4@XM)A}L5ar(|nh=lc{#pU;e0}tHGf?a~BwUs3h?U%%l&&K2 zEt))PfJiQ!ds0M<-kY19s9F`YNUSxl5zF&3M}fS3&Cx3Nv}l zs3+w!M}NSXL|`vUc~4S$S-X__1Pn)n)4s9$_dVT%u@8GS$WIM_0lxZ#;1;Q`;W+$5 zRN+)4p;(9V)@8H0a-xx^9gHhbJY<4O7Qm5|gQbGwZ!e85J-Tkw2KBeLC$m{}^bW;R z3X0^qA3h_}R-?ewB}sNI!Bslgm`{b`J@(ZoCyM4(7ZO&&dWQzYl#E&_IJ$Hsc~(y` zWkCYJwh&K+vyntzm;JwFCZ%&hyeujffoam?l!-=d3KoFx^^~^N#^XC z{|<*YQruO`{%O4cex`HdjLp9eHUhD_KJ_aIc6A)U(wV>p{=11!?6jQIMaa zK|s7C@yV&GdNGB4U%gy1VLJV2byS;X1h^8+|NUEvrC_+y#&bt6R&oY?H6;VBD)Jdf zNU{3#yvsVLiKI3}Jxw=6in7095O}9lPbTaF$A4s;5^w0@D5V!&O)g#?|3Gz6i7!tp z8(+}460CKDdg0J?#1Nfxrd3riC(ix&jb3ZRuOjK*|m!Vo-2+a%!(s#8P95nWB&zhnr%XxWQG@2 zl9pD))e0ppuNq`vHuB;Qtoy7$qs^qQa2vC~X6q_|ZA0~4F zr32;>4|BO7lgN6p3uXM5Kj+jR2P91KIdNqY5MOnFTp1;09pshf>EUO6*FTya6YUgc zJI8aV%nbDuT!Mae?~hb(ItZb4iVV>mOZ_GYZkN@2i#$Ow&mVo) z8T;z?rz6Ya8xi1(RQp9)*xwj;a!!nI4oxb~8(A-J9wLw>%W6vhAbFA9e6$XuJrxLG z8q9p!oa>#~pCzpxCAoiCxaOv+KKsEuz|ZGhtip0%W72uJWD--)c(rw$&^~hVoGCXJ zmeKUr>63N(A%W%l!$hXgGnWh3#WYRRW9tv<%DqYTMN)t~k8)gE;?3eDJTZSwHIAbi zLgIBQTiG?>WbUQ&sJ+oi>c_|A`a_i;uG8LoxrA~cn~!r~d86>a_vwY39TvZ8cUa33Hiu5i4vi>^)suqC?-ll@5y5*88FEsX%+? zZYT^S3!bDLwNQKV+01|FyDF?+CuqDqV~3#DJm@Zz#!H}@90r|!SO_7vSeJzvwi`^#+$bgj1329s^^8T=^R|0NgH&Eu7p9lE<6 zsvs|VdII}U*jFE4jINJyU>dxasLN;}^~nctUz;DQhp%}ghDz^m2EBD48SG`el2Vgm z-U!NdNLYK2*vmM`Yk!Kh?j^N>9mf4sve`Hc<+}atM(E*>_ok}UidX@cMo$Q>?TOyj z-V5c*Eb$&P8nDET_Or%Y<9@H*5x4R;@y^UZpCUE1IAGtUgUBdP5Ncc2X#E*<2Ac)w zTMQLk(FtQJYr$2))`iEenZZ+Tg(U>c)+s|N?l!AT)Hd;-7ZVZnxj(X9v5{#>);L`U z?AYOa+u%==8oX6I{l|n!?`48!9u8p1dgppHt)kK2B!XCH{T;e$!@3|GGg#(#pK9n= z_9LlIsdV9oJH!Qh#^b{1-Ccg_6kA&RS-&?w<&u^h0N9hVJMng8x%=in@uL-{kEy96<&t9Z# zTjmBe2NC4{2LURp^-)Gb6SMm?h`B0Lr9Bm4CEy^4#MoyxhSGH12dc#f6XeG6mW zpCL?}WgTB+3p&=o9lLzI&dbZYMk@vzuM;p&{@_Ir(WH`yad&5uKr0T(Hbe*99Y@#R zP|25t<%)_q2*rcL3@<1hA`*fDrGYx0vQbb}07f^S+moykOnDji-eI3H%rQCu^U z1`g9}?Wu55TS~g zBDH2%QLy!|5{BoLxgCU~EZ(&=?sqJWJA4Q{!hawot!6o?FUW0xdYKf)=D{o-(k_uxTw{I{v3C9`u{9o_sd(Zex%KQvH;1z!+Jod`el^n&sN2 z1zxb_-wrUDf&;C*sA`MU{^MM&{QzhKsw5@0+Oe@JMGza9XitSOojOkUj#pLGr@qmp zdZ}tXVhS`D5l&ii)AOP>qCcfqRHYZDP)16sj;7gLpd#o#5j(ZJ^>Av(Ys|A7*04Xk zc9v~7)fsGTuVaHZJt@j6B9mxrZ!AeMuox18a)oLc?tS93Ue~Ip(Qk8qPe4qAYErb` z^V@{&GdM0!lYLp#Jkbg)x+xvOnvAaoBSJBZ-2mV#9{A0!{F`s5@85PVUjWZ)OT;4BAk8x*HmE>7P@Z@F~&Qw<&+#|6|Y8iZv~V zA9w{~V&7_C(#*iT=k@N|EvEtFSVmRi>yxszz#=c%>5p{5-@E|Gfh;hZLqUdSe=rY|F9{^50o<$y3c_XDA_f!K)-y;p~0d zS=gX$&al3$sH9e&6_p`QS$Ce5wHnQ@>83tYsP(4ylqcvl)9^6$kyVOuqU+R9f)B#g zo6~x7`QCHBE88bwv#8DE{Scveq83SYHsCgbX;ru_25X3F|FL$~gMrOOuK8$^+0?6s zjhXaO8Dtf^Of;92Qose_wHAV|={s`ao-oYKeILw)_vF4G|FLx`Lrwor-M|aJ6HDp( zKgN}<1;X%r@e31$#w+3lpym=@S8IFbRyw|DGhcX>X0a8;9wr^LZL3xgamN>gPy1&V zXhcbV0^|Q16C)k$o<90tX*aUm&Pa0oe_Rzr8y7^0_isG@|IJqmt;<^&9u7{@eH_GT!k9tl7RU;i1_*0~ SG_z(#k~)#zdu;dbL;nSSc(MTi delta 6288 zcmZX2byU<*v^5~zARU5q2};KRf--arGjt9xfYROkq(ed&7*I-5q*Db3knR*jK#&xM zP93D4zVE&D)_dPw>#Vc?IQOh`{*HF9fX=MHyIkjfnDb5?OUYbOf*QY4&w^1dgv=P=~;3ZL#&`1-S@dT$x4^6Fv zYt*?;C|dj8yMLbQ-~7llI19ZU0`BCB5c}_ah;ed5wT&7XaU&^d|Dm{R#dlKFI5?lV zad230kW}_;$N?%x(x9JKx92yz*QZyFGStL`z(A{Z9Vvyn3?*Nsfe=hw2>Fn$v81Hj z@VK=Rv0Rl#x+<@dXWlw zn@QTpz{ofUyTYr1GD9q_!F$!|H-~j~#48B;&Ul-mlcsdg$d-5*SCu@*BpfhUxS|Je ztJ9_Ty@stXN`1O`dm=lF){L%L_Irbcb{#3TmI^-C3S%j1`~aWk@sy2eq-C&Z?e6d? z&u@=aDT2=GwSc3x%yhVsqKXqALJur3+$%Ro4%qZ@tWXhz&aY`|Yl(lm(ZB&jI3=xb zq4YriK($Ok{s$};OGXa}rc#BVfTUxS-s(|s<7p;z-yMbdBBaz4)C4!H2@sbl*Z_K@ zx_^*1)l|JRF~-BaQq1F6kVDB2`OPX<`%xPslHjGSMYYn|vDXdO^1bZtNB4C1W0lo$*K>9m zq=Sboxk*R8;oQZKZDdG+(OW+{7%A&@O3~-Q%A(JfAM4V*b6nIT)}1|6;7D)Z8w5B4 zl7QEf967`OG@@RQPBy%WjsV zqSq%4j{{+=n@=97;>hPIBK_g+(G#53XO*h(z3m3I--D+dp%lQ$fM;`s0dr0Lc7BjL z_RI9JQ-kr;?b|yeZajFdCthPW;uO zotgFToc8R#w!kTlJFvCaO}8s+jaMP#z|p#19Le{0iI-}uVN2uVIpMV7;)nM|QNm?}LXVt**S;4Y%yt9{6ZNKesr@yKGpd@BhYsxscOZR~ z^&S<+l zbK0VCpZ$EX&Hc<9zeMq~3C|i%rt%tRmh$y(hB6m#9mEAtNR#(d^+~&{*!FvGXK!h~ z&EBbv>+FOe8{HHnV3jM6$|p&9a%6DsBd+jTKg*E}*7~a&ZMI+O9o6!6;CI9|%v^U& z@3hbS%$T3w1@oZ_SNmZ0wU|U^UVhFMX+=%{)<&cB{K|z?AA--AJ`voF`m?t)T z{I_#N7J}=u29!r!y^Oyz1}VRkgW0;k@c1}q&}VxX`tf=fVqI(+yWMcSYj7WXo9*lv z@5tiJQ-1E&hId?~Bcb9&K6Y@T6mU^PAmAigd_W(KQB3XERxdu*oocUb8Kg}-vmHSHy!8V5nSMVGH6 zyBajYrwdv)93X*lgG1&XZ&hfVVgmClrF#piYHT@py1uQK#b2DxsXpnA4FMQE{VjGu zN+&eG^0>tXBIpzsx2G_e!hyF=Vv2lyaQvLjbcla zCIVpJ`4d0q5I3xbT%VnyRqnH{1rMa@;3x3=T$7I!>sUcGq$eVDQ+#XP7GbD^3!!O> zPt&?B)9szbnb^71K|XXv3lRS! zz_%$~?bGkA=G@!Mc*63nfzcnqxaYk?`iVz(^CtM$I@GS7P|)UYlaXPw)UI-fNIcG7 zYKPgJotJh*2ieoe0K*U6ZMtQ{4y6z?Q04S9{&<+%onmn|^?|qnCH@D83D(>irHr^j1yW zIgQ*LoZj3tVGE+GwW7-jEurydqmrYPOP!Db?fjwMpE`lI_4Ve8n2{y|>ns zN8`}(zS7uIyKZ_Tqtaa)KM&gl0igTT4b> z&S#0+w*m)VLcCZ1$$|cHIO6~0a1YYkP;DeQIO?=GIE;U}+kX)w)xZ1*a7BP!;ZCCV zu5cfpWD_D!HWTKX0^drDAz_#1kLeGzQ6?#DQX{<5lpp9U-hDAcg^hg_Qex$*)f2{4 zeX04>jx08?<2-vTfZPB1ad{y@uAikwJPE=XH&*-=IYQtG^RPjzIBsnAZF?AhWiB(% zSZM3v<<0GN?cVLLj+^iHz#NXP*!*goI9~JJvPXGR0ly?N!xE)5>G1pYIxFs64*2-eJ3E5N4!{47k=P$90+4lOyf z9(TeCBuA~6Ld75TNRei)XJmY!9N&JXxiGpi5N1f5TarJ<3rfR&3zTF_=7-K);E`Ap z2hbjUOd;gvhO<-czo$*rr_Uwqp$gD?i2DprRo=ydFL+Ub%PEP#-sx*ySwKJkcs2$U ztebt((}XYa3h3%PAUQs$8kv!$@SuP=?IST zHTl}IOS0E{C9f+Z2~0~6-*w-j+wGf38#|sxW_6F$i2%}Saj02HupCJ)a%J;~)N&-L zI44M?_S$(sljfk0RO-#;EG?SQ`ejSfMbMWwJd;*-iR>-{*6gfn=d@_tgPs}GqmPas z0}174a(F%!{OLvS*K-kbp)((UDY42Pe6tWbuH%%Ous`zmfLj>kY5 zSK*4@S*e zL&<5>u+$OK>t&BnZhnmG@=pyCBN~V~xF@qElB*mY%XDAvZw+_vBs%Iw>(e-w*?mRt zmfa+)0R_a2$Mx9~Rd{6q(_oa*lIhGuf+xL23ji5hT(X>Xj&yYY92Ut>JOde9%F(#(c>LCHQ+8oqCmft(*?Tul(*?tftThwXw~mF_blq$)k(Hq& z#v5cMw&ExI*_C`li?;_CzR0lUeZRkU7;v<=K34d^R=v(g*rX`u(&6kHhu?0V*R3tP z8c5~6q7l_7uXOqC{y;gZO?>G2a?h!=GnUzyZeWY&jp@O+0rs$MEWP!6*=cb;vNolfp-(>bnJ>{OX#{Uwvkztp<8_bImbUu;E3q zBVM&b<(4dKy;3whLZ6_0f@a&n8eQE)4k!lrl=V?bE!vt}jQj4XQue6Havdc^Y7CR? z71sMUtEQz-`ObNGyjLD*ZKe4wG8 zRj)|zZb-7@*HJS^;}I5uJ<@M?bXvO!l~kazTR|)n-oKnL`Yy=x_xFlo?-a^z+4hGn zj|PqBCRfwQeQ<#OG(&=+^WAADWlv6h)lTeuLgBJhCP%QRA|}>~DEIum(C_DJC%LTMXs6>^Ez? zH5{>%veot3^ zrGxji$WRLzI%eI1F>A5@6a`&mUO;P|BNSN_6%5UiN4fiOnuYG zq-?kCWvXEn(N<)#+E;(zav8dL87*Hp{u|eMx8|<9g29x*R4DU-9M!P@W@gr^5e0qL zs$abQZ_6|Ci1s0HX7eG0sShBG=S^0rb>Ld6Zwi3g=PIp?DdFx?FJ-|MQ0cP+og za<)vS80@Pub_X;1bIyFKV^#5oi>X$VedHDpWqVBg$&)B@N0mc0kL-SQFe4?mt0$Zo zUQ9+4O~|-LoMWlMCE5Bga&n`aWNGnknrmEJyC#a=&V>vSat{~~6!H+QHC*;4SoWtG zp5F?b+MynjdjY&K9Go2gq^f%R4agh69tE^ok^Q(OAVWxm=Z`?JcO!-q{YGv*? zP=n!ed3ps-8WZ#J;JJ5{)^ofPRGL>R`CyJ`4Q`RRI414fyI4yaU*it9I;a8SZm0=; zAt`avYgA+v4NX1T`r$&ibW6}!DwuqS_@C1yYBceyQT{c(^5hP3nn7kF( z_Byy9IP6k3gU_OGXU@?0?Su#YySNPU3S&09Y#Anzn%n>VG z!wkc4MRXu-Bf-tnABvK{oc>rOb^eTO%IzPu ztggIEHCn;PVvM}aelHsW^Pxja2U^mZFxvV|SAU}hg4=OORne4I@UpmGy}Oq(J3c#4|{+p4fcX*r8lV555X2oIYc9UiCErzbbMh|=ZyIcbvWIJfmkXSq&Hj0k8 zO>&w`#02^?N;{i(S;fuxW5@TH z{Xp~{ji=^qWUC0}A_mBE`IAbTR_q*6u;%|g4)A_}}~KcAMK zhc3L|k~jz!1?A;W$e)KUTTQ4zR=@rf=ml)*o?NNTQQHXlnS^{-lI_?F_OuigToe8f zU)z%SiA%_bZodW9z57|eFVModYRQNri z*Hn!T(jSEI|9R`kdvFwf~W52HLoIba(!rlKc0Pk&|#1Tq5Ky{4s7G zQp8Rdw+|U?r-?^Gf#j#SC;We_|F^^ax9fkMl89s^rM)1>zZw3oHvFHOKS*Qyr+7Pp z$Sy(V|F8AGF9rXvqzwA2Ct3dc&fi=T;f7qXSI2uVjeH@^fz)si!oyQU&VvMzd4J(u sRpgEe7jpM6k69 diff --git a/src/beagle/basta/BastaFactory.java b/src/beagle/basta/BastaFactory.java new file mode 100644 index 0000000000..e150ae4ba0 --- /dev/null +++ b/src/beagle/basta/BastaFactory.java @@ -0,0 +1,77 @@ +package beagle.basta; + +import beagle.*; + +import java.util.logging.Logger; + +public class BastaFactory extends BeagleFactory { + + public static BeagleBasta loadBastaInstance( + int tipCount, + int partialsBufferCount, + int compactBufferCount, + int stateCount, + int patternCount, + int eigenBufferCount, + int matrixBufferCount, + int categoryCount, + int scaleBufferCount, + int[] resourceList, + long preferenceFlags, + long requirementFlags) { + + getBeagleJNIWrapper(); + if (BeagleJNIWrapper.INSTANCE != null) { + + getBastaJNIWrapper(); + if (BastaJNIWrapper.INSTANCE != null) { + + try { + BeagleBasta beagle = new BastaJNIImpl( + tipCount, + partialsBufferCount, + compactBufferCount, + stateCount, + patternCount, + eigenBufferCount, + matrixBufferCount, + categoryCount, + scaleBufferCount, + resourceList, + preferenceFlags, + requirementFlags + ); + + // In order to know that it was a CPU instance created, we have to let BEAGLE + // to make the instance and then override it... + + InstanceDetails details = beagle.getDetails(); + + if (details != null) // If resourceList/requirements not met, details == null here + return beagle; + + } catch (BeagleException beagleException) { + Logger.getLogger("beagle").info(" " + beagleException.getMessage()); + } + } else { + throw new RuntimeException("No acceptable BEAGLE-BASTA library plugin found. " + + "Make sure that BEAGLE-BASTA is properly installed or try changing resource requirements."); + } + } + + throw new RuntimeException("No acceptable BEAGLE library plugins found. " + + "Make sure that BEAGLE is properly installed or try changing resource requirements."); + } + + private static BastaJNIWrapper getBastaJNIWrapper() { + if (BastaJNIWrapper.INSTANCE == null) { + try { + BastaJNIWrapper.loadBastaLibrary(); + } catch (UnsatisfiedLinkError ule) { + System.err.println("Failed to load BEAGLE-BASTA library: " + ule.getMessage()); + } + } + + return BastaJNIWrapper.INSTANCE; + } +} diff --git a/src/beagle/basta/BastaJNIImpl.java b/src/beagle/basta/BastaJNIImpl.java new file mode 100644 index 0000000000..0cdf27d899 --- /dev/null +++ b/src/beagle/basta/BastaJNIImpl.java @@ -0,0 +1,42 @@ +package beagle.basta; + +import beagle.BeagleException; +import beagle.BeagleJNIImpl; + +public class BastaJNIImpl extends BeagleJNIImpl implements BeagleBasta { + + public BastaJNIImpl(int tipCount, + int partialsBufferCount, + int compactBufferCount, + int stateCount, + int patternCount, + int eigenBufferCount, + int matrixBufferCount, + int categoryCount, + int scaleBufferCount, + final int[] resourceList, + long preferenceFlags, + long requirementFlags) { + super(tipCount, partialsBufferCount, compactBufferCount, stateCount, patternCount, eigenBufferCount, + matrixBufferCount, categoryCount, scaleBufferCount, resourceList, preferenceFlags, requirementFlags); + + } + + @Override + public void updateBastaPartials(int[] operations, int operationCount, int populationSizeIndex) { + int errCode = BastaJNIWrapper.INSTANCE.updateBastaPartials(instance, operations, operationCount, + populationSizeIndex); + if (errCode != 0) { + throw new BeagleException("updateBastaPartials", errCode); + } + } + + @Override + public void accumulateBastaPartials(int[] operations, int operationCount, int[] segments, int segmentCount) { + int errCode = BastaJNIWrapper.INSTANCE.accumulateBastaPartials(instance,operations, operationCount, + segments, segmentCount); + if (errCode != 0) { + throw new BeagleException("accumulateBastaPartials", errCode); + } + } +} diff --git a/src/beagle/basta/BastaJNIWrapper.java b/src/beagle/basta/BastaJNIWrapper.java new file mode 100644 index 0000000000..9e83d371ba --- /dev/null +++ b/src/beagle/basta/BastaJNIWrapper.java @@ -0,0 +1,44 @@ +package beagle.basta; + +public class BastaJNIWrapper { + + private static final String LIBRARY_NAME = getPlatformSpecificLibraryName(); + + private BastaJNIWrapper() { } + + public native int updateBastaPartials(int instance, + final int[] operations, + int operationCount, + int populationSizeIndex); + + public native int accumulateBastaPartials(int instance, + final int[] operations, + int operationCount, + final int[] segments, + int segmentCount); + + private static String getPlatformSpecificLibraryName() { + String osName = System.getProperty("os.name").toLowerCase(); + String osArch = System.getProperty("os.arch").toLowerCase(); + if (osName.startsWith("windows")) { + if (osArch.equals("x86") || osArch.equals("i386")) return "hmsbeagle-basta32"; + if (osArch.startsWith("amd64") || osArch.startsWith("x86_64")) return "hmsbeagle-basta64"; + } + return "hmsbeagle-jni-basta"; + } + + public static void loadBastaLibrary() throws UnsatisfiedLinkError { + String path = ""; + if (System.getProperty("beagle.library.path") != null) { + path = System.getProperty("beagle.library.path"); + if (path.length() > 0 && !path.endsWith("/")) { + path += "/"; + } + } + + System.loadLibrary(path + LIBRARY_NAME); + INSTANCE = new BastaJNIWrapper(); + } + + public static BastaJNIWrapper INSTANCE; +} diff --git a/src/beagle/basta/BeagleBasta.java b/src/beagle/basta/BeagleBasta.java new file mode 100644 index 0000000000..de90653b1a --- /dev/null +++ b/src/beagle/basta/BeagleBasta.java @@ -0,0 +1,15 @@ +package beagle.basta; + +import beagle.Beagle; + +public interface BeagleBasta extends Beagle { + + void updateBastaPartials(final int[] operations, + int operationCount, + int populationSizeIndex); + + void accumulateBastaPartials(final int[] operations, + int operationCount, + final int[] segments, + int segmentCount); +} diff --git a/src/dr/evomodel/coalescent/basta/BastaLikelihood.java b/src/dr/evomodel/coalescent/basta/BastaLikelihood.java index f70481f42e..08500b57b7 100644 --- a/src/dr/evomodel/coalescent/basta/BastaLikelihood.java +++ b/src/dr/evomodel/coalescent/basta/BastaLikelihood.java @@ -275,10 +275,10 @@ protected void acceptState() { } // nothing to do private double calculateLogLikelihood() { // update eigen-decomposition - likelihoodDelegate.setEigenDecomposition(0, substitutionModel.getEigenDecomposition()); // TODO do conditionally and double-buffer + likelihoodDelegate.updateEigenDecomposition(0, substitutionModel.getEigenDecomposition(), false); // TODO do conditionally and double-buffer // update population sizes - likelihoodDelegate.setPopulationSizes(0, popSizeParameter.getParameterValues()); // TODO do conditionally and double-buffer + likelihoodDelegate.updatePopulationSizes(0, popSizeParameter.getParameterValues(), false); // TODO do conditionally and double-buffer // update operations on tree treeTraversalDelegate.dispatchTreeTraversalCollectBranchAndNodeOperations(); diff --git a/src/dr/evomodel/coalescent/basta/BastaLikelihoodDelegate.java b/src/dr/evomodel/coalescent/basta/BastaLikelihoodDelegate.java index 03c1c4c087..d5516b4c87 100644 --- a/src/dr/evomodel/coalescent/basta/BastaLikelihoodDelegate.java +++ b/src/dr/evomodel/coalescent/basta/BastaLikelihoodDelegate.java @@ -69,11 +69,11 @@ default void getPartials(int index, double[] partials) { throw new RuntimeException("Not yet implemented"); } - default void setEigenDecomposition(int index, EigenDecomposition decomposition) { + default void updateEigenDecomposition(int index, EigenDecomposition decomposition, boolean flip) { throw new RuntimeException("Not yet implemented"); } - default void setPopulationSizes(int index, double[] sizes) { + default void updatePopulationSizes(int index, double[] sizes, boolean flip) { throw new RuntimeException("Not yet implemented"); } diff --git a/src/dr/evomodel/coalescent/basta/BeagleBastaLikelihoodDelegate.java b/src/dr/evomodel/coalescent/basta/BeagleBastaLikelihoodDelegate.java new file mode 100644 index 0000000000..8a3b430708 --- /dev/null +++ b/src/dr/evomodel/coalescent/basta/BeagleBastaLikelihoodDelegate.java @@ -0,0 +1,100 @@ +package dr.evomodel.coalescent.basta; + +import beagle.Beagle; +import beagle.basta.BeagleBasta; +import beagle.basta.BastaFactory; +import dr.evolution.tree.Tree; +import dr.evomodel.substmodel.EigenDecomposition; +import dr.evomodel.treedatalikelihood.BufferIndexHelper; + +import java.util.List; + +/** + * @author Marc A. Suchard + */ +public class BeagleBastaLikelihoodDelegate extends BastaLikelihoodDelegate.AbstractBastaLikelihoodDelegate { + + private final BeagleBasta beagle; + + private final BufferIndexHelper eigenBufferHelper; + private final OffsetBufferIndexHelper populationSizesBufferHelper; + + public BeagleBastaLikelihoodDelegate(String name, + Tree tree, + int stateCount) { + super(name, tree, stateCount); + + beagle = BastaFactory.loadBastaInstance(1, 1, 1, 16, + 1, 1, 1, 1, + 1, null, 0L, 0L); + + eigenBufferHelper = new BufferIndexHelper(1, 0); + populationSizesBufferHelper = new OffsetBufferIndexHelper(1, 0, 0); + + double[] tmp = new double[16]; + beagle.setPartials(0, tmp); + } + + @Override + protected void computeBranchIntervalOperations(List intervalStarts, + List branchIntervalOperations) { + + } + + @Override + protected void computeTransitionProbabilityOperations(List matrixOperations) { + + } + + @Override + protected double computeCoalescentIntervalReduction(List intervalStarts, + List branchIntervalOperations) { + return 0; + } + + @Override + public void setPartials(int index, double[] partials) { + beagle.setPartials(index, partials); + } + + @Override + public void getPartials(int index, double[] partials) { + assert index >= 0; + assert partials != null; + + beagle.getPartials(index, Beagle.NONE, partials); + } + + @Override + public void updateEigenDecomposition(int index, EigenDecomposition decomposition, boolean flip) { + if (flip) { + eigenBufferHelper.flipOffset(0); + } + + beagle.setEigenDecomposition( + eigenBufferHelper.getOffsetIndex(0), + decomposition.getEigenVectors(), + decomposition.getInverseEigenVectors(), + decomposition.getEigenValues()); + } + + @Override + public void updatePopulationSizes(int index, double[] sizes, boolean flip) { + if (flip) { + populationSizesBufferHelper.flipOffset(0); + } + + beagle.setPartials(populationSizesBufferHelper.getOffsetIndex(0), + sizes); + } + + class OffsetBufferIndexHelper extends BufferIndexHelper { + + public OffsetBufferIndexHelper(int maxIndexValue, int minIndexValue, int bufferSetNumber) { + super(maxIndexValue, minIndexValue, bufferSetNumber); + } + + @Override + protected int computeOffset(int offset) { return offset; } + } +} diff --git a/src/dr/evomodel/coalescent/basta/GenericBastaLikelihoodDelegate.java b/src/dr/evomodel/coalescent/basta/GenericBastaLikelihoodDelegate.java index 87acd46cdb..182148a599 100644 --- a/src/dr/evomodel/coalescent/basta/GenericBastaLikelihoodDelegate.java +++ b/src/dr/evomodel/coalescent/basta/GenericBastaLikelihoodDelegate.java @@ -1,5 +1,8 @@ package dr.evomodel.coalescent.basta; +import beagle.Beagle; +import beagle.basta.BeagleBasta; +import beagle.basta.BastaFactory; import dr.evolution.tree.Tree; import dr.evomodel.substmodel.EigenDecomposition; import dr.math.matrixAlgebra.WrappedVector; @@ -31,6 +34,34 @@ public GenericBastaLikelihoodDelegate(String name, int stateCount) { super(name, tree, stateCount); + BeagleBasta basta = BastaFactory.loadBastaInstance(1, 1, 1, 16, 1, 1, 1, 1, + 1, null, 0L, 0L); + +// Beagle basta = BeagleFactory.loadBeagleInstance(10, 10, 0, 16, 1, 1, 1, 1, +// 1, null, 0L, 0L); + + +// beagle = BeagleFactory.loadBeagleInstance( +// tipCount, +// numPartials, +// compactPartialsCount, +// stateCount, +// patternCount, +// evolutionaryProcessDelegate.getEigenBufferCount(), +// numMatrices, +// categoryCount, +// numScaleBuffers, // Always allocate; they may become necessary +// resourceList, +// preferenceFlags, +// requirementFlags +// ); + + int cumulativeBufferIndex = Beagle.NONE; + /* No need to rescale partials */ + + double[] tmp = new double[16]; + basta.setPartials(0, tmp); + this.partials = new double[maxNumCoalescentIntervals * tree.getNodeCount() * stateCount]; // TODO much too large this.matrices = new double[maxNumCoalescentIntervals * stateCount * stateCount]; // TODO much too small (except for strict-clock this.coalescent = new double[maxNumCoalescentIntervals]; @@ -154,12 +185,12 @@ public void setPartials(int index, double[] partials) { } @Override - public void setEigenDecomposition(int index, EigenDecomposition decomposition) { + public void updateEigenDecomposition(int index, EigenDecomposition decomposition, boolean flip) { decompositions[index] = decomposition; } @Override - public void setPopulationSizes(int index, double[] sizes) { + public void updatePopulationSizes(int index, double[] sizes, boolean flip) { assert sizes.length == stateCount; System.arraycopy(sizes, 0, this.sizes, index * stateCount, stateCount); diff --git a/src/dr/evomodel/coalescent/basta/StructuredCoalescentLikelihoodParser.java b/src/dr/evomodel/coalescent/basta/StructuredCoalescentLikelihoodParser.java index 67bf03079c..f042cfe738 100644 --- a/src/dr/evomodel/coalescent/basta/StructuredCoalescentLikelihoodParser.java +++ b/src/dr/evomodel/coalescent/basta/StructuredCoalescentLikelihoodParser.java @@ -59,6 +59,7 @@ public class StructuredCoalescentLikelihoodParser extends AbstractXMLObjectParse public static final Boolean USE_OLD_CODE = false; private static final boolean USE_DELEGATE = true; + private static final boolean USE_BEAGLE = false; public String getParserName() { return STRUCTURED_COALESCENT; @@ -114,14 +115,19 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { generalSubstitutionModel, subIntervals, includeSubtree, excludeSubtrees); } else { if (USE_DELEGATE) { - return new BastaLikelihood("name", - treeModel, patternList, generalSubstitutionModel, popSizes, branchRateModel, - (threads != 1) ? - new ParallelBastaLikelihoodDelegate("name", treeModel, - generalSubstitutionModel.getDataType().getStateCount(), threads) : - new GenericBastaLikelihoodDelegate("name", treeModel, - generalSubstitutionModel.getDataType().getStateCount()), - subIntervals, true); + final BastaLikelihoodDelegate delegate; + if (USE_BEAGLE) { + delegate = new BeagleBastaLikelihoodDelegate("name", treeModel, + generalSubstitutionModel.getDataType().getStateCount()); + } else { + delegate = (threads != 1) ? + new ParallelBastaLikelihoodDelegate("name", treeModel, + generalSubstitutionModel.getDataType().getStateCount(), threads) : + new GenericBastaLikelihoodDelegate("name", treeModel, + generalSubstitutionModel.getDataType().getStateCount()); + } + return new BastaLikelihood("name", treeModel, patternList, generalSubstitutionModel, + popSizes, branchRateModel, delegate, subIntervals, true); } else { return new FasterStructuredCoalescentLikelihood(treeModel, branchRateModel, popSizes, patternList, dataType, tag, generalSubstitutionModel, subIntervals, includeSubtree, excludeSubtrees, diff --git a/src/dr/evomodel/treedatalikelihood/BufferIndexHelper.java b/src/dr/evomodel/treedatalikelihood/BufferIndexHelper.java index 8d6e19c9c5..31b33b335b 100644 --- a/src/dr/evomodel/treedatalikelihood/BufferIndexHelper.java +++ b/src/dr/evomodel/treedatalikelihood/BufferIndexHelper.java @@ -58,9 +58,11 @@ public BufferIndexHelper(int maxIndexValue, int minIndexValue, int bufferSetNumb storedIndexOffsets = new int[doubleBufferCount]; indexOffsetsFlipped = new boolean[doubleBufferCount]; - this.constantOffset = bufferSetNumber * getBufferCount(); + this.constantOffset = computeOffset(bufferSetNumber); } + protected int computeOffset(int bufferSetNumber) { return bufferSetNumber * getBufferCount(); } + public int getBufferCount() { return 2 * doubleBufferCount + minIndexValue; }