[머신러닝을 위한 대수학 - 1.5] torch.einsum은 무적이다

2025.04.14 16:39 314 조회

[머신러닝을 위한 대수학 - 1] (텐서편)과 이어집니다.

아래는 torch 내 einsum 함수의 예제 코드를 가볍게 작성해본 모습입니다.


우선 다음과 같이 세팅해줍니다.

import torch
import torch.nn.functional as F


A = torch.tensor([
    [1, 2],
    [3, 4]
], dtype=torch.float32)

B = torch.tensor([
    [5, 6],
    [7, 8]
], dtype=torch.float32)

u = torch.tensor([1, 2], dtype=torch.float32)
v = torch.tensor([3, 4], dtype=torch.float32)



일반적인 행렬 곱

print("<torch.matmul 이용>")
print(torch.matmul(A, B))

print("<torch.einsum 이용>")
print(torch.einsum('ik,kj->ij', A, B))


아다마르 곱(Hadamard product)

print("<torch.mul 이용>")
print(torch.mul(A, B))

print("<torch.einsum 이용>")
print(torch.einsum('ij,ij->ij', A, B))


쌍선형 형식(bilinear form)

print(torch.einsum('i,ij,j->',u,B,v))


*2D CNN

(위,아래 첨자 구분 X)

def conv2d(x, w, b):
    B, C_in, H_in, W_in = x.shape
    C_out, _, K_h, K_w = w.shape


    H_out = H_in - K_h + 1
    W_out = W_in - K_w + 1
    x_unfold = F.unfold(x, kernel_size=(K_h, K_w), stride=1, padding=0)
    x_patches = x_unfold.view(B, C_in, K_h, K_w, H_out, W_out)
    x_patches = x_patches.permute(0, 1, 4, 5, 2, 3)


    y = torch.einsum('bcijmn,kcmn->bkij', x_patches, w)
    
    y = y + b.view(1, -1, 1, 1)
    y = F.relu(y)


    return y



B = 2
C_in = 3
H_in, W_in = 5, 5
C_out = 2
K_h, K_w = 3, 3

x = torch.randn(B, C_in, H_in, W_in)
w = torch.randn(C_out, C_in, K_h, K_w)
b = torch.randn(C_out)

y = conv2d(x, w, b)
print(y.shape)
print(y)