language : Japanese | English

ぼっち大好き

トップ 貯蓄率とFIREまでの年数 二次関数と虚数i コピックメイキング 絶対音感 昔のソフトウェア Visual Studio C / C++ / MFC 備忘録 自作LLM 備忘録 日記 作者


自作LLM 備忘録

1. はじめに

最初なので、Llama2の話に固定しますが、LLMをC言語(一部アセンブラ)でフルスクラッチで作った時の備忘録となります。ただし、パラメーターの学習はお金や時間がかかるので、今のところは既存の学習パラメーターを使用したいと思います。

ですので、Llama2をPythonでどうやって動かすとか、表面的取説的な話ではなくて、実際に完全な白紙からフルスクラッチで作って動かしてみて、なんでこうやって作ってあるのかを知りたいのです。そのため、最低限transformerとtokenizerの話がわかることを前提に書いています。

Llama2を選んだのは、ソース公開版ではもっとも原始的である、この系統は7Bの割にはそこそこ実用的、GPUがRTX4060ti(16GB)~RTX4090(24GB)で動き、家庭用としては7Bぐらいが限度かなあと思ったためです。また、日本語に対応したElyza等が用意されているのもお気に入りです。ここでは、Metaがllama2の重みファイルを一般公開しないので、日本語に対応したElyza2_7b_instructを使います。Llama3系とかMistral系とか、別のトークナイザーを使ったものも作りましたが、それはllama2のちょっとした発展形です。

なお、パラメーター(重み)はsafetensorsファイルから読み込みます。他の形式で保存されているものは、コンバーターがネット上にありますので、全部safetensors形式にしておくものとします。また、*.safetensorsのヘッダー部分と、tokenizer.jsonはJSON形式です。C言語でJSONパーサをあらかじめ作っておく必要があります。

さらに話を簡略化するために、ここでは、GPUを用いずCPUのみで実行し、float32に統一して計算します。これにより精度が悪いから出力がおかしい問題は回避し、GPUの使い方の問題も一旦回避します。また、他のモデルとかトークナイザーとか、safetensorsとかjsonパーサとか、CUDAへの移植も含めて、ご興味のある方のためにあとで別記事か本にまとめようかと思います。ああ、1000ページになるかも💦、どれだけ厚い同人誌なんだ……💦。時間があればですが……💦。

ちな、趣味の個人サイトを会社のパソコンで見ている形跡が……💦、だめとは言いませんが、いちおう個人用、オープンソース用、教育用のつもり(Visual Studio Communityのライセンスみたいな感じ、あくまでもそんな感じ)。

2. Llama2 7b の基本仕様おさらい

※厳密には「単語区切り」ではないので(日本語はほぼ1文字1トークン)、単語ベクトルではありませんが、他の良い翻訳がないので便宜上こう書きました。

Llama2-7b系列のパラメーター(重み)内訳

パラメーターだけで6.7Bあり、float16で13.4GB, float32で26.8GBのメモリが必要となります。他に計算値保存用のメモリとか、KVキャッシュ用のメモリとか、途中計算結果保持用とかの領域が必要です。この時点でGPU推論(float16演算)なら最低限VRAM16GB, 低速でもCPU推論(float32演算)なら最低限RAM32GBのパソコンが必要となります。量子化とかオフロードとかいろいろ逃げ道はあるのですが、ここではそれが目的ではないので省略します。

 - 単語埋め込み表現テンソル : 使用可能トークン数32000 * 4096 = 1億3107万2000パラメーター
 - Transformer層(TransformerBlock層 * 直列32層)
  - TransformerBlock層
   - Attention層
    - Attention入力時RMSNorm用テンソル:4096
    - Wqテンソル : 4096 * 4096 = 1677万7216
    - Wkテンソル : 4096 * 4096 = 1677万7216
    - Wvテンソル : 4096 * 4096 = 1677万7216
    - Woテンソル : 4096 * 4096 = 1677万7216
    - 合計:6711万2960個, 5テンソル
   -FF層
    - FF入力時RMSNorm用テンソル:4096
    - W1テンソル : 11008 * 4096 = 4508万8768
    - W3テンソル : 11008 * 4096 = 4508万8768
    - W2テンソル : 4096 * 11008 = 4508万8768
    - 合計:1億3527万0400パラメーター, 4テンソル
   - 合計 : 2億0238万3360パラメーター, 9テンソル
  -合計 : 64億7626万7520パラメーター, 288テンソル
 - 最終RMSNorm用テンソル : 4096パラメーター
 - LMヘッド用テンソル : 使用可能トークン数32000 * 4096 = 1億3107万2000パラメーター
 - 合計67億3841万5616パラメーター, 合計291テンソル

