pi-GAN源码分析

1. train.py

1. load_images(images, curriculum, device): 加载图像数据,批量加载。
2. z_sampler(shape, device, dist):生成随机噪声,每次生成25张图像的噪声。
3. train(rank, world_size, opt):开始正式训练

1)生成器和辨别器的初始化:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# generator : ImplicitGenerator3d  由generators进入generator.py
"""
返回:pixels, torch.cat([pitch, yaw], -1)
pixels: rgb_final
"""
generator = getattr(generators, metadata['generator'])(SIREN, metadata['latent_dim']).to(device) # latent_dim = 256

"""
return: prediction, latent, position
"""
# discriminator : ProgressiveEncoderDiscriminator 由discriminators进入discriminator.py
discriminator = getattr(discriminators, metadata['discriminator'])().to(device) # 判别器

# 用于对生成器 generator 的参数进行指数移动平均处理
# decay: 决定了历史权重在平均中的影响大小
ema = ExponentialMovingAverage(generator.parameters(), decay=0.999)
ema2 = ExponentialMovingAverage(generator.parameters(), decay=0.9999)

2)训练辨别器:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# TRAIN DISCRIMINATOR 训练判别器
with torch.cuda.amp.autocast():
# Generate images for discriminator training
with torch.no_grad():
# 随机噪声 z [9, 256]
z = z_sampler((real_imgs.shape[0], metadata['latent_dim']), device=device, dist=metadata['z_dist'])
logging.info(f"train discriminator z.shape: {z.shape}") # [9, 256]
# real_imgs.shape[0] == z.shape[0] --> batch_size(9)
split_batch_size = z.shape[0] // metadata['batch_split']
gen_imgs_film = []
gen_imgs = []
gen_positions = []
for split in range(metadata['batch_split']): # 循环 batch_split 次
subset_z = z[split * split_batch_size:(split+1) * split_batch_size] # subset_z [3, 256]

# g_imgs --> pixels [batch_split, 3(RGB), img_size, img_size],
# g_pos -> torch.cat([pitch, yaw], -1) [batch_size, 2]
# --> generator.py --> ImplicitGenerator3d --> forward()
g_imgs, g_pos = generator_ddp(subset_z, **metadata)
gen_imgs.append(g_imgs)
gen_positions.append(g_pos)
gen_imgs = torch.cat(gen_imgs, axis=0) # [9, 3, 32, 32] [batch_size, channels, img_size, img_size]
gen_positions = torch.cat(gen_positions, axis=0) # [9, 2] [batch_size, 2]

real_imgs.requires_grad = True

# prediction [batch_size, 1]
# 真实图像的 预测结果
# --> discriminator.py --> ProgressiveEncoderDiscriminator --> forward()
r_preds, _, _ = discriminator_ddp(real_imgs, alpha, **metadata)

if metadata['r1_lambda'] > 0: # 梯度下降
# Gradient penalty 真实图像 梯度惩罚
grad_real = torch.autograd.grad(outputs=scaler.scale(r_preds.sum()), inputs=real_imgs,
create_graph=True)
inv_scale = 1./scaler.get_scale()
grad_real = [p * inv_scale for p in grad_real][0] # 计算梯度 p [batch_size, 3, img_size, img_size]

with torch.cuda.amp.autocast():
if metadata['r1_lambda'] > 0:
grad_penalty = (grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2).mean()
grad_penalty = 0.5 * metadata['r1_lambda'] * grad_penalty
else:
grad_penalty = 0

# 生成器生成的图像 gen_imgs
# gen_img_prediction [batch_size, 1], latent [batch_size, 256],
# position [batch_size, 2]
g_pred_latent_film, g_pred_position_film =
discriminator_ddp(gen_imgs_film, alpha, **metadata)
g_preds, g_pred_latent, g_pred_position = discriminator_ddp(gen_imgs, alpha,
**metadata)
# 生成图像 梯度惩罚
if metadata['z_lambda'] > 0 or metadata['pos_lambda'] > 0:
# latent code惩罚
latent_penalty = torch.nn.MSELoss()(g_pred_latent, z) * metadata['z_lambda']
position_penalty = torch.nn.MSELoss()(g_pred_position, gen_positions) *
metadata['pos_lambda'] # 位置信息惩罚

