Skip to content

Commit

Permalink
Handles gpu/cpu transfer in QPFunction's backward + replaces torch.Te…
Browse files Browse the repository at this point in the history
…nsor by torch.empty or torch.tensor
  • Loading branch information
oumayb authored and fabinsch committed Jan 19, 2024
1 parent 671e9b6 commit 7661158
Showing 1 changed file with 24 additions and 23 deletions.
47 changes: 24 additions & 23 deletions bindings/python/proxsuite/torch/qplayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
if ctx.cpu is not None:
ctx.cpu = max(1, int(ctx.cpu / 2))

zhats = torch.Tensor(nBatch, ctx.nz).type_as(Q)
lams = torch.Tensor(nBatch, ctx.neq).type_as(Q)
nus = torch.Tensor(nBatch, ctx.nineq).type_as(Q)
zhats = torch.empty((nBatch, ctx.nz)).type_as(Q)
lams = torch.empty((nBatch, ctx.neq)).type_as(Q)
nus = torch.empty((nBatch, ctx.nineq)).type_as(Q)

for i in range(nBatch):
qp = ctx.vector_of_qps.init_qp_in_place(ctx.nz, ctx.neq, ctx.nineq)
Expand Down Expand Up @@ -163,22 +163,23 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
ctx.vector_of_qps.get(i).solve()

for i in range(nBatch):
zhats[i] = torch.Tensor(ctx.vector_of_qps.get(i).results.x)
lams[i] = torch.Tensor(ctx.vector_of_qps.get(i).results.y)
nus[i] = torch.Tensor(ctx.vector_of_qps.get(i).results.z)
zhats[i] = torch.tensor(ctx.vector_of_qps.get(i).results.x)
lams[i] = torch.tensor(ctx.vector_of_qps.get(i).results.y)
nus[i] = torch.tensor(ctx.vector_of_qps.get(i).results.z)

return zhats, lams, nus

@staticmethod
def backward(ctx, dl_dzhat, dl_dlams, dl_dnus):
device = dl_dzhat.device
nBatch, dim, neq, nineq = ctx.nBatch, ctx.nz, ctx.neq, ctx.nineq
dQs = torch.Tensor(nBatch, ctx.nz, ctx.nz)
dps = torch.Tensor(nBatch, ctx.nz)
dGs = torch.Tensor(nBatch, ctx.nineq, ctx.nz)
dus = torch.Tensor(nBatch, ctx.nineq)
dls = torch.Tensor(nBatch, ctx.nineq)
dAs = torch.Tensor(nBatch, ctx.neq, ctx.nz)
dbs = torch.Tensor(nBatch, ctx.neq)
dQs = torch.empty(nBatch, ctx.nz, ctx.nz, device=device)
dps = torch.empty(nBatch, ctx.nz, device=device)
dGs = torch.empty(nBatch, ctx.nineq, ctx.nz, device=device)
dus = torch.empty(nBatch, ctx.nineq, device=device)
dls = torch.empty(nBatch, ctx.nineq, device=device)
dAs = torch.empty(nBatch, ctx.neq, ctx.nz, device=device)
dbs = torch.empty(nBatch, ctx.neq, device=device)

ctx.cpu = os.cpu_count()
if ctx.cpu is not None:
Expand Down Expand Up @@ -211,11 +212,11 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus):
else:
for i in range(nBatch):
rhs = np.zeros(n_tot)
rhs[:dim] = dl_dzhat[i]
rhs[:dim] = dl_dzhat[i].cpu()
if dl_dlams != None:
rhs[dim : dim + neq] = dl_dlams[i]
rhs[dim : dim + neq] = dl_dlams[i].cpu()
if dl_dnus != None:
rhs[dim + neq :] = dl_dnus[i]
rhs[dim + neq :] = dl_dnus[i].cpu()
qpi = ctx.vector_of_qps.get(i)
proxsuite.proxqp.dense.compute_backward(
qp=qpi,
Expand All @@ -226,25 +227,25 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus):
)

for i in range(nBatch):
dQs[i] = torch.Tensor(
dQs[i] = torch.tensor(
ctx.vector_of_qps.get(i).model.backward_data.dL_dH
)
dps[i] = torch.Tensor(
dps[i] = torch.tensor(
ctx.vector_of_qps.get(i).model.backward_data.dL_dg
)
dGs[i] = torch.Tensor(
dGs[i] = torch.tensor(
ctx.vector_of_qps.get(i).model.backward_data.dL_dC
)
dus[i] = torch.Tensor(
dus[i] = torch.tensor(
ctx.vector_of_qps.get(i).model.backward_data.dL_du
)
dls[i] = torch.Tensor(
dls[i] = torch.tensor(
ctx.vector_of_qps.get(i).model.backward_data.dL_dl
)
dAs[i] = torch.Tensor(
dAs[i] = torch.tensor(
ctx.vector_of_qps.get(i).model.backward_data.dL_dA
)
dbs[i] = torch.Tensor(
dbs[i] = torch.tensor(
ctx.vector_of_qps.get(i).model.backward_data.dL_db
)

Expand Down

0 comments on commit 7661158

Please sign in to comment.