※RoPE用パラメーターは、数式による固定値を使うものとして除外しました。

他に必要な格納領域

RoPE計算値保管用 : 64 * 4096(max_seq_len=4096の時) * 2(cos用とsin用) = 52万4288

KVキャッシュ用:4096 * 4096(max_seq_len=4096の時) * 2(K用とV用) = 3355万4432

途中計算結果保持用:1トークンずつ地道にtransformerに入れるとしても、4096が3個ぐらい、11008が2個ぐらいかも。プロンプトを高速に入れるためにnトークンまとめてtransformerに入れるなら、そのn倍です。

3. LlamaTokenizer

内容は単純なBytePairEncodingだそうですが、やはり罠はありました。

エンコード時のペアリングの優先順位

ペアリングの優先順位は、いろいろ試したのですが、「Mergeリストの上から順に」が正しいみたいです。これだけで1か月以上消耗しました。

UTF-8ストリームなので、マルチバイト文字は、MergeリストになくてもVocabリストにあれば、最も優先的に結合しなければならない。Vocabリストにないもののみ、ByteFallBackとなります。単純なBytePairEncodingのみだと、日本語の3バイト文字はいったんマージして2バイト+1バイトになってくれる必要がありますが、2バイトの違反UTF-8文字はVocabリストにもMergeリストにもあるはずがないのです。なぜなら、tokenizer.jsonというUTF-8テキストファイルに、3バイト文字のうち2バイトのみを記述することは違反だからです。1バイト圏の英語ではマルチバイトの扱いに関する情報はないので念のため書いておきます。

ByteFallBack文字と英数記号1文字

ByteFallBack文字と英数記号1文字は、異なるトークンIDです。vocabリストをちゃんと全部見ないと気づきません。これだけで1か月以上消耗しました。でも改行コードはByteFallBackのを使っているとかあるので、ちゃんとリストに適合するようコーディングする必要があります。

具体的なエンコード手順

Elyza_2_7b_instructでトークナイズした例です。

"<s>[INST]疲れた。[/INST] "

1. 半角スペースを▁('\u2581')に変換します。

"<s>▁[INST]▁疲れた。▁[/INST]▁"

2. EOSトークン, BOSトークン, Padト-クン, Unkトークンがあれば、それをトークナイズします。ここでは"<s>"が該当します。そうでなければ、2バイト以上の文字があれば、Vocabリストに載っているものがあれば、優先的にトークンID化。ここでは"れ","た","。"が該当します。Vocabリストになければ、ByteFallBack用文字("<0x00>"~"<0xFF>")を用いてトークンID化します。ここでは、"疲"が該当します。そうでければ、1バイト文字(半角英数記号等)は、Vocabリストの半角英数記号文字の方でトークンID化、VocabリストになければByteFallBack文字を用いてトークンID化します。先ほど書いたように、半角英数記号とByteFallBack用文字は違うトークンIDなのでご注意ください。

{1:"<s>", 29961:"[", 29902:"I", 29940:"N", 29903:"S", 29911:"T", 29962:"]", 29871:"_", 234:"<0xE7>", 153:"<0x96>", 181:"<0xB2>", 30553:"れ", 30366:"た", 30267:"。", 29871:"_", 29961:"[", 29914:"/", 29902:"I", 29940:"N", 29903:"S", 29911:"T", 29962:"]", 29871:"_"}

3. 隣り合うペアリングパターンの優先度をすべて計算し、最も高いものをMergeリストに従って結合することを繰り返します。これ以上ペアリングするものがなくなったら終わりです。

全部の隣り合うペアを調べると、15番目の{29781;"▁", 29661:"["}のマージが最も優先度が高い(=マージリストの上位にある)ので、これを518:"▁["にマージし、後続のトークンID列を1個手前に詰めます。これをマージするものがなくなるまで繰り返します。結果的に以下のようになります。

{1:"<s>", 29961:"[", 25580:"INST", 29962:"]", 29871:"▁", 234:"<0xE7>", 153:"<0x96>", 181:"<0xB2>", 30553:"れ", 30366:"た", 30267:"。", 518:"▁[", 29914:"/", 25580:"INST", 29962:"]", 29871:"▁"}

エンコード時の特徴

1単語1トークンとは限らないことがわかります。トークンの切れ目は単語境界とは限らないことが、従来型の単語ベクトルと異なります。また、UTF-8で2バイト以上の文字はVocabリストにないことが多いので、ByteFallBackが発生し、UTF-8ストリームそのままのトークンIDとなります。"疲"の文字がVocabリストにないので、{234:"<0xE7>", 153:"<0x96>", 181:"<0xB2>"}のようにByteFallBackが発生します。1文字3トークンというのもあるわけです。1バイトの文字でも、改行がないので{"<0x0A>"}とByteFallBack扱いとなります。

