Python

[Python] tensor dimension 맞추기 문법, unsqueeze(0), squeeze, permute() 총정리!

도도걸만단 2025. 3. 9. 16:40
반응형

unsqueeze(0)의 역할과 차원 조정 문법 정리

1. unsqueeze(0)란?

torch.Tensor.unsqueeze(dim)

unsqueeze(dim)은 지정된 차원에 새로운 차원을 추가하는 PyTorch 함수

예제 : 1D 텐서 → 2D 텐서 변환

import torch

x = torch.tensor([1, 2, 3])  # [3]
print(x.shape)  # torch.Size([3])

# 첫 번째 차원(0번)에 새로운 차원 추가
x_unsqueezed = x.unsqueeze(0)
print(x_unsqueezed.shape)  # torch.Size([1, 3])

원래 텐서가 [3] (1차원)이었지만, unsqueeze(0)을 적용하여 [1, 3] (2차원)으로 변경됨.

2. unsqueeze()의 다양한 활용

(1) 이미지 데이터에서 배치 차원 추가

이미지 데이터(3D 텐서) → 배치 차원 추가 (4D 텐서)

image = torch.randn(3, 256, 256)  # (C, H, W) = (3, 256, 256)
image_batch = image.unsqueeze(0)  # (B, C, H, W) = (1, 3, 256, 256)

이미지는 일반적으로 (C, H, W) 형식이므로, unsqueeze(0)을 적용하여 (B, C, H, W)로 변환.
 DataLoader에서 배치 단위로 처리하려면 (B, C, H, W) 형식이 필요함.


(2) Disparity Map을 RGBD와 병합하기 위한 차원 맞추기

Disparity Map은 (H, W)이므로 unsqueeze(0)을 적용해 (1, H, W)로 변환 후 RGB 이미지 (3, H, W)와 병합

disp = torch.randn(256, 256)  # (H, W)
disp = disp.unsqueeze(0)  # (1, H, W)
rgbd = torch.cat([image, disp], dim=0)  # (4, H, W)

RGB(3채널) + Depth(1채널) = RGBD(4채널)로 변환!


3. unsqueeze()를 사용하는 이유

  • PyTorch에서 CNN, DataLoader, Loss Function들은 특정 차원 형식을 요구함.
  • unsqueeze()를 사용하면 원하는 차원을 쉽게 추가할 수 있음.

 4. squeeze() vs unsqueeze()

함수 역할 예제

unsqueeze(dim) 지정된 차원에 새로운 차원 추가 [3] → [1, 3]
squeeze(dim) 크기가 1인 차원 제거 [1, 3] → [3]

(예제)

x = torch.tensor([[1, 2, 3]])  # (1, 3)
x_squeezed = x.squeeze(0)  # (3) → 차원 축소
x_unsqueezed = x_squeezed.unsqueeze(0)  # 다시 (1, 3)으로 변환

 

squeeze(dim)

squeeze(dim)은 크기가 1인 차원을 제거하는 함수

import torch

tensor = torch.randn(1, 3, 256, 256)  # (B, C, H, W)
tensor_squeezed = tensor.squeeze(0)  # 배치 차원 제거 (C, H, W)
print(tensor_squeezed.shape)  # torch.Size([3, 256, 256])

 (B, C, H, W) → (C, H, W)
 배치 차원이 필요 없는 경우 squeeze(0)로 제거 가능

(2) 모든 크기 1 차원 제거 (squeeze())

tensor = torch.randn(1, 1, 256, 256)  # (B, C, H, W) = (1, 1, 256, 256)
tensor_squeezed = tensor.squeeze()  # 모든 차원에서 크기 1 제거
print(tensor_squeezed.shape)  # torch.Size([256, 256])

크기가 1인 차원은 모두 제거됨!

squeeze()로 모든 크기 1인 차원을 제거하면 어떤 점이 좋은가?

torch.Tensor.squeeze()는 크기가 1인 차원을 자동으로 제거하여 연산을 더 효율적으로 수행할 수 있도록 도와줌

  • 크기가 1인 차원은 데이터 손실 없이 제거 가능
  • 크기가 2 이상인 차원은 정보가 포함되어 있으므로 제거하면 데이터가 손실됨
  • 데이터 손실을 방지하면서 차원을 줄일 수 있는 안전한 방법

 

