๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
Deep Learning/[์ฝ”๋“œ ๊ตฌํ˜„] DL Architecture ๊ตฌํ˜„

[UNet] copy and crop ์ฝ”๋“œ ๊ตฌํ˜„ ๋ฐ ์•„ํ‚คํ…์ฒ˜ ๊ตฌํ˜„ํ•˜๊ธฐ (Pytorch)

by ์ œ๋ฃฝ 2024. 2. 8.
728x90
๋ฐ˜์‘ํ˜•

๋‹ค๋ฅธ ๋ธ”๋กœ๊ทธ๋“ค์—์„œ๋Š” padding์„ 2๋กœ ๋งž์ถฐ์„œ, ๋”ฐ๋กœ copy and crop์„ ํ•˜์ง„ ์•Š์•˜์ง€๋งŒ, ํ•ด๋‹น ๋…ผ๋ฌธ์˜ ๊ทธ๋ฆผ์„ ๋ณด๊ณ  ์žฌํ˜„ํ•˜๊ธฐ ์œ„ํ•ด padding=0์œผ๋กœ ์„ค์ •ํ•ด์„œ ๋ณธ ์•„ํ‚คํ…์ฒ˜๋ฅผ ์ƒˆ๋กœ ๊ตฌ์ถ•ํ•˜์˜€๋‹ค.

input size ๋ฐ padding์„ 0์œผ๋กœ ํ–ˆ์„ ๋•Œ์˜ output feature map size์— ๋Œ€ํ•œ ๊ตฌ์กฐ๋„

 

import torch.nn as nn
import torch.nn.functional as F
import torch
import os
import numpy as np

from torchvision import transforms

#https://sd118687.tistory.com/8 (Mirroring extrapolate)