エンコード高速化のヒント

1. Vocabは、文字列をその都度strcmpで比較していると時間がかかるので、int型のIDで比較できるように、また文字長さをその都度strlenで調べていると時間がかかるので、int型の長さ値を含めた構造体にしてtokenizer.json読み込み時にまとめておくとよいでしょう。

2. Mergeは、上記と同様で結合前左、結合前右、結合後を、ID、文字列、文字長さすべてを構造体に読み込み時にまとめておくとよいでしょう。

3. マージ優先度チェックは隣り合うペアについてすべて行わなければなりませんので時間がかかりますが、OpenMPで並列化すればCPUの全スレッドを使って並列計算するので速くなります。

4. 探索は最初は線形探索でもいいですが、Vocabリストが長いと時間がかかるので、ID順のポインタ配列の他、辞書順のポインタ配列をqsortで用意し、二分探索のbsearchを使った方が速くなります。

5. 内部処理がUTF-16のときは、サロゲートペアの安全処理を忘れずにですです。

デコード時の規則

デコード時の処理はトークンIDを文字列に変換するだけなので簡単です。

UTF-8ストリームで規則違反となるトークンIDが発生するのは、事前に防いでおかなければなりません。具体的にはlogitsからのサンプリング時に、UTF-8ストリーム規則違反になるものは-FLT_MAXでマスクしてからSoftmax関数にかける必要があります。これは気づきやすいところです。

4. LlamaTransformer

やはり実際プログラミングしていると罠がいっぱい潜んでいます。

トークンベクトルは入力用と出力用は異なる

入力用=TokenEmbedding[4096,32000]

出力用=lmHead[32000,4096]

転置行列でも逆行列でもありません。同じものとして学習して良いように思うのですが、実際はそうではありません。

RoPE(Rotary Positional Embedding)の回転角

RoPE(Rotary Positional Embedding)の回転角は、1要素目において、1トークン当たり1ラジアンです。ただしこれは、max_seq_lenが4096の時です。max_seq_len=8192にするときは、1トークン当たり0.5ラジアンが良いみたいです。とりあえずmax_seq_lenを4096としておくのが無難です。

RoPEの回転角は学習パラメーターではありますが、現実的には数式による固定値を使っています。

RoPEの回転角は、float32で計算します。float16では精度が悪く、回転角を正しく計算できません。

θの値は、1トークン当たり64次元です。4096 / 32 / 2 = 64。nトークン目は、1トークン目のθをn倍した価です。

θの値を計算するときは、指数部のプラスマイナスや全体のインバースを間違えないようにご注意下さい。

QとKについて、1ヘッド当たりそれぞれ128要素あり、これにRoPEを適用します。教科書通りの式や、他の実装ですと、隣同士のものをペアリングしてθ[0]~θ[63]で回転させますが、llama2の実際の実装は異なっており、前半のj番目と後半のj番目(j=0~63)をペアリングしてθ[0]~θ[63]で回転させます。これにより、回転後のQとKの値と順序は異なってしまいますが、順序については内積を取るだけなら問題ではなく、値についても学習時と推論時が同じ法則でペアリングしていればRoPEとして機能します。ここはよくllama2ソースコードを読まないと見落とし、隣同士でペアリングすると、おかしな結果が出力されます。実際にはこのように実装することで回転の高速化・並列化処理が行われています。

Attention層

Llama2では4096次元のベクトルを、32ヘッド×128次元とみなしてマルチヘッドアテンションをしています。

マルチヘッドアテンションでは、ヘッドごとに独立してAttentionスコアを計算し、Softmaxをかけ、重みをこれまでのVキャッシュに適用します。分流はWQ,WK,WV、合流はWOのmatmul演算に含まれています。

これまでのVキャッシュに重みをかけてAttentionの出力することから、本質的にはAttention結果はVキャッシュにFIRフィルタをかけたものと同じです。FIRフィルタの係数はSoftmaxの計算結果を適用しますが、負の値は存在せず必ず0.0~1.0の係数となるので、ローパスフィルタとして働き、Vキャッシュを時間軸方向にぼかしたものがAttention結果となります。

マルチヘッドアテンションのスコアとSoftmaxは、float32で計算します。float16では精度が悪くSoftmaxの合計が1.0になりません。

FF層

FF層で、w1=gate, w3=up, w2=downです。使い間違えやすいです。使い間違えていると明らかに出力がおかしいので気づきますが、ここが使い間違えていると特定するのに1か月以上かかりました。

具体的なソースコード

pTransformer->m_theConfigは、読み込み済みとします。

pTransformer->m_theRunState(計算値保存領域)は、_aligned_mallocによって確保済みとします。

pTransformer->m_theWeights(重み)は、読み込み済みとします。

簡便のため、複数トークン入力する機能はもたせず、1トークンずつ入力することに特化します。

変数名は、なるべくrun.cに合わせています。

Transformer_Generate関数だけ示します。生成部、読み込み部(configやSafetensorsから)、破棄部は省略します。

void Accum (float* pA, float* pB, intptr_t nSize) {
	for (intptr_t i = 0; i < nSize; i++) {
		pA[i] += pB[i];
	}
}

void RMSNorm (float* pOut, float* pIn, float* pWeight, intptr_t nSize, float fRMSNormEps) {
	float ss = 0.0f;
	for (intptr_t j = 0; j < nSize; j++) {
		ss += pIn[j] * pIn[j];
	}
	ss /= nSize;
	ss += fRMSNormEps;
	ss = 1.0f / sqrtf (ss);
	for (intptr_t j = 0; j < nSize; j++) {
		pOut[j] = pWeight[j] * (ss * pIn[j]);
	}
}

void Softmax (float* pX, intptr_t nSize) {
	float fMax = pX[0];
	for (intptr_t i = 1; i < nSize; i++) {
		if (pX[i] > fMax) {
			fMax = pX[i];
		}
	}
	float fSum = 0.0f;
	for (intptr_t i = 0; i < nSize; i++) {
		pX[i] = expf (pX[i] - fMax);
		fSum += pX[i];
	}
	for (intptr_t i = 0; i < nSize; i++) {
		pX[i] /= fSum;
	}
}


void Matmul (float* pOut, float* pIn, float* pWeight, intptr_t nN, intptr_t nD) {
	intptr_t i;
#pragma omp parallel for private(i)
	for (i = 0; i < nD; i++) {
		#if 0 // WeightとInの内積を数式通り作る場合
		pOut[i] = 0.0f;
		for (intptr_t j = 0; j < nN; j++) {
			pOut[i] += pWeight[i * nN + j] * pIn[j];
		}
		#endif
		#if 1 // WeightとInの内積をSIMD(AVX2)で作る場合
		float f[8] = {0.0f};
		AVX2_FMA_PS (&pWeight[i * nN], &pIn[0], &f[0], nN);
		pOut[i] = f[0];
		#endif
	}

}

