pytorch 질문

비회원
2023.02.08 15:41 1,137 조회
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(3,6,5)
    self.pool = nn.MaxPool2d(2,2)
    self.conv2 = nn.Conv2d(6,16,5)
    self.fc1 = nn.Linear(16*5*5, 120)
    self.fc2 = nn.Linear(120,84)
    self.fc3 = nn.Linear(84,10)
  
  def forward(self,x):
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = torch.flatten(x,1) #배치를 제외한 모든 차원을 평평하게 한다.
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x
  
net = Net()


파이토치를 공부하던 중 갑자기 의문점이 생기는데,


self.conv1 = nn.Conv2d(3,6,5)

self.conv2 = nn.Conv2d(6,16,5)


혹시 여기 코드를 해석하면


self.conv1 = nn.Conv2d(입력 체널 =3(rgb)이므로, 출력 체널 = 6, 커널 사이즈 = 5)

self.conv2 = nn.Conv2d(입력체널 = 6(행렬곱을 위해 conv1의 출력체널과 맞추어 줌),출력체널 = 16,커널 사이즈 = 5)


이정도로 원래 알고 있었는데, 각 출력체널이 왜 6,16인지 모르겠습니다.


갑자기 든 의문이라, 혹시 공식같은게 있을까요?