GNERF 源码解读(一)

主要需要了解生成器和辨别器是如何通过 patch 来进行训练的。对生成器(generator.py)、辨别器(discriminator)、光线采样(patch_sampler.py)、训练(train.py, trainer.py)等文件进行分析

1. 参数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
Generator params: 1162624, 
Discriminator params: 2642560
InversionNet params: 7168515O
ptimizable poses: 100

args: Namespace(
N_importance=64, N_samples=64, batch_size=12, begin_refine=12000, chunk=32768,dir_freq=4,
# 辨别器参数
discriminator={'type': 'rmsprop', 'lr': 0.0001, 'lr_anneal': 0.5, 'lr_anneal_every': 8000},
# azim是绕z轴旋转的角度, elev是绕y轴旋转的角度
azim_range=[0.0, 360.0], elev_range=[0.0, 90.0],
empty_cache_every=1000, epoch=-1, far=6.0, fc_depth=8, fc_dim=360, gan_type='standard',
# 生成器参数
generator={'type': 'rmsprop', 'lr': 0.0005, 'lr_anneal': 0.5, 'lr_anneal_every': 8000},
img_wh=[400, 400],
inv_net={'type': 'adam', 'lr': 0.0001, 'lr_anneal': 0.5, 'lr_anneal_every': 4000},
inv_size=64, it=-1, look_at_origin=True, max_scale=1.0, min_scale=0.04, name='GNeRF', ndc=False, ndf=64, near=2.0,
# patch_size = 16
num_epoch=40000, num_workers=6, patch_size=16, policy=['color', 'cutout'], pose_mode='3d',
progressive_end=20000, progressvie_training=True, psnr_best=-inf, radius=[4.0, 4.0], random_scale=True,
reg_param=10.0, reg_type='real', sample_every=1000, save_every=2000, scale_anneal=0.0002,
train_pose_params={'type': 'adam', 'lr': 0.005, 'lr_anneal': 0.5, 'lr_anneal_every': 4000},
val_pose_params={'type': 'adam', 'lr': 0.005, 'lr_anneal': 0.5, 'lr_anneal_every': 4000},
white_back=True, xyz_freq=10
)

1. N_importance, N_samples: 这些参数通常用于控制NeRF模型中的采样。
2. N_importance 和 N_samples 可能表示在体积渲染过程中用于积分的重要性采样点和均匀采样点的数量。
3. batch_size: 指定了每次训练迭代中使用的数据样本数量。
4. begin_refine: 可能指定了开始细化或优化模型的特定迭代次数。
5. chunk: 用于指定在训练或推理时处理的数据块的大小,影响内存使用和计算效率。
6. dir_freq: 方向频率,用于方向数据的傅里叶编码。
7. discriminator: 指定了GAN中鉴别器的参数,包括优化器类型(rmsprop),学习率(lr),学习率衰减(lr_anneal)和衰减频率(lr_anneal_every)。
8. azim_range, elev_range: 定义了摄像机的方位角(azim)和仰角(elev)的范围。这些参数控制视角的变化范围。
9. empty_cache_every: 可能指定了每多少迭代清空一次GPU内存缓存。
10. far, near: 设置了摄像机视锥的远近平面距离,影响渲染的深度范围。
11. fc_depth, fc_dim: 定义全连接层的深度和维度。
12. gan_type: 指定了GAN的类型,这里是“标准”类型。
13. generator: 指定了生成器的参数,包括优化器类型、学习率等。
14. img_wh: 图像的宽度和高度。
15. inv_net, inv_size: 指定了一个反演网络(E)的参数配置(ViT)。
16. it: 表示训练时间
17. look_at_origin, max_scale, min_scale: 控制了视点和缩放相关的参数。
18. name: 模型的名称。
19. ndc, ndf: 这些参数的具体含义可能依赖于模型的上下文。
20. num_epoch, num_workers: 分别指定了训练的总轮数和用于数据加载的工作线程数。
21. patch_size: 在处理图像时使用的补丁大小。
22. policy: 可能指定了数据增强策略。
23. pose_mode: 指定了姿态模型的类型。
24. progressive_end, progressvie_training: 指示是否使用渐进式训练及其结束时间。
25. psnr_best: 记录最好的峰值信噪比(PSNR)值。
26. radius, random_scale: 控制了摄像机位置和随机缩放的参数。
27. reg_param, reg_type: 正则化参数和类型。
28. sample_every, save_every: 控制样本采集和模型保存的频率。
29. scale_anneal: 缩放衰减参数。
30. train_pose_params, val_pose_params: 训练和验证时姿态参数的设置。
31. white_back, xyz_freq: 是否使用白色背景和空间频率。

