EpiGRAF源码分析

1 Generation

1. Mapping Network

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
@persistence.persistent_class
class MappingNetwork(torch.nn.Module):
def __init__(self,
z_dim, # Input latent (Z) dimensionality, 0 = no latent.
c_dim, # Conditioning label (C) dimensionality, 0 = no label.
w_dim, # Intermediate latent (W) dimensionality.
num_ws, # Number of intermediate latents to output, None = do not broadcast.
num_layers = 2, # Number of mapping layers.
embed_features = None, # Label embedding dimensionality, None = same as w_dim.
layer_features = None, # Number of intermediate features in the mapping layers,
activation = 'lrelu', # Activation function: 'relu', 'lrelu', etc.
lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
w_avg_beta = 0.998, # Decay for tracking the moving average of W during training,
camera_cond = False, # Camera conditioning
camera_raw_scalars = False, # Should we use raw camera angles as input or preprocess them with
camera_cond_drop_p = 0.0, # Camera conditioning dropout
camera_cond_noise_std = 0.0, # Camera conditioning noise std.
mean_camera_pose = None, # Average camera pose for use at test time.
):
super().__init__()
if camera_cond:
if camera_raw_scalars:
self.camera_scalar_enc = ScalarEncoder1d(coord_dim=2, x_multiplier=0.0, const_emb_dim=0,
use_raw=True)
else:
self.camera_scalar_enc = ScalarEncoder1d(coord_dim=2, x_multiplier=64.0, const_emb_dim=0)
c_dim = c_dim + self.camera_scalar_enc.get_dim()
assert self.camera_scalar_enc.get_dim() > 0
else:
self.camera_scalar_enc = None

self.z_dim = z_dim
self.c_dim = c_dim
self.w_dim = w_dim
self.num_ws = num_ws
self.num_layers = num_layers
self.w_avg_beta = w_avg_beta
self.camera_cond_drop_p = camera_cond_drop_p
self.camera_cond_noise_std = camera_cond_noise_std

if embed_features is None:
embed_features = w_dim
if self.c_dim == 0:
embed_features = 0
if layer_features is None:
layer_features = w_dim
features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]

if self.c_dim > 0:
self.embed = FullyConnectedLayer(self.c_dim, embed_features)
for idx in range(num_layers):
in_features = features_list[idx]
out_features = features_list[idx + 1]
layer = FullyConnectedLayer(in_features, out_features, activation=activation,
lr_multiplier=lr_multiplier)
setattr(self, f'fc{idx}', layer)

if num_ws is not None and w_avg_beta is not None:
self.register_buffer('w_avg', torch.zeros([w_dim]))

if not mean_camera_pose is None:
self.register_buffer('mean_camera_pose', mean_camera_pose)
else:
self.mean_camera_pose = None

def forward(self, z, c, camera_angles: torch.Tensor=None, truncation_psi=1, truncation_cutoff=None,
update_emas=False):
if (not self.camera_scalar_enc is None) and (not self.training) and (camera_angles is None):
camera_angles = self.mean_camera_pose.unsqueeze(0).repeat(len(z), 1) # [batch_size, 3]

if not self.camera_scalar_enc is None:
# Using only yaw and pitch for conditioning (roll is always zero)
camera_angles = camera_angles[:, [0, 1]] # [batch_size, 2]
if self.training and self.camera_cond_noise_std > 0:
camera_angles = camera_angles + self.camera_cond_noise_std * torch.randn_like(camera_angles) *
camera_angles.std(dim=0, keepdim=True) # [batch_size, 2]

# [batch_size, 2]
camera_angles = camera_angles.sign() * ((camera_angles.abs() % (2.0 * np.pi)) / (2.0 * np.pi))
camera_angles_embs = self.camera_scalar_enc(camera_angles) # [batch_size, fourier_dim]

# [batch_size, fourier_dim]
camera_angles_embs = F.dropout(camera_angles_embs, p=self.camera_cond_drop_p, training=self.training)

# [batch_size, c_dim]
c = torch.zeros(len(camera_angles_embs), 0, device=camera_angles_embs.device) if c is None else c
c = torch.cat([c, camera_angles_embs], dim=1) # [batch_size, c_dim + angle_emb_dim]

# Embed, normalize, and concat inputs.
x = None
with torch.autograd.profiler.record_function('input'):
if self.z_dim > 0:
misc.assert_shape(z, [None, self.z_dim])
x = normalize_2nd_moment(z.to(torch.float32))
if self.c_dim > 0:
misc.assert_shape(c, [None, self.c_dim])
y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
x = torch.cat([x, y], dim=1) if x is not None else y

# Main layers.
for idx in range(self.num_layers):
layer = getattr(self, f'fc{idx}')
x = layer(x)

# Update moving average of W.
if update_emas and self.w_avg_beta is not None:
with torch.autograd.profiler.record_function('update_w_avg'):
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))

# Broadcast.
if self.num_ws is not None:
with torch.autograd.profiler.record_function('broadcast'):
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])

# Apply truncation.
if truncation_psi != 1:
with torch.autograd.profiler.record_function('truncate'):
assert self.w_avg_beta is not None
if self.num_ws is None or truncation_cutoff is None:
x = self.w_avg.lerp(x, truncation_psi)
else:
x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
return x

def extra_repr(self):
return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}'

2. Systhesis network

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
@persistence.persistent_class
class SynthesisBlocksSequence(torch.nn.Module):
def __init__(self,
w_dim, # Intermediate latent (W) dimensionality.
in_resolution, # Which resolution do we start with?
out_resolution, # Output image resolution.
in_channels, # Number of input channels.
out_channels, # Number of input channels.
channel_base = 32768, # Overall multiplier for the number of channels.
channel_max = 512, # Maximum number of channels in any layer.
num_fp16_res = 4, # Use FP16 for the N highest resolutions.
**block_kwargs, # Arguments for SynthesisBlock.
):
assert in_resolution == 0 or (in_resolution >= 4 and math.log2(in_resolution).is_integer())
assert out_resolution >= 4 and math.log2(out_resolution).is_integer()
assert in_resolution < out_resolution

super().__init__()

self.w_dim = w_dim
self.out_resolution = out_resolution
self.in_channels = in_channels
self.out_channels = out_channels
self.num_fp16_res = num_fp16_res

in_resolution_log2 = 2 if in_resolution == 0 else (int(np.log2(in_resolution)) + 1)
out_resolution_log2 = int(np.log2(out_resolution))
self.block_resolutions = [2 ** i for i in range(in_resolution_log2, out_resolution_log2 + 1)]
out_channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
fp16_resolution = max(2 ** (out_resolution_log2 + 1 - num_fp16_res), 8)

