亚洲乱码中文字幕综合,中国熟女仑乱hd,亚洲精品乱拍国产一区二区三区,一本大道卡一卡二卡三乱码全集资源,又粗又黄又硬又爽的免费视频

簡(jiǎn)單易懂Pytorch實(shí)戰(zhàn)實(shí)例VGG深度網(wǎng)絡(luò)

 更新時(shí)間:2019年08月27日 08:33:17   作者:青盞  
這篇文章主要介紹了簡(jiǎn)單易懂Pytorch實(shí)戰(zhàn)實(shí)例VGG深度網(wǎng)絡(luò),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧

模型VGG,數(shù)據(jù)集cifar。對(duì)照這份代碼走一遍,大概就知道整個(gè)pytorch的運(yùn)行機(jī)制。

來(lái)源

定義模型:

'''VGG11/13/16/19 in Pytorch.'''
import torch
import torch.nn as nn
from torch.autograd import Variable


cfg = {
  'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
  'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
  'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

# 模型需繼承nn.Module
class VGG(nn.Module):
# 初始化參數(shù):
  def __init__(self, vgg_name):
    super(VGG, self).__init__()
    self.features = self._make_layers(cfg[vgg_name])
    self.classifier = nn.Linear(512, 10)

# 模型計(jì)算時(shí)的前向過(guò)程,也就是按照這個(gè)過(guò)程進(jìn)行計(jì)算
  def forward(self, x):
    out = self.features(x)
    out = out.view(out.size(0), -1)
    out = self.classifier(out)
    return out

  def _make_layers(self, cfg):
    layers = []
    in_channels = 3
    for x in cfg:
      if x == 'M':
        layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
      else:
        layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
              nn.BatchNorm2d(x),
              nn.ReLU(inplace=True)]
        in_channels = x
    layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
    return nn.Sequential(*layers)

# net = VGG('VGG11')
# x = torch.randn(2,3,32,32)
# print(net(Variable(x)).size())

定義訓(xùn)練過(guò)程:

'''Train CIFAR10 with PyTorch.'''
from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import os
import argparse

from models import *
from utils import progress_bar
from torch.autograd import Variable

# 獲取參數(shù)
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
args = parser.parse_args()

use_cuda = torch.cuda.is_available()
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch

# 獲取數(shù)據(jù)集,并先進(jìn)行預(yù)處理
print('==> Preparing data..')
# 圖像預(yù)處理和增強(qiáng)
transform_train = transforms.Compose([
  transforms.RandomCrop(32, padding=4),
  transforms.RandomHorizontalFlip(),
  transforms.ToTensor(),
  transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 繼續(xù)訓(xùn)練模型或新建一個(gè)模型
if args.resume:
  # Load checkpoint.
  print('==> Resuming from checkpoint..')
  assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!'
  checkpoint = torch.load('./checkpoint/ckpt.t7')
  net = checkpoint['net']
  best_acc = checkpoint['acc']
  start_epoch = checkpoint['epoch']
else:
  print('==> Building model..')
  net = VGG('VGG16')
  # net = ResNet18()
  # net = PreActResNet18()
  # net = GoogLeNet()
  # net = DenseNet121()
  # net = ResNeXt29_2x64d()
  # net = MobileNet()
  # net = MobileNetV2()
  # net = DPN92()
  # net = ShuffleNetG2()
  # net = SENet18()

# 如果GPU可用,使用GPU
if use_cuda:
  # move param and buffer to GPU
  net.cuda()
  # parallel use GPU
  net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()-1))
  # speed up slightly
  cudnn.benchmark = True


# 定義度量和優(yōu)化
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

# 訓(xùn)練階段
def train(epoch):
  print('\nEpoch: %d' % epoch)
  # switch to train mode
  net.train()
  train_loss = 0
  correct = 0
  total = 0
  # batch 數(shù)據(jù)
  for batch_idx, (inputs, targets) in enumerate(trainloader):
    # 將數(shù)據(jù)移到GPU上
    if use_cuda:
      inputs, targets = inputs.cuda(), targets.cuda()
    # 先將optimizer梯度先置為0
    optimizer.zero_grad()
    # Variable表示該變量屬于計(jì)算圖的一部分,此處是圖計(jì)算的開(kāi)始處。圖的leaf variable
    inputs, targets = Variable(inputs), Variable(targets)
    # 模型輸出
    outputs = net(inputs)
    # 計(jì)算loss,圖的終點(diǎn)處
    loss = criterion(outputs, targets)
    # 反向傳播,計(jì)算梯度
    loss.backward()
    # 更新參數(shù)
    optimizer.step()
    # 注意如果你想統(tǒng)計(jì)loss,切勿直接使用loss相加,而是使用loss.data[0]。因?yàn)閘oss是計(jì)算圖的一部分,如果你直接加loss,代表total loss同樣屬于模型一部分,那么圖就越來(lái)越大
    train_loss += loss.data[0]
    # 數(shù)據(jù)統(tǒng)計(jì)
    _, predicted = torch.max(outputs.data, 1)
    total += targets.size(0)
    correct += predicted.eq(targets.data).cpu().sum()

    progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
      % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

