This repository was archived by the owner on Jun 13, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 21
This repository was archived by the owner on Jun 13, 2024. It is now read-only.
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation #27
Copy link
Copy link
Open
Description
the following code generates an error in some of the most recent versions of py-torch:
oac-explore/trainer/trainer.py
Lines 146 to 159 in cbc0333
| """ | |
| Update networks | |
| """ | |
| self.qf1_optimizer.zero_grad() | |
| qf1_loss.backward() | |
| self.qf1_optimizer.step() | |
| self.qf2_optimizer.zero_grad() | |
| qf2_loss.backward() | |
| self.qf2_optimizer.step() | |
| self.policy_optimizer.zero_grad() | |
| policy_loss.backward() | |
| self.policy_optimizer.step() |
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
In order to solve it is necessary to move these lines
oac-explore/trainer/trainer.py
Lines 120 to 124 in cbc0333
| q_new_actions = torch.min( | |
| self.qf1(obs, new_obs_actions), | |
| self.qf2(obs, new_obs_actions), | |
| ) | |
| policy_loss = (alpha * log_pi - q_new_actions).mean() |
between the q networks gradient steps and the steps on the policy network as so:
"""
Update networks
"""
self.qf1_optimizer.zero_grad()
qf1_loss.backward(retain_graph=True)
self.qf1_optimizer.step()
self.qf2_optimizer.zero_grad()
qf2_loss.backward(retain_graph=True)
self.qf2_optimizer.step()
q_new_actions = torch.min(
self.qf1(obs, new_obs_actions),
self.qf2(obs, new_obs_actions),
)
policy_loss = (alpha * log_pi - q_new_actions).mean()
self.policy_optimizer.zero_grad()
policy_loss.backward(retain_graph=True)
self.policy_optimizer.step()Be aware that if you simply use an old version of pytorch to solve this problem the behaviour might not be what you expect since the policy_loss was computed based on a network which no longer exists
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels