class MlpBarlowTwinsActor(nn.Module): def __init__(self, num_prop, num_hist, mlp_encoder_dims, actor_dims, latent_dim, num_actions, activation): super().__init__()
self.mlp_encoder = nn.Sequential( *mlp_batchnorm_factory( input_dims=num_prop * num_hist, hidden_dims=mlp_encoder_dims, activation=activation ) )
self.latent_layer = nn.Sequential( nn.Linear(mlp_encoder_dims[-1], 32), nn.BatchNorm1d(32), nn.ELU(), nn.Linear(32, latent_dim) )
self.vel_layer = nn.Linear(mlp_encoder_dims[-1], 3)
self.actor = nn.Sequential( *mlp_factory( input_dims=latent_dim + num_prop + 3, out_dims=num_actions, hidden_dims=actor_dims ) )
def forward(self, actor_obs): obs = actor_obs[:, -1, :] obs_hist = actor_obs[:, 1:6, :] b, _, _ = obs_hist.size()
latent = self.mlp_encoder(obs_hist.reshape(b, -1)) z = self.latent_layer(latent) vel = self.vel_layer(latent)
actor_input = torch.cat([vel, z, obs], dim=-1) mean = self.actor(actor_input) return mean
|