AI/Computer Vision

[CV] conditional, unconditional image generation / ImageNet class label diffusion 어떻게 이용

도도걸만단 2025. 9. 8. 15:16
반응형

1) 두 가지 학습/샘플 방식

Unconditional

  • 모델이 아무 조건 없이  이미지를 생성함. “그럴듯한 ImageNet 스타일” 이미지를 뽑아냄.
     라벨(클래스)  전혀 쓰지 않음. 폴더 이름도 상관없고, 매핑도 불필요.
  • category X, random하게 생성됨.
  • 예전 generative model 은 unconditional을 많이 사용했음.
  • but 조건 없는 생성은 랜덤이고 제어 불가능 → conditional task가 요새는 더 많이 사용됨.
  • 모델은 훈련 데이터셋 분포를 학습하면서 이미지를 생성하는데, 조건이 없다면 훈련 중에 본 이미지와 통계적으로 유사한 새 이미지를 생성하게됨.

Class-conditional 

  • 모델에 클래스 조건을 넣고 “이 클래스처럼 보여야 해”라고 가이드.
     라벨이 필수. 학습 때도 라벨을 쓰고, 샘플링할 때 원하는 클래스 인덱스를 넣어 그 클래스로 생성함.
  • 조건부 이미지 생성 : 모델이 특정한 레이블에 따라 해당 클래스가 속하는 이미지를 생성하도록 함.
  • 조건을 인식하고 해당 조건에 맞게 이미지를 생성한다는 것은 모델이 데이터 분포와 그 클래스 특성을 이해하고 있고 분류하며 활용할 줄 아는 것을 의미함.
  • conditional : 클래스 레이블 뿐만 아니라 텍스트나 이미지도 조건이 될 수 있다. 그 중 이미지넷에서 class-conditional은 카테고리를 레이블로 주는 것을 의미.

 

2) ImageNet과 클래스명의 역할

  • ImageNet-1K에는 1000개 클래스가 있고, 각 클래스 폴더(보통 WordNet synset: n01440764 같은 이름)에 이미지가 들어 있음.
  • 모델이 직접 문자열 클래스명을 읽는 게 아니라, 폴더 이름(문자열)정수 인덱스(0~999) 로 바꿔서 씀.
    이 매핑은 보통 torchvision.datasets.ImageFolder가 알파벳순으로 자동 생성함.
  • 이름 자체가 중요한게 아니라, 이름이 어떤 정수 인덱스에 매핑됐는가라는 일관성이 중요함.
  • 클래스명은 사람에게 보여줄 라벨일 뿐, 모델은 정수 인덱스만 사용한다
  •  
  • 조건부 모델은 학습 때 이미지 ↔ 라벨(정수) 쌍을 배움.
  • 나중에 샘플링에서 “라벨 207(예: 골드피시)”을 넣으면 학습 때의 207번 클래스 분포를 따라 이미지를 그림.
  • 만약 특징추출 → 학습 → 샘플링 단계 중 하나라도 클래스→인덱스 매핑이 달라지면
    (예: 폴더가 추가/삭제되어 알파벳 정렬 순서가 바뀜)
    → 학습은 “207=골드피시”로 배웠는데
    → 샘플링은 “207=전혀 다른 클래스”로 해석 ⇒ 엉뚱한 이미지

3) DiT/fast-DiT에서 라벨이 들어가는 방식

  • 모델은 num_classes=C(ImageNet-1K면 1000)을 알고, 각 라벨을 임베딩으로 바꿔 네트워크에 주입합니다.
  • Classifier-Free Guidance(CFG) 를 쓰면 “라벨을 일부 확률로 빼고(null)”도 학습합니다.
    이때 null 라벨의 인덱스는 보통 C 로 둡니다.
    → 샘플러에서 y_null = num_classes(=1000) 같은 코드가 나오는 이유.
  • 샘플링 시 cfg_scale=1이면 “가이드 없음” (조건부/비조건부 동일 출력을 유도),
    >1이면 조건 신호를 더 강하게 씁니다.

