一.元素类型的转换
- data.type(torch.DoubleTensor)
- data.double
1. data.type(torch.DoubleTensor)或data.type(torch.double)
- 作用:将张量转换为指定的类型类(
Tensor class) - 同类方法:
data.type(torch.IntTensor)或者data.type(torch.int)→ int32data.type(torch.LongTensor)或者data.type(torch.long)→ int64data.type(torch.FloatTensor)或者data.type(torch.float)→ float32
- 示例:
import torch
data = torch.tensor([1, 2, 3], dtype=torch.float32)
# 转换为 DoubleTensor(即 float64)
double_data = data.type(torch.DoubleTensor)
print(double_data.dtype) # 输出:torch.float64
2.data.double()
-
作用:直接将张量转换为
float64 (双精度浮点型),等价于data.type(torch.DoubleTensor) -
同类方法:
data.float()→ float32data.int()→ int32data.long()→ int64data.bool()→ bool
-
示例:
data = torch.tensor([1, 2, 3], dtype=torch.float32)
double_data = data.double()
print(double_data.dtype) # 输出:torch.float64
二.张量和ndarray的转换
张量与Numpy转换
1.张量 → Numpy数组:tensor.numpy()
原生特性
直接调用 tensor.numpy() 生成的数组与原张量共享底层内存,修改其中一方,另一方数据会同步变动
import torch
import numpy as np
t = torch.tensor([1,2,3])
arr = t.numpy()
t[0] = 100
print(arr) # 输出 [100 2 3],数据同步改变
用copy/clone切断内存共享(两种写法)
1. 先克隆张量再转数组(PyTorch侧拷贝)
tensor.clone() 会开辟全新内存复制张量,再转numpy就无关联:
t = torch.tensor([1,2,3])
arr = t.clone().numpy() # 先复制张量,再转数组
t[0] = 100
print(arr) # 输出 [1 2 3],不受原张量修改影响
2. numpy数组层面拷贝(Numpy侧拷贝)
拿到共享数组后调用 .copy() 生成独立数组副本:
t = torch.tensor([1,2,3])
arr_shared = t.numpy()
arr = arr_shared.copy() # 数组单独拷贝
t[0] = 100
print(arr) # 输出 [1 2 3]
2.Numpy数组 → 张量
1.torch.from_numpy(arr) :共享内存
原生转换共享内存,修改numpy数组,张量同步变化;
搭配 arr.copy() 拷贝数组后再转换,即可隔离:
arr = np.array([1,2,3])
t = torch.from_numpy(arr.copy()) # 先拷贝数组
arr[0] = 100
print(t) # tensor([1, 2, 3]),无联动修改
2.torch.tensor(arr) :天然独立内存
该方法会直接复制一份全新数据,无需额外copy,默认不共享内存:
arr = np.array([1,2,3])
t = torch.tensor(arr)
arr[0] = 100
print(t) # tensor([1, 2, 3])
3.转换方法内存对比表
| 转换操作 | 默认内存关系 | 实现数据隔离写法 |
|---|---|---|
tensor.numpy() | 共享内存 | tensor.numpy().copy() |
torch.from_numpy(arr) | 共享内存 | tensor.from_numpy(arr.copy()) |
torch.tensor(arr) | 完全独立 | 无需额外拷贝 |
三.标量张量和数字转换
对于只有一个元素的张量,使用item()函数将该值从张量中提取出来
#当张量只包含一个元素时可以通过item()函数提取出该值
data=torch.rensor([30,])
print(data.item())#输出30
data=torch.tensor(30)
print(data.item())#输出30

1119

被折叠的 条评论
为什么被折叠?



