Skip to content
This repository was archived by the owner on Jun 13, 2024. It is now read-only.
This repository was archived by the owner on Jun 13, 2024. It is now read-only.

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation #27

@matte-esse

Description

@matte-esse

the following code generates an error in some of the most recent versions of py-torch:

"""
Update networks
"""
self.qf1_optimizer.zero_grad()
qf1_loss.backward()
self.qf1_optimizer.step()
self.qf2_optimizer.zero_grad()
qf2_loss.backward()
self.qf2_optimizer.step()
self.policy_optimizer.zero_grad()
policy_loss.backward()
self.policy_optimizer.step()

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

In order to solve it is necessary to move these lines

q_new_actions = torch.min(
self.qf1(obs, new_obs_actions),
self.qf2(obs, new_obs_actions),
)
policy_loss = (alpha * log_pi - q_new_actions).mean()

between the q networks gradient steps and the steps on the policy network as so:

"""
Update networks
"""
self.qf1_optimizer.zero_grad()
qf1_loss.backward(retain_graph=True)
self.qf1_optimizer.step()

self.qf2_optimizer.zero_grad()
qf2_loss.backward(retain_graph=True)
self.qf2_optimizer.step()

q_new_actions = torch.min(
    self.qf1(obs, new_obs_actions),
    self.qf2(obs, new_obs_actions),
)
policy_loss = (alpha * log_pi - q_new_actions).mean()

self.policy_optimizer.zero_grad()
policy_loss.backward(retain_graph=True)
self.policy_optimizer.step()

Be aware that if you simply use an old version of pytorch to solve this problem the behaviour might not be what you expect since the policy_loss was computed based on a network which no longer exists

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions