From fd18e1cf3475e4c461e3965b161342e70d6d126d Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Mon, 22 Jan 2018 06:59:41 -0800 Subject: [PATCH] New files --- layers/.attention.py.swp | Bin 0 -> 12288 bytes layers/.tacotron.py.swo | Bin 0 -> 28672 bytes layers/.tacotron.py.swp | Bin 0 -> 20480 bytes layers/__init__.py | 0 layers/attention.py | 86 ++++++++++++ layers/tacotron.py | 283 +++++++++++++++++++++++++++++++++++++++ models/.tacotron.py.swo | Bin 0 -> 12288 bytes models/.tacotron.py.swp | Bin 0 -> 12288 bytes models/__init__.py | 0 models/tacotron.py | 50 +++++++ utils/text/__init__.py | 78 +++++++++++ utils/text/cleaners.py | 91 +++++++++++++ utils/text/cmudict.py | 65 +++++++++ utils/text/numbers.py | 71 ++++++++++ utils/text/symbols.py | 24 ++++ 15 files changed, 748 insertions(+) create mode 100644 layers/.attention.py.swp create mode 100644 layers/.tacotron.py.swo create mode 100644 layers/.tacotron.py.swp create mode 100644 layers/__init__.py create mode 100644 layers/attention.py create mode 100644 layers/tacotron.py create mode 100644 models/.tacotron.py.swo create mode 100644 models/.tacotron.py.swp create mode 100644 models/__init__.py create mode 100644 models/tacotron.py create mode 100644 utils/text/__init__.py create mode 100644 utils/text/cleaners.py create mode 100644 utils/text/cmudict.py create mode 100644 utils/text/numbers.py create mode 100644 utils/text/symbols.py diff --git a/layers/.attention.py.swp b/layers/.attention.py.swp new file mode 100644 index 0000000000000000000000000000000000000000..e724ca72aae107f590b851b4c864a438e3b90182 GIT binary patch literal 12288 zcmeI2O>Epm6vwAUe6%#B0uB@j7?O(@HcnCyEs^%nA_Wj6LRAS8Rcp<9XE&zy%rYLQ zNhN|1H^hNY0VfV{2BdO86%sc%BDf*WJ(Lp?Lcopx%-C!1W>b|TFqVGWwf*MJoBw;yMTe!PB&f&sP$(MMn z^Mx}kmRd=%kg{#A@&%?f7dnw59B$u;LvQ01@Cw|R0)_6+pO_(!99vp61LkHQq<7u< z{EdzER$c+GfLFjP;1%!+cm=!yUIDMbe?kF0+(%x;O7@qlSua2LO?;N8{=+Na74Qmp z1-t@Y0k42pz$@St@CtYZyaHZD@E&*< zNN^fZaP=@D-+>Ro2Velsf>m%YxH?0~PvC3tCg9*1a1`7F{yId+=ioE&4#+?XE`SqY z1{?zW!Oyo7@)h_3^ugoc0dN@1fJ5NgK|(Hr*T8dN0sMIYF@kTw8{ibU5BzZ(A>V)< z@CrBy4uS*V*INns5?lr!gV(`v@EAA>uH8b&r{EK?1J=MQcnBN;KSIkNz!mT|cnd_p zXxsz-@CtYZ{uc_+@=#nCO3;|6X(S6hEOeWW-gd&*G#Q8t?>?cimdEHW?bMcG{&K4U{q zo6a}nP5RC}$*GW<~frUR!@qiQE~ znxjuh0YmDBT16=nJvvl;Gm%B!NNia)^_-2Q47hfA>noW>DHj91kq1tzs!XfPbyiK5 zYz0@hXd_=X44ZRYbO|&QR>x#UOl4SKq_TkbT}D$~nE{iQ9Wx^$N3m63$0D^CTg!H` z9`Dn>R2P}*1ty<2rWiFU7|#h59jJU*-{HrQ5Q zYlC9QRWRxlFDN#|1-Cm}VH71I(NPrC88pGzV=~mvd_R>az}ZCfXIl-SMoyekRR|5S z>OI&damw<%Ze3+Sfe_)U>=h}8V<%(Ts&`DnNEwsO-r1zLb?)qPdJgMdvUiXwii>TU zqUtFZMaC6&Sy`=BDV8DI0<20awB1Xf^|(^2dr$LxsnrU*+i=(lFX7$UO0L=2**dVX zT3V(|DYi}1nGJ9FS{C?KgqT0d3YN)#}ds>gv(r0*{5zt5>u)F zXscH88bwQq>t0?xiJFbKey&vkHy>_QH#D_v8vdr;FQH~)BdB+v>p~Cj+LNui?siYp zrfB~o^(LK_^I6Z4eQ()y#%UfWADw3HE~Jk89D4|lrZi@AsHmJPv}ZgA=Ghk7jj6V- zlB^fATqj~MvPXNeUPJr<2wJq$p~qZ1F^-zWV$SRJxJtWjU6=vpuqlxhVoS;*2SW zIy3feBXhgmTHjM0+kSS6ZS#&VtB^8@#{<_tebLDo?HLqZ$f05z5oCK literal 0 HcmV?d00001 diff --git a/layers/.tacotron.py.swo b/layers/.tacotron.py.swo new file mode 100644 index 0000000000000000000000000000000000000000..c637f4479218ebeffd233f9b79def064bac02fbd GIT binary patch literal 28672 zcmeI4dyr&TS-|g-Y?KIS<$+iNb#^9Yx|5#XncW1E4IQ()F*|X?8a9&+Fsxm7y6^N{ zr{{Ka@9o{$P?#9aA6i&Q5Uap9AxH!bhCie#qQnQVBE;}06!i~Z1tP_xtQ3h*=%yYD^sJic?zcfRkOuXkns)U7X9FKu7r@%kap`^&EduYUXNi@mr0 zwC4r!XfPVejh$Z`!`{NqDa_rUH5u+7$D^MKdTFwM$?uKQcogj)`sac;*`M0)?eV$Y z>q*E@J%VP2k?o!O@@z2weaK3pfOB~VMCmOw3mS^~8MY6;X5 zs3lNK;Cn#=>3FYq4-L7{Hs<;EeNW~6BK!T0ZPHy}uLt%$wC^sTXW8#RXx~3vxxd$b zztO&bx^n*;|6^XAZSLF4^;<20S^~8MY6;X5s3lNKpq4-_fm#B!1ZoM?5~w9GMFM`` z^Cr*rypMsL|2zNu@A>d|@D8{Su7j7tkHA-+=XsC818^FeZ~=VW{gq+R$Wlq_)lsavDvVU@M}wd-45CI7 z3|HH($T*6UX7hCE8Fz2j)o~AyS$8e$_k$>*Ld(-@91|7G8n1V*xv{eHK*p?$JCG0 zFdC?IEg<*TPJ*b%DLmbjy_UMJ`WgT_MB6aW9e7=jPV0X_b|JHP)k{QO^o zx4r7ScExf!A0;%_62?$ehcn}SHtx{II?#MeC9cx z_a1mR{3^T|9)L}FCCtM!Kj3*EfIop>gg#sY2jKv`0KR&$=beXlz?sDOtf4^-}OHoECmvGfRt0(&ODjX_3GOLcoqsdq_$;M=uhT~!2Gz_^# zK`%%WKR%Z|lNF)DL`9=ioft*Ip9nVA{S6=vSE)k_NmxwT6POv7BQ z52>FpxpJey!@>_b+ROXPn z+1Mj)T&l1gv_ z+Qr?`>0~Ztmn7HXNN%Z^nbI4wk~^K%Kx$?X${fmyii30#M@HHAR@Md$qYPR`DY&_p z%irn>yF8fDW}nIHRi5KhbDTb#r3Q%P=#j-^4c)27I}1!vJ1Vhjkh8u%Mb4%!c{hcT z$aKf<(PFj>%kIT&GS1{{pWj)&)`HFDuI$U+XsH#k*1D@yc(%h)vO@jnH7kr{_=xncI#AK4l5DdwxV}}}H@7yqw@wXkr~eOa)-@sLwQ#U@ z)<0KlEyLGcrX6N(*kr7_@}Rsn)zuv%-Ne%u_=7zdBw>Hz4@HAwZ;D1Kce=Dji<9xn zp1_vE>9S2@J&D1)&n7nFEG&MK&1O4g^CB6K65L=EVBTM|sL|fKN=MqE9>w7xjHpwx zhJ8?|Y_+>8oh4<^j57=Us&vfNOgQFiw=hjBHPsE?V5uIDMnj$5@}ZWJe@poei*-9} z>kAHtX*>>0wn<|PT=Apz#Q6XJgn#`C{BZIAUHt!7@bUi@WFO%7;bDm2Fx&)J!c+M6 zABTtFUN{9m49}3)C*VnV6y69YLCSa;G+-}0LAj5>J#Ydphx7RHZ-J9A2Q9b+zJ?F~ z1iS|x0P*jC1A7M#bI{|Nj6tib@T zffvADxDY;tPyc>+0LCzaSHfLzJzNGKbpC!}Et2tuebHNMV~#YIv!R!d6>WJv`Ha6z%QnadajsHP1&Cly}4_!V>hW>_%Eq(WxJ_3}KnPIS4q)g0g5 zQZ3KA+PZi2Pjv<3O4tu$qXhk7WmR)a3e`3CT%#=5jvd|3vPVny|%vd&;T>E~K9 zDD5Y|*A&%ycR0>h@={sui>!e-&K-oGm($^K&!v#}09(M$Q*(FNs8jjf=+dAnJ8iVekm${LF6Y-L#%p4K?)(jDfU z08t+ok7G+`79}p9U4HAb#YIWEJUW>y%0*46p-3znR|Yl`K*`H+{Q&1SLDsRNG3Zd?=BgdOd9Yc#ms0+HgJI#SdjYQV`wU!haX zIggs)66NOCoU%%Tt-kc^C|*zKs(7nKR(xV%uCEp9XoyXXO4V}HZKWR4NVTC|_e07r zlp6(;7=y|dLk1gSh4s~F)wx~ThuP8RT4Y2oQ9+}0Wm+t>=V|d$J}hMA(S4DGgN;#W zc1_w#&f&@vF1my=#mqt0VmZ}z%6`j^iY>~Fi#u$&c}cq%I7@9&mFBc=H5?ARjpEj> z&OjQ5aIDH@bcpWEbzr5q#X^Sm$S)s_a%d;b-(7Zk?AlTCZ4KW_1zk70Rcgnc)=gfC zidM7?Zdt1Jtz0)0lVChl?v~v*$wFRPHby!ytnCl2Xkqch*W{`!_;1Fo|W%pJ@-a8}(i8V*NiWn-M(sALqc^6g65Z(a4- zqeoFXS28Qs6<>X~+SQ~F)cT>(aL6M_EJr*Nu*ddtD@L5x>RGy z=<4#v3@5YaxXmaM5mxp9O_sJ9#Z5`4uuO*K-6}a-QtfQFRE4Z&8~{!|d-AqPyk+0E zwPm}NQ+O#zyLhgQC@NfR`;%0TP&ivKHE@9vp8c*L^=4Ub#qK?%^nVDZ&;K9Chy66Z zt@!`$_xFE`fB#2tKl}oW;4U}?i!cMvgQxNLAAz?)4_4q7_$vPW!>|l;-v1)_cl`Tb zhqpll*TGZx_V0%`!6F=kD_|DBh)@5g@B#QuI1A$M_h1FY@4pQG8{hss+z;#UDv)#i z7s8Li=ke=52j}4pun5QC3Ydi%_&V*^GC&OVtCm15fm#B!1ZoNVzm$NOLH2M`ybU50 zv_q+0))6P8+4eb+Z+c<@g$Yp-{qAU$7#WExw8@P#GT%Nh?Yz_Eh`Hwv!$Bl2OkoSv z-o!#PyV1l;_SS8#l$;;B%(vJGYU6gL;b1bF6wku)d@I@HH@4*DNt@-BAK&ObRi#e_ z^DQoWP-4T5#L6ak)mB=EX=SmJ`7IYC`h6!Q+$D7Bat2Ll|b*k1o0r~&*b*J%Q+aDN*U)tHuYR6 zgbwv6SI|8)y*f6d1s&0_&4^*3xJ=}oM~sLI4vAbRvZAB-&!XkCd*!_d?Pe%O;xy`i$f8N|C`V_P}7i{;9iqAA5m zQ#)SeG*X+=?LzWA+BjJb`wS7Ss5a;>UIgSNY{uu9IWZXR-3#M(&L(!A8!#GW`lDI z(c3MxC4jxD10&1rbyd>+<{D_X#LEFgSNJ`CNsa zT?>hRl~9uWTD7QByna?OiFKOK+ngygeUOtE>$DlltfIzDd=t5t57iYjD~L+lq^0m% z5sJKK)e`5MX7y8zY+Y(8qD=A-Ds%pS5$Dbk=hNc<-(d}&FW}q%2Yenr1oy&en1gTN z-~R=ig`?oXKj7bg7Cs0M!RtZ%{+r;(;YPRup26pT5*~#6UWxB)JP$MN}p30?&= zun#`NUdVglES!O5cs^VX|H9e+hv0qiDBK6Da1*>3UIZ7xUicjH{BsbwE{bgHS1p0> z1qm4SVopx-6?u7^%dCke1j@#t*kp3-Qh2zUipZuqTBVLw5xcVRGs8MsrH)pqqg9Ls z)jqCKo83BErH)q7ze2L-$GbipuA^0^o&ndEa=%m2Dt1R~ khfoApbTNLB00kXlng#K+?5~xwZaZqM?j7T*Ou&fuzbgW+1^@s6 literal 0 HcmV?d00001 diff --git a/layers/.tacotron.py.swp b/layers/.tacotron.py.swp new file mode 100644 index 0000000000000000000000000000000000000000..b59bc2151b24c05fd411613bbe5f56decc73edaa GIT binary patch literal 20480 zcmeHPeT-aH6@M*MEQ-jN5+i!sP3a7secjpexiHY(BDAzLP1&}{xDIb;?#@2;&3of} zZ`y5*t)G!Z2|{8t!Ndd=1Vkf<5&eTd41}VAPzd3}4;rKY7=&o5LHzoi`|<93Z)RtA zCH`aPh2PHZ+;h)8=iGbGJ?Gq4rbqT}yPjQQUufZZuVp>@D8DKA@Q19`QTj-*?L$q+}I9-TX;Q=HcmSAAPxh6qvsstVYD$;f7|VYOW`-58CYOofhAg? z6VD7^dYZLqv^+vRhSsiOXPvociJb>z0}Bi+FtEVD0s{*SEHJRZzybpc3@q^fV1c-O zl63@%tW?U(sLvI-&y&>ObGc7Vce48aD)sq-`qb^Zzy7>RHRyIlKKK||U|@lP1qK!v zSYTj*fdvK@7+7Fnfq?}E78qDyV1fUE1)PRueF1zgqXGc#|8@TV#_5*zJK!L(;Gr;$NJApfZuK@dj?Z6OlVzp&G58MNM4&cBzuo+kb zy!B3u349m07q|zw9k>D51grp_JH@ho0XzkK8E6960VQAs@UwSV);zEaSOuI6JiE%W zZUa6BoCW*}#gGSp7>Ix(uo5_m;>VM~gTO4%1TF;51y%tk1J7d~egfPHw1EJ)5!eUN z9GwZA0X&7m$x+}*T|i;_8DfQr5_A1p!);x#y~HMKu3xk)-8{)EY~GICL%c9tPJK`9 z^Jq(+MfhmCMaUK#Zr+YV$B)`U#0%vT8!cLzu{)EJlT|7#;@*rc^%_g4(!*+mvLv2$jVS~>wFg)Od zjRFZqt<+I{);#VvusD@tIx9Oy zlJoUJqvP=%ZWI?1gNU))A+IS?Nx8%Ixf2%Dn>~;@st%W@jEd8)6IHHpyol$P9NB{m zLJJbcOED-h2ios<$O2JXP|<}nA@!(Vi3Ocl{pQq*q~9U8ZA2Cw@;};6qv86^3XE^Y zsi3=+rmRlT#O%b%*&-HBuOa7n=yOkuU)fY*Q5?Dr4!#K^VKTY!YAFR5wy_b2CZa0( zhurH-u?tuMD=Tg1A{%DqBHO^KwA7}OLA6bZ!FKeMOJNFxDG{n;m?g7z#6T@)zSa%` z54%Jxuj`%pUE?s&=1S4r}w98-2C>=PW!j_4x z*FpxkjiOzHYjNXRO>rQv>un6+!yCSwh=iL?)4 zO(%vggiq;&A@^g}!MMiJ4Kw95n9t(_K?nnE0CD6tI*vzf)0fUD-J?9*5Dqo)I;X}ekx@XOzd&1|t=GqJlQOq9}Nk@+oTv|6dM#_x>9ah3OwSpG*VW(9Y zu`j!<#CSWZl=<)`qv#O3Ch)xA0M02Eb3Y2g8T_?AVvx`2crhGh=(Ko@cV{gzW}hGG zSdt%bBkK55pA>JH6S@x@(Uftf=Je}%;z|9@=_xeVCb0BKkyu-al_dEBYfSeMpT_Vy z)S|EZcu_qxODUnxgTyQ)wL)_O^5OYGC6@fU zFLc_78IYONKyeV(XPFwq_fucmPA3kUq0`WPwN%r_g@6A0nP(H2)u^e{yE?n@MGW(AOb>Q2T%ak1FL}($nT#4egr%W z+y>kPoDZx6P67UaeE;{rW5DNuj{-LW47eNkEbtj%8nA(X zBlmw4I1KCq_5u?C171SD|03{H;0~Y)i~|2i<^Ou+(`PQrql7^^1CF`pVJ-BLP!&z8 zKz+C`+u)xlJdy#3&ks;70UM!ctJuLfj)vieT(Hr=`9%9tFS~5|wS4x;jv{f@0N~%v zeu!;DAPbi-6hQigFgKOcFQ%NZjC3XlBc(U&p)CdJuHzwChj|u|qZzOnI*)UsPu8?- zy1F3CFQ~qdaQZ>u%i%brBTou6xxz;xqNTg(PuUNWDF*vpIeBS*$P+?o*s^Lh#y%q50Totgy%%JAk*Ttl&9-5RRWw~<9ZKKebuug|3X@w>`4r}a zT#`Bz!OAwZ;xawucVq18EDdvQZj<%B;i!#d6#qI6yG@8f^UDfO%=muC8>Y%$(FCKU zVa{NA1tVJ&O($!LuEY?g8gq_^%_tRF>3_C{AQ~rgse4UE%xO4nWZTBbuvb=PVGS~2 z)7{T~R^WCM%0UHqub5W0)PliQpn`o}->d;`8Q7dEkZ^juZN}2G-yJ39%s{mT^7bNh z&aZ17PbX6&7b&VkN|73kNse@>>}NV2AyGbFG{-M06C(y|+vtGL(9oC@{RF>0K7$^OH6U!;$EUeN|yYz z^qmrsT`QlyToWhuRsH+hSNVG5s+f9IL-$ojb^|12H7JPPfWo2|*;wN?DiyW1Q=h{@ z9Pv+d;L5mJVj3uEFIMHhWYPMP;U( zGW9uVgQ)p!Xd@&6XLS>k)`L<2^)-CFT*K;boMhCIo_V%G$BUNu0zVLDbr*fJ7D$R_ z*cJL3C_OkP!`v{kHEIjV{hvBUiOs^-e?xGW^*Oq_i)T3tLbvJSB8)||ICE3NQht)2 zuH0nN7}?Su=}Fd+-zLV9f2OZza*~SGV(vzuZ*{E1=)XymNvRlKY2OtLa}1g-bWYE|pYlKg6~KKG-YCxM!x@T;9I}};6|VXYyjQ|{0lk$KY`bQqrfA;HvxJFU>rCVpgO>7$n(Dr>;hH* z7VvZA^IrmVo_{ue{|$2ayMT`amjHi4-u@c!6z~n;F5r{E9$*p}2Q~vm;Lod(9|QLS zbzmEC32-*>cjWPJ0#5*s19b0y3vfQL4p<4$y}t##1RIbY9s+&{90ooGj05X|bAYpf zH(;;Vf%}0&z!Y#fFbu2&jw9!P6}TI?6}TCw0aOb(30Mg{ft>%-z!6{)I3M^6a{U*8 zuK|<5I=}*sBiE;Q0&WNBT>!fO|0{C;6Z-BSWXpSK6VM||j1(A97>KnOaoNzb2^Bir z)WirSkPmpeSJ60!)VFVkO>e=nqc55NB7}=HOLdn-WNflba$H7rnec70=-a(-WA7PW zUYBK1YtmjTF+1UXmz!qOV?8g*Mfd&hMbI32jINs&Wv&Ukc`V#QCG*8Lr0%`20GgtE z|<#D^M3|M@8E~{#>*uJ#ixxG#a&F2DcDYmsI5Ps&T{~r8`Z* zF2ynlksZ%%`gB2=DrcI5LIx=)w$|qqSBz|4 zqTlsnW|bPN$)6@@CQ7}qh(!LLdM(KiB+kVYnAG%UO(a=>T3bb==2^jrYo<9xZmt}n zs@s^STgD8jMV-@6kF3cLd@D_&f6p*Nur7<7sap;?WL6>Sybz@}N~&G5FzGRMr=$nE zP{t$Q{!3DkbjMEs3Rf4!`ZG|E|l&6w!4=Y(IiF#8coRj`wrCqkRSkgu_khj#h9++k_ zYlw{0nwhGBW!ka|W3h-zV@7(?B-01{)OM$LM2#~Mj#O2ot1-RSnDr6LQH>a1bYRxS zr4zj`Yre}5mASY}1fwfF{C}`WHhL1ni7QMqtaeNAN;!WdYklA%f5{(xmSBq?5L%J; oeE(>OP2#db9%HPK955x8yuzIe7jnyU3HP+xAjw*Q(B>BPKPswawg3PC literal 0 HcmV?d00001 diff --git a/layers/__init__.py b/layers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/layers/attention.py b/layers/attention.py new file mode 100644 index 00000000..fee9a873 --- /dev/null +++ b/layers/attention.py @@ -0,0 +1,86 @@ +import torch +from torch.autograd import Variable +from torch import nn +from torch.nn import functional as F + + +class BahdanauAttention(nn.Module): + def __init__(self, dim): + super(BahdanauAttention, self).__init__() + self.query_layer = nn.Linear(dim, dim, bias=False) + self.tanh = nn.Tanh() + self.v = nn.Linear(dim, 1, bias=False) + + def forward(self, query, processed_memory): + """ + Args: + query: (batch, 1, dim) or (batch, dim) + processed_memory: (batch, max_time, dim) + """ + if query.dim() == 2: + # insert time-axis for broadcasting + query = query.unsqueeze(1) + # (batch, 1, dim) + processed_query = self.query_layer(query) + + # (batch, max_time, 1) + alignment = self.v(self.tanh(processed_query + processed_memory)) + + # (batch, max_time) + return alignment.squeeze(-1) + + +def get_mask_from_lengths(memory, memory_lengths): + """Get mask tensor from list of length + + Args: + memory: (batch, max_time, dim) + memory_lengths: array like + """ + mask = memory.data.new(memory.size(0), memory.size(1)).byte().zero_() + for idx, l in enumerate(memory_lengths): + mask[idx][:l] = 1 + return ~mask + + +class AttentionWrapper(nn.Module): + def __init__(self, rnn_cell, attention_mechanism, + score_mask_value=-float("inf")): + super(AttentionWrapper, self).__init__() + self.rnn_cell = rnn_cell + self.attention_mechanism = attention_mechanism + self.score_mask_value = score_mask_value + + def forward(self, query, attention, cell_state, memory, + processed_memory=None, mask=None, memory_lengths=None): + if processed_memory is None: + processed_memory = memory + if memory_lengths is not None and mask is None: + mask = get_mask_from_lengths(memory, memory_lengths) + + # Concat input query and previous attention context + cell_input = torch.cat((query, attention), -1) + + # Feed it to RNN + cell_output = self.rnn_cell(cell_input, cell_state) + + # Alignment + # (batch, max_time) + alignment = self.attention_mechanism(cell_output, processed_memory) + + if mask is not None: + mask = mask.view(query.size(0), -1) + alignment.data.masked_fill_(mask, self.score_mask_value) + + # Normalize attention weight + alignment = F.softmax(alignment, dim=0) + + # Attention context vector + # (batch, 1, dim) + attention = torch.bmm(alignment.unsqueeze(1), memory) + + # (batch, dim) + attention = attention.squeeze(1) + + return cell_output, attention, alignment + diff --git a/layers/tacotron.py b/layers/tacotron.py new file mode 100644 index 00000000..00db03d1 --- /dev/null +++ b/layers/tacotron.py @@ -0,0 +1,283 @@ +# coding: utf-8 +import torch +from torch.autograd import Variable +from torch import nn + +from .attention import BahdanauAttention, AttentionWrapper +from .attention import get_mask_from_lengths + +class Prenet(nn.Module): + def __init__(self, in_dim, sizes=[256, 128]): + super(Prenet, self).__init__() + in_sizes = [in_dim] + sizes[:-1] + self.layers = nn.ModuleList( + [nn.Linear(in_size, out_size) + for (in_size, out_size) in zip(in_sizes, sizes)]) + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.5) + + def forward(self, inputs): + for linear in self.layers: + inputs = self.dropout(self.relu(linear(inputs))) + + return inputs + + +class BatchNormConv1d(nn.Module): + def __init__(self, in_dim, out_dim, kernel_size, stride, padding, + activation=None): + super(BatchNormConv1d, self).__init__() + self.conv1d = nn.Conv1d(in_dim, out_dim, + kernel_size=kernel_size, + stride=stride, padding=padding, bias=False) + # Following tensorflow's default parameters + self.bn = nn.BatchNorm1d(out_dim, momentum=0.99, eps=1e-3) + self.activation = activation + + def forward(self, x): + x = self.conv1d(x) + if self.activation is not None: + x = self.activation(x) + return self.bn(x) + + +class Highway(nn.Module): + def __init__(self, in_size, out_size): + super(Highway, self).__init__() + self.H = nn.Linear(in_size, out_size) + self.H.bias.data.zero_() + self.T = nn.Linear(in_size, out_size) + self.T.bias.data.fill_(-1) + self.relu = nn.ReLU() + self.sigmoid = nn.Sigmoid() + + def forward(self, inputs): + H = self.relu(self.H(inputs)) + T = self.sigmoid(self.T(inputs)) + return H * T + inputs * (1.0 - T) + + +class CBHG(nn.Module): + """CBHG module: a recurrent neural network composed of: + - 1-d convolution banks + - Highway networks + residual connections + - Bidirectional gated recurrent units + """ + + def __init__(self, in_dim, K=16, projections=[128, 128]): + super(CBHG, self).__init__() + self.in_dim = in_dim + self.relu = nn.ReLU() + self.conv1d_banks = nn.ModuleList( + [BatchNormConv1d(in_dim, in_dim, kernel_size=k, stride=1, + padding=k // 2, activation=self.relu) + for k in range(1, K + 1)]) + self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1) + + in_sizes = [K * in_dim] + projections[:-1] + activations = [self.relu] * (len(projections) - 1) + [None] + self.conv1d_projections = nn.ModuleList( + [BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1, + padding=1, activation=ac) + for (in_size, out_size, ac) in zip( + in_sizes, projections, activations)]) + + self.pre_highway = nn.Linear(projections[-1], in_dim, bias=False) + self.highways = nn.ModuleList( + [Highway(in_dim, in_dim) for _ in range(4)]) + + self.gru = nn.GRU( + in_dim, in_dim, 1, batch_first=True, bidirectional=True) + + def forward(self, inputs, input_lengths=None): + # (B, T_in, in_dim) + x = inputs + + # Needed to perform conv1d on time-axis + # (B, in_dim, T_in) + if x.size(-1) == self.in_dim: + x = x.transpose(1, 2) + + T = x.size(-1) + + # (B, in_dim*K, T_in) + # Concat conv1d bank outputs + x = torch.cat([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks], dim=1) + assert x.size(1) == self.in_dim * len(self.conv1d_banks) + x = self.max_pool1d(x)[:, :, :T] + + for conv1d in self.conv1d_projections: + x = conv1d(x) + + # (B, T_in, in_dim) + # Back to the original shape + x = x.transpose(1, 2) + + if x.size(-1) != self.in_dim: + x = self.pre_highway(x) + + # Residual connection + x += inputs + for highway in self.highways: + x = highway(x) + + if input_lengths is not None: + x = nn.utils.rnn.pack_padded_sequence( + x, input_lengths, batch_first=True) + + # (B, T_in, in_dim*2) + self.gru.flatten_parameters() + outputs, _ = self.gru(x) + + if input_lengths is not None: + outputs, _ = nn.utils.rnn.pad_packed_sequence( + outputs, batch_first=True) + + return outputs + + +class Encoder(nn.Module): + def __init__(self, in_dim): + super(Encoder, self).__init__() + self.prenet = Prenet(in_dim, sizes=[256, 128]) + self.cbhg = CBHG(128, K=16, projections=[128, 128]) + + def forward(self, inputs, input_lengths=None): + inputs = self.prenet(inputs) + return self.cbhg(inputs, input_lengths) + + +class Decoder(nn.Module): + def __init__(self, memory_dim, r): + super(Decoder, self).__init__() + self.memory_dim = memory_dim + self.r = r + self.prenet = Prenet(memory_dim * r, sizes=[256, 128]) + # attetion RNN + self.attention_rnn = AttentionWrapper( + nn.GRUCell(256 + 128, 256), + BahdanauAttention(256) + ) + + self.memory_layer = nn.Linear(256, 256, bias=False) + + # concat and project context and attention vectors + # (prenet_out + attention context) -> output + self.project_to_decoder_in = nn.Linear(512, 256) + + # decoder RNNs + self.decoder_rnns = nn.ModuleList( + [nn.GRUCell(256, 256) for _ in range(2)]) + + self.proj_to_mel = nn.Linear(256, memory_dim * r) + self.max_decoder_steps = 200 + + def forward(self, decoder_inputs, memory=None, memory_lengths=None): + """ + Decoder forward step. + + If decoder inputs are not given (e.g., at testing time), as noted in + Tacotron paper, greedy decoding is adapted. + + Args: + decoder_inputs: Encoder outputs. (B, T_encoder, dim) + memory: Decoder memory. i.e., mel-spectrogram. If None (at eval-time), + decoder outputs are used as decoder inputs. + memory_lengths: Encoder output (memory) lengths. If not None, used for + attention masking. + """ + B = decoder_inputs.size(0) + + processed_memory = self.memory_layer(decoder_inputs) + if memory_lengths is not None: + mask = get_mask_from_lengths(processed_memory, memory_lengths) + else: + mask = None + + # Run greedy decoding if memory is None + greedy = memory is None + + if memory is not None: + # Grouping multiple frames if necessary + if memory.size(-1) == self.memory_dim: + memory = memory.view(B, memory.size(1) // self.r, -1) + assert memory.size(-1) == self.memory_dim * self.r,\ + " !! Dimension mismatch {} vs {} * {}".format(memory.size(-1), + self.memory_dim, self.r) + T_decoder = memory.size(1) + + # go frames - 0 frames tarting the sequence + initial_input = Variable( + decoder_inputs.data.new(B, self.memory_dim * self.r).zero_()) + + # Init decoder states + attention_rnn_hidden = Variable( + decoder_inputs.data.new(B, 256).zero_()) + decoder_rnn_hiddens = [Variable( + decoder_inputs.data.new(B, 256).zero_()) + for _ in range(len(self.decoder_rnns))] + current_attention = Variable( + decoder_inputs.data.new(B, 256).zero_()) + + # Time first (T_decoder, B, memory_dim) + if memory is not None: + memory = memory.transpose(0, 1) + + outputs = [] + alignments = [] + + t = 0 + current_input = initial_input + while True: + if t > 0: + current_input = outputs[-1] if greedy else memory[t - 1] + # Prenet + current_input = self.prenet(current_input) + + # Attention RNN + attention_rnn_hidden, current_attention, alignment = self.attention_rnn( + current_input, current_attention, attention_rnn_hidden, + decoder_inputs, processed_memory=processed_memory, mask=mask) + + # Concat RNN output and attention context vector + decoder_input = self.project_to_decoder_in( + torch.cat((attention_rnn_hidden, current_attention), -1)) + + # Pass through the decoder RNNs + for idx in range(len(self.decoder_rnns)): + decoder_rnn_hiddens[idx] = self.decoder_rnns[idx]( + decoder_input, decoder_rnn_hiddens[idx]) + # Residual connectinon + decoder_input = decoder_rnn_hiddens[idx] + decoder_input + + output = decoder_input + + # predict mel vectors from decoder vectors + output = self.proj_to_mel(output) + + outputs += [output] + alignments += [alignment] + + t += 1 + + if greedy: + if t > 1 and is_end_of_frames(output): + break + elif t > self.max_decoder_steps: + print("Warning! doesn't seems to be converged") + break + else: + if t >= T_decoder: + break + + assert greedy or len(outputs) == T_decoder + + # Back to batch first + alignments = torch.stack(alignments).transpose(0, 1) + outputs = torch.stack(outputs).transpose(0, 1).contiguous() + + return outputs, alignments + + +def is_end_of_frames(output, eps=0.2): + return (output.data <= eps).all() diff --git a/models/.tacotron.py.swo b/models/.tacotron.py.swo new file mode 100644 index 0000000000000000000000000000000000000000..b4cfd7c58104b01c73bb3f9d66bfc10d8e7c0613 GIT binary patch literal 12288 zcmeI2O>Epm6vro&uL5lJo zRpkIjZasknhf08i;070j#Dxz52e@$KP^q9wz#&%-Nc?Tj*xprX5(y!J#?r65<9Rdl z-hZC`ys1~(ryn^^?{j7ejy;6@c!fQCxUrACd4>?i(?J?#9oIhM$nRX+qUb$eOxPOo z^abWi(OUNWRPr=w#c9AI(UPj|4A-|ZpRq9t7zMVdz)1EF&F>;J-A+3j(YWgXz2o*% zTO4f6i~>dhqkvJsC}0#Y3K#{90!D%Vg936mMP5Z9*QsE->T_!1Q#H+pQNSo*6fg=H z1&jhl0i%FXz$jo8FbWt2i~|2b1-yWey*CpwwG+58F)8J3|eHpwD-U9;M59UE06Fz@Hd>6s@;8Sn`JPuBP8Swi} z@DIEJd~g)(2OBpM@)1bD6JQ_MxPg!_!582>SOiA^0l!>N$j{(2a1jK+2aDiNuoGO_ zfq21N;1pN}ZO{UH!4B{>{uuKWxB#93PlE%1g8krMJ*YZ3lOvwcm_@yGB!?p@Y#N0L z^LW+vqHvJJERh1vs$<-Sx>T^J??j%Ey`sw+w`guMa+nI4Flp5sX;!P%&SI%r%Xv9u zYu2Jo4|VLqg{(6O<7&MI&H9(^eAqRPe7xLJ6G?h}z%un>7N>l@7qMg@SA{#?tlX*vVbe-stz^LR`K?km6~;Yf90z$d zS?1CBLYU`7ny_*>-os%FW$UW;VS??~ewU_UztD9gSWhhxd*a1~N{$IBbOh?Ol}`&@ z#-^`aui^@m@r^46mea*6EC{e0YQkD}EW)NRc0~9Jv)aY%0qfI#%GW#|SXp)|7vPH8 zD>qYU!m<2ZvliWCUSq8!aURPrkvvRD}@ww;;!b}jSu?3xhhju@vx24${fY4!KV(98F)7Y1kD+V^Vi z!g=j#A@q`NTDYJru^z9ak!aeq^g4^kCVJI;iiuRkL=$Jm&DD)6+ecU5`_LFSQewkG zeXF`hL&h!r&RcO8YC64Mn1r&|v#Rj4kCWcG$2?q|$hp!$f6*D`cG}%J?6Zr1YRVP) zn-9x!y<9RAVL~A{&#fv_HKU5fop0O3$IV+%9fyS^Kb{7oD07VTc^a2pxXZ6IE*#V= z?dbNa$Tr7rcctyxoBFm{ealis44H2#oWM<>exLp4E1bXTanF~y$ZAy}t8!JxCY7!- zEGkOAbEe*9&!}bPF!j#VyRxWlq2B4&*l8FKQ!eRA&hR_I)(=^>elWB12fdbilnp_n!86=&eMIC@=D%q(Sj>0gq)1G=CwO F`~}aao>l+= literal 0 HcmV?d00001 diff --git a/models/.tacotron.py.swp b/models/.tacotron.py.swp new file mode 100644 index 0000000000000000000000000000000000000000..57a89620f998d6d711a44fffd0e88e16a9774b11 GIT binary patch literal 12288 zcmeI2&yUI_Lon%I(>DYlJW<79y)U?tSIiY@57Y?=G9wYzD$n5Qiqv7`pQaA-_YvLEk}NLi^BVXaYS5{q-;*H=$pk z&!Ii&97LgKpbh9x#QqcXBXk400G)>(hw8lU!*y{XI|7b?Bj5-)0*=6cNPwD4GF?he zbCyM-)J3V4PqUP>Sa$6=OBXz6Tq}GbG1saM4XI+;Jji0Dqo&JSwHW@($U-R1nf7{t z47=HSPcc=m6?~MkLvPoor^kL1Le-h1c{ksHR{i^aJ#4M#ZKe+5f^`!O=p_U~5edcd z%^6cFDg=zraG0|PNyl=*DtbA~gOwNE*g-^SgVrt^ktsaRyD$Yb@vPKRR+ zh0VY`_G-7_vf{oAy~``*s|(i*Y~AHkmLymWD80il?+3M)cufz{83z3kTO8gp?Tyji~@b+w|T=qWn!>0Gi8S2L)`n##twS9j09 zZio2vjd0tqimW;#oC>bOcPD3GXjk`IFTdK2DG0bZXG+6Vx=(wBP@5X^Fx_ja)MFOj zUE*u|f(PX-#dtYUbFG2>U@Iydk0xia&LRHv z$Tjs_536duow9u-gaS^Uoxoz zdyJi?c_E~xmn6f!>D!|7ZK3*g;Ws`?on{!xuIs?8r5$aWjhPm5c3^ur?>!9SQj3L* d6We 1 and s[0] == '@': + s = '{%s}' % s[1:] + result += s + return result.replace('}{', ' ') + + +def _clean_text(text, cleaner_names): + for name in cleaner_names: + cleaner = getattr(cleaners, name) + if not cleaner: + raise Exception('Unknown cleaner: %s' % name) + text = cleaner(text) + return text + + +def _symbols_to_sequence(symbols): + return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] + + +def _arpabet_to_sequence(text): + return _symbols_to_sequence(['@' + s for s in text.split()]) + + +def _should_keep_symbol(s): + return s in _symbol_to_id and s is not '_' and s is not '~' diff --git a/utils/text/cleaners.py b/utils/text/cleaners.py new file mode 100644 index 00000000..fe0a46a2 --- /dev/null +++ b/utils/text/cleaners.py @@ -0,0 +1,91 @@ +#-*- coding: utf-8 -*- + + +''' +Cleaners are transformations that run over the input text at both training and eval time. + +Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" +hyperparameter. Some cleaners are English-specific. You'll typically want to use: + 1. "english_cleaners" for English text + 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using + the Unidecode library (https://pypi.python.org/pypi/Unidecode) + 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update + the symbols in symbols.py to match your data). +''' + +import re +from unidecode import unidecode +from .numbers import normalize_numbers + + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r'\s+') + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ + ('mrs', 'misess'), + ('mr', 'mister'), + ('dr', 'doctor'), + ('st', 'saint'), + ('co', 'company'), + ('jr', 'junior'), + ('maj', 'major'), + ('gen', 'general'), + ('drs', 'doctors'), + ('rev', 'reverend'), + ('lt', 'lieutenant'), + ('hon', 'honorable'), + ('sgt', 'sergeant'), + ('capt', 'captain'), + ('esq', 'esquire'), + ('ltd', 'limited'), + ('col', 'colonel'), + ('ft', 'fort'), +]] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def expand_numbers(text): + return normalize_numbers(text) + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, ' ', text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_cleaners(text): + '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text): + '''Pipeline for non-English text that transliterates to ASCII.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_cleaners(text): + '''Pipeline for English text, including number and abbreviation expansion.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_numbers(text) + text = expand_abbreviations(text) + text = collapse_whitespace(text) + return text diff --git a/utils/text/cmudict.py b/utils/text/cmudict.py new file mode 100644 index 00000000..6673546b --- /dev/null +++ b/utils/text/cmudict.py @@ -0,0 +1,65 @@ +#-*- coding: utf-8 -*- + + +import re + + +valid_symbols = [ + 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', + 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', + 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', + 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', + 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', + 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', + 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' +] + +_valid_symbol_set = set(valid_symbols) + + +class CMUDict: + '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' + + def __init__(self, file_or_path, keep_ambiguous=True): + if isinstance(file_or_path, str): + with open(file_or_path, encoding='latin-1') as f: + entries = _parse_cmudict(f) + else: + entries = _parse_cmudict(file_or_path) + if not keep_ambiguous: + entries = {word: pron for word, + pron in entries.items() if len(pron) == 1} + self._entries = entries + + def __len__(self): + return len(self._entries) + + def lookup(self, word): + '''Returns list of ARPAbet pronunciations of the given word.''' + return self._entries.get(word.upper()) + + +_alt_re = re.compile(r'\([0-9]+\)') + + +def _parse_cmudict(file): + cmudict = {} + for line in file: + if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): + parts = line.split(' ') + word = re.sub(_alt_re, '', parts[0]) + pronunciation = _get_pronunciation(parts[1]) + if pronunciation: + if word in cmudict: + cmudict[word].append(pronunciation) + else: + cmudict[word] = [pronunciation] + return cmudict + + +def _get_pronunciation(s): + parts = s.strip().split(' ') + for part in parts: + if part not in _valid_symbol_set: + return None + return ' '.join(parts) diff --git a/utils/text/numbers.py b/utils/text/numbers.py new file mode 100644 index 00000000..4ce2d389 --- /dev/null +++ b/utils/text/numbers.py @@ -0,0 +1,71 @@ +#-*- coding: utf-8 -*- + +import inflect +import re + + +_inflect = inflect.engine() +_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') +_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') +_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') +_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') +_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') +_number_re = re.compile(r'[0-9]+') + + +def _remove_commas(m): + return m.group(1).replace(',', '') + + +def _expand_decimal_point(m): + return m.group(1).replace('.', ' point ') + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split('.') + if len(parts) > 2: + return match + ' dollars' # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + return '%s %s' % (dollars, dollar_unit) + elif cents: + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s' % (cents, cent_unit) + else: + return 'zero dollars' + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return 'two thousand' + elif num > 2000 and num < 2010: + return 'two thousand ' + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + ' hundred' + else: + return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') + else: + return _inflect.number_to_words(num, andword='') + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r'\1 pounds', text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text diff --git a/utils/text/symbols.py b/utils/text/symbols.py new file mode 100644 index 00000000..cd873a64 --- /dev/null +++ b/utils/text/symbols.py @@ -0,0 +1,24 @@ +#-*- coding: utf-8 -*- + + +''' +Defines the set of symbols used in text input to the model. + +The default is a set of ASCII characters that works well for English or text that has been run +through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. +''' +from Tacotron.utils.text import cmudict + +_pad = '_' +_eos = '~' +_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? ' + +# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): +_arpabet = ['@' + s for s in cmudict.valid_symbols] + +# Export all symbols: +symbols = [_pad, _eos] + list(_characters) + _arpabet + + +if __name__ == '__main__': + print(symbols)