摘要:import torch import os, glob import warnings import random, csv from PIL import Image from torch import optim, nn import torch.nn.functional as F from torchvision import transforms from torchvision.models import resnet18 from torch.utils.data import Dataset, DataLoader warnings.filterwarnings('ignore') from matplotlib import pyplot as plt class Pokemon(Dataset): def __init__(self, root, resize, model): super(Pokemon, self).__init__() self.root = root self.resize = resize self.name2label = {} # 將文件夾的名字映射爲label(數字) for name in sorted(os.listdir(os.path.join(root))): if not os.path.isdir(os.path.join(root, name)): continue self.name2label[name] = len(self.name2label.keys()) # image, label self.images, self.labels = self.load_csv('images.csv') if model == 'train': # 60% self.images = self.images[:int(0.6*len(self.images))] self.labels = self.labels[:int(0.6*len(self.labels))] elif model == 'val': # 20% self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))] self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))] else: # 20% self.images = self.images[int(0.8*len(self.images)):] self.labels = self.labels[int(0.8*len(self.labels)):] def load_csv(self, filename): if not os.path.exists(os.path.join(self.root, filename)): images = [] for name in self.name2label.keys(): images += glob.glob(os.path.join(self.root, name, '*.png')) images += glob.glob(os.path.join(self.root, name, '*.jpg')) images += glob.glob(os.path.join(self.root, name, '*.jpeg')) random.shuffle(images) with open(os.path.join(self.root, filename), mode='w', newline='') as f: writer = csv.writer(f) for img in images: # pokemon\\bulbasaur\\00000000.png name = img.split(os.sep)[-2] # bulbasaur label = self.name2label[name] # pokemon\\bulbasaur\\00000000.png 0 writer.writerow([img, label]) print('writen into csv file:', filename) # read csv file images, labels = [], [] with open(os.path.join(self.root, filename)) as f: reader = csv.reader(f) for row in reader: image, label = row label = int(label) images.append(image) labels.append(label) assert len(images) == len(labels) return images, labels def __len__(self): return len(self.images) def __getitem__(self, idx): # idx [0~len(images)] # self.images, self.labels # pokemon\\bulbasaur\\00000000.png 0 img, label = self.images[idx], self.labels[idx] tf = transforms.Compose([ lambda x:Image.open(x).convert('RGB'), # string path => image data transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))), transforms.RandomRotation(15), transforms.CenterCrop(self.resize), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) img = tf(img) label = torch.tensor(label) return img, label class Flatten(nn.Module): def __init__(self): super(Flatten, self).__init__() def forward(self, x): shape = torch.prod(torch.tensor(x.shape[1:])).item() return x.view(-1, shape) batchsz = 32 lr = 1e-3 epochs = 10 device = torch.device('cuda') torch.manual_seed(1234) train_db = Pokemon('pokemon', 224, model='train') val_db = Pokemon('pokemon', 224, model='val') test_db = Pokemon('pokemon', 224, model='test') train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=2) val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2) test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2) def evalute(model, loader): correct = 0 total = len(loader.dataset) for x,y in loader: with torch.no_grad(): logits = model(x) pred = logits.argmax(dim=1) correct += torch.eq(pred, y).sum().float().item() return correct / total def main(): trained_model = resnet18(pretrained=True) model = nn.Sequential(*list(trained_model.children())[:-1],# [b, 512, 1, 1] Flatten(), # [b, 512, 1, 1] => [b, 512] nn.Linear(512, 5) ) optimizer = optim.Adam(model.parameters(), lr=lr) criteon = nn.CrossEntropyLoss() best_acc, best_epoch = 0, 0 for epoch in range(epochs): for step, (x, y) in enumerate(train_loader): # x:[b, 3, 224, 224], y:[b] logits = model(x) loss = criteon(logits, y) optimizer.zero_grad() loss.backward() optimizer.step() if epoch % 2 == 0: val_acc = evalute(model, val_loader) if val_acc > best_acc: best_epoch = epoch best_acc = val_acc torch.save(model.state_dict(), 'best.mdl') print('best acc:', best_acc, 'best_epoch', best_epoch) model.load_state_dict(torch.load('best.mdl')) print('loaded from ckt。import torch import os, glob import warnings import random, csv from PIL import Image from torch import optim, nn import torch.nn.functional as F from torchvision import transforms from torch.utils.data import Dataset, DataLoader warnings.filterwarnings('ignore') class Pokemon(Dataset): def __init__(self, root, resize, model): super(Pokemon, self).__init__() self.root = root self.resize = resize self.name2label = {} # 將文件夾的名字映射爲label(數字) for name in sorted(os.listdir(os.path.join(root))): if not os.path.isdir(os.path.join(root, name)): continue self.name2label[name] = len(self.name2label.keys()) # image, label self.images, self.labels = self.load_csv('images.csv') if model == 'train': # 60% self.images = self.images[:int(0.6*len(self.images))] self.labels = self.labels[:int(0.6*len(self.labels))] elif model == 'val': # 20% self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))] self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))] else: # 20% self.images = self.images[int(0.8*len(self.images)):] self.labels = self.labels[int(0.8*len(self.labels)):] def load_csv(self, filename): if not os.path.exists(os.path.join(self.root, filename)): images = [] for name in self.name2label.keys(): images += glob.glob(os.path.join(self.root, name, '*.png')) images += glob.glob(os.path.join(self.root, name, '*.jpg')) images += glob.glob(os.path.join(self.root, name, '*.jpeg')) random.shuffle(images) with open(os.path.join(self.root, filename), mode='w', newline='') as f: writer = csv.writer(f) for img in images: # pokemon\\bulbasaur\\00000000.png name = img.split(os.sep)[-2] # bulbasaur label = self.name2label[name] # pokemon\\bulbasaur\\00000000.png 0 writer.writerow([img, label]) print('writen into csv file:', filename) # read csv file images, labels = [], [] with open(os.path.join(self.root, filename)) as f: reader = csv.reader(f) for row in reader: image, label = row label = int(label) images.append(image) labels.append(label) assert len(images) == len(labels) return images, labels def __len__(self): return len(self.images) def __getitem__(self, idx): # idx [0~len(images)] # self.images, self.labels # pokemon\\bulbasaur\\00000000.png 0 img, label = self.images[idx], self.labels[idx] tf = transforms.Compose([ lambda x:Image.open(x).convert('RGB'), # string path => image data transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))), transforms.RandomRotation(15), transforms.CenterCrop(self.resize), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) img = tf(img) label = torch.tensor(label) return img, label class ResBlk(nn.Module): def __init__(self, ch_in, ch_out, stride=1): super(ResBlk, self).__init__() self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1) self.bn1 = nn.BatchNorm2d(ch_out) self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1) self.bn2 = nn.BatchNorm2d(ch_out) self.extra = nn.Sequential() if ch_out。

Pokemon Dataset

通過網絡上收集寶可夢的圖片,製作圖像分類數據集。我收集了5種寶可夢,分別是皮卡丘,超夢,傑尼龜,小火龍,妙蛙種子

數據集鏈接: https://pan.baidu.com/s/1Kept7FF88lb8TqPZMD_Yxw 提取碼:1sdd

一共有1168張寶可夢的圖片,其中皮卡丘234張,超夢239張,傑尼龜223張,小火龍238張,妙蛙種子234張

每個目錄由神奇寶貝名字命名,對應目錄下是該神奇寶貝的圖片,圖片的格式有jpg,png,jpeg三種

數據集的劃分如下(訓練集60%,驗證集20%,測試集20%)。這個比例不是針對每一類提取,而是針對總體的1168張

Load Data

在PyTorch中定義數據集主要涉及到兩個主要的類:Dataset和DataLoder

DataSet類

DataSet類是PyTorch中所有數據集加載類中都應該繼承的父類,它的兩個私有成員函數 __len__()__getitem__() 必須被重載,否則將觸發錯誤提示

其中 __len__ 應該返回數據集的樣本數量,而 __getitem__() 實現通過索引返回樣本數據的功能

首先看一個自定義Dataset的例子

class NumbersDataset(Dataset):
    def __init__(self, training=True):
        if training:
            self.samples = list(range(1, 1001))
        else:
            self.samples = list(range(1001, 1501))
            
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]

然後需要對圖片做Preprocessing

  1. Image Resize:224*224 for ResNet18
  2. Data Argumentation:Rotate & Crop
  3. Normalize:Mean & std
  4. ToTensor

首先我們在 __init__() 函數里將name->label,這裏的name就是文件夾的名字,然後拆分數據集,按照6:2:2的比例

class Pokemon(Dataset):
    def __init__(self, root, resize, model):
        super(Pokemon, self).__init__()
        
        self.root = root
        self.resize = resize

        self.name2label = {} # 將文件夾的名字映射爲label(數字)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root, name)):
                continue
            self.name2label[name] = len(self.name2label.keys())

        # image, label
        self.images, self.labels = self.load_csv('images.csv')
        
        if model == 'train': # 60%
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif model == 'val': # 20%
            self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
            self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
        else: # 20%
            self.images = self.images[int(0.8*len(self.images)):]
            self.labels = self.labels[int(0.8*len(self.labels)):]

其中 load_csv() 函數的作用是將所有的圖片名(名字裏包含完整的路徑)以及label都存到csv文件裏,例如,有一個圖片的路徑是 pokemon\\bulbasaur\\00000000.png ,對應的label是0,那麼csv就會寫入一行 pokemon\\bulbasaur\\00000000.png, 0 ,總共寫入了1167行(有一張圖片既不是png,也不是jpg和jpeg,找不到,算了)。 load_csv() 函數具體如下所示

def load_csv(self, filename):
    if not os.path.exists(os.path.join(self.root, filename)):
        images = []
        for name in self.name2label.keys():
            images += glob.glob(os.path.join(self.root, name, '*.png'))
            images += glob.glob(os.path.join(self.root, name, '*.jpg'))
            images += glob.glob(os.path.join(self.root, name, '*.jpeg'))

        random.shuffle(images)
        with open(os.path.join(self.root, filename), mode='w', newline='') as f:
            writer = csv.writer(f)
            for img in images: # pokemon\\bulbasaur\\00000000.png
                name = img.split(os.sep)[-2] # bulbasaur
                label = self.name2label[name]
                # pokemon\\bulbasaur\\00000000.png 0
                writer.writerow([img, label])
            print('writen into csv file:', filename)

    # read csv file
    images, labels = [], []
    with open(os.path.join(self.root, filename)) as f:
        reader = csv.reader(f)
        for row in reader:
            image, label = row
            label = int(label)
            images.append(image)
            labels.append(label)
    assert len(images) == len(labels)
    return images, labels

然後是 __len__() 函數的代碼

def __len__(self):
    return len(self.images)

最後是 __getitem__() 函數的代碼,這個比較複雜,因爲我們現在只有圖片的string path(字符串形式的路徑),要先轉成三通道的image data,這個利用PIL庫中的 Image.open(path).convert('RGB') 函數可以完成。圖片讀取出來以後,要經過一系列的transforms,具體代碼如下

def __getitem__(self, idx):
    # idx [0~len(images)]
    # self.images, self.labels
    # pokemon\\bulbasaur\\00000000.png    0
    img, label = self.images[idx], self.labels[idx]
    tf = transforms.Compose([
        lambda x:Image.open(x).convert('RGB'), # string path => image data
        transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
        transforms.RandomRotation(15),
        transforms.CenterCrop(self.resize),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    img = tf(img)
    label = torch.tensor(label)

    return img, label

Normalize的參數是PyTorch推薦的,直接寫上就可以了

DataLoader類

Dataset類是讀入數據集並對讀入的數據進行了索引,但是光有這個功能是不夠的,在實際加載數據集的過程中,我們的數據量往往都很大,因此還需要以下幾個功能:

  1. 每次讀入一些批次:batch_size
  2. 可以對數據進行隨機讀取,打亂數據的順序(shuffling)
  3. 可以並行加載數據集(利用多核處理器加快載入數據的效率)

爲此,就需要DataLoader類了,它裏面常用的參數有:

  • batch_size:每個batch的大小
  • shuffle:是否進行shuffle操作
  • num_works:加載數據的時候使用幾個進程

DataLoader這個類並不需要我們自己設計代碼,只需要利用它讀取我們設計好的Dataset的子類即可

db = Pokemon('pokemon', 224, 'train')
lodder = DataLoader(db, batch_size=32, shuffle=True, num_workers=4)

完整代碼如下:

import torch
import os, glob
import random, csv
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

class Pokemon(Dataset):
    def __init__(self, root, resize, model):
        super(Pokemon, self).__init__()
        
        self.root = root
        self.resize = resize

        self.name2label = {} # 將文件夾的名字映射爲label(數字)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root, name)):
                continue
            self.name2label[name] = len(self.name2label.keys())

        # image, label
        self.images, self.labels = self.load_csv('images.csv')
        
        if model == 'train': # 60%
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif model == 'val': # 20%
            self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
            self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
        else: # 20%
            self.images = self.images[int(0.8*len(self.images)):]
            self.labels = self.labels[int(0.8*len(self.labels)):]
    
    def load_csv(self, filename):
        if not os.path.exists(os.path.join(self.root, filename)):
            images = []
            for name in self.name2label.keys():
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))

            random.shuffle(images)
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                writer = csv.writer(f)
                for img in images: # pokemon\\bulbasaur\\00000000.png
                    name = img.split(os.sep)[-2] # bulbasaur
                    label = self.name2label[name]
                    # pokemon\\bulbasaur\\00000000.png 0
                    writer.writerow([img, label])
                print('writen into csv file:', filename)

        # read csv file
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                image, label = row
                label = int(label)
                images.append(image)
                labels.append(label)
        assert len(images) == len(labels)
        return images, labels
        
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # idx [0~len(images)]
        # self.images, self.labels
        # pokemon\\bulbasaur\\00000000.png    0
        img, label = self.images[idx], self.labels[idx]
        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'), # string path => image data
            transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
            transforms.RandomRotation(15),
            transforms.CenterCrop(self.resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        img = tf(img)
        label = torch.tensor(label)
        
        return img, label

db = Pokemon('pokemon', 224, 'train')
lodder = DataLoader(db, batch_size=32, shuffle=True, num_workers=8)

Build Model

用PyTorch搭建ResNet其實在我之前的文章已經講過了,這裏直接拿來用,修改一下里面的參數就行了

import torch
import torch.nn as nn
import torch.nn.functional as F

class ResBlk(nn.Module):
    def __init__(self, ch_in, ch_out, stride=1):
        super(ResBlk, self).__init__()
        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)
        
        self.extra = nn.Sequential()
        if ch_out != ch_in:
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
                nn.BatchNorm2d(ch_out),
            )
        
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        # short cut
        out = self.extra(x) + out
        out = F.relu(out)
        
        return out
        
class ResNet18(nn.Module):
    def __init__(self, num_class):
        super(ResNet18, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=0),
            nn.BatchNorm2d(16),
        )
        # followed 4 blocks
        
        # [b, 16, h, w] => [b, 32, h, w]
        self.blk1 = ResBlk(16, 32, stride=3)
        # [b, 32, h, w] => [b, 64, h, w]
        self.blk2 = ResBlk(32, 64, stride=3)
        # [b, 64, h, w] => [b, 128, h, w]
        self.blk3 = ResBlk(64, 128, stride=2)
        # [b, 128, h, w] => [b, 256, h, w]
        self.blk4 = ResBlk(128, 256, stride=2)
        
        self.outlayer = nn.Linear(256*3*3, num_class)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        
        x = self.blk1(x)
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)
        
        x = x.view(x.size(0), -1)
        x = self.outlayer(x)
        
        return x

Train and Test

訓練的時候,嚴格按照Training和Test的邏輯,就是在訓練epoch的過程中,間斷的做一次validation,然後看一下當前的validation accuracy是不是最高的,如果是最高的,就把當前的模型參數保存起來。training完以後,加載最好的模型,再做testing。這就是非常嚴格的訓練邏輯。代碼如下:

batchsz = 32
lr = 1e-3
epochs = 10
device = torch.device('cuda')
torch.manual_seed(1234)

train_db = Pokemon('pokemon', 224, model='train')
val_db = Pokemon('pokemon', 224, model='val')
test_db = Pokemon('pokemon', 224, model='test')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=2)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)


def evalute(model, loader):
    correct = 0
    total = len(loader.dataset)
    for x,y in loader:
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()
    return correct / total

def main():
    model = ResNet18(5)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()
    
    best_acc, best_epoch = 0, 0
    for epoch in range(epochs):
        for step, (x, y) in enumerate(train_loader):
            # x:[b, 3, 224, 224], y:[b]
            logits = model(x)
            loss = criteon(logits, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        if epoch % 2 == 0:
            val_acc = evalute(model, val_loader)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc
                torch.save(model.state_dict(), 'best.mdl')
                
    print('best acc:', best_acc, 'best_epoch', best_epoch)
    
    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckt!')
    
    test_acc = evalute(model, test_loader)
    print('test_acc:', test_acc)

截至到目前爲止,能完整運行的代碼如下:

import torch
import os, glob
import warnings
import random, csv
from PIL import Image
from torch import optim, nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
warnings.filterwarnings('ignore')


class Pokemon(Dataset):
    def __init__(self, root, resize, model):
        super(Pokemon, self).__init__()
        
        self.root = root
        self.resize = resize

        self.name2label = {} # 將文件夾的名字映射爲label(數字)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root, name)):
                continue
            self.name2label[name] = len(self.name2label.keys())

        # image, label
        self.images, self.labels = self.load_csv('images.csv')
        
        if model == 'train': # 60%
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif model == 'val': # 20%
            self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
            self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
        else: # 20%
            self.images = self.images[int(0.8*len(self.images)):]
            self.labels = self.labels[int(0.8*len(self.labels)):]
    
    def load_csv(self, filename):
        if not os.path.exists(os.path.join(self.root, filename)):
            images = []
            for name in self.name2label.keys():
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))

            random.shuffle(images)
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                writer = csv.writer(f)
                for img in images: # pokemon\\bulbasaur\\00000000.png
                    name = img.split(os.sep)[-2] # bulbasaur
                    label = self.name2label[name]
                    # pokemon\\bulbasaur\\00000000.png 0
                    writer.writerow([img, label])
                print('writen into csv file:', filename)

        # read csv file
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                image, label = row
                label = int(label)
                images.append(image)
                labels.append(label)
        assert len(images) == len(labels)
        return images, labels
        
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # idx [0~len(images)]
        # self.images, self.labels
        # pokemon\\bulbasaur\\00000000.png    0
        img, label = self.images[idx], self.labels[idx]
        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'), # string path => image data
            transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
            transforms.RandomRotation(15),
            transforms.CenterCrop(self.resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        img = tf(img)
        label = torch.tensor(label)
        
        return img, label

class ResBlk(nn.Module):
    def __init__(self, ch_in, ch_out, stride=1):
        super(ResBlk, self).__init__()
        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)
        
        self.extra = nn.Sequential()
        if ch_out != ch_in:
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
                nn.BatchNorm2d(ch_out),
            )
        
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        # short cut
        out = self.extra(x) + out
        out = F.relu(out)
        
        return out
        
class ResNet18(nn.Module):
    def __init__(self, num_class):
        super(ResNet18, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=0),
            nn.BatchNorm2d(16),
        )
        # followed 4 blocks
        
        # [b, 16, h, w] => [b, 32, h, w]
        self.blk1 = ResBlk(16, 32, stride=3)
        # [b, 32, h, w] => [b, 64, h, w]
        self.blk2 = ResBlk(32, 64, stride=3)
        # [b, 64, h, w] => [b, 128, h, w]
        self.blk3 = ResBlk(64, 128, stride=2)
        # [b, 128, h, w] => [b, 256, h, w]
        self.blk4 = ResBlk(128, 256, stride=2)
        
        self.outlayer = nn.Linear(256*3*3, num_class)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        
        x = self.blk1(x)
        x = self.blk2(x)
        x = self.blk3(x)
        x = self.blk4(x)
        
        x = x.view(x.size(0), -1)
        x = self.outlayer(x)
        
        return x
    
batchsz = 32
lr = 1e-3
epochs = 10
device = torch.device('cuda')
torch.manual_seed(1234)

train_db = Pokemon('pokemon', 224, model='train')
val_db = Pokemon('pokemon', 224, model='val')
test_db = Pokemon('pokemon', 224, model='test')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=2)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)


def evalute(model, loader):
    correct = 0
    total = len(loader.dataset)
    for x,y in loader:
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()
    return correct / total

def main():
    model = ResNet18(5)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()
    
    best_acc, best_epoch = 0, 0
    for epoch in range(epochs):
        for step, (x, y) in enumerate(train_loader):
            # x:[b, 3, 224, 224], y:[b]
            logits = model(x)
            loss = criteon(logits, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        if epoch % 2 == 0:
            val_acc = evalute(model, val_loader)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc
                torch.save(model.state_dict(), 'best.mdl')
                
    print('best acc:', best_acc, 'best_epoch', best_epoch)
    
    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckt!')
    
    test_acc = evalute(model, test_loader)
    print('test_acc:', test_acc)

if __name__ == '__main__':
    main()

Transfer Learning

運行上面的代碼,基本上最終test accuracy可以達到0.88左右。如果想要提升的話,就需要使用更多工程上的tricks或者調參

當然還有一種方法,就是遷移學習,我們先看下面這張圖,這張圖展示的問題在於,當數據很少的情況下(第一張圖),模型訓練的結果可能會有很多情況(第二張圖),當然最終輸出就一個結果。然而這個結果可能test accuracy並不高。就比方說我們的pokemon圖片,只有1000多張,算是一個比較少的數據集了,但是由於pokemon和ImageNet都是圖片,它們可能存在某些共性。那我們能不能用ImageNet的一些train好的模型,拿來幫助我們解決一下特定的圖片分類任務,這就是Transfer Learning,也就是在A任務上train好一個分類器,再transfer到B上去

我個人理解Transfer Learning的作用是這樣的,我們都知道神經網絡初始化參數非常重要,有時候初始化不好,可能就會導致最終效果非常差。現在我們用一個在A任務上已經訓練好了的網絡,相當於幫你做了一個很好的初始化,你在這個網絡的基礎上,去做B任務,如果這兩個任務比較接近的話,誇張一點說,這個網絡的訓練可能就只需要微調一下,就能在B任務上顯示出非常好的效果

下圖展示的是一個真實的Transfer Learning的過程,左邊是已經training好的網絡,我們利用這個網絡的公有部分,吸取它的common knowledge, 然後把最後一層去掉,換成我們需要的

先上核心代碼

import torch.nn as nn
from torchvision.models import resnet18

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
    
    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape) 

trained_model = resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1],# [b, 512, 1, 1]
                      Flatten(), # [b, 512, 1, 1] => [b, 512]
                      nn.Linear(512, 5) # [b, 512] => [b, 5]
                     )

PyTorch中有已經訓練好的各種規格的resnet,第一次使用需要下載。我們不要resnet18的最後一層,所以要用 list(trained_model.children())[:-1] 把除了最後一層以外的所有層都取出來,保存在list中,然後用 * 將其list展開,之後接一個我們自定義的Flatten層,作用是將output打平,打平以後才能送到Linear層去

