@@ -252,6 +252,7 @@ def update_critic(
252
252
ent_coef_state : TrainState ,
253
253
encoder_state : RLTrainState ,
254
254
action_encoder_state : RLTrainState ,
255
+ z_state : jnp .ndarray ,
255
256
observations : np .ndarray ,
256
257
actions : np .ndarray ,
257
258
next_observations : np .ndarray ,
@@ -287,7 +288,7 @@ def update_critic(
287
288
# shape is (batch_size, 1)
288
289
target_q_values = rewards .reshape (- 1 , 1 ) + (1 - dones .reshape (- 1 , 1 )) * gamma * next_q_values
289
290
290
- z_state = encoder_state .apply_fn (encoder_state .target_params , observations )
291
+ # z_state = encoder_state.apply_fn(encoder_state.target_params, observations)
291
292
z_state_action = action_encoder_state .apply_fn (action_encoder_state .target_params , z_state , actions )
292
293
293
294
def mse_loss (params , dropout_key ):
@@ -319,12 +320,13 @@ def update_actor(
319
320
ent_coef_state : TrainState ,
320
321
encoder_state : RLTrainState ,
321
322
action_encoder_state : RLTrainState ,
323
+ z_state : jnp .ndarray ,
322
324
observations : np .ndarray ,
323
325
key : jax .random .KeyArray ,
324
326
):
325
327
key , dropout_key , noise_key = jax .random .split (key , 3 )
326
328
327
- z_state = encoder_state .apply_fn (encoder_state .target_params , observations )
329
+ # z_state = encoder_state.apply_fn(encoder_state.target_params, observations)
328
330
329
331
def actor_loss (params ):
330
332
dist = actor_state .apply_fn (params , observations , z_state )
@@ -432,6 +434,9 @@ def slice(x, step=i):
432
434
slice (data .next_observations ),
433
435
)
434
436
437
+ z_state = encoder_state .apply_fn (encoder_state .target_params , slice (data .observations ))
438
+
439
+
435
440
(
436
441
qf_state ,
437
442
(qf_loss_value , ent_coef_value ),
@@ -443,6 +448,7 @@ def slice(x, step=i):
443
448
ent_coef_state ,
444
449
encoder_state ,
445
450
action_encoder_state ,
451
+ z_state ,
446
452
slice (data .observations ),
447
453
slice (data .actions ),
448
454
slice (data .next_observations ),
@@ -462,6 +468,7 @@ def slice(x, step=i):
462
468
ent_coef_state ,
463
469
encoder_state ,
464
470
action_encoder_state ,
471
+ z_state ,
465
472
slice (data .observations ),
466
473
key ,
467
474
)
0 commit comments