ご無沙汰しております。システム開発部の potter です。
Jubatus のバースト検知を使ってみた話以来となるので久々の投稿となります。
最近運用チームから開発チームに異動したのですが、Scala を使うプロジェクトがありそうということでその勉強を始めました。
今回はそんな Scala 歴 2 週間足らずの私が人生で初めて書いた Scala プログラムを晒してしまおうという、ちょっぴり M な香りのする企画となります。 (ただし晒しているのは Scala の先輩であるシステム開発部の uraura 先生にいくつかアドバイスを頂いて修正したプログラムになります。Special tanks to uraura!! )
作成したプログラムの内容
単純パーセプトロンによる識別精度を n-fold cross validation により評価するプログラムを作成しました。1
使用する入力データは下記のような3カラムのファイルとなります。
1 2 3 4 5 6 7 |
0.11,0.42,1 0.13,0.43,1 0.13,0.64,1 0.14,0.15,-1 0.14,0.43,1 0.16,0.24,-1 ・・・ |
- 1列目:特徴ベクトルの1次元目
- 2列目:特徴ベクトルの2次元目
- 3列目:正解ラベル
入力データの特徴ベクトルをグラフにプロットすると下記のようになります。
データ数は100です。特徴ベクトルの 1 次元目を横軸に、特徴ベクトルの 2 次元目を縦軸にとり、正解ラベルが -1 のデータは青い点で、正解ラベルが 1 のデータは赤い点で示しています。
今回は 10-fold cross validation としましたので、これらのデータを10グループに分割します。
その中の9グループ(すなわち90個のデータ)を学習データとして用い、データを分類する識別直線を学習します。
その後残りの1グループ(すなわち10個のデータ)を評価データとして用い、その識別直線で正しく識別できるかを評価します。
学習データ・評価データの選び方は10通りあるので、この学習・評価を10回繰り返して総合的な精度を算出します。
ソースコード
今回は sbt プロジェクトとしてプログラムを作成しました。ソースのフォルダ階層は以下の通りです。2
各プログラムのソースは以下の通りです。
build.sbt
1 2 3 4 5 6 7 8 9 10 |
name := "perceptron_scala" version := "1.0" scalaVersion := "2.11.5" libraryDependencies ++= Seq( "com.github.scala-incubator.io" %% "scala-io-core" % "0.4.3", "com.github.scala-incubator.io" %% "scala-io-file" % "0.4.3" ) |
Main.scala
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import scala.io.Source import scala.annotation.tailrec object Main { /** 学習係数 */ val LEARN_COEFFICIENT = 0.01f /** 学習の繰返回数上限値 */ val LOOP_MAX = 1000 /** クロスバリデーションの分割数 */ val CROSS_VALIDATION_SPLIT_COUNT = 10 /** 特徴ベクトルの次元数 */ val FEATURE_VECTOR_DIMENSION = 2 /** 重みベクトルの初期値 */ val INITIAL_WEIGHT_VALUE = 0.0f /** * クロスバリデーションに用いるデータ * @param featureVector 特徴ベクトル * @param label 正解ラベル * @param group クロスバリデーションのグループ */ case class Data(featureVector: List[Float], label: Int, group: Int) /** * 1回分のクロスバリデーション結果 * @param evaluationGroup 評価データのグループ名 * @param evaluationCount 評価データ数 * @param correctCount 評価において識別が成功したデータ数 * @param accuracyRate 精度 * @param weight 学習結果の重み係数 */ case class ValidationResult(evaluationGroup: String, evaluationCount: Int, correctCount: Int, accuracyRate: Float, weight: Seq[Float]) /** * 同値の値で構成される指定したサイズのリストを作成する * @param listSize リストの要素数 * @param elementValue 要素の値 * @tparam A リスト要素の型 * @return */ @tailrec def createSameValueList[A](listSize: Int, elementValue: A, output : List[A] = Nil): List[A] ={ if(listSize <= 0) output else createSameValueList(listSize - 1, elementValue, elementValue :: output) } def main(args : Array[String]): Unit = { //データ読込・クロスバリデーション用のグループラベル付与 val data=Source.fromFile("src/main/resources/data.csv").getLines.zipWithIndex.toList.map{case (line, index) => val items = line split "," val featureVector = 1.0f :: //先頭の1.0fはバイアス項の要素 (0 to FEATURE_VECTOR_DIMENSION - 1).map{itemIndex => items(itemIndex).toFloat}.toList val label = items(FEATURE_VECTOR_DIMENSION).toInt val group = index % CROSS_VALIDATION_SPLIT_COUNT + 1 //簡単のため今回はランダムには分割しない Data(featureVector, label, group) } //クロスバリデーション val result : Seq[ValidationResult] = (1 to CROSS_VALIDATION_SPLIT_COUNT).map{group => //学習データと評価データを取得 val trainingData = data.filter(_.group != group). map{item => Perceptron.TrainingData(item.featureVector, item.label)} val evaluationData = data.filter(_.group == group). map{item => (item.featureVector, item.label)} //学習 val initialWeight = createSameValueList(FEATURE_VECTOR_DIMENSION + 1, INITIAL_WEIGHT_VALUE) val trainedWeight = Perceptron.train(initialWeight, trainingData, LEARN_COEFFICIENT, LOOP_MAX) //評価 val evaluationCount = evaluationData.size val correctCount = evaluationData.count{case(featureVector,label) => Perceptron.predict(trainedWeight, featureVector) == label } val accuracyRate = correctCount.toFloat / evaluationCount.toFloat //結果 ValidationResult(group.toString, evaluationCount, correctCount, accuracyRate, trainedWeight) } //total での精度を計算 val (evaluationCounts, correctCounts) = result.map{r => (r.evaluationCount, r.correctCount)}.unzip val totalEvaluationCount = evaluationCounts.foldLeft(0)(_ + _) val totalCorrectCount = correctCounts.foldLeft(0)(_ + _) val totalAccuracyRate = totalCorrectCount.toFloat / totalEvaluationCount.toFloat //結果出力 println("(evaluationGroup,accuracyRate,weight)") result.foreach{r => println(r.evaluationGroup + "," + r.accuracyRate + ",[" + r.weight.mkString(",") + "]") } println ("------------------------------------------") println ("total accuracy rate : " + totalAccuracyRate) } } |
Perceptron.scala
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import scala.util.control.Breaks object Perceptron { /** * 学習データ * @param featureVector 特徴ベクトル * @param label 正解ラベル */ case class TrainingData(featureVector: Seq[Float], label: Int) /** * 未知データのラベルを予測する * @param w 重みベクトル * @param x 未知データの特徴ベクトル * @return ラベル (1 or -1) */ def predict(w : Seq[Float], x : Seq[Float]) = if(w.zip(x).map(item => item._1 * item._2).sum >= 0.0f) 1 else -1; /** * 学習データの集合を入力して学習後の重みベクトルを取得する * @param initialWeight 学習前の重み係数 * @param data 学習データ集合(特徴ベクトルと正解ラベルのリスト) * @param coefficient 学習係数 * @param loopMax 学習繰返回数上限 * @return 学習後の重み係数 */ def train(initialWeight : Seq[Float], data : Seq[TrainingData], coefficient : Float, loopMax : Int) : Seq[Float] = { var loop = 0; var weight = initialWeight val breaks = new Breaks() import breaks.{break, breakable} breakable { while (loop < loopMax) { loop += 1; var converged = true data.foreach{d => if (predict(weight, d.featureVector) * d.label < 0) { //予測結果と正解ラベルが異なる場合は重み係数の更新 weight = weight.zip(d.featureVector).map{case (w,x) => (w + coefficient * d.label * x) } converged = false } } if (converged) { break } } } weight } } |
実行結果の確認
プログラムを実行すると、標準出力に下記の通り出力されます。
1 2 3 4 5 6 7 8 9 10 11 12 13 |
(evaluationGroup,accuracyRate,weight) 1,0.9,[-0.02,-0.06260001,0.10020004] 2,1.0,[-0.01,-0.029599985,0.0496] 3,1.0,[-0.02,-0.06849995,0.11019996] 4,1.0,[-0.01,-0.028699988,0.04909999] 5,0.8,[-0.01,-0.018399999,0.039199997] 6,1.0,[-0.01,-0.028899983,0.048899986] 7,1.0,[-0.01,-0.031399984,0.051899962] 8,1.0,[-0.01,-0.029699987,0.049699992] 9,1.0,[-0.01,-0.031599984,0.051999986] 10,1.0,[-0.02,-0.06879996,0.1104999] ------------------------------------------ TotalAccuracyRate : 0.97 |
妥当な結果が得られているのか確認するため、 evaluationGroup = 1 の結果について見てみましょう。
まず、 evaluationGroup = 1 の結果では下図の 90 個の学習データによる学習によって(-0.02,-0.06260001,0.10020004)という重みベクトルが得られています。
この重みベクトルは、下図の緑色の識別直線でデータを識別することを意味します。
識別直線付近の際どいデータもありますが、学習データについては誤分類が発生しないような識別直線が得られ、学習が収束していることがわかります。
次に、上記の学習で得られた識別直線と評価データをプロットしたのが下図になります。
ご覧の通り、赤い点が一つだけ識別直線の下に配置されていることがわかります。
このデータを誤識別してしまったために accuracyRate=0.9 となっていることがわかります。
Scala 力をアップするために
今回、構文すら良くわかっていない状態からプログラムを書き始めたのですが、 やはり最初はどのように記載するのが Scala らしいプログラムなのかといったことも良くわからず、試行錯誤でした。
そのような中で、私と同じような Scala 初学者がどういったことに気を付ければ Scala 力をアップして行けるのか、 私なりに感じたことを記載させて頂きます。
map や filter などの予め用意されている高階関数の使い方を覚えよう
Scala はオブジェクト指向型と関数型の 2 つのパラダイムを兼ね備えた言語となります。 後者の関数型のパラダイムに由来する部分ですが、Scala の関数は「第一級の値」として扱われます。 これは Int や String を関数の引数や戻り値にできるように、関数を関数の引数や戻り値にすることができることを意味します。
map や filter といったよく使われる関数も引数に関数を取る関数(高階関数)となります。 これらの関数を有効に使えるようになってくると、プログラムもすっきりしてくるし、関数を引数に取るという考え方にも慣れてくるように思います。
個人的には、これらの関数を使う際は引数となっている関数の入力(引数)と出力(戻り値)の型を意識すると理解しやすいように思いました。
immutable な変数を使って処理を書くことに慣れよう
一度値を設定した変数に再代入することができないことを immutable であると言います。 Scala では val で定義した変数は immutable な変数、 var で定義した変数は mutable な変数となります。
Scala では(初期設定された値が変更されないことが保証されるので)なるべく mutable な変数は用いずに処理を書くことが推奨されます。
そのような中で Java などとは少し異なる方法で処理の実装を検討する場面が出てきます。
今回のプログラムでも、 createSameValueList を、末尾再帰の再帰関数3として定義しています。関数型プログラミングに馴染みのなかった私などは下記のような実装をしたくなってしまうところですが、この場合 mutable な変数を用いることになってしまうので上述したソース内の実装としました。
1 2 3 4 5 6 7 |
def createSameValueList[A](listSize: Int, elementValue: A): List[A] ={ var list = List[A]() while (list.size < listSize){ list = elementValue :: list } list } |
その他、 totalEvaluationCount と totalCorrectCount についても、 mutable な変数として定義して for 文で足し込んで行くという実装も出来そうですが、 foldLeft 関数を用いて畳み込むことで mutable な変数の利用を避けています。
ただ必ずしも immutable な変数のみで処理を記載しなければならないというわけではありませんし、 mutable な変数を用いなければならない(又は用いた方が良い)局面もあるかと思います。
immutable な変数での処理記述のパターンを理解し、どういった時に immutable に拘るべきなのか、 どういった時には mutable の変数を用いるべきなのかを判断できるようになった方が良いかと思いますが、このあたりは少し経験が必要なのかなと感じました。
なお、余談となりますが、最初に末尾再帰で createSameValueList 関数を実装しようとした時に下記のような実装をしていました。
1 2 3 |
def createSameValueList[A](listSize: Int, elementValue: A): List[A] ={ if(listSize <= 0) Nil else elementValue :: createSameValueList(listSize - 1, elementValue) } |
しかしこれは末尾再帰になっていません。 createSameValueList の評価の後に elementValue をリストに追加する :: が評価されるためです。ソースの最後に記載しているからといって末尾再帰になるわけではなく、あくまで関数内で最後に評価されないと末尾再帰にならないということですね。
ここは uraura 先生からの有り難いマサカリによって気づいた所です。同時に 「 @tailrec を指定して末尾再帰になっていなかったらコンパイルエラーになるようにしようね」というアドバイスも頂きました。
Scala 特有の構文を覚えよう
Scala にはプログラムの見通しを良くするために便利な構文が存在しています。
例えば私は最初 DTO として case classを定義できることや、クロージャで tuple を入力とする際 case (x, y)のように記述できることを知らなかったのですが、これらを使うことでだいぶプログラムの見通しが良くなったと感じています。このあたりも uraura 先生に頂いたアドバイスで改善した所です。
こういった有効な構文を 1 つずつ覚えて使えるようになっていくと、 Scala 力がアップしていくのかなと思いました。
参考書籍
今回は下記のような書籍を参考にして勉強・実装を進めました。
- Guide to ScalaーScalaプログラミング入門
- ざっと流し読んで Scala の全体像がどんな感じなのか知るために用いました。また、最初の方しかやっていませんが、対話型シェル(REPL)にサンプルソースを実際に打ち込んでみることで Scala の動作を確認しました。
- Scala逆引きレシピ
- プログラムを書きながらわからないところを調べるのに用いました。また本投稿とはあまり関係ありませんが、sbt に関する情報はよくまとまっていてわかりやすかったです。
- 関数プログラミング実践入門 ──簡潔で、正しいコードを書くために
- プログラム例は Haskell ですが、関数型言語の考え方を知るのに参考になりました。
まとめ
- Scala を用いて単純パーセプトロンによる識別を n-fold cross validation により評価するプログラムを作成しました。
- 作成したプログラムの実行結果によって学習・評価がどのように実施されているかを確認しました。
- Scala 初学者が Scala 力をアップさせるために意識した方が良いことについて考えてみました。
今回自分にとって未経験の言語にチャレンジしてみたわけですが、やはり新しいことにチャレンジするというのは楽しいものですし、とりわけ Scala は勉強していて楽しい言語だなと感じました。
またプログラムを勉強する際、何かしらのプログラムを実際に作成してみる事は大きな勉強になると改めて思いました。今回パーセプトロンのプログラムを実装してみて、ようやく知識を肉付けて行くための 1 つの軸ができたかなと思います。
とは言え今回のプログラムでは出てきていない概念もありますし、勉強することはまだまだありますので、さらに精進していこうと思った次第です。
ではでは、今回はこのあたりで失礼いたします!
-
本投稿では単純パーセプトロンのアルゴリズムについての説明は行いません。インターネット上に参考になる資料が既に多数存在しますが、例えばこちらで公開されている資料などは参考になるかと思います。 ↩
-
今回は intellij の Scala プラグインにおける sbt テンプレートでプロジェクトを作成しました。 ↩
-
末尾再帰とは関数の最後で自分自身を呼び出す再帰のことを意味します。Scala においては再帰関数を定義する際は末尾再帰で記述することが推奨されています。参考書籍を読んで理解したのみですが、末尾再帰にすることで呼び出し前のスタック状態を保持する必要が無くなってコンパイラによる最適化時に単純なループに変換することができるようになり、スタックオーバーフローのリスクや再帰による実行時のオーバーヘッドがなくなるとのことです。 ↩