class UNet(nn.Module):
    #copy and crop ํ•จ์ˆ˜ ์ •์˜
    #batch_size, channels, height, width
    def copy_and_crop(self, input, output):
        output_size = output.size()[2:]
        input_size = input.size()[2:]
        
        #์ˆ˜์ถ• ๋ถ€๋ถ„๊ณผ ํ™•์žฅ ๋ถ€๋ถ„์˜ w,h ํฌ๊ธฐ ๋น„๊ตํ•ด์„œ ๊ฐ€์žฅ ์ž‘์€ ๊ฐ’์˜ ์‚ฌ์ด์ฆˆ ์ถ”์ถœ
        crop_size = [min(i,o) for i,o in zip(input_size, output_size)]
        
        #์ž๋ฅผ ์‹œ์ž‘์  ์„ ์ • (๊ทธ๋ฆผ์ฐธ๊ณ )
        crop_start = [int((i - c) / 2) for i, c in zip(input_size, crop_size)]
        
        #์ธ๋ฑ์‹ฑ์€ ์ฐจ์›์„ ๋‚ฎ์ถ”๊ณ , ์Šฌ๋ผ์ด์‹ฑ์€ ์ฐจ์›์„ ๋ณด์กด!!!!!!!!!!!!!
        cropped_input = input[:, :, crop_start[0]:crop_start[0] + crop_size[0], crop_start[1]:crop_start[1] + crop_size[1]]
        
        return cropped_input
        
    def __init__(self):
        super(UNet, self).__init__() #.__init__()
        
        def act(in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias=True):
            layers=[]
            layers+=[nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                               kernel_size=kernel_size, stride=stride, padding=padding,
                               bias=bias)]
            layers+=[nn.BatchNorm2d(num_features=out_channels)]#nn. , []
            layers+=[nn.ReLU()]
            
            #nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
            #                   kernel_size=kernel_size, stride=stride, padding=padding,
            #                   bias=bias), 
            #                   nn.BatchNorm2d(num_features=out_channels),
            #                   nn.ReLU())              
            results = nn.Sequential(*layers) #* (๊ฐ€๋ณ€์ธ์ž)
            
            #print(layers)
            #[Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1)), BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True), ReLU()]
            
            return results
        
        #Contractive path
        self.con1_1 = act(in_channels=1, out_channels=64)
        self.con1_2 = act(in_channels=64, out_channels=64)
        
        self.con_pool1 = nn.MaxPool2d(kernel_size=2) #2,2
        
        self.con2_1 = act(in_channels=64, out_channels=128)
        self.con2_2 = act(in_channels=128, out_channels=128)
        
        self.con_pool2 = nn.MaxPool2d(kernel_size=2)
        
        self.con3_1 = act(in_channels=128,out_channels=256)
        self.con3_2 = act(in_channels=256, out_channels=256)
        
        self.con_pool3 = nn.MaxPool2d(kernel_size=2)
        
        self.con4_1 = act(in_channels=256, out_channels=512)
        self.con4_2 = act(in_channels=512, out_channels=512)
        
        self.con_pool4 = nn.MaxPool2d(kernel_size=2)
        
        self.con5_1 = act(in_channels=512, out_channels=1024)
        
        
        #expansive path
        self.ex5_1 = act(in_channels=1024, out_channels=512) #๊ทธ๋ฆผ์ด ์ฐธ..
        
        self.ex_unpool4 = nn.ConvTranspose2d(in_channels=512, out_channels=512,
                                             kernel_size=2, stride=2, padding=0, bias=True)
        
        self.ex4_2 = act(in_channels=512*2,out_channels=512)
        self.ex4_1 = act(in_channels=512, out_channels=256)
        
        self.ex_unpool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256,
                                              kernel_size=2, stride=2, padding=0, bias=True)
        
        self.ex3_2 = act(in_channels=256*2, out_channels=256)
        self.ex3_1 = act(in_channels=256, out_channels=128)
        
        self.ex_unpool2 = nn.ConvTranspose2d(in_channels=128, out_channels=128,
                                              kernel_size=2, stride=2, padding=0, bias=True)
        
        self.ex2_2 = act(in_channels=128*2, out_channels=128)
        self.ex2_1 = act(in_channels=128, out_channels=64)
        
        self.ex_unpool1 = nn.ConvTranspose2d(in_channels=64, out_channels=64,
                                              kernel_size=2, stride=2, padding=0, bias=True)
        
        self.ex1_2 = act(in_channels=64*2, out_channels=64) #๋ฐ˜๋Œ€๋กœ์”€ ;
        self.ex1_1 = act(in_channels=64, out_channels=64)
        
        
        self.fc = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0,
                            bias = True) #bais ์˜คํƒ€
        
        
    def forward(self,input):
        con1_1 = self.con1_1(input)
        con1_2 = self.con1_2(con1_1)
        con_pool1 = self.con_pool1(con1_2)
        
        con2_1 = self.con2_1(con_pool1)
        con2_2 = self.con2_2(con2_1)
        con_pool2 = self.con_pool2(con2_2)
        
        con3_1 = self.con3_1(con_pool2)
        con3_2 = self.con3_2(con3_1)
        con_pool3 = self.con_pool3(con3_2)
        
        con4_1 = self.con4_1(con_pool3)
        con4_2 = self.con4_2(con4_1)
        con_pool4 = self.con_pool4(con4_2)
        
        con5_1 = self.con5_1(con_pool4)
        
        ex5_1 = self.ex5_1(con5_1)
        
        #print(con1_1.size(), con1_2.size(), con_pool1.size(), con2_1.size(), con2_2.size(), con_pool2.size(), con3_1.size(), con3_2.size(), con_pool3.size(), con4_1.size(), con4_2.size(), con_pool4.size(), con5_1.size(), ex5_1.size())

        ex_unpool4 = self.ex_unpool4(ex5_1)
        cropped_con4_2 = self.copy_and_crop(con4_2,ex_unpool4)
        cat4= torch.cat((ex_unpool4,cropped_con4_2),dim=1)
        ex4_2 = self.ex4_2(cat4)
        ex4_1 = self.ex4_1(ex4_2)
        
        #print(ex_unpool4.size(),cropped_con4_2.size(), cat4.size(), ex4_2.size(),ex4_1.size() )
        
        ex_unpool3 = self.ex_unpool3(ex4_1)
        cropped_con3_2 = self.copy_and_crop(con3_2, ex_unpool3)
        cat3 = torch.cat((ex_unpool3, cropped_con3_2), dim=1)
        ex3_2 = self.ex3_2(cat3)
        ex3_1 = self.ex3_1(ex3_2)
        
        #print(ex_unpool3.size(),cropped_con3_2.size(), cat3.size(), ex3_2.size(),ex3_1.size() )
        
        ex_unpool2 = self.ex_unpool2(ex3_1)
        cropped_con2_2 = self.copy_and_crop(con2_2, ex_unpool2)
        cat2 = torch.cat((ex_unpool2, cropped_con2_2),dim=1)
        ex2_2 = self.ex2_2(cat2)
        ex2_1 = self.ex2_1(ex2_2)
        
        #print(ex_unpool2.size(),cropped_con2_2.size(), cat2.size(), ex2_2.size(),ex2_1.size() )
        
        ex_unpool1 = self.ex_unpool1(ex2_1)
        cropped_con1_2 = self.copy_and_crop(con1_2, ex_unpool1)
        cat1 = torch.cat((ex_unpool1, cropped_con1_2),dim=1)
        ex1_2 = self.ex1_2(cat1)
        ex1_1 = self.ex1_1(ex1_2)
                
        x = self.fc(ex1_1)
        
        #print(ex_unpool1.size(),cropped_con1_2.size(), cat1.size(),ex1_2.size(),ex1_1.size(),x.size())
        
        return x
        
        # Must be a tensor with equal size along the class dimension to the number of classes.
        #BCELoss๋Š” output๊ณผ target ์‚ฌ์ด์ฆˆ๊ฐ€ ๊ฐ™์•„์•ผ ํ•œ๋‹ค๊ณ  ํ•จ..
728x90
๋ฐ˜์‘ํ˜•