squeeze(dim) vs squeeze() 차이

squeeze(dim)과 squeeze()는 동작 방식이 다름

squeeze() 모든 크기 1인 차원을 제거 (1, 3, 1, 256, 256) → (3, 256, 256)
squeeze(0) 0번 차원이 1이면 제거 (1, 3, 256, 256) → (3, 256, 256)
squeeze(1) 1번 차원이 1이면 제거 (3, 1, 256, 256) → (3, 256, 256)

squeeze(0), squeeze(1) 같은 경우 특정 차원만 제거 가능!
squeeze()는 모든 크기 1인 차원을 제거.


5. 차원 맞추는 주요 문법 모음

배치 차원 추가 tensor.unsqueeze(0) (C, H, W) → (1, C, H, W)
채널 차원 추가 tensor.unsqueeze(1) (B, H, W) → (B, 1, H, W)
채널 차원 제거 tensor.squeeze(1) (B, 1, H, W) → (B, H, W)
차원 변경 tensor.permute(2, 0, 1) (H, W, C) → (C, H, W)
크기 조정 F.interpolate(tensor, (H, W)) (C, H_old, W_old) → (C, H, W)

 


permute(), F.interpolate(), squeeze(), F의 역할?

5. permute(dim0, dim1, ...)

permute()는 PyTorch에서 텐서의 차원을 원하는 순서로 변경하는 함수

new_tensor = tensor.permute(dim0, dim1, dim2, ...)

현재 차원의 순서를 바꿀 때 사용됨!

(1) 예제: HWC → CHW 변환 (이미지 데이터)

import torch

tensor = torch.randn(256, 256, 3)  # (H, W, C)
tensor_permuted = tensor.permute(2, 0, 1)  # (C, H, W)
print(tensor.shape)  # torch.Size([256, 256, 3])
print(tensor_permuted.shape)  # torch.Size([3, 256, 256])
  • permute(2, 0, 1)을 적용하여 마지막 차원(C)을 첫 번째 위치로 이동.
  • 일반적인 이미지 처리에서는 permute()를 사용하여 채널 차원을 앞으로 이동.

6.  F.interpolate(input, size, mode)

PyTorch에서 이미지 또는 텐서를 리사이징(크기 조정)하는 함수
Interpolate = 보간(Interpolation) → 새로운 픽셀 값을 예측하여 크기를 조정하는 것

import torch
import torch.nn.functional as F

image = torch.randn(1, 3, 128, 128)  # (B, C, H, W) = (1, 3, 128, 128)
resized = F.interpolate(image, size=(256, 256), mode="bilinear")
print(resized.shape)  # torch.Size([1, 3, 256, 256])

이미지 크기를 (128, 128) → (256, 256)으로 변경
mode="bilinear": 선형 보간법을 사용하여 부드러운 크기 조정 수행

 

주요 모드 (Interpolation 방식)

"nearest" 가장 가까운 픽셀 값을 사용 (블록형 확대)
"bilinear" 선형 보간 (부드러운 확대)
"bicubic" 삼차 보간 (더 부드러운 확대)
"trilinear" 3D 데이터(볼륨) 확대 시 사용

 

(2) disp (Disparity Map)에서 F.interpolate() 적용

disp = torch.randn(1, 1, 256, 256)  # (B, 1, H, W)
disp_resized = F.interpolate(disp, size=(384, 512), mode="bilinear")
print(disp_resized.shape)  # torch.Size([1, 1, 384, 512])

Disparity Map도 크기 조정 가능!
✔ mode="bilinear"을 사용하여 연속적인 깊이 값을 부드럽게 조정


7.  F?

F는 torch.nn.functional의 alias(별칭)

import torch.nn.functional as F

# ReLU 활성화 함수
x = torch.tensor([-1.0, 0.0, 1.0])
relu_x = F.relu(x)  
print(relu_x)  # tensor([0., 0., 1.])

F.relu() → torch.nn.ReLU()와 동일한 기능

 

F.relu() ReLU 활성화 함수
F.softmax() Softmax 함수
F.interpolate() 텐서 크기 조정 (리사이징)
F.cross_entropy() 크로스 엔트로피 손실 함수

 

 

 

 

반응형