identity_penalty = latent_penalty + position_penalty
else:
identity_penalty=0
# g_preds: 生成图像的预测结果 r_preds: 真实图像的预测结果
d_loss = torch.nn.functional.softplus(g_preds).mean() +
torch.nn.functional.softplus(-r_preds).mean() + grad_penalty + identity_penalty
discriminator_losses.append(d_loss.item())

3)训练生成器:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# TRAIN GENERATOR 训练生成器
# 随机噪声 z.shape [batch_size, 256]
z = z_sampler((imgs.shape[0], metadata['latent_dim']), device=device, dist=metadata['z_dist'])
split_batch_size = z.shape[0] // metadata['batch_split']

for split in range(metadata['batch_split']):
with torch.cuda.amp.autocast():
subset_z = z[split * split_batch_size:(split+1) * split_batch_size] # subset_z [batch_split, 256]

# gen_imgs [batch_split, 3(rgb), img_size, img_size] gen_positions [batch_split, 2]
gen_imgs, gen_positions = generator_ddp(subset_z, **metadata) # 通过噪声生成图像 以及 姿态信息

# g_preds [batch_split, 1] g_pred_latent [batch_split, 256] g_pred_position [batch_split, 2]
# --> ProgressiveEncoderDiscriminator --> forward()
g_preds, g_pred_latent, g_pred_position = discriminator_ddp(gen_imgs, alpha, **metadata)
topk_percentage = max(0.99 ** (discriminator.step/metadata['topk_interval']), metadata['topk_v'])
if 'topk_interval' in metadata and 'topk_v' in metadata else 1

topk_num = math.ceil(topk_percentage * g_preds.shape[0])

g_preds = torch.topk(g_preds, topk_num, dim=0).values # 选取topk_num个最大值

if metadata['z_lambda'] > 0 or metadata['pos_lambda'] > 0:
# 辨别器 latent code 惩罚 位置信息惩罚 信任度惩罚
latent_penalty = torch.nn.MSELoss()(g_pred_latent, subset_z) * metadata['z_lambda']
position_penalty = torch.nn.MSELoss()(g_pred_position, gen_positions) * metadata['pos_lambda']
identity_penalty = latent_penalty + position_penalty
else:
identity_penalty = 0

g_loss = torch.nn.functional.softplus(-g_preds).mean() + identity_penalty
generator_losses.append(g_loss.item())

scaler.scale(g_loss).backward()

2. siren.py

  1. sine_init(m):初始化sine函数
1
2
3
4
5
6
7
def sine_init(m):
with torch.no_grad():
if isinstance(m, nn.Linear):
num_input = m.weight.size(-1) # 输入特征数量
# uniform_() 方法将张量中的每个元素初始化为从均匀分布中获取的值
# -np.sqrt(6 / num_input) / 30, np.sqrt(6 / num_input) / 30 He正太初始化
m.weight.uniform_(-np.sqrt(6 / num_input) / 30, np.sqrt(6 / num_input) / 30)
  1. CustomMappingNetwork类:映射网络,将随机噪声生成频率和相位。