6) evaluation & label ?

  • FID 는 라벨을 쓰지 않습니다(분포 비교).
  • Inception Score(IS) 도 GT 라벨을 직접 쓰진 않지만, Inception 분류기의 예측 분포를 이용.
  • 클래스 조건부 샘플링을 하려면 당연히 라벨 인덱스를 알아야 하죠(무슨 클래스로 뽑을지 지정).

7) unconditional로 하고 싶다면?

  • 가장 단순히는 라벨을 무시하고 학습(데이터로더에서 라벨을 쓰지 않는 버전)하면 됩니다.
  • DiT 계열에서 라벨 입력을 꼭 요구한다면, 모든 샘플에 같은 상수 라벨을 주거나, label-dropout=1.0(항상 null)로 학습하는 변형을 씁니다.
    (레포 코드 구조에 따라 옵션이 달라서, 제공 옵션에 맞춰 설정)

diffusion sampling code 분석 (ex. fast dit)


1) 스크립트의 목적

  • 사전학습/학습한 DiT로 **대량의 이미지를 병렬 생성(DDP)**하고
  • PNG로 저장한 뒤, FID 계산용 .npz(arr_0=이미지 텐서)를 만들어 ADM 평가 코드가 바로 쓰게 하는 것.

2) 핵심 인자(마지막 argparse)

  • --model: DiT 아키텍처 키 (예: DiT-XL/2, DiT-B/4 등).
  • --image-size: 256/512. VAE latent는 image_size/8로 계산됨.
  • --num-classes: 클래스-조건 개수(기본 1000=ImageNet-1k).
  • --vae: 디코더 선택(ema/mse) → stabilityai/sd-vae-ft-<vae>를 허깅페이스에서 로드.
  • --cfg-scale: CFG 스케일(>1이면 guidance 켬).
  • --num-sampling-steps: DDPM/DiT 샘플링 스텝 수(기본 250).
  • --per-proc-batch-size: GPU당 배치.
  • --num-fid-samples: 최종 저장할 이미지 총 수(기본 50k).
  • --ckpt: 쓸 체크포인트(.pt). 없으면 XL/2 공개 가중치 자동 다운로드.

3) 실행 순서(상단부터 한 줄씩 흐름)

(a) 기본 세팅

  • torch.backends.cuda.matmul.allow_tf32 = args.tf32 : TF32 허용(암페어↑ 속도↑, 수치 미세차 가능).
  • assert torch.cuda.is_available() : DDP라 GPU 필수.
  • torch.set_grad_enabled(False) : 샘플링이라 미분 끔.

(b) DDP 초기화

dist.init_process_group("nccl")
rank = dist.get_rank()
device = rank % torch.cuda.device_count()
seed = args.global_seed * dist.get_world_size() + rank
torch.manual_seed(seed); torch.cuda.set_device(device)
  • NCCL 백엔드로 프로세스그룹 생성.
  • 각 프로세스의 rank로 사용할 GPU시드를 정함(랭크마다 다른 시드).

(c) 모델/확산/VA E 로드

latent_size = args.image_size // 8
model = DiT_models[args.model](input_size=latent_size, num_classes=args.num_classes).to(device)
state_dict = find_model(args.ckpt or f"DiT-XL-2-{args.image_size}x{args.image_size}.pt")
model.load_state_dict(state_dict); model.eval()
diffusion = create_diffusion(str(args.num_sampling_steps))
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
  • DiTVAE latent 크기(H/8 × W/8)를 입력으로 받음.
  • find_model은 ckpt 경로가 없으면 공개 가중치 자동 다운로드.
  • create_diffusion(steps)로 샘플링 스케줄 준비.
  • Hugging Face에서 VAE 로드(ema/mse 중 택1).

⚠️ 주의: HF에서 각 rank가 동시에 VAE를 받으면 네트워크/디스크 병목이나 권한문제가 생길 수 있어. 보통 캐시 경로를 넉넉한 디스크로 지정(HF_HOME 등)하거나, rank 0만 받게 하고 dist.barrier() 이후 로드하는 패턴을 쓰기도 해.

