논문 리뷰

[논문 리뷰] AdaMPI : Single-View View Synthesis in the Wild with Learned Adaptive Multiplane Images

도도걸만단 2025. 3. 7. 02:44
반응형

TMPI의 baseline model, adaptive depth plane placement

 

 

 

ACM SIGGRAPH 2022 [Submitted on 24 May 2022]

https://arxiv.org/abs/2205.11733

 

Single-View View Synthesis in the Wild with Learned Adaptive Multiplane Images

This paper deals with the challenging task of synthesizing novel views for in-the-wild photographs. Existing methods have shown promising results leveraging monocular depth estimation and color inpainting with layered depth representations. However, these

arxiv.org

 

https://yxuhan.github.io/AdaMPI/

 

AdaMPI

 

yxuhan.github.io

 

https://github.com/yxuhan/AdaMPI

 

GitHub - yxuhan/AdaMPI: [SIGGRAPH 2022] Single-View View Synthesis in the Wild with Learned Adaptive Multiplane Images

[SIGGRAPH 2022] Single-View View Synthesis in the Wild with Learned Adaptive Multiplane Images - yxuhan/AdaMPI

github.com


 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

3.4 Network Training

 

1. L_{vs}  (View Synthesis Loss)

목적: 합성된 새로운 시점의 이미지 및 깊이 맵을 Ground Truth와 최대한 유사하게 만들기

사용된 손실:

L1 Loss (weight: 1.0)

SSIM Loss (weight: 1.0)

Perceptual Loss (VGG Loss, weight: 0.1)

Focal Frequency Loss (weight: 10.0)

깊이 맵의 L1 Loss (weight: 1.0)

결론: 이미지 품질 개선을 위한 손실 함수

 

2. L_{reg}  (Regularization Loss)

목적:

MPI(Multiplane Image) 깊이 맵이 올바른 깊이 순서를 유지하도록 보정

깊이 맵이 각 평면에 최적 배치되도록 정렬

사용된 손실:

Rank Loss ( L_{rank} , weight: 100)

Assignment Loss ( L_{assign} , weight: 10)

 결론: 깊이 구조를 유지하기 위한 손실 함수

 

\( L_{\text{total}} = L_{\text{vs}} + L_{\text{reg}} \)

 

\( L_{\text{total}} = (1.0 \cdot L_{L1} + 1.0 \cdot L_{\text{SSIM}} + 0.1 \cdot L_{\text{VGG}} + 10.0 \cdot L_{\text{focal}} + 1.0 \cdot L_{\text{depth}}) + (100 \cdot L_{\text{rank}} + 10 \cdot L_{\text{assign}}) \)

 

코드상(?)

\(L_{\text{total}} = L_{\text{L1}} + L_{\text{SSIM}} + \lambda_{\text{disp}} L_{\text{depth}} + \lambda_{\text{occ}} (L_{\text{occ,SSIM}} + L_{\text{occ,RGB}}) + \lambda_{\text{occ,disp}} L_{\text{occ,depth}} + \lambda_{\text{focal}} L_{\text{focal}} + \lambda_{\text{smooth}} L_{\text{smooth}} \)

4 EXPERIMENTS  

4.1 Experimental Setup  

4.1.1 Datasets.

4.1.2 Baselines.

4.1.3 Metrics.

4.2 Comparison with Previous Methods

 

 

 

 

4.3 Ablation Study

4.3.1 Network Architecture.

4.3.2 Loss Function.

4.3.3 Training Data.

 

 


 

 

 

 

 

 


 

 

README.md: 프로젝트에 대한 전반적인 설명과 최신 업데이트 정보를 제공한다.

gen_3dphoto.py: 단일 이미지와 그에 대응하는 깊이 맵을 입력으로 받아 3D 사진을 생성하는 스크립트이다.

train.py: 모델 학습을 위한 메인 스크립트로, 데이터 로더 초기화, 분산 학습 설정, 학습 루프 등을 포함한다.

trainer.py: 학습 과정에서의 손실 계산, 모델 업데이트, 평가 등을 수행하는 SynthesisTask 클래스를 정의한다.

preprocess_data.py: 데이터셋 전처리를 위한 스크립트로, COCO 데이터셋의 이미지에 대해 단안 깊이 추정을 수행하여 깊이 맵을 생성한다.

download_data.sh: COCO 데이터셋을 다운로드하고 압축을 해제하는 스크립트이다.

config/ 디렉토리: 모델 학습에 필요한 하이퍼파라미터와 설정 값을 포함한 YAML 형식의 구성 파일들이 저장되어 있다.

doc/ 디렉토리: 프로젝트와 관련된 추가 문서들이 포함되어 있다.

images/ 디렉토리: 테스트용 예제 이미지와 그에 대응하는 깊이 맵이 저장되어 있다.

misc/ 디렉토리: 기타 부가적인 파일들이 저장되어 있다.

model/ 디렉토리: AdaMPI 모델의 주요 구성 요소인 MPIPredictor, DepthPredictionNetwork, Feature Mask Network 등을 정의하는 파이썬 파일들이 포함되어 있다.

utils/ 디렉토리: MPI 렌더링, 호모그래피 샘플링, 손실 함수 계산 등 다양한 유틸리티 함수들이 포함되어 있다.

warpback/ 디렉토리: Warp-Back 전략을 구현하는 코드와 COCO 데이터셋을 위한 데이터셋 클래스 등이 포함되어 있다.

 


Warp-Back Strategy란?

“원래 이미지에서 새로운 시점을 만들고, 그걸 다시 원래 이미지로 변환하면서 데이터셋을 구축하는 방법”

Warp-Back 전략단일 이미지 데이터만을 사용하여 대규모 스테레오 학습 데이터를 생성하는 방법입니다.

 

보통 새로운 시점의 이미지를 생성(View Synthesis)하려면 다중 시점(Multi-view) 데이터가 필요합니다. 하지만 현실에서는 다중 시점 데이터가 부족한 경우가 많습니다. 이를 해결하기 위해, Warp-Back 전략을 사용하면 단일 이미지 데이터셋(예: COCO)만으로도 스테레오 데이터 쌍을 생성할 수 있습니다.

 

1. Warp-Back Strategy의 동작 방식

 

(1) Depth-based Warping (깊이 기반 워핑)

입력 이미지의 깊이(Depth) 맵을 추정합니다. (DPT 같은 Monocular Depth Estimation 네트워크 활용)

이 깊이 정보를 사용하여 가상의 새로운 시점(View)을 생성합니다.

Mesh Renderer를 사용하여 새로운 시점에서 보이는 장면을 합성합니다.

이 과정에서 가려진 영역(occlusion)이 생기며, 보이지 않는 부분은 비어 있습니다.

 

(2) Back-Warping and Inpainting (되돌리기 & 채우기)

생성된 새로운 시점의 이미지를 원래 시점으로 다시 Warp(되돌리기) 합니다.

이 과정에서 가려진 영역의 홀(hole)이 원래 이미지에서 보이게 됩니다.

이러한 가려진 부분을 채우기 위해 Inpainting Network (채우기 네트워크) 를 훈련합니다.

기존 일반적인 인페인팅 네트워크는 랜덤한 마스크(구멍)에 대해 학습하지만, Warp-Back에서는 카메라 이동으로 인해 발생하는 홀만을 특화하여 학습합니다.

이를 통해 자연스럽고 일관된 배경 정보로 구멍을 채울 수 있습니다.

 

(3) Stereo Training Pair Generation (학습 데이터 생성)

학습을 위해 스테레오 데이터 쌍(원래 이미지와 새로운 시점의 이미지)을 생성합니다.

생성된 새로운 시점의 이미지에서 가려진 영역을 보완한 후, 원래 이미지와 함께 학습에 사용합니다.

이렇게 하면 실제 다중 시점 이미지 없이도 모델이 다양한 시점 변화를 학습할 수 있습니다.

 

2. Warp-Back Strategy의 핵심 장점

1. 다중 시점(Multi-View) 데이터 없이 학습 가능

기존의 View Synthesis 모델들은 여러 각도에서 촬영한 사진이 필요했지만, Warp-Back 전략을 사용하면 단일 이미지 데이터셋만으로도 충분히 학습이 가능함.

2. 실제 촬영된 멀티뷰 데이터보다 더욱 다양한 시점 학습 가능

랜덤하게 카메라 이동을 시뮬레이션하여 다양한 각도에서 학습 가능.

3. 복잡한 3D 구조에 대한 일반화 성능 향상

단순한 Layered Depth 방식보다 더욱 정교한 다층 평면(Multiplane Image, MPI) 학습이 가능.

 

3. Warp-Back Strategy의 한계점

1. 깊이 추정 오류가 발생하면 품질 저하

Monocular Depth Estimation이 완벽하지 않기 때문에 깊이 정보가 부정확하면 왜곡이 발생할 수 있음.

2. View-Dependent Effect (반사, 투명체 등) 표현 어려움

이 방법은 물체의 고유한 반사 특성(예: 유리창, 반짝이는 금속 표면 등)을 잘 학습하지 못함.

3. Occluded Region(가려진 영역) 복원 한계

Warp-Back 과정에서 occlusion이 있는 영역을 복원할 때, 배경이 너무 복잡하면 부정확할 수 있음.

 


Code refactoring

stage1_dataset.py

AdaMPI에서 warp-back strategy의 첫 번째 단계, 즉 pseudo ground truth 생성 파이프라인의 시작점

 

한 장의 RGB 이미지 + 대응하는 Disparity(depth) 로부터 다음을 생성:

 

목적:

뷰포인트를 무작위로 이동(warp) 시켜서 novel view(가짜 뷰) 생성

그 view를 다시 원래 시점으로 warp-back

→ 이때 생기는 mask(occlusion hole) 를 이용해 disocclusion inpainting 학습용 데이터를 만드는 것이 목적

 

들어가는 입력

입력 항목설명

RGB 이미지 실제 이미지
Disparity Map (depth) DPT 같은 네트워크로 미리 뽑아둔 depth map
Camera Intrinsic 고정값으로 사용 (K 행렬)
Random Transform 무작위 회전/이동으로 카메라 뷰를 변경

 

처리 순서

1. RGB + Disparity → RGBD 텐서로 만들기

2. construct_mesh()로 3D mesh 생성

3. render_mesh()무작위 시점으로 warp

4. 그 결과를 다시 warp-back하여 원래 시점으로 복원

5. 복원 시 생기는 mask(=가려진 영역) 추출

 

결과로 생성되는 데이터

 

collect_data() 함수의 반환값:

Key 이름설명

rgb 원본 이미지
disp 원본 disparity
warp_rgb 무작위 시점으로 warp된 이미지
warp_disp 무작위 시점으로 warp된 disparity
warp_back_rgb warp-back된 이미지
warp_back_disp warp-back된 disparity
mask occlusion 영역 mask (== inpainting 학습용)

🖼️ 시각화 예시 (stage1 이미지)

 

save_image()로 저장되는 이미지들은 다음과 같이 나열돼요:

[원본 이미지]
[Disparity (3채널 반복)]
[Occlusion mask]
[Warp-back 이미지]
[Warp-back disparity]
[Warp 이미지]
[Warp disparity]

→ 한 장짜리 8칸짜리 grid 이미지로 저장됨 (nrow=bs)

 

최종 목적

 

이 데이터는 다음 단계인 inpainting network 학습 (stage2) 에서 사용됩니다.

 

즉, Stage1 = 학습용 synthetic 데이터 자동 생성기

Stage2에서 이 mask 영역을 복원하는 네트워크를 학습시켜요.

 

 흐름 요약

[RGB + Disparity]
      ↓
  Construct Mesh
      ↓
 Random Warp View
      ↓
  Warp-Back to Original View
      ↓
Generate Occlusion Mask + RGBD Pair → 저장

stage2_dataset.py

Stage2에서의 Warp(변환) 은 실제로 원본 이미지에 대해 3D 기하학적 변환 (camera extrinsic 변화) 를 적용해서 다른 시점에서 본 것처럼 만드는 과정

 

Warp가 어떻게 일어나는가?

cam_ext, cam_ext_inv = self.get_rand_ext(batch_size)

get_rand_ext() 함수에서 x, y, z, a, b, c 방향으로 랜덤하게 변환을 생성해.

이 값들은 self.trans_range에 정의된 범위 내에서 무작위로 뽑혀.

 

trans_range={"x":0.2, "y":-1, "z":-1, "a":-1, "b":-1, "c":-1}

여기서:

x=0.2 → x축 방향으로 [-0.2, -0.1] 또는 [0.1, 0.2] 사이 랜덤 이동

나머지가 -1 → 변환 없음 (0 이동)

 

🧩 그래서 Warp는 실제로 무슨 의미?

원본 이미지(tgt_rgb)에서 생성한 RGBD를

랜덤 extrinsic으로 novel view로 warp한 게 warp_rgb

그걸 다시 inpainting한 게 src_rgb (즉, pseudo GT)

 

warp_rgb 원본 이미지를 랜덤 카메라 시점으로 warp한 것
src_rgb warp_rgb를 inpaint하여 복원한 pseudo GT
cam_ext, cam_ext_inv 카메라 extrinsic matrix (3x4), 이동·회전 포함
수치로 이동량 확인 cam_ext_inv의 마지막 열 (translation) 확인

 

 

✅ -1이면 왜 변환이 없는 거야?

trans_range={"x":0.2, "y":-1, "z":-1, "a":-1, "b":-1, "c":-1}

trans_range는 각 축(x, y, z)과 회전(a, b, c)에 대한 변환 범위를 뜻해.

코드에서 변환을 만들 때는 이 값을 기준으로 랜덤 이동/회전을 만들어.

if r < 0:
    return torch.zeros((batch_size, 1, 1))  # → 변환 없음 (0 이동)

즉, -1은 “변환 하지 마라”는 신호로 쓰이는 거야.

r이 음수이면, 무조건 0을 반환해서 이동이나 회전을 안 시킴.

 

✅ x=0.2는 왜 랜덤 이동이야?

다음 함수에서 확인할 수 있어:

def rand_tensor(self, r, l):
    rand = torch.rand((l, 1, 1))  # 0~1 사이 랜덤 숫자
    sign = 2 * (torch.randn_like(rand) > 0).float() - 1  # 부호 결정: + 또는 -
    return sign * (r / 2 + r / 2 * rand)

예를 들어 r=0.2일 때:

rand는 0.0 ~ 1.0 사이의 무작위 수

따라서 r/2 + r/2 * rand0.1 ~ 0.2 사이 값

여기에 부호가 랜덤하게 붙어 → -0.2 ~ -0.1 또는 +0.1 ~ +0.2로 이동

 

즉, ±0.1 ~ 0.2 범위에서 무작위로 이동하는 걸 의미해.

 

 

 

더보기
import sys
sys.path.append(".")
sys.path.append("..")
import os
import glob
import math
import numpy as np
from skimage.feature import canny
import torch
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader, default_collate
from torchvision.utils import save_image
from torchvision import transforms

 

from warpback.utils_ori import (
RGBDRenderer,
image_to_tensor,
disparity_to_tensor,
transformation_from_parameters,
)
from warpback.networks import get_edge_connect



class WarpBackStage2Dataset(Dataset):
def __init__(
self,
data_root,
width=384,
height=256,
depth_dir_name="pro_depth",
device="cuda", # device of mesh renderer
trans_range={"x":0.2, "y":-1, "z":-1, "a":-1, "b":-1, "c":-1}, # xyz for translation, abc for euler angle
ec_weight_dir="warpback/ecweight",
):
self.data_root = data_root
self.depth_dir_name = depth_dir_name
self.renderer = RGBDRenderer(device)
self.width = width
self.height = height
self.device = device
self.trans_range = trans_range
self.image_path_list = glob.glob(os.path.join(self.data_root, "*.jpg"))
self.image_path_list += glob.glob(os.path.join(self.data_root, "*.png"))

 

# get Stage-1 pretrained inpainting network
self.edge_model, self.inpaint_model, self.disp_model = get_edge_connect(ec_weight_dir)
self.edge_model = self.edge_model.to(self.device)
self.inpaint_model = self.inpaint_model.to(self.device)
self.disp_model = self.disp_model.to(self.device)

 

# set intrinsics
self.K = torch.tensor([
[0.58, 0, 0.5],
[0, 0.58, 0.5],
[0, 0, 1]
]).to(device)

 

def __len__(self):
return len(self.image_path_list)

 

def __getitem__(self, idx):
image_path = self.image_path_list[idx]
image_name = os.path.splitext(os.path.basename(image_path))[0]
disp_path = os.path.join(self.data_root, self.depth_dir_name, "%s.png" % image_name)
 
image = image_to_tensor(image_path, unsqueeze=False) # [3,h,w]
disp = disparity_to_tensor(disp_path, unsqueeze=False) # [1,h,w]
 
# do some data augmentation, ensure the rgbd spatial resolution is (self.height, self.width)
image, disp = self.preprocess_rgbd(image, disp)
 
return image, disp, image_name # image name ms add
 
def preprocess_rgbd(self, image, disp):
# NOTE
# (1) here we directly resize the image to the target size (self.height, self.width)
# a better way is to first crop a random patch from the image according to the height-width ratio
# then resize this patch to the target size
# (2) another suggestion is, add some code to filter the depth map to reduce artifacts around
# depth discontinuities
image = F.interpolate(image.unsqueeze(0), (self.height, self.width), mode="bilinear").squeeze(0)
disp = F.interpolate(disp.unsqueeze(0), (self.height, self.width), mode="bilinear").squeeze(0)
return image, disp
 
def get_rand_ext(self, bs):
x, y, z = self.trans_range['x'], self.trans_range['y'], self.trans_range['z']
a, b, c = self.trans_range['a'], self.trans_range['b'], self.trans_range['c']
cix = self.rand_tensor(x, bs)
ciy = self.rand_tensor(y, bs)
ciz = self.rand_tensor(z, bs)
aix = self.rand_tensor(math.pi / a, bs)
aiy = self.rand_tensor(math.pi / b, bs)
aiz = self.rand_tensor(math.pi / c, bs)
 
axisangle = torch.cat([aix, aiy, aiz], dim=-1) # [b,1,3]
translation = torch.cat([cix, ciy, ciz], dim=-1)
 
cam_ext = transformation_from_parameters(axisangle, translation) # [b,4,4]
cam_ext_inv = torch.inverse(cam_ext) # [b,4,4]
return cam_ext[:, :-1], cam_ext_inv[:, :-1]
 
def rand_tensor(self, r, l):
'''
return a tensor of size [l], where each element is in range [-r,-r/2] or [r/2,r]
'''
if r < 0: # we can set a negtive value in self.trans_range to avoid random transformation
return torch.zeros((l, 1, 1))
rand = torch.rand((l, 1, 1))
sign = 2 * (torch.randn_like(rand) > 0).float() - 1
return sign * (r / 2 + r / 2 * rand)

 

def inpaint(self, image, disp, mask):
image_gray = transforms.Grayscale()(image)
edge = self.get_edge(image_gray, mask)
 
mask_hole = 1 - mask

 

# inpaint edge
edge_model_input = torch.cat([image_gray, edge, mask_hole], dim=1) # [b,4,h,w]
edge_inpaint = self.edge_model(edge_model_input) # [b,1,h,w]

 

# inpaint RGB
inpaint_model_input = torch.cat([image + mask_hole, edge_inpaint], dim=1)
image_inpaint = self.inpaint_model(inpaint_model_input)
image_merged = image * (1 - mask_hole) + image_inpaint * mask_hole
 
# inpaint Disparity
disp_model_input = torch.cat([disp + mask_hole, edge_inpaint], dim=1)
disp_inpaint = self.disp_model(disp_model_input)
disp_merged = disp * (1 - mask_hole) + disp_inpaint * mask_hole

 

return image_merged, disp_merged

 

