Pytorch
pytorch剪枝
尽管Pytorch自带了剪枝的工具,但是其在灵活性上终究敌不过自己手写的剪枝代码,以下就是博主模型剪枝的一次简单尝试。
代码:
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch import optim
from torchvision import datasets,transforms
from torch.utils.data import DataLoaderdevice = torch.device('cpu')
# 载入训练集
train_dataset = datasets.MNIST(root='./MNIST/',train=True, # 载入训练集transform=transforms.ToTensor(), # 转变为tensor数据download=True) # 下载数据
#载入测试集
test_dataset = datasets.MNIST(root='./MNIST/',train=False, # 载入测试集transform=transforms.ToTensor(), # 转变为tensor数据download=True) # 下载数据# 设置批次大小(每次传入数据量)
batch_size = 64 # 每次训练64张图片的数据# 装载数据集
train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size, #每批数据的大小shuffle=True) # shuffle表示打乱数据
test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size, #每批数据的大小shuffle=True) # shuffle表示打乱数据class Net(nn.Module):def __init__(self):super(Net, self).__init__()weight_params1 = torch.nn.init.xavier_uniform_(torch.Tensor(16,1,5,5))bias_params1 = torch.zeros((16,),requires_grad=True)self.conv1_weight = nn.Parameter(weight_params1)self.conv1_bias = nn.Parameter(bias_params1)weight_params2 = torch.nn.init.xavier_uniform_(torch.Tensor(32,16,5,5))bias_params2 = torch.zeros((32,),requires_grad=True)self.conv2_weight = nn.Parameter(weight_params2)self.conv2_bias = nn.Parameter(bias_params2)self.fc_weight = nn.Parameter(torch.nn.init.xavier_uniform_(torch.Tensor(10,32*7*7)))self.fc_bias = nn.Parameter(torch.randn((10,)),requires_grad=True)self.conv1_weight=self.conv1_weight.to(device)self.conv2_weight=self.conv2_weight.to(device)self.conv1_bias=self.conv1_bias.to(device)self.conv2_bias=self.conv2_bias.to(device)self.fc_weight=self.fc_weight.to(device)self.fc_bias=self.fc_bias.to(device)self.sparsity=0.5#mask矩阵,用于剪枝self.register_buffer('conv1_mask', torch.ones((16,1,5,5),dtype=torch.uint8))self.register_buffer('conv2_mask', torch.ones((32,16,5,5),dtype=torch.uint8))def forward(self,x):#更新卷积层1的mask矩阵w1=self.conv1_weight.clone().detach()w1=torch.where(self.conv1_mask==1,w1,torch.zeros(w1.size()))w1=torch.abs(w1)sorted,indices=torch.sort(w1.view(-1),descending=False)threshold1=sorted[int(sorted.size(0)*self.sparsity)]self.conv1_mask=torch.tensor(w1.ge(threshold1),dtype=torch.uint8)#print(torch.sum(self.conv1_mask))#更新卷积层2的mask矩阵w2 = self.conv2_weight.clone().detach()w2 = torch.where(self.conv2_mask == 1, w2, torch.zeros(w2.size()))w2 = torch.abs(w2)sorted, indices = torch.sort(w2.view(-1), descending=False)threshold2 = sorted[int(sorted.size(0) * self.sparsity)]self.conv2_mask = torch.tensor(w2.ge(threshold2), dtype=torch.uint8)#第一个卷积层self.conv1_weight.data=self.conv1_weight*self.conv1_maskx=F.conv2d(input=x,weight=self.conv1_weight,bias=self.conv1_bias,stride=1,padding=2) #1,28,28 ---> 16,28,28x=F.relu(x)#池化层x=F.max_pool2d(x,kernel_size=2,stride=2) #(16,14,14)#第二个卷积层self.conv2_weight.data=self.conv2_weight*self.conv2_maskx=F.conv2d(input=x,weight=self.conv2_weight,bias=self.conv2_bias,stride=1,padding=2) #16,14,14 ---> 32,14,14x=F.relu(x)#池化层x=F.max_pool2d(x,kernel_size=2,stride=2) #32,14,14 --》32,7,7#x=x.view(x.size(0),-1) #展开成(batch_size,32*7*7)#全连接层x=F.linear(x,self.fc_weight,bias=self.fc_bias)x = F.softmax(x, dim=1)return xmodel = Net()
model.to(device)
#定义代价函数
mse_loss = nn.MSELoss()
#定义优化器
LR=0.01 #学习率
optimizer = optim.SGD(model.parameters(),lr=LR)def train_model():for i, data in enumerate(train_loader):# 循环一次获得一批次的数据与标签inputs, labels = datainputs, labels = inputs.to(device) , labels.to(device)# 获得模型预测结果out = model(inputs)# to onehot,把数据标签变为独热编码labels = labels.reshape(-1, 1) # 将一维数据变为二维数据(64)->(64,1)one_hot = torch.zeros(inputs.shape[0], 10,device=device).scatter(1, labels, 1)loss = mse_loss(out, one_hot)# 梯度清零optimizer.zero_grad()# 计算梯度loss.backward()# 修改权值optimizer.step()def test_model():correct = 0for i, data in enumerate(test_loader):# 获取一批次的数据inputs, labels = data# 预测结果out = model(inputs)# 获得最大值即最大值所在的位置_, predicted = torch.max(out, 1)# 对比预测结果与标签(累积预测正确的数量)correct += (predicted == labels).sum()print("Test acc:{0}".format(correct.item() / len(test_dataset)))for epoch in range(20):print('epoch:', epoch)train_model()test_model()print(model.conv1_weight*model.conv1_mask)
print(model.conv2_weight*model.conv2_mask)
print(torch.sum(model.conv1_mask))
print(torch.sum(model.conv2_mask))
流程简介
其中,主要流程为:
1.设置一个mask张量,值为0则表示对应权重张量相同位置的权值已经被剪枝,为1则表示还未被剪枝。
2.每次前向推理时,首先得到权重张量的一个拷贝,然后,根据mask张量,将已经被剪枝的权值设为0,然后对其求绝对值。
3.将得到的张量展开为一维张量,升序排列,选取第k小的值作为剪枝时的阈值(k=权值数目*稀疏度)。
4.将所有小于该阈值的权值剪枝掉(即更新mask张量的值为0)
5.在进行卷积前,将权值张量首先乘以mask张量,使得被剪去的权值为0,然后再进行卷积。
(注:由于在前向传播过程中,被剪枝的权值乘以了0,因此反向传播时,梯度也为0,以阻止其继续向前传播)
更改
对全连接层进行剪枝,剪枝策略为Bank-Balanced Sparsity策略,这是一种介于非结构化剪枝和块剪枝之间的一种剪枝策略,详见论文:Efficient and Effective Sparse LSTM on FPGA with Bank-Balanced Sparsity
代码如下:
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch import optim
from torchvision import datasets,transforms
from torch.utils.data import DataLoaderdevice = torch.device('cpu')
# 载入训练集
train_dataset = datasets.MNIST(root='./MNIST/',train=True, # 载入训练集transform=transforms.ToTensor(), # 转变为tensor数据download=True) # 下载数据
#载入测试集
test_dataset = datasets.MNIST(root='./MNIST/',train=False, # 载入测试集transform=transforms.ToTensor(), # 转变为tensor数据download=True) # 下载数据# 设置批次大小(每次传入数据量)
batch_size = 64 # 每次训练64张图片的数据# 装载数据集
train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size, #每批数据的大小shuffle=True) # shuffle表示打乱数据
test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size, #每批数据的大小shuffle=True) # shuffle表示打乱数据class Net(nn.Module):def __init__(self):super(Net, self).__init__()weight_params1 = torch.nn.init.xavier_uniform_(torch.Tensor(16,1,5,5))bias_params1 = torch.zeros((16,),requires_grad=True)self.conv1_weight = nn.Parameter(weight_params1)self.conv1_bias = nn.Parameter(bias_params1)weight_params2 = torch.nn.init.xavier_uniform_(torch.Tensor(32,16,5,5))bias_params2 = torch.zeros((32,),requires_grad=True)self.conv2_weight = nn.Parameter(weight_params2)self.conv2_bias = nn.Parameter(bias_params2)self.fc_weight = nn.Parameter(torch.nn.init.xavier_uniform_(torch.Tensor(10,32*7*7)))self.fc_bias = nn.Parameter(torch.randn((10,)),requires_grad=True)self.conv1_weight=self.conv1_weight.to(device)self.conv2_weight=self.conv2_weight.to(device)self.conv1_bias=self.conv1_bias.to(device)self.conv2_bias=self.conv2_bias.to(device)self.fc_weight=self.fc_weight.to(device)self.fc_bias=self.fc_bias.to(device)self.sparsity=0.5self.bank_size=32#mask矩阵,用于剪枝self.register_buffer('conv1_mask', torch.ones((16,1,5,5),dtype=torch.uint8))self.register_buffer('conv2_mask', torch.ones((32,16,5,5),dtype=torch.uint8))self.register_buffer('fc_mask', torch.ones((10,32*7*7),dtype=torch.uint8))def forward(self,x):#更新卷积层1的mask矩阵w1=self.conv1_weight.clone().detach()w1=torch.where(self.conv1_mask==1,w1,torch.zeros(w1.size()))w1=torch.abs(w1)sorted,indices=torch.sort(w1.view(-1),descending=False)threshold1=sorted[int(sorted.size(0)*self.sparsity)]self.conv1_mask=torch.tensor(w1.ge(threshold1),dtype=torch.uint8).clone().detach()#print(torch.sum(self.conv1_mask))#更新卷积层2的mask矩阵w2 = self.conv2_weight.clone().detach()w2 = torch.where(self.conv2_mask == 1, w2, torch.zeros(w2.size()))w2 = torch.abs(w2)sorted, indices = torch.sort(w2.view(-1), descending=False)threshold2 = sorted[int(sorted.size(0) * self.sparsity)]self.conv2_mask = torch.tensor(w2.ge(threshold2), dtype=torch.uint8).clone().detach()#更新全连接层的mask矩阵w3 = self.fc_weight.clone().detach()w3 = torch.where(self.fc_mask==1,w3,torch.zeros(w3.size())) #根据mask,已经被剪枝的权值设为0w3 =torch.abs(w3) #求绝对值作为其重要性的度量for i in range(10):for j in range(32*7*7//self.bank_size): #在每一个bank内进行细粒度剪枝bank_weight=w3[i,j*self.bank_size:j*self.bank_size+self.bank_size] #获取当前块sorted,indices=torch.sort(bank_weight,descending=False) #升序排列threshold3=sorted[int(self.bank_size*self.sparsity)] #bank_mask=torch.tensor(bank_weight.ge(threshold3),dtype=torch.uint8).clone().detach() #大于阈值的为1,即保留,小于阈值的剪去self.fc_mask[i,j*self.bank_size:j*self.bank_size+self.bank_size]=bank_mask #更新mask张量#print(torch.sum(self.fc_mask))#第一个卷积层self.conv1_weight.data=self.conv1_weight*self.conv1_maskx=F.conv2d(input=x,weight=self.conv1_weight,bias=self.conv1_bias,stride=1,padding=2) #1,28,28 ---> 16,28,28x=F.relu(x)#池化层x=F.max_pool2d(x,kernel_size=2,stride=2) #(16,14,14)#第二个卷积层self.conv2_weight.data=self.conv2_weight*self.conv2_maskx=F.conv2d(input=x,weight=self.conv2_weight,bias=self.conv2_bias,stride=1,padding=2) #16,14,14 ---> 32,14,14x=F.relu(x)#池化层x=F.max_pool2d(x,kernel_size=2,stride=2) #32,14,14 --》32,7,7#x=x.view(x.size(0),-1) #展开成(batch_size,32*7*7)#全连接层self.fc_weight.data=self.fc_weight*self.fc_maskx=F.linear(x,self.fc_weight,bias=self.fc_bias)x = F.softmax(x, dim=1)return xmodel = Net()
model.to(device)
#定义代价函数
mse_loss = nn.MSELoss()
#定义优化器
LR=0.01 #学习率
optimizer = optim.SGD(model.parameters(),lr=LR)def train_model():for i, data in enumerate(train_loader):# 循环一次获得一批次的数据与标签inputs, labels = datainputs, labels = inputs.to(device) , labels.to(device)# 获得模型预测结果out = model(inputs)# to onehot,把数据标签变为独热编码labels = labels.reshape(-1, 1) # 将一维数据变为二维数据(64)->(64,1)one_hot = torch.zeros(inputs.shape[0], 10,device=device).scatter(1, labels, 1)loss = mse_loss(out, one_hot)# 梯度清零optimizer.zero_grad()# 计算梯度loss.backward()# 修改权值optimizer.step()def test_model():correct = 0for i, data in enumerate(test_loader):# 获取一批次的数据inputs, labels = data# 预测结果out = model(inputs)# 获得最大值即最大值所在的位置_, predicted = torch.max(out, 1)# 对比预测结果与标签(累积预测正确的数量)correct += (predicted == labels).sum()print("Test acc:{0}".format(correct.item() / len(test_dataset)))trained=Trueif not trained:for epoch in range(20):print('epoch:', epoch)train_model()test_model()torch.save(model.state_dict(), "mynet.pth")
else:model = Net()# 加载预训练模型的参数model.load_state_dict(torch.load("mynet.pth"))test_model()print(torch.sum(model.conv1_mask))print(torch.sum(model.conv2_mask))print(torch.sum(model.fc_mask))wfc=model.fc_maskfor i in range(10):for j in range(32*7*7//32):print(torch.sum(wfc[i,j*32:j*32+32]))print(model.fc_weight)# print(model.conv1_weight*model.conv1_mask)
# print(model.conv2_weight*model.conv2_mask)
# print(torch.sum(model.conv1_mask))
# print(torch.sum(model.conv2_mask))

发布评论