분석시각화 대회 코드 공유 게시물은
내용 확인 후
좋아요(투표) 가능합니다.
카메라 이미지 품질 향상 AI 경진대회
[pytorcyh] 더 빠른 학습을 가능하게 하는 amp 예제 코드 공유
학습하는데 computational cost가 높아서 어려움이 많은데, amp를 사용하면 더 큰 배치사이즈를 사용할 수 있거나, gpu 스펙에 따라 더 빠른 학습이 가능합니다.
아래와 같이 쉽게 사용할 수 있는데, 간단한 amp 사용 예시 공유합니다.
모두 좋은 결과 있으시길 바래요!
import torch.cuda.amp as amp scaler = amp.GradScaler() net = YourModel() # ------------------------ # loss # ------------------------ loss_fn = nn.L1Loss() # ------------------------ # Optimizer # ------------------------ optimizer = optim.Adam(net.parameters(), lr=args.start_lr, weight_decay=args.weight_decay) for epoch in range(1, args.epochs+1): for t, (images, targets) in enumerate(tqdm.tqdm(trainloader)): images = images.to(device=device, dtype=torch.float) targets = targets.to(device=device, dtype=torch.float) net.train() optimizer.zero_grad() if args.amp: with amp.autocast(): output = net(images) # loss loss = loss_fn(output, targets) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: output = net(images) # loss loss = loss_fn(output, targets) # update loss.backward() optimizer.step()
삭제된 댓글입니다
pretrained weight을 받아서 사용하시는 분들은 amp 사용을 조심하셔야 할듯합니다.
저는 amp로 학습하다 성능이 잘 안나와 원인을 찾던도중 amp를 제거하니 성능이 좋아졌습니다.
아마 float16으로 사용하는 과정에서 문제가 생기는 것 같습니다.
저는 현재 amp 계속 활용중이고, 성능에는 별영향이 없었습니다.
아마 amp에서 scale 해주는 과정에서 loss가 불안정할 수 있는데, gradient clipping이나 작은 loss로 시작하면 문제는 해결될 것 같습니다.
저도 그래서 clipping을 사용하고 lr을 낮춰보았는데 성능차이가 좁혀지진 않더라고요...
아무튼 답변 감사합니다!
데이콘(주) | 대표 김국진 | 699-81-01021
통신판매업 신고번호: 제 2021-서울영등포-1704호
서울특별시 영등포구 은행로 3 익스콘벤처타워 901호
이메일 dacon@dacon.io | 전화번호: 070-4102-0545
Copyright ⓒ DACON Inc. All rights reserved
👍🏻 광한님 감사합니다.