(d) CFG 여부/저장폴더 구성

using_cfg = args.cfg_scale > 1.0
folder = f"{model_string}-{ckpt_string}-size-{image_size}-vae-{vae}-cfg-{cfg}-seed-{seed0}"
os.makedirs(folder, exist_ok=True)  # rank 0만 수행
dist.barrier()
  • CFG 쓰면 배치가 2배로 커짐(아래 참조). 폴더는 랭크 공통 경로.

(e) 샘플 개수의 균등 분배(중요 수식)

n = args.per_proc_batch_size       # GPU당 배치
W = dist.get_world_size()          # GPU 수
global_batch = n * W
total_samples = ceil(num_fid / global_batch) * global_batch
samples_per_gpu = total_samples // W
iterations = samples_per_gpu // n
  • num_fid_samples가 W×n으로 딱 나눠떨어지지 않아도 올림해서 조금 더 뽑고 끝에 버리기 쉽게 맞춤.
  • 모든 GPU가 **동일 횟수(iterations)**를 돌도록 보장.

(f) 메인 루프: 노이즈 → DiT → VAE → PNG

  1. 노이즈/라벨 샘플링
  2. z = torch.randn(n, model.in_channels, latent, latent, device=device) y = torch.randint(0, args.num_classes, (n,), device=device)
  3. CFG 설정(있으면 배치×2)
    • Classifier-Free Guidance: 실제 라벨 배치 + “null 라벨” 배치를 합쳐 모델을 한 번에 추론, 두 출력을 가중합하여 강화.
    • 여기서 null 라벨=1000 고정:
      중요한 가정: DiT의 클래스 임베딩 테이블이 1000(=ImageNet 클래스 수) + 1(=null) 로 만들어졌다고 전제.
      → 만약 num_classes != 1000인 커스텀 데이터라면 버그 포인트가 됨(랭크가 죽을 수 있음). 그래서 깃허브 이슈에서 “[args.num_classes]로 바꿔라”는 말이 나온 것.
  4. if using_cfg: z = torch.cat([z, z], 0) # 배치 2배 y_null = torch.tensor([1000] * n, device=device) # ← Null-class 인덱스 y = torch.cat([y, y_null], 0) model_kwargs = dict(y=y, cfg_scale=args.cfg_scale) sample_fn = model.forward_with_cfg else: model_kwargs = dict(y=y) sample_fn = model.forward
  5. 확산 샘플링
    • p_sample_loop는 T→0로 역확산하여 latent를 생성.
    • CFG 사용 시 반쪽은 null이므로 결과에서 절반만 남김.
  6. samples = diffusion.p_sample_loop(sample_fn, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=False, device=device) if using_cfg: samples, _ = samples.chunk(2, dim=0) # null 절반 제거
  7. VAE 디코딩 & 후처리
    • 0.18215: Stable Diffusion VAE의 latent 스케일 팩터(표준). 학습·추론 시 latent ↔ image 사이의 정규화 상수.
    • 샘플을 [0,255] 8bit RGB로 변환.
  8. samples = vae.decode(samples / 0.18215).sample samples = torch.clamp(127.5 * samples + 128.0, 0, 255) \ .permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
  9. 전 세계 인덱싱으로 PNG 저장
    • 랭크 간 파일명 충돌 방지를 위해 전역 인덱스 공식을 사용:
      index = (해당 미니배치 내 i) * W + rank + (지금까지 전역으로 저장된 개수)
    • 모든 GPU가 서로 다른 파일 이름으로 저장.
  10. index = i * W + rank + total Image.fromarray(sample).save(f"{folder}/{index:06d}.png") total += global_batch_size (루프 밖에서 누적)

(g) 동기화 & .npz 만들기

dist.barrier()  # 모두 PNG 저장 완료 대기
if rank == 0:
    create_npz_from_sample_folder(folder, args.num_fid_samples)
