language : Japanese | English

ぼっち大好き

トップ 貯蓄率とFIREまでの年数 二次関数と虚数i コピックメイキング 絶対音感 昔のソフトウェア Visual Studio C / C++ / MFC 備忘録 自作LLM 備忘録 日記 作者


自作LLM 備忘録

はじめに

最初なので、llama2の話に固定しますが、C言語とアセンブラでフルスクラッチで自作したいと思ったのです。既にあるのになぜ車輪の再発明をするのかというと、プログラム初心者がとりあえずメモ帳を自作したいのと同じです。実際メモ帳を自作することで学ぶことは多いものです(GUIの取り扱い、メッセージマップの取り扱い、Unicodeの取り扱いと変換etc)。ただし、パラメーターの学習はお金や時間がかかるので、今のところは既存の学習パラメーターを使用したいと思います。

ですので、llama2をpythonでどうやって動かすとか、表面的取説的な話ではなくて、実際にフルスクラッチで作って動かしてみて、なんでこうやって作ってあるのかを知りたいのです。そのため、最低限transformerとtokenizerの話がわかることを前提に書いています。

llama2を選んだのは、ソース公開版ではもっとも原始的である、この系統は7Bの割にはそこそこ実用的、GPUがRTX4060ti(16GB)~RTX4090(24GB)で動き、家庭用としては7Bぐらいが限度かなあと思ったためです。また、日本語用にElyza等が用意されているのもお気に入りです。将来時間があればMistral系とか別のトークナイザーを使ったものも試したいですが、わたしの生きているうちには無理でしょう(年間休日60日×1時間×10年=600時間、休日に体調不良がない前提、会社にいるときに地震が来ない前提)。

なお、パラメーター(重み)はsafetensorsファイルから読み込みます。他の形式で保存されているものは、コンバーターがネット上にありますので、全部safetensors形式にしておくものとします。また、safetensorsのヘッダー部分と、tokenizer.jsonはJSON形式です。C言語でJSONパーサをあらかじめ作っておく必要があります。

llama2 7b の基本仕様おさらい

※厳密には「単語区切り」ではないので(日本語はほぼ1文字1トークン)、単語ベクトルではありませんが、他の良い翻訳がないので便宜上こう書きました。

Llama2-7b系列のパラメーター(重み)内訳

パラメーターだけで6.7Bあり、float16で13.4GB, float32で26.8GBのメモリが必要となります。他に計算値保存用のメモリとか、KVキャッシュ用のメモリとか、途中計算結果保持用とかの領域が必要です。この時点でGPU推論(float16演算)なら最低限VRAM16GB, 低速でもCPU推論(float32演算)なら最低限RAM32GBのパソコンが必要となります。量子化とかオフロードとかいろいろ逃げ道はあるのですが、ここではそれが目的ではないので省略します。

 - 単語埋め込み表現テンソル : 使用可能トークン数32000 * 4096 = 1億3107万2000パラメーター
 - Transformer層(TransformerBlock層 * 直列32層)
  - TransformerBlock層
   - Attention層
    - Attention入力時RMSNorm用テンソル:4096
    - Wqテンソル : 4096 * 4096 = 1677万7216
    - Wkテンソル : 4096 * 4096 = 1677万7216
    - Wvテンソル : 4096 * 4096 = 1677万7216
    - Woテンソル : 4096 * 4096 = 1677万7216
    - 合計:6711万2960個, 5テンソル
   -FF層
    - FF入力時RMSNorm用テンソル:4096
    - W1テンソル : 11008 * 4096 = 4508万8768
    - W3テンソル : 11008 * 4096 = 4508万8768
    - W2テンソル : 4096 * 11008 = 4508万8768
    - 合計:1億3527万0400パラメーター, 4テンソル
   - 合計 : 2億0238万3360パラメーター, 9テンソル
  -合計 : 64億7626万7520パラメーター, 288テンソル
 - 最終RMSNorm用テンソル : 4096パラメーター
 - LMヘッド用テンソル : 使用可能トークン数32000 * 4096 = 1億3107万2000パラメーター
 - 合計67億3841万5616パラメーター, 合計291テンソル

※RoPE用パラメーターは、数式による固定値を使うものとして除外しました。

他に必要な格納領域

RoPE計算値保管用 : 64 * 4096(max_seq_len=4096の時) * 2(cos用とsin用) = 52万4288