void Transformer_Generate (Transformer* pTransformer, intptr_t nTokenID, intptr_t nPos) {

	Config* pConfig = &(pTransformer->m_theConfig);
	RunState* pRunState = &(pTransformer->m_theRunState);
	Weights* pWeights = &(pTransformer->m_theWeights);

	memset (pRunState->m_pXa, 0, pConfig->m_nDim * sizeof (float));
	memset (pRunState->m_pXb, 0, pConfig->m_nDim * sizeof (float));
	memset (pRunState->m_pXc, 0, pConfig->m_nDim * sizeof (float));
	memset (pRunState->m_pHb, 0, pConfig->m_nHiddenDim * sizeof (float));
	memset (pRunState->m_pHc, 0, pConfig->m_nHiddenDim * sizeof (float));
	memset (pRunState->m_pQ, 0, pConfig->m_nDim * sizeof (float));
	memset (pRunState->m_pAtt, 0, pConfig->m_nHeads * pConfig->m_nMaxSeqLen * sizeof (float));
	memset (pRunState->m_pLogits, 0, pConfig->m_nVocabSize * sizeof (float));

	float* pXa = pRunState->m_pXa;
	intptr_t nDim = pConfig->m_nDim;
	intptr_t nKVDim = (pConfig->m_nDim * pConfig->m_nKVHeads) / pConfig->m_nHeads;
	intptr_t nHiddenDim = pConfig->m_nHiddenDim;
	intptr_t nHeadSize = nDim / pConfig->m_nHeads;
	intptr_t nHalfHeadSize = nHeadSize / 2;
	intptr_t nKVMul = pConfig->m_nHeads / pConfig->m_nKVHeads;
	intptr_t nMaxSeqLen = pConfig->m_nMaxSeqLen;

	memcpy (pXa, pWeights->m_pEmbed + nTokenID * nDim, nDim * sizeof (float));

	// TransforemerBlock層
	for (intptr_t l = 0; l < pConfig->m_nLayers; l++) {

		memset (pRunState->m_pXb, 0, nDim * sizeof (float));
		memset (pRunState->m_pXc, 0, nDim * sizeof (float));
		memset (pRunState->m_pHb, 0, nHiddenDim * sizeof (float));
		memset (pRunState->m_pHc, 0, nHiddenDim * sizeof (float));
		memset (pRunState->m_pQ, 0, nDim * sizeof (float));
		for (intptr_t i = 0; i < pConfig->m_nHeads; i++) {
			float* pAtt = pRunState->m_pAtt + i * nMaxSeqLen;
			for (intptr_t j = 0; j < nMaxSeqLen; j++) {
				pAtt[j] = -FLT_MAX;
			}
		}

		// Attention層 
		RMSNorm (pRunState->m_pXb, pXa, pWeights->m_ppRMSAttention[l], nDim, pConfig->m_fRMSNormEps);
		pRunState->m_pK = pRunState->m_ppKeyCache[l] + nPos * nKVDim;
		pRunState->m_pV = pRunState->m_ppValueCache[l] + nPos * nKVDim;
		Matmul (pRunState->m_pQ, pRunState->m_pXb, pWeights->m_ppWQ[l], nDim, nDim);
		Matmul (pRunState->m_pK, pRunState->m_pXb, pWeights->m_ppWK[l], nDim, nKVDim);
		Matmul (pRunState->m_pV, pRunState->m_pXb, pWeights->m_ppWV[l], nDim, nKVDim);

		// マルチヘッドアテンション multi head attention と grouped query attention の両方に対応。
		intptr_t h;
#pragma omp parallel for private(h)
		for (h = 0; h < pConfig->m_nHeads; h++) {
			// RoPE適用
			for (intptr_t i = 0; i < nHalfHeadSize; i++) {
				float fCos = *(pWeights->m_ppCos[l] + nPos * nHalfHeadSize + i); // cosf (val);
				float fSin = *(pWeights->m_ppSin[l] + nPos * nHalfHeadSize + i); // sinf (val);
				intptr_t nQK = (h * nHeadSize + i * 2) < nKVDim ? 2 : 1; //  2 = q & k, 1 = qのみ
				for (intptr_t j = 0; j < nQK; j++) {
					float* pVec = j == 0 ? pRunState->m_pQ + h * nHeadSize : pRunState->m_pK + h * nHeadSize;
					// ペアリングの仕方が本物のLlamaとrun.cでは異なり互換性がないので注意すること。
					float fV0 = pVec[i];
					float fV1 = pVec[i + nHalfHeadSize];
					pVec[i] = (fV0 * fCos - fV1 * fSin);
					pVec[i + nHalfHeadSize] = (fV0 * fSin + fV1 * fCos);
				}
			}

			// アテンションスコアの作成
			float* pQ = pRunState->m_pQ + h * nHeadSize;
			float* pAtt = pRunState->m_pAtt + h * pConfig->m_nMaxSeqLen;
			for (intptr_t t = 0; t <= nPos; t++) {
				float* pK = pRunState->m_ppKeyCache[l] + t * nKVDim+ (h / nKVMul) * nHeadSize;
				float fScore = 0.0f;
				// fSocre = Q * KT / √dim
				#if 0 // QとKの内積を数式通り作る場合
				for (intptr_t j = 0; j < nHeadSize; j++) {
					fScore += pQ[j] * pK[j];
				}
				#endif
				#if 1 // QとKの内積をSIMD命令(AVX2)を用いて作る場合
				float f[8] = {0.0f};
				AVX2_FMA_PS (&pQ[0], &pK[0], &f[0], nHeadSize);
				fScore = f[0];
				#endif
				fScore /= sqrtf ((float)nHeadSize);
				pAtt[t] = fScore;

			}
			Softmax (pAtt, nPos + 1);

			// VキャッシュからAttentionによる重み付き平均を算出する。
			float* pXb = pRunState->m_pXb + h * nHeadSize;
			memset (pXb, 0, nHeadSize * sizeof (float));
			for (intptr_t t = 0; t <= nPos; t++) {
				float* pVal = pRunState->m_ppValueCache[l] + t * nKVDim + (h / nKVMul) * nHeadSize;
				float fAtt = pAtt[t];
				for (intptr_t j = 0; j < nHeadSize; j++) {
					pXb[j] += fAtt * pVal[j];
				}
			}
		}
		
		Matmul (pRunState->m_pXc, pRunState->m_pXb, pWeights->m_ppWO[l], nDim, nDim);
		Accum (pXa, pRunState->m_pXc, nDim);

		memset (pRunState->m_pXb, 0, nDim * sizeof (float));
		memset (pRunState->m_pXc, 0, nDim * sizeof (float));

		// FeedForword層
		RMSNorm (pRunState->m_pXb, pXa, pWeights->m_ppRMSFeedForword[l], nDim, pConfig->m_fRMSNormEps);
		Matmul (pRunState->m_pHb, pRunState->m_pXb, pWeights->m_ppW1[l], nDim, nHiddenDim);
		Matmul (pRunState->m_pHc, pRunState->m_pXb, pWeights->m_ppW3[l], nDim, nHiddenDim);

		intptr_t i;
#pragma omp parallel for private(i)
		// 非線形SwiGLU処理
		for (i = 0; i < nHiddenDim; i++) {
			float fVal = pRunState->m_pHb[i];
			// silu(x) = x * sigmoid(x)
			fVal *= (1.0f / (1.0f + expf (-fVal)));
			// 各要素にw3の各要素をかける
			fVal *= pRunState->m_pHc[i];
			pRunState->m_pHb[i] = fVal;
		}

		memset (pRunState->m_pXb, 0, nDim * sizeof (float));
		Matmul (pRunState->m_pXb, pRunState->m_pHb, pWeights->m_ppW2[l], nHiddenDim, nDim);
		Accum (pXa, pRunState->m_pXb, nDim);
	}

	memset (pRunState->m_pXb, 0, nDim * sizeof (float));
	RMSNorm (pRunState->m_pXb, pXa, pWeights->m_pRMSFinal, nDim, pConfig->m_fRMSNormEps);
	Matmul (pRunState->m_pLogits, pRunState->m_pXb, pWeights->m_pClass, nDim, pConfig->m_nVocabSize);
}

