微比恩 > 信息聚合 > 秒秒钟揪出张量形状错误,这个工具能防止 ML 模型训练白忙一场

秒秒钟揪出张量形状错误,这个工具能防止 ML 模型训练白忙一场

2021-12-27 22:29:45来源: IT之家

模型吭哧吭哧训练了半天,结果发现张量形状定义错了,这一定没少让你抓狂吧。那么针对这种情况,是否存在较好的解决方法呢?这不最近,韩国首尔大学的研究者就开发出了一款“利器”—— PyTea。据研究人员介绍,它在训练模型前,能几秒内帮助你静态分析潜在的张量形状错误。那么 PyTea 是如何做到的,到底靠不靠谱,让我们一探究竟吧。PyTea 的出场方式为什么张量形状错误这么重要?神经网络涉及到一系列的矩阵计算,前面矩阵的列数必需匹配后面矩阵的行数,如果维度不匹配,那后面的运算就都无法运行了。上图代码就是一个典型的张量形状错误,[B x 120] * [80 x 10] 无法进行矩阵运算。无论是 PyTorch,TensorFlow 还是 Keras 在进行神经网络的训练时,大多都遵循图上的流程。首先定义一系列神经网络层(也就是矩阵),然后合成神经网络模块……那么为什么需要 PyTea 呢?以往我们都是在模型读取大量数据,开始训练,代码运

关注公众号