Fine tune là gì

1. Introduction

1.1 Fine-tuning là gì ?

Chắc hẳn hầu như ai thao tác cùng với những Model trong deep learning phần nhiều sẽ nghe/thân quen với khái niệm Transfer learning với Fine tuning. Khái niệm tổng quát: Transfer learning là tận dụng tối đa học thức học tập được từ một vụ việc nhằm áp dụng vào 1 vụ việc gồm tương quan không giống. Một ví dụ 1-1 giản: núm vày train 1 model mới hoàn toàn cho bài bác toán phân loại chó/mèo, bạn ta có thể tận dụng 1 Model đã làm được train ở ImageNet dataset cùng với hằng triệu ảnh. Pre-trained Model này sẽ được train tiếp bên trên tập dataphối chó/mèo, quy trình train này diễn ra nhanh hao rộng, công dụng thường tốt rộng. Có không ít kiểu Transfer learning, các bạn cũng có thể xem thêm trong bài này: Tổng thích hợp Transfer learning. Trong bài này, mình vẫn viết về 1 dạng transfer learning phổ biến: Fine-tuning.

Bạn đang xem: Fine tune là gì

Hiểu đơn giản dễ dàng, fine-tuning là các bạn lấy 1 pre-trained mã sản phẩm, tận dụng tối đa một trong những phần hoặc tổng thể các layer, thêm/sửa/xoá 1 vài ba layer/nhánh để tạo ra 1 Mã Sản Phẩm bắt đầu. Thường những layer đầu của model được freeze (đóng băng) lại - tức weight các layer này đã không trở nên chuyển đổi quý giá trong quy trình train. Lý vì chưng vày các layer này sẽ có khả năng trích xuất biết tin nấc trìu tượng tốt , kĩ năng này được học tự quy trình training trước đó. Ta freeze lại để tận dụng tối đa được tài năng này cùng giúp câu hỏi train ra mắt nkhô giòn rộng (mã sản phẩm chỉ đề nghị update weight sinh sống những layer cao). Có tương đối nhiều các Object detect mã sản phẩm được chế tạo dựa vào những Classifier Mã Sản Phẩm. VD Retimãng cầu mã sản phẩm (Object detect) được chế tạo cùng với backbone là Resnet.

*

1.2 Tại sao pytorch rứa vày Keras ?

Chủ đề nội dung bài viết từ bây giờ, bản thân đã trả lời fine-tuning Resnet50 - 1 pre-trained mã sản phẩm được hỗ trợ sẵn vào torchvision của pytorch. Tại sao là pytorch nhưng chưa phải Keras ? Lý vì chưng vị bài toán fine-tuning Model trong keras siêu đơn giản dễ dàng. Dưới đấy là 1 đoạn code minc hoạ đến câu hỏi chế tạo 1 Unet dựa trên Resnet trong Keras:

from tensorflow.keras import applicationsresnet = applications.resnet50.ResNet50()layer_3 = resnet.get_layer("activation_9").outputlayer_7 = resnet.get_layer("activation_21").outputlayer_13 = resnet.get_layer("activation_39").outputlayer_16 = resnet.get_layer("activation_48").output#Adding outputs decoder with encoder layersfcn1 = Conv2D(...)(layer_16)fcn2 = Conv2DTranspose(...)(fcn1)fcn2_skip_connected = Add()()fcn3 = Conv2DTranspose(...)(fcn2_skip_connected)fcn3_skip_connected = Add()()fcn4 = Conv2DTranspose(...)(fcn3_skip_connected)fcn4_skip_connected = Add()()fcn5 = Conv2DTranspose(...)(fcn4_skip_connected)Unet = Model(inputs = resnet.input đầu vào, outputs=fcn5)quý khách có thể thấy, fine-tuning Model trong Keras thực sự hết sức đơn giản và dễ dàng, dễ dàng làm, dễ hiểu. Việc add thêm những nhánh rất đơn giản bởi cú pháp đơn giản và dễ dàng. Trong pytorch thì trở lại, desgin 1 Mã Sản Phẩm Unet giống như sẽ khá vất vả với phức tạp. Người mới học tập vẫn chạm mặt trở ngại bởi bên trên mạng rất ít những khuyên bảo cho bài toán này. Vậy nên bài bác này bản thân sẽ chỉ dẫn chi tiết phương pháp fine-tune vào pytorch nhằm vận dụng vào bài xích tân oán Visual Saliency prediction

2. Visual Saliency prediction

2.1 What is Visual Saliency ?

*

Khi nhìn vào 1 tấm hình, mắt thông thường có Xu thế triệu tập quan sát vào 1 vài đơn vị chính. Ảnh trên đấy là 1 minh hoạ, màu sắc rubi được thực hiện nhằm biểu hiện cường độ ham mê. Saliency prediction là bài tân oán mô bỏng sự tập trung của mắt người Lúc quan lại gần cạnh 1 tấm hình. Cụ thể, bài xích tân oán đòi hỏi sản xuất 1 Model, model này nhấn hình họa đầu vào, trả về 1 mask mô phỏng cường độ đắm đuối. vì vậy, Mã Sản Phẩm dìm vào 1 input image cùng trả về 1 mask bao gồm kích cỡ tương đương.

Để rõ rộng về bài bác tân oán này, bạn có thể đọc bài: Visual Saliency Prediction with Contextual Encoder-Decoder Network.Datamix phổ cập nhất: SALICON DATASET

2.2 Unet

Note: Quý Khách hoàn toàn có thể làm lơ phần này ví như đang biết về Unet

Đây là một bài xích tân oán Image-to-Image. Để xử lý bài bác toán thù này, bản thân sẽ xây dựng dựng 1 model theo phong cách thiết kế Unet. Unet là 1 trong những kiến trúc được thực hiện nhiều trong bài tân oán Image-to-image như: semantic segmentation, tự động color, super resolution ... Kiến trúc của Unet tất cả điểm tương tự với kiến trúc Encoder-Decoder đối xứng, nhận thêm các skip connection từ Encode quý phái Decode tương ứng. Về cơ bạn dạng, các layer càng cao càng trích xuất công bố tại mức trìu tượng cao, điều này đồng nghĩa tương quan cùng với bài toán những lên tiếng nấc trìu tượng tốt nhỏng đường đường nét, màu sắc, độ phân giải... sẽ ảnh hưởng mất non đi vào quá trình lan truyền. Người ta thêm những skip-connection vào nhằm giải quyết và xử lý vụ việc này.

Với phần Encode, feature-bản đồ được downscale bởi những Convolution. trái lại, ở vị trí decode, feature-bản đồ được upscale bởi vì các Upsampling layer, vào bài này mình áp dụng các Convolution Transpose.

*

2.3 Resnet

Để giải quyết bài bác toán thù, bản thân sẽ xây dựng Mã Sản Phẩm Unet cùng với backbone là Resnet50. Quý khách hàng nên tò mò về Resnet nếu như không biết về kiến trúc này. Hãy quan liêu cạnh bên hình minh hoạ dưới đây. Resnet50 được chia thành những kăn năn phệ . Unet được phát hành cùng với Encoder là Resnet50. Ta vẫn lấy ra output của từng kân hận, chế tạo ra những skip-connection kết nối tự Encoder sang Decoder. Decoder được gây ra bởi vì những Convolution Transpose layer (xen kẹt trong đó là những lớp Convolution nhằm mục tiêu mục tiêu bớt số chanel của feature bản đồ -> sút con số weight cho model).

Theo ý kiến cá nhân, pytorch rất giản đơn code, dễ dàng nắm bắt hơn không hề ít so với Tensorflow 1.x hoặc ngang ngửa Keras. Tuy nhiên, Việc fine-tuning model vào pytorch lại cạnh tranh rộng không ít đối với Keras. Trong Keras, ta không phải thừa quyên tâm tới phong cách thiết kế, luồng cách xử lý của mã sản phẩm, chỉ cần mang ra các output trên 1 số ít layer nhất quyết có tác dụng skip-connection, ghép nối cùng tạo thành Mã Sản Phẩm bắt đầu.

*

Trong pytorch thì ngược chở lại, bạn phải phát âm được luồng xử lý cùng copy code phần nhiều layer mong mỏi giữ lại trong model bắt đầu. Hình trên là code của resnet50 trong torchvision. quý khách hàng hoàn toàn có thể xem thêm link: torchvision-resnet50. bởi vậy Lúc xây dừng Unet như bản vẽ xây dựng sẽ biểu thị trên, ta bắt buộc đảm bảo an toàn đoạn code từ bỏ Conv1 -> Layer4 không bị đổi khác. Hãy phát âm phần tiếp theo sau để làm rõ rộng.

Xem thêm: Cách Sửa Máy Tính Ko Lên Màn Hình Là Lỗi Gì ? Cách Khắc Phục Triệt Để

3. Code

Tất cả code của bản thân được gói gọn trong tệp tin notebook Salicon_main.ipynb. Bạn hoàn toàn có thể cài đặt về và run code theo link github: github/trungthanhnguyen0502 . Trong bài viết mình đã chỉ giới thiệu hầu như đoạn code bao gồm.

Import các package

import albumentations as Aimport numpy as npimport torchimport torchvisionimport torch.nn as nn import torchvision.transforms as Timport torchvision.models as modelsfrom torch.utils.data import DataLoader, Datasetimport ....

3.1 utils functions

Trong pytorch, tài liệu có trang bị từ dimension không giống cùng với Keras/TF/numpy. Đôi khi cùng với numpy tuyệt keras, hình ảnh gồm dimension theo lắp thêm trường đoản cú (batchkích cỡ,h,w,chanel)(batchkích cỡ, h, w, chanel)(batchkích cỡ,h,w,chanel). Thứ trường đoản cú vào Pytorch ngược chở lại là (batchform size,chanel,h,w)(batchsize, chanel, h, w)(batchsize,chanel,h,w). Mình sẽ xây dựng dựng 2 hàm toTensor với toNumpy nhằm biến đổi qua lại thân nhị format này.

def toTensor(np_array, axis=(2,0,1)): return torch.tensor(np_array).permute(axis)def toNumpy(tensor, axis=(1,2,0)): return tensor.detach().cpu().permute(axis).numpy() ## display one image in notebookdef plot_img(img): ... ## display multi imagedef plot_imgs(imgs): ...

3.2 Define model

3.2.1 Conv và Deconv

Mình sẽ xây dựng dựng 2 function trả về module Convolution với Convolution Transpose (Deconv)

def Deconv(n_input đầu vào, n_output, k_size=4, stride=2, padding=1): Tconv = nn.ConvTranspose2d( n_đầu vào, n_output, kernel_size=k_form size, stride=stride, padding=padding, bias=False) bloông xã = < Tconv, nn.BatchNorm2d(n_output), nn.LeakyReLU(inplace=True), > return nn.Sequential(*block) def Conv(n_đầu vào, n_output, k_size=4, stride=2, padding=0, bn=False, dropout=0): conv = nn.Conv2d( n_input đầu vào, n_output, kernel_size=k_kích cỡ, stride=stride, padding=padding, bias=False) block = < conv, nn.BatchNorm2d(n_output), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(dropout) > return nn.Sequential(*block)

3.2.2 Unet model

Init function: ta sẽ copy các layer phải giữ trường đoản cú resnet50 vào unet. Sau đó khởi tạo ra các Conv / Deconv layer và các layer quan trọng.

Forward function: bắt buộc đảm bảo an toàn luồng xử lý của resnet50 được giữ nguyên tương tự code gốc (trừ Fully-connected layer). Sau đó ta ghnghiền nối những layer lại theo phong cách thiết kế Unet vẫn diễn đạt vào phần 2.

