記事

連続する入力を使用した予測の作成

再帰型ニューラルネットワーク(RNN)モデルを統合し、連続する入力(シーケンス)を処理します。

概要

機械学習の問題によっては、複数の入力セットが必要で、時間とともに入力シーケンスを処理しなければなりません。ニューラルネットワークモデルは入力シーケンスを処理できますが、入力と入力の間に、ニューラルネットワークのいくつかの状態を保持する必要があります。Core MLは、ネットワークの状態を保持して入力シーケンスを処理するための、簡単な方法を提供します。

ニューラルネットワークのワークフローの理解

自然言語の処理が機械学習モデルにとって難しいタスクなのは、対象となり得る文の数が無限にあり、すべての入力をモデルにエンコードするのが不可能だからです。可能性のある入力数を減らすアプローチとして一般的なのは、文全体を1つの入力として処理するのではなく、文字または単語をモデルへの入力として使うことです。しかしそうすると、モデルには、シーケンスの中で以前に与えられた文字や単語を「覚えている」状態を保持する方法が必要になります。

シェイクスピア劇「ロミオとジュリエット」を生成するようにトレーニングされたニューラルネットワークモデルを見てみます。ニューラルネットワークは、明確なルールを使わずに、単語とその前後の単語との関係をエンコードします。「O, Romeo, Romeo, wherefore art thou Romeo?(おお、ロミオ、ロミオ、あなたはどうしてロミオなの?)」という有名なセリフには「Romeo」という語が3回出てきますが、その後に続く語は毎回異なります。モデルには、語の使い方を区別する方法が必要です。再帰型ニューラルネットワークは、単語を処理する時に、各単語を処理した後のモデルの状態を追加入力として使用することによってこの問題に対応するニューラルネットワークのクラスです。

図1

入力単語が3つある場合の再帰型ニューラルネットワークの入力と出力

図1は、「ロミオとジュリエット」を学習したネットワークのワークフローの例です。フレーズを開始するため、「O」とnilの状態が入力として提供されます。次の単語が予測され、ネットワークは入力「O」に対するその状態の表現である「f("O")」も生成します。次の入力単語「Romeo」と、前の状態「f("O")」とを組み合わせて、次の入力が作成されます。その入力がモデルに提供され、モデルは高い確率で再び「Romeo」を出力します。

次の入力単語「Romeo」は、前の入力単語と同じですが、状態の入力が異なります。今度の状態入力は「f("O", "Romeo")」です。先ほどと入力単語が同じでも、状態が異なるので、ネットワークは「wherefore」という予測を出力できます。

モデルの状態の顕在化

Xcodeのプロジェクトに再帰型ニューラルネットワークベースのモデルを追加することで、ニューラルネットワークの状態を入力フィーチャ、出力フィーチャとして見ることができます。

図2

「ロミオとジュリエット」のテキストを生成する再帰型ニューラルネットワークの例

図2は再帰型ニューラルネットワークレイヤを持つShakespeareLanguageModelをXcodeで表示したもので、状態の入力フィーチャと出力フィーチャがリストされています。LSTM(長・短期記憶)ネットワークやGated Recurrent(ゲート付き再帰型)ネットワークなど、その他の再帰型ニューラルネットワークは、入力フィーチャと出力フィーチャを自動的に作成します。

このネットワークは、入力単語と状態入力(オプション)の2つの入力を取ります。単語はStringで、stateInと名付けられた状態は、512のDouble値の1次元MLMultiArray(英語)です。状態入力がオプションなのは、シーケンスの最初には「前の状態」がないためです。

ネットワークの出力は3つあります。最も可能性の高い次の語、可能性のある次の語とその確率がペアになった辞書、そして入力を処理した後のネットワークの状態を表す、stateOutと付けられた512のDouble値の1次元MLMultiArrayです。