1
2
3
4
5
6
7
# 一次性计算出所有层的频率 γ 和相位 β
def forward(self, z):
frequencies_offsets = self.network(z) # [batch_size, 4608] 4608 = (8+1)*256*2
frequencies = frequencies_offsets[..., :frequencies_offsets.shape[-1]//2] # [batch_size, 2304]
phase_shifts = frequencies_offsets[..., frequencies_offsets.shape[-1]//2:] # [batch_size, 2304]

return frequencies, phase_shifts
  1. FiLMLayer类:隐藏层的计算
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class FiLMLayer(nn.Module):
def __init__(self, input_dim, hidden_dim):
super().__init__()
# 初始化线性层
self.layer = nn.Linear(input_dim, hidden_dim) # 初始化Linear层 Linear --> FiLMLayer
# x 输入的位置坐标信息 position x
def forward(self, x, freq, phase_shift):
logging.info(f"FiLMLayer input_x.shape: {x.shape}")
# 计算线性层 x.shape [batch_size, num_rays*num_step, hidden_dim] layer --> Linear FiLMLayer input
x = self.layer(x)
logging.info(f"FiLMLayer output_x.shape: {x.shape}")
# 通过激活函数计算 FiLM SIREN sin(γx + β)
freq = freq.unsqueeze(1).expand_as(x) # freq.shape [batch_size, num_rays*num_step, hidden_dim]
# phase_shift.shape [batch_size, num_rays*num_step, hidden_dim]
phase_shift = phase_shift.unsqueeze(1).expand_as(x)

# [batch_size, num_rays*num_step, hidden_dim] FiLM SIREN sin(γ(wx+b) + β)
return torch.sin(freq * x + phase_shift)
  1. TALLSIREN模型:SIREN-MLP模型,得出通过MLP的结果
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
class TALLSIREN(nn.Module):
"""Primary SIREN architecture used in pi-GAN generators."""
""" 用于pi-GAN 生成器的主要SIREN架构。"""
def __init__(self, input_dim=2, z_dim=100, hidden_dim=256, output_dim=1, device=None):
super().__init__()
self.device = device
self.input_dim = input_dim # 3 (x, y, z)张图片初始化
self.z_dim = z_dim # 256
self.hidden_dim = hidden_dim # 256
self.output_dim = output_dim # 4
# TALLSIREN: input_dim.shape: 3, output_dim.shape: 4
logging.info(f"TALLSIREN: input_dim.shape: {input_dim}, output_dim.shape: {output_dim}")
self.network = nn.ModuleList([ # 8个FiLM SIREN 层 [3, 256], [256, 256] ... [256, 256]
FiLMLayer(input_dim, hidden_dim),
FiLMLayer(hidden_dim, hidden_dim),
FiLMLayer(hidden_dim, hidden_dim),
FiLMLayer(hidden_dim, hidden_dim),
FiLMLayer(hidden_dim, hidden_dim),
FiLMLayer(hidden_dim, hidden_dim),
FiLMLayer(hidden_dim, hidden_dim),
FiLMLayer(hidden_dim, hidden_dim),
])
self.final_layer = nn.Linear(hidden_dim, 1) # [256, 1] alpha输出层

self.color_layer_sine = FiLMLayer(hidden_dim + 3, hidden_dim) # 加 ray direction d [256+3, 256]
# c(x, d) [256, 3] rgb输出层 普通线性层

self.color_layer_linear = nn.Sequential(nn.Linear(hidden_dim, 3), nn.Sigmoid())

# mapping network output_dim = (8+1)*256*2
self.mapping_network = CustomMappingNetwork(z_dim, 256, (len(self.network) + 1)*hidden_dim*2)

# 一次 25 张图片初始化
self.network.apply(frequency_init(25)) # 8 层 FiLMLayer 进行初始化
self.final_layer.apply(frequency_init(25)) # alpha 输出层初始化
self.color_layer_sine.apply(frequency_init(25)) # rgb 额外层初始化
self.color_layer_linear.apply(frequency_init(25)) # rgb 输出层初始化
self.network[0].apply(first_layer_film_sine_init) # 第一层 FiLMLayer 进行初始化

# input -> transformed_points (generator.py -> coarse_output)
def forward(self, input, z, ray_directions, **kwargs):
frequencies, phase_shifts = self.mapping_network(z) # 从mapping network中获取频率 γ 和相位 β
return self.forward_with_frequencies_phase_shifts(input, frequencies, phase_shifts, ray_directions,
**kwargs)

# SIREN MLP 网络 计算 SIREN sin(γx + β) 输出RGB 和 alpha
def forward_with_frequencies_phase_shifts(self, input, frequencies, phase_shifts, ray_directions, **kwargs):

frequencies = frequencies*15 + 30
# x.shape [batch_size, num_rays*num_steps, 3] input -> points
x = input

for index, layer in enumerate(self.network): # 8层隐藏层的计算
# layer == FiLMLayer(i) FiLM SIREN sin(γx + β)
# 每次取一层的 γ 和 β (end - start) == 256
start = index * self.hidden_dim
end = (index+1) * self.hidden_dim
# x.shape [batch_size, num_rays*num_steps, hidden_dim]
x = layer(x, frequencies[..., start:end], phase_shifts[..., start:end]) # -> FiLMLayer forward
# x通过8层 MLP计算后 的最终输出维度为 [batch_size, num_rays*num_steps, hidden_dim]
logging.info(f"forward_with_frequencies_phase_shifts after x.shape: {x.shape}")
sigma = self.final_layer(x) # sigma [batch_size, num_rays*num_steps, 1] alpha 输出层

# ray_directions d [batch_size, num_rays*num_steps, 3]

# 最后一层的 γ 和 β 259 -> 256
rbg = self.color_layer_sine(torch.cat([ray_directions, x], dim=-1),
frequencies[..., -self.hidden_dim:], phase_shifts[..., -self.hidden_dim:])
rbg = self.color_layer_linear(rbg) # rgb [batch_size, num_rays*num_steps] 输出层 256 -> 3

# 生成图片的时候需要将 alpha 和 rgb 拼接在一起 然后输入到volume rendering中渲染
return torch.cat([rbg, sigma], dim=-1) # return [batch_size, num_rays*num_steps, 4] rgb + alpha

3. generators.py

  1. ImplicitGenerator3d类:生成器模型,用来生成图像
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# 传入 (SIREN, metadata['latent_dim'])
class ImplicitGenerator3d(nn.Module):
def __init__(self, siren, z_dim, **kwargs):
super().__init__()
self.z_dim = z_dim # 256 (latent_dim)
self.siren = siren(output_dim=4, z_dim=self.z_dim, input_dim=3, device=None) # 初始化 SIREN -> siren.py
self.epoch = 0
self.step = 0

def set_device(self, device):
self.device = device
self.siren.device = device

self.generate_avg_frequencies() # 求频率和相位的平均值 [1, 2304] 2304 = 256 * (8 + 1)


def forward(self, z, img_size, fov, ray_start, ray_end, num_steps,
h_stddev, v_stddev, h_mean, v_mean, hierarchical_sample,
sample_dist=None, lock_view_dependence=False, **kwargs):
"""
Generates images from a noise vector, rendering parameters, and camera distribution.
Uses the hierarchical sampling scheme described in NeRF.
从 噪声向量,渲染参数,相机分布 生成图像
"""
# z.shape: (3, 256) [batch_size / batch_split, z_dim], num_rays = img_size * img_size
batch_size = z.shape[0]

# Generate initial camera rays and sample points. 生成初始相机射线和采样点
# 返回sample points, z_vals, ray directions batch_size, pixels, num_steps, 1
with torch.no_grad():
# 获取 采样点,z_vals(深度)和光线方向 position x
points_cam, z_vals, rays_d_cam = get_initial_rays_trig(
batch_size, num_steps, resolution=(img_size, img_size), device=self.device,
fov=fov, ray_start=ray_start, ray_end=ray_end)
# points_cam.shape: [batch_size, img_size*img_size, num_steps, 3(RGB)]
# z_vals.shape: [batch_size, img_size*img_size, num_steps, 1]
# rays_d_cam.shape: [batch_size, img_size*img_size, num_steps, 3(xyz)]


# transform_sampled_points 对相机位置进行采样,并将相机空间坐标映射到世界空间坐标
# 采样点,z_vals(深度),光线方向,光线原点,俯仰角,偏航角 转换后的坐标
transformed_points, z_vals, transformed_ray_directions, transformed_ray_origins, pitch, yaw = \
transform_sampled_points(points_cam, z_vals, rays_d_cam, h_stddev=h_stddev, v_stddev=v_stddev,
h_mean=h_mean, v_mean=v_mean, device=self.device, mode=sample_dist)

# trasformed_points.shape: [batch_size, num_rays, num_steps, 3(RGB)]
# z_vals.shape: [batch_size, num_rays, num_steps, 1]
# transformed_ray_directions.shape: [batch_size, num_rays, num_steps, 3(xyz)]
# transformed_ray_origins.shape: [batch_size, num_rays, num_steps, 3(xyz)]
# pitch.shape: [batch_size, num_rays, 1] yaw.shape: [batch_size, num_rays, 1]

# 坐标系变换 从相机坐标系到世界坐标系
# transformed_ray_directions_expanded 转换后的射线方向

# [batch_size, num_rays, num_steps, 1, 3]
transformed_ray_directions_expanded = torch.unsqueeze(transformed_ray_directions, -2)
transformed_ray_directions_expanded = transformed_ray_directions_expanded.expand(-1, -1,
num_steps, -1) # [batch_size, num_rays, num_steps, num_steps, 3]
transformed_ray_directions_expanded = transformed_ray_directions_expanded.reshape(batch_size, img_size*img_size*num_steps, 3)

# [batch_size, num_rays*num_steps, 3]
transformed_points = transformed_points.reshape(batch_size, img_size*img_size*num_steps, 3)

if lock_view_dependence:
transformed_ray_directions_expanded = torch.zeros_like(transformed_ray_directions_expanded)
transformed_ray_directions_expanded[..., -1] = -1

# Model prediction on course points MLP 隐藏层计算 粗糙采样
"""
输入:transformed_points [batch_size, num_rays*num_steps, 3(xyz)], z [batch_size, 256],
ray_directions [batch_size, num_rays*num_steps, 3(xyz)]
输出:rgb aplha [batch_size, num_rays*num_steps, 4]
"""
# input siren # -> siren.py TALLSIREN forward
coarse_output = self.siren(transformed_points, z, ray_directions=transformed_ray_directions_expanded)

# [batch_size, num_rays, num_steps, 4] 每条光线上的每个采样点
coarse_output = coarse_output.reshape(batch_size, img_size * img_size, num_steps, 4)

# Re-sample fine points alont camera rays, as described in NeRF
if hierarchical_sample: # 半球采样
with torch.no_grad():
# 每个光线上的每个采样点
transformed_points = transformed_points.reshape(batch_size, img_size * img_size, num_steps, 3)
# 从 fancy_integration 中获取 weights 权重 用来进行重要性采样(精细采样)

_, _, weights = fancy_integration(coarse_output, z_vals, device=self.device,
clamp_mode=kwargs['clamp_mode'],
noise_std=kwargs['nerf_noise'])
weights = weights.reshape(batch_size * img_size * img_size, num_steps) + 1e-5

#### Start new importance sampling 重要性采样
z_vals = z_vals.reshape(batch_size * img_size * img_size, num_steps)
z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:])
z_vals = z_vals.reshape(batch_size, img_size * img_size, num_steps, 1)

fine_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1],
num_steps, det=False).detach()

fine_z_vals = fine_z_vals.reshape(batch_size, img_size * img_size, num_steps, 1)

# fine_points.shape [batch_size, num_rays, num_steps, 3]

fine_points = transformed_ray_origins.unsqueeze(2).contiguous() + \
transformed_ray_directions.unsqueeze(2).contiguous() *
fine_z_vals.expand(-1,-1,-1,3).contiguous()
# 精细网络 采样点

# [batch_size, num_rays*num_steps, 3]
fine_points = fine_points.reshape(batch_size, img_size*img_size*num_steps, 3)


if lock_view_dependence:
transformed_ray_directions_expanded = torch.zeros_like(transformed_ray_directions_expanded)
transformed_ray_directions_expanded[..., -1] = -1
#### end new importance sampling

# Model prediction on re-sampled find points 精细采样后的点在进行预测
"""
输入:fine_points [batch_size, num_rays*num_steps, 3], z [batch_size, 246],
ray_directions [batch_size, num_rays*num_steps, 3]
输出:fine_output [batch_size, num_rays, nums_steps, 4](rgb aplha)
"""
fine_output = self.siren(fine_points, z, ray_directions=transformed_ray_directions_expanded)
# [batch_size, num_rays, num_steps, 4]
fine_output = fine_output.reshape(batch_size, img_size * img_size, num_steps, 4)

