一覺(jué)醒來(lái),超越 Transformer 和 Mamba 的新架構(gòu)誕生了?
斯坦福、UCSD、UC 伯克利和 Meta 的研究人員提出了一種全新架構(gòu),用機(jī)器學(xué)習(xí)模型取代 RNN 的隱藏狀態(tài)。
這個(gè)模型通過(guò)對(duì)輸入 token 進(jìn)行梯度下降來(lái)壓縮上下文,這種方法被稱為“測(cè)試時(shí)間訓(xùn)練層(Test-Time-Training layers,TTT)”。
TTT 層直接替代了注意力機(jī)制,解鎖了具有表現(xiàn)力記憶的線性復(fù)雜度架構(gòu),使我們能夠在上下文中訓(xùn)練包含數(shù)百萬(wàn)(未來(lái)可能是數(shù)十億)個(gè) token 的 LLM。
作者相信,這個(gè)研究了一年多的項(xiàng)目,將從根本上改變我們的語(yǔ)言模型方法。
而結(jié)果證明,TTT-Linear 和 TTT-MLP 直接趕超或擊敗了最強(qiáng)的 Transformer 和 Mamba!
作者之一的 Xiaolong Wang 驚喜地表示:不敢相信,我們真的做到了。
更令人興奮的是,雖然目前 TTT 只應(yīng)用于語(yǔ)言建模,但在未來(lái),它也可以用在長(zhǎng)視頻上,可謂前景遠(yuǎn)大。
在將來(lái),當(dāng)我們對(duì)長(zhǎng)視頻進(jìn)行建模時(shí),就可以對(duì)幀進(jìn)行密集采樣,而不是采樣 1FPS 了。這些密集幀對(duì) Transformer 是一種負(fù)擔(dān),但對(duì)于 TTT 層來(lái)說(shuō),這卻是一種福音!
一個(gè) 5 年多的想法,終于實(shí)現(xiàn)了
作者表示,在過(guò)去的 1.5 年里,團(tuán)隊(duì)一直在開發(fā)一種新的 LLM 架構(gòu),可以具有線性復(fù)雜度和更強(qiáng)的隱藏狀態(tài),用于長(zhǎng)上下文建模。
而這個(gè)測(cè)試時(shí)訓(xùn)練(TTT)的想法,已經(jīng)研究了超過(guò) 5 年。
Xiaolong 清晰記得,在剛開始做博士后時(shí),Alyosha 曾讓自己去找 Yu Sun 討論 TTT。這次會(huì)面,就是這項(xiàng)研究的起點(diǎn)。
序列模型會(huì)把歷史上下文存儲(chǔ)在一個(gè)隱藏狀態(tài)中。
像 Mamba 這樣的 RNN 層,會(huì)隨著時(shí)間的推移壓縮成一個(gè)固定大小的狀態(tài),它們雖然效率很高,但性能受限于其表達(dá)能力。
注意力機(jī)制有一個(gè) KV 緩存,它會(huì)隨著時(shí)間的推移不斷增長(zhǎng)。這個(gè)狀態(tài)不會(huì)壓縮任何歷史上下文,但隨著上下文長(zhǎng)度的增加,成本也會(huì)越來(lái)越高。
團(tuán)隊(duì)成員想:既然這樣,為什么不把上下文壓縮到模型的權(quán)重中 —— 就像 LLM 處理互聯(lián)網(wǎng)數(shù)據(jù)那樣呢?
這種「隱藏狀態(tài)模型」既能在時(shí)間上保持固定大小,又能大大增強(qiáng)表達(dá)能力。
研究人員使用了自監(jiān)督學(xué)習(xí)來(lái)更新隱藏狀態(tài)的權(quán)重,對(duì)每個(gè) token 進(jìn)行一次梯度下降。在處理一個(gè)序列時(shí),該狀態(tài)已經(jīng)在其上下文窗口中的 token 上「訓(xùn)練」過(guò)了。
值得注意的是,隱藏狀態(tài)只存在于端到端架構(gòu)中的一層。其他組件,比如 QKV 投影矩陣,是在預(yù)訓(xùn)練期間通過(guò)標(biāo)準(zhǔn)的交叉熵目標(biāo)函數(shù)學(xué)習(xí)的。
因此,端到端架構(gòu)實(shí)際上是在進(jìn)行元學(xué)習(xí),尋找壓縮上下文的最佳方式,以便更好地預(yù)測(cè)下一個(gè) token,也就是在「學(xué)習(xí)如何在測(cè)試時(shí)學(xué)習(xí)」。
結(jié)果顯示,與 Mamba 相比,TTT-Linear 具有更好的困惑度和更少的 FLOP(左),并且更好地利用了長(zhǎng)上下文(右)。
下圖顯示了批大小為 16 的情況下,隨著上下文長(zhǎng)度的變化,每個(gè) token 的前向時(shí)間(延遲)。所有模型的參數(shù)都是 1.3B(Mamba 為 1.4B)。
可以看到,隨著上下文長(zhǎng)度的增加,Transformer 每個(gè) token 的前向時(shí)間呈線性增長(zhǎng),但其他兩種方法的前向時(shí)間基本保持不變。
在 8k 上下文時(shí),TTT-Linear 比 Transformer 更快,與 Mamba 相當(dāng)。
RNN 的尷尬現(xiàn)實(shí)
2020 年,OpenAI 縮放定律論文表明 LSTM(RNN 的一種)無(wú)法像 Transformer 那樣進(jìn)行縮放,或有效地使用長(zhǎng)上下文。
真的是這樣嗎?
在這個(gè)項(xiàng)目中,研究人員重新評(píng)估了圖 2 中的這些發(fā)現(xiàn)。
在左側(cè),可以觀察到 Mamba(當(dāng)今最流行的 RNN 之一)的擴(kuò)展性與強(qiáng)大的 Transformer 類似,這是自 2020 年的 LSTM 以來(lái)顯示出的巨大進(jìn)步。
然而,在右側(cè),可以觀察到與 OpenAI 相同的 Mamba 問(wèn)題。
平均而言,序列中靠后的 token 應(yīng)該更容易預(yù)測(cè),因?yàn)樗鼈円愿嘈畔闂l件。
對(duì) Transformer 來(lái)說(shuō)確實(shí)如此,每個(gè) token 索引的平均復(fù)雜度在其 32k 上下文中不斷減少。相比之下,Mamba 在 16k 后就出現(xiàn)了同樣的情況。
對(duì)于現(xiàn)有的 RNN 來(lái)說(shuō),這個(gè)結(jié)果代表了一個(gè)尷尬的現(xiàn)實(shí) ——
一方面,RNN(相對(duì)于 Transformer)的主要優(yōu)勢(shì)就是它們的線性(相對(duì)于二次)復(fù)雜性。這種漸進(jìn)優(yōu)勢(shì)實(shí)際上只會(huì)在長(zhǎng)上下文中實(shí)現(xiàn)。
另一方面,一旦上下文足夠長(zhǎng),現(xiàn)有的 RNN(如 Mamba)就很難真正利用額外的條件信息。
長(zhǎng)上下文的困難是 RNN 層本質(zhì)上的問(wèn)題:與自注意力機(jī)制不同,RNN 層必須將上下文壓縮為固定大小的隱藏狀態(tài)。
作為一種壓縮啟發(fā)式,更新規(guī)則需要發(fā)現(xiàn)成千上萬(wàn)甚至數(shù)百萬(wàn)個(gè) token 之間的底層結(jié)構(gòu)和關(guān)系。
研究人員首先觀察到,自監(jiān)督學(xué)習(xí)可以將大量訓(xùn)練集壓縮為 LLM 等模型的權(quán)重,該模型通常表現(xiàn)出對(duì)其訓(xùn)練數(shù)據(jù)之間語(yǔ)義聯(lián)系的深刻理解,而這,恰恰是他們所需要的。
TTT 層
受此啟發(fā),研究人員設(shè)計(jì)了一類新的序列建模層,其中隱藏狀態(tài)是模型,更新規(guī)則是自監(jiān)督學(xué)習(xí)的一個(gè)步驟。
由于更新測(cè)試序列上隱藏狀態(tài)的過(guò)程,相當(dāng)于在測(cè)試時(shí)訓(xùn)練模型,因此此類新層稱為測(cè)試時(shí)訓(xùn)練(TTT)層。
研究人員引入兩個(gè)簡(jiǎn)單的實(shí)例:TTT-Linear 和 TTT-MLP,其中隱藏狀態(tài)分別是線性模型和兩層 MLP。TTT 層可以集成到任何網(wǎng)絡(luò)架構(gòu)中并進(jìn)行端到端優(yōu)化,類似于 RNN 層和自注意力。
實(shí)際運(yùn)行時(shí)間
TTT 層在 FLOP 方面已經(jīng)非常高效,研究人員則更進(jìn)一步地提出了兩項(xiàng)創(chuàng)新,使其在實(shí)際運(yùn)行時(shí)間內(nèi)也能保持高效。
首先,與在常規(guī)訓(xùn)練中對(duì) mini-batch 序列采取梯度步進(jìn)以實(shí)現(xiàn)更好的并行性類似,他們也在 TTT 中使用了 mini-batch 的 token。
其次,研究人員為每個(gè) TTT mini-batch 內(nèi)的操作開發(fā)了一種對(duì)偶形式,以更好地利用現(xiàn)代 GPU 和 TPU。這種對(duì)偶形式的輸出與原始實(shí)現(xiàn)相當(dāng),但訓(xùn)練速度卻快了 5 倍以上。
正如圖 3 所示,TTT-Linear 在 8k 上下文中比 Transformer 更快,并且與 Mamba 相當(dāng)。
Transformer 殺手 ——TTT
如圖 4 所示,所有的序列建模層,都可以從將歷史上下文存儲(chǔ)到隱藏狀態(tài)的角度來(lái)看待。
比如,RNN 層 —— 如 LSTM、RWKV 和 Mamba 層 —— 將上下文壓縮成一個(gè)固定大小的狀態(tài),這個(gè)狀態(tài)隨時(shí)間變化。
這種壓縮帶來(lái)了兩種結(jié)果:優(yōu)勢(shì)是處理效率高,因?yàn)槊總€(gè) token 的處理時(shí)間是恒定的。劣勢(shì)是在處理長(zhǎng)上下文時(shí),RNN 性能受限于隱藏狀態(tài)的「表達(dá)能力」。
自注意力機(jī)制(Self-attention)也可以從如上角度來(lái)理解。
不同之處在于,它的隱藏狀態(tài),通常稱為鍵值(KV)緩存是一個(gè)隨 t 增長(zhǎng)的線性 list。
它可以存儲(chǔ)所有的上下文,并且不會(huì)進(jìn)行壓縮,具有很好的表達(dá)能力,不過(guò)其處理時(shí)間隨上下文長(zhǎng)度線性增長(zhǎng)。
因此,為了在長(zhǎng)上下文中既保持效率,又具有表達(dá)能力,需要一個(gè)更好的「壓縮啟發(fā)式」(compression heuristic)方法。
具體來(lái)說(shuō),就需要將數(shù)百萬(wàn)個(gè) token 壓縮成一個(gè)能有效捕捉其底層結(jié)構(gòu)和關(guān)系的隱藏狀態(tài)。
TTT 隱藏狀態(tài)
研究人員的關(guān)鍵思想是,使用自監(jiān)督學(xué)習(xí)來(lái)將歷史上下文 x1,...,xt 壓縮成一個(gè)隱藏狀態(tài) St。方法是將上下文視為一個(gè)無(wú)標(biāo)簽數(shù)據(jù)集,而將狀態(tài)視為一個(gè)模型。
具體來(lái)說(shuō),隱藏狀態(tài) St 現(xiàn)在等同于一個(gè)模型 f 的權(quán)重 Wt,這個(gè)模型 f 可以是線性模型、小型神經(jīng)網(wǎng)絡(luò)或其他任何形式。輸出規(guī)則簡(jiǎn)單地表示為:zt=f(xt;wt)。
直觀講,輸出 token 就是由更新后權(quán)重 Wt 的模型 f 對(duì) xt 所做的預(yù)測(cè)。更新規(guī)則是在某個(gè)自監(jiān)督損失?上進(jìn)行的一步梯度下降:Wt=Wt-1-ηΔ?(Wt-1;xt)。其中學(xué)習(xí)率為 η。
從壓縮的角度來(lái)看,每種啟發(fā)式方法都需要決定記住 / 忘記哪些輸入。W 會(huì)記住那些產(chǎn)生大梯度的輸入 —— 直觀地說(shuō),就是那些使 W 學(xué)習(xí)很多的輸入。
?的一種選擇是重構(gòu) xt 本身。為了使學(xué)習(xí)問(wèn)題變得非平凡,作則首先將 xt 處理成一個(gè)被破壞的輸入
,然后優(yōu)化:
類似于去噪自編碼器,f 需要發(fā)現(xiàn) xt 各維度之間的相關(guān)性,以便從部分信息
中重構(gòu)出 xt。
如圖 5 所示,梯度下降能夠減少?,但無(wú)法將其降至零。
與其他 RNN 層和自注意力機(jī)制一樣,研究人員將輸入序列 x1,...,xT 映射到輸出序列 z1,...,zt 的算法可以被編程到序列建模層的前向傳播中,使用上述的隱藏狀態(tài)、更新規(guī)則和輸出規(guī)則。
即使在測(cè)試時(shí),新層仍然為每個(gè)輸入序列訓(xùn)練一個(gè)不同的權(quán)重序列 W1,...,Wt。因此,研究人員將其稱之為測(cè)試-時(shí)間訓(xùn)練層(TTT)。
使用 TTT 層訓(xùn)練神經(jīng)網(wǎng)絡(luò)
TTT 層的前向傳播,也有相應(yīng)的后向傳播。
TTT 層與 RNN 層、自注意力機(jī)制有著相同的接口,因此可以在任何更大的神經(jīng)網(wǎng)絡(luò)架構(gòu)中替換它們。
值得一提的是,訓(xùn)練帶有 TTT 層神經(jīng)網(wǎng)絡(luò)的方式,與訓(xùn)練任何其他 Transformer 模型相同。
可以使用相同的數(shù)據(jù)、方法和目標(biāo)(如下一個(gè) token 預(yù)測(cè))來(lái)優(yōu)化網(wǎng)絡(luò)其余部分的參數(shù)。
在此,研究人員將訓(xùn)練更大的神經(jīng)網(wǎng)絡(luò)稱為外循環(huán)(outer loop),而在每個(gè) TTT 層內(nèi)訓(xùn)練 W 稱為內(nèi)循環(huán)(inner loop)。
它們之間梯度計(jì)算的區(qū)別是,內(nèi)循環(huán)針對(duì)的是 W(即模型 f 的參數(shù)),外循環(huán)針對(duì)的是網(wǎng)絡(luò)其余部分的參數(shù) θrest。
TTT 學(xué)習(xí)自監(jiān)督任務(wù)
可以說(shuō),TTT 最重要的部分是自監(jiān)督任務(wù),因?yàn)樗鼪Q定了 W 從測(cè)試序列中學(xué)習(xí)的特征類型。
在這個(gè)任務(wù)的設(shè)計(jì)上,研究人員采取了更加端到端的方法 —— 直接優(yōu)化自監(jiān)督任務(wù)以實(shí)現(xiàn)下一個(gè) token 預(yù)測(cè)的最終目標(biāo)。
具體來(lái)說(shuō),研究者將自監(jiān)督任務(wù)的學(xué)習(xí),作為外循環(huán)的一部分。
從如上公式 3 中的簡(jiǎn)單重構(gòu)任務(wù)開始,添加了一些外循環(huán)參數(shù)來(lái)讓這個(gè)任務(wù)可學(xué)習(xí)。最新的自監(jiān)督損失是:
在內(nèi)循環(huán)中,只有 W 被優(yōu)化,因此作為?的參數(shù)寫出;θ 們是這個(gè)損失函數(shù)的「超參數(shù)」。在外循環(huán)中,θK,θV,θQ 與 θrest 一起被優(yōu)化,而 W 僅僅是一個(gè)隱藏狀態(tài),不是參數(shù)。
圖 6 用代碼說(shuō)明了這種區(qū)別,其中 θK 和 θV 被實(shí)現(xiàn)為 TTT 層的參數(shù),類似于自注意力中的 KV 參數(shù)。
總的來(lái)說(shuō),θK,θV,θQ 所有可能的選擇構(gòu)成了一系列多視圖重構(gòu)任務(wù),外循環(huán)可以被理解為從這個(gè)任務(wù)組中選擇一個(gè)具體任務(wù)。為了簡(jiǎn)單起見,研究人員在這里將所有視圖設(shè)計(jì)為線性投影。
mini-batch TTT 并行化
目前,開發(fā)的原生 TTT 層在浮點(diǎn)運(yùn)算(FLOP)次數(shù)方面已經(jīng)非常高效。
然而,其更新規(guī)則 Wt=Wt-1-ηΔ?(Wt-1;xt)無(wú)法實(shí)現(xiàn)并行化,因?yàn)?Wt 在兩個(gè)位置上依賴于 Wt-1:負(fù)號(hào)和 Δ?。
對(duì)此,研究人員提出了 mini-batch 梯度下降,用 b 表示 TTT 批大小。
研究中使用 Gt=Δ?(Wt';xt),其中 t'=t-mod(t,d)代表著前一個(gè) mini-batch 的最后一個(gè)時(shí)間步(或者第一個(gè) mini-batch 0),因此,可以一次并行 b 個(gè)梯度計(jì)算。
對(duì)偶形式
上面介紹的并行化是必要的,但對(duì)于「實(shí)際運(yùn)行時(shí)間」(wall-clock time)的效率來(lái)說(shuō)還不夠。
正如之前所述,可以對(duì)于 t = 1,...,b 進(jìn)行并行計(jì)算:
然而,現(xiàn)實(shí)中,是無(wú)法對(duì)單個(gè) matmul 來(lái)計(jì)算 GtS 所有的 b。
相反,需要 b 個(gè)外積來(lái)對(duì)其進(jìn)行一一計(jì)算。更糟糕的是,對(duì)于每個(gè)
,Gt 是 d×d,這會(huì)比大 d xt 產(chǎn)生更大的內(nèi)存占用和 I / O 成本。
為了解決這兩個(gè)問(wèn)題,研究人員觀察到:我們實(shí)際上并不需要具體化 G1,...,Gb,只要要我們可以在 mini-batch 結(jié)束時(shí)計(jì)算 Wb,并且輸出 token z1,...,zb(如上圖 7 所示)。
現(xiàn)在,就可以用上面簡(jiǎn)化的 TTT-Linear 情況來(lái)演示這些計(jì)算,表示 X = [x1,...,xb]:
所以 Wb 可以用 matmul 方便地計(jì)算出來(lái)。為了計(jì)算 Z = [z1,...,zb],我們知道:
表示
和矩陣
,可以得出:
如上過(guò)程,研究人員將其稱為「對(duì)偶形式」。
理論等價(jià)
前面已經(jīng)提到 f 可以是線性模型,也可以是神經(jīng)網(wǎng)絡(luò)。還有更新規(guī)則的三種變體:online GD、batch GD 和 mini-batch GD。
如下圖所示,在這些 2×3 組合中,每一種都會(huì)引起 TTT 層的不同實(shí)例化。
研究中,作者分別從 2 個(gè)定理證明了在這些誘導(dǎo)實(shí)例中,具有線性模型和 batch GD 的 TTT 層等同于線性注意力 —— 一個(gè)廣為人知的 RNN 層。
圖 10 總結(jié)了所有序列建模層的更廣泛范圍內(nèi) TTT 層的一般定義。
兩種變體
研究中,作者提出了 TTT 層的兩種變體 TTT-Linear 和 TTT-MLP,僅在 f 的實(shí)例化方面有所不同。
對(duì)于 TTT-Linear,flin(x)=Wx,其中 W 是平方。對(duì)于 TTT-MLP,fMLP 有兩層,類似于 Transfomer 的 MLP。
具體來(lái)說(shuō),隱藏維度是 4× 輸入維度,然后是 GELU 激活。為了在 TTT 期間獲得更好的穩(wěn)定性,f 始終包含層歸一化 (LN) 和殘差連接。
即,f(x)=x + LN(fres(x)),其中,fres 可以是 flin 或 fMLP。
實(shí)驗(yàn)
通過(guò)與兩個(gè)基線 Transformer 和 Mamba(現(xiàn)代 RNN)比較,研究人員評(píng)估了 TTT-Linear 和 TTT-MLP。
數(shù)據(jù)集
繼續(xù) Mamba 論文之后,研究人員在 Pile 上執(zhí)行了 2k 和 8k 上下文長(zhǎng)度的標(biāo)準(zhǔn)實(shí)驗(yàn),Pile 是一個(gè)用于訓(xùn)練開源 LLM 的流行文檔數(shù)據(jù)集。
主架構(gòu)
Transformer 和 Mamba 使用不同的,除非另有說(shuō)明,TTT-Linear 和 TTT-MLP 始終使用 Mamba 架構(gòu)。
短上下文:the Pile
在 2k 上下文中,TTT-Linear(M)、Mamba 和 Transformer 具有相當(dāng)?shù)男阅埽€條大部分重疊。
TTT-MLP(M)在較大的 FLOP 預(yù)算下表現(xiàn)稍差。盡管 TTT-MLP 在每個(gè)模型大小上,都比 TTT-Linear 具有更好的復(fù)雜度,但 FLOP 的額外成本抵消了這種優(yōu)勢(shì)。
在 8k 上下文中,TTT-Linear(M)和 TTT-MLP(M)的表現(xiàn)均明顯優(yōu)于 Mamba。即使是具有 Transformer 架構(gòu)的 TTT-MLP(T),性能也比 Mamba 略好。
另外,研究人員還觀察到了一個(gè)非常明顯的現(xiàn)象:隨著上下文長(zhǎng)度變長(zhǎng),TTT 層相對(duì)于 Mamba 的優(yōu)勢(shì)就更大了。
長(zhǎng)上下文:Books
為了評(píng)估長(zhǎng)上下文中的功能,研究人員使用了 Pile 的一個(gè)流行子集 ——Books,對(duì)從 1k 到 32k 以 2 個(gè)增量的上下文長(zhǎng)度進(jìn)行了實(shí)驗(yàn)。
根據(jù)上圖,可以觀察到 ——
在 Books 的 2k 上下文中,Pile 2k 的所有觀察結(jié)果仍然成立,唯一的例外是 Mamba 的表現(xiàn)略好于 TTT-Linear。
在 32k 上下文中,TTT-Linear(M)和 TTT-MLP(M)的性能均優(yōu)于 Mamba,與 Pile 8k 的觀察結(jié)果類似。即使具有 Transformer 架構(gòu)的 TTT-MLP(T),在 32k 上下文中的表現(xiàn)也比 Mamba 稍好。
在 1.3B 尺度上,TTT-MLP(T)僅比 TTT-MLP(M)稍差。由于缺之清晰的線性擬合,很難推導(dǎo)出經(jīng)驗(yàn)縮放定律。然而,TTT-MLP(T)的強(qiáng)勁趨勢(shì)表明,Transformer 架構(gòu)可能更適合超出評(píng)估的更大模型和更長(zhǎng)上下文。
上下文長(zhǎng)度作為超參數(shù)
雖然輸入序列的長(zhǎng)度由用戶確定,但語(yǔ)言模型處理輸入的上下文長(zhǎng)度可以由工程師確定。因此,上下文長(zhǎng)度也是一個(gè)可以選擇的超參數(shù)。
對(duì)于具有線性復(fù)雜度的 LLM,研究人員選擇了困惑度中的 argmin,因?yàn)槊總€(gè)上下文長(zhǎng)度都有相同的 FLOP。
從圖 13 中,可以觀察到以下結(jié)果 ——
- 性能最好的方法 TTT-Linear 和 TTT-MLP 的線幾乎完全重疊。Mamba 和 TF Finetune 的線在 10^20 FLOP 后也大部分重疊。
- TF Finetune 的性能明顯優(yōu)于 TF Pretrain,因?yàn)樗芤嬗陂L(zhǎng)上下文,而不會(huì)在訓(xùn)練 FLOP 中產(chǎn)生極大的成本。
- 對(duì)于所有從頭開始訓(xùn)練的方法(包括 TF 預(yù)訓(xùn)練),一旦上下文長(zhǎng)度變得太大,困惑度就會(huì)變得更糟。
從上圖可見,與 TTT-Linear 相比,TTT-MLP 在短上下文中表現(xiàn)稍差,但在長(zhǎng)上下文中表現(xiàn)更好。
這一觀察結(jié)果正符合研究人員的預(yù)期,即作為隱藏狀態(tài)的 MLP 比線性模型更具表現(xiàn)力。同樣,所有方法都具有與 Mamba 1.4B 相同的訓(xùn)練 FLOP。
實(shí)際運(yùn)行時(shí)間
LLM 訓(xùn)練和推理可以分解為前向、后向和生成。
由于前向(在訓(xùn)練和推理期間)和后向都可以并行化,因此研究人員使用對(duì)偶形式。生成新 token(也稱為解碼)本質(zhì)上是順序的,因此研究人員使用原始形式。
由于資源限制,這項(xiàng)實(shí)驗(yàn)是用 JAX 編寫并在 TPU 上運(yùn)行的。
然而,由于 Mamba(在 PyTorch、Triton 和 CUDA 中實(shí)現(xiàn))只能在 GPU 上運(yùn)行,因此為了公平比較,研究人員還重寫了方法,以在 GPU 上運(yùn)行。
具體來(lái)說(shuō),研究人員在 ThunderKittens 中編寫了一個(gè)用于前向的 GPU 內(nèi)核。從歷史上看,由于并行性和矩陣相乘的使用不當(dāng),RNN 在前向和后向過(guò)程中效率低下。
這個(gè)前向內(nèi)核的目標(biāo),是證明 mini-batch TTT 和這些問(wèn)題對(duì)偶形式的有效性。
圖 15 的左圖顯示了前向內(nèi)核批大小為 16 的延遲。所有模型參數(shù)均為 1.3B(Mamba 為 1.4B)。
對(duì)于 Transformer,每個(gè) token 的時(shí)間隨著上下文長(zhǎng)度的增加而線性增長(zhǎng),但對(duì)于其他方法則大致保持不變。
此外,研究人員在 Triton 中編寫了另一個(gè)用于生成的 GPU 內(nèi)核,并在圖 15 的右圖中對(duì)批大小為 512 的速度進(jìn)行了基準(zhǔn)測(cè)試。
可以看出,TTT-Linear 和 Mamba 的延遲幾乎相同,明顯小于 Transformer 和 TTT-MLP。
Mamba 之后,又看到 TTT 這么能打的新架構(gòu)誕生,少不了 AI 社區(qū)的熱議。
有網(wǎng)友稱,這會(huì)不會(huì)是最接近實(shí)時(shí)上下文的方法?很想聽聽大家的想法。這意味著 TTT 甚至在使用過(guò)程中,也能夠?qū)W習(xí)和適應(yīng),為長(zhǎng)上下文提供更好的性能,而不會(huì)產(chǎn)生通常與 Transformer 相關(guān)的高昂計(jì)算成本。
OpenAI 視頻生成研究人員對(duì)此表示,這項(xiàng)研究看起來(lái)很有趣。
如果 scaling law 依然存在,TTT 將帶來(lái)難以置信的影響。對(duì)于長(zhǎng)序列,Transformer 的計(jì)算成本往往很高,當(dāng)長(zhǎng)序列變得更長(zhǎng)時(shí),RNN 會(huì)遺忘。TTT 訓(xùn)練巧妙地利用神經(jīng)網(wǎng)絡(luò)解決 RNN 的不足。
作者介紹
論文最后,分別列出了這篇研究的作者貢獻(xiàn)。
其中的核心作者是,Yu Sun、Xinhao Li 和 Karan Dalal。
Yu Sun
Yu Sun 是斯坦福大學(xué)計(jì)算機(jī)專業(yè)的博士后,導(dǎo)師是 Carlos Guestrin、Tatsu Hashimoto 和 Sanmi Koyejo。
此前,他曾在加州大學(xué)伯克利分校完成了電子工程科學(xué)博士學(xué)位,導(dǎo)師是 Alyosha Efros 和 Moritz Hardt。他還在康奈爾大學(xué)拿到了學(xué)士學(xué)位。
個(gè)人主頁(yè)中,他介紹自己的研究重點(diǎn)是一種名為測(cè)試時(shí)間訓(xùn)練(test-time training)的算法框架。其核心思想是,每個(gè)測(cè)試實(shí)例都定義了自己的學(xué)習(xí)問(wèn)題,都有自己的泛化目標(biāo)。這通常使用自監(jiān)督學(xué)習(xí),為每個(gè)實(shí)例即時(shí)訓(xùn)練一個(gè)不同的模型來(lái)實(shí)現(xiàn)的。
在最新研究中,Yu Sun 與 Xinhao Li 在 2022 年 11 月共同啟動(dòng)了這一項(xiàng)目。自 2023 年 6 月起,Yu Sun 專職負(fù)責(zé)該項(xiàng)目。
他提出了項(xiàng)目的概念框架,設(shè)計(jì)了 mini-batch TTT 和對(duì)偶形式(dual form)。
Xinhao Li
Xinhao Li 是 UC San Diego 研二的學(xué)生,導(dǎo)師是 Xiaolong Wang 教授。他本人的研究興趣主要是深度學(xué)習(xí)和計(jì)算機(jī)視覺(jué)。
他在斯坦福大學(xué) Tatsunori Hashimoto 教授的團(tuán)隊(duì)中作為訪問(wèn)學(xué)生,與 Yu Sun 博士和其他導(dǎo)師朋友一起工作。在此之前,他曾在電子科技大學(xué)獲得了學(xué)士學(xué)位。
在 2024 年 3 月之前,Xinhao Li 是 TTT 早期代碼庫(kù)的主要貢獻(xiàn)者,這些代碼庫(kù)塑造了最新項(xiàng)目。
Karan Dalal
Karan Dalal 是 UC Berkeley 電子工程科學(xué)系的本科生。他于 2023 年 6 月全職加入該項(xiàng)目,與 Xinhao Li 合作共同領(lǐng)導(dǎo)了當(dāng)前代碼庫(kù)的開發(fā)工作。
參考資料:
https://x.com/karansdalal/status/1810338845659131940
https://x.com/xiaolonw/status/1810387662060269668
https://arxiv.org/abs/2407.04620
廣告聲明:文內(nèi)含有的對(duì)外跳轉(zhuǎn)鏈接(包括不限于超鏈接、二維碼、口令等形式),用于傳遞更多信息,節(jié)省甄選時(shí)間,結(jié)果僅供參考,IT之家所有文章均包含本聲明。