self.num_ws = 0
for block_idx, res in enumerate(self.block_resolutions):
cur_in_channels = out_channels_dict[res // 2] if block_idx > 0 else in_channels
cur_out_channels = out_channels_dict[res]
use_fp16 = (res >= fp16_resolution)
is_last = (res == self.out_resolution)
block = SynthesisBlock(cur_in_channels, cur_out_channels, w_dim=w_dim, resolution=res,
img_channels=self.out_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs)
self.num_ws += block.num_conv
if is_last:
self.num_ws += block.num_torgb
setattr(self, f'b{res}', block)

def forward(self, ws, x: torch.Tensor=None, **block_kwargs):
block_ws = []
with torch.autograd.profiler.record_function('split_ws'):
misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
ws = ws.to(torch.float32)
w_idx = 0
for res in self.block_resolutions:
block = getattr(self, f'b{res}')
block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
w_idx += block.num_conv

img = None
for res, cur_ws in zip(self.block_resolutions, block_ws):
block = getattr(self, f'b{res}')
x, img = block(x, img, cur_ws, **block_kwargs)
return img

3. tri_plane_renderer && TriPlaneMLP

三平面渲染,以及三平面的MLP解码网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def tri_plane_renderer(x: torch.Tensor, coords: torch.Tensor, ray_d_world: torch.Tensor, mlp: Callable, scale: float=1.0) -> torch.Tensor:
"""
Computes RGB\sigma values from a tri-plane representation + MLP

x: [batch_size, feat_dim * 3, h, w]
coords: [batch_size, h * w, num_steps, 3]
ray_d_world: [batch_size, h * w, 3] --- ray directions in the world coordinate system
mlp: additional transform to apply on top of features
scale: additional scaling of the coordinates
"""
assert x.shape[1] % 3 == 0, f"We use 3 planes: {x.shape}"
coords = coords.view(coords.shape[0], -1, 3) # [batch_size, h * w * num_points, 3]
batch_size, raw_feat_dim, h, w = x.shape
num_points = coords.shape[1]
feat_dim = raw_feat_dim // 3
misc.assert_shape(coords, [batch_size, None, 3])

x = x.view(batch_size * 3, feat_dim, h, w) # [batch_size * 3, feat_dim, h, w]
coords = coords / scale # [batch_size, num_points, 3]
coords_2d = torch.stack([
coords[..., [0, 1]], # z/y plane
coords[..., [0, 2]], # z/x plane
coords[..., [1, 2]], # y/x plane
], dim=1) # [batch_size, 3, num_points, 2]
coords_2d = coords_2d.view(batch_size * 3, 1, num_points, 2) # [batch_size * 3, 1, num_points, 2]
# assert ((coords_2d.min().item() >= -1.0 - 1e-8) and (coords_2d.max().item() <= 1.0 + 1e-8))
# x : [batch_size, 3, feat_dim, num_points]
x = F.grid_sample(x, grid=coords_2d, mode='bilinear', align_corners=True).view(batch_size, 3, feat_dim,
num_points)
x = x.permute(0, 1, 3, 2) # [batch_size, 3, num_points, feat_dim]
x = mlp(x, coords, ray_d_world) # [batch_size, num_points, out_dim]

return x
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class TriPlaneMLP(nn.Module):
def __init__(self, cfg: DictConfig, out_dim: int):
super().__init__()
self.cfg = cfg
self.out_dim = out_dim

if self.cfg.tri_plane.mlp.n_layers == 0:
assert self.cfg.tri_plane.feat_dim == (self.out_dim + 1),
f"Wrong dims: {self.cfg.tri_plane.feat_dim}, {self.out_dim}"
self.model = nn.Identity()
else:
if self.cfg.tri_plane.get('posenc_period_len', 0) > 0:
self.pos_enc = ScalarEncoder1d(3, x_multiplier=self.cfg.tri_plane.posenc_period_len,
const_emb_dim=0)
else:
self.pos_enc = None

backbone_input_dim = self.cfg.tri_plane.feat_dim + (0 if self.pos_enc is None else
self.pos_enc.get_dim())
backbone_out_dim = 1 + (self.cfg.tri_plane.mlp.hid_dim if self.cfg.tri_plane.has_view_cond else
self.out_dim)
# (n_hid_layers + 2)
self.dims = [backbone_input_dim] + [self.cfg.tri_plane.mlp.hid_dim] * (self.cfg.tri_plane.mlp.n_layers
- 1) + [backbone_out_dim]
activations = ['lrelu'] * (len(self.dims) - 2) + ['linear']
assert len(self.dims) > 2,
f"We cant have just a linear layer here: nothing to modulate. Dims: {self.dims}"
layers = [FullyConnectedLayer(self.dims[i], self.dims[i+1], activation=a) for i, a in
enumerate(activations)]
self.model = nn.Sequential(*layers)

if self.cfg.tri_plane.has_view_cond:
self.ray_dir_enc = ScalarEncoder1d(coord_dim=3, const_emb_dim=0, x_multiplier=8, use_cos=False,
use_raw=True)
self.color_network = nn.Sequential(
FullyConnectedLayer(backbone_out_dim - 1 + self.ray_dir_enc.get_dim(), 32,
activation='lrelu'),
FullyConnectedLayer(32, self.out_dim, activation='linear'),
)
else:
self.ray_dir_enc = None
self.color_network = None

def forward(self, x: torch.Tensor, coords: torch.Tensor, ray_d_world: torch.Tensor) -> torch.Tensor:
"""
Params:
x: [batch_size, 3, num_points, feat_dim] --- volumetric features from tri-planes
coords: [batch_size, num_points, 3] --- coordinates, assumed to be in [-1, 1]
ray_d_world: [batch_size, h * w, 3] --- camera ray's view directions
"""
batch_size, _, num_points, feat_dim = x.shape
x = x.mean(dim=1).reshape(batch_size * num_points, feat_dim) # [batch_size * num_points, feat_dim]
if not self.pos_enc is None:
misc.assert_shape(coords, [batch_size, num_points, 3])
# pos_embs: [batch_size, num_points, pos_emb_dim]
pos_embs = self.pos_enc(coords.reshape(batch_size * num_points, 3))
x = torch.cat([x, pos_embs], dim=1) # [batch_size, num_points, feat_dim + pos_emb_dim]
x = self.model(x) # [batch_size * num_points, backbone_out_dim]
x = x.view(batch_size, num_points, self.dims[-1]) # [batch_size, num_points, backbone_out_dim]

if not self.color_network is None:
num_pixels, view_dir_emb = ray_d_world.shape[1], self.ray_dir_enc.get_dim()
num_steps = num_points // num_pixels
ray_dir_embs = self.ray_dir_enc(ray_d_world.reshape(-1, 3)) # [batch_size * h * w, view_dir_emb]
# ray_dir_embs: [batch_size, h * w, 1, view_dir_emb]
ray_dir_embs = ray_dir_embs.reshape(batch_size, num_pixels, 1, view_dir_emb)
ray_dir_embs = ray_dir_embs.repeat(1, 1, num_steps, 1) # [batch_size, h * w, num_steps, view_dir_emb]
# ray_dir_embs: [batch_size, h * w * num_steps, view_dir_emb]
ray_dir_embs = ray_dir_embs.reshape(batch_size, num_points, view_dir_emb)
density = x[:, :, [-1]] # [batch_size, num_points, 1]
# color_feats [batch_size, num_points, backbone_out_dim - 1]
color_feats = F.leaky_relu(x[:, :, :-1], negative_slope=0.1)
# color_feats: [batch_size, num_points, backbone_out_dim - 1 + view_dir_emb]
color_feats = torch.cat([color_feats, ray_dir_embs], dim=2)
# color_feats: [batch_size * num_points, backbone_out_dim - 1 + view_dir_emb]
color_feats = color_feats.view(batch_size * num_points, self.dims[-1] - 1 + view_dir_emb)
colors = self.color_network(color_feats) # [batch_size * num_points, out_dim]
colors = colors.view(batch_size, num_points, self.out_dim) # [batch_size * num_points, out_dim]
y = torch.cat([colors, density], dim=2) # [batch_size, num_points, out_dim + 1]
else:
y = x

misc.assert_shape(y, [batch_size, num_points, self.out_dim + 1])

return y

EpiGRAF源码分析
http://seulqxq.top/posts/41203/
作者
SeulQxQ
发布于
2024年1月9日
许可协议