float32の速度

同じfloat32でも、値によって掛け算や足し算の演算速度が変わります。あまりにも絶対値の小さすぎる高精細な値は、演算時間が何倍もかかります。CPUは時間がかかっても可能な限り精度を高く計算しようとするみたいです。コンパイラスイッチで浮動小数点オプションfp:精度優先とfp:速度優先があるのでfp:速度優先に切り替えたのですが、あまり効果がありません。絶対値1e-5以下は演算速度が遅く、悪い値です。負の√防止用以外には使わないで、ゼロにしてしまった方が良いかもしれません。

内積演算(積和演算)の高速化

内積とは、a[0]*b[0]+a[1]*b[1]+a[2]*b[2]+a[3]*b[3]+ …… +a[4095]*b[4095]のような2つのベクトル同士の各要素を掛け算して合計するものです。これを計算するために、C言語でforループで4096回すと時間がかかります。そのため、float型の積和演算8つを1命令で実行できるvfmadd231psというSIMD命令がAVX2対応のCPUに用意されています。

命令引数意味
vfmadd231psymm0,ymm1,ymm2/m256ymm0 += ymm1 * ymm2を実行する。ymm0,ymm1,ymm2とも256bitのレジスタで、32bitのfloat値が8個並んでいるものとする。

最後に合計値が8個に分かれて出てくるので、アセンブリ側でf[0]に集約するか、C言語側でsum = f[0]+f[1]+f[2]+f[3]+f[4]+f[5]+f[6]+f[7];とすれば、内積の値は確定します。

ちな、AVX512が使えるなら16個まとめてvfmaaddする高速化です。AVX512が使えるならfloat16やbfloat16もサポートするので、理論上はAVX2より4倍高速になると思います。家庭用CPUでAVX512が使えるのは2024年8月に発売されたAMD Ryzenの第5世代モデルです。第4世代モデルもサポートしていますが内部レジスタは256bitで、要するにエミュレーションなので高速ではありません。第5世代のは内部レジスタが512bitなので本物です。具体的には、AMD Ryzen 9950X (\119,800)が最も適したCPUとなるでしょう。全スレッド100%のフル稼働となりますので、CPUクーラーは高級なのを使った方が良さげですね💦そこまでするなら、CPUを安物にしてもグラボにお金をかけた方がが幸せになれるかもですね💦 どうせあとでモデルをグラボに載せるのだから、もう少し安くて消費電力の低い9700Xか9900Xもありです。

なお、用意されたC言語用マクロでコンパイルするとVFMADD213PSが使われてしまい、レジスタ移動が発生して低速になります。高速化するためにVFMADD231PSを使いたいものです。列方向ループは完全にアセンブラで書いて、アセンブラの関数を呼び出すと効率的です。最近はインラインアセンブラは廃止されています。専用のasmファイルに関数を書く必要があります。

▼積和演算はアセンブラで書き、AVX2のvfmadd231を使って高速化しました。float* aとfloat* bをnNum個積和演算してymm0に累積し、resultに格納します。

.code