# 測(cè)試階段
def test(epoch):
  global best_acc
  # 先切到測(cè)試模型
  net.eval()
  test_loss = 0
  correct = 0
  total = 0
  for batch_idx, (inputs, targets) in enumerate(testloader):
    if use_cuda:
      inputs, targets = inputs.cuda(), targets.cuda()
    inputs, targets = Variable(inputs, volatile=True), Variable(targets)
    outputs = net(inputs)
    loss = criterion(outputs, targets)
    # loss is variable , if add it(+=loss) directly, there will be a bigger ang bigger graph.
    test_loss += loss.data[0]
    _, predicted = torch.max(outputs.data, 1)
    total += targets.size(0)
    correct += predicted.eq(targets.data).cpu().sum()

    progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
      % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

  # Save checkpoint.
  # 保存模型
  acc = 100.*correct/total
  if acc > best_acc:
    print('Saving..')
    state = {
      'net': net.module if use_cuda else net,
      'acc': acc,
      'epoch': epoch,
    }
    if not os.path.isdir('checkpoint'):
      os.mkdir('checkpoint')
    torch.save(state, './checkpoint/ckpt.t7')
    best_acc = acc

# 運(yùn)行模型
for epoch in range(start_epoch, start_epoch+200):
  train(epoch)
  test(epoch)
  # 清除部分無(wú)用變量 
  torch.cuda.empty_cache()

運(yùn)行:

新模型:
python main.py --lr=0.01
舊模型繼續(xù)訓(xùn)練:
python main.py --resume --lr=0.01

一些utility:

'''Some helper functions for PyTorch, including:
  - get_mean_and_std: calculate the mean and std value of dataset.
  - msr_init: net parameter initialization.
  - progress_bar: progress bar mimic xlua.progress.
'''
import os
import sys
import time
import math

import torch.nn as nn
import torch.nn.init as init


def get_mean_and_std(dataset):
  '''Compute the mean and std value of dataset.'''
  dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
  mean = torch.zeros(3)
  std = torch.zeros(3)
  print('==> Computing mean and std..')
  for inputs, targets in dataloader:
    for i in range(3):
      mean[i] += inputs[:,i,:,:].mean()
      std[i] += inputs[:,i,:,:].std()
  mean.div_(len(dataset))
  std.div_(len(dataset))
  return mean, std

def init_params(net):
  '''Init layer parameters.'''
  for m in net.modules():
    if isinstance(m, nn.Conv2d):
      init.kaiming_normal(m.weight, mode='fan_out')
      if m.bias:
        init.constant(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
      init.constant(m.weight, 1)
      init.constant(m.bias, 0)
    elif isinstance(m, nn.Linear):
      init.normal(m.weight, std=1e-3)
      if m.bias:
        init.constant(m.bias, 0)


_, term_width = os.popen('stty size', 'r').read().split()
term_width = int(term_width)

TOTAL_BAR_LENGTH = 65.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
  global last_time, begin_time
  if current == 0:
    begin_time = time.time() # Reset for new bar.

  cur_len = int(TOTAL_BAR_LENGTH*current/total)
  rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

  sys.stdout.write(' [')
  for i in range(cur_len):
    sys.stdout.write('=')
  sys.stdout.write('>')
  for i in range(rest_len):
    sys.stdout.write('.')
  sys.stdout.write(']')

  cur_time = time.time()
  step_time = cur_time - last_time
  last_time = cur_time
  tot_time = cur_time - begin_time

  L = []
  L.append(' Step: %s' % format_time(step_time))
  L.append(' | Tot: %s' % format_time(tot_time))
  if msg:
    L.append(' | ' + msg)

  msg = ''.join(L)
  sys.stdout.write(msg)
  for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
    sys.stdout.write(' ')

  # Go back to the center of the bar.
  for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
    sys.stdout.write('\b')
  sys.stdout.write(' %d/%d ' % (current+1, total))

  if current < total-1:
    sys.stdout.write('\r')
  else:
    sys.stdout.write('\n')
  sys.stdout.flush()

def format_time(seconds):
  days = int(seconds / 3600/24)
  seconds = seconds - days*3600*24
  hours = int(seconds / 3600)
  seconds = seconds - hours*3600
  minutes = int(seconds / 60)
  seconds = seconds - minutes*60
  secondsf = int(seconds)
  seconds = seconds - secondsf
  millis = int(seconds*1000)

  f = ''
  i = 1
  if days > 0:
    f += str(days) + 'D'
    i += 1
  if hours > 0 and i <= 2:
    f += str(hours) + 'h'
    i += 1
  if minutes > 0 and i <= 2:
    f += str(minutes) + 'm'
    i += 1
  if secondsf > 0 and i <= 2:
    f += str(secondsf) + 's'
    i += 1
  if millis > 0 and i <= 2:
    f += str(millis) + 'ms'
    i += 1
  if f == '':
    f = '0ms'
  return f

以上就是本文的全部?jī)?nèi)容,希望對(duì)大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。

