From 0e9b7158a9d8845c243c4332dc2c641904ffa1c9 Mon Sep 17 00:00:00 2001 From: Longcat2957 Date: Sun, 9 Jul 2023 05:35:42 +0000 Subject: [PATCH] test commit --- .gitignore | 9 ++ .../unicode_data/13.0.0/charmap.json.gz | Bin 0 -> 20988 bytes eval.py | 0 models/__init__.py | 1 + models/resnet.py | 124 +++++++++++++++ models/vgg.py | 65 ++++++++ train.py | 66 ++++++++ utils/__init__.py | 2 + utils/datasets.py | 148 ++++++++++++++++++ utils/logging.py | 0 utils/misc.py | 52 ++++++ 11 files changed, 467 insertions(+) create mode 100644 .gitignore create mode 100644 .hypothesis/unicode_data/13.0.0/charmap.json.gz create mode 100644 eval.py create mode 100644 models/__init__.py create mode 100644 models/resnet.py create mode 100644 models/vgg.py create mode 100644 train.py create mode 100644 utils/__init__.py create mode 100644 utils/datasets.py create mode 100644 utils/logging.py create mode 100644 utils/misc.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..75062a3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +/runs +/runs/* +*.pyc +__pycache__/ +*.py[cod] +*$py.class +.DS_store +venv +*.so diff --git a/.hypothesis/unicode_data/13.0.0/charmap.json.gz b/.hypothesis/unicode_data/13.0.0/charmap.json.gz new file mode 100644 index 0000000000000000000000000000000000000000..c377c96bbdce89cd27cb8a918af748b6d12c64d7 GIT binary patch literal 20988 zcmbUIWmFwa`0feg?ht~zOK^Ah;O_1Of;++8-8Sy-?iL8{x^Z{c&G4M_f8RN4&6ip8 z>Z6AJY6_!tXaVn^#4u`bvgNb8?+=IYRl%gcRw5a=Zl@FAq$AR~9JDXUC9uRMG#KVCn^30Q^6)3_orsp0W&M9!ooy zz1|Ns?Q?`a%8Lv?ADNyzhT8xmedxC5yk@u0&6YCWqLJ!Cy}HsbDK4#)4@6I6YpG4^ zeKfzF>>?n1_Ail@07Zos?KsSiT;A5J$Bm4ofK zMxP7SJny|ne9hC-qrO<~LC$f0{(8EyotK9z$rhbLHjQ^<5tTiQRy1pY)q)iE(!9U#6GMA+UCnLp#;;3Ux@mo(Uy_iQae*z zNN#%0H7wh7Wf12QljB{?q?q)(7os+~q?YVqi)i6{vtAu`kyMu(9QW}O9eb(;AvIHZ zZ@1lj6uo^Yhb^tecJ&eXPiB&0y(n$v;P)t97+G#Q%dc#v${34d*z9tQ_r|Kpxu)`F z@gjgXjOP*b6LH;UmB39#O8)hIUDx$fO4n(OCkEH)28)h0m5QgbmPCxisQ^zVF(75E zd=8USXd&CS055SdAGOFX?$1j4PF)%N#Gb)*C#q@3>q7#8-XLl>mmJIGbZ##{*r)Sj zUR25bOPUD|{ z0s&WvE<#5&wi(&5qgC3bRzHat?e}%2Q4*}F4GWC4BYcuoFw(r#}Az|(y?BjHgfRf z^liK`W$oJzi*nsk%_+eKD%}L{v6b_E*{-c(DmkREpM(940v%JTl{?&c+DF@bQ-AL& zU#eP?Ph}LlP49G{V+$=sCn~Coyr~bjyIj^$;Fnq5Kf=~z*Wv?JV#d~F5v1Qy&IUId z()yZ%YoMRb9ab$VdgOfFuQzAikA!!H#_<5nDNUp|KHAe|Ey<3H<&{76Tu&=lOt&sd z)b1wu`GM*fmpO)+jWs%swdOMO?e?Mr%^StO`)kWGEqemkt_WG;p*j_R|KzIC^t`|4Ef3CXDn+|_~PCl0g zm22t*D~J{`od*epxr3WucO2(Z>enTZq1Vi{1?nVB#HR3Tb+~Pn4GRAf79m?Vf;9sq5y2*0Xw2n*=Hn~172uy&>e3G6Ctd*`h zlzEpZJHlRcPjyN#xy#aE7TkPFFe$t{a23xLUt(V=N6*oj)=lFeVvK4>XQtI)1lz(} zJ+bgi)u^cYI-Ev2*)+H5<`UXIZ_fRGwfP6EF4Aa892v?HC7G3MV~pK*-Qpyj-F-d% zS+unKhWC*8{MoR`R>4?27hxk;Zqd_O+#K|lsH55lT{LrbTE37v_h+cnJLShppx)x8 z*lOL=b0>?2Op`$kotx}#(o{kn(u=inKtp?_a|l_V=a*rGJiF;nPU*7>iB-u+LZOJZyW zyjRQ~mGWPvpQ21DAWq|UB3%xCL*iz=!hC+I2Y*A-hgPasLp{T*ss5X@V~?%4%ePtb z>+YpRw{Op{pZZsx)gJvJakkRGio-Q96K%D7Vl=Y7cXa~BTM7v<<%%^cJ1uNZ?qwqn ziU5txUa2DLlZjpJZMS|Kv`q%X>qx%9`e^OSF20WC%I`K8VH=iK>hWLK6+XmRIwogc3X>n!gF#FSnKmW+UYivs zsZM^i^W>`li-XRf5O0qIBJI6mJR%{Mzv_(?dm#Ck*2?$A0S|mXe|H=HO$)cb9DWr$ zbj}}GjNKgcUP^H!TW))Jnq8vwrhyFkWVDI8w8t$tc!dbq7LT7CNAl|B;|!NyG_4}&|b%|K4Gt3bl@=iI%GrSj4h=J<^L^ZOF+GwAs> zukM$vPcs{doYbn}WZK(_S!WuCatwRE!QHZ-&RXb>=43oS?b+uQc-#q5`4x5U)BNTo zFQkO+Ty29bMaubGuW|B4^~*2d0bIjUTxZZ*W3#~qVVM)3&qu_~Yg0{S^Ui1S@;q-Of^>~FwOrj=~b!qU`J2n3@^ssS`KM0}hp#RFceT%TG)*b8QDdNU`lEAi#SY5;VpW0qU zee{iQR_VM-d2zD_UrfBNLd;i0cD@*=v0w9yuxh?E+qt6H89;H1k1=2z# zvK2oSaM_1bctB6u^VaRwuax7HKV2y9D+p}pBFCBjwXe@2C8%X?ROZBSX99EQPZ@2W zqk9kkkUdZQ^e%`#th+ZYmtWoSrPv|P=I!sjkBM7n?|4?Z$O+cbSWRcp`Hd-q}G8c%ZEzL$X& z!>@7OmgO+eNGV9N`*V?C-QrmBPFLId9d!KmB9coW38SR18g3=3z`1TzxC;V`>FV8SQXdGb4q;Y#4*YGy6mOJ*9Pjo_t}mW zCHLI<47)--1g_gnFmtXqI-qZqhkoeG!8W|-J8?Yl8qU8tx?g|RLI_g8_C8%nRsS;f zKDk`*!JBMx$a)aen5P_McJTVVLb(jP09e7{Y)FhPd-J{uWq&j2|`il>g>?JR>wXTX&4Scfpy;V&FjTFH-r$VrQ6q6o3Nt0sP4& z^5k(z5U2mRvDyE8k%@YVQqzm~*CV&Zf44`IThO;V;>HP|5b)=$GiHD2^X}u=Qg!L> z3J$3tA+E5$nv}`^f?!hr zmd!zXipy7u)r!kHnhaij!x56dV=UPdse6y!Jl$=&7H&(PIhN+xil%dI;}#!&dL~`| z?Cx#dd2>$7&=LBo`{r`t-5vkM=JMwUfY$e^f&&sIVEq>90;SDM;#Va{!y{ET0ida; zGYy<;*!_I@d1{*_6W#{3Cd|T6n7$Szn~ezBo$1)G9<}!|5hd@s9p-(;Y7o%0PyN`S z=Zy_0`oK64e%Y7i0a5Jh$h-ui6o@RYbvnarN?EoaLBdKjcL-gaqMTd40k+c8KDy}n zJ@5bCJY*|u46dH}CcX|G@pj?2UR+^*(6tE35GM8HI#UVp-1#;<{7qek1KG8JD2}CH zcfA1&zPElUufba}YlEMnm5STQmT!Xaz8E00Yt36blV2xE?*A63=Z`YsmzON}hY1`1 z+AgiHzUuc}_pb5Z8tVII0T@2#TaH=p`1HJ8Y#v;u3Fl$fE)?+ytxG4LWR8=bDq}y* zuT(B}PSp6`Wfb77TpwrgYeofW&h2?mwrT_Lggb%M$Cj@#XNWZ%P3wTZ46PZ2heqCw z!o?oNK$Kph%hJ1=OGUwjo`_j$$Oyo3Dnhqf9w$S{BU`g2eyY!aTrXSo+W(<%J}a3_dgL{yon1;FHa-+Vq9T@ z=9{-_Q)?W%>Ge)6yZusYEW6LNJ8~atEYG|i&pc6mS+}S=Oy3?bU#}h!(-0*9U2jPp zxjxHDyi`CIsSkpupfPk^|ANxSHy=Qe;7#vJ(9znJw%sPN_nd8<`<@jeO&g*Yis$+z z(I9+r?n`EyZK}fvAImpo~$@1%tdJFy(NUdeY2KHa;I6k$+ zu0F;w_0ZmInRm#`IsqOtV6Wn;kJt8||2j_N`jp^a@+5r@bWCmXWdrV>PY5ROjyCH% zLH2$HuGg100s@ntXLM}uDNP+~du4o+mO`1MQ(2=wkE9*gqGDHX%H0-u?GWohFP*KGw z7GE);{?w?H+V*&-18Q}up6J7w+p5{;v7O-VkcP9?)pb}#2FCt1G%)KpRGMDeK>j<* zxP?ol3w-k-sAgO*J0vEdIvP~MV4)RatV<#UXeHUbq?CKlwe&X^{WfOli`{d+UvWOl zXZi%cKR}>~@;<&)H%w=o?SDRTsSLUTH~jpIf?bEYI6cGF`$pfxqo?DkdHM&=bOoFR zi?u#Livaz4&Mm(V4v{?F#H`| z-;AOD)hrt=-ZUTHqFf@07oDoNJ=ncsjHa1LTp3F&<6t&pL4D^!odFf0yh&No`dcBg zaA8NhGMMC71n~zH(VxZNG5X{OSiuh|l2!E^ce@_+%+ln=!(a6BSm7WA&h&GA(x~%v zBejV|^MV8Dz>;T_Dj#t*tKHphazdnj-{eR%IpC72&y-M|t9@v&ta_D1O=HdSuQpWX zp6s&ebJQX@Bg4I!_+V{ekcI|)q;Vpv>u7^jZW$UkXm{PJSJ%da{lD_u!y7|z50LiZFJb-@m^j2;Go3VmwfWQBNk=3MNR;=#6A*N{4 zNPy*Ik3trtMLk|tJSnH&s88$TZB!ifJa1?UJW+1ol33Gb;}(=fuCj>Hd z{t6lgjK9kFxpMzR^TmF=UYPUZ)&R#5B&c)*lQs;1@bIt)Ljf$X931R66N3HT|14+m zbg98h3Ohs0a}*uO!2?95apvu2Dmfic7Su}51vXJYcS%0t?nrC55udNe*~$hX-qCS1 zS#Dq3Eb;4xxff2fg1>dh6c%*9#fw8Nt$Z%*B%M6M!mClhcO$)JV}naRa?P$&MD$QU zh}NdDI$P;-?YV4M*VPvE6WKk&illvasGJqFek2lUOlj=^N8HTpHoiw&6>5SP_u{4) zh>vC^8yia}6OE)mH!|}<-;t6eM$axXIF@N3rk#}&B*36}5LHTMQNEuk?$)@MTE&;> zP3tzeH(1>ZF^KgwyDuPo^FRJb5cf~$mbgb+WfH$i^c9_16ncPQ)@TLAu{?J%+|#X; zf}xsdVk@SFNn>-`i+N*PIjyx^1naKKCEt_(l@jZXo%%@}DPVB#E+B*s3i35<;aJTSs{4>iVGGvez@%4uVG)%!|KXA1b^RDa~pSiuU}W~Me) zo2`wtm|73Z7UWp##y=L?Pvz-b;m334pZCIe&n#^Ruv5%x=q1LRm&_tObjo~U-_MSX z{+Z=v*ZG4|Q%!MC)=N&JsDQr4$!v-;nGVTzI1>$jt%~_r4$eE1TXiL=|}B_Ye9mAc~5J+Q3}jkovsa&qft2DNon7B*GMtT8ro;La0dd#xKZ; zF-2mrIUcu>8u{eEhYI_N(e#pTr@)c=knwt*oI>wt90+b;ce^~y&B|#q5q_CKK4P(M zFm^PtZlrtG1w@E{i=yhGMhQ`>mn3xmTv&j z4M$D{S0z>Uy^ly#St@T!OgRh;ij4qVQc!{hKn508{@+pch-DUI(=NLIxpZB7%KYy! z`og?oa9P6a49O`@>)Dn0nwr4YZ3KXc%2dT zm2J8tgxtx<1OcfaaSIkVRnzZvjD;-_838v9p|F$>5iIF=LI0=F>|2lwk`TB{tu(bH z&P0tEnIM&nkg%ePc@%@aceErk!>F3Z=UYpafWN~Z(7-f!+tR~#TOm=6EKD8(SD-ng_A&{j^`t@t9N4EQM`0;|myFzvy{#R45 zrq@#Q z(0u$$VO0tqpXMMo3yaGDp9m}i3go3Kloi`^+oSb@@Z|VESQ}Z z329YZ8iOSmu^S6Wl1PLY^T(sMcesq07JO^17%&UpY@wx}nctlxa+0swofTdX~s#4*p-^nP>EfR??CgSo@Tee-c4 zU4xPL#zXTdA#VeikH%y3x%CnN`)O(Qt4yO`%@JX|ci#u&W9tTSqfbc!OVZwa{b=VP z=o~klpmx6JkIYeDle^KJ0P_WLx(k}THyo8d)VyGz40V6>QExot;D-~xqBjDSKJ)ROKJx`rx(g0Y`z)0{%)B6Es797jddk5tL(Ns=25s+EP{S)ZNgg9`dibr~ zUIdqh(8j0H=iry&+5j5)4}Kv3XUzQPKNKFhHowdi-Fni8 z!V}jj(8WLMgNL5c<8$yy=sy8yYj1ik3jF`q|EExIvwc~>vv>Z9??1t>5X75!?XH16 z{aT*aw)c*JGT&_fw;>>u9SNI5C$mU#wM==~zf%23KEyi9;qS9rhcA;x@dUM=TQ||rUR^}7%X>Fo+y-%@XFDDOM72lrSFRsn zdk-G}k1hXaA<=Y{+3zmGxYofNH~uY@(RJJ_kH@PUH}0(|AMMKX>Aed_AHp~$9CRxN zwh7k^QV8vW#W6nWL^ugInM&8J3mSTpZ1Ov24eH&Rso$CZ$SUA9SsNY?_g+L9J~+oS z*Rr&2G8O4JqqWVac4aIi4lp7CH13v+09h%H3 z7R5u#J(qia=r3X*g{bxozH;T}hXbj|qyc-~LkJmG|9+0dxj$GMo@g}zFI#xi~U(i-@S!!-i5i3f!x|}zqHOkDkd9Ntbf0UY~=sO%F$YxQ#6L& zU|{Tv@#7;O^`J1wT;VtzfyT-*?|6|?iJNM))wFm6xSB+$DdhSF&AkWZ zZ0%kI+A0r1rOl`anUEw5s$C|0VJAYFhHXDKsPxkG4lWB8rBMQLFkX zQ;8eea7&*sNn^&U@^=lxfM_Pk!iUqNAoDx|#<{@0p_XG;CAe<_{VcPVr=3uDh4EnlG_JnGI4qOWq0c?V=M550t$ z)f0_)2U@Fe;wY4Kev8@(D6&MnJJf+SVbR;?!(t=Tg8rs`jMJ$<8x-pyqS>`=Gi^lo z9!*DCrKY9g$TYsL;JUD81#2GYI*lS&%^?$CN&Nz9ELq66)t#xJ{v7@!{F_iodKQr)sQc$JA>8C->6h!8!EHe%UlE09bs z5Rb~7i3=t)pcP^Yt<+LE9V^&bS39FG3E=&mXgU)ew2h8sgT);|nmr^u{))D~F5Kk< zC&i3*KrFe2a?Og8V=OQGT4OQH;xR2P#ug0<25IT;nDT)&sQ&vCFy^*n(w{~xv>%Xl zQ)AZ!NXHHJxq$r9LW;Bl7VZcq!M*K(f7^`8UDyReHgiMrv`8;zZ^A|$8lpc?#lQt? zrRXyZAw`LXn0$r2TAAVFL{scJE9d^?yr(mE3fk5wwEfFsMPP4?m(Af(swHgaS}P*Q zY-Wv=o+d^BfHdmNujfN|%22@3xb9RiPx^h5BdVmpkgnQR22Zhog<%8n=F&Q6EG)WE zJv-CZTniPy>Hs39;nH1hNJ_*3%OA5lp~PQq4`JD4CgHf-Os&)0F`y2py|ss$G7==M zDxddbNOw}*pvROnw5N2H9vBprkrSMhxm#GGp+PV4MON#zI}vSX>70Ui{U!5e*hkE; z1CD(SMz5l!@c#4hN>twU#0Sw*YuC4%k^3{aQh*~d?fe;<&Byw}FugeSH`RCf>rS;< zGRb>~37syvZ`_|l^?#HdAD$G~e9E138lNunY)|y!I8T~G+JRu3zb2vauV1UVd=zn3 zR2hiLUS48(LWFUfCefe*3GNR1l)gvpP~k?m1v4A8TwaDwm`yZnDGjvF6oLz$b_z$5 z|4Bv%UB(Lwzk}CFcQhWM1GwD5%MpVI5rXRm_|en*MQ=(;tb=*T^PYln$T85Y?~>SQ z6YbX<`dz0s^q4F@04F$IfI1WG8U;JMu7!Zql zaN9@5j#+vMlH#Lk?Y&#xtYSPL2PeE?o2yFQ%yL$BF#WRxJaMucnY@>?K#6(t> zcHswaHhq3GrVS!ED+CGmuCE-zPhpxt?K>r=Bep$_c!5v(G$&dPL*pCI3oWi zKYMUsTS4?@%PRnsjI1;C``y58O^epA-Q17I!ZaNin~+l9-9s{q6Jt*F=@%Dbn{Md_ z#;X}q4UF?~^&{&4WKx`?01h^xC=I!VE!|q4WdND{dtbbPhJ6i(al|3_%_;VlJ$yRr zkIBDD@lNh2KYKMx`fmg;WOBevu&Q%z5#l5Do7`w|qh>HnT<*NMRMv-n!S=K{2=WRm z&uYOK>I>~=k^9dQIs#tK#xzS;ePTFIo~<3xHZ2w6W2g5Jw_-L!Vlzb&_yc(n}!W&s@P z{qJrze~o>X+%zZoI7G64o5hdSq$mpC>|5siHzRkY0=HTZzrv`)b=SJ{Dw?oI_M04- z)urrinK;e(@5*fq_?R&r1Dv+TeYf+j6sj8K8+f-(CyqQ+VBK2YbzVPoPpSfn*#k$m zzt!XW)sD?mf!X3~Hp{f=#&ne{0cX=|tCMOceQcG2qBH4sInAz`Mn6N}5;lFTa(tY2 zGl2}Hzwu*$PaFdGpCWX z>xWM*aX+8NwipW7ZwsyYiAlU?q(?q^M0>ED@SZ1A-ydOpZ;2R-_S5!l*;*dS+aB3ET7360--ddZ`bP1mCh2zBZnNM|w8g38 z!(uHv`WJQQeJpc)EDzblAVjlb#Z7wJ8;f72;kN87-R@Fk3{sQc%~HW(dt#{8Mp(%cg#(WRDPrX)k?caRvEEACT<&KRYN2pPHSq>MCiY0anT|%HPXdS?A33nskLQ zBfcxA39^JIv=Cc$AjyTeL1mvo4kT_EbDz)QoWVXm{4o@ocjoD@P6#%G<3$=Tq{X1X zA1^4x?d%QaELpC83^&hX%#V4BpbcFjCq(eQrEw8dKK6+dJ;Dr`A_r&#WWxNc!o{cN ziEw{@$NlMF#H;^tzdSwDKEk)GG`%;Ai#!V(BSp+E-o>`nUNXgDIIa5ePwJLU0-GHu zr2o3;rP<~Za`w=;FzvP6=<<~ro@ULN?IbFOsQEaKzGqk}jM0&W(7OS7heOeBeT%d^ zo!vW6`GPg01cc58LZd}uLhR!2hHnBn3cOVxBhxVBXhmXsvuG#j2St1(#G~;Om{32E zSTlne;m%;~C%nXe{0%4=2$k4Y8O8EN1Wf50U2}f(l~c__?;tqVZ5gcjjH;u@cKQ>h zd?wWmM#T`%j6@!DGNcZyo`Y426~W4Y5B;a+!7Y-W>=lG@u%|MznKKT5q-Z$#iJqzK zxzlwQE=30yhNGt`!X^HWwE)Q&C09*aP>Nw)mR1r^Xp`qeT+a`EtVFBO&nCXDR)*9( zrWT3B(p>I{Uto{WLb_C!(Y z8&R1Vx)E(c#q)=F1J;-*LDXBgHEnnZ&XR^O;tfM+F_LMPqufEg5@FO^ zSwS%g_O1c~7b!+@fqJ+F0B|IUV&_~s5o(f z)P~e~_!GfZncg2qwckj6>x{Av;%8>DVlVCp4zwD-y%f1$tbR|SgjM`mcPtrKN02Jt z$aFLrABKa?+c0&E{rdnt`hC-Y=9XkU<i63;#tsCZQIN&*W}J()ig;wf!$=@_*c-%45_WW#?r#a02w zIpCz4H!1?IX_R2X4$+&(!;UeZ&?7Q8n5apB+I?|Qf;v@7542k9Z$;>|=8~C^3oy`@ z*tS|Xt>OGP8G+qOGdSA}*hiN8>@pAHdQ#>%d1}&y!V3R}pS};{ivNO>$_Pi^anXbW zccf3it?}}Q?t%yv+pbr?s=Y^``8G%M#W4!lgH46aLXB9lVn}C*A^Zcnl@>y@K>H;I zUj)M$Oel9iT{iN=8xL78_K1v7RXy34*X=bp(rgMsVA@mmY7)hWTV4xvjR4W(#;^FS_@^Av?2wd%ocqd>AuLy>+N%RL5 zludXY*8bM&(qtJ?Q()c8YVI-UED7fEpv49Iibo1xLfY`j`=S_Z>DJ9YB))Z|CemyJ zsmei#UUHMSf`HbNDIS%ybh)3Q%Ir_egO;RZsyFBCQ*xTC0>6=j`i?fvrz-p zaH*{hD^8dhwG{Uz_=#$FS#Zc|*vSa(4*-nco<-(>L@@^!vGrQ76e z)i4U%P}X*Z(n(U4e!~XHVB1z`1(w*?M+_|I>+c%rhex@#b_{k!zsia-QAJRyy`v9B zb}fr{Q(QJP%FC@Rq1qFFe>B-xF*e1h&Lm3NFQXErWT~K-E8^Xd7gc)@U7z9K0>M?b}_)4Ed z-l8AYqM3t^c0gf>5i$6J)PK4E6fH2d z>i<&OA=42f|mAasL61uQ%`Z zO>K3!<`=%04tlw>0)0&>(@%v{&QhWv$4SDl5&%2mfOqffLI}CSI8MXzIPa-L#A5t$ zjUO-GrDE$;vEts3C=4Wv#P(1Z!Yp%n){?k~IHQ!H&8TCnfXNu6e z`2j7{)X4YNh`w{;fR#T>Ip4xpRy_N^lHm2}_v{KY`_I`#fTTU=wO7mM`pFL; zgx_>+LLFdw7{v;_;3eH)Ne%(lLVZQ#Mr)YNCkW|8(@A@tUi^w3kR>1BNhp7;i~v5;Zy?w!|K~61?eEQbq1_y=^}Ys$UHSiy;4^Gw49@)z)8K>GP>gdAjPmAnf<@r_ ze`1XXnGxT)1c2H3==Q->S7z@2po{YKc!2-yOX>d?VNKyvh)O6LmYz4E&?jj?q++9- zW8L9kYmpHC*u{3f*^J)fK_lz110nZe7J;pGJ~y1CDNQJL`cCF+&I z%VS4YN`#@g7?g>f=D+d)5ywpr-gxp0qK+QjJo6I(M)&R83qjvoB;H1QQzk z<1XKm4A+G9(CX>{#{2VJYJ!e}be?ogo7dWCb=R6_%J~fF0DiUk)^8Uc2a`Pty>Cx* zbv-ki;ZICt;teLhUR1Oj)LBr-96=Y;!rLx0q3*Mx?rMM&I-z-RuVYbO7iMKQbVWBa z%eFTAFt4o;=wrA*_iv;p!hNp8oqUP@3iY!hn=>Ka1tTT5)y!W%?Anmng)h zX`1!q#*I;Z?U2-4gEAL+oM<^(VT!|zKdAMm+s#T?wPC5%=2a5R$0=0z$Y(}EVzo}B zpEFS-HtB@>Gpod)X<+PLNOeJ5ztxmgB8Za1a3*QD1DQ@p7AJr)I2jYC1H_gN1L}FqbWwz*LXF-bL;sih`n8Z)oMe^%Q=Y8T)I&;K>fDzl~XOmeQ??>aWD39blP%6wVBwWti_I@Y1|AY(9{c<)*P{IrT+|H^nv zSaPZ-jo*Q?ue)jZ(8E3mpTfk$ZRR4l`8X# zo)>x%q$7v<96cax%hSBum(X5DU9Z6VFuYyyr)u~dGyCLWy-lbEu^Iag4ynhOQ}y_C z&a_DY<`3Y;;SP-orI^#K6gg-lj)X#fJ4YHa65U zUy-L;ljgF3Yvc)C9Shlt{X`vCQr{EL9u&tHxI)3f#GA-zz);0&ht9-QTr)9=AzLM! zAz;rePk)}vd4{+Q8IOC(#$ijxDICJ5;&n}{K1MR)oh7&83$-658Ma6_VY5OEweKI- z5+REOXp&CSmSSoPQZ>+^heqg%J+-n5!(fQaH5r73Y!ct(kE`Z{x%1NTyGACrd|6VR~W~bwf9~sVml7U<-lDRv@IOafw zVK>9ZG0QgFWE=U7GoD1hI7EksXZgxd!{l7ElhuX1?1 z1q62{j=WJ^q&L}snRP1F6E4DbmPzpjl<9|ny^-A;uK1gUuRN_QAGWq6D*3B7*%;w8 zeNWt&D)~+2+!&GdCP~(}j?* zw4tRzF5ZA7Y8TSu1M!bmscP@(`{)SWC?uqiP(%isD)=ATI1tqLwIoOFE-+P~`5!|WEpDJI~K zkSHjWe}Nk_t38@;%b*IJfNR~L)p!Jud(tTT_Hl3|Q^{aQ<|7q?pEi-6UxR&{=AQZy zFmas}bKh_Lm=TTeFUe~TQ_-u!aa26NpVn(2;z;megZdpsb@+p7Cckkt>?f6lH{6L) zi+^0z<|9=S=F6;0qVd+jPwG%fa zoye=6BxV@)kcUf*`X@g$VDIG2s82t_T8KG zN`|D3_%;{n?vkSy3gTji!Y6o9jJ*duG=3)2Sr;oswS*81tQ<&r?62A@w;=rYJHBP4 zweF9<4caP0?4^u2XPg<<#1%v*dc)a$u`c7V0;S z9cc)Uo&$luV)Cbfo5V1i>3mglQvjQq-U)j4oyc z4bf4_zDQ@26{Jx}<=!cADVj)B_hI2$hXlD->1BK5H=~xHP!NoZ z!5`%5ctX?VS`7m?No(540HLB;VpfKvmHqUU{V@QUzXbJ-4A;%yJby{8ir+=-ud3(7 zFWfQyaniv<v55d-qMA#2ny4w+`Bu8tY zQH3A%LdAJ<4@lQ4b$?Y<8{ldKr3jU*3yG0oTEpeUKR8f$QDbeo>0LmX zzl3oMc1-+e2$top!HPZzlPiFdD+o4y>IYNt@I``&85jx_G3l@K)!?J}G3#NhMN6Pb zH#%MPY5DcDQ%><8n{cr^bCyhk*C0XkDE7HNF7&uMig%9I+}^ST?xGATKas;bAZs0HcG7{mYT*p^RlIOvkIeSVvX8L z7#)9!nT~g{kDZ+kFq>|Htyrq9*hej193k<>H@R?B<$|M5rF%fJ&W)f0AGdAn0&(Y{RS z0gT3_jK5z?eicafr+?pPo+V9|OKm8U!$$AX$iypS5_%EtU5O!o~6@2!}C3s4A6A#L^+s+a$MCvAXG7f6P4y%CXwE; zQOh2bFrq8e3|KrP`{tdK5Xme#|6^P~)s7|P2tOCVM5#xx>o?^A(xIbXO zze=l3<%O1P*7kh3w{F~v=srqZ#CcTR6!f?a&ZPvWND`SHi`Vo)ICkSu&Eg^!7h|4 zoOqB6d6YkJ;L|ydO?-{1^UZ~C7K@%PkE|_PA{V|%4x;oJO2U*J6`@nOnaa3z(ODVy z)d|=DgwP}ul?@hIuq?XO;{xQ>3qw_EJC+g{{PKN7CQ=E4KFqDm5oS34*G&Lq8-qN^= z19OlH^BAMSnReJBYov~}MF2E@xF~8(d7SyWsr3#uP*c}ovAm7AA?SfA#yE{tpnY%S z?1!VOr^hT|*A`?##M#c(zE)XK|9EU;-zMi(dmLr0`;&N|kSXio@;3ay)jr;^F5a+i zoM0$_zSsmcyZk%xYfwFztAt;ra@P~Ff_!Ac2$xkOnLdQ&!dMG2S~TuodJ<@A7&U~J z<_s7nTqxgSVd+?6s@OoeY@uAp;FV8j4Q4)=CcPyyd(xRbs$Z8bILvTiOLG7){kuE= zyARO^m_J}D^jCKqAhOX^!K|B=p-6z1(22fU2c#;Z^*kdjPx=h!0}UqN23Bm zLQ#o>`S%F_g;;~~#sRlqLfa?sZ7FkN){cuk!dv zPzd<@%w71QYEw0q2W7JzV*RvNq;`UGSZ066T!_e6`OLZYkuu{mXU4VG&|f#tNhi^X zStdTxOnl~<_((MInQ7u9)x;|R_|X7(NKjSV`;>7}JDqa3JSk7hoAS2I{)YRT>~Ffi z$^NGMo9u78zsdfl`LifW!(?f+xlE^t0fDCr%N&>3x%iOn`EKzbfNI% zn)#R?`;npbWD#&!|M{p78MQlT77p8YgXY~ZVs~&_8c&mJs%_RD(S#w&Zm-TsJll`UpJF4vA{#Xp| zb1^v5S1;37FH}83W0`dA?R}q=)5SrcGngKF!ApfkGUuTYo^eU1J|5jWLR-~%p7{&) zM-=#0t}0%l)H}pwJ@v*Ukeg|ilg0{>%>^Q_8rn#A(=3w=Cm$OQkIy$RbG$jKLJ$W~ z9J;>$l75S7aJq(uRYwRHyy7PGPQ1{m_c;L?N~6kKh>XYicftI-un#3z_S|-Pixhd! zWAEXr;e1e~2UPx$u9n(3;?$gg#d8o|N(kFdK?_6EiR8T_l++#~h$)mj?ln-rahw4Xm}s>vXKOT@p|~J`y+1d5dI8 z6~|d4kEvNq$zi&5$AL9(!kRav3pV){8mB(tEgSG2EsOyION>TE{|1mY!jRw<>cUxd zL?OGpGaU)LT~WafyQK>W)xnC|#fjQCFzWm9O_Fd@?2UCRhwdc11XF~m0pz( zhy>v%P^|85wK4;ztz4#oypf?yrU687ZHWRv+KVG`A7bsFWBW*SbfL2+bA&-8R%`;7HG0i!0@7Pfp6O%qoe+94{8JvWw|9x-nWf zmlq}H(O6^UTy)-CBdRd3)Qmr?WPI(AAnXgT)SQd#{Uz;fdznkgTy7W5Y)QDGt=*P% z7N@zQleyGsUMj3SaHHDO!Ij$?JEA;(g@IJ}>|YT70GlkA0N*SLzMtF`|90*te;rD5l>afhlP6_uN^D|G}enh^VQmxlksWCj~SmI+vf`w z+IiO(JfGr`1(QIyIYa&8Abcar}1S)vZRJF3@mQ z^K$j5sn{YKydT4bX^{AY#)WD4eLn9A&Zk`9Dzq|KcxBjA*6Arze4FvzEq{2qNgBN+ zukZ*VIF_+aqS1T?%_+slx2afcYZjYTmv5sv^;(XIf4dMzB9Kvy;?D?&RO1Yhk4Yi+ zXJ=ztNH?Cp)9))k@;OU>-{tr;@%Z-p0-Rcnzcfns(KHPL_dwTqvyL zJzgPiA+EbUR`7nST>Fo{|4;fAP!*2xtBPG|PREQczXIrykNJ^7kSu0hEL|&}wLs5{ zR@)9Q184C>b+uF>)dc-@ivRk)Km$I9A)iCS=aBL_pb9+wJ_j{9&SxJ0Y_8^!Ud{p% zrXDk<9#Wp|Q;Vgmfb^1dZiMdD`cY^f_3QVL&;}ijz#3PP5`9S#i>=I7wEV zw5FA0mC2uyjGLOsi|i_}>ynHL51HANWRcx$k)3pr-F%U(=Z(fb1^&K0Um$7XE~)l$CH}nf0d|STGQF(IkX%fWT%T8J zf@jxf9|3%FXe{&hUa1?i=$lKbHDr=4m(*G;sa1_NgQR94O#;XnG+xQM(#P{jEXvjl z=!2;k7w4ou(929bH*G_hhO#{|Er~Jxnj!)FQaKpF+w0DDWOjFdhq@SDRWYt@2SmS% zRF@CEwql?UAjw?}Hee5FmIA|@4xA#ukZqG{#r_OCy`p?%o-%tC|d>wwsK`^U$`!; z>c-}RT2z?YTHUnP)k>*#9w;GNXh7VE7Xb@~{Ll%9b;@I%h-jqlPy+Lw%41F7+KnMh zU{g|8d^NH1Q-xjf{{DH4xEE)LlP$$U{oza|5j)i|3Naro(c z;x$+}(fZy(q!*2!Wbn!AImKb9_+L;VfA|~MUI0ajb1&?={ty0*o-rOUpBc)TxN-^7 zo#^+d!7Ypks=s@7=49OLn?csqldV+;1pljP#-K7LXUQ5ID#u=Lf-_x%O+ig@cT+0i~*{2$rRKC_E`^u^j& zJz2VNx^&>|+Pjj8ydTjDAaqe1O47NR;pPeZzUjAee?HITlYAqJ0!BhlI8ojq8A z$=Dchc{Fi)xGs_EFW+|gEf(Se&P?SOg4a-SRB*tB^A(Yoql|R%`#-Xc@1K|l>+?(TD;Lq5uVZ$Oe$}l@DMRSz7 zh3qWxZlHd(QThW;qag&twM`IvT+XDzck+dTt98h_9!cbwcrf6e9`a8QkR}atCZV3( z-G-!tqARMp0)M#~22gqJZq3C;st5%n^7M#YC7$u1;mNCH$nR%=JVJJB3O~c+yiA+HJuDn&JUXvvcI_R@E_6I*v;gV}UzM&7Unw__BTyTwkD$XCka@1V;lcl>+|VWekr z%-u~h`4sk2GdYatrTB?!`H9^L3D>9gA$oB&d@FUJN){Ah6<1xO97);;xaBt{L{-j0VnElt48D41GCdn>vOVLknnj``$V3Bi2@G zUvGqiHrFF(Eapy^263@|`L~||z%;7udg`oyz8!!?mM2Pd=RXRQ@d1%jAoENZEEw*#>vjC%2W|0YR42^Sv?qO%*taSMVfF zj_`o%QJWJ%ew`;F}M+H3jDKCr4#G-g_P&e-#fV z7J{;y7zlfVpSb*HP?NcBzF)DNW4z)l(Wy4Yqg@OAuQ2g!F^BQ&%Vulv9^Q7pV@OnY zTVJzEhPrxrO$)@+#&L;|{4_y?-}uxNl6E*1|Dz@iCoE)al_c^1N=+PjUI~+<(j=ef zrlWg!K7;QxGOIkSKWY4@r&+B$BF;M|(5M^)PD+pP?g+wsx`8rB#a52VsB4ec^W3>- z@>!f$B=ZRPeC(|2f)L);i5Ohm(7GTC%fQ~&6nJOr>|v+ap_i{(|0;xc4bRO-Pv^vX zpXeSQzbWVFj@1*S0{utAt`l&>+WPbV@#hH24nyJoNS8J0910FmU8)or;%-Z15G?PW iN@_}wj-?E)L}~BrHV9OurlVZupZ*_{ycpT-;{X6pS;b!f literal 0 HcmV?d00001 diff --git a/eval.py b/eval.py new file mode 100644 index 0000000..e69de29 diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..51b9340 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1 @@ +from .vgg import VGG \ No newline at end of file diff --git a/models/resnet.py b/models/resnet.py new file mode 100644 index 0000000..6b0c0ef --- /dev/null +++ b/models/resnet.py @@ -0,0 +1,124 @@ +import torch +import torch.nn as nn +from functools import partial + +class ConvBnAct(nn.Module): + def __init__( + self, in_feat:int, out_feat:int, kernel_size:int, stride:int, padding:int, + bn:bool=True, + act:bool=True + ) -> None: + super().__init__() + self.conv = nn.Conv2d( + in_feat, out_feat, kernel_size=kernel_size, stride=stride, padding=padding, bias=not bn + ) + if bn: + self.bn = nn.BatchNorm2d(out_feat) + else: + self.bn = None + if act: + self.act = nn.ReLU() + else: + self.act = None + def forward(self, x:torch.Tensor) -> torch.Tensor: + x = self.conv(x) + if self.bn is not None: + x = self.bn(x) + if self.act is not None: + x = self.act(x) + return x + +class downsample(nn.Module): + def __init__(self, feat): + super().__init__() + self.body = nn.Sequential( + nn.LazyConv2d(feat, kernel_size=1, stride=2, padding=0), + nn.BatchNorm2d(feat) + ) + def forward(self, x:torch.Tensor) -> torch.Tensor: + return self.body(x) + +class basicResBlock(nn.Module): + def __init__(self, in_feat, out_feat, half:bool): + super().__init__() + self.conv_1 = ConvBnAct(in_feat, out_feat, 3, 1, 1) + if not half: + self.conv_2 = ConvBnAct(out_feat, out_feat, 3, 1, 1, bn=True, act=False) + self.downsample = None + else: + # if half + self.conv_2 = ConvBnAct(out_feat, out_feat, 3, 2, 1, bn=True, act=False) + self.downsample = downsample(out_feat) + + self.final_act = nn.ReLU() + + def forward(self, x:torch.Tensor) -> torch.Tensor: + residual = x + x = self.conv_1(x) + x = self.conv_2(x) + if self.downsample is not None: + residual = self.downsample(residual) + x += residual + return self.final_act(x) + +# if resnet depth >= 50 +class bottleneck(nn.Module): + pass + + +class RESNET18(nn.Module): + def __init__(self): + super().__init__() + self.stem = ConvBnAct(3, 64, 7, 2, 3) + self.stage1 = nn.Sequential( + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + basicResBlock(64, 64, False), + basicResBlock(64, 64, False) + ) + self.stage2 = nn.Sequential( + basicResBlock(64, 128, True), + basicResBlock(128, 128, False) + ) + self.stage3 = nn.Sequential( + basicResBlock(128, 256, True), + basicResBlock(256, 256, False), + ) + self.stage4 = nn.Sequential( + basicResBlock(256, 512, True), + basicResBlock(512, 512, False), + ) + self.head = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + nn.Linear(512, 1000) + ) + + def forward(self, x:torch.Tensor) -> torch.Tensor: + x = self.stem(x) + x = self.stage1(x) + x = self.stage2(x) + x = self.stage3(x) + x = self.stage4(x) + x = self.head(x) + return x + + +if __name__ == "__main__": + # random_input = torch.randn(1, 3, 224, 224) + # vanilla_conv = ConvBnAct(3, 6, 3, 1, 1, False, True) + # random_output = vanilla_conv(random_input) + # print(random_output.shape) + + # random_input = torch.randn(1, 64, 112, 112) + # first_resblock = basicResBlock(64, False) + # random_output = first_resblock(random_input) + # print(random_output.shape) + + # second_resblock = basicResBlock(64, True) + # random_output = second_resblock(random_output) + # print(random_output.shape) + random_input = torch.randn(1, 3, 224, 224) + + # test resnet18 + resnet_18 = RESNET18() + resnet_18_output = resnet_18(random_input) \ No newline at end of file diff --git a/models/vgg.py b/models/vgg.py new file mode 100644 index 0000000..872ee61 --- /dev/null +++ b/models/vgg.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn + +# VGG 블록 정의 +class VGGBLock(nn.Module): + def __init__(self, num_convs, out_channels): + super().__init__() + layers = [] + for _ in range(num_convs): + layers.append( + nn.LazyConv2d(out_channels, kernel_size=3, stride=1, padding=1) + ) + layers.append( + nn.LazyBatchNorm2d() + ) + layers.append( + nn.ReLU() + ) + layers.append( + # kernel=2, stride=2 이므로 H, W 값이 각각 0.5*H, 0.5*W로 축소 + nn.MaxPool2d(kernel_size=2, stride=2) + ) + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + +# VGG 네트워크 정의 +class VGG(nn.Module): + # 논문에 제시된 기본 설계 + default_config = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512)) + + def __init__(self, cfg=None, num_classes=1000): + super().__init__() + # config파일이 따로 주어지지 않는다면 기본값을 사용 + cfg = self.default_config if cfg is None else cfg + conv_blks = [] + + # config의 내용에 따라 네트워크를 구성 + for (num_convs, out_channels) in cfg: + conv_blks.append(VGGBLock(num_convs, out_channels)) + + # Iterable unpacking + self.backbone = nn.Sequential( + *conv_blks + ) + + # 분류를 위한 헤드 부분 + self.head = nn.Sequential( + nn.Flatten(), + nn.LazyLinear(4096), nn.ReLU(), nn.Dropout(0.1), + nn.LazyLinear(4096), nn.ReLU(), nn.Dropout(0.1), + nn.LazyLinear(num_classes) + ) + + def forward(self, x): + feature = self.backbone(x) + preds = self.head(feature) + return preds + +if __name__ == "__main__": + net = VGG() + random_input = torch.randn(1, 3, 224, 224) + b = net(random_input) + print(b.shape) \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..5259f8e --- /dev/null +++ b/train.py @@ -0,0 +1,66 @@ +import os +import argparse +import shutil +import torch +import torch.nn as nn +from torch.optim import SGD, Adam +from tqdm import tqdm +from models import VGG +from utils import get_dataloader, get_current_datetime, AverageMeter, save_dict_as_json + +parser = argparse.ArgumentParser() +parser.add_argument("--model", type=str, default="vgg") +parser.add_argument("--imgsz", type=int, default=224) +parser.add_argument("--batch-size", type=int, default=32) +parser.add_argument("--num-workers", type=int, default=8) +parser.add_argument("--optimizer", type=str, default='sgd') +parser.add_argument("--lr", type=float, default=1e-3) +parser.add_argument("--epochs", type=int, default=300) +parser.add_argument("--save_interval", type=int, default=50) + +def prepare(opt): + SAVE_PATH = os.path.join("./runs", opt.model.upper()+"_"+get_current_datetime()) + if os.path.exists(SAVE_PATH): + shutil.rmtree(SAVE_PATH) + os.makedirs(SAVE_PATH) + else: + # os.path.exists(SAVE_PATH) == False + os.makedirs(SAVE_PATH) + TRAIN_CONFIG = os.path.join(SAVE_PATH, "config.json") + save_dict_as_json(TRAIN_CONFIG, vars(opt)) + + PTH_DIR = os.path.join(SAVE_PATH, "weights") + os.makedirs(PTH_DIR) + + return SAVE_PATH, PTH_DIR + +# optimizer 관련 코드 +def get_optimizer(name:str): + if name.lower() == 'sgd': + return SGD + elif name.lower() == 'adam': + return Adam + +if __name__ == "__main__": + if not os.path.exists("./runs"): + os.makedirs("./runs") + + opt = parser.parse_args() + # 모델 가중치 저장 디렉토리 + SAVE_PATH, PTH_DIR = prepare(opt) + + # 신경망 + net = eval(opt.model.upper())().cuda() + + # 손실함수 + criterion = nn.CrossEntropyLoss() + + # 최적화 + optimizer_type = get_optimizer(opt.optimizer) + optimizer = optimizer_type(net.parameters(), lr=opt.lr) + + # 데이터 파이프라인 + train_loader, val_loader = get_dataloader(opt.batch_size, opt.imgsz, opt.num_workers) + + for e in range(1, opt.epochs+1): + pass \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..b55f60a --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,2 @@ +from .datasets import get_dataloader +from .misc import get_current_datetime, AverageMeter, save_dict_as_json diff --git a/utils/datasets.py b/utils/datasets.py new file mode 100644 index 0000000..b403d93 --- /dev/null +++ b/utils/datasets.py @@ -0,0 +1,148 @@ +import os +import cv2 +import json +import torch +import pickle +import albumentations as A +import numpy as np + +from typing import Tuple +from torch.utils.data import Dataset, DataLoader +from albumentations.pytorch import ToTensorV2 +from functools import partial + +# 데이터셋의 절대경로 +IMAGENET_ROOT = "/workspace/data/ImageNet_2012_rename" + +def read_json(p:str) -> dict: + with open(p, 'r') as f: + data = f.read() + obj = json.loads(data) + return obj + +MY_JSON = partial(read_json, p=os.path.join(IMAGENET_ROOT, "idx2class.json")) + +def write_cache(p:str, data:list) -> None: + with open(p, 'wb') as file: + pickle.dump(data, file) + +def load_cache(p): + with open(p, 'rb') as f: + data = pickle.load(f) + return list(data) + +def read_img(p:str) -> np.ndarray: + return cv2.cvtColor(cv2.imread(p), cv2.COLOR_BGR2RGB) + + +class TrainDataset(Dataset): + idx2class = MY_JSON() + def __init__(self, root:str, imgsz:int=224): + super().__init__() + # default size + if imgsz is None: + imgsz = 224 + + self.root = os.path.join(root, 'train') + self.phase = 'train' + + # parsing + if not os.path.exists(os.path.join(root, self.phase+".pkl")): + self.data_list = self._parsing() + # write cache + write_cache(os.path.join(root, self.phase+".pkl"), self.data_list) + print(f"# write cahce file({self.phase})") + else: + print(f"# load cache file({self.phase})") + self.data_list = load_cache(os.path.join(root, self.phase+".pkl")) + + self.data_length = len(self.data_list) + + # data augmentation + self.transform = A.Compose([ + A.SmallestMaxSize(max_size=256), + A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5), + A.RandomCrop(height=imgsz, width=imgsz), + A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5), + A.RandomBrightnessContrast(p=0.5), + A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ToTensorV2() + ]) + + def _parsing(self) -> list: + data_list = [] + for dir in os.listdir(self.root): + # dir = "0000_safety pin" + class_index = int(dir.split('_')[0]) + for img_path in os.listdir(os.path.join(self.root, dir)): + abs_img_path = os.path.join(self.root, dir, img_path) + data_list.append((class_index, abs_img_path)) + return data_list + + + def __len__(self) -> int: + return self.data_length + + def __getitem__(self, idx) -> Tuple[torch.Tensor, int]: + class_index, abs_img_path = self.data_list[idx] + img_obj = read_img(abs_img_path) # img_obj:np.ndarray + img_tensor = self.transform(image=img_obj)["image"] + return img_tensor, class_index + +class ValidDataset(Dataset): + def __init__(self, root, imgsz:int=224): + super().__init__() + if imgsz is None: + imgsz = 224 + + self.root = os.path.join(root, 'val') + self.phase = 'val' + + # parsing + if not os.path.exists(os.path.join(root, self.phase+".pkl")): + self.data_list = self._parsing() + # write cache + write_cache(os.path.join(root, self.phase+".pkl"), self.data_list) + print(f"# write cahce file({self.phase})") + else: + print(f"# load cache file({self.phase})") + self.data_list = load_cache(os.path.join(root, self.phase+".pkl")) + + self.data_length = len(self.data_list) + + # data augmentation + self.transform = A.Compose([ + A.SmallestMaxSize(max_size=256), + A.CenterCrop(height=224, width=224), + A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + ToTensorV2() + ]) + + def _parsing(self) -> list: + data_list = [] + for img_path in os.listdir(self.root): + # img_path = "0000_020996.JPEG" + class_index = int(img_path.split("_")[0]) + abs_img_path = os.path.join(self.root, img_path) + data_list.append((class_index, abs_img_path)) + return data_list + + def __len__(self) -> int: + return self.data_length + + def __getitem__(self, idx) -> Tuple[torch.Tensor, int]: + class_index, abs_img_path = self.data_list[idx] + img_obj = read_img(abs_img_path) + img_tensor = self.transform(image=img_obj)["image"] + return img_tensor, class_index + + +def get_dataloader(batch_size:int=16, imgsz:int=224, num_workers:int=8) -> Tuple[DataLoader, DataLoader]: + _train_dataset = TrainDataset(IMAGENET_ROOT, imgsz) + _valid_dataset = ValidDataset(IMAGENET_ROOT, imgsz) + train_loader = DataLoader(_train_dataset, batch_size, True, num_workers=num_workers, persistent_workers=True) + val_loader = DataLoader(_valid_dataset, batch_size, True, num_workers=num_workers, persistent_workers=True) + return train_loader, val_loader + +if __name__ == "__main__": + _ = get_dataloader() \ No newline at end of file diff --git a/utils/logging.py b/utils/logging.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000..5641578 --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,52 @@ +import torch +import math +import json +import numpy as np +from typing import Union, Iterable +from datetime import datetime + +def save_dict_as_json(file_path, data): + with open(file_path, 'w') as file: + json.dump(data, file, indent=4) + +def get_current_datetime() -> str: + now = datetime.now() + return now.strftime("%Y%m%d%H%M")[2:] + +class AverageMeter: + """ + 평균값 및 표준편차를 계산하고 저장하는 클래스 + """ + def __init__(self) -> None: + self.reset() + + def reset(self) -> None: + """ + 초기화 메서드 + """ + self.count = 0 + self.sum = 0 + self.avg = 0 + self.var = 0 + self.std = 0 + + def update(self, val: Union[torch.Tensor, int, float, Iterable], n: int = 1) -> None: + """ + 값을 입력받아 평균값과 표준편차를 계산하고 저장하는 메서드 + + :param val: 입력값 + :param n: 입력값의 개수 + """ + if isinstance(val, torch.Tensor): + val = val.detach().cpu().item() + + if isinstance(val, Iterable): + val = np.mean(val) + + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + if self.count > 1: + self.var = ((self.count - 1) * self.var + (val - self.avg) ** 2) / self.count + self.std = math.sqrt(self.var) \ No newline at end of file