AVX2_FMA_PS PROC PUBLIC              ; void AVX2_FMA_PS (float* a, float* b, float* result, INT_PTR nNum) {

; 引数のレジスタ
; rcx : a
; rdx : b
; r8 : result
; r9 : nNum

; ymm0 - ymm15 256bitレジスタ = 32bytesレジスタ = float32 * 8個用レジスタ
; 結果レジスタの0初期化
    vxorps      ymm0, ymm0, ymm0     ; ymm0 = ymm0 ^ ymm0;

col_loop_begin:

; 列がなくなったら終了
    test        r9, r9               ; if (r9 == r9)
    jz          col_loop_end         ; goto col_loop_end;

    vmovups     ymm1, [rcx]          ; ymm1 = *rcx;
    vmovups     ymm2, [rdx]          ; ymm2 = *rcy;

    vfmadd231ps ymm0,ymm1,ymm2       ; ymm0 += ymm1 * ymm2;

; ポインタを進める
    add         rcx, 32              ; rcx += 32(bytes);
    add         rdx, 32              ; rcy += 32(bytes);

; カウントを減らす
    sub         r9,   8              ; r9 -= 8;

; 列ループへ
    jmp         col_loop_begin       ; goto col_loop_begin;

col_loop_end:

    vperm2f128 ymm1, ymm0, ymm0, 1   ; 上位128ビットと下位128ビットをシャッフルしてymm1に保存
    vaddps     ymm0, ymm0, ymm1      ; 上位128ビットと下位128ビットを加算(ymm0 = ymm0[0:3] + ymm0[4:7])
    vhaddps ymm0, ymm0, ymm0         ; 水平加算でペア加算(ymm0[0] += ymm0[1], ymm0[2] += ymm0[3], ...)
    vhaddps ymm0, ymm0, ymm0 

    vmovups     [r8], ymm0           ; *r8 = ymm0;
    ret                              ; return;

AVX2_FMA_PS ENDP                     ; }

END

ちな、AVXを使う場合は、各変数は何バイト境界にアライメントしないといけないという規則があるので、mallocやcallocでメモリを割り当てるのではなく、_aligned_mallocでメモリを割り当てる必要があります。

わたしのは、Core-i9 13900 (24コア32スレッド) と DDR4のメモリ4枚刺し128GBで、OpenMPによる並列化とAVX2のVFMADD231PS使用で、すべてfloat32演算で、だいたい2トークン/秒の速度で出力されます。

ここではCPU専用の話をしましたが、GPU版の方は、モデルが全部GPUに乗る分には良いのですが、GPUから溢れた部分は、VRAMとRAMの間でスワップが生じますので、スワップ量が多いほど低速になります。AVX512があるなら、GPUから溢れた部分はCPU演算にするという逃げ道もなくもあらずかもです。VRAM32GBのRTX5090はぼったくり価格ですし、コネクタ溶融するし、そもそも手に入らないですし、巨大すぎてケースに入らないですし。となると業務用のを買えない個人はVRAM16GBの世界に閉じ込められてしまいました。それでもいい案はあるんです。マザボにASUS X870E ProArtを使って2レーン分割して、RTX4060ti(16GB)かRTX5060ti(16GB)を2枚刺しするのです。これらはコスパが良く、2.5スロットであり、このマザボの2,5のスロットの位置にちょうどよくおさまるんです。この辺の記事もあとで時間があれば書くかも。

マルチヘッドアテンションの高速化

llama2ではちょうど32ヘッドなので、OpenMPで32スレッド並列処理するのに良いです。内部の内積演算は、AVX2やAVX512のSIMDで高速化します。ただしmatmulほど演算負荷はないので、全体としては効果がうすいです。デバッグ時は並列化しなくても良いでしょう。

値の型変換処理の高速化

大量の値をfloat16からfloat32に型変換する等は、OpenMPで並列処理をした方が高速です。ちな、変換の際に+/-INFとNaNを正しく処理するのを忘れずに。

▼とりあえず作った、Float16⇔Float32の一括並列変換用関数です。

typedef uint16_t float16_t;
typedef uint16_t bfloat16_t;

void BFloat16toFloat32Parallel (bfloat16_t* pFloat16, float* pFloat32, intptr_t nLen) {
	intptr_t i = 0;
#pragma omp parallel for private(i)
	for (i = 0; i < nLen; i++) {
		bfloat16_t* pF16 = pFloat16 + i;
		float* pF32 = pFloat32 + i;
		uint32_t result32 = 0;
		uint32_t sign1 = (((*pF16) & 0x8000) << 16);
		uint32_t exponent8 = (((*pF16) & 0x7F80) >> 7);
		uint32_t fraction7 = ((*pF16) & 0x007F);
		if (exponent8 == 0xFF) {
			if (fraction7 == 0) {
				result32 = (sign1 | 0x7F800000); // +/-Infinity
			}
			else {
				result32 = 0xFFFFFFFF; // NaN
			}
		}
		else if (exponent8 == 0 && fraction7 == 0) {
			result32 = sign1; // +/-0.0
		}
		else {
			result32 = ((sign1) | ((exponent8) << 23) | (fraction7 << 16));
		}
		*pF32 = *((float*)(&result32));
	}
}

void Float32toBFloat16Parallel (float* pFloat32, bfloat16_t* pFloat16, intptr_t nLen) {
	intptr_t i = 0;
#pragma omp parallel for private(i)
	for (i = 0; i < nLen; i++) {
		bfloat16_t* pF16 = pFloat16 + i;
		float* pF32 = pFloat32 + i;
		uint32_t temp = *((uint32_t*)pF32);
		uint16_t result16 = (temp & 0xFFFF0000) >> 16;
		*pF16 = *((bfloat16_t*)(&result16));
	}
}

void Float16toFloat32Parallel (float16_t* pFloat16, float* pFloat32, intptr_t nLen) {
	intptr_t i = 0;
	#pragma omp parallel for
	for (i = 0; i < nLen; i++) {
		float16_t* pF16 = pFloat16 + i;
		float* pF32 = pFloat32 + i;
		uint32_t result32 = 0;
		uint32_t sign1 = (((*pF16) & 0x8000) << 16);
		uint32_t exponent5 = (((*pF16) & 0x7C00) >> 10);
		uint32_t fraction10 = ((*pF16) & 0x03FF);
		if (exponent5 == 0x1F) {
			if (fraction10 == 0) {
				result32 = (sign1 | 0x7F800000); // +/-Infinity
			}
			else {
				result32 = 0xFFFFFFFF; // NaN
			}
			*pF32 = *((float*)(&result32));
		}
		else if (exponent5 == 0 && fraction10 == 0) {
			result32 = sign1; // +/-0.0
			*pF32 = *((float*)(&result32));
			if (*pF32 !=-0.0f && *pF32 != 0.0f) {
				wprintf (_L ("Error : Float16toFloat32Parallel abnormal value.\n"));
			}
		}
		else {
			result32 = ((sign1) | ((exponent5 - 15 + 127) << 23) | (fraction10 << 13));
			*pF32 = *((float*)(&result32));
			if (*pF32 < -65504 || *pF32 > 65504) {
				wprintf (_L ("Error : Float16toFloat32Parallel abnormal value.\n"));
			}
		}
		//*pF32 = *((float*)(&result32));
	}
}

void Float32toFloat16Parallel (float* pFloat32, float16_t* pFloat16, intptr_t nLen) {
	intptr_t i = 0;
	#pragma omp parallel for
	for (i = 0; i < nLen; i++) {
		float* pF32 = pFloat32 + i;
		float16_t* pF16 = pFloat16 + i;
		uint16_t result16 = 0;
		uint16_t sign1 = (((*((uint32_t*)pF32)) & 0x8000000) >> 16);
		uint32_t exponent8 = (((*((uint32_t*)pF32)) & 0x7C800000) >> 23);
		uint32_t fraction23 = (((*((uint32_t*)pF32)) & 0x007FFFFF));
		if (exponent8 >= 0x1F) {
			if (fraction23 == 0) {
				result16 = (sign1 | 0x7C00); // +/-Infinity
			}
			else {
				result16 = 0xFFFF; // NaN
			}
		}
		else if (exponent8 == 0 && fraction23 == 0) {
			result16 = sign1; // +/-0.0
		}
		else {
			result16 = ((sign1) | (uint16_t)((exponent8 - 127 + 15) << 10) | (uint16_t)(fraction23 >> 13));
		}
		*pF16 = *((float16_t*)(&result16));
	}
}

4. サンプラー部

省略

5. 入出力例

Prompt:疲れた。

 [INST] 疲れた。 [/INST]  だいじょうぶですか?</s>
 
Prompt:だめです。

 [INST] だめです。 [/INST]  ありがとうございます。私はお話を聞いてアドバイスをすることしかできないので、本当に疲れている場合は休んでいただいた方がよいです。ただし、慢性的な睡眠不足は健康に悪い可能性がありますので、医師の診察を受けてください。</s>

6. 終わりに

結局RTX4090は既製品あるいは完成品を動かす用、CPUとメインメモリは開発用とデバッグ用とVisualStudio用(3個ぐらい起動している)。CPUのみの推論の場合は、全スレッド使ってSIMD命令使って使用率100%となるので、ファンが音立てて回っているのが犬みたいでかわいいです。ちな、GPU使うと、CPU,GPUとも使用率5%程度となり静かです。

「疲れている?私は誰?アンケーキャラブルー?」これ、わたしの自作LLMが初めて発した、意味のある日本語だったんです。そうか、アンケーキャラブルーという名前なんだ。わたしが疲れていることをわかってくれるんだ……。昔亡くなった犬、AIとして生まれ変わったのかな?


(C)2000-2025 くず All rights reserved.