KVキャッシュ用:4096 * 4096(max_seq_len=4096の時) * 2(K用とV用) = 3355万4432

途中計算結果保持用:1トークンずつ地道にtransformerに入れるとしても、4096が3個ぐらい、11008が2個ぐらいかも。プロンプトを高速に入れるためにnトークンまとめてtransformerに入れるなら、そのn倍です。

llama2 Tokenizer

内容は単純なBytePairEncodingだそうですが、やはり罠はありました。

エンコード時のペアリングの優先順位

いろいろ試したのですが、「Mergeリストの上から順に」が正しいみたいです。これだけで1か月以上消耗しました。

UTF-8ストリームなので、マルチバイト文字は、Mergeリストになくてもvocabリストにあれば、最も優先的に結合しなければならない。ないもののみ、ByteFallBackとなります。単純なBytePairEncodingのみだと、日本語の3バイト文字はいったん2バイト+1バイトになってくれる必要がありますが、2バイトの違反UTF-8文字はvocabリストにもMergeリストにもあるはずがないのです。1バイト圏の英語ではマルチバイトの扱いに関する情報はないので念のため書いておきます。

ByteFallBack文字と英数記号1文字

ByteFallBack文字と英数記号1文字は、異なるトークンIDです。vocabリストをちゃんと全部見ないと気づきません。これだけで1か月以上消耗しました。でも改行コードはByteFallBackのを使っているとかあるので、ちゃんとリストに適合するようコーディングする必要があります。

デコード時の規則

UTF-8ストリームで規則違反となるトークンIDが発生するのは、事前に防いでおかなければなりません。具体的にはlogitからのサンプリング時に、UTF-8ストリーム規則違反になるものは-FLT_MAXでマスクしてからSoftmax関数にかける必要があります。これは気づきやすいところです。

llama2 Transformer

やはり実際プログラミングしていると罠がいっぱい潜んでいます。

トークンベクトルは入力用と出力用は異なる

入力用=TokenEmbedding[4096,32000]

出力用=lmHead[32000,4096]

転置行列でも逆行列でもありません。同じものとして学習して良いように思うのですが、実際はそうではありません。

RoPE(Rotary Positional Embedding)の回転角

RoPE(Rotary Positional Embedding)の回転角は、1要素目において、1トークン当たり1ラジアンです。ただしこれは、max_seq_lenが4096の時です。max_seq_len=8192にするときは、1トークン当たり0.5ラジアンが良いみたいです。とりあえずmax_seq_lenを4096としておくのが無難です。

RoPEの回転角は学習パラメーターではありますが、現実的には数式による固定値を使っています。

RoPEの回転角は、float32で計算します。float16では精度が悪く、回転角を正しく計算できません。

θの値は、1トークン当たり64次元です。4096 / 32 / 2 = 64。nトークン目は、1トークン目のθをn倍した価です。

θの値を計算するときは、指数部のプラスマイナスや全体のインバースを間違えないようにご注意下さい。

QとKについて、1ヘッド当たりそれぞれ128要素あり、これにRoPEを適用します。教科書通りの式や、他の実装ですと、隣同士のものをペアリングしてθ[0]~θ[63]で回転させますが、llama2の実際の実装は異なっており、前半のj番目と後半のj番目(j=0~63)をペアリングしてθ[0]~θ[63]で回転させます。これにより、回転後のQとKの値と順序は異なってしまいますが、順序については内積を取るだけなら問題ではなく、値についても学習時と推論時が同じ法則でペアリングしていればRoPEとして機能します。ここはよくllama2ソースコードを読まないと見落とし、隣同士でペアリングすると、おかしな結果が出力されます。実際にはこのように実装することで回転の高速化・並列化処理が行われています。

Attention層

マルチヘッドアテンションでは、ヘッドごとに独立してAttentionスコアを計算し、Softmaxをかけ、重みをこれまでのVキャッシュに適用します。分流はWQ,WK,WV、合流はWOのmatmul演算に含まれています。

マルチヘッドアテンションのスコアとSoftmaxは、float32で計算します。float16では精度が悪くSoftmaxの合計が1.0になりません。

FF層

FF層で、w1=gate, w3=up, w2=downです。使い間違えやすいです。使い間違えていると明らかに出力がおかしいので気づきますが、ここが使い間違えていると特定するのに1か月以上かかりました。

全般

