@@ -207,7 +207,9 @@ <h2 id="データセットクラスの用意">データセット・クラスの
207
207
208
208
< span class ="n "> n_labels</ span > < span class ="o "> =</ span > < span class ="n "> struct</ span > < span class ="p "> .</ span > < span class ="n "> unpack</ span > < span class ="p "> (</ span > < span class ="s "> '>i'</ span > < span class ="p "> ,</ span > < span class ="n "> fp</ span > < span class ="p "> .</ span > < span class ="n "> read</ span > < span class ="p "> (</ span > < span class ="mi "> 4</ span > < span class ="p "> ))[</ span > < span class ="mi "> 0</ span > < span class ="p "> ]</ span >
209
209
< span class ="n "> labels</ span > < span class ="o "> =</ span > < span class ="n "> struct</ span > < span class ="p "> .</ span > < span class ="n "> unpack</ span > < span class ="p "> (</ span > < span class ="s "> '>'</ span > < span class ="o "> +</ span > < span class ="s "> 'B'</ span > < span class ="o "> *</ span > < span class ="n "> n_labels</ span > < span class ="p "> ,</ span > < span class ="n "> fp</ span > < span class ="p "> .</ span > < span class ="n "> read</ span > < span class ="p "> (</ span > < span class ="n "> n_labels</ span > < span class ="p "> ))</ span >
210
- < span class ="n "> labels</ span > < span class ="o "> =</ span > < span class ="n "> np</ span > < span class ="p "> .</ span > < span class ="n "> asarray</ span > < span class ="p "> (</ span > < span class ="n "> labels</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="s "> 'uint8'</ span > < span class ="p "> )</ span >
210
+
211
+ < span class ="c1 "> # 誤差関数用にlongで表しておく
212
+ </ span > < span class ="n "> labels</ span > < span class ="o "> =</ span > < span class ="n "> np</ span > < span class ="p "> .</ span > < span class ="n "> asarray</ span > < span class ="p "> (</ span > < span class ="n "> labels</ span > < span class ="p "> ,</ span > < span class ="n "> dtype</ span > < span class ="o "> =</ span > < span class ="s "> 'int64'</ span > < span class ="p "> )</ span >
211
213
212
214
< span class ="k "> return</ span > < span class ="n "> labels</ span >
213
215
</ code > </ pre > </ div > </ div >
@@ -236,7 +238,7 @@ <h2 id="モジュールクラスの用意">モジュール・クラスの用意<
236
238
< span class ="nb "> super</ span > < span class ="p "> (</ span > < span class ="n "> Net</ span > < span class ="p "> ,</ span > < span class ="bp "> self</ span > < span class ="p "> ).</ span > < span class ="n "> __init__</ span > < span class ="p "> ()</ span >
237
239
238
240
< span class ="bp "> self</ span > < span class ="p "> .</ span > < span class ="n "> net</ span > < span class ="o "> =</ span > < span class ="n "> nn</ span > < span class ="p "> .</ span > < span class ="n "> Sequential</ span > < span class ="p "> (</ span >
239
- < span class ="n "> nn</ span > < span class ="p "> .</ span > < span class ="n "> Conv2d</ span > < span class ="p "> (</ span > < span class ="mi "> 1</ span > < span class ="p "> ,</ span > < span class ="mi "> 6</ span > < span class ="p "> ,</ span > < span class ="n "> kernel_size</ span > < span class ="o "> =</ span > < span class ="mi "> 5</ span > < span class ="p "> ,</ span > < span class ="n "> stride</ span > < span class ="o "> =</ span > < span class ="mi "> 1</ span > < span class ="p "> ,</ span > < span class ="n "> padding</ span > < span class ="o "> =</ span > < span class ="mi "> 2 </ span > < span class ="p "> ),</ span >
241
+ < span class ="n "> nn</ span > < span class ="p "> .</ span > < span class ="n "> Conv2d</ span > < span class ="p "> (</ span > < span class ="mi "> 1</ span > < span class ="p "> ,</ span > < span class ="mi "> 6</ span > < span class ="p "> ,</ span > < span class ="n "> kernel_size</ span > < span class ="o "> =</ span > < span class ="mi "> 5</ span > < span class ="p "> ,</ span > < span class ="n "> stride</ span > < span class ="o "> =</ span > < span class ="mi "> 1</ span > < span class ="p "> ,</ span > < span class ="n "> padding</ span > < span class ="o "> =</ span > < span class ="mi "> 0 </ span > < span class ="p "> ),</ span >
240
242
< span class ="n "> nn</ span > < span class ="p "> .</ span > < span class ="n "> MaxPool2d</ span > < span class ="p "> (</ span > < span class ="n "> kernel_size</ span > < span class ="o "> =</ span > < span class ="mi "> 2</ span > < span class ="p "> ,</ span > < span class ="n "> stride</ span > < span class ="o "> =</ span > < span class ="mi "> 2</ span > < span class ="p "> ),</ span >
241
243
< span class ="n "> nn</ span > < span class ="p "> .</ span > < span class ="n "> Sigmoid</ span > < span class ="p "> (),</ span >
242
244
< span class ="n "> nn</ span > < span class ="p "> .</ span > < span class ="n "> Conv2d</ span > < span class ="p "> (</ span > < span class ="mi "> 6</ span > < span class ="p "> ,</ span > < span class ="mi "> 16</ span > < span class ="p "> ,</ span > < span class ="n "> kernel_size</ span > < span class ="o "> =</ span > < span class ="mi "> 5</ span > < span class ="p "> ,</ span > < span class ="n "> stride</ span > < span class ="o "> =</ span > < span class ="mi "> 1</ span > < span class ="p "> ,</ span > < span class ="n "> padding</ span > < span class ="o "> =</ span > < span class ="mi "> 0</ span > < span class ="p "> ),</ span >
@@ -301,6 +303,9 @@ <h2 id="学習ループ">学習ループ</h2>
301
303
</ span > < span class ="n "> images</ span > < span class ="o "> =</ span > < span class ="n "> data</ span > < span class ="p "> [</ span > < span class ="s "> 'images'</ span > < span class ="p "> ]</ span >
302
304
< span class ="n "> labels</ span > < span class ="o "> =</ span > < span class ="n "> data</ span > < span class ="p "> [</ span > < span class ="s "> 'labels'</ span > < span class ="p "> ]</ span >
303
305
306
+ < span class ="c1 "> # トレーニングモードに変更
307
+ </ span > < span class ="n "> net</ span > < span class ="p "> .</ span > < span class ="n "> train</ span > < span class ="p "> ()</ span >
308
+
304
309
< span class ="c1 "> # 勾配の初期化
305
310
</ span > < span class ="n "> net</ span > < span class ="p "> .</ span > < span class ="n "> zero_grad</ span > < span class ="p "> ()</ span >
306
311
@@ -317,7 +322,7 @@ <h2 id="学習ループ">学習ループ</h2>
317
322
</ span > < span class ="n "> optim</ span > < span class ="p "> .</ span > < span class ="n "> step</ span > < span class ="p "> ()</ span >
318
323
</ code > </ pre > </ div > </ div >
319
324
320
- < p > より複雑なネットワークになればネットワークへのデータ転送や誤差の評価は複雑にはなるが、基本的な流れはほとんど変わらない。なお上記のコードに現れる < code class ="highlighter-rouge "> criterion</ code > は誤差を評価する損失関数で対数softmax関数を最終出力に用いた場合には < code class ="highlighter-rouge "> nn.NNLLoss </ code > (非負対数尤度, Non-Negative Likelihood)を用いる。</ p >
325
+ < p > より複雑なネットワークになればネットワークへのデータ転送や誤差の評価は複雑にはなるが、基本的な流れはほとんど変わらない。なお上記のコードに現れる < code class ="highlighter-rouge "> criterion</ code > は誤差を評価する損失関数で、ネットワークの最終出力に < code class =" highlighter-rouge " > log_softmax </ code > 用いた場合には < code class ="highlighter-rouge "> nn.NLLLoss </ code > (非負対数尤度, Non-Negative Likelihood)を用いる。 (効率は落ちるが通常の < code class =" highlighter-rouge " > softmax </ code > を使った場合には < code class =" highlighter-rouge " > nn.CrossEntropyLoss </ code > を使う) 。</ p >
321
326
322
327
< div class ="language-python highlighter-rouge "> < div class ="highlight "> < pre class ="highlight "> < code > < span class ="n "> criterion</ span > < span class ="o "> =</ span > < span class ="n "> nn</ span > < span class ="p "> .</ span > < span class ="n "> NLLLoss</ span > < span class ="p "> ()</ span >
323
328
</ code > </ pre > </ div > </ div >
@@ -330,7 +335,6 @@ <h2 id="学習の結果とネットワークの改良">学習の結果とネッ
330
335
331
336
< p > これ以外にも、様々な学習のテクニックがあるが、それらについては、ネット上にも多くの記事や実装があるので、各自調べてみてほしい。</ p >
332
337
333
-
334
338
</ div >
335
339
336
340
0 commit comments