xgboostのコードリーディング(その3)

前回(xgboostのコードリーディング(その2) - threecourse’s blog)の続きで、一旦これで完結のつもりです。 前回同様、あくまで私の理解であり、正確性の保証は無いのでご注意下さい。

今回は、もはやxgboostではないのですが、lightgbmのEFB(Exclusive Feature Bundling)という機能でなぜ速くなるのかを説明します。 コードリーディングの発端の発端はここで、この機能でなぜ速くなる理由がわからず気持ち悪かったことにあります。 Kaggle本の執筆でも、途中まではEFBを説明に加えていたのですが、理由を説明できないにもかかわらず記載するのは良くないと思い、削りました。

EFB(Exclusive Feature Bundling)の概要

EFBとは、排他な特徴量をまとめて計算することにより計算量を削減する方法。
排他な特徴量とは、互いに(ほぼ)同時に0以外の値をとらない特徴量で、まとめ方も含めて以下の図のようなイメージとなる。

f:id:threecourse:20191031140356j:plain

ここで、例えばone-hot encodingによる特徴量があれば、それをlabel encodingのように一つの特徴量に戻して分岐の計算を行う・・・と思ってしまうがそうではなく、 あくまでhistogram-basedのヒストグラムの作成でEFBによりバンドルされた特徴量を利用することで計算量を削減し、分岐の計算においては、バンドルされたものでなく元の特徴量を使って計算する。

(参考)
LightGBM: A Highly Efficient Gradient Boosting Decision Tree
NIPS2017読み会 LightGBM: A Highly Efficient Gradient Boosting Decision T…
What makes LightGBM lightning fast? - Abhishek Sharma - Medium

EFBによる計算量の削減

lightgbmは、前回説明したhistogram-basedアルゴリズムにより計算が行われる。ここでは、分岐の作成の部分のみ説明する。
それぞれの葉における分岐の作成では、以下の2つの処理を行う。

  1. ヒストグラムを作成する
  2. 最も良く分割できる分岐の計算を行う

ここで、以下のような流れとなる。

  1. ヒストグラムを作成するときに、
    EFBによって計算量をO(データ数 x 特徴量の数)からO(データ数 x バンドルの数)に落とす
  2. 最も良く分割できる分岐の計算を行うときに、
    バンドルにした特徴量を元に戻して計算する。この計算量はO(ビンの数 x 特徴量の数)である

ヒストグラムを作成する部分は計算量の支配的な部分であり、その部分の計算量を落とすことができる。

コード(メモ)

lightgbmの計算の流れは以下のとおり:

1. python-package/lightgbm/engine.pyのtrainメソッド

学習において、決定木の作成ごとに、BoosterクラスのUpdateメソッドを通して、c++のメソッド_LIB.LGBM_BoosterUpdateOneIterを呼び出す

2. src/c_api.cppのLGBM_BoosterUpdateOneIterメソッド

GBDTクラス(=Boosterクラスの派生クラス)のTrainOneIterを呼び出す。

3. src/boosting/gbdt.cppのGBDTクラスのTrainOneIterメソッド
  • (各データの勾配を求める処理は省略)
  • SerialTreeLearnerクラス(=TreeLearnerクラスの派生クラス)のTrainメソッドを呼び出し、決定木を作成する
4. src/treelearner/serial_tree_learner.cppのSerialTreeLearnerクラスのTrainメソッド

作成する葉の数だけ、SerialTreeLearnerクラスのFindBestSplitsメソッドを繰り返す。
FindBestSplitsメソッドでは、以下の処理を行う。

  1. SerialTreeLearner::ConstructHistogramsメソッドにより、ヒストグラムを作成する
  2. SerialTreeLearner::FindBestSplitsFromHistogramsメソッドにより、最も良い分岐を作成する
5. src/treelearner/serial_tree_learner.cppのSerialTreeLearnerクラスのConstructHistogramsメソッド

DatasetクラスのConstructHistogramsメソッドを呼び出す。 DatasetクラスのConstructHistogramsメソッドでは、特徴量のバンドルごとに、DenseBinクラスのConstructHistogramメソッドを呼び出してヒストグラムのアップデートを行う

6. src/treelearner/serial_tree_learner.cppのSerialTreeLearnerクラスのFindBestSplitsFromHistogramsメソッド

特徴量のバンドルからそれぞれの特徴量に戻し、特徴量ごとにFeatureHistogramクラスのFindBestThresholdメソッドを実行する