Tạo model: cần load resnet50 cùng truyền vào Unet. Đừng quên Freeze những layer của resnet50 vào Unet.

Xem thêm: " Khóc Cho Nhau Một Lần Rồi Thôi Không Gặp Nhau Nữa Đâu Là Bài Gì

class Unet(nn.Module): def __init__(self, resnet): super().__init__() self.conv1 = resnet.conv1 self.bn1 = resnet.bn1 self.relu = resnet.relu self.maxpool = resnet.maxpool self.tanh = nn.Tanh() self.sigmoid = nn.Sigmoid() # get some layer from resnet khổng lồ make skip connection self.layer1 = resnet.layer1 self.layer2 = resnet.layer2 self.layer3 = resnet.layer3 self.layer4 = resnet.layer4 # convolution layer, use lớn reduce the number of channel => reduce weight number self.conv_5 = Conv(2048, 512, 1, 1, 0) self.conv_4 = Conv(1536, 512, 1, 1, 0) self.conv_3 = Conv(768, 256, 1, 1, 0) self.conv_2 = Conv(384, 128, 1, 1, 0) self.conv_1 = Conv(128, 64, 1, 1, 0) self.conv_0 = Conv(32, 1, 3, 1, 1) # deconvolution layer self.deconv4 = Deconv(512, 512, 4, 2, 1) self.deconv3 = Deconv(512, 256, 4, 2, 1) self.deconv2 = Deconv(256, 128, 4, 2, 1) self.deconv1 = Deconv(128, 64, 4, 2, 1) self.deconv0 = Deconv(64, 32, 4, 2, 1) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) skip_1 = x x = self.maxpool(x) x = self.layer1(x) skip_2 = x x = self.layer2(x) skip_3 = x x = self.layer3(x) skip_4 = x x5 = self.layer4(x) x5 = self.conv_5(x5) x4 = self.deconv4(x5) x4 = torch.cat(, dim=1) x4 = self.conv_4(x4) x3 = self.deconv3(x4) x3 = torch.cat(, dim=1) x3 = self.conv_3(x3) x2 = self.deconv2(x3) x2 = torch.cat(, dim=1) x2 = self.conv_2(x2) x1 = self.deconv1(x2) x1 = torch.cat(, dim=1) x1 = self.conv_1(x1) x0 = self.deconv0(x1) x0 = self.conv_0(x0) x0 = self.sigmoid(x0) return x0 device = torch.device("cuda")resnet50 = models.resnet50(pretrained=True)model = Unet(resnet50)mã sản phẩm.to(device)## Freeze resnet50"s layers in Unetfor i, child in enumerate(Mã Sản Phẩm.children()): if i 7: for param in child.parameters(): param.requires_grad = False

3.3 Dataset và Dataloader

Dataphối trả nhận 1 các mục những image_path và mask_dir, trả về image và mask khớp ứng.

Define MaskDataset

class MaskDataset(Dataset): def __init__(self, img_fns, mask_dir, transforms=None): self.img_fns = img_fns self.transforms = transforms self.mask_dir = mask_dir def __getitem__(self, idx): img_path = self.img_fns img_name = img_path.split("/")<-1>.split(".")<0> mask_fn = f"self.mask_dir/img_name.png" img = cv2.imread(img_path) mask = cv2.imread(mask_fn) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) if self.transforms: sample = "image": img, "mask": mask sample = self.transforms(**sample) img = sample<"image"> mask = sample<"mask"> # to lớn Tensor img = img/255.0 mask = np.expand_dims(mask, axis=-1)/255.0 mask = toTensor(mask).float() img = toTensor(img).float() return img, mask def __len__(self): return len(self.img_fns)Test dataset