同じfloat32でも、値によって掛け算や足し算の演算速度が変わります。あまりにも絶対値の小さすぎる高精細な値は、演算時間が何倍もかかります。CPUは時間がかかっても可能な限り精度を高く計算しようとするみたいです。コンパイラスイッチで浮動小数点オプションfp:精度優先とfp:速度優先があるのでfp:速度優先に切り替えたのですが、あまり効果がありません。絶対値1e-5以下は演算速度が遅く、悪い値です。負の√防止用以外には使わないで、ゼロにしてしまった方が良いかもしれません。

テンソル掛け算(matmul)の高速化

単純で効果的なのが、行方向ループはOpenMPでマルチスレッド化して高速化、列方向ループはAVX2で8個まとめてvfmaaddする高速化、AVX512が使えるなら16個まとめてvfmaaddする高速化です。AVX512が使えるならbfloat16もサポートするので、bfloat16ベースにするのも高速化になります。AVX512で両方合わせると、理論上はAVX2より4倍高速になると思います。

家庭用CPUでAVX512が使えるのは2024年8月に発売されたAMD Ryzenの第5世代モデルです。第4世代モデルもサポートしていますが内部レジスタは256bitで、要するにエミュレーションなので高速ではありません。第5世代のは内部レジスタが512bitなので本物です。具体的には、AMD Ryzen 9950X (\119,800)が最も適したCPUとなるでしょう。全スレッド100%のフル稼働となりますので、CPUクーラーは高級なのを使った方が良さげですね💦そこまでするなら、CPUを安物にしてもグラボにお金をかけた方がが幸せになれるかもですね💦

スレッド数の多さを重視するならThreadRipperもありますが、64コア128スレッドのを使っても16コア32スレッドの4倍なので、9950Xに比べるとコストパフォーマンスと置場と電気代の問題がいまいちです。理論値と現実値は違うのはありますが。将来ThreadRipperの第5世代が出たらすごそうですね💦高くて誰も買えませんね💦

なお、用意されたC言語用マクロでコンパイルするとVFMADD213PSが使われてしまい、レジスタ移動が発生して低速になります。高速化するためにVFMADD231PSを使いたいものです。列方向ループは完全にアセンブラで書いて、アセンブラの関数を呼び出すと効率的です。最近はインラインアセンブラは廃止されています。専用のasmファイルに関数を書く必要があります。

VEX.256.66.0F38.W0 98 /r VFMADD132PS ymm1, ymm2, ymm3/m256 A V/V FMA Multiply packed single precision floating-point values from ymm1 and ymm3/mem, add to ymm2 and put result in ymm1.
VEX.256.66.0F38.W0 A8 /r VFMADD213PS ymm1, ymm2, ymm3/m256 A V/V FMA Multiply packed single precision floating-point values from ymm1 and ymm2, add to ymm3/mem and put result in ymm1.
VEX.256.66.0F38.0 B8 /r VFMADD231PS ymm1, ymm2, ymm3/m256 A V/V FMA Multiply packed single precision floating-point values from ymm2 and ymm3/mem, add to ymm1 and put result in ymm1.

▼matmulの列方向ループはアセンブラで書き、AVX2のvfmadd231を使って高速化しました。float* aとfloat* bをnNum個積和演算してymm0に累積し、resultに格納します。

.code

; void AVX2_FMA_PS (float* a, float* b, float* result, INT_PTR nNum);
AVX2_FMA_PS PROC PUBLIC

; 引数のレジスタ
; rcx : a
; rdx : b
; r8 : result
; r9 : nNum

; ymm0 - ymm15 256bits = 32bytesレジスタ = float32 * 8個用レジスタ
; 結果レジスタの0初期化
    vxorps      ymm0, ymm0, ymm0

col_loop_begin:

; 列がなくなったら終了
    test        r9, r9
    jz          col_loop_end

    vmovups     ymm1, [rcx]
    vmovups     ymm2, [rdx]

    vfmadd231ps ymm0,ymm1,ymm2

; ポインタを進める
    add         rcx, 32
    add         rdx, 32

; カウントを減らす
    sub         r9,   8

; 列ループへ
    jmp         col_loop_begin

col_loop_end:

    vmovups     [r8], ymm0
    ret

AVX2_FMA_PS ENDP

END

