From b7f9694cb985f0cdf916ee77c26e9710c2147a3f Mon Sep 17 00:00:00 2001 From: Dario Coscia Date: Thu, 7 Sep 2023 17:13:51 +0200 Subject: [PATCH] Spectral Convolution Addition * Implementing 1D/2D/3D spectral conv * Implementing tests for 1D/2D/3d spectral conv --- lightning_logs/version_0/hparams.yaml | 1 + .../version_1/checkpoints/epoch=4-step=5.ckpt | Bin 0 -> 15514 bytes lightning_logs/version_1/hparams.yaml | 1 + .../checkpoints/epoch=14-step=15.ckpt | Bin 0 -> 15514 bytes lightning_logs/version_2/hparams.yaml | 1 + .../version_3/checkpoints/epoch=4-step=5.ckpt | Bin 0 -> 15770 bytes lightning_logs/version_3/hparams.yaml | 1 + pina/model/layers/__init__.py | 6 +- pina/model/layers/spectral.py | 320 +++++++++++++++++- tests/test_layers/test_spectral_conv.py | 43 +++ 10 files changed, 367 insertions(+), 6 deletions(-) create mode 100644 lightning_logs/version_0/hparams.yaml create mode 100644 lightning_logs/version_1/checkpoints/epoch=4-step=5.ckpt create mode 100644 lightning_logs/version_1/hparams.yaml create mode 100644 lightning_logs/version_2/checkpoints/epoch=14-step=15.ckpt create mode 100644 lightning_logs/version_2/hparams.yaml create mode 100644 lightning_logs/version_3/checkpoints/epoch=4-step=5.ckpt create mode 100644 lightning_logs/version_3/hparams.yaml create mode 100644 tests/test_layers/test_spectral_conv.py diff --git a/lightning_logs/version_0/hparams.yaml b/lightning_logs/version_0/hparams.yaml new file mode 100644 index 0000000..0967ef4 --- /dev/null +++ b/lightning_logs/version_0/hparams.yaml @@ -0,0 +1 @@ +{} diff --git a/lightning_logs/version_1/checkpoints/epoch=4-step=5.ckpt b/lightning_logs/version_1/checkpoints/epoch=4-step=5.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..d3ca30fbed22de8cf25d2c88d1657d79ea0b946f GIT binary patch literal 15514 zcmb_@2Uru!_ck3t1Vlt^h=>ZPbduf8&Y-BEf=E?V1Or5Z(sq+zjS4C%DmLs56uV+? zkXh{AtD<7T-rKcY%T+9Vvm5_Px#f5D%kyM1dv@M)-Z|6G%x*#g^$m1%>ebV!^Rd*i z(vh<9@}vxzTY^+6bxj|Z?C)o;GoZ$2yq=oTaf&j;a9MgEE?uM(5f;-ty3 z3Z*Qa(F-(;YKmp)Bb8}bW2fY#M7c6GDK#-RL&hqS(oz|HL4&8OyQ?>25M_$XDwI;C zEH)u2Udb57$EPJH%i@)|q{6fhn;>Il3El7=Q^v?Sj4^hKYK;N0sWKIdEvCwpuCXa; z39@8YZR+kiT*ez@>ctD@U1L?sq-2F@EGvssB_$`sDrKpPG&VNF)08oBa*8q*tf#Tk zL>XhMg_!y$r%9FJ7>%VsPB^#XOX z8Oy11o?t^IV-?0UifZ-GEaQ@-3dUNFDacLbS`!{jV|jx>Yb9eNsGh_$3Dlj$*lN`_ zRWi-OnCAandC&i=e2W^2J(-sB2EAwLDj7R1Y3ppJ73Pjf+iSJARx%D@Oq+iu?dd9z zcC4Y>i)pJR-R>Ld_B?4Pty*U#;}XVn_-E4I|3|uG4aMF}C#_;vUa@W<=Io~B?51GukI44s}MSNY;u}H!T70}u5v-f50Gm8mFcEtx@+%P)x7hq zCo}$PCO})j_Ad))yE@ZD&Ggh3YxWPtT-}*oY9{c1DCofishQsYLqSg_Sk3hL9}0Rg zA!;V{e<svNGI$ z(Q0OpcB)gRDW%CwjGBqnUb8Z3!boN?p62-e5ucWlo-9*hrBW=>PA5{9zmEG>0;`P2 z8nGlk&`>25iC3|#ELF)Qs2Q1)d~l$yT#CPO_!}RntBPuX^(HA|lcfr!mOnFOiYlsQ zt=i&hC8&~R><`X~YDVrP$4rxWrG{9kp-PTz4E=6HrHblUQ;k%mOiPg}lj393(v?Xm zNuzK+Ew#EYhpCz5DEI#fC8+qpCq>Pq3Qqc_TE;SIY9`$&3a>(V&c}-dqh?rvv7#np zWp&29n@7pmwAe&en!wr{qI#SJy zit_w-h_<496@Juqw3-wWt?>q*)U$H8!Z>Q81 z?4)F=3U^jg&GxAcRR^ zmYSI@7?~r-8s`cc^Zo*gHO~7Bv7pfmQ%}di1&%GWO6RMY1%lFrywXL2(#3y)#!8pe z731Efoe-qS$#K&7VG3rcnpqY#8e8ej2Pj?TviM=?X-TQd(d|>xQj>6e(VnsoPD_<_ zw8w%38NOJGBP$bR>?|}QiWMdf$zFkyCY0oywC#+wZ-u0mBnK9G8Q+; zH`Qupxtdv_4RgL3W%7fgEHSA#-kCzs6sw3&V`WT1aFj(&dAd|7XI4h>3!Yn~0!ME+ zHe{31;?orINm93NGQ}`uTDn_E&)_a@Kcc+kv_ys5w?#45!wt_;->R)}6G#K^PW-q5 zKCW1ynDQQiJOF<{mR)`Y4L(^g)=tW`7Xq8ecf!<6#a zYF-g+ub2A@t;*lSZdJ*c4K-p35wlSnOKifigtN0YhS;oRio%#Jf;)vtTUT~2{IYG${SoT*`0%IK^TVPSB{T;g!r{IP>5gmYq=j#G6%z$Lqb^oO+~DNVGh^O zLNZ6RT8@6#a!jk`xK_&vC37;2Ii(8JRWqlnQ`f^4_i$&()j?K;%Fm^)f(cdJpmr)2JjF%N1}^JX4?i>n_AcK8_eu{K71qP6u@$vg{Vp4aB* z!~FgY_~JY8r55-~3w*6){s?2uX%-cF7Ma(-b$@kSr{;6a>gfSm$lN2+RT3>t; zOif}w^S=0^HS<-;RE4oRc!N`$BEjnNQr`2d9?ogOtUk_SJZpe|1r@9z&M8?VoQAQ+ zI8_B2sM&frQ-?LlnqaX&oHfO{AKcmc_)^K5;WUhGz!$372|V3db6zU2vA{W=4ZkH| zE%AlmIJP0qX&G7JoRV#X(=gT=r?sl)>A^PUrGhdWocpe<3BJ%Ov&A{BvZgqvWSikM zjBSonRiL4oZGkhjtLN#-w!~s>_3Uu2TJ>7tOC@WM(=fI*U#M2~yjTZb%I`_oHaI6R zn#4NdEbr{L_*cuO9nLA)_Bah=op4&aYTm3fFBKHI;M_Mw9dMRc)Diz`6?MWnCF_dQ zFxCyHcsN#o@#g(bnlq!lksSK_mQ zI2;*EgkpTGlaMrI2tG(iRVE2HXyFEbjxj`q19KL9M!s)zw z6s#}4F^KJqQ#>`}eyA6yA7~J0C{syQQek~bX7NJyEjm}QU4-I(I90Pn8??ZiTK*gIw`Wb}I!-mtn%vy|{TyqZ zG}SnE`hGEfWd+yq!Xap#Q^{Qzc7XesybRvSdF!W^Gtiq;Y!Pc7E}l$J7Bz*FHY$N*RufK@Vo_R|Fb6ix zbM$R*|0m>qU^PMRo{t^IfB(q`jx~MY;qPbvuRd7t!q|7|JBL!!{zjTDdCR!3(=Kz} zCSHf0r~SA&;mu2XTr(~`IL%ygJ7631)Njj;Q2a)lX1+wv+eUJ!6OL+T{E-RWgRFe@ z4!8AfT)GtTd?%x(86iB^f8=YS<5i0EJ`&b&>T!-|jm7Lfn`y^kz( zd~4anPLljro7E|FHgtgIqS=xzGfI)pypG%#%~<-*#!u+ok72< zYa)DuM35a~2Uo_ipw<2gblzq@e30e{i^yjhyZhNt5@si~cU}s+9ZSX{sQ-oVZZD zi_ErcFL~?|M3onw5Wjb;C)q{;if@bKxtSh6!{6&C`2TAAd;enlZLwXi)L%C^zQ@@J}c8a>=Vn;1?x2jT$f6k1nR&K{SCR^Y7$(I6QhlBv)8(74iUx~19zUWs<6dp?|0a%OryN=jV|<`qBDBsS9p5ns$;LcewN>>YE_ z^_$1h#_(6j`$jqHq1lJ#E9=4hgPEw=_UXvV#hhOJdOg}~5=u`Axe3;m*w7c<#U=IU zk4C#bW&zdW-J1M%`{^Nfu7N&#$J5gihM-FMQY5{vL3UxQ(T;0h(As?@{Jeb^YV3O) z?eNj3w>^A}Mn1`;3ke&TKhBbV-EMG6ANfXfvGZCM`qf z7o0}}lXSS=2Opw4?V57;sW`aaz8pP1`3RjlFp`dUGUMC_Z>6v8C!k5uSlZ2NRms#> zXOVRDWzd+iLC3nc0j?Tvn7Fu{zPpt}4eIG}8JC)KFE`fX0w1=8;hw{}t?yofwu}LU zu92E?@wRZ)v0a+w8+&rEp3|E6!3QM^hur~QqkczU%-RE&6Fw+uR{;0<^()Q3f#<;1 zL1xh4XnVMVX$rS(WRYv+PEb|pgp3-CkY3OYNu$a^XjXhfI^yPB@CaU&Y@ObqBud;C z>27xhMtzet6=Sk#SNoaZ4^0QUVto^|X+#K;_ZW_-$S9QRx)7<3-U4SQ_e0skhM|FH zk4tW;r=eAbZRwu(c7e5X%OqQ!8kZzn2cxm0qrlb~3pDQ@ETJzdcY+6>BWN#YfD&(w zMYb`EkwL4eDAeYr#`tao@}`u9$$UU2VbA zTcIV}`?v$@`mvHu&lAu^IIzU@$WBV{fi9SI%mno->q4Kf^3ya|$WZhRBaOqY7j(wZ zR^awLKk#;g8z_z$LazT1pGuxwq)nv<32 zPe&tY(HEkgpG9DQ+Za68>Vux5N2oI2mb=tI!nw5Rz>UbR1hx#p1LT>;fXGX*+b2RngI+Wq(uc;Dy0Nrlspv#t7 z;L#T~kfks%GjAI==e#p4(hyMBsz~G3Y!ICM*bqr4&f>;I?$)rmFD3hDw}H{mT)5w_ zr-HmiBau^SE;l;D2`&HH7`lJ$1)C5m*fAgqIvjNaE{aT~tN1`KB#c4ns+5wG7Og?p zll>)T%fnFaj*ODoyNW2I%ontL)JIKbTQ~YaM<0zt>kt$<;(f^g-3q$%?nWRW!4p)q zwF9xGELh*j17!z7X$|{%7EvVMZxP-DDc|02u*Zu1dQG6!Ky=fB_|4f!AACB zNkY{~^y$g6lAK1Tse+XaKtE`W(q0DB0VRDkV?&0cq0<^_iq7eRO|#rUX;uV??jZpg zFp+aNT@NMctvJ)<4jh@Hq0Rb+bHC;4gY0!vVcL{)bm)Tek{csiaBn>vpng{=ik>qR z9Cn-xTkLEJj+{5=9P|9R$R)$L_XWK;e!Yyo+ipo`9KiG6e~MkY){0$x{QZ1?C3fNC z80D8q)Ig?zq~4m(WQ(zuRQSDS;^KL+RR4mt;t^B!Q8)h#7u#le`sD0>M}8F>P;ph` ziSf%v5SedxQT0<^iJiWjBJD3z)cG?x)crxds3QhBM8bGm(w3-C*c7!RbVj!Hx!EI; z*jl#Br`eKIL{WjMPw<-^Udww{iM!dD5*v3+5!owOi)J1@Oe`kML_LmQ5qIhpKonT# z5+}W6MBadNFqR%^dFNU+OI4kgXh!}&v|G|olHm&9h9h zbyM{t9S5uywNmaSlfv4Ia+8$a*Izs*vX4{|WtT^Zht{7e-fq8x{FwAkR9SkII5kE@ zK2DuPKAapx2DN%c4(m6bN{mvGdB*d|7TrGi9A0&TToLkB^!Z|YswCA`)V|*yZ%Lbh zICTSgtn3)M zQu9Exxr-H5_;jfFkHE3i^Sv)b>2JFDT%Axr-ru{4To~(2l&6KZcDYK{g zG>k*!fUxuA;%gy9RJ4-lot{T^^e7kGoG&0d#8^?Gt3xTj2sf&5^L3(8jDoD^8%c~_ zFq~-AY?seo$GwE-0DBQTvklo!Qsgt%z{-2p@gO49q90*3r%beUqmlS>W>a$Y%>km$ zD>@RtrLQFZaCks0AGd{Yn$>}H=)8~2nY)qrQgoW=l{?o*e8qws|2R-IdW(`Qge!eI z1v_}ZNnS?GvfE5deCi_Jr0XT_+le9b(Hc?H4HJpt7O#lw+3uuWcztqzda-!(mR`i@ z_LSH$DORkiBE7%fSRy{mKJlKFW<>ny2fVwF-0ykHuY>r(4oaN2d9RO`{TZJ-NT1m7 zz{e+LuBG_J;t68fe!o~$zEJ!)KZlssqcLgU!kh?GbRu#+3cSr1$O-@T3w-uD9wx3{ z-r_ygAk}Np&@bYde>NaA7iWl+O{a@ms*ezf4UI*%5#?f&hLG5AkV`x)O(dG_97n}) zCX&zhjHvO;+fw5Wwh}d&-;Y}DGDa;4 zqCM1vPhZ8UuE)r#I~^%U&n(K09zZE>KPIoY97CO2KZNw{KZSG(7$QouIz%phZXrHV z)sE_$G*Yy|;gR=+E`7-`-z0LV*+a3^_k%d(t^+k{Oq@6^-J9H#T}Woyz9RioO31c{ zT*=?;wo zuZ!sAkp86DzSw8R982$`!PAL*a~2TIn8xBSt?k7=J`v=(le0x1&J80x&t4|3=(i+` zKD;Jg$6TfSFASCJm-nTVoSI5*A0t-Y*hOu!HzqV`%_Yz8E5xFv=Y7E50BQr>pECV? zi7?B0NUV6>MDn(ODAD?Jd#Z|FL_MN!Qwy?-D9I%`;n3tW5qxqYeU4G zNFzq9p$Y$k6_l+`isZwZe$GP$1FKW!d zfz)BWD}?Um`^4eaO(hPdVZ`y{&Q!Bci>cgWcPS&2Ez~(wKpYw1N_JV%iV!sigyp%` zJ{Qs^6Ho1r`n=A2NL=bE@)@VM(2MA2O)UP>nlQMxT$G_+FS34phX_zK7xn7#QG6gW zo>+f%4Y6ruCed7cj9Qr&A?baXro3GTP@8kJ#OuDSrus(f5iOpWNnTpzi|gHA>oaPe zh)PK8O6@h~h^n!th^Hp?B;)eLM7o156}D|U<$dNi>WSVWVZT@SbWOjXS-(X8AJ(s5 z-T#SyasPLB=LWc{)fV(5zaJ;+=)|pi62~2%IgQI{u>d}8bPI`NO}IAAdm^K_P2fDY znp^xP4lW*E4_@)rg+=OXu)pUDFyT-I++c{n&W7iOc!1oL`{rl%&%{scpFf)RmxS35 zA%jkDqT2Ncr(BkmQjdv-lHwcAl2rddN_f!U$aue$1{?B<)%NwpZp2_|E>MabvoCKq41Fj{YxR#u!raA zo{nNJ-bli6v)XXWA34BjrzU_Si)PVxey2g@?2*7pZ2~*qG>7dvxzgnUy70r-a2P#n zC+#q4EpXd^2vq)g3G^!42=+<-L?>2v=enB`+{|&W(U?u6P}Z4`Kqu%J7@p(+ts5D` zi0PGJ!)+0lv&oD@EqcRa8L>cjRCB0v=sKEH7>3HqHlug%)hHsqAv`el1g zM29vz3^rbU2Rif}z;)d^5MF*C&*`k#gTk^%be+rtc@|-)&%PnRxgwe?yJ4$Y-tR2x z?KxE99-X45 z_Y9BVRv|Y|@8}lExgpEA5 zgT04Mxm)vWxi=FogLR=TP{$>k>AU9o+_fI#fxO2!^!`LXxb!T5Qxx7nImiWOcdw6z zO(o#7@E~;UXc6*U>%pnyongh4xyazj3c9CZ5@%0@gZ|#3@XUrHu<)7gTC2t7_ zgzQJrFFf!g)s~!E90(POQ_!RI)*#(#3Fz3}8l>+D;!1rdqKSr=(EUjwZZ|$VRvc{1 zUG#{B!8ru5{=5{}8Jy9mR$K+nowCsN?0oQkP!@O)rNdpYwn8&~+i*)fEn(&k72u3p zgTY-KpvuJ%ju>5aNW-^VW8)bK{&x4|2bW_$wo;2z%7Zghpc`a5tI+g_ra z_tD&g+zeztVj(yh91Rl9CS?C?KHE%9x68#oj>hlhD&9mpC7M-M} z`|Lo+wxiK`vMp$I+J3T5ILSUK+}^o2`121bm|%PuPP}I_yS+&d~7c8b9v!gMCwL(EFre{S<)g8VT)QfQ+p+ZonBT_+EGGXQk4Zpsy}DF#6=0)fp=U2eMnC6N1-152_>=sr8$K|@sgvB1PO_Pc_`&*4A(`q3RSiYfJJ2uQJ(^9PTt&}TmE&w z=H}$n2*~W=D1{eSa^Di0P~!6~(36YX8!9|6nSOTM`k#oOJn{3AU*7*eDqIN9hkxOwHumzh z*UjOo^m><0IFQL%vYw@jZO?Ek#;=F7){9EbjS1h!Awi|}`&Obk4bw}T+8zOuN4A2c z`};%LN;fz&bPBh|wpfT?hW8F@_Y>>CCw><5pP&C29QKm@UiMnjAU2tHxN0RaU$=}N zIzm^{@Yqng%rajR{w_@7kTZ(57&w;Bf4h{vWAcDJ)#DPq)_5ZsL55Q5aV6C0^byq8 zgT@lG?1@7B;#)qG`x*W}^Zez%#V>W9UwPtZCcpIj%A&r$@cit9=LDypcz)*jtItoY z)m4Cx8FxkwNcwmCb>cT%{a=gUKgXisV;o`o@y=*^?i+M2r~uu3`vpDCG2(!;GaR`;0)0J^i<}o7 z*9;Hv1YIgTIhXf$K<~0m!1iq_82dU5KIv%z7RH6c1Ka$;i5sd70DmFGtfsDaPo)!tThcRT1cj%;AoWvw^?jDOeMc1?=Gg;Awl7{{2&Z zkY5pn=0~kTCmy^7E{%QQacT-UcI+HzzITA8pkJ;A_-zC$J2^vAo)3b|t_X9(2P21( z`S`g$|NZ;(adVwo??2$}RA8(z}sGTr2EFuG82Np@X*)N5{yBT~G6fb6GZ`F&R5Wqb!YxuUm78SEu61 z0C7`l@39u-b*oe4!@gC-hs^2Z)3l}H_huHvp-$Tf_xupDsP|3sq_2W1Xlh2rph7aj zy-4J{u#woV*Ff^?puW`eVSc1XaV-P#&hkJ{Mk?}PBKApiK-Ti3?cFVUm#gR39n)dYv&e~Hwj?Nb~4_ujU2tPK{`23J4ss%bf|4p>|p zT>Y%6CU|h$+Lnph;J>GB->xJf5H)*+p08TPExJFADM zHNp0-wZV0I=igWPb{%jX+SP-^nwBHnYFn<;JF7={HNh9_fa~LFE4XGgf#vRub- z^>C&p*50EQwvOHEVM$Hw?b_HngQ9v!Q4<@1{|;1RMAaD+)tjj%_F`>pogqottq%#Jh20t$!FMq}`))ci>spV=#Tw+Rs znlVa9$Vg39B&bnKm1!3?QNbz_gV7mN#@I88G4Y6PiU4uxiYyjcOjoGA;?gn_6{%jj z(#LC%LNLhGN)XL^#bv2eQdOpLtRg-uB{eZltw>j8uyF(ZOqtpq9?tYA!a z5Yv#<47oZQ!I99EF^knhZG4zIu?FZjAv23H5A{$p7F|Onq4h)c$1|3p`k{K`8LM2S zK(MZwsTajq$2$Hq%lH(zim9(e6qKe)oe5vYMrj^eU(GZSRcA8|L-n$mMmn{P)r@Tv zWB1R_`~B&>eHF!iOcSMf_;fur~%exFga|I_*u>jB^y@^3SCGyhPH?s_1Ub zxavr|eIxBIkoM52^;9#>qnH-|OxpiX(k-hf_Gen@6nhDZ^+FM6ZyjeJopN6_;}^xW z{%6kqI?n!8R0l8tNHq$FR8*bKU;<}crQW0W@IwmHB1i=rICh-jkQ5?aUv4W zcvYH3dTN+n;_0|}xjI3qP{pa0aoB1uB1 zWtP|^A&X@d>1rlX!zetI{X_MXa`cTy--J-TtXOlTH$@efDp#p>{F$T)S+Py3d5hdi z%t}?TKR73A7^R02F-;MaS|X*ESxRK%x9>KxvSM3SRU^+*XQav1DG6~Ind+3Zl%c4e zkzP5L12jx(tk3_15>@=*lcr(P#UOoCtzel94U_2+i&h~t=cC1f(J-vYSXGs=x-w%S z%tIAyMqDy0Ph`-@94JBteTSgP=^P#W9VP@bD=SMC>sgg}nmj!Vts8&m>kthyG}iCm zAv#9|7k+pvpw!F#7VhXu@>LefA8?w8fLtxDo4NwoyS~=D^jJ4?SwzY7T?9F zHwj%8Igx^9NQG{E=4zOU;`p2-TB$~OvWA%=TABJ6%urXI_7`GAI7{qCY6Th);v_I# z!^{wk%v2(cvqX)v{{o6M&iMQ5ag+;@$!TLDrS*}SsXhISqT>sly*u*!hp<-lyvnl_q2@k6qH}MlTHyC>57(4 zNRX&NNA77U>F(%Iffkh^ap`etR-TeB&RQyT)~C`LapDq$zDTHRhDNU<4yjkL=z@Gx ztznjEn0#HD^UWwz5D{yINW}}m6qBYnRYC@<#6H zY)VE#hAJUN?j5X94NzxfdPjDOXyg4O%S+8jR(XG06yyB7(Jb}NZI!o38uO69=YHv?e+~Kg?Tq)d@6n^=dk2~SU57O`HAaX87Ci@%)lxzoZ%4+rLud3G)PQk_lAn8>^_|n`_;^WDCVG;p?^~mr(&2x zRkRSyVV#yE-?bdoX*s6Ta$L=vh+ zNm_a0drl4xuIrm zMlrXtLeakFb`5H+nL9dacPml5r)KU)F%PO!^JmJwrP$@79U)zPs7qHL>1;h#Gf$$J zr`7odFwedLpMM9w&;eiSfUnfd>nP?;R;Yo7d0T^|l&R2>d{>F&A8O`(6!W1vNsRfZ zi^V6=R5tTjh{YG3nXhVAf;Ku)tR5<>v&30_K`I2FH9$2To;5^Oq=GdYaukM+7JAESW7`Fvav!n zfsL?1VC$j-F?MV{RMRoCMm05CAC*z84Jxa-=I6^c5Tv5AhN$*kStE3yQ`Q*ObjoZ| zP0iY&GK#fFWmc$>hHZi>)!p;+V;zuK=bj^~RdUY>9je);sElHrg+|rfYt6a{QelU} zHbXU$Q8w#}szR{c(65e-JF2N!4^&36o~W$unm^lIkcx_0pxQS@Em2ia)C&FT6nUYV zn)OCy6zhXZw33O7qnwqep`Ja$+UCHGJIW5MsQY06*(h8&%6B^O&hfKZWGno{P>bCE3 ziHZdI5Ee(LLnY!ICQnld=^slVQLeZnNRtm1@>K;(iXtKPPD1OLJVD9IkR(E!w-V%; zYILU{9_7iD*i6WL5>t|r&~1fub&7Zbh$n=*8cA6wO=s!v=oGX}rwAz;%LXFh4>yfb z<4hquDmDn6=*zZ2C7PyDKhzI32sI2fQe?@qTduBR+ltNGp;E&J3!qp* zSCU+~_rSJCf>1s2mWsIRh)Y4Mxrz-D#1hfzc)1GYSxD9a$@;P#QTd&fT=ZC6Zp4+A z?IhrIn-a7j$7i4=yfYHziU!2AA}%2#D_zZo3JtAd#eHP50;NN2TmnjC6p3+zQq;iVQWW71G_+r#af<3g5`Tz4K%tv8{6F$=fHS zZG&p}o&LNDmHz4~e>W4-*6;|YJu0Q>N0t)eBniI=+Sfk^&SW;rmZsH>>KkZ%A9Gzw!D{?|b!1Z)FrXWwhI#aTk`z zI?tcWns39M!l1xfZ_ZC0kx*2M}KU9zBJZjXoh{T+G#Ek~)+X&8!Wx$Pse_oNEOdDgD*<=Qj6x*;%@9L(k=~rL>|m5Fq135#M{gZ&O^+%=!;+98~a>_}oJEZ*0fl+0d7 z*$tHfeXm@q+57o{r+YBF*Z+FMvI6}=oI~>e7d71jsybqMu8V4E|=Wv$yM#8Z1 zLqYWINbcmRcihqTW!$_mui@PaU#{h~FK}|he6Z-$32yQmf6iBSTE@OR1HbfZAUhoG z0BUt!3X-fJQG?Xu;LZ=TL7Ti;+})zF&~(f!;BDKRJFjjGJupvL(A!dF#K}+Ny1}r5c`zEhuq#aD<|!Q#mRu8rqKLv@^nHL%)Eg z;FvO*o9+HkJND@I-G1gQG;cm|;sVUwlM`9ve=?<$7>ea+Ywbx*Bj^&1O-Tdu-soUfwQy8!?X> zOl;u#E=iDSOuVSPm0nc9YzMF+p-5Id+Jiq-Q39;~@PwxiKj5hBYg(gnGv4-XNA2C2 zLEyl?ZD4_46X4^%0XA_e1&dYJxx>jjVEV`sFrjWX_u4E1My%@q8?KV^SBIbG*m}>v zD%(l0PvKE&?VcV~8|)r&f5-vwxU3_21of?6W>cZ@%LiP+Bo%laF%2FIpF!z8%m#Q- zE^PCpCpB)3DKLCh8#FI7qVZ?J)G5!&pfFhV0y^>MeTDJwTQxjO2<3zPPGo9-3rY$wJ?;@~w=26+DvpvC9 zl|5|MDUrI_&W4)mj%jav4gg`hx9x8KY9r^G{((CRytvMB4=Lh>0la<7n95~%Dx5Tj zt(2o>M8OQI_Sx0c?|oI^SfmTJKCvU+@Z&pZlkWmv$1I^_mkhYhy_?aJ_Ws(Jml?Qg z=?7STi-7f(p9b;wAH#Mx9#K-0QV_rKA*>g)f$DCS2~tW%gSP$S>Edx~sMh_SQCosG za(mn9QESH)$v#iucE4PDgv!UFcE=s+#0^`T&F#$DrCq*%6g8!EI#<|zu59R;PLxT= zBJK&PCtGFfLf!5}QwdKELF4OtWDR{n>3~HBd>ef~;BpXBqkQeTN!jkS%{oR~zGpm) zy7L~Me(DeJEjSP2cfWw;D_>AEhaU$a_20mR21}`xU>FFVHVFj3N~9}#mQdrnM1ws& zPIL2jNT|l%2C~oj`w9Q3=HT*{&g6R2+T3vyQ~0tXiHq-gmI_-_!Ce{lN|u^c3uu~O z;hNt1AX{P7ivoMvQNMS~1f}+~$Q{Rm`M#STfH|A%z_D-rIl7}6)pD>C-=vkhtms}h z5bpLBO#1>r-Gt+iY<~}I?0t*tSbhn9Fu4K_cH_9Vr3rAz0u_Ad6~`xgMRUS>`EnTB zCJ3eM!uOwYmv+^1mw=GAfxnWw2ziX7(RR`{#6&i=%NJr%8#D6IoBCL5?JTk)7+{ko ztt791P!Q)%nBkrYW~609Z*qjuDExkpQN%8<&E&OnZg~Cu`-x76e95W%j*)#PdXdK$ z&&A7|zQwQXe}wUyjqr%81_5VchT!W;yaLjPhKZBlRyP5joniMC7ue(ucAS#N6x4 zq=nn8$-~|d&(9x*8GQVW2wXas*xd0tHs{VG{A92_`P+*n*n!*b#Ow2UL|}hUa>uZx zM6&_;MB3}O#H81caGQaK#NrDE(&CIU#J=X<(xaYvM4qM%{s;Y8x*@Crk6vj<47{$u zGIpln?VBtm60;q!6HPYab!R#ermZ#;(M8RO=~rCI9nZ7K*vLTyVKJ2`&+8-QN0^XY z{$lC0whhUKwRhoe{qivQZYU9~E+o#c{sRj)u^^_#s%QcS7nPNkHT(c9qV7-kPa{4yW=bS5^ zYS5DCBTpfJk1fNpnvEp31U@D8xtbY`#jdX#{=T^tVZOG#01>pY-i$J`$gn&wHf|OzLJoKS(Bax%gJNc z9Y}tJ9`UTW3vsDJ4M z&$tyqyA>ZRUb0 zuMG-VluS$CEZTy#b+HSWHe`?V;uat5QN?ELN+S}J$K~P0MbGie{4=aCYl`0+b<{s> z&2W5zzkNWr`?K*i9yVB7OlV7(c&RZe+(x)$ z^}`>RDM-oq8D!Li-|;u)B}Cp?8(B=r6#UK=XHqX`Ho2~mtIW`U9{Gy0A;w>eAyT_V z;nOYB@s!|R0q$>J;KkOn0%r6!B$PcaVzFBbr60!j!};bj@G~1qrQ7Tp;Jsd$6Q$1z zrR@(j#)JAz#GUrK6E}&Oc_zB{E&tPqY!_TZ= z(*K9`>sR-G*e~w?+IpDL((JF?)+_lC-^@Z&uoRxDxCoDheFZ)KXaz4v_JG@4j^QR& zlu@PcjrsT6cYzNZT>)LY3-I%Ir#pF-QwJCO(A)ehz;8Dkcp)DcPj=q>Gy7-!C-%?F z!=q${Ur8c%+cvW6Rf?SaW*1qEo6266c*#_s3y8Vno01PlMc}H*{=_w(FxjV^QZflz zk|##|PTX#`PWJAeIg#9Jqio!*<;3mL1u}tu(z6pKKf|B+3I6}>`ddhXs^(DWmI?Js zDb%qA?Wk0P7JSbYA-toX7hmSylAg0L5?u7VN^QQA0=mCg2=Z&0(=$@^XvL)iRLiO5 zpwO`!9bi}olm=`Ez9nnHE(0Fy>opB5ay91XT*dfOvtT~C&lMOjr3=KHlz}<1cfj=* zjp)`}-hmyet6*V2f=|5t1=@3gw0m?eSkuauKG*FntTQbW=3{eU!^siwZn7io>UA0F zrLBY?r`(lAC0_;O_8ZX_Z{qn4ObYE{nZ_?OFM*>M4uW?VOa*J(+QSQ-e89bnUHHPh zX50e092!mOrp?Kaz_)!m(M{Gm)5Ws}1Aj$xNQVc~b7mX|?>m=*Xvd~>sdF2+e_0Fq z?5-xfUquLiD(VTCSLYz;A9kJUG<6|0bE_U5^}?5T=tgmdd!JL;{zKrDtT6uK?3-}L zhe(yj&USq{KzSoJ~VZ0PRX=Dh3&Tj=XWWnI}0-B#3)(?ix-U+X- zZOgmtG2_X*ru?d~Ui93^Oc2#(3fTU#9-OdfH`p)Fg`*E@fm2#IDDr#&*|%1(#W`!+)#a}f7+q`dP`G8C1@7xVfmY?+=prXe-a2bDh&$4l4|(j&56w#C=PWm-S2wh!EhF9dn1y#h zxE9cGl0EO-k^=$p_S!{3g}@W9M;pwu<>AM^a7A28x~xn z!SEKms_-M#xlsu<=J%e^-#!YA3pU{&woBrCoa^&tj_UyvJp+s`!uhQIAGGbpT%{cR z9r=V%TgtQIE%)q13)*N+EdP9!Gkre$6P5n;1lK@WOg)^t7^baw&LvK2N>3Zy4cyDl z09!6O!A1*TfFSGc{9!_$KGm}+xV!Waw9S$8#-aVeAA0-XotO=9P4kYF!~9@=l3i0s zdH3f7G!Af6!8UNkSPeqP_oVZ;J_a4$bb_x8j)HxI$AF}X1Nc!pLl=TUL)Raa48E}!^4liMVzN@?e3Y8gj&!z z!3y7Ez&GwfpTM=Ce$PWN`ZPw5W|o3Q@D7*xwK?77$y&IxRvKRvF$pG)mC?5-3TpOQ z@WbnQ@%ubSb9HAFL&qoHbX>7FUuXR`$`cL-%x5F|S@TuEnl+}6Z>J#*=8i;uwl9d_|0xmGZtV>th2sxQ^;bO;^RVlp2!{0O~5t+!{5_j#CZuD_>6 z}jU74ZvYN1D#2 zGB#+)rpIzAc0~mF!f6K8VVj$*g~wJIpWcp2uaiJ+7?(#ik?6^;NQ}Vz_%ib9Xg$*F zKnB?_|1!BZW2LO)0(UWg@hdmC{u%y1^Zez%v2XRoCKZ?~hDt4DxtF>S>0>z`}* z@+bJC+mZjBf7Zok;xkV*U6J+^*B@;mf9?89B%?+O{r^J?&6S_9kG7A$YX84Hzpl~$ z(U$Mm>~|?$3{w~M0NzEP$aHowRk19OI^i-FTG_Ya-p5$-^_C6flCLbJqPv^`3&A9K z(8Z0?yK$K-+_a6FJ2)J0$Xo8;) zSc}W890+K;ikNi2Hd(*cLE@R03%PFKbA0EbvBZpJud#!gow&V48ve3gr@gdeElp3f4VBab7Li(sPd`)53lOz%I~fS?@9gm3`Oto zR7QLIRL53+kEAL#6TMkm1?yWKTlswu;T7Z`AE)}&u{8`=egUH@cyyFMEzhn2_OA}Ed=sxKI1>F!q$=%z>fpbRwpFw0U}<%5 z<=b6VEzhn2#;SuW-)pK0?(bUNGF~10_q3hds)Nx;tU4|#XYEyc=j`g>8kf0phFq0( zWUFe1{c7~i%IRrUu#;DHaE;#i_b%V60j@#2a*|lpayRelmTUCR${Ai&@cA0x8oje} zN>w%3-F&K9u3@-xI#U(v7Sb!?46Q8}fkitUE}4pe1C)ff|%FH=?Q z`RdpjL!$C>Sryx@bv3&+MnvVMrYbhmzZ$m2fT+BPRK38N0Y9+cmZr#1V`JfOW5UP5 v0CoCrAG>P5N$8ol@F_(qgb%_AZ{15`{hVxB(}1X*hz%;y0-rZ!ihN7 literal 0 HcmV?d00001 diff --git a/lightning_logs/version_2/hparams.yaml b/lightning_logs/version_2/hparams.yaml new file mode 100644 index 0000000..0967ef4 --- /dev/null +++ b/lightning_logs/version_2/hparams.yaml @@ -0,0 +1 @@ +{} diff --git a/lightning_logs/version_3/checkpoints/epoch=4-step=5.ckpt b/lightning_logs/version_3/checkpoints/epoch=4-step=5.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..fde95e8db6a0f8d4c52a3aabc6f8b3d3b95d54ec GIT binary patch literal 15770 zcmb_j2UrwI(_W&03Wy4*2xbsTl6Pmi2N5wW3L+{daDgQ#X=WD@bxkNi5pzbxoB$Og zveSd%L@}UZ&e`+y%!=8+XVy;%clUj#|2)t1_SAG$y;Ys6dwPSUI=Ui}fq|&rWhycg z$(YF4xG{3)D49~`lsr1#-`7|)wAM9QTSbfPLiI4aJUJ;cme!K$hBm}cG4V+eviNX? zQl3m}OZ7rqU|MpTG6_p`jE{?nRVKzI#)OZNGm5yRL|R9X;OgY!f)3t=M~Zhk%D@s@DycSyuv7)kw>J&#YcrJ<%xVgEo=sNVPI( z)9hFtU=t;67ECt{wfSe15pgmF-7FSEh&77UsOU;Jk8Lb%rlidU(bH)Qsn&GbQX{s7 zl5QDHxB6$}UH@aeRV~4;bnDp0edlN?X=@GWW*M{%#*RU^(MWHrq}v74?f)6HtCIk_ zLoL~EbVm*7PTxS=@}TWBVmm8o`(V1uKZADv4`_#4g57CHjbJBUu$B~KcGh5a(FpIV zq+Nq)w|~a$uEFeHOSA{=fkk6?cnYGY(_TDgZ;e>7l9mM1M2dN>fu>y_nAT`=6?Wdx9Xx^CBe)Fv* z)BY;Dr>21AUl!0bb-I^|?yV`-@*j#hxzK%7wDf-{*p&`Y(S84ig06Hw6&?6L6m+A5 zRCNFUp`be*tfB{K3bv|ijC`z;k%dRgWy%yruAoCy^gvC?*8f}*TRuod4;IR%C?bTq ziBC#OR?wj;dWc;pKMY34DRFfrCDTJy^f0?vJrzAXv^k~=!<+ENsm9neLPduOufro` z%E(x`B3v0OQwqagvOJ0om+C2mp|1wFP%u21Ns7VP74%3IEz=C3$|R*Mo{mt_k(y^l zE{jT|qwwI#x4OurgyeX+5(|}Mie`|KG5mA%w-Q)nB$kLNF;cx0p-5y3!^jhrbgYVw zvx|+AYQ@U&9gXi8sa8s8V=OmL5gsp7C^h)$QCTUWyc28SRbaECQsU*z56q)gbi7?G zMw-A2)x$#dQev@=#P2#%QbHYSi;<-$lM-agxXAFNWMy1J+&G+1O04P2Bo&<;>heEv z2_k;rp;a^^IO&^cIYTQ{w9+mV&uw@n!BY_Kyki8!skIT0t%;a-^Ef$^6duFKqUbag zJx*X5|D6SUPGjhV?`*s?Q&LhCp`B|3Pmm?1;MwT!Y*ne~iJ`9loke3PZ^932C#mSk z06er!cW8qNQF8obn-G_1iy!59 zs!9t_49CN1T%s^)DezmD8gGPwfG2jvL`^Z=d*$I+yqv)m@=dggUZbMdYJ#V4O6hg| zLQOHK2;P}OC>E}WOk(8p`hKA%wdu(+Wh}iRl%GbO2P$w(h$BuWE-5le5g8|Q_LD0{ zE0dC)gL?Pt=KLdej8BSDIDeaJ!`+~VGZ`*l=P8cx?BkMzljJ}(Dcz-QpEHzjg*S-QjTk+oX|)) zsiaQ@)0HVwEfsybCUk9Vcvlztj0XGJ8rZ9p^toXAeC!XrLfNkL1zZ-#+7|^?)9Fim z2zXhe=8BTO8cbh{{ee#4aiykx3--GE#b)mV_e|(G99}9Z;$oGjR^8Hhz>#34{7EC{{i_e38@s082cg9y5#@8Cg zH%j_#FkPJ@)lt#!>Vfp6-)lgAr~&e$lKvD-f36GCi~geV#aBVqbVh_1N!T5X7XB5K zG1@q%WOQ&E%;@5@E?jR$kEimr!06-LH$e?>mKS7ze>H*(aZbq?;WU_Oh|?6Qu8L`d zGxaPGGmSA>FoH40xgT~gCiqdwnBp{;X~Gw(-5^|D7&D$KplOP8JQ{w{!8F4Uf)knM zIHy5mj&n-J0;j=@B~I&@=Gv8M!BYicEphIu&i~~;Vn&!?p@>D^P6V81TWjo-?J|^e-Z-s2)qhYFx~?F z$qD{C*^_!!ed@A#W|s^`jOoJiT#ltm_(HHpTJ9Z}9EqS*vn3E7C6AP)F$8`s6$yix zEJ48se+-F9*}{4tK{l3;R^<#3NPOTOjb|}gWGn+QrJpc>Maq)#QbZPkqhv}b#>YKT zanaFut0GYuC%o_xUhq35(J44cXS#joj>D6593P-Dj4vksFjF5FPUhXCV7lWMBN#uN z;$a%MLv1_^>Pq$GDY6uqFxSL0cvAZosVkTsLUDhbs+gWUQz$PhTE_2ZFugEAswHgG z2y2e;I6RjtnBF{DBp4kbQ{X5I)B0fA2u6z2@2F&g#lmzW%(F}Y&!<_E;ORIb2~XgC zF)3S6AOscRkx40uN~RxQP(M^yo5sj-FvNsM;vhyI6`mTWjQtjXFo6OuzZBCfOt4pi zFlhwSAE!T1V8fk)qzw#`Paixu5igthJLqqBvy4UhwRW?dUHpBu@m^NYo+4B&c7n5( zWx>+%+LWqlB^xh&hbp`a*(=jVq8*D5!NE|6i$7%y8AmftX?hPnbnXQ>e-SizEJa@X zO}OJ79>b;=7eLRAk!We`bX0yzpOZ9v2a}UeqFw$CVdL2s(dKFqHO{>l-t(uR|BwuJ zzfos&=lw=Vq>h5A#lzS!7dT`)?J$y-Oh==m3DjhBF7!Fo9dvm)3U!}!mz}n7C@gbq zj{MXnT-n)dq?rB|j-4=%txjx09dvk%j#<2C^DD=q_TAfnp>5UZ&4@x+IA%TS^=Bbm zS^W|)9jZ`YkByM`SV-G+-RIc7{QXZJYgXH1UHyGc>+Jpw9ojJCIl7n^3X6-kkR8uX zK)%kCQO^09@J+%;Bx$(@DkmCH-9inhbzAqS?;0FO6RUK&HkTHo-Ag>lwJy5osp%aU zIBp$kpJtBQEh-W4bt<(y7Q|!wM|>tCw>tQ6l{(i0SZlYk{)0%A@5B2<-fq=);r92ASMP8$J`yCQr2;$le=h1t)|Xz>qKNNZuB`6>}T>jQ?DJ z!hhaYQR49@KN5SkZ}1*j+5z~?luDl7Y)afW&?efadXtUFU?QlfyO;6B#*%we+{iZ< z1`#`qBwpTYzlv{;Ywoo~w%dDLhhdV>>sv@JMfiL1_*Ynu-TyQA-F^c9uhzf&FV_Ep z$T^TaG>7Z)X)%?$>j5XZ{{YVODWL+|WpRIOAIJ6HevWE1suDiW2%(B@Z{#x64?u-c z8?G9@=4Q1?K_}GbK>5cFTtMVRYCgM|dL6U}bxxO44-$eo-zgc?t}i)gRpdC%oo5pg6nX}nLMkd^aI0?1EBNN&0T?b#?+(1QGJ_9|;;~?fON$s9z&t3O*{LT)@+dQInzjLz_R$}mn!Se$Z0v*l+YCf8H=MYMU?8=@ouYzz52QLf z-{R=m$0*+m_qi+m8&GfCzM|r)LpYxVmOI|;B6a)^fU2#csJHUT+`N%JsfM#kP^@3G zvUcU`(P6(#v}pe#E^~7s6&bsl(%sgQ3P@bT-R(P%S{u2M%hxZ5EBv-oOOCYQ=EE6W zo<|{S6R&k(Tva|=w{$ynyuAROF?a!(OZ};( z@lU`F{|c~mb``vo@d?obZ=raj(NJw^$OSrUQTAC@oTZC1OpLUqB<_n*U&H6<1#Cto z-_6fZV%lL z=|erOW|YrXCoU+{hP$_SAl2dSO)hg-Ic2}|4p)$+M_s-1l&UHS<-8n9xc9nM)HL9U zuJntgg50KZQ%d}(OvIsy1Dll{MvD;IJ{O%!U&uYS-9qK6ueZFMDozOD!t)_cWvn4snOfh9f%L0a4gK3H)R1_<>2Iu{>~HVl?v+RipdsB=fyfI@W4Ut-6sPW85F`D+0Rf*4s_u{HtmJIp6UDax|9h`d#m6e ztr?W;+7j-K;au*{Yzb_*zZ}_e+hFHO73jB1w}5&2S=a{!AlfDm1vO{kjnda@{o^5U zb}@tI+jfNSgVgMnQeAG`>v(qR#Z1=NITK|V&4VsyX2KS2ys5M6H*wP@MU{1Lf)2)<;I^I}f!0Q@Ku@e( zIWzNAs`1mlRP4KCE>C=u%1kQd?ha(B7P&9cCcEpL*sK%9elDio#&6^9&OZcaCJ*2i z%KWI8CZjm}Rq^oF8bcpqRT{jWcoCKlOrj24S;ZXe^oz&F(frtxDI0zFfL%G{)`cYlxBy*zb@o>ow6Q5oe=RwxI5@wl9 zrfj)QoL}d;T-$B-TuI6vYHRN`T;iHi>aodXBq}_>#hDV+=e6^xZH7f$E5loWtL)Al znmUr|o@L4Paw&kOw_a1zT_WLV<9KKkG>&q6mc^}TI)&rs%SE?X2Ok{X^Y=f+F5TpfNw>9BY)eguex&jN=Q=nH^B{AA^IvDlJ1cXd{Ma~?u4OB842>W$K zq0oR#p9Fn#P1Hx zAde0nN%{t)k#_wSNj|oCN<5!?jMy^RhRp5lDv_QyC1VT+5hqU0AqHJBA}zk?darxo zN*Km2Cia}_D*63*tfbYwG;-q@3-LnBt=@OlD~PxY(}|~U^N3l?T9UIJTY(WLo{%Ho zHUt%ky@?Ucy8sWzM}+#*S#odVx8y;2jHGL;=HyuGj^ru(#^k!L%Osz>OacXOW=RaW zt&-i1Mw1tZ7m}kaa!9l7h2+k+Wk5%|92}Zf0pQt2Wd5c*pucMoNfu;-pu`;@y;+Xr zaQ+d}@6#vp<-;Xp{N`!I)LC}$+|dC<-qD6c*gz-Ha7r{7xvM9bo8U<{i0lJ&C$OcFud4iECR(hRh3PJp19W|B{fgGk1AI=Q|{TQaM{h>!;E05_i8k%a!4 zE7>++9~r&nHCeFkEV;3I16i?f75LIO0ZjLt2G}-b1lzm-oDuaW^(PMmYWvaPZdfN` zgX3)S{e?N?{Ja%pw^^?x4f=cpZ_P}Jq+tO>MZ{(DZG)AY_`OqBl_s9Zr z*NZ~&szf!R-unj;HDM(Y>=iAU9N!+~s76XWM;k~MQFdhCj94<}r7yWZc^lcH1^2|Js*gYYc z(Cd{?1X>;u2iq4C<73uIPy!@=yD*bzs(S1d=dxX5U-&>$K6Rkv%C?zej{-w-@y2W3 zRg8&vMne`dAYGa+3kO(nPQ>?kqMTMZ08A;|$xitt&NLRx7*CWl;JOOA%O$Tqjiz-8}D z;M-&dI23-D7=K$0oCk)GAInogd*1}mdi^4zWb9V5f7=D*^N#1qm5vuBHiQ|RQ20{v z_U29^WWhCZ*%m)w)W{GV%=?|ZsW=0Yj245fsVl&fqQ}In_{+fO3QhK|$O0c4@|rMZ4xyu<9mEM~E|2$O?`Y+LzB4Sr)z#}k%Tec8{X{GH)UrQ2;E)~2&0kmg@Ioot z^zkb6dwv{NRJ;<_|9)-aX8p|i+50Ef&&>*Q$gtdvlI@K*icjXbNxXH-#jtU>SUsa9 zp_P~6JtMA{ctn}ZyZm!+@$e1X#H(9)NTS}X_fE1aVp}Fk40fw;eg~e?3h(vHH zy$^a7?*pqlxx?#M4A553uBgk>8*J0Eu5esx2yz{HmDO%bAX|MCc2A}aTJ_EzO>EE+ zz1kfDrbpVue@28knZ5}pL+22h>O6|_$AxXsv1y!e$cQriF#SiW1mT`vje1I?7daz zK+i!HrSA7VS>x9QpkrkhHsqZivYhploc7fQE+}5X%997cw0rt&6EiUafdy>a;JW}0 zXh_a79FA^UHN=z1E7oGDH9T_60w(XXf`$HO@R`LPBK2MZVbLKMbXs_u-IEs#yLq){ z_5G}&=N!xZAzed(ctH$VT?)~_51r6^2Ef*BRs&P74(#H7MnE()jqG#k1AAkh81*dC zMO!DEBBJelkbg4;B)o}+N5@V9*9jZ6Ce9dKp4Ay1ekVm9sm?I3usK>{YYNu<5e)Ac zK4MQMHie3{V)#0IAoBWC2Za@G0aQr{y5SZFXYI3v+JVo(QfV4^Z$QBI&La3MZ6=t# z@Hy+w90p&-r$GAbi{!?3*CZGF_lIR0U#PubzGm$go56+0hQas+kJ-)_e*?vCdMNB& z0W0qRhCHi7!5a$?vIX)E@SR&P)Df;#-!80TtxFA2W!Kf9L^6~;7-6BVl(`|PelWaU zrH97k^#w((P6OLfRqPi+AB}i<1PqQytd9|M>@SCEA|bA1$Q2RlEY!>d43?9cBUu$qTNW8zN?VkdBq=9 zK5T%N4>o~i9uvTk7VY2_Hvqg2omap2Xn~^IJOFN1z0u@Z3UEEc0~I^?q1TTdgFxL& zg%?SOgl|hM}2SL*dl$uj)+41>n-AMZn!S3OV;|g$gP^v$o9~!4|zXsEKQL6yh)& zjJwqyg&b`J)6Tde?aO<}iL&#o&$2A_yj{&vZ&O#;|GqB!*0v?9%FzeUS9zA|ed!3R zGBUt0_H-%vRtn9$4N<@I!(s7i5qy-T3sdv;&@Qh~NpxC2M1E-kjg4GTVB;w8^r9ZT z@yBCU5~!oT5F8Gdr@msvD?hV?d+8yaqK5FzqhPqf?g)FYa6D+Zz#n|w-XB)q*aIv& zc*6-(Lg4*L&%t>60VpHHOI_OkGH`lr3DMzJu*tGZ>OhCX;L~w?!pV(0ox^eTAfY=&Atxx#)v-wx(gZ3G?8_&|@BI#62{fLt1nU>E8*plY4Q zu;NuH+&aPx_6biU)$MK2z-&MCDy|QT?EIeGKG88UaOu({D z-blZP0;J9igY#z@KoalQ0o-3p_PkXex~!*;^9@2;gD(Tr&JT_V%vBSsEx=1%66Tfn z6XGWj@-*Nl;wM-9z2ukI|B3BS`>an7D>LWeC9>3_3?^4#TA{1DQiDMF6?LcMhDFQ ziTU3Ze~bCo-+!7f-9#-LeE@$;ae&u`Yz5<+6qA>p3?c7`O{w*RR#L0mni2&wv%!8H zQ(|xOeWKkGLvo#@5qRep0O=9kDDGqsm3E^LEVyV3a*#otVR>HF1FiJwqseG98z{KWS&*I)g9qU<c=Kos!R@GgG z`qKMAg9ht?L;NzLXhL5&BR3FHZ!OUFJtbg`#h+l)v~w&??*^Xh-NCWFEuhUhYv_?a z9T@KuLH$vCfyieZ`*_j`Fmu2$p+C2XnQZgl+mErxy>={$UFq_#W6|Y1B=EQKZ`gtI zh3sHgJ-Ehd5IoAYgk{TJ(DEq6Hci{gJ{z+RRDQa~7NrNl(F$ACthf=3UMfZ*FC9VV zZ_VNQDZ^0|jyJQDoe;mCI}%WA{}cG}cI3auf3@ccme*PAI4t!i>K`v5f35ySqREqZ z{U5rhuKa|4ynOss{r~3sb-nhFmwdmbe_KisJGlh1H{cOO^(aB5*@MvP+2x#lX$<+$ zJ%vf8Haw>Qn`5bD|Tg-(domnYs+L-X90nIy~pxigTcL zyGgL{pdA;uGX(m@i{a;0c2w46U2+*yN;0Zb$kHqyc$eG4H)jIN%1z6NUNcJ7gyDO1 zbogEMhfGV>)xI6v(9I1*r0)RHrSH&?S4MD2NO!_uGv{b#aFz)Oc@ae*>R%p`h~cMi@HTfhG(*{^dF1RIi)JLr!1ctsS+ zO`0q|o2N?>Eq!6H!wHfbSu-UQd)y@+U05S_y(aTsxaB=Lud};&Mpe0Y?;DrFLG9MW zf=42v2t|-LC-0Jw+p;A`yrRhY_8y*3Hq7z5v~~w!8q{5Eb1zil>$e%|KFIPYZnMbC z8SjyfO?m8{T^5?|(J4mX<|9t;XTg~%9ru>Uq zSKWWO)#a{vR)#;*_2W8)&l}a`c6O=DTk{-CZC>0oYV&rj%Ukn23V&Go$CY1?x1Qpf z2S92wH^cvAsRg=gJkeIm-+8u=Xa{B zSzMR7=BlA~>uhFQhq+!g*No6>qdx6em$zQ)tQj8HX3lr2%UrK@{=LZ~o$E5!gIzO7 zt*!Z7J?46?vu1=@n>oa#uI75Jvt~$F+u7&pGS^dFGn}f;o8PsL;(B^(hDEh`N8&#p z)#?@XIz`QprZ(@ny1e!JM9pfd%^TuYM{m6@QFD5&&3mpcZ@nH-b84y0>*G;JZ@mst za}udt!IM3I@P3<$f~5NT{2$!-%UTCF`fnHiS4IB%&u=t;>=s{Vu?YUceEbPwkvHH| U5r0M=)AWQ=tuPl~UQ_n}0IUV#nE(I) literal 0 HcmV?d00001 diff --git a/lightning_logs/version_3/hparams.yaml b/lightning_logs/version_3/hparams.yaml new file mode 100644 index 0000000..0967ef4 --- /dev/null +++ b/lightning_logs/version_3/hparams.yaml @@ -0,0 +1 @@ +{} diff --git a/pina/model/layers/__init__.py b/pina/model/layers/__init__.py index 7038c6d..71d781c 100644 --- a/pina/model/layers/__init__.py +++ b/pina/model/layers/__init__.py @@ -1,7 +1,11 @@ __all__ = [ 'ContinuousConvBlock', - 'ResidualBlock' + 'ResidualBlock', + 'SpectralConvBlock1D', + 'SpectralConvBlock2D', + 'SpectralConvBlock3D' ] from .convolution_2d import ContinuousConvBlock from .residual import ResidualBlock +from .spectral import SpectralConvBlock1D, SpectralConvBlock2D, SpectralConvBlock3D diff --git a/pina/model/layers/spectral.py b/pina/model/layers/spectral.py index e685164..465ff18 100644 --- a/pina/model/layers/spectral.py +++ b/pina/model/layers/spectral.py @@ -1,16 +1,326 @@ import torch import torch.nn as nn from ...utils import check_consistency +import warnings - -class SpectralConvBlock(nn.Module): +######## 1D Spectral Convolution ########### +class SpectralConvBlock1D(nn.Module): """ - Implementation of spectral convolution block. + Implementation of Spectral Convolution Block for one + dimensional tensor. """ - def __init__(self): + def __init__(self, input_numb_fields, output_numb_fields, n_modes): + """ + TODO + + :param input_numb_fields: _description_ + :type input_numb_fields: _type_ + :param output_numb_fields: _description_ + :type output_numb_fields: _type_ + :param n_modes: _description_ + :type n_modes: _type_ + """ super().__init__() + # check type consistency + check_consistency(input_numb_fields, int) + check_consistency(output_numb_fields, int) + + # assign variables + self._modes = n_modes + self._input_channels = input_numb_fields + self._output_channels = output_numb_fields + + # scaling factor + scale = (1. / (self._input_channels * self._output_channels)) + self._weights = nn.Parameter(scale * torch.rand(self._input_channels, + self._output_channels, + self._modes, + dtype=torch.cfloat)) + + def _compute_mult1d(self, input, weights): + """ + Compute the matrix multiplication of the input + with the linear kernel weights. + + :param input: The input tensor, expect of size + [batch, input_numb_fields, x]. + :type input: torch.Tensor + :param weights: The kernel weights, expect of + size [input_numb_fields, output_numb_fields, x]. + :type weights: torch.Tensor + :return: The matrix multiplication of the input + with the linear kernel weights. + :rtype: torch.Tensor + """ + return torch.einsum("bix,iox->box", input, weights) def forward(self, x): - pass \ No newline at end of file + """ + Forward computation for Spectral Convolution. + + :param x: The input tensor, expect of size + [batch, input_numb_fields, x]. + :type x: torch.Tensor + :return: The output tensor obtained from the + spectral convolution of size [batch, output_numb_fields, x]. + :rtype: torch.Tensor + """ + batch_size = x.shape[0] + + # if x.shape[-1] // 2 + 1 < self._modes: + # raise RuntimeError('Number of modes is too high, decrease number of modes.') + + # Compute Fourier transform of the input + x_ft = torch.fft.rfft(x) + + # Multiply relevant Fourier modes + out_ft = torch.zeros(batch_size, + self._output_channels, + x.size(-1) // 2 + 1, + device=x.device, + dtype=torch.cfloat) + out_ft[:, :, :self._modes] = self._compute_mult1d(x_ft[:, :, :self._modes], self._weights) + + # Return to physical space + return torch.fft.irfft(out_ft, n=x.size(-1)) + + +######## 2D Spectral Convolution ########### +class SpectralConvBlock2D(nn.Module): + """ + Implementation of spectral convolution block for two + dimensional tensor. + """ + + def __init__(self, input_numb_fields, output_numb_fields, n_modes): + super().__init__() + + # check type consistency + check_consistency(input_numb_fields, int) + check_consistency(output_numb_fields, int) + if not isinstance(n_modes, (tuple, list)): + raise ValueError('expected n_modes to be a list or tuple of len two, ' + 'with each entry corresponding to the number of modes ' + 'for each dimension ') + if len(n_modes) != 2: + raise ValueError('expected n_modes to be a list or tuple of len two, ' + 'with each entry corresponding to the number of modes ' + 'for each dimension ') + check_consistency(n_modes, int) + + + # assign variables + self._modes = n_modes + self._input_channels = input_numb_fields + self._output_channels = output_numb_fields + + # scaling factor + scale = (1. / (self._input_channels * self._output_channels)) + self._weights1 = nn.Parameter(scale * torch.rand(self._input_channels, + self._output_channels, + self._modes[0], + self._modes[1], + dtype=torch.cfloat)) + self._weights2 = nn.Parameter(scale * torch.rand(self._input_channels, + self._output_channels, + self._modes[0], + self._modes[1], + dtype=torch.cfloat)) + + def _compute_mult2d(self, input, weights): + """ + Compute the matrix multiplication of the input + with the linear kernel weights. + + :param input: The input tensor, expect of size + [batch, input_numb_fields, x, y]. + :type input: torch.Tensor + :param weights: The kernel weights, expect of + size [input_numb_fields, output_numb_fields, x, y]. + :type weights: torch.Tensor + :return: The matrix multiplication of the input + with the linear kernel weights. + :rtype: torch.Tensor + """ + return torch.einsum("bixy,ioxy->boxy", input, weights) + + def forward(self, x): + """ + Forward computation for Spectral Convolution. + + :param x: The input tensor, expect of size + [batch, input_numb_fields, x]. + :type x: torch.Tensor + :return: The output tensor obtained from the + spectral convolution of size [batch, output_numb_fields, x]. + :rtype: torch.Tensor + """ + + batch_size = x.shape[0] + + # Compute Fourier transform of the input + x_ft = torch.fft.rfft2(x) + + # Multiply relevant Fourier modes + out_ft = torch.zeros(batch_size, + self._output_channels, + x.size(-2), + x.size(-1)//2 + 1, + device=x.device, + dtype=torch.cfloat) + out_ft[:, :, :self._modes[0], :self._modes[1]] = self._compute_mult2d(x_ft[:, :, :self._modes[0], :self._modes[1]], + self._weights1) + out_ft[:, :, -self._modes[0]:, :self._modes[1]:] = self._compute_mult2d(x_ft[:, :, -self._modes[0]:, :self._modes[1]], + self._weights2) + + # Return to physical space + return torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1))) + + +######## 2D Spectral Convolution ########### +class SpectralConvBlock3D(nn.Module): + """ + Implementation of spectral convolution block for two + dimensional tensor. + """ + + def __init__(self, input_numb_fields, output_numb_fields, n_modes): + """ + TODO + + :param input_numb_fields: _description_ + :type input_numb_fields: _type_ + :param output_numb_fields: _description_ + :type output_numb_fields: _type_ + :param n_modes: _description_ + :type n_modes: _type_ + :raises ValueError: _description_ + :raises ValueError: _description_ + """ + super().__init__() + + # check type consistency + check_consistency(input_numb_fields, int) + check_consistency(output_numb_fields, int) + if not isinstance(n_modes, (tuple, list)): + raise ValueError('expected n_modes to be a list or tuple of len three, ' + 'with each entry corresponding to the number of modes ' + 'for each dimension ') + if len(n_modes) != 3: + raise ValueError('expected n_modes to be a list or tuple of len three, ' + 'with each entry corresponding to the number of modes ' + 'for each dimension ') + check_consistency(n_modes, int) + + # assign variables + self._modes = n_modes + self._input_channels = input_numb_fields + self._output_channels = output_numb_fields + + # scaling factor + scale = (1. / (self._input_channels * self._output_channels)) + self._weights1 = nn.Parameter(scale * torch.rand(self._input_channels, + self._output_channels, + self._modes[0], + self._modes[1], + self._modes[2], + dtype=torch.cfloat)) + self._weights2 = nn.Parameter(scale * torch.rand(self._input_channels, + self._output_channels, + self._modes[0], + self._modes[1], + self._modes[2], + dtype=torch.cfloat)) + self._weights3 = nn.Parameter(scale * torch.rand(self._input_channels, + self._output_channels, + self._modes[0], + self._modes[1], + self._modes[2], + dtype=torch.cfloat)) + self._weights4 = nn.Parameter(scale * torch.rand(self._input_channels, + self._output_channels, + self._modes[0], + self._modes[1], + self._modes[2], + dtype=torch.cfloat)) + + def _compute_mult3d(self, input, weights): + """ + Compute the matrix multiplication of the input + with the linear kernel weights. + + :param input: The input tensor, expect of size + [batch, input_numb_fields, x, y]. + :type input: torch.Tensor + :param weights: The kernel weights, expect of + size [input_numb_fields, output_numb_fields, x, y]. + :type weights: torch.Tensor + :return: The matrix multiplication of the input + with the linear kernel weights. + :rtype: torch.Tensor + """ + return torch.einsum("bixyz,ioxyz->boxyz", input, weights) + + def forward(self, x): + """ + Forward computation for Spectral Convolution. + + :param x: The input tensor, expect of size + [batch, input_numb_fields, x]. + :type x: torch.Tensor + :return: The output tensor obtained from the + spectral convolution of size [batch, output_numb_fields, x]. + :rtype: torch.Tensor + """ + + batch_size = x.shape[0] + + # Compute Fourier transform of the input + x_ft = torch.fft.rfftn(x, dim=[-3, -2, -1]) + + # Multiply relevant Fourier modes + out_ft = torch.zeros(batch_size, + self._output_channels, + x.size(-3), + x.size(-2), + x.size(-1)//2 + 1, + device=x.device, + dtype=torch.cfloat) + + slice0 = (slice(None), + slice(None), + slice(self._modes[0]), + slice(self._modes[1]), + slice(self._modes[2]), + ) + out_ft[slice0] = self._compute_mult3d(x_ft[slice0], self._weights1) + + slice1 = (slice(None), + slice(None), + slice(self._modes[0]), + slice(-self._modes[1], None), + slice(self._modes[2]), + ) + out_ft[slice1] = self._compute_mult3d(x_ft[slice1], self._weights2) + + slice2 = (slice(None), + slice(None), + slice(-self._modes[0], None), + slice(self._modes[1]), + slice(self._modes[2]), + ) + out_ft[slice2] = self._compute_mult3d(x_ft[slice2], self._weights3) + + slice3 = (slice(None), + slice(None), + slice(-self._modes[0], None), + slice(-self._modes[1], None), + slice(self._modes[2]), + ) + out_ft[slice3] = self._compute_mult3d(x_ft[slice3], self._weights4) + + # Return to physical space + return torch.fft.irfftn(out_ft, s=(x.size(-3), x.size(-2), x.size(-1))) + diff --git a/tests/test_layers/test_spectral_conv.py b/tests/test_layers/test_spectral_conv.py new file mode 100644 index 0000000..db9f399 --- /dev/null +++ b/tests/test_layers/test_spectral_conv.py @@ -0,0 +1,43 @@ +from pina.model.layers import SpectralConvBlock1D, SpectralConvBlock2D, SpectralConvBlock3D +import torch + +input_numb_fields = 3 +output_numb_fields = 4 +batch = 5 + +def test_constructor_1d(): + SpectralConvBlock1D(input_numb_fields=input_numb_fields, + output_numb_fields=output_numb_fields, + n_modes=5) + +def test_forward_1d(): + sconv = SpectralConvBlock1D(input_numb_fields=input_numb_fields, + output_numb_fields=output_numb_fields, + n_modes=4) + x = torch.rand(batch, input_numb_fields, 10) + sconv(x) + + +def test_constructor_2d(): + SpectralConvBlock2D(input_numb_fields=input_numb_fields, + output_numb_fields=output_numb_fields, + n_modes=[5, 4]) + +def test_forward_2d(): + sconv = SpectralConvBlock2D(input_numb_fields=input_numb_fields, + output_numb_fields=output_numb_fields, + n_modes=[5, 4]) + x = torch.rand(batch, input_numb_fields, 10, 10) + sconv(x) + +def test_constructor_3d(): + SpectralConvBlock3D(input_numb_fields=input_numb_fields, + output_numb_fields=output_numb_fields, + n_modes=[5, 4, 4]) + +def test_forward_3d(): + sconv = SpectralConvBlock3D(input_numb_fields=input_numb_fields, + output_numb_fields=output_numb_fields, + n_modes=[5, 4, 4]) + x = torch.rand(batch, input_numb_fields, 10, 10, 10) + sconv(x)