대학 대항전 : 퍼즐 이미지 AI 경진대회

알고리즘 | 월간 데이콘 | 대학 대항전 | 비전 | 자기지도학습 | 기술 혁신 | 분류

  • moneyIcon 상금 : 인증서
  • 841명 마감

 

Jigsaw-ViT 모델

2024.01.17 14:13 1,518 조회 language

Jigsaw-ViT 모델을 사용하여 Public 점수 0.95955 까지 달성하였는데, 시도해온 방법 공유드립니다.

Jigsaw-ViT 모델에서 보조적으로 사용하는 Jigsaw Puzzle Solving Task가 이번 대회 Task와 동일하기 때문에, 해당 Task만 사용해 학습했습니다.
공개된 Pretrained 모델은 이미지를 14x14 패치로 나누어 퍼즐을 푸는 구조인데, 14x14 패치 중 일부가 4x4 패치의 여러 패치에 걸치는 문제가 있어서, 사용하는데 어려움이 있다 판단했습니다.

대신 deit3_base_patch16_384 pretrained 모델 기반으로 Jigsaw-ViT 모델 구조로 변경해 사용했습니다. 이 모델은 이미지를 24x24 패치로 분할해 퍼즐을 푸는 구조이기 때문에, 나중에 4x4 퍼즐을 맞출 때 영역이 일치하지 않는 문제가 발생하지 않습니다.

학습은 기존 Jigsaw-ViT 설정을 그대로 가져와 사용했습니다.
처음엔 논문대로 Mask Ratio 0.5를 사용해 학습했고, 해당 방법으로 리더보드 점수 0.86 정도를 달성했습니다.
시간은 Colab V100 기준 24시간 정도 소요된 것 같습니다.

Random Masking을 하면 일부 패치를 버리기 때문에 학습 속도가 빠르다는 장점이 있지만, 추론 시에는 모든 패치를 보기 때문에 학습과 추론 사이에 차이가 존재하게 됩니다.
그래서 Mask Ratio를 점점 낮춰가며 학습했고, 일관된 성능 향상이 있었습니다.

검증 데이터에서 잘못 예측한 이미지들을 살펴보면서 발견한 문제점은 배경이 단조로운 경우 동일한 패치가 많이 생겨서 예측을 잘 못하는 것 같다고 판단했습니다.
패치 개수를 줄여보는 등의 시도를 해보시면 좋을 것 같습니다.

참고
Jigsaw-ViT: Learning Jigsaw Puzzles in Vision Transformer: https://arxiv.org/abs/2207.11971
데이터셋 코드는 이전에 파이썬초보만 님이 올려주신 코드를 수정했습니다.

코드
로그인이 필요합니다
0 / 1000
파이썬초보만
2024.01.17 15:15

저도 masking하는 부분때문에 이번 대회와 약간 차이가 있다고 생각했는데
masking rate를 조절하는 방법이 있군요! 잘 읽었습니다

Oak_tree
2024.01.18 09:02

공유 감사합니다. 많이 배우고 갑니다.

UAI_1
2024.01.22 11:07

# Define combinations for 2x2 and 3x3 puzzles
    combinations_2x2 = [(i, j) for i in range(3) for j in range(3)]
    combinations_3x3 = [(i, j) for i in range(2) for j in range(2)]
혹시 여기에 2x2퍼즐에 for문이 3으로 들어가 있고, 
3x3퍼즐에 for문이 2로 들어가 있는데 왜 그런지 알 수 있을까요?

klne
2024.01.22 12:10

해당 부분은 공개되어 있는 평가산식 코드를 가져온 것인데요.
저도 구현이 틀렸나 싶어서 확인해봤는데, 각각 2x2, 3x3 퍼즐의 좌측 상단 위치를 의미하는 것 같네요.
2x2 퍼즐은 9개, 3x3 퍼즐은 4개가 존재하니 맞는 구현인 것 같습니다.

UAI_1
2024.01.22 17:05

답변 감사드립니다.
혹시 학습에 너무 오래 걸려서 약 1천건의 데이터만 사용해서 진행시켰더니 ['10', '4', '11', '11', '10', '11', '11', '10', '11', '4', '4', '10', '4', '11', '10', '11']처럼 중복값이 발생하는데 전체 데이터로 수행하면 이런 문제는 발생하지 않나요?

klne
2024.01.25 01:23

학습이 될수록 중복값이 줄어들긴 하지만, 완전히 없어지지는 않습니다...

정답을 구성하는 부분을 중복 값이 생기지 않도록 수정하시면 약간의 성능 향상이 있을수도 있습니다.

UAI_1
2024.01.29 15:44

친절한 답변 감사합니다.
올려주신 코드로 공부중인데, CNT_ROW.argmax(2) * 4 + CNT_COL.argmax(2) + 1
이 부분에서 행에 4를 곱한 이유를 알 수가 없었습니다.
행을 가만히 두고 열에 4를 곱해도 결과적으로는 완전히 같지는 않더라도 1~16의 값을 얻을 수 있지 않나요?

klne
2024.01.29 18:04

단순히 정답 형식이 행, 열 순서로 되어있어서 그렇습니다.
예를 들어 어떤 패치에서 (0행 0열부터 시작) 1행 3열이 가장 많이 카운트 되었을 경우, 해당 패치는 1*4+3+1=8번 위치에 존재한다고 예측해야겠죠.

UAI_1
2024.01.30 11:20

그 외에 4를 곱한 이유는 행과 열을 합쳐서 1~16의 값을 만들기 위함인거 같은데, 4를 곱하는 이유가 그 외에 있을까요?

co1dtype
2024.01.31 08:17

중복이 발생하지 않게 수정한다는 건 정확히 무슨 의미일까요? test에 대한 개수 추론 같아서 중복 제거를 해도 되는 지 궁금하네요.

ssung
2024.01.22 16:39

갈피를 못잡고 있었는데 공유 정말 감사합니다. 많이 배우고 가요!