# Combine course and fine points 组合粗糙采样和精细采样
# 最终输出: all_z_vals all_outputs

# [batch_size, num_rays, num_steps*2, 4]
all_outputs = torch.cat([fine_output, coarse_output], dim = -2)

all_z_vals = torch.cat([fine_z_vals, z_vals], dim = -2) # [batch_size, num_rays, num_steps*2, 1]

_, indices = torch.sort(all_z_vals, dim=-2)


all_z_vals = torch.gather(all_z_vals, -2, indices) # [batch_size, num_rays, num_steps*2, 1]

# [batch_size, num_rays, num_steps*2, 4]
all_outputs = torch.gather(all_outputs, -2, indices.expand(-1, -1, -1, 4))
else:
all_outputs = coarse_output

all_z_vals_film = z_vals
all_z_vals = z_vals


# Create images with NeRF
# 使用 NeRF 创建图像
# 输出:rgb [batch_size, num_rays, 3], depth [batch_size, num_rays, 1],
# weight [batch_size, num_rays, num_stpes*2 1]

pixels, depth, weights = fancy_integration(all_outputs, all_z_vals, device=self.device,
white_back=kwargs.get('white_back', False),
last_back=kwargs.get('last_back', False),
clamp_mode=kwargs['clamp_mode'],
noise_std=kwargs['nerf_noise'])
# 还原 pixels.shape: [batch_size, img_size, img_size, 3]

