Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove value function from SAC #64

Open
keiohta opened this issue Dec 2, 2019 · 4 comments
Open

Remove value function from SAC #64

keiohta opened this issue Dec 2, 2019 · 4 comments

Comments

@keiohta
Copy link
Owner

keiohta commented Dec 2, 2019

Recent SAC implementations do not explicitly learn value function and estimate it from soft-Q function instead.
Follow the recent implementation from author's repo

@HesNobi
Copy link

HesNobi commented Jun 7, 2021

Hi,
I have made these changes regarding this issue. I haven't forked your code yet so I am going to copy paste the code here.
Thanks for your time.

    @tf.function
    def _train_body(self, states, actions, next_states, rewards, dones, weights):
        print("[DEBUG] initializing {_train_body SAC}")
        with tf.device(self.device):
            assert len(dones.shape) == 2
            assert len(rewards.shape) == 2
            rewards = tf.squeeze(rewards, axis=1)
            dones = tf.squeeze(dones, axis=1)

            not_dones = 1. - tf.cast(dones, dtype=tf.float32)

            with tf.GradientTape(persistent=True) as tape:
                # Compute loss of critic Q
                current_q1 = self.qf1(states, actions)
                current_q2 = self.qf2(states, actions)
                sample_next_actions, next_logp = self.actor(next_states)
                next_q1_target = self.qf1_target(next_states, sample_next_actions)
                next_q2_target = self.qf2_target(next_states, sample_next_actions)
                next_q_min_target = tf.minimum(next_q1_target, next_q2_target)
                soft_next_q = next_q_min_target - self.alpha * next_logp
                target_q = tf.stop_gradient(rewards + not_dones * self.discount * soft_next_q)

                td_loss_q1 = tf.reduce_mean((target_q - current_q1) ** 2)
                td_loss_q2 = tf.reduce_mean((target_q - current_q2) ** 2)
                td_loss_q = td_loss_q1 + td_loss_q2

                sample_actions, logp = self.actor(states)  # Resample actions to update V
                current_q1_policy = self.qf1(states, sample_actions)
                current_q2_policy = self.qf2(states, sample_actions)
                current_min_q_policy = tf.minimum(current_q1_policy, current_q2_policy)

                # Compute loss of policy
                policy_loss = tf.reduce_mean(self.alpha * logp - current_min_q_policy)

                # Compute loss of temperature parameter for entropy
                if self.auto_alpha:
                    alpha_loss = -tf.reduce_mean((self.log_alpha * tf.stop_gradient(logp + self.target_alpha)))

            trainable_var = self.qf1.trainable_variables + self.qf2.trainable_variables
            q_grad = tape.gradient(td_loss_q, trainable_var)
            self.qf_optimizer.apply_gradients(zip(q_grad, trainable_var))

            update_target_variables(self.qf1_target.weights, self.qf1.weights, tau=self.tau)
            update_target_variables(self.qf2_target.weights, self.qf2.weights, tau=self.tau)

            actor_grad = tape.gradient(policy_loss, self.actor.trainable_variables)
            self.actor_optimizer.apply_gradients(zip(actor_grad, self.actor.trainable_variables))

            if self.auto_alpha:
                alpha_grad = tape.gradient(alpha_loss, [self.log_alpha])
                self.alpha_optimizer.apply_gradients(zip(alpha_grad, [self.log_alpha]))
                self.alpha.assign(tf.exp(self.log_alpha))

            del tape

        return target_q, policy_loss, td_loss_q2, td_loss_q1, tf.reduce_min(logp), tf.reduce_max(logp), tf.reduce_mean(
            logp)

@keiohta
Copy link
Owner Author

keiohta commented Jun 7, 2021

Hi, thanks for working on this issue. Actually I just modified this part last week and applied the changes to the latest master.
https://github.com/keiohta/tf2rl/blob/master/tf2rl/algos/sac.py
I'm sorry I haven't notified the changes at this issue page.

@keiohta
Copy link
Owner Author

keiohta commented Jun 7, 2021

Oh, I found the CriticV is not removed even the SAC class does not use it. I'll fix it later.

@HesNobi
Copy link

HesNobi commented Jun 7, 2021

Hi, thanks for working on this issue. Actually I just modified this part last week and applied the changes to the latest master.
https://github.com/keiohta/tf2rl/blob/master/tf2rl/algos/sac.py
I'm sorry I haven't notified the changes at this issue page.

It is all my pleasure.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants