From 2e8fa71626d2da179aa73c8d63d3ecdf575948e0 Mon Sep 17 00:00:00 2001 From: oumayb Date: Fri, 19 Jan 2024 14:53:03 +0100 Subject: [PATCH] Handles gpu/cpu transfer in QPFunction's backward + replaces torch.Tensor by torch.empty or torch.tensor --- bindings/python/proxsuite/torch/qplayer.py | 95 +++++++++++----------- 1 file changed, 48 insertions(+), 47 deletions(-) diff --git a/bindings/python/proxsuite/torch/qplayer.py b/bindings/python/proxsuite/torch/qplayer.py index f82710196..8ca81a00b 100644 --- a/bindings/python/proxsuite/torch/qplayer.py +++ b/bindings/python/proxsuite/torch/qplayer.py @@ -50,41 +50,41 @@ def QPFunction( Solve the QP problem. Args: - Q (torch.Tensor): Batch of quadratic cost matrices of size (nBatch, n, n) or (n, n). - p (torch.Tensor): Batch of linear cost vectors of size (nBatch, n) or (n). - A (torch.Tensor, optional): Batch of eq. constraint matrices of size (nBatch, p, n) or (p, n). - b (torch.Tensor, optional): Batch of eq. constraint vectors of size (nBatch, p) or (p). - G (torch.Tensor): Batch of ineq. constraint matrices of size (nBatch, m, n) or (m, n). - l (torch.Tensor): Batch of ineq. lower bound vectors of size (nBatch, m) or (m). - u (torch.Tensor): Batch of ineq. upper bound vectors of size (nBatch, m) or (m). + Q (torch.tensor): Batch of quadratic cost matrices of size (nBatch, n, n) or (n, n). + p (torch.tensor): Batch of linear cost vectors of size (nBatch, n) or (n). + A (torch.tensor, optional): Batch of eq. constraint matrices of size (nBatch, p, n) or (p, n). + b (torch.tensor, optional): Batch of eq. constraint vectors of size (nBatch, p) or (p). + G (torch.tensor): Batch of ineq. constraint matrices of size (nBatch, m, n) or (m, n). + l (torch.tensor): Batch of ineq. lower bound vectors of size (nBatch, m) or (m). + u (torch.tensor): Batch of ineq. upper bound vectors of size (nBatch, m) or (m). Returns: - zhats (torch.Tensor): Batch of optimal primal solutions of size (nBatch, n). - lams (torch.Tensor): Batch of dual variables for eq. constraint of size (nBatch, m). - nus (torch.Tensor): Batch of dual variables for ineq. constraints of size (nBatch, p). + zhats (torch.tensor): Batch of optimal primal solutions of size (nBatch, n). + lams (torch.tensor): Batch of dual variables for eq. constraint of size (nBatch, m). + nus (torch.tensor): Batch of dual variables for ineq. constraints of size (nBatch, p). Only for infeasible case: - s_e (torch.Tensor): Batch of slack variables for eq. constraints of size (nBatch, m). - s_i (torch.Tensor): Batch of slack variables for ineq. constraints of size (nBatch, p). + s_e (torch.tensor): Batch of slack variables for eq. constraints of size (nBatch, m). + s_i (torch.tensor): Batch of slack variables for ineq. constraints of size (nBatch, p). Backward: Compute the gradients of the QP problem wrt its parameters. Args: - dl_dzhat (torch.Tensor): Batch of gradients of size (nBatch, n). - dl_dlams (torch.Tensor, optional): Batch of gradients of size (nBatch, p). - dl_dnus (torch.Tensor, optional): Batch of gradients of size (nBatch, m). + dl_dzhat (torch.tensor): Batch of gradients of size (nBatch, n). + dl_dlams (torch.tensor, optional): Batch of gradients of size (nBatch, p). + dl_dnus (torch.tensor, optional): Batch of gradients of size (nBatch, m). Only for infeasible case: - dl_ds_e (torch.Tensor, optional): Batch of gradients of size (nBatch, m). - dl_ds_i (torch.Tensor, optional): Batch of gradients of size (nBatch, m). + dl_ds_e (torch.tensor, optional): Batch of gradients of size (nBatch, m). + dl_ds_i (torch.tensor, optional): Batch of gradients of size (nBatch, m). Returns: - dQs (torch.Tensor): Batch of gradients of size (nBatch, n, n). - dps (torch.Tensor): Batch of gradients of size (nBatch, n). - dAs (torch.Tensor): Batch of gradients of size (nBatch, p, n). - dbs (torch.Tensor): Batch of gradients of size (nBatch, p). - dGs (torch.Tensor): Batch of gradients of size (nBatch, m, n). - dls (torch.Tensor): Batch of gradients of size (nBatch, m). - dus (torch.Tensor): Batch of gradients of size (nBatch, m). + dQs (torch.tensor): Batch of gradients of size (nBatch, n, n). + dps (torch.tensor): Batch of gradients of size (nBatch, n). + dAs (torch.tensor): Batch of gradients of size (nBatch, p, n). + dbs (torch.tensor): Batch of gradients of size (nBatch, p). + dGs (torch.tensor): Batch of gradients of size (nBatch, m, n). + dls (torch.tensor): Batch of gradients of size (nBatch, m). + dus (torch.tensor): Batch of gradients of size (nBatch, m). """ global proxqp_parallel proxqp_parallel = omp_parallel @@ -114,9 +114,9 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_): if ctx.cpu is not None: ctx.cpu = max(1, int(ctx.cpu / 2)) - zhats = torch.Tensor(nBatch, ctx.nz).type_as(Q) - lams = torch.Tensor(nBatch, ctx.neq).type_as(Q) - nus = torch.Tensor(nBatch, ctx.nineq).type_as(Q) + zhats = torch.empty((nBatch, ctx.nz)).type_as(Q) + lams = torch.empty((nBatch, ctx.neq)).type_as(Q) + nus = torch.empty((nBatch, ctx.nineq)).type_as(Q) for i in range(nBatch): qp = ctx.vector_of_qps.init_qp_in_place(ctx.nz, ctx.neq, ctx.nineq) @@ -163,22 +163,23 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_): ctx.vector_of_qps.get(i).solve() for i in range(nBatch): - zhats[i] = torch.Tensor(ctx.vector_of_qps.get(i).results.x) - lams[i] = torch.Tensor(ctx.vector_of_qps.get(i).results.y) - nus[i] = torch.Tensor(ctx.vector_of_qps.get(i).results.z) + zhats[i] = torch.tensor(ctx.vector_of_qps.get(i).results.x) + lams[i] = torch.tensor(ctx.vector_of_qps.get(i).results.y) + nus[i] = torch.tensor(ctx.vector_of_qps.get(i).results.z) return zhats, lams, nus @staticmethod def backward(ctx, dl_dzhat, dl_dlams, dl_dnus): + device = dl_dzhat.device nBatch, dim, neq, nineq = ctx.nBatch, ctx.nz, ctx.neq, ctx.nineq - dQs = torch.Tensor(nBatch, ctx.nz, ctx.nz) - dps = torch.Tensor(nBatch, ctx.nz) - dGs = torch.Tensor(nBatch, ctx.nineq, ctx.nz) - dus = torch.Tensor(nBatch, ctx.nineq) - dls = torch.Tensor(nBatch, ctx.nineq) - dAs = torch.Tensor(nBatch, ctx.neq, ctx.nz) - dbs = torch.Tensor(nBatch, ctx.neq) + dQs = torch.empty(nBatch, ctx.nz, ctx.nz, device=device) + dps = torch.empty(nBatch, ctx.nz, device=device) + dGs = torch.empty(nBatch, ctx.nineq, ctx.nz, device=device) + dus = torch.empty(nBatch, ctx.nineq, device=device) + dls = torch.empty(nBatch, ctx.nineq, device=device) + dAs = torch.empty(nBatch, ctx.neq, ctx.nz, device=device) + dbs = torch.empty(nBatch, ctx.neq, device=device) ctx.cpu = os.cpu_count() if ctx.cpu is not None: @@ -211,11 +212,11 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus): else: for i in range(nBatch): rhs = np.zeros(n_tot) - rhs[:dim] = dl_dzhat[i] + rhs[:dim] = dl_dzhat[i].cpu() if dl_dlams != None: - rhs[dim : dim + neq] = dl_dlams[i] + rhs[dim : dim + neq] = dl_dlams[i].cpu() if dl_dnus != None: - rhs[dim + neq :] = dl_dnus[i] + rhs[dim + neq :] = dl_dnus[i].cpu() qpi = ctx.vector_of_qps.get(i) proxsuite.proxqp.dense.compute_backward( qp=qpi, @@ -226,25 +227,25 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus): ) for i in range(nBatch): - dQs[i] = torch.Tensor( + dQs[i] = torch.tensor( ctx.vector_of_qps.get(i).model.backward_data.dL_dH ) - dps[i] = torch.Tensor( + dps[i] = torch.tensor( ctx.vector_of_qps.get(i).model.backward_data.dL_dg ) - dGs[i] = torch.Tensor( + dGs[i] = torch.tensor( ctx.vector_of_qps.get(i).model.backward_data.dL_dC ) - dus[i] = torch.Tensor( + dus[i] = torch.tensor( ctx.vector_of_qps.get(i).model.backward_data.dL_du ) - dls[i] = torch.Tensor( + dls[i] = torch.tensor( ctx.vector_of_qps.get(i).model.backward_data.dL_dl ) - dAs[i] = torch.Tensor( + dAs[i] = torch.tensor( ctx.vector_of_qps.get(i).model.backward_data.dL_dA ) - dbs[i] = torch.Tensor( + dbs[i] = torch.tensor( ctx.vector_of_qps.get(i).model.backward_data.dL_db )