相關(guān)文章

  • Python進(jìn)階之協(xié)程詳解

    Python進(jìn)階之協(xié)程詳解

    這篇文章主要為大家介紹了Python進(jìn)階之協(xié)程,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下,希望能夠給你帶來(lái)幫助
    2022-01-01
  • python 字符串格式化代碼

    python 字符串格式化代碼

    python 字符串格式化代碼,需要的朋友可以參考一下
    2013-03-03
  • Python處理文本換行符實(shí)例代碼

    Python處理文本換行符實(shí)例代碼

    這篇文章主要介紹了Python處理文本換行符實(shí)例代碼,分享了相關(guān)代碼示例,小編覺(jué)得還是挺不錯(cuò)的,具有一定借鑒價(jià)值,需要的朋友可以參考下
    2018-02-02
  • Python無(wú)法安裝包的一種解決(Requirement already satisfied問(wèn)題)

    Python無(wú)法安裝包的一種解決(Requirement already satisfied問(wèn)題)

    這篇文章主要介紹了Python無(wú)法安裝包的一種解決(Requirement already satisfied問(wèn)題),具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2023-08-08
  • python爬蟲反爬之圖片驗(yàn)證功能實(shí)現(xiàn)

    python爬蟲反爬之圖片驗(yàn)證功能實(shí)現(xiàn)

    這篇文章主要介紹了python爬蟲反爬之圖片驗(yàn)證功能實(shí)現(xiàn),本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友參考下吧
    2024-03-03
  • Python實(shí)現(xiàn)字符串中某個(gè)字母的替代功能

    Python實(shí)現(xiàn)字符串中某個(gè)字母的替代功能

    小編想實(shí)現(xiàn)這樣一個(gè)功能:將輸入字符串中的字母 “i” 變成字母 “p”。想著很簡(jiǎn)單,怎么實(shí)現(xiàn)呢?下面小編給大家?guī)?lái)了Python實(shí)現(xiàn)字符串中某個(gè)字母的替代功能,感興趣的朋友一起看看吧
    2019-10-10
  • Python接口自動(dòng)化測(cè)試框架運(yùn)行原理及流程

    Python接口自動(dòng)化測(cè)試框架運(yùn)行原理及流程

    這篇文章主要介紹了Python接口自動(dòng)化測(cè)試框架運(yùn)行原理及流程,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2020-11-11
  • Python中的迭代器漫談

    Python中的迭代器漫談

    這篇文章主要介紹了Python中的迭代器漫談,本文主要講解range函數(shù)和xrange函數(shù)性能區(qū)別,需要的朋友可以參考下
    2015-02-02
  • Python實(shí)現(xiàn)讀取并保存文件的類

    Python實(shí)現(xiàn)讀取并保存文件的類

    這篇文章主要介紹了Python實(shí)現(xiàn)讀取并保存文件的類,涉及Python針對(duì)文件的讀寫操作相關(guān)實(shí)現(xiàn)技巧,需要的朋友可以參考下
    2017-05-05
  • openCV入門學(xué)習(xí)基礎(chǔ)教程第二篇

    openCV入門學(xué)習(xí)基礎(chǔ)教程第二篇

    人臉識(shí)別,物體檢測(cè),OpenCV是基石,下面這篇文章主要給大家介紹了關(guān)于openCV入門學(xué)習(xí)基礎(chǔ)教程的相關(guān)資料,文中通過(guò)實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下
    2022-11-11

最新評(píng)論