pixels = pixels.reshape((batch_size, img_size, img_size, 3))
pixels = pixels.permute(0, 3, 1, 2).contiguous() * 2 - 1 # 交换维度 [batch_size, 3, img_size, img_size]

logging.info(f"generators forward pixels.shape: {pixels.shape}")
# pixels.shape: [batch_size, 3, img_size, img_size] pitch.shape: [batch_size, 1] yaw.shape: [batch_size, 1]
return pixels, torch.cat([pitch, yaw], -1)

4. discriminators.py

  1. AddCoords:额外添加坐标的信息
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# 负责为给定的特征图添加坐标通道
class AddCoords(nn.Module):
"""
Source: https://github.com/mkocabas/CoordConv-pytorch/blob/master/CoordConv.py
"""

def __init__(self, with_r=False):
super().__init__()
self.with_r = with_r

def forward(self, input_tensor): # <-- CoordConv
"""
Args:
input_tensor --> x shape(batch, channel, img_size, img_size)
"""
logging.info(f"AddCoords input_tensor.shape: {input_tensor.shape}")
batch_size, _, x_dim, y_dim = input_tensor.size()

# 生成x, y 坐标网格
xx_channel = torch.arange(x_dim).repeat(1, y_dim, 1) # [1, img_size, img_size]
yy_channel = torch.arange(y_dim).repeat(1, x_dim, 1).transpose(1, 2) # [1, img_size, img_size]

# 归一化 [0, 1]
xx_channel = xx_channel.float() / (x_dim - 1)
yy_channel = yy_channel.float() / (y_dim - 1)

# xx_channel: shape(1, img_size, img_size)
# 映射到 [-1, 1]
xx_channel = xx_channel * 2 - 1
yy_channel = yy_channel * 2 - 1

# shape: (batch_size, 1, img_size, img_size)
xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)
yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3)

ret = torch.cat([
input_tensor,
xx_channel.type_as(input_tensor),
yy_channel.type_as(input_tensor)], dim=1)

if self.with_r:
rr = torch.sqrt(torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) +
torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2))
ret = torch.cat([ret, rr], dim=1)

# ret.shape: [batch_size, channel + (x, y), img_size, img_size] channel --> 256
return ret
  1. ProgressiveDiscriminator:使用渐进式辨别器
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# 渐进式增长判别器
class ProgressiveDiscriminator(nn.Module):
"""Implement of a progressive growing discriminator with ResidualCoordConv Blocks"""
""" 每次增加新的分辨率级别时,会增加新的 ResidualCoordConvBlock """
def __init__(self, **kwargs):
super().__init__()
self.epoch = 0
self.step = 0
self.layers = nn.ModuleList(
[
# inplane plane downsample
ResidualCoordConvBlock(16, 32, downsample=True), # 512x512 -> 256x256 # 每层有两层的CoordConv
ResidualCoordConvBlock(32, 64, downsample=True), # 256x256 -> 128x128
ResidualCoordConvBlock(64, 128, downsample=True), # 128x128 -> 64x64
ResidualCoordConvBlock(128, 256, downsample=True), # 64x64 -> 32x32
ResidualCoordConvBlock(256, 400, downsample=True), # 32x32 -> 16x16
ResidualCoordConvBlock(400, 400, downsample=True), # 16x16 -> 8x8
ResidualCoordConvBlock(400, 400, downsample=True), # 8x8 -> 4x4
ResidualCoordConvBlock(400, 400, downsample=True), # 4x4 -> 2x2
])

self.fromRGB = nn.ModuleList(
[
# output_channels
AdapterBlock(16),
AdapterBlock(32),
AdapterBlock(64),
AdapterBlock(128),
AdapterBlock(256),
AdapterBlock(400),
AdapterBlock(400),
AdapterBlock(400),
AdapterBlock(400)
])
self.final_layer = nn.Conv2d(400, 1, 2)
self.img_size_to_layer = {2:8, 4:7, 8:6, 16:5, 32:4, 64:3, 128:2, 256:1, 512:0}


