Skip to content

Commit d215ecc

Browse files
committedJun 7, 2023
Avoid recomputation of z_state
1 parent c92d66f commit d215ecc

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed
 

‎sbx/sac7/sac7.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ def update_critic(
252252
ent_coef_state: TrainState,
253253
encoder_state: RLTrainState,
254254
action_encoder_state: RLTrainState,
255+
z_state: jnp.ndarray,
255256
observations: np.ndarray,
256257
actions: np.ndarray,
257258
next_observations: np.ndarray,
@@ -287,7 +288,7 @@ def update_critic(
287288
# shape is (batch_size, 1)
288289
target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values
289290

290-
z_state = encoder_state.apply_fn(encoder_state.target_params, observations)
291+
# z_state = encoder_state.apply_fn(encoder_state.target_params, observations)
291292
z_state_action = action_encoder_state.apply_fn(action_encoder_state.target_params, z_state, actions)
292293

293294
def mse_loss(params, dropout_key):
@@ -319,12 +320,13 @@ def update_actor(
319320
ent_coef_state: TrainState,
320321
encoder_state: RLTrainState,
321322
action_encoder_state: RLTrainState,
323+
z_state: jnp.ndarray,
322324
observations: np.ndarray,
323325
key: jax.random.KeyArray,
324326
):
325327
key, dropout_key, noise_key = jax.random.split(key, 3)
326328

327-
z_state = encoder_state.apply_fn(encoder_state.target_params, observations)
329+
# z_state = encoder_state.apply_fn(encoder_state.target_params, observations)
328330

329331
def actor_loss(params):
330332
dist = actor_state.apply_fn(params, observations, z_state)
@@ -432,6 +434,9 @@ def slice(x, step=i):
432434
slice(data.next_observations),
433435
)
434436

437+
z_state = encoder_state.apply_fn(encoder_state.target_params, slice(data.observations))
438+
439+
435440
(
436441
qf_state,
437442
(qf_loss_value, ent_coef_value),
@@ -443,6 +448,7 @@ def slice(x, step=i):
443448
ent_coef_state,
444449
encoder_state,
445450
action_encoder_state,
451+
z_state,
446452
slice(data.observations),
447453
slice(data.actions),
448454
slice(data.next_observations),
@@ -462,6 +468,7 @@ def slice(x, step=i):
462468
ent_coef_state,
463469
encoder_state,
464470
action_encoder_state,
471+
z_state,
465472
slice(data.observations),
466473
key,
467474
)

0 commit comments

Comments
 (0)