From e2299887cc1748a02f4fb3d17628c00f08cbb5f0 Mon Sep 17 00:00:00 2001 From: Arnaud Mimart <33665250+amimart@users.noreply.github.com> Date: Wed, 1 Mar 2023 19:40:03 +0100 Subject: [PATCH] feat(storage): try to generify pagination logic --- contracts/cw-storage/src/contract.rs | 69 ++++-------- contracts/cw-storage/src/lib.rs | 1 + .../cw-storage/src/pagination_handler.rs | 103 ++++++++++++++++++ 3 files changed, 126 insertions(+), 47 deletions(-) create mode 100644 contracts/cw-storage/src/pagination_handler.rs diff --git a/contracts/cw-storage/src/contract.rs b/contracts/cw-storage/src/contract.rs index 53bb0346..ebd25b4d 100644 --- a/contracts/cw-storage/src/contract.rs +++ b/contracts/cw-storage/src/contract.rs @@ -221,6 +221,7 @@ pub mod query { use super::*; use crate::cursor; use crate::msg::{BucketResponse, Cursor, ObjectResponse, ObjectsResponse, PageInfo}; + use crate::pagination_handler::PaginationHandler; use cosmwasm_std::{Addr, Storage, Uint128}; use cw_storage_plus::Bound; use std::cmp::min; @@ -238,7 +239,7 @@ pub mod query { pub fn object(deps: Deps, id: ObjectId) -> StdResult { objects() .load(deps.storage, id) - .map(|object| map_object(object)) + .map(|object| map_object(&object)) } pub fn data(deps: Deps, id: ObjectId) -> StdResult { @@ -251,59 +252,33 @@ pub mod query { after: Option, first: Option, ) -> StdResult { - let page_size = match first { - Some(req) => { - if req > BUCKET.load(deps.storage)?.pagination.max_page_size { - return Err(StdError::generic_err( - "Requested page size exceed maximum allowed", - )); - } - Ok(req) - } - _ => BUCKET - .load(deps.storage) - .map(|b| b.pagination.default_page_size), - }? as usize; - - let min_bound = match after { - Some(cursor) => { - let id: String = cursor::decode(cursor)?; - Some(Bound::exclusive(id)) - } + let address = match address { + Some(raw) => Some(deps.api.addr_validate(&raw)?), _ => None, }; - let iter = match address { - Some(raw_addr) => { - let addr = deps.api.addr_validate(raw_addr.as_str())?; - objects().idx.owner.prefix(addr).range( + let handler: PaginationHandler = + PaginationHandler::from(BUCKET.load(deps.storage)?.pagination); + + let page: (Vec, PageInfo) = handler.query_page( + |min_bound| match address { + Some(addr) => objects().idx.owner.prefix(addr).range( deps.storage, min_bound, None, Order::Ascending, - ) - } - _ => objects().range(deps.storage, min_bound, None, Order::Ascending), - }; - - let raw_objects = iter.take(page_size + 1).collect::>>()?; - let mapped_objects: Vec = raw_objects - .iter() - .take(page_size) - .map(|(_, object)| map_object(object.to_owned())) - .collect(); - - let cursor = mapped_objects - .last() - .map(|object| cursor::encode(object.id.clone())) - .unwrap_or("".to_string()); + ), + _ => objects().range(deps.storage, min_bound, None, Order::Ascending), + }, + |c| cursor::decode(c), + |o: &Object| cursor::encode(o.id.clone()), + after, + first, + )?; Ok(ObjectsResponse { - data: mapped_objects, - page_info: PageInfo { - has_next_page: raw_objects.len() > page_size, - cursor, - }, + data: page.0.iter().map(|object| map_object(object)).collect(), + page_info: page.1, }) } @@ -316,11 +291,11 @@ pub mod query { Err(StdError::generic_err("Not implemented")) } - fn map_object(object: Object) -> ObjectResponse { + fn map_object(object: &Object) -> ObjectResponse { ObjectResponse { id: object.id.clone(), size: object.size, - owner: object.owner.into(), + owner: object.owner.clone().into(), is_pinned: object.pin_count > Uint128::zero(), } } diff --git a/contracts/cw-storage/src/lib.rs b/contracts/cw-storage/src/lib.rs index 6d87cde9..d710bd9b 100644 --- a/contracts/cw-storage/src/lib.rs +++ b/contracts/cw-storage/src/lib.rs @@ -3,6 +3,7 @@ pub mod crypto; mod cursor; mod error; pub mod msg; +mod pagination_handler; pub mod state; pub use crate::error::ContractError; diff --git a/contracts/cw-storage/src/pagination_handler.rs b/contracts/cw-storage/src/pagination_handler.rs new file mode 100644 index 00000000..c87d1292 --- /dev/null +++ b/contracts/cw-storage/src/pagination_handler.rs @@ -0,0 +1,103 @@ +use crate::msg::{Cursor, PageInfo}; +use crate::state::Pagination; +use cosmwasm_std::{StdError, StdResult, Storage}; +use cw_storage_plus::{Bound, PrimaryKey}; +use serde::de::DeserializeOwned; +use serde::Serialize; +use std::marker::PhantomData; + +pub struct PaginationHandler<'a, T, PK> +where + T: Serialize + DeserializeOwned, + PK: PrimaryKey<'a>, +{ + max_page_size: u32, + default_page_size: u32, + + _data_type: PhantomData, + _pk_type: PhantomData, + _lifetime: PhantomData<&'a ()>, +} + +impl<'a, T, PK> From for PaginationHandler<'a, T, PK> +where + T: Serialize + DeserializeOwned, + PK: PrimaryKey<'a>, +{ + fn from(value: Pagination) -> Self { + PaginationHandler::new(value.max_page_size, value.default_page_size) + } +} + +impl<'a, T, PK> PaginationHandler<'a, T, PK> +where + T: Serialize + DeserializeOwned, + PK: PrimaryKey<'a>, +{ + pub const fn new(max_page_size: u32, default_page_size: u32) -> Self { + PaginationHandler { + max_page_size, + default_page_size, + _data_type: PhantomData, + _pk_type: PhantomData, + _lifetime: PhantomData, + } + } + + pub fn query_page( + self, + iter_fn: I, + cursor_dec_fn: CD, + cursor_enc_fn: CE, + after: Option, + first: Option, + ) -> StdResult<(Vec, PageInfo)> + where + I: FnOnce(Option>) -> Box> + 'a>, + CD: FnOnce(Cursor) -> StdResult, + CE: FnOnce(&T) -> Cursor, + { + let min_bound = match after { + Some(cursor) => Some(Bound::exclusive(cursor_dec_fn(cursor)?)), + _ => None, + }; + let page_size = self.compute_page_size(first)?; + let mut raw_items: Vec = iter_fn(min_bound) + .take(page_size + 1) + .map(|res: StdResult<(PK, T)>| res.map(|(_, item)| item)) + .collect::>>()?; + + let has_next_page = raw_items.len() > page_size; + if has_next_page { + raw_items.pop(); + } + + let cursor = raw_items + .last() + .map(|item| cursor_enc_fn(item)) + .unwrap_or("".to_string()); + + Ok(( + raw_items, + PageInfo { + has_next_page, + cursor, + }, + )) + } + + fn compute_page_size(self, first: Option) -> StdResult { + match first { + Some(req) => { + if req > self.max_page_size { + return Err(StdError::generic_err( + "Requested page size exceed maximum allowed", + )); + } + Ok(req) + } + _ => Ok(self.default_page_size), + } + .map(|size| size as usize) + } +}