def forward(self, input, alpha, instance_noise=0, **kwargs):
start = self.img_size_to_layer[input.shape[-1]]
logging.info(f"ProgressiveDiscriminator input.shape: {input.shape}")
x = self.fromRGB[start](input)
logging.info(f"ProgressiveDiscriminator x.shape: {x.shape}")
for i, layer in enumerate(self.layers[start:]):
if i == 1:
# 改变输入数据的尺寸
x = alpha * x + (1 - alpha) * self.fromRGB[start+1](F.interpolate(input, scale_factor=0.5,
mode='nearest'))
x = layer(x)

x = self.final_layer(x).reshape(x.shape[0], 1)
logging.info(f"ProgressiveDiscriminator x_output.shape: {x.shape}")

return x

5. volumetirc_rendering.py

  1. fancy_integration():预测出来的rgb 生成最终的图像
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def fancy_integration(rgb_sigma, z_vals, device, noise_std=0.5, last_back=False, white_back=False, clamp_mode=None, fill_mode=None):
"""Performs NeRF volumetric rendering."""
# rgb_sigma.shape: [batch_size, num_rays, num_steps, 4]
# z_vals.shape: [batch_size, num_rays, num_steps, 1]

# rgb_sigma 由siren网络训练得到
rgbs = rgb_sigma[..., :3] # rgb [batch_size, num_rays, num_steps, 3]
sigmas = rgb_sigma[..., 3:] # sigma [batch_size, num_rays, num_steps, 1]

# deltas 两个采样点之间的距离 d
deltas = z_vals[:, :, 1:] - z_vals[:, :, :-1] # 每两个采样点之间的距离
delta_inf = 1e10 * torch.ones_like(deltas[:, :, :1]) # 远平面 无穷远处
deltas = torch.cat([deltas, delta_inf], -2)

noise = torch.randn(sigmas.shape, device=device) * noise_std # 随机噪声 [batch_size, num_rays, num_steps, 1]

# 计算 alpha
if clamp_mode == 'softplus':
alphas = 1-torch.exp(-deltas * (F.softplus(sigmas + noise)))
elif clamp_mode == 'relu':
alphas = 1 - torch.exp(-deltas * (F.relu(sigmas + noise)))
else:
raise "Need to choose clamp mode"

alphas_shifted = torch.cat([torch.ones_like(alphas[:, :, :1]), 1-alphas + 1e-10], -2)

# 计算 NeRF 中的 transmittance weights = aplhas * T_i 体渲染公式
weights = alphas * torch.cumprod(alphas_shifted, -2)[:, :, :-1] # [batch_size, num_rays, num_steps, 1]
weights_sum = weights.sum(2) # [batch_size, num_rays, 1] 每条射线的权重和

if last_back:
weights[:, :, -1] += (1 - weights_sum)

rgb_final = torch.sum(weights * rgbs, -2) # [batch_size, num_rays, 3] 最终预测出来的rgb
depth_final = torch.sum(weights * z_vals, -2) # [batch_size, num_rays, 1] 最终预测出来的深度

if white_back:
rgb_final = rgb_final + 1-weights_sum

if fill_mode == 'debug':
rgb_final[weights_sum.squeeze(-1) < 0.9] = torch.tensor([1., 0, 0], device=rgb_final.device)
elif fill_mode == 'weight':
rgb_final = weights_sum.expand_as(rgb_final)

logging.info(f"fancy_integration output rgb_final.shape: {rgb_final.shape}")

# 最终预测出来的rgb 生成最终的图像
return rgb_final, depth_final, weights
  1. get_initial_rays_trig():输出采样点,采样间隔,相机射线方向
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def get_initial_rays_trig(n, num_steps, device, fov, resolution, ray_start, ray_end):
"""Returns sample points, z_vals, and ray directions in camera space."""
""" return: 返回相机空间中的采样点,z_vals(深度)和光线方向。
ray_start: 近平面
ray_end: 远平面
"""