MLMultiArray出力は、ネットワークの状態、つまりその内部ノードの活性化レベルを表します。どのような入力シーケンスが処理されたかをネットワークが「覚えている」ためには、前の出力状態が次の入力に伴う必要があります。

実際には、デフォルトの状態フィーチャ名を持つレイヤを見かける場合があります。例えば、LSTMネットワークは、入力ではlstm_h_inlstm_c_in、出力ではlstm_h_outlsth_c_outという名前のデフォルトの状態パラメータを持ちます。「h」はLSTMネットワークで使われるhidden stateを、「c」はcell stateを示します。ネットワークが入力シーケンスにわたって適切に機能するためには、これらの出力状態が入力状態として引き継がれる必要があります。

入力シーケンスの開始

このネットワークは、セリフ文の最初の2語を与えると文の残りの部分を生成するようにトレーニングされました。このモデルに、プロンプトの最初の単語と前の状態nilを渡し、入力シーケンスの処理を開始します。

リスト1

最初の状態として nil を使用したネットワークの初期化

// Create the prompt to use as an example
let prompt = ["O", "Romeo"]
// Use the generated input API to create the network's input, with no state
let modelInput = ShakespeareLanguageModelInput(previousWord: prompt[0], stateIn: nil)
// Predict the 2nd word and generate a model state for "O"
var modelOutput = try model.prediction(input: modelInput)

このサンプルコードでは、Xcodeによって生成されたShakespeareLanguageModelInputクラスを使用して、予測呼び出し用に2つの入力を保存します。

前の状態に基づく予測

プロンプトの2番目の単語と、予測からの出力状態を入力状態として使い、入力を作成します。モデルでその入力を使い、文の3番目の単語の予測を生成します。

リスト2

2番目の単語と、最初の単語を処理した後の状態を使った、3番目の単語の予測

// Set up the input for the second word (ignoring the predicted words)
modelInput.previousWord = prompt[1]
// Use the output model state as the input model state for the next word
modelInput.stateIn = modelOutput.stateOut
// Predict the third word
modelOutput = try model.prediction(input: modelInput)
// The third word is now in modelOutput.nextWord

最初の2つの単語でネットワークを初期化する場合、入力シーケンスを表現するために出力状態が保持されている必要があります。予測された単語と見込みは無視されます。これらが無視されるのは、2番目の単語「Romeo」は、モデルによる予測ではなく実際のテキストから得られるものだからです。

ただし、セリフ文の最初の2語が処理された時点で、出力nextWordは、文の3番目の単語として最も可能性の高い語です。これが、文の4番目の単語を生成するための入力単語として使われます。出力を入力として使うことを繰り返し、文の残りの部分を生成します。

リスト3

入力単語として次の単語の予測を使用し、文の残りの部分を生成

// Feed the next word and output state back into the network,
// while the predicted word isn't the end of the sentence.
while modelOutput.nextWord != "</s>" {
	// Update the inputs from the network's output
	modelInput.previousWord = modelOutput.nextWord
	modelInput.stateIn = modelOutput.stateOut
	// Predict the next word
	modelOutput = try model.prediction(input: modelInput)
}

リスト3では、予測される単語が</s>になるまで、予測された単語と状態を、入力単語と状態として使用するプロセスを繰り返しています。このネットワークは、文字列</s>を使って文の終わりを表しています。

出力の検証と入力状態のリセット

この時点で、モデルは文の終わりを予測しました。nextWord値のシーケンスは、文全体に関するモデルの予測を表します。予測された文全体をユーザーに示して検証させるか、プログラムを使用して実際のテキストと比較することができます。

入力状態としてnilを使って入力コンテキストをリセットし(リスト1と同じ)、新しい文の予測の作成を開始します。

関連項目

機械学習モデル

class MLModel(英語)

機械学習モデルのすべての詳細をカプセル化したものです。

デバイスでのモデルのダウンロードとコンパイル

Appのインストール後に、ユーザーのデバイスにCore MLモデルを配信します。