关于pytorch中contiguous(内存连续性)的一些理解

作者: L-I-F
2025-5-19 0:0:0

之前在学习pytorch的时候一直对于这个概念云里雾里,于是画了一个小时理了理内存连续的概念,有了一些理解.
文中所说的c风格式存储,指的是一行一行存,先存第一行,接着把第二行放在第一行后面存储,以此类推.

1.内存连续性的定义


首先来看官方给的定义

Tensor is or will be allocated in dense non-overlapping memory. Strides represented by values in decreasing order.


翻译成中文就是

张量在密集且不重叠的内存中分配,stride 按递减顺序排列。


要解释明白这句话,首先就要说明一下stride是什么:
常常被表示为一个元组,元组的位置和tensor的维度一一对应,表示该维度加一需要走过的内存中的元素数量
举个例子,如果我有一个tenser是:

[ [1,2,3],
[4,5,6] ]

在内存中的存储是 1,2,3,4,5,6
那么他的stride是(3,1),容易理解的是,第二维每增加1,就要走过一个元素
比如1->2,从1走到2,走过了元素“1”,走过了一个元素
而对于第一维,要增加1,就是要从1走到4:
1->2->3->4,走过了“1”,“2”,“3”三个元素,所以stride对应位置是3
现在我们回到contiguous上面,stride的数字应该从前往后依次减小,也就是说,靠后的维数(比如第2维)要比靠前的维数(比如第1维)每增大一个值,走过的内存中的元素小
你可能会疑惑,在C风格的内存存储下,后面比前面的stride小不是很显然吗?难不成还有变大的?
有,还真有!
问题就出在stride定义时候强调的走过内存中的元素,注意是内存中的,不是你看到的
比如我对刚刚那个矩阵生成一个转置的视图(视图view,只改变你看到的形状,不改变原存储方式):

[ [1,4],
[2,5]
[3,6] ]

由于内存没有变,所以在内存中的存储依然是 1,2,3,4,5,6
此时第二个维数增加1,在内存中要走1->2->3->4,stride为3
但是第一个维数增加1,在内存中只需要走1->2,stride为1
stride是(1,3),是不是后面的比前面的大了?
但是好像这样的解释有点隔靴搔痒,没有触及本质。

2.违反contiguous的本质(我理解的)


有一个很简单的判断准则:对于一个你看见的tensor,如果你按照C风格存储方式所预期的stride不等于实际上的stride,就不连续。
因为只有按照c风格来存储,你用索引从小到大,维数从后到前的最正常的遍历方法,在内存中才是不跳跃,不反向的连续遍历,而按照C风格存储方式所预期的stride不等于实际上的stride,就意味着你不是按照C风格存储,就会导致不连续。
([0][0]->[0][1]->[1][0]->[1][1]->[2][0]->[2][1])
下面用3个不连续的例子帮你理解:
首先就用上面转置生成的视图举例子:

[ [1,4],
[2,5]
[3,6] ]


假设你不知道这个是转置生成的视图,按照C风格的存储方式,你预期的存储应该是 1,4,2,5,3,6,stripe应该是(3,1)
但是实际上是 1,2,3,4,5,6,stripe是(1 , 3)
这个例子下,在按照索引从小到大,维数从后到前遍历时候,会发生内存的跳跃(从1跳到4)
第二个例子是切片(pytorch中的切片返回的是视图)
就拿最简单的一维举例子:

[1,2,3,4,5,6]


①取切片[ : :2],生成[1,3,5],假设你不知道这个是一个视图,按照C风格的存储方式,你预期的存储应该是 1,3,5,stripe应该是(1)
但是实际上存储是 1,2,3,4,5,6,stripe是(2)
此时在按照索引从小到大,维数从后到前遍历时候,会发生跳跃
事实上pytorch不支持负步长,也大概不存在反转的视图(torch.flip()会直接产生副本),上面这个负步长的例子是基于我的理解推演出的例子
②如果取[ : :-1],生成[6,5,4,3,2,1],假设你不知道这个是一个视图,按照C风格的存储方式,你预期的存储应该是 6,5,4,3,2,1,stripe应该是(1)
但是实际上存储是 1,2,3,4,5,6,stripe是(-1)
此时在按照索引从小到大,维数从后到前遍历时候,会发生反向
第三个例子是expand函数,用法三言两语解释不清,简而言之是:生成一个将现有的数据在某个维度上“复制”几次的视图(注意:不是真的在内存中复制了,只是把同一个地方的数据反复用)
比如我有一个(1,4)的tensor,变量名为t:

[1,2,3,4]



现在用t.expand(4,4)“复制”数据生成一个(4,4)的视图:

[[1,2,3,4],
[1,2,3,4],
[1,2,3,4],
[1,2,3,4]]


假设你不知道这个是一个视图,按照C风格的存储方式,你预期的存储应该是 1,2,3,4,1,2,3,4,1,2,3,4,1,2,3,4,stride为(4,1)
但是实际上存储是 1,2,3,4,stripe是(0,1)
这里0代表着第一个维数加1(从第一行的1到第二行的1,一步也不用动,事实上确实是如此)
此时在按照索引从小到大,维数从后到前遍历时候,会在切换行的时候反向跳回第一个数字

3.总结


你可以用官方给出的定义来判断,也可以用我再第2点里提到的判断准则来判断,由于stride的性质,这两种判断方法实际上是等价的
而所谓的不连续,本质上就是视图和它的存储不满足C风格的对应关系,导致按照索引从小到大,维数从后到前遍历的时候,在内存中并不是一个一个一个一个元素连续遍历
要想转成连续(比如a是一个不连续的视图)

a = a.contiguous()



会强制拷贝一份a,并用c风格存到内存里。

评论