DASCHOOL! Fall Special Discount
분석시각화 대회 코드 공유 게시물은
내용 확인 후
좋아요(투표) 가능합니다.
[머신러닝을 위한 대수학 - 1.5] torch.einsum은 무적이다
[머신러닝을 위한 대수학 - 1] (텐서편)과 이어집니다.
아래는 torch 내 einsum 함수의 예제 코드를 가볍게 작성해본 모습입니다.
우선 다음과 같이 세팅해줍니다.
import torch import torch.nn.functional as F A = torch.tensor([ [1, 2], [3, 4] ], dtype=torch.float32) B = torch.tensor([ [5, 6], [7, 8] ], dtype=torch.float32) u = torch.tensor([1, 2], dtype=torch.float32) v = torch.tensor([3, 4], dtype=torch.float32)

print("<torch.matmul 이용>")
print(torch.matmul(A, B))
print("<torch.einsum 이용>")
print(torch.einsum('ik,kj->ij', A, B))

print("<torch.mul 이용>")
print(torch.mul(A, B))
print("<torch.einsum 이용>")
print(torch.einsum('ij,ij->ij', A, B))

print(torch.einsum('i,ij,j->',u,B,v))

(위,아래 첨자 구분 X)
def conv2d(x, w, b):
B, C_in, H_in, W_in = x.shape
C_out, _, K_h, K_w = w.shape
H_out = H_in - K_h + 1
W_out = W_in - K_w + 1
x_unfold = F.unfold(x, kernel_size=(K_h, K_w), stride=1, padding=0)
x_patches = x_unfold.view(B, C_in, K_h, K_w, H_out, W_out)
x_patches = x_patches.permute(0, 1, 4, 5, 2, 3)
y = torch.einsum('bcijmn,kcmn->bkij', x_patches, w)
y = y + b.view(1, -1, 1, 1)
y = F.relu(y)
return y
B = 2
C_in = 3
H_in, W_in = 5, 5
C_out = 2
K_h, K_w = 3, 3
x = torch.randn(B, C_in, H_in, W_in)
w = torch.randn(C_out, C_in, K_h, K_w)
b = torch.randn(C_out)
y = conv2d(x, w, b)
print(y.shape)
print(y)
DACON Co.,Ltd | CEO Kookjin Kim | 699-81-01021
Mail-order-sales Registration Number: 2021-서울영등포-1704
Business Providing Employment Information Number: J1204020250004
#901, Eunhaeng-ro 3, Yeongdeungpo-gu, Seoul 07237
E-mail dacon@dacon.io |
Tel. 070-4102-0545
Copyright ⓒ DACON Inc. All rights reserved