W, H = resolution # (img_size, img_size)
# Create full screen NDC (-1 to +1) coords [x, y, 0, 1].
# Y is flipped to follow image memory layouts.
x, y = torch.meshgrid(torch.linspace(-1, 1, W, device=device), # (W, H)
torch.linspace(1, -1, H, device=device))
x = x.T.flatten() # (H*W,) 铺平
y = y.T.flatten()
z = -torch.ones_like(x, device=device) / np.tan((2 * math.pi * fov / 360)/2) # (H*W,) 透视投影

# 射线方向
rays_d_cam = normalize_vecs(torch.stack([x, y, z], -1)) # (H*W, 3)

# (H*W, num_steps, 1)
z_vals = torch.linspace(ray_start, ray_end, num_steps, device=device).reshape(1, num_steps, 1).repeat(W*H, 1,
1)
points = rays_d_cam.unsqueeze(1).repeat(1, num_steps, 1) * z_vals # (H*W, num_steps, 3)

points = torch.stack(n*[points]) # (n, H*W, num_steps, 3) n --> batch_size // batch_split
z_vals = torch.stack(n*[z_vals])
rays_d_cam = torch.stack(n*[rays_d_cam]).to(device) # (n, H*W, 3)

logging.info(f"get_initial_rays_trig's points.shape: {points.shape}, rays_d_cam.shape: {rays_d_cam.shape}")

return points, z_vals, rays_d_cam #TODO: debug these dimensions
  1. transform_sampled_points():采样相机位置并将相机空间中的点映射到世界空间
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def transform_sampled_points(points, z_vals, ray_directions, device, h_stddev=1, v_stddev=1, h_mean=math.pi * 0.5, 
v_mean=math.pi * 0.5, mode='normal'):
"""Samples a camera position and maps points in camera space to world space."""
""" 采样相机位置并将相机空间中的点映射到世界空间。 """
# n --> batch_size, num_rays --> H*W(pixels), num_steps --> num_samples
n, num_rays, num_steps, channels = points.shape # input points.shape: [batch_size, num_rays, num_steps, 3]

# TODO: the points's dims
points, z_vals = perturb_points(points, z_vals, ray_directions, device)

# 获取相机原点,水平角和仰视角 camera_origin.shape: [batch_size, 3],
# pitch.shape: [batch_size, 1], yaw.shape: [batch_size, 1]
camera_origin, pitch, yaw = sample_camera_positions(n=points.shape[0], r=1, horizontal_stddev=h_stddev,
vertical_stddev=v_stddev,
horizontal_mean=h_mean, vertical_mean=v_mean,
device=device, mode=mode)
forward_vector = normalize_vecs(-camera_origin)

cam2world_matrix = create_cam2world_matrix(forward_vector, camera_origin, device=device)

points_homogeneous = torch.ones((points.shape[0], points.shape[1], points.shape[2], points.shape[3] + 1),
device=device)
points_homogeneous[:, :, :, :3] = points # (n, num_rays, num_steps, 4)

# should be n x 4 x 4 , n x r^2 x num_steps x 4 (采样点)
transformed_points = torch.bmm(cam2world_matrix, # (n, num_rays, num_steps, 4)
points_homogeneous.reshape(n, -1, 4).permute(0,2,1)).permute(0, 2,
1).reshape(n, num_rays, num_steps, 4)

# 没有使用齐次坐标(向量的平移不变性) 射线方向
transformed_ray_directions = torch.bmm(cam2world_matrix[..., :3, :3], # (n, num_rays, 3(x,y,z))
ray_directions.reshape(n, -1, 3).permute(0,2,1)).permute(0, 2,1)
.reshape(n, num_rays, 3)

# 点需要平移,先转换成齐次坐标再作c2m 原点
homogeneous_origins = torch.zeros((n, 4, num_rays), device=device) # (n, 4, num_rays)
homogeneous_origins[:, 3, :] = 1
transformed_ray_origins = torch.bmm(cam2world_matrix, homogeneous_origins).permute(0, 2, 1)
.reshape(n, num_rays, 4)[..., :3] # (n, num_rays, 3(x,y,z)))

# 返回转换之后的采样点,深度,光线方向,相机原点,仰视角,水平角
return transformed_points[..., :3], z_vals, transformed_ray_directions, transformed_ray_origins, pitch, yaw

👉 详细代码与注释:https://github.com/SeulQxQ/pi-GAN-read


pi-GAN源码分析
http://seulqxq.top/posts/27470/
作者
SeulQxQ
发布于
2024年1月3日
许可协议