上面幾行代碼就實現了Transfer Learning,而且不需要我們自己實現resnet,完整代碼如下

import torch
import os, glob
import warnings
import random, csv
from PIL import Image
from torch import optim, nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.models import resnet18
from torch.utils.data import Dataset, DataLoader
warnings.filterwarnings('ignore')
from matplotlib import pyplot as plt


class Pokemon(Dataset):
    def __init__(self, root, resize, model):
        super(Pokemon, self).__init__()
        
        self.root = root
        self.resize = resize

        self.name2label = {} # 將文件夾的名字映射爲label(數字)
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root, name)):
                continue
            self.name2label[name] = len(self.name2label.keys())

        # image, label
        self.images, self.labels = self.load_csv('images.csv')
        
        if model == 'train': # 60%
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif model == 'val': # 20%
            self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
            self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
        else: # 20%
            self.images = self.images[int(0.8*len(self.images)):]
            self.labels = self.labels[int(0.8*len(self.labels)):]
    
    def load_csv(self, filename):
        if not os.path.exists(os.path.join(self.root, filename)):
            images = []
            for name in self.name2label.keys():
                images += glob.glob(os.path.join(self.root, name, '*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))

            random.shuffle(images)
            with open(os.path.join(self.root, filename), mode='w', newline='') as f:
                writer = csv.writer(f)
                for img in images: # pokemon\\bulbasaur\\00000000.png
                    name = img.split(os.sep)[-2] # bulbasaur
                    label = self.name2label[name]
                    # pokemon\\bulbasaur\\00000000.png 0
                    writer.writerow([img, label])
                print('writen into csv file:', filename)

        # read csv file
        images, labels = [], []
        with open(os.path.join(self.root, filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                image, label = row
                label = int(label)
                images.append(image)
                labels.append(label)
        assert len(images) == len(labels)
        return images, labels
        
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # idx [0~len(images)]
        # self.images, self.labels
        # pokemon\\bulbasaur\\00000000.png    0
        img, label = self.images[idx], self.labels[idx]
        tf = transforms.Compose([
            lambda x:Image.open(x).convert('RGB'), # string path => image data
            transforms.Resize((int(self.resize*1.25), int(self.resize*1.25))),
            transforms.RandomRotation(15),
            transforms.CenterCrop(self.resize),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        img = tf(img)
        label = torch.tensor(label)
        
        return img, label
    
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
    
    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)    
    
batchsz = 32
lr = 1e-3
epochs = 10
device = torch.device('cuda')
torch.manual_seed(1234)

train_db = Pokemon('pokemon', 224, model='train')
val_db = Pokemon('pokemon', 224, model='val')
test_db = Pokemon('pokemon', 224, model='test')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True, num_workers=2)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)


def evalute(model, loader):
    correct = 0
    total = len(loader.dataset)
    for x,y in loader:
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, y).sum().float().item()
    return correct / total

def main():
    trained_model = resnet18(pretrained=True)
    model = nn.Sequential(*list(trained_model.children())[:-1],# [b, 512, 1, 1]
                          Flatten(), # [b, 512, 1, 1] => [b, 512]
                          nn.Linear(512, 5)
                         )
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criteon = nn.CrossEntropyLoss()
    
    best_acc, best_epoch = 0, 0
    for epoch in range(epochs):
        for step, (x, y) in enumerate(train_loader):
            # x:[b, 3, 224, 224], y:[b]
            logits = model(x)
            loss = criteon(logits, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        if epoch % 2 == 0:
            val_acc = evalute(model, val_loader)
            if val_acc > best_acc:
                best_epoch = epoch
                best_acc = val_acc
                torch.save(model.state_dict(), 'best.mdl')
                
    print('best acc:', best_acc, 'best_epoch', best_epoch)
    
    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckt!')
    
    test_acc = evalute(model, test_loader)
    print('test_acc:', test_acc)

if __name__ == '__main__':
    main()

最終test accuracy在0.94左右,比我們自己從0開始訓練效果好了很多

相關文章