def get_edge(self, image_gray, mask):
image_gray_np = image_gray.squeeze(1).cpu().numpy() # [b,h,w]
mask_bool_np = np.array(mask.squeeze(1).cpu(), dtype=np.bool_) # [b,h,w]
edges = []
for i in range(mask.shape[0]):
cur_edge = canny(image_gray_np[i], sigma=2, mask=mask_bool_np[i])
edges.append(torch.from_numpy(cur_edge).unsqueeze(0)) # [1,h,w]
edge = torch.cat(edges, dim=0).unsqueeze(1).float() # [b,1,h,w]
return edge.to(self.device)

 

def collect_data(self, batch):
batch = default_collate(batch)
image, disp , names = batch # ms add names
image = image.to(self.device)
disp = disp.to(self.device)
rgbd = torch.cat([image, disp], dim=1) # [b,4,h,w]
b = image.shape[0]

 

cam_int = self.K.repeat(b, 1, 1) # [b,3,3]
cam_ext, cam_ext_inv = self.get_rand_ext(b) # [b,3,4]
cam_ext = cam_ext.to(self.device)
cam_ext_inv = cam_ext_inv.to(self.device)
 
# warp to a random novel view and inpaint the holes
# as the source view (input view) to the single-view view synthesis method
mesh = self.renderer.construct_mesh(rgbd, cam_int)
warp_image, warp_disp, warp_mask = self.renderer.render_mesh(mesh, cam_int, cam_ext)
 
with torch.no_grad():
src_image, src_disp = self.inpaint(warp_image, warp_disp, warp_mask)

 

return {
"src_rgb": src_image,
"src_disp": src_disp,
"tgt_rgb": image,
"tgt_disp": disp,
"warp_rgb": warp_image,
"warp_disp": warp_disp,
"cam_int": cam_int, # src and tgt view share the same intrinsic
"cam_ext": cam_ext_inv,
"names" : names
}
'''
src_rgb, src_disp 복원된 source 뷰 (inpainted)
tgt_rgb, tgt_disp 원본 target 뷰
cam_int, cam_ext camera intrinsic/extrinsic
'''

 

if __name__ == "__main__":
bs = 8
data = WarpBackStage2Dataset(
data_root="warpback/testdata2",
)
loader = DataLoader(
dataset=data,
batch_size=bs,
shuffle=True,
collate_fn=data.collect_data,
)
for idx, batch in enumerate(loader):
src_rgb, src_disp = batch["src_rgb"], batch["src_disp"]
tgt_rgb, tgt_disp = batch["tgt_rgb"], batch["tgt_disp"]
warp_rgb, warp_disp = batch["warp_rgb"], batch["warp_disp"]
 
## ms add
names = batch["names"]
 
visual = torch.cat([
warp_rgb,
warp_disp.repeat(1, 3, 1, 1),
src_rgb,
src_disp.repeat(1, 3, 1, 1),
tgt_rgb,
tgt_disp.repeat(1, 3, 1, 1),
], dim=0)
# save_image(visual, "warpback/testdata-result/stage2-%03d.jpg" % idx, nrow=bs)
# breakpoint()
# ms add
 
for i in range(src_rgb.size(0)):
pseudo_img = src_rgb[i]
save_path = os.path.join('/mnt/data/minsun3054/AdaMPI/warpback/pseudoGT', f"pseudo_{names[i]}.jpg")
save_image(pseudo_img, save_path)
 
for j in range(tgt_rgb.size(0)):
img = tgt_rgb[j]
save_path = os.path.join('/mnt/data/minsun3054/AdaMPI/warpback/pseudoGT', f"tgt_rgb_{names[j]}.jpg")
save_image(img, save_path)
 
# save_image(src_rgb, "warpback/pseudoGT/")
'''
1️⃣ warp_rgb 원본 tgt → 임의의 novel view로 warp한 결과 랜덤 시점에서 본 이미지
2️⃣ warp_disp 위 warp 결과에 대응하는 disparity (깊이)
3️⃣ src_rgb warp_rgb를 inpainting으로 보정한 결과 pseudo-source 이미지!!!!!!!!!!!!!!!!!!!!!!
4️⃣ src_disp 보정된 source의 depth
5️⃣ tgt_rgb 원본 이미지 (ground truth view) 기준 이미지
6️⃣ tgt_disp 원본 이미지의 DPT에서 나온 depth
'''

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

반응형