img_fns = glob("./Salicon_dataset/image/train/*.jpg")mask_dir = "./Salicon_dataset/mask/train"train_transform = A.Compose(< A.Resize(width=256,height=256, p=1), A.RandomSizedCrop(<240,256>, height=256, width=256, p=0.4), A.HorizontalFlip(p=0.5), A.Rotate(limit=(-10,10), p=0.6),>)train_dataphối = MaskDataset(img_fns, mask_dir, train_transform)train_loader = DataLoader(train_dataphối, batch_size=4, shuffle=True, drop_last=True)# Test datasetimg, mask = next(iter(train_dataset))img = toNumpy(img)mask = toNumpy(mask)<:,:,0>img = (img*255.0).astype(np.uint8)mask = (mask*255.0).astype(np.uint8)heatmap_img = cv2.applyColorMap(mask, cv2.COLORMAP_JET)combine_img = cv2.addWeighted(img, 0.7, heatmap_img, 0.3, 0)plot_imgs(

3.4 Train model

Vì bài bác toán đơn giản và dễ dàng với để cho dễ hiểu, mình đang train theo cách đơn giản và dễ dàng nhất, không validate vào qúa trình train nhưng chỉ lưu lại Model sau 1 số epoch độc nhất vô nhị định

train_params = optimizer = torch.optim.Adam(train_params, lr=0.001, betas=(0.9, 0.99))epochs = 5Model.train()saved_dir = "model"os.makedirs(saved_dir, exist_ok=True)loss_function = nn.MSELoss(reduce="mean")for epoch in range(epochs): for imgs, masks in tqdm(train_loader): imgs_gpu = imgs.to(device) outputs = model(imgs_gpu) masks = masks.to(device) loss = loss_function(outputs, masks) loss.backward() optimizer.step()

3.5 Test model

img_fns = glob("./Salicon_dataset/image/val/*.jpg")mask_dir = "./Salicon_dataset/mask/val"val_transkhung = A.Compose(< A.Resize(width=256,height=256, p=1), A.HorizontalFlip(p=0.5),>)mã sản phẩm.eval()val_dataphối = MaskDataset(img_fns, mask_dir, val_transform)val_loader = DataLoader(val_dataphối, batch_size=4, shuffle=False, drop_last=True)imgs, mask_targets = next(iter(val_loader))imgs_gpu = imgs.to(device)mask_outputs = model(imgs_gpu)mask_outputs = toNumpy(mask_outputs, axis=(0,2,3,1))imgs = toNumpy(imgs, axis=(0,2,3,1))mask_targets = toNumpy(mask_targets, axis=(0,2,3,1))for i, img in enumerate(imgs): img = (img*255.0).astype(np.uint8) mask_output = (mask_outputs*255.0).astype(np.uint8) mask_target = (mask_targets*255.0).astype(np.uint8) heatmap_label = cv2.applyColorMap(mask_target, cv2.COLORMAP_JET) heatmap_pred = cv2.applyColorMap(mask_output, cv2.COLORMAP_JET) origin_img = cv2.addWeighted(img, 0.7, heatmap_label, 0.3, 0) predict_img = cv2.addWeighted(img, 0.7, heatmap_pred, 0.3, 0) result = np.concatenate((img,origin_img, predict_img),axis=1) plot_img(result)Kết trái thu được:

*

Đây là bài xích tân oán đơn giản và dễ dàng bắt buộc mình chú trọng vào quá trình cùng cách thức fine tuning trong pytorch rộng là đi sâu vào giải quyết và xử lý bài toán. Cảm ơn các bạn đang đọc

4. Reference

Dataset: salinhỏ.net

Code bài xích viết: https://github.com/trungthanhnguyen0502/-tienhieptruyenky.com-Visual-Saliency-prediction

Resnet50 torchvision code: torchvision-resnet

Bài viết cùng chủ đề Visual saliency: Visual Saliency Prediction with Contextual Encoder-Decoder Network!

Theo dõi các nội dung bài viết nâng cao về AI/Deep learning tại: Vietnam giới AI Link Sharing Community


Chuyên mục: Hỏi đáp