초보 개발자의 이야기, 릿허브

3D point cloud landmark detection 논문 관련 백업 및 정리 본문

논문/💻 논문 구현

3D point cloud landmark detection 논문 관련 백업 및 정리

릿99 2025. 12. 18. 23:09
728x90
반응형
  1. dataloader에서 포인트/GT 정렬이 어떻게 보장되는지
  2. baseline에서 forward → shape 정리 → loss
  3. evaluation에서 indices로 원본 좌표 복원하는 핵심

1) dataloader_baseline.py 핵심: FPS로 points와 heatmap을 같은 인덱스로 정렬

아래 블록이 point cloud processing에서 제일 중요한 부분
포인트를 FPS로 6890개로 줄였으면, GT heatmap도 반드시 같은 ply_index로 줄여야 “포인트 i의 GT”가 유지

# (dataloader_baseline_annotated.py 중 핵심 구간)

ply_mesh = trimesh.load(ply_path, process=False)

# 원본 포인트 (N_raw, 3)
ply_point_o = ply_mesh.vertices

# FPS 인덱스 생성: (1, N_raw, 3) -> idx (1, 6890)
# - farthest point sampling은 random sampling보다 공간적으로 고르게 포인트를 뽑는 장점이 있음
ply_point = torch.tensor(ply_point_o).float().unsqueeze(0)       # (1, N_raw, 3)
ply_index = farthest_point_sample(ply_point, 6890).squeeze(0)    # (6890,)

# PLY 포인트 다운샘플: points와 GT를 동일 인덱스로 정렬하기 위한 기준
ply_points = ply_point_o[ply_index]                              # (6890, 3)

# GT heatmap 로드 (보통 (N_raw, num_landmarks) 형태)
heatmap = np.load(heatmap_path)

# GT heatmap도 동일한 ply_index로 다운샘플
# - 포인트 i의 heatmap 값이 의미를 갖도록 points와 정렬 유지
heatmap = heatmap[ply_index]                                     # (6890, num_landmarks)

# indices 저장: 다운샘플된 포인트가 원본 PLY의 어느 정점이었는지 기록
# - 평가 시 argmax 인덱스를 원본 좌표로 복원하는 데 사용
self.indices.append(ply_index)

return points, heatmap, ply_name, indices

2) baseline.py 핵심: model 출력 shape을 GT에 맞추고 MSE로 학습

이 코드가 heatmap regression 형태라서 loss가 MSE
그리고 모델 출력이 (B, C, N)인 경우가 많아서 (B, N, C)로 permute해서 GT랑 shape을 맞추는 패턴

# (baseline_annotated.py 중 test() 내부 핵심 구간)

ply, heatmap, ply_name, indices = data

# NOTE: 현재 코드는 Tensor -> numpy -> Tensor로 다시 변환.
#       시험 답안에서는 Dataset에서 이미 torch.Tensor로 주면 이 변환을 생략하고
#       ply = ply.to(device) 형태로 바로 옮기는 편이 효율적.
ply = ply.cpu().numpy()
heatmap = heatmap.cpu().numpy()

ply = torch.tensor(ply).cuda()
heatmap = torch.tensor(heatmap).cuda()

with torch.no_grad():
    # (B, N, 3) 포인트 입력 -> (B, C, N) 또는 (B, N, C) 형태의 heatmap 예측
    pred_heatmap = model(ply)

    # pred_heatmap 차원을 GT heatmap과 맞추기 위한 permute
    # 예: model output이 (B, C, N)이고 GT가 (B, N, C)인 경우
    pred_heatmap = pred_heatmap.permute(0, 2, 1)

    # heatmap regression: 연속값(gaussian heatmap)을 맞추므로 MSE 사용
    loss = F.mse_loss(pred_heatmap, heatmap)

3) baseline.py 핵심: 평가에서 indices로 원본 PLY 좌표 복원

heatmap의 argmax는 “다운샘플된 6890 포인트 안에서의 인덱스”
하지만 GT(.lnd) 좌표 비교나 시각화는 원본 PLY 좌표계에서 수행, indices를 이용해 원본 정점 인덱스로 복원하는 과정이 핵심

# (baseline_annotated.py 중 calculate_error_baseline() 핵심 구간)

# 각 landmark 채널에서 heatmap 값이 최대인 포인트 인덱스를 선택 (argmax)
max_index = np.argmax(pred_heatmap_single)

# IMPORTANT: max_index는 '다운샘플된 6890개 포인트' 기준 인덱스
#            indices (ply_index)를 통해 원본 PLY의 정점 인덱스로 복원해야
#            원본 좌표계에서 landmark 좌표를 올바르게 가져올 수 있음
pred_coordinate = np.asarray(ply.points)[index[max_index]]

'''
For compare baseline models
1. PointNet
2. PointNet++
3. Point Transformer
4. DGCNN
5. GTNet
6. PTv1
'''
import argparse
import numpy as np
import open3d as o3d
import trimesh
import os
...

import torch
import torch.nn.functional as F
from tqdm import tqdm

# device 처리 통일: 코드 전체에서 이 DEVICE만 사용
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

