[머신러닝을 위한 대수학 - 1] (텐서편) 과 이어집니다. 아래는 torch 내 einsum 함수의 예제 코드를 가볍게 작성해본 모습입니다. 우선 다음과 같이 세팅해줍니다. import torch
import torch.nn.functional as F
A = torch.tensor([
[1, 2],
[3, 4]
], dtype=torch.float32)
B = torch.tensor([
[5, 6],
[7, 8]
], dtype=torch.float32)
u = torch.tensor([1, 2], dtype=torch.float32)
v = torch.tensor([3, 4], dtype=torch.float32)
일반적인 행렬 곱 print("<torch.matmul 이용>")
print(torch.matmul(A, B))
print("<torch.einsum 이용>")
print(torch.einsum('ik,kj->ij', A, B))
아다마르 곱(Hadamard product) print("<torch.mul 이용>")
print(torch.mul(A, B))
print("<torch.einsum 이용>")
print(torch.einsum('ij,ij->ij', A, B))
쌍선형 형식(bilinear form) print(torch.einsum('i,ij,j->',u,B,v))
*2D CNN (위,아래 첨자 구분 X) def conv2d(x, w, b):
B, C_in, H_in, W_in = x.shape
C_out, _, K_h, K_w = w.shape
H_out = H_in - K_h + 1
W_out = W_in - K_w + 1
x_unfold = F.unfold(x, kernel_size=(K_h, K_w), stride=1, padding=0)
x_patches = x_unfold.view(B, C_in, K_h, K_w, H_out, W_out)
x_patches = x_patches.permute(0, 1, 4, 5, 2, 3)
y = torch.einsum('bcijmn,kcmn->bkij', x_patches, w)
y = y + b.view(1, -1, 1, 1)
y = F.relu(y)
return y
B = 2
C_in = 3
H_in, W_in = 5, 5
C_out = 2
K_h, K_w = 3, 3
x = torch.randn(B, C_in, H_in, W_in)
w = torch.randn(C_out, C_in, K_h, K_w)
b = torch.randn(C_out)
y = conv2d(x, w, b)
print(y.shape)
print(y)