카메라 이미지 품질 향상 AI 경진대회

[pytorcyh] 더 빠른 학습을 가능하게 하는 amp 예제 코드 공유

2021.07.15 20:25 6,121 조회

학습하는데 computational cost가 높아서 어려움이 많은데, amp를 사용하면 더 큰 배치사이즈를 사용할 수 있거나, gpu 스펙에 따라 더 빠른 학습이 가능합니다.

아래와 같이 쉽게 사용할 수 있는데, 간단한 amp 사용 예시 공유합니다.

모두 좋은 결과 있으시길 바래요!

import torch.cuda.amp as amp
  
scaler = amp.GradScaler()
net = YourModel()

# ------------------------
# loss
# ------------------------
loss_fn = nn.L1Loss()

# ------------------------
# Optimizer
# ------------------------
optimizer = optim.Adam(net.parameters(), lr=args.start_lr, weight_decay=args.weight_decay)

for epoch in range(1, args.epochs+1):

    for t, (images, targets) in enumerate(tqdm.tqdm(trainloader)):
      images = images.to(device=device, dtype=torch.float)
      targets = targets.to(device=device, dtype=torch.float)

      net.train()
      optimizer.zero_grad()

      if args.amp:
        with amp.autocast():
          output = net(images)
          # loss
          loss = loss_fn(output, targets)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

      else:
        output = net(images)

        # loss
        loss = loss_fn(output, targets)

        # update
        loss.backward()
        optimizer.step()

로그인이 필요합니다
0 / 1000
도비콘
2021.07.15 22:42

 👍🏻 광한님 감사합니다.

INVIS
2021.07.26 14:45

삭제된 댓글입니다

INVI
2021.07.28 20:12

pretrained weight을 받아서 사용하시는 분들은 amp 사용을 조심하셔야 할듯합니다.
저는 amp로 학습하다 성능이 잘 안나와 원인을 찾던도중 amp를 제거하니 성능이 좋아졌습니다.
아마 float16으로 사용하는 과정에서 문제가 생기는 것 같습니다.

Team
2021.07.26 15:15

저는 현재 amp 계속 활용중이고, 성능에는 별영향이 없었습니다.
아마 amp에서 scale 해주는 과정에서 loss가 불안정할 수 있는데, gradient clipping이나 작은 loss로 시작하면 문제는 해결될 것 같습니다.

INVI
2021.07.26 15:25

저도 그래서 clipping을 사용하고 lr을 낮춰보았는데 성능차이가 좁혀지진 않더라고요...
아무튼 답변 감사합니다!