pytorch:模型载存、Datasets、数据预处理
pytorch:模型载存、Datasets、数据预处理
模型载存
两种方法:
torch.save(obj, dir)
保存整个模型(结构+参数)
整个model对象
#保存模型
torch.save(model_object,'model.pkl')
#加载模型
model = torch.load('model.pkl')
只保存参数(推荐)
获取model的参数 - 词典
#仅保存模型参数
torch.save(model_object.state_dict(),'params.pkl')
#仅加载模型参数
model_object.load_state_dict(torch.load('params.pkl'))
#也可以指定
save_data = {
'model_state_dict':model.state_dict(),
'optimzer_state_dict':optimzer.state_dict(),
'loss':loss,
'epoch':epoch,
'args':args
···
}
- 当载存整个模型的时候:只需要在训练阶段定义网络结构,在测试代码当中不需要重新定义,直接根据文件加载网络结构+参数即可
- 当仅载存模型参数的时候,不仅需要在训练阶段定义网络结构,在测试代码中也需要,然后根据文件加载参数。
Datasets
为什么要有 Datasets?
- 方便数据集输入到模型当中
- 方便划分数据集:训练集、测试集、验证集
- ……
怎么定义 Datasets?
1. Datasets 类
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class MyDataset(Dataset):#继承Dataset类
def __init__(self,dir,augment=None, transform=None):
pass
#从datasets数据集中取读一条数据,index表示索引位置返回obj&label
def __getitem__(self, index):
pass
#返回数据集的总长度(训练集的总数)
def __len__(self):
pass
数据增强( augment):比如数据集原来有10张照片,我们将它颜倒、旋转、缩放、剪裁等扩充100张照片,因为无论一只猫是倒着还是正着,我们都能识别出它是猫,因此我们希望通过数据集的扩充,一方面增加数据集的大小,另一方面体现这个不变性。
2. 分析DataLoader
train_loader = DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
DataLoder(datasets,batch_size=128, shuffle = True)
- batch_size表示我们定义的batch大小(即每轮训练使用的批大小)
- shuffle表示是否打乱数据顺序(对于整个datasets里包含的所有数据)
Datasets的使用
datasets = MyDataset(dir = '../data/', argument = None)
dataloader = DataLoader(datasets, batch_size = 128, shuffle = True)
num_epoches = 10000
for epoch in range(num_epoches):
for img, label in dataloader:
#训练(后向传播)
pass
pass
数据预处理
首先区分训练阶段、测试阶段数据预处理有差异性也有一致性
差异性:数据増强一般只用于训练集以扩充训练集满足不变性,但是在测试/应用阶段不必进行增强
一致性:尽量保证训练、測试/应用时候数据的格式、属性是一致的
举几个例子(一致性):
- 去除唯一属性:id,对分类没什么用
- 缺失值处理:直接删除带缺失值的数据或者进行补全
补全技术是一个很有意思的问题,经常出现,常见的解决方式有:压缩感知、因果推理( Causal Inference)。