001package org.opengion.penguin.math.statistics;
002
003import java.util.Arrays;
004import java.util.Collections;
005import java.util.List;
006
007/**
008 * 多項ロジスティック回帰の実装です。
009 * 確率的勾配降下法(SGD)を利用します。
010 * 
011 * ロジスティック回帰はn次元の情報からどのグループに所属するかの予測値を得るための手法の一つです。
012 * 
013 * 実装は
014 * http://nbviewer.jupyter.org/gist/mitmul/9283713
015 * https://yusugomori.com/projects/deep-learning/
016 * を参考にしています。
017 */
018public class HybsLogisticRegression {
019        private final int n_N;          // データ個数
020        private final int n_in;         // データ次元
021        private final int n_out;        // ラベル種別数
022
023        // 写像変数ベクトル f(x) = Wx + b
024        private double[][] vW;
025        private double[] vb;
026
027        /**
028         * コンストラクタ。
029         * 
030         * 学習もしてしまう。
031         * 
032         * xはデータセット各行がn次元の説明変数となっている。
033         * trainはそれに対する{0,1,0},{1,0,0}のようなラベルを示すベクトルとなる。
034         * 学習率は通常、0.1程度を設定する。
035         * このロジックではループ毎に0.95をかけて徐々に学習率が下がるようにしている。
036         * 全データを利用すると時間がかかる場合があるので、確率的勾配降下法を利用しているが、
037         * 選択個数はデータに対する割合を与える。
038         * データ個数が少ない場合は1をセットすればよい。
039         * 
040         * @param data データセット配列
041         * @param label データに対応したラベルを示す配列
042         * @param learning_rate 学習係数(0から1の間の数値)
043         * @param loop 学習のループ回数(ミニバッチを作る回数)
044         * @param minibatch_rate 全体に対するミニバッチの割合(0から1の間の数値)
045         * 
046         */
047        public HybsLogisticRegression(final double data[][], final int label[][], final double learning_rate ,final int loop, final double minibatch_rate ) {
048        //      List<Integer> indexList; //シャッフル用
049
050                this.n_N = data.length;
051                this.n_in = data[0].length;
052                this.n_out = label[0].length; // ラベル種別
053
054                vW = new double[n_out][n_in];
055                vb = new double[n_out];
056
057                // 確率勾配に利用するための配列インデックス配列
058                final Integer[] random_index = new Integer[n_N]; //プリミティブ型だとasListできないため
059                for( int i=0; i<n_N; i++) {
060                        random_index[i] = i; 
061                }
062                final List<Integer> indexList = Arrays.asList( random_index );
063
064                double localRate = learning_rate;
065                for(int epoch=0; epoch<loop; epoch++) {
066                        Collections.shuffle( indexList );
067        //              random_index = indexList.toArray(new Integer[indexList.size()]);
068
069                        //random_indexの先頭からn_N*minibatch_rate個のものを対象に学習をかける(ミニバッチ)
070                        for(int i=0; i< n_N * minibatch_rate; i++) {
071        //                      final int idx = random_index[i];
072                                final int idx = indexList.get(i);
073                                train(data[idx], label[idx], localRate);
074                        }
075                    localRate *= 0.95; //徐々に学習率を下げて振動を抑える。
076                }
077        }
078
079        /**
080         * データを与えて学習をさせます。
081         * パラメータの1行を与えています。
082         * 
083         * 0/1のロジスティック回帰の場合は
084         * ラベルc(0or1)が各xに対して与えられている時
085         * s(x)=σ(Wx+b)=1/(1+ exp(-Wx-b))として、
086         * 確率の対数和L(W,b)の符号反転させたものの偏導関数
087         * ∂L/∂w=-∑x(c-s(x))
088         * ∂L/∂b=-∑=(c-s(x))
089         * が最小になるようなW,bの値をパラメータを変えながら求める。
090         * というのが実装になる。(=0を求められないため)
091         * 多次元の場合はシグモイド関数σ(x)の代わりにソフトマックス関数π(x)を利用して
092         * 拡張したものとなる。(以下はソフトマックス関数利用)
093         * 
094         * @param in_x 1行分のデータ
095         * @param in_y xに対するラベル
096         * @param lr 学習率
097         * @return 差分配列
098         */
099        private double[] train( final double[] in_x, final int[] in_y, final double lr ) {
100                final double[] p_y_given_x = new double[n_out];
101                final double[] dy          = new double[n_out];
102
103                for(int i=0; i<n_out; i++) {
104                        p_y_given_x[i] = 0;
105                        for(int j=0; j<n_in; j++) {
106                                p_y_given_x[i] += vW[i][j] * in_x[j];
107                        }
108                        p_y_given_x[i] += vb[i];
109                }
110                softmax( p_y_given_x );
111
112                // 勾配の平均で更新?
113                for(int i=0; i<n_out; i++) {
114                        dy[i] = in_y[i] - p_y_given_x[i]; 
115
116                        for(int j=0; j<n_in; j++) {
117                                vW[i][j] += lr * dy[i] * in_x[j] / n_N;
118                        }
119
120                        vb[i] += lr * dy[i] / n_N;
121                }
122
123                return dy;
124        }
125
126        /**
127         * ソフトマックス関数。
128         * π(xi) = exp(xi)/Σexp(x)
129         * @param in_x 変数X
130         */
131        private void softmax( final double[] in_x ) {
132                // double max = 0.0;
133                double sum = 0.0;
134
135                // for(int i=0; i<n_out; i++) {
136                //      if(max < x[i]) {
137                //              max = x[i];
138                //      }
139                // }
140
141                for(int i=0; i<n_out; i++) {
142                        //x[i] = Math.exp(x[i] - max); // maxとの差分を取ると利点があるのか分からなかった
143                        in_x[i] = Math.exp(in_x[i]);
144                        sum += in_x[i];
145                }
146
147                for(int i=0; i<n_out; i++) {
148                        in_x[i] /= sum;
149                }
150        }
151
152        /**
153         * 写像式 Wx+b のW、係数ベクトル。
154         * @return 係数ベクトル
155         */
156        public double[][] getW() {
157                return vW;
158        }
159
160        /**
161         * 写像式 Wx + bのb、バイアス。
162         * @return バイアスベクトル
163         */
164        public double[] getB() {
165                return vb;
166        }
167
168        /**
169         * 出来た予測式に対して、データを入力してyを出力する。
170         * (yは各ラベルに対する確率分布となる)
171         * @param in_x 予測したいデータ
172         * @return 予測結果
173         */
174        public double[] predict(final double[] in_x) {
175                final double[] out_y = new double[n_out];
176
177                for(int i=0; i<n_out; i++) {
178                        out_y[i] = 0.;
179                        for(int j=0; j<n_in; j++) {
180                                out_y[i] += vW[i][j] * in_x[j];
181                        }
182                        out_y[i] += vb[i];
183                }
184
185                softmax(out_y);
186
187                return out_y;
188        }
189
190        //************** ここまでが本体 **************
191        /**
192         * ここからテスト用mainメソッド 。
193         *
194         * @param args 引数
195         */
196        public static void main( final String[] args ) {
197                // 3つの分類で分ける
198                final double[][] train_X = {
199                                {-2.0, 2.0}
200                                ,{-2.1, 1.9}
201                                ,{-1.8, 2.1}
202                                ,{0.0, 0.0}
203                                ,{0.2, -0.2}
204                                ,{-0.1, 0.1}
205                                ,{2.0, -2.0}
206                                ,{2.2, -2.1}
207                                ,{1.9, -2.0}
208                };
209
210                final int[][] train_Y = {
211                                {1, 0, 0}
212                                ,{1, 0, 0}
213                                ,{1, 0, 0}
214                                ,{0, 1, 0}
215                                ,{0, 1, 0}
216                                ,{0, 1, 0}
217                                ,{0, 0, 1}
218                                ,{0, 0, 1}
219                                ,{0, 0, 1}
220                };
221
222                 // test data
223                final double[][] test_X = {
224                                {-2.5, 2.0}
225                                ,{0.1, -0.1}
226                                ,{1.5,-2.5}
227                };
228
229                final double[][] test_Y = new double[test_X.length][train_Y[0].length];
230
231                final HybsLogisticRegression hlr = new HybsLogisticRegression( train_X, train_Y, 0.1, 500, 1 );
232
233                // テスト
234                // このデータでは2番目の条件には入りにくい?
235                for(int i=0; i<test_X.length; i++) {
236                         test_Y[i] = hlr.predict(test_X[i]);
237                         System.out.print( Arrays.toString(test_Y[i]) );
238                }
239        }
240}
241