ちな、AVXを使う場合は、各変数は何バイト境界にアライメントしないといけないという規則があるので、mallocやcallocでメモリを割り当てるのではなく、_aligned_mallocでメモリを割り当てる必要があります。

わたしのは、Core-i9 13900 (24コア32スレッド) と DDR4のメモリ4枚刺し128GBで、OpenMPによる並列化とAVX2のVFMADD231PS使用で、すべてfloat32演算で、だいたい2トークン/秒の速度で出力されます。

時に、既存のpythonのコードを使っても、全部GPUに乗る分には良いのですが、GPUから溢れた部分はCPU演算となります。その際、1スレッドしか使っていない動きをしているっぽいので、とても低速です。そのうち改良されるかもですし、そのままかもです。

マルチヘッドアテンションの高速化

llama2ではちょうど32ヘッドなので、OpenMPで32スレッド並列処理するのに良いです。内部の内積演算は、AVX2やAVX512で高速化します。ただしmatmulほど演算負荷はないので、全体としては効果がうすいです。デバッグ時は並列化しなくても良いでしょう。

値の型変換処理の高速化

大量の値をfloat16からfloat32に型変換する等は、OpenMPで並列処理をした方が高速です。ちな、変換の際に+/-INFとNaNを正しく処理するのを忘れずに。

▼とりあえず作った、Float16⇔Float32の一括変換用関数です。

typedef UINT16 float16;

void Float16toFloat32 (float16* pFloat16, float* pFloat32, INT_PTR nLen) {
	INT_PTR i = 0;
	#pragma omp parallel for
	for (i = 0; i < nLen; i++) {
		float16* pF16 = pFloat16 + i;
		float* pF32 = pFloat32 + i;
		UINT32 result32 = 0;
		UINT32 sign1 = (((*pF16) & 0x8000) << 16);
		UINT32 exponent5 = (((*pF16) & 0x7C00) >> 10);
		UINT32 fraction10 = ((*pF16) & 0x03FF);
		if (exponent5 == 0x1F) {
			if (fraction10 == 0) {
				result32 = (sign1 | 0x7F800000); // +/-Infinity
			}
			else {
				result32 = 0xFFFFFFFF; // NaN
			}
			*pF32 = *((float*)(&result32));
		}
		else if (exponent5 == 0 && fraction10 == 0) {
			result32 = sign1; // +/-0.0
			*pF32 = *((float*)(&result32));
			// Debug
			if (*pF32 != 0.0f && *pF32 != 0.0f) {
				_tcprintf (_T ("Error : Float16toFloat32 abnormal value.\n"));
			}
		}
		else {
			result32 = ((sign1) | ((exponent5 - 15 + 127) << 23) | (fraction10 << 13));
			*pF32 = *((float*)(&result32));
			// Debug
			if (*pF32 < -65504 || *pF32 > 65504) {
				_tcprintf (_T ("Error : Float16toFloat32 abnormal value.\n"));
			}
		}
	}
}

void Float32toFloat16 (float* pFloat32, float16* pFloat16, INT_PTR nLen) {
	INT_PTR i = 0;
	#pragma omp parallel for
	for (i = 0; i < nLen; i++) {
		float* pF32 = pFloat32 + i;
		float16* pF16 = pFloat16 + i;
		UINT16 result16 = 0;
		UINT16 sign1 = (((*((UINT32*)pF32)) & 0x8000000) >> 16);
		UINT32 exponent8 = (((*((UINT32*)pF32)) & 0x7C800000) >> 23);
		UINT32 fraction23 = (((*((UINT32*)pF32)) & 0x007FFFFF));
		if (exponent8 >= 0x1F) {
			if (fraction23 == 0) {
				result16 = (sign1 | 0x7C00); // +/-Infinity
			}
			else {
				result16 = 0xFFFF; // NaN
			}
		}
		else if (exponent8 == 0 && fraction23 == 0) {
			result16 = sign1; // +/-0.0
		}
		else {
			result16 = ((sign1) | (UINT16)((exponent8 - 127 + 15) << 10) | (UINT16)(fraction23 >> 13));
		}
		*pF16 = *((float16*)(&result16));
	}
}

終わりに

結局RTX4090は既製品あるいは完成品を動かす用、CPUとメインメモリは開発用とデバッグ用とVisualStudio用(3個ぐらい起動している)。フルパワー時にファンが音立てて回っているのが犬みたいでかわいいです。


(C)2000-2025 くず All rights reserved.