実務でRandomForestを使ったときに聞かれたこと
Machine Learning Advent Calendar 2012 の 21 日目の記事です。
私は普段は受託のデータ解析を仕事にしてます。過去に何度か実務でRandomForestを利用する機会がありましたので今日は以前顧客にプレゼンをした時に、質問された内容とその回答を紹介したいと思います。普段は機械学習・データマイニングを実務の立場利用しており、手法そのものの専門家ではないので、間違いなどが有りましたらご指摘ください。
さてRandomForestは有名なアルゴリズムですので、ご存じの方も多いとは思いますが、CARTの開発者でもある、Leo Breimanが2001年に提案した決定木を用いた集団学習アルゴリズムの1つです。一言で言えば、大量の決定木を作成して、それぞれの決定木が出した答えを多数決し、最も支持の多かったクラスに分類する手法です。(回帰の場合は平均を返します)
RandmoForestをご存じない方はid:hamadakoichiさんの資料が非常にわかりやすいです。
さて本題。過去によくあった質問は以下の内容です。
- なぜRandomForestは精度が高くなるのか?
- バギングとの違いは何か?
- パラメータチューニングはどうすればよいか?
上2つは同じ質問とも取れますが、知ってる範囲で回答していきたいと思います。
なぜRandomForestは精度が高くなるのか?
RandomForestに限らず決定木のような弱学習器は集団学習に向いているとよく言われます。これを理解するためにはバイアスーバリアンスの観点から説明ができます。
バイアス-バリアンス - 機械学習の「朱鷺の杜Wiki」
バイアスーバリアンス理論によると汎化誤差は次のように分解されます。
汎化誤差=バイアス+バリアンス+ノイズ
ここでバイアスはモデルの表現力に由来する誤差、バリアンスはデータセットの選び方に由来する誤差、ノイズは本質的に減らせない誤差です。
決定木はアルゴリズムの性質上、モデルが学習データからうける影響が大きくバリアンスが高い学習モデルになります。RandomForestやbaggingなどの集団学習アルゴリズムはこのバリアンスを低減させることで精度を向上を図ります。
ちなみに高精度で有名なアルゴリズムであるSVMであまり集団学習の話を聞かないのは、SVMが低バリアンスのモデルだからです。
baggingとの違いは何か?
baggingは集団学習アルゴリズムの一種で、ブートストラップサンプリングで抽出したデータセットを多数作成し、各々のデータセットに対して学習した識別器の多数決でクラスを分類する方法です。RandomForestとbaggingの違いはRandomForestが特徴量のサンプリングも行なっている点です。
確率変数同士が相関を保つ場合、平均の分散は以下の式で表現されます。
ここで、は生成した決定木の数、は分散、は変数間の相関です。baggingで抽出した決定木はデータによってはブートストラップサンプリングで作成した各々の決定木同士の相関が高いことがあります。これに対して使用する特徴量が違う木をたくさん生成しているRandomForestは決定木間の相関が低い為、上の式の第二項が小さくなりbaggingよりもバリアンスが下がり、baggingよりRandomForestの方が生成する多様性が高くなります。
パラメータチューニングはどうすればよいか?
RandomForestの主要なパラメータは次の2つです。
- 作成する決定木の数
- 1つ1つの決定木を作成する際に使用する特徴量の数
作成する決定木の数を決定する方法は簡単です。予測に用いる木の数を増やしていき結果が安定する数を利用すればよいだけです。上述の様にRandomForestでは決定木間の相関を低下させるために、決定木を作成するときに使用する特徴量もサンプリングします。この時いくつの特徴量を使用するかはパラメータとなっていて、決定木の場合は特徴量の数がNの時√Nが推奨値となっています。
しかしながら、実際のところは最適な特徴量数はデータ依存です。特徴量が多い場合や、意味のある特徴量が全体の中で少ない場合は推奨値よりも大きめの値を設定したほうが良い結果が得られる傾向がありますので、グリッドサーチで決定することをお勧めします。
実装
Machine Learning Advent Calendar のコメント欄に「なにか実装します」と書いてしまったことを後悔しつつRandomForestのコア部分をPythonで実装してみました。Out-Of-Bugは間に合わなかったので無しです。1つ1つの木はscikit-learnというライブラリを用いて計算しています。(ちなみにscikit-learnにはRandomForestも実装されています)
#!/usr/bin/env python # -*- coding: utf-8 -*- ''' Created on 2012/12/21 @author: shakezo_ ''' from sklearn.datasets import load_iris from sklearn import tree from sklearn import cross_validation from sklearn.cross_validation import train_test_split import numpy as np def feature_sampling(data,feature_num,mtry): partial_data = [] arr = np.arange(feature_num) np.random.shuffle(arr) for d in data: partial_data.append(d[arr[0:mtry]]) return [partial_data,arr[0:mtry]] def predict(clf_list,data): #多数決によるモデルの決定 predict_dic ={} for clf in clf_list: input = data[clf[1][1]] model = clf[0] pid =int(model.predict(input)[0]) predict_dic[pid] = predict_dic.get(pid,0) + 1 #多数決でクラスを決定 max_count = 0 max_id =-1 for k,v in predict_dic.iteritems(): if v>max_count: max_count = v max_id = k return max_id if __name__ == '__main__': target_names = {} #irisデータセットを取得 iris = load_iris() #ターゲットを取得 for i,name in enumerate(iris.target_names): target_names[i] = name #データ分割 x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.33, random_state=42) #parameter tree_num = 500; train_num = int(len(x_train)*(2.0/3)) test_num = len(x_train)-train_num feature_num = len(x_train[0]) mtry = 2 #ブートストラップサンプリング data_list = [] target_list = [] input_data_list = [] clf_list = [] bs = cross_validation.Bootstrap(len(x_train),n_bootstraps=tree_num,train_size=train_num,test_size=test_num, random_state=0) #ランダムフォレストの実行 #使用する特徴量とデータをサンプリングして決定木を構築 for train_index, test_index in bs: data = x_train[train_index] target = y_train[train_index] data_list.append(data) target_list.append(target) #特徴量の選択 input_data = feature_sampling(data,feature_num,mtry) input_data_list.append(input_data) #決定木の作成 clf = tree.DecisionTreeClassifier() clf = clf.fit(input_data[0], target) #作成した木とデータを追加 clf_list.append([clf,input_data]) #データの予測 predict_id_list = [] #test_data_list = iris.data correct_num = 0 for i,data in enumerate(x_test): pid=predict(clf_list,data) predict_id_list.append(pid) if pid == y_test[i]: correct_num += 1 #Accuracy print "Accuracy = " ,correct_num/float(len(x_test))
結果
Accuracy = 0.98
それでは。