2. 训练过程

  1. trainer.py
1
2
3
4
5
6
7
8
# 渐进式训练
if self.cfg.progressvie_training:
img_real = self.progressvie_training(img_real) # [N, 3, patch_size, patch_size]
val_imgs = self.progressvie_training(val_imgs_raw)

# update_intrinsic 更新相机内参参数,主要是来更新焦距
self.generator.ray_sampler.update_intrinsic(self.img_wh_curr / self.img_wh_end) # 根据 patch 调整焦距

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def progressvie_training(self, img):
if self.phase == 'A':
scale = 1.0 / self.cfg.begin_refine * self.it # 缩放尺度

self.img_wh_curr = self.start_img_wh + ((128.0 - self.start_img_wh) * scale).int() # 当前图像尺度
elif self.phase == 'ABAB':
img_scale_base = self.cfg.begin_refine / self.cfg.progressive_end
scale = img_scale_base + (1.0 - img_scale_base) /
(self.cfg.progressive_end - self.cfg.begin_refine) * (self.it - self.cfg.begin_refine)

self.img_wh_curr = 128 + ((self.img_wh_end - 128.0) /
(1.0 - img_scale_base) * (scale - img_scale_base)).int()
else:
return img

downsample_func = Resize((self.img_wh_curr[1], self.img_wh_curr[0]))# [400, 400] -> [16, 16] 下采样函数
img = downsample_func(img) # img [N, 3, 400(img_size), 400] -> [N, 3, 16(patch_size), 16]

return img # img [N, 3, 16, 16]

3. 光线采样

  1. ray_sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class RaySampler(object):
def __init__(self, near, far, azim_range, elev_range, radius, look_at_origin, ndc, intrinsics):
self.near = near
self.far = far
self.azim_range = azim_range # 方位角
self.elev_range = elev_range # 仰视角
self.radius = radius
self.look_at_origin = look_at_origin
self.up = (0., 0, 1)
self.ndc = ndc
self.scale = 1.0
self.start_intrinsics = intrinsics

# 更新摄像机的内参矩阵
def update_intrinsic(self, scale):
self.intrinsics = self.start_intrinsics.clone().detach()
self.intrinsics[:2] = self.intrinsics[:2] * scale[:, None] # 更新焦距

return self.intrinsics

def random_poses(self, nbatch, device='cpu'):
raes = torch.rand(nbatch, 3, device=device)

azims = raes[:, 0:1] * (self.azim_range[1] - self.azim_range[0]) + self.azim_range[0] # [batch, 1]
elevs = raes[:, 1:2] * (self.elev_range[1] - self.elev_range[0]) + self.elev_range[0] # [batch, 1]

azims = math.pi / 180.0 * azims
elevs = math.pi / 180.0 * elevs

cx = torch.cos(elevs) * torch.cos(azims)
cy = torch.cos(elevs) * torch.sin(azims)
cz = torch.sin(elevs)
T = torch.cat([cx, cy, cz], -1) # [N, 3]

radius = raes[:, 2:] * (self.radius[1] - self.radius[0]) + self.radius[0]

T = T * radius

if self.look_at_origin:
lookat = (0, 0, 0)
else:
xy = torch.randn((nbatch, 2), device=device) * self.radius[0] * 0.01
z = torch.zeros((nbatch, 1), device=device)

lookat = torch.cat((xy, z), dim=-1)

R = look_at_rotation(T, at=lookat, up=self.up, device=device) # [N, 3, 3]
RT = torch.cat((R, T[..., None]), -1) # [N, 3, 4]

return RT # [batch, 3, 4]

GNERF 源码解读(一)
http://seulqxq.top/posts/10285/
作者
SeulQxQ
发布于
2024年1月23日
许可协议