1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
| class TriPlaneMLP(nn.Module): def __init__(self, cfg: DictConfig, out_dim: int): super().__init__() self.cfg = cfg self.out_dim = out_dim
if self.cfg.tri_plane.mlp.n_layers == 0: assert self.cfg.tri_plane.feat_dim == (self.out_dim + 1), f"Wrong dims: {self.cfg.tri_plane.feat_dim}, {self.out_dim}" self.model = nn.Identity() else: if self.cfg.tri_plane.get('posenc_period_len', 0) > 0: self.pos_enc = ScalarEncoder1d(3, x_multiplier=self.cfg.tri_plane.posenc_period_len, const_emb_dim=0) else: self.pos_enc = None
backbone_input_dim = self.cfg.tri_plane.feat_dim + (0 if self.pos_enc is None else self.pos_enc.get_dim()) backbone_out_dim = 1 + (self.cfg.tri_plane.mlp.hid_dim if self.cfg.tri_plane.has_view_cond else self.out_dim) self.dims = [backbone_input_dim] + [self.cfg.tri_plane.mlp.hid_dim] * (self.cfg.tri_plane.mlp.n_layers - 1) + [backbone_out_dim] activations = ['lrelu'] * (len(self.dims) - 2) + ['linear'] assert len(self.dims) > 2, f"We cant have just a linear layer here: nothing to modulate. Dims: {self.dims}" layers = [FullyConnectedLayer(self.dims[i], self.dims[i+1], activation=a) for i, a in enumerate(activations)] self.model = nn.Sequential(*layers)
if self.cfg.tri_plane.has_view_cond: self.ray_dir_enc = ScalarEncoder1d(coord_dim=3, const_emb_dim=0, x_multiplier=8, use_cos=False, use_raw=True) self.color_network = nn.Sequential( FullyConnectedLayer(backbone_out_dim - 1 + self.ray_dir_enc.get_dim(), 32, activation='lrelu'), FullyConnectedLayer(32, self.out_dim, activation='linear'), ) else: self.ray_dir_enc = None self.color_network = None
def forward(self, x: torch.Tensor, coords: torch.Tensor, ray_d_world: torch.Tensor) -> torch.Tensor: """ Params: x: [batch_size, 3, num_points, feat_dim] --- volumetric features from tri-planes coords: [batch_size, num_points, 3] --- coordinates, assumed to be in [-1, 1] ray_d_world: [batch_size, h * w, 3] --- camera ray's view directions """ batch_size, _, num_points, feat_dim = x.shape x = x.mean(dim=1).reshape(batch_size * num_points, feat_dim) if not self.pos_enc is None: misc.assert_shape(coords, [batch_size, num_points, 3]) pos_embs = self.pos_enc(coords.reshape(batch_size * num_points, 3)) x = torch.cat([x, pos_embs], dim=1) x = self.model(x) x = x.view(batch_size, num_points, self.dims[-1])
if not self.color_network is None: num_pixels, view_dir_emb = ray_d_world.shape[1], self.ray_dir_enc.get_dim() num_steps = num_points // num_pixels ray_dir_embs = self.ray_dir_enc(ray_d_world.reshape(-1, 3)) ray_dir_embs = ray_dir_embs.reshape(batch_size, num_pixels, 1, view_dir_emb) ray_dir_embs = ray_dir_embs.repeat(1, 1, num_steps, 1) ray_dir_embs = ray_dir_embs.reshape(batch_size, num_points, view_dir_emb) density = x[:, :, [-1]] color_feats = F.leaky_relu(x[:, :, :-1], negative_slope=0.1) color_feats = torch.cat([color_feats, ray_dir_embs], dim=2) color_feats = color_feats.view(batch_size * num_points, self.dims[-1] - 1 + view_dir_emb) colors = self.color_network(color_feats) colors = colors.view(batch_size, num_points, self.out_dim) y = torch.cat([colors, density], dim=2) else: y = x
misc.assert_shape(y, [batch_size, num_points, self.out_dim + 1])
return y
|