diff --git a/bindings/python/proxsuite/torch/qplayer.py b/bindings/python/proxsuite/torch/qplayer.py index f82710196..52d9fbda6 100644 --- a/bindings/python/proxsuite/torch/qplayer.py +++ b/bindings/python/proxsuite/torch/qplayer.py @@ -114,9 +114,9 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_): if ctx.cpu is not None: ctx.cpu = max(1, int(ctx.cpu / 2)) - zhats = torch.Tensor(nBatch, ctx.nz).type_as(Q) - lams = torch.Tensor(nBatch, ctx.neq).type_as(Q) - nus = torch.Tensor(nBatch, ctx.nineq).type_as(Q) + zhats = torch.empty((nBatch, ctx.nz)).type_as(Q) + lams = torch.empty((nBatch, ctx.neq)).type_as(Q) + nus = torch.empty((nBatch, ctx.nineq)).type_as(Q) for i in range(nBatch): qp = ctx.vector_of_qps.init_qp_in_place(ctx.nz, ctx.neq, ctx.nineq) @@ -163,22 +163,23 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_): ctx.vector_of_qps.get(i).solve() for i in range(nBatch): - zhats[i] = torch.Tensor(ctx.vector_of_qps.get(i).results.x) - lams[i] = torch.Tensor(ctx.vector_of_qps.get(i).results.y) - nus[i] = torch.Tensor(ctx.vector_of_qps.get(i).results.z) + zhats[i] = torch.tensor(ctx.vector_of_qps.get(i).results.x) + lams[i] = torch.tensor(ctx.vector_of_qps.get(i).results.y) + nus[i] = torch.tensor(ctx.vector_of_qps.get(i).results.z) return zhats, lams, nus @staticmethod def backward(ctx, dl_dzhat, dl_dlams, dl_dnus): + device = dl_dzhat.device nBatch, dim, neq, nineq = ctx.nBatch, ctx.nz, ctx.neq, ctx.nineq - dQs = torch.Tensor(nBatch, ctx.nz, ctx.nz) - dps = torch.Tensor(nBatch, ctx.nz) - dGs = torch.Tensor(nBatch, ctx.nineq, ctx.nz) - dus = torch.Tensor(nBatch, ctx.nineq) - dls = torch.Tensor(nBatch, ctx.nineq) - dAs = torch.Tensor(nBatch, ctx.neq, ctx.nz) - dbs = torch.Tensor(nBatch, ctx.neq) + dQs = torch.empty(nBatch, ctx.nz, ctx.nz, device=device) + dps = torch.empty(nBatch, ctx.nz, device=device) + dGs = torch.empty(nBatch, ctx.nineq, ctx.nz, device=device) + dus = torch.empty(nBatch, ctx.nineq, device=device) + dls = torch.empty(nBatch, ctx.nineq, device=device) + dAs = torch.empty(nBatch, ctx.neq, ctx.nz, device=device) + dbs = torch.empty(nBatch, ctx.neq, device=device) ctx.cpu = os.cpu_count() if ctx.cpu is not None: @@ -211,11 +212,11 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus): else: for i in range(nBatch): rhs = np.zeros(n_tot) - rhs[:dim] = dl_dzhat[i] + rhs[:dim] = dl_dzhat[i].cpu() if dl_dlams != None: - rhs[dim : dim + neq] = dl_dlams[i] + rhs[dim : dim + neq] = dl_dlams[i].cpu() if dl_dnus != None: - rhs[dim + neq :] = dl_dnus[i] + rhs[dim + neq :] = dl_dnus[i].cpu() qpi = ctx.vector_of_qps.get(i) proxsuite.proxqp.dense.compute_backward( qp=qpi, @@ -226,25 +227,25 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus): ) for i in range(nBatch): - dQs[i] = torch.Tensor( + dQs[i] = torch.tensor( ctx.vector_of_qps.get(i).model.backward_data.dL_dH ) - dps[i] = torch.Tensor( + dps[i] = torch.tensor( ctx.vector_of_qps.get(i).model.backward_data.dL_dg ) - dGs[i] = torch.Tensor( + dGs[i] = torch.tensor( ctx.vector_of_qps.get(i).model.backward_data.dL_dC ) - dus[i] = torch.Tensor( + dus[i] = torch.tensor( ctx.vector_of_qps.get(i).model.backward_data.dL_du ) - dls[i] = torch.Tensor( + dls[i] = torch.tensor( ctx.vector_of_qps.get(i).model.backward_data.dL_dl ) - dAs[i] = torch.Tensor( + dAs[i] = torch.tensor( ctx.vector_of_qps.get(i).model.backward_data.dL_dA ) - dbs[i] = torch.Tensor( + dbs[i] = torch.tensor( ctx.vector_of_qps.get(i).model.backward_data.dL_db )