...
def test(model, loader):
    total_error = [0.0] * num_classes
    total_batches = 0

  
    # - model.eval()는 테스트 시작 시 1회만
    # - batch마다 numpy 변환 없이 바로 to(DEVICE)
    model.eval()

    for j, data in tqdm(enumerate(loader), total=len(loader)):
        ply, heatmap, ply_name, indices = data

        # numpy 왕복 제거: Dataset이 torch.Tensor를 주도록 dataloader 쪽도 정리
        # non_blocking=True는 pin_memory=True일 때 전송 최적화
        ply = ply.to(DEVICE, non_blocking=True)
        heatmap = heatmap.to(DEVICE, non_blocking=True)

        with torch.no_grad():
            pred_heatmap = model(ply)

            # 원본 코드의 의도 유지: 모델 출력 shape이 (B, C, N)인 경우가 많아 GT와 맞추기 위해 permute
            # GT heatmap이 (B, N, C)라면 아래 permute가 필요
            pred_heatmap = pred_heatmap.permute(0, 2, 1)

            loss = F.mse_loss(pred_heatmap, heatmap)

            # error 계산 함수가 numpy 기반이면 여기서만 cpu().numpy() 변환
            # 핵심: 입력 ply/heatmap을 처음부터 numpy로 바꿨다가 다시 tensor로 만들지 않는다
            pred_heatmap_np = pred_heatmap.detach().cpu().numpy()
            heatmap_np = heatmap.detach().cpu().numpy()
            indices_np = indices.detach().cpu().numpy()

            # error = calculate_error_baseline('test', ply_name, pred_heatmap_np, indices_np)
            # total_error[...] += error[...]
            # total_batches += 1
            ...

    ...
    return total_error


def main():
    ...
    # device 변수도 통일
    device = DEVICE

    ...
    TRAIN_DATASET = Baseline_SizeKoreaDataLoader_LT(
        ply_root=NM_PLY_TRAIN_DATA_PATH,
        heatmap_root=HEATMAP_TRAIN_DATA_PATH,
        sigma=sigma,
        mode="TRAIN"
    )
    TEST_DATASET = Baseline_SizeKoreaDataLoader_LT(
        ply_root=NM_PLY_TEST_DATA_PATH,
        heatmap_root=HEATMAP_TEST_DATA_PATH,
        sigma=sigma,
        mode="TEST"
    )

    # DataLoader 정리
    # - train: shuffle=True
    # - test: shuffle=False
    # - pin_memory: GPU 사용 시 host->device 전송 최적화
    # - num_workers는 환경에 따라 0~4 사이 조정 가능
    trainDataLoader = torch.utils.data.DataLoader(
        TRAIN_DATASET,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=torch.cuda.is_available()
    )
    testDataLoader = torch.utils.data.DataLoader(
        TEST_DATASET,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=torch.cuda.is_available()
    )

    ...
    model = model.to(DEVICE)

    ...
    for epoch in range(start_epoch, epochs):
        model.train()
        for i, data in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader)):
            ply, heatmap, ply_name, indices = data

            # numpy 왕복 제거
            ply = ply.to(DEVICE, non_blocking=True)
            heatmap = heatmap.to(DEVICE, non_blocking=True)

            pred_heatmap = model(ply)
            pred_heatmap = pred_heatmap.permute(0, 2, 1)

            loss = F.mse_loss(pred_heatmap, heatmap)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            ...
        # test는 위에서 model.eval() 포함해서 test() 내부에서 처리
        test_error = test(model, testDataLoader)
        ...

    ...


if __name__ == '__main__':
    main()
import numpy as np
import os
from torch.utils.data import Dataset
import torch
import json
import trimesh
import open3d as o3d
import potpourri3d as pp3d
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.neighbors import NearestNeighbors
from pointnet_util import farthest_point_sample
import potpourri3d as pp3d


class Baseline_SizeKoreaDataLoader_LT(Dataset):
    '''K for landmark transformer neighbor'''
    def __init__(self, ply_root, heatmap_root, sigma, mode):
        self.ply_root = ply_root    # normalized_ply
        self.heatmap_root = heatmap_root
        self.sigma = sigma
        self.mode = mode

        self.points_list = []
        self.heatmap_list = []
        self.ply_names = []
        self.indices = []

        ...
        # 원본 코드에서 전처리/캐시 로딩 로직이 여기서 수행되는 구조 유지
        # points_list, heatmap_list, indices에 (가능하면) torch.Tensor로 저장해두면 제일 좋음
        ...

        data = {
            'points': self.points_list,
            'heatmaps': self.heatmap_list,
            'ply_names': self.ply_names,
            'indices': self.indices
        }
        torch.save(data, 'Baseline_SizeKorea_preprocessed_data_' + self.mode + '_sigma=' + str(self.sigma) + '.pth')

    def __len__(self):
        return len(self.points_list)

    def __getitem__(self, idx):
        points = self.points_list[idx]
        heatmap = self.heatmap_list[idx]
        ply_name = self.ply_names[idx]
        indices = self.indices[idx]

        # 핵심
        # - Dataset에서 torch.Tensor를 반환하면 학습 루프에서 numpy 왕복이 완전히 사라짐
        # - baseline.py에서 ply/heatmap은 바로 .to(DEVICE)만 수행하면 됨

        # points: (N, 3) float32
        if not torch.is_tensor(points):
            points = torch.tensor(points, dtype=torch.float32)
        else:
            points = points.to(dtype=torch.float32)

        # heatmap: (N, num_landmarks) float32
        if not torch.is_tensor(heatmap):
            heatmap = torch.tensor(heatmap, dtype=torch.float32)
        else:
            heatmap = heatmap.to(dtype=torch.float32)

        # indices: (N,) int64 (원본 ply 정점 인덱스 매핑)
        if not torch.is_tensor(indices):
            indices = torch.tensor(indices, dtype=torch.long)
        else:
            indices = indices.to(dtype=torch.long)

        return points, heatmap, ply_name, indices
728x90
반응형