Skip to content

Commit

Permalink
feat(storage): try to generify pagination logic
Browse files Browse the repository at this point in the history
  • Loading branch information
amimart committed Mar 13, 2023
1 parent 4326b4c commit e229988
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 47 deletions.
69 changes: 22 additions & 47 deletions contracts/cw-storage/src/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ pub mod query {
use super::*;
use crate::cursor;
use crate::msg::{BucketResponse, Cursor, ObjectResponse, ObjectsResponse, PageInfo};
use crate::pagination_handler::PaginationHandler;
use cosmwasm_std::{Addr, Storage, Uint128};
use cw_storage_plus::Bound;
use std::cmp::min;
Expand All @@ -238,7 +239,7 @@ pub mod query {
pub fn object(deps: Deps, id: ObjectId) -> StdResult<ObjectResponse> {
objects()
.load(deps.storage, id)
.map(|object| map_object(object))
.map(|object| map_object(&object))
}

pub fn data(deps: Deps, id: ObjectId) -> StdResult<Binary> {
Expand All @@ -251,59 +252,33 @@ pub mod query {
after: Option<Cursor>,
first: Option<u32>,
) -> StdResult<ObjectsResponse> {
let page_size = match first {
Some(req) => {
if req > BUCKET.load(deps.storage)?.pagination.max_page_size {
return Err(StdError::generic_err(
"Requested page size exceed maximum allowed",
));
}
Ok(req)
}
_ => BUCKET
.load(deps.storage)
.map(|b| b.pagination.default_page_size),
}? as usize;

let min_bound = match after {
Some(cursor) => {
let id: String = cursor::decode(cursor)?;
Some(Bound::exclusive(id))
}
let address = match address {
Some(raw) => Some(deps.api.addr_validate(&raw)?),
_ => None,
};

let iter = match address {
Some(raw_addr) => {
let addr = deps.api.addr_validate(raw_addr.as_str())?;
objects().idx.owner.prefix(addr).range(
let handler: PaginationHandler<Object, String> =
PaginationHandler::from(BUCKET.load(deps.storage)?.pagination);

let page: (Vec<Object>, PageInfo) = handler.query_page(
|min_bound| match address {
Some(addr) => objects().idx.owner.prefix(addr).range(
deps.storage,
min_bound,
None,
Order::Ascending,
)
}
_ => objects().range(deps.storage, min_bound, None, Order::Ascending),
};

let raw_objects = iter.take(page_size + 1).collect::<StdResult<Vec<_>>>()?;
let mapped_objects: Vec<ObjectResponse> = raw_objects
.iter()
.take(page_size)
.map(|(_, object)| map_object(object.to_owned()))
.collect();

let cursor = mapped_objects
.last()
.map(|object| cursor::encode(object.id.clone()))
.unwrap_or("".to_string());
),
_ => objects().range(deps.storage, min_bound, None, Order::Ascending),
},
|c| cursor::decode(c),
|o: &Object| cursor::encode(o.id.clone()),
after,
first,
)?;

Ok(ObjectsResponse {
data: mapped_objects,
page_info: PageInfo {
has_next_page: raw_objects.len() > page_size,
cursor,
},
data: page.0.iter().map(|object| map_object(object)).collect(),
page_info: page.1,
})
}

Expand All @@ -316,11 +291,11 @@ pub mod query {
Err(StdError::generic_err("Not implemented"))
}

fn map_object(object: Object) -> ObjectResponse {
fn map_object(object: &Object) -> ObjectResponse {
ObjectResponse {
id: object.id.clone(),
size: object.size,
owner: object.owner.into(),
owner: object.owner.clone().into(),
is_pinned: object.pin_count > Uint128::zero(),
}
}
Expand Down
1 change: 1 addition & 0 deletions contracts/cw-storage/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pub mod crypto;
mod cursor;
mod error;
pub mod msg;
mod pagination_handler;
pub mod state;

pub use crate::error::ContractError;
103 changes: 103 additions & 0 deletions contracts/cw-storage/src/pagination_handler.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
use crate::msg::{Cursor, PageInfo};
use crate::state::Pagination;
use cosmwasm_std::{StdError, StdResult, Storage};
use cw_storage_plus::{Bound, PrimaryKey};
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::marker::PhantomData;

pub struct PaginationHandler<'a, T, PK>
where
T: Serialize + DeserializeOwned,
PK: PrimaryKey<'a>,
{
max_page_size: u32,
default_page_size: u32,

_data_type: PhantomData<T>,
_pk_type: PhantomData<PK>,
_lifetime: PhantomData<&'a ()>,
}

impl<'a, T, PK> From<Pagination> for PaginationHandler<'a, T, PK>
where
T: Serialize + DeserializeOwned,
PK: PrimaryKey<'a>,
{
fn from(value: Pagination) -> Self {
PaginationHandler::new(value.max_page_size, value.default_page_size)
}
}

impl<'a, T, PK> PaginationHandler<'a, T, PK>
where
T: Serialize + DeserializeOwned,
PK: PrimaryKey<'a>,
{
pub const fn new(max_page_size: u32, default_page_size: u32) -> Self {
PaginationHandler {
max_page_size,
default_page_size,
_data_type: PhantomData,
_pk_type: PhantomData,
_lifetime: PhantomData,
}
}

pub fn query_page<I, CE, CD>(
self,
iter_fn: I,
cursor_dec_fn: CD,
cursor_enc_fn: CE,
after: Option<Cursor>,
first: Option<u32>,
) -> StdResult<(Vec<T>, PageInfo)>
where
I: FnOnce(Option<Bound<PK>>) -> Box<dyn Iterator<Item = StdResult<(PK, T)>> + 'a>,
CD: FnOnce(Cursor) -> StdResult<PK>,
CE: FnOnce(&T) -> Cursor,
{
let min_bound = match after {
Some(cursor) => Some(Bound::exclusive(cursor_dec_fn(cursor)?)),
_ => None,
};
let page_size = self.compute_page_size(first)?;
let mut raw_items: Vec<T> = iter_fn(min_bound)
.take(page_size + 1)
.map(|res: StdResult<(PK, T)>| res.map(|(_, item)| item))
.collect::<StdResult<Vec<T>>>()?;

let has_next_page = raw_items.len() > page_size;
if has_next_page {
raw_items.pop();
}

let cursor = raw_items
.last()
.map(|item| cursor_enc_fn(item))
.unwrap_or("".to_string());

Ok((
raw_items,
PageInfo {
has_next_page,
cursor,
},
))
}

fn compute_page_size(self, first: Option<u32>) -> StdResult<usize> {
match first {
Some(req) => {
if req > self.max_page_size {
return Err(StdError::generic_err(
"Requested page size exceed maximum allowed",
));
}
Ok(req)
}
_ => Ok(self.default_page_size),
}
.map(|size| size as usize)
}
}

0 comments on commit e229988

Please sign in to comment.