Skip to content

张量(Tensor)是有序的序列,因此可以使用索引访问元素,其索引方式与 numpy.array 基本一致。此外,PyTorch 还提供了一些函数用于索引操作

张量不仅具备列表和数组的基本功能,同时也可以表示向量、矩阵,甚至类似数据框的结构。因此,PyTorch 提供了完善的张量合并与变换操作,以支持多种计算需求。

张量的索引

张量的符号索引

啥是索引不多bb了吧

一维符号索引

索引和切片跟列表,np.array的方式一样,不多bb了

张量索引出来的结果是0维张量,不是单独的数,要转化成单独的数可以使用item()

image-20250227205155213

二维符号索引

也和np.array一样

image-20250227205710172

三维符号索引

一样的,就是多加个逗号

image-20250227210938572

张量的函数索引

在Pytorch中,还可以通过**select_index()**函数,通过指定index来对张量进行索引

torch.select_index():

python
torch.index_select(
	input: Tensor, 
	dim: int, 
	index: Tensor)

参数说明:

  • input:待索引的张量(Tensor)。

  • dim:沿着哪个维度进行索引(0 表示按行索引,1 表示按列索引,以此类推)。

  • index,表示要选取的索引值。

一维函数索引

image-20250228131110197

二维函数索引

image-20250228131435476

tensor.view()方法

tensor.view()是 PyTorch 中用于改变张量形状(维度)的方法,但不会更改张量的原始数据

,它不会创建新的数据副本,而是返回一个新的张量,但与原张量共享相同的内存。即原先的张量变了,新的张量也会跟着变。,比如后面的张量分片中,会返回原对象的视图,而不是新的对象。image-20250228132501981

tensor.storage().data_ptr():查看张量的内存地址

张量的切分

分块:torch.chunk()

torch.chunk() 是 PyTorch 中用于将张量分割成多个子张量的函数。它通过指定分割的数量将一个大张量分成多个小张量,每个小张量具有相同的大小(除了可能的剩余部分)。

当使用 torch.chunk() 分割张量时,如果张量不能均匀分割成指定的块数,不会报错,PyTorch 会尽量平均分配剩余的元素。具体来说,剩余的元素会被分配到最后几个块中,使得它们的大小相差不超过 1。

注意

python
torch.chunk(input, chunks, dim=0)

参数说明

  • input:要分割的输入张量。
  • chunks:分割的数量,即将张量分成多少份。
  • dim(可选):指定沿着哪个维度进行分割,默认是沿着第 0 维(即行方向)。
image-20250228133556234

拆分:torch.split()

torch.split() 是 PyTorch 中用于将张量按指定的大小分割成多个子张量的函数,与 torch.chunk() 类似,但它提供了更多的灵活性,尤其是在指定分割大小时。可以指定分割子张量的大小,不执着于均分

python
torch.split(tensor, split_size_or_sections, dim=0)

参数说明:

  • tensor:输入的张量。
  • split_size_or_sections:可以是一个整数或者一个列表/元组。
    • 如果是一个整数,表示将张量沿指定维度分割成大小为 split_size_or_sections 的块。
    • 如果是一个列表或元组,表示指定每个子张量的大小,这个列表的总和应该等于输入张量在指定维度的大小。
  • dim(可选):指定沿着哪个维度进行分割,默认是第 0 维(即行方向)。
image-20250228134544443

张量的合并

拼接函数torch.cat()

torch.cat() 是 PyTorch 中用于连接多个张量的函数,可以沿指定的维度将多个张量连接成一个大张量。

python
torch.cat(tensors, dim=0, out=None)

参数说明:

  • tensors:一个张量列表或元组,需要连接的张量们。这些张量必须在除连接维度之外的所有维度上具有相同的大小。
  • dim:指定沿着哪个维度进行连接,默认是第 0 维(即行方向)。如果 dim=1,则沿列方向连接。
  • out(可选):一个输出张量,可以将结果直接存放在这个张量中,默认为 None,即创建一个新的张量。
image-20250228135459120

堆叠函数torch.stack()

torch.stack() 是 PyTorch 中用于沿新维度连接多个张量的函数。与 torch.cat() 不同,torch.stack() 不只是简单地将张量拼接在现有维度上,而是创建一个新的维度,将多个张量堆叠在一起。

python
torch.stack(tensors, dim=0, out=None)

参数说明:

  • tensors:一个张量列表或元组,表示需要堆叠的张量。所有张量必须具有相同的形状(包括大小和维度)。
  • dim:指定新维度的位置,默认为 0。新的维度将会插入到指定的位置,所有张量会被堆叠到这个新维度中。
  • out(可选):一个输出张量,用于存放堆叠结果,默认为 None

返回值

返回一个新的张量,它是输入张量沿新维度堆叠的结果。

image-20250228135955128image-20250228140214700

张量的维度变换

之前通过torch.reshape()可以灵活地调整张量形状,但当想要时,可以使用torch.unsqueeze()torch.squeeze()来搞

python
torch.unsqueeze(input, dim) # 增加一个大小为1的新维度
torch.squeeze(input, dim)   # 删除大小为 1 的维度

参数

  • input:输入张量。
  • dim:指定插入新维度的位置。例如,dim=0 会在第0维插入一个新维度。
image-20250228141052782