設(shè)置
  • 日夜間
    隨系統(tǒng)
    淺色
    深色
  • 主題色

秒秒鐘揪出張量形狀錯(cuò)誤,這個(gè)工具能防止 ML 模型訓(xùn)練白忙一場

量子位 2021/12/27 22:29:45 責(zé)編:江離

模型吭哧吭哧訓(xùn)練了半天,結(jié)果發(fā)現(xiàn)張量形狀定義錯(cuò)了,這一定沒少讓你抓狂吧。那么針對(duì)這種情況,是否存在較好的解決方法呢?

這不最近,韓國首爾大學(xué)的研究者就開發(fā)出了一款“利器”—— PyTea。

據(jù)研究人員介紹,它在訓(xùn)練模型前,能幾秒內(nèi)幫助你靜態(tài)分析潛在的張量形狀錯(cuò)誤。

那么 PyTea 是如何做到的,到底靠不靠譜,讓我們一探究竟吧。

PyTea 的出場方式

為什么張量形狀錯(cuò)誤這么重要?

神經(jīng)網(wǎng)絡(luò)涉及到一系列的矩陣計(jì)算,前面矩陣的列數(shù)必需匹配后面矩陣的行數(shù),如果維度不匹配,那后面的運(yùn)算就都無法運(yùn)行了。

圖片

上圖代碼就是一個(gè)典型的張量形狀錯(cuò)誤,[B x 120] * [80 x 10] 無法進(jìn)行矩陣運(yùn)算。

圖片

無論是 PyTorch,TensorFlow 還是 Keras 在進(jìn)行神經(jīng)網(wǎng)絡(luò)的訓(xùn)練時(shí),大多都遵循圖上的流程。

首先定義一系列神經(jīng)網(wǎng)絡(luò)層(也就是矩陣),然后合成神經(jīng)網(wǎng)絡(luò)模塊……

那么為什么需要 PyTea 呢?

以往我們都是在模型讀取大量數(shù)據(jù),開始訓(xùn)練,代碼運(yùn)行到錯(cuò)誤張量處,才可以發(fā)現(xiàn)張量形狀定義錯(cuò)誤。

由于模型可能十分復(fù)雜,訓(xùn)練數(shù)據(jù)非常龐大,所以發(fā)現(xiàn)錯(cuò)誤的時(shí)間成本會(huì)很高,有時(shí)候代碼放在后臺(tái)訓(xùn)練,出了問題都不知道……

PyTea 就可以有效幫我們避免這個(gè)問題,因?yàn)樗茉谶\(yùn)行模型代碼之前,就幫我們分析出形狀錯(cuò)誤。

圖片

網(wǎng)友們已經(jīng)在熱烈討論了。

PyTea 是如何運(yùn)作的,它能否有效地檢查出錯(cuò)誤呢?

圖片

受各種約束條件的影響,代碼可能的運(yùn)行路徑有很多,不同的數(shù)據(jù)會(huì)走向不同的路徑。

所以 PyTea 需要靜態(tài)掃描所有可能的運(yùn)行路徑,跟蹤張量變化,推斷出每個(gè)張量形狀精確而保守的范圍。

上圖就是 PyTea 的整體架構(gòu),一共分為翻譯語言,收集約束條件,求解器判斷和給出反饋四步。

圖片

首先 PyTea 將原始的 Python 代碼翻譯成一種內(nèi)核語言。PyTea 內(nèi)部表示法(PyTea IR)。

圖片

接著 PyTea 追蹤 PyTea IR 每個(gè)可能的執(zhí)行路徑,并收集有關(guān)張量形狀的約束條件。

判斷約束條件是否被滿足,分為線上分析和離線分析兩步

  • 線上分析 node.js(TypeScript / JavaScript):查找張量形狀數(shù)值上的不匹配和誤用 API 函數(shù)的情況。如果 PyTea 發(fā)現(xiàn)問題,就會(huì)停止在當(dāng)前位置,然后給用戶報(bào)錯(cuò)。

圖片

  • 離線分析 Z3 / Python:如果線上分析沒有問題,PyTea 將收集到的約束條件傳給 SMT(Satisfiability Modulo Theories)求解器 Z3,求解器負(fù)責(zé)查看每條路徑的約束條件是否都能被滿足,如果不能,返回給用戶第一條出錯(cuò)路徑的約束條件。

圖片

如果求解器過久沒有反應(yīng),PyTea 會(huì)返回不知道是否存在問題。

然而追蹤所有可能的路徑是指數(shù)級(jí)別的任務(wù),對(duì)于復(fù)雜的神經(jīng)網(wǎng)絡(luò)來說,一定會(huì)發(fā)生路徑爆炸這個(gè)問題。

圖片

比如說在這個(gè)例子中,網(wǎng)絡(luò)的最終結(jié)構(gòu)是由 24 個(gè)相同模塊塊構(gòu)成的(第 17 行),那么可能的路徑就有 16M 之多。

所以路徑爆炸是一定要處理的,PyTea 是怎么做的?

PyTea 選擇保守的地對(duì)路徑剪枝和超時(shí)判斷來處理這種路徑爆炸。

什么樣的路徑可以被剪枝?

PyTea 給出的答案是,如果該前饋函數(shù)不改變?nèi)种担⑶宜妮敵鲋挡皇芊种l件影響,對(duì)于每條路徑都是相等的,我們就可以忽略許多完全一致的路徑,來節(jié)約計(jì)算資源。

如果路徑剪枝還是不行,那么就只能按超時(shí)處理了。

原理就介紹這么多了,感覺還是值得一試的,現(xiàn)在代碼已經(jīng)在 GitHub 上面開源了,快去看看吧!

使用方法

依賴庫:

圖片

安裝方法:

圖片

運(yùn)行命令:

圖片

參考鏈接

[1]https://github.com/ropas/pytea

[2]https://arxiv.org/abs/2112.09037

廣告聲明:文內(nèi)含有的對(duì)外跳轉(zhuǎn)鏈接(包括不限于超鏈接、二維碼、口令等形式),用于傳遞更多信息,節(jié)省甄選時(shí)間,結(jié)果僅供參考,IT之家所有文章均包含本聲明。

相關(guān)文章

軟媒旗下網(wǎng)站: IT之家 最會(huì)買 - 返利返現(xiàn)優(yōu)惠券 iPhone之家 Win7之家 Win10之家 Win11之家

軟媒旗下軟件: 軟媒手機(jī)APP應(yīng)用 魔方 最會(huì)買 要知