-
Tensor 변환과 결합 - cat(), view(), squeeze(), unsqueeze()AI/pytorch 2021. 6. 28. 20:57
본 내용은 https://blog.naver.com/qbxlvnf11/221627488337 에서 대부분 발췌해 온 것임
Pytorch에서 텐서를 결합할 때는 cat() 함수를 활용,
dim 파라미터를 통해 기준을 설정
- torch 라이브러리 import
import torch
- 테스트용 텐서 생성
a = torch.ones(4, 3) b = torch.zeros(3, 3)
- 텐서 열 기준으로 concatenate (각 텐서의 열의 수가 같아야 함)
torch.cat([a, b], dim=0)
- 테스트용 텐서 생성
a = torch.ones(4, 3) b = torch.zeros(4, 4)
- 텐서 행 기준으로 concatenate (각 텐서의 행의 수가 같아야 함)
torch.cat([a, b], dim=1)
Pytorch에서 텐서의 shape를 변환하고 차원을 확장할 때는 view() 함수를 사용,
tensorflow의 reshape() 함수와 유사
- 테스트용 텐서 생성
a = torch.ones(4, 3)
- 텐서 shape 변경
a.view(3,4)
- 텐서 차원 확장
a.view(3,4,1)
squeeze() 함수와 unsqueeze() 함수 역시 차원을 변환하는 함수,
squeeze() 함수는 원소가 1인 차원을 제거,
unsqueeze() 함수는 인수로 받은 위치에 새로운 차원 삽입
- squeeze() 함수
a.view(3,4,1).squeeze()
- unsqueeze() 함수
a.unsqueeze(0)
a.unsqueeze(0).size()
a.unsqueeze(2)
a.unsqueeze(2).size()
'AI > pytorch' 카테고리의 다른 글
max, gather (0) 2021.06.28 Tensor 변환과 결합 - cat(), view(), squeeze(), unsqueeze() (0) 2021.06.28 squeeze(), unsqueeze()함수와 주의점 (0) 2021.06.28