dist.barrier()
dist.destroy_process_group()
  • rank 0이 앞선 PNG들 중 앞에서부터 정확히 num_fid_samples장만 읽어 **.npz**로 묶음.

4) 자주 헷갈리는 포인트/함정

  1. Null-class=1000 상수
    • ImageNet-1k 전용 가정. 클래스 수가 다른 데이터에선 문제가 된다(임베딩 범위 밖 접근, 일부 랭크 크래시 → 나머지 NCCL 타임아웃).
    • 그런 경우엔 보통 임베딩을 num_classes+1로 만들고 null 인덱스를 num_classes로 맞춘다.
  2. CFG를 켜면 메모리 2배
    • 배치 두 배로 합쳐 추론 → OOM 한 랭크 발생 → 그 랭크만 죽음 → 나머지 랭크는 barrier/ALLGATHER에서 10분 대기 후 타임아웃 패턴이 나온다.
  3. VAE/HF 다운로드
    • 각 랭크가 동시에 받아 디스크/권한 오류가 나면 위와 같은 랭크 불일치 → NCCL 타임아웃으로 이어진다.
    • 캐시 경로를 큰 디스크로 지정하고, 가능하면 rank 0만 다운로드 + barrier 이후 로드가 안정적.
  4. 디스크 부족
    • PNG 50k장 + .npz는 수~수십 GB. 한 랭크가 I/O 에러로 먼저 죽으면 같은 증상이 난다.
  5. 샘플 수 배분 수학
    • num_fid_samples가 W*n으로 나누어 떨어지지 않으면 약간 더 뽑고 맨 뒤에서 버림. .npz 생성 시 앞에서부터 정확히 num_fid만 사용한다.

5) 이 코드가 “언제” 멈추거나 터질 수 있는 지점

  • 모델/체크포인트 로드 직후: 경로/권한/파일 손상.
  • VAE 로드 시점: HF 인증/네트워크/캐시·디스크.
  • 첫 번째 dist.barrier()(폴더 생성 뒤): 어떤 랭크가 위 단계에서 죽으면 여기서 나머지 랭크가 영원히 대기 → 10분 뒤 NCCL 타임아웃.
  • 루프 내부: OOM(특히 CFG)·디스크 꽉참 → 한 랭크만 크래시 → 다음 barrier() 에서 타임아웃.

6) 파라미터가 결과/안정성에 미치는 영향

  • per-proc-batch-size ↑ : 속도↑, 메모리↑. CFG 켜면 배치 실제 2배.
  • num-sampling-steps ↑ : 품질 대체로 약간↑, 시간↑. 너무 크면 오히려 과정 왜곡(overshoot) 이나 스케줄 불일치로 품질 저하가 생길 수 있음(네가 겪은 패턴).
  • cfg-scale : 너무 낮으면 조건 약함, 너무 높으면 과도한 강조로 깨짐/그림자 현상.
  • vae : 학습/추출/샘플링 VAE 일관성이 중요(ema↔mse 섞이면 색/질감 깨짐).

7) 한눈에 보는 전체 파이프라인

  1. DDP 초기화(rank, world_size, device, seed).
  2. DiT 로드(latent 크기 = H/8, num_classes 적용).
  3. 확산 스케줄 생성(sampling steps).
  4. VAE 로드(ema/mse).
  5. 샘플 총량 계산 → GPU별 균등 분배.
  6. 루프:
    • z ~ N(0, I), y ~ Uniform(0..C-1).
    • CFG면 (z,y) 복제 + null 레이블 합치기 → 한 번에 추론.
    • 역확산으로 latent 생성.
    • VAE 디코딩(+0.18215 스케일 보정) → RGB 8bit.
    • 전역 인덱스로 PNG 저장.
  7. barrier로 동기화 → rank 0이 앞에서부터 num_fid장 읽어 .npz 생성.
  8. 종료.

 

 


reference

https://flashsummit.tistory.com/24

반응형