Skip to content

Commit

Permalink
[DataGrid] Introduce selectors with arguments (#14236)
Browse files Browse the repository at this point in the history
  • Loading branch information
MBilalShafi authored Aug 22, 2024
1 parent 64fbda1 commit c94b26d
Show file tree
Hide file tree
Showing 5 changed files with 262 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@ import {
GridDataSourceGroupNode,
useGridSelector,
} from '@mui/x-data-grid';
import { useGridSelectorV8 } from '@mui/x-data-grid/internals';
import CircularProgress from '@mui/material/CircularProgress';
import { useGridRootProps } from '../hooks/utils/useGridRootProps';
import { useGridPrivateApiContext } from '../hooks/utils/useGridPrivateApiContext';
import { DataGridProProcessedProps } from '../models/dataGridProProps';
import { GridPrivateApiPro } from '../models/gridApiPro';
import { GridStatePro } from '../models/gridStatePro';
import {
gridDataSourceErrorSelector,
gridDataSourceLoadingIdSelector,
} from '../hooks/features/dataSource/gridDataSourceSelector';

type OwnerState = DataGridProProcessedProps;

Expand Down Expand Up @@ -50,10 +55,8 @@ function GridTreeDataGroupingCellIcon(props: GridTreeDataGroupingCellIconProps)
const classes = useUtilityClasses(rootProps);
const { rowNode, id, field, descendantCount } = props;

const loadingSelector = (state: GridStatePro) => state.dataSource.loading[id] ?? false;
const errorSelector = (state: GridStatePro) => state.dataSource.errors[id];
const isDataLoading = useGridSelector(apiRef, loadingSelector);
const error = useGridSelector(apiRef, errorSelector);
const isDataLoading = useGridSelectorV8(apiRef, gridDataSourceLoadingIdSelector, id);
const error = useGridSelectorV8(apiRef, gridDataSourceErrorSelector, id);

const handleClick = (event: React.MouseEvent<HTMLButtonElement>) => {
if (!rowNode.childrenExpanded) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ import {
gridFilterModelSelector,
gridSortModelSelector,
gridPaginationModelSelector,
GridRowId,
} from '@mui/x-data-grid';
import { createSelector } from '@mui/x-data-grid/internals';
import { createSelector, createSelectorV8 } from '@mui/x-data-grid/internals';
import { GridStatePro } from '../../../models/gridStatePro';

const computeStartEnd = (paginationModel: GridPaginationModel) => {
Expand Down Expand Up @@ -37,7 +38,17 @@ export const gridDataSourceLoadingSelector = createSelector(
(dataSource) => dataSource.loading,
);

export const gridDataSourceLoadingIdSelector = createSelectorV8(
gridDataSourceStateSelector,
(dataSource, id: GridRowId) => dataSource.loading[id] ?? false,
);

export const gridDataSourceErrorsSelector = createSelector(
gridDataSourceStateSelector,
(dataSource) => dataSource.errors,
);

export const gridDataSourceErrorSelector = createSelectorV8(
gridDataSourceStateSelector,
(dataSource, id: GridRowId) => dataSource.errors[id],
);
75 changes: 74 additions & 1 deletion packages/x-data-grid/src/hooks/utils/useGridSelector.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import * as React from 'react';
import { fastObjectShallowCompare } from '@mui/x-internals/fastObjectShallowCompare';
import type { GridApiCommon } from '../../models/api/gridApiCommon';
import { OutputSelector } from '../../utils/createSelector';
import { OutputSelector, OutputSelectorV8 } from '../../utils/createSelector';
import { useLazyRef } from './useLazyRef';
import { useOnMount } from './useOnMount';
import { warnOnce } from '../../internals/utils/warning';
import type { GridCoreApi } from '../../models/api/gridCoreApi';

function isOutputSelector<Api extends GridApiCommon, T>(
selector: any,
): selector is OutputSelector<Api['state'], T> {
return selector.acceptsApiRef;
}

type Selector<Api extends GridApiCommon, Args, T> =
| ((state: Api['state']) => T)
| OutputSelectorV8<Api['state'], Args, T>;

// TODO v8: Remove this function
function applySelector<Api extends GridApiCommon, T>(
apiRef: React.MutableRefObject<Api>,
selector: ((state: Api['state']) => T) | OutputSelector<Api['state'], T>,
Expand All @@ -22,11 +28,25 @@ function applySelector<Api extends GridApiCommon, T>(
return selector(apiRef.current.state);
}

// TODO v8: Rename this function to `applySelector`
function applySelectorV8<Api extends GridApiCommon, Args, T>(
apiRef: React.MutableRefObject<Api>,
selector: Selector<Api, Args, T>,
args: Args,
instanceId: GridCoreApi['instanceId'],
) {
if (isOutputSelector(selector)) {
return selector(apiRef, args);
}
return selector(apiRef.current.state, instanceId);
}

const defaultCompare = Object.is;
export const objectShallowCompare = fastObjectShallowCompare;

const createRefs = () => ({ state: null, equals: null, selector: null }) as any;

// TODO v8: Remove this function
export const useGridSelector = <Api extends GridApiCommon, T>(
apiRef: React.MutableRefObject<Api>,
selector: ((state: Api['state']) => T) | OutputSelector<Api['state'], T>,
Expand Down Expand Up @@ -72,3 +92,56 @@ export const useGridSelector = <Api extends GridApiCommon, T>(

return state;
};

// TODO v8: Rename this function to `useGridSelector`
export const useGridSelectorV8 = <Api extends GridApiCommon, Args, T>(
apiRef: React.MutableRefObject<Api>,
selector: Selector<Api, Args, T>,
args: Args = undefined as Args,
equals: (a: T, b: T) => boolean = defaultCompare,
) => {
if (process.env.NODE_ENV !== 'production') {
if (!apiRef.current.state) {
warnOnce([
'MUI X: `useGridSelector` has been called before the initialization of the state.',
'This hook can only be used inside the context of the grid.',
]);
}
}

const refs = useLazyRef<
{
state: T;
equals: typeof equals;
selector: typeof selector;
},
never
>(createRefs);
const didInit = refs.current.selector !== null;

const [state, setState] = React.useState<T>(
// We don't use an initialization function to avoid allocations
(didInit ? null : applySelectorV8(apiRef, selector, args, apiRef.current.instanceId)) as T,
);

refs.current.state = state;
refs.current.equals = equals;
refs.current.selector = selector;

useOnMount(() => {
return apiRef.current.store.subscribe(() => {
const newState = applySelectorV8(
apiRef,
refs.current.selector,
args,
apiRef.current.instanceId,
) as T;
if (!refs.current.equals(refs.current.state, newState)) {
refs.current.state = newState;
setState(newState);
}
});
});

return state;
};
8 changes: 7 additions & 1 deletion packages/x-data-grid/src/internals/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,13 @@ export type * from '../models/props/DataGridProps';
export type * from '../models/gridDataSource';
export { getColumnsToExport, defaultGetRowsToExport } from '../hooks/features/export/utils';
export * from '../utils/createControllablePromise';
export { createSelector, createSelectorMemoized } from '../utils/createSelector';
export {
createSelector,
createSelectorV8,
createSelectorMemoized,
createSelectorMemoizedV8,
} from '../utils/createSelector';
export { useGridSelectorV8 } from '../hooks/utils/useGridSelector';
export { gridRowGroupsToFetchSelector } from '../hooks/features/rows/gridRowsSelector';
export {
findParentElementFromClassName,
Expand Down
162 changes: 162 additions & 0 deletions packages/x-data-grid/src/utils/createSelector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,23 @@ import { warnOnce } from '../internals/utils/warning';

type CacheKey = { id: number };

// TODO v8: Remove this type
export interface OutputSelector<State, Result> {
(apiRef: React.MutableRefObject<{ state: State; instanceId: GridCoreApi['instanceId'] }>): Result;
(state: State, instanceId: GridCoreApi['instanceId']): Result;
acceptsApiRef: boolean;
}

// TODO v8: Rename this type to `OutputSelector`
export interface OutputSelectorV8<State, Args, Result> {
(
apiRef: React.MutableRefObject<{ state: State; instanceId: GridCoreApi['instanceId'] }>,
args: Args,
): Result;
(state: State, instanceId: GridCoreApi['instanceId']): Result;
acceptsApiRef: boolean;
}

type StateFromSelector<T> = T extends (first: infer F, ...args: any[]) => any
? F extends { state: infer F2 }
? F2
Expand All @@ -26,16 +37,38 @@ type StateFromSelectorList<Selectors extends readonly any[]> = Selectors extends
: StateFromSelectorList<R>
: {};

// TODO v8: Remove this type
type SelectorArgs<Selectors extends ReadonlyArray<Selector<any>>, Result> =
// Input selectors as a separate array
| [selectors: [...Selectors], combiner: (...args: SelectorResultArray<Selectors>) => Result]
// Input selectors as separate inline arguments
| [...Selectors, (...args: SelectorResultArray<Selectors>) => Result];

type SelectorResultArrayWithArgs<Selectors extends ReadonlyArray<Selector<any>>, Args> = [
...SelectorResultArray<Selectors>,
Args,
];

// TODO v8: Rename this type to `SelectorArgs`
type SelectorArgsV8<Selectors extends ReadonlyArray<Selector<any>>, Args, Result> =
// Input selectors as a separate array
| [
selectors: [...Selectors],
combiner: (...args: SelectorResultArrayWithArgs<Selectors, Args>) => Result,
]
// Input selectors as separate inline arguments
| [...Selectors, (...args: SelectorResultArrayWithArgs<Selectors, Args>) => Result];

// TODO v8: Remove this type
type CreateSelectorFunction = <Selectors extends ReadonlyArray<Selector<any>>, Result>(
...items: SelectorArgs<Selectors, Result>
) => OutputSelector<StateFromSelectorList<Selectors>, Result>;

// TODO v8: Rename this type to `CreateSelectorFunction`
type CreateSelectorFunctionV8 = <Selectors extends ReadonlyArray<Selector<any>>, Args, Result>(
...items: SelectorArgsV8<Selectors, Args, Result>
) => OutputSelectorV8<StateFromSelectorList<Selectors>, Args, Result>;

const cache = new WeakMap<CacheKey, Map<any[], any>>();

function checkIsAPIRef(value: any) {
Expand All @@ -44,6 +77,7 @@ function checkIsAPIRef(value: any) {

const DEFAULT_INSTANCE_ID = { id: 'default' };

// TODO v8: Remove this function
export const createSelector = ((
a: Function,
b: Function,
Expand Down Expand Up @@ -125,6 +159,89 @@ export const createSelector = ((
return selector;
}) as unknown as CreateSelectorFunction;

// TODO v8: Rename this function to `createSelector`
export const createSelectorV8 = ((
a: Function,
b: Function,
c?: Function,
d?: Function,
e?: Function,
f?: Function,
...other: any[]
) => {
if (other.length > 0) {
throw new Error('Unsupported number of selectors');
}

let selector: any;

if (a && b && c && d && e && f) {
selector = (stateOrApiRef: any, args: any, instanceIdParam: any) => {
const isAPIRef = checkIsAPIRef(stateOrApiRef);
const instanceId =
instanceIdParam ?? (isAPIRef ? stateOrApiRef.current.instanceId : DEFAULT_INSTANCE_ID);
const state = isAPIRef ? stateOrApiRef.current.state : stateOrApiRef;
const va = a(state, args, instanceId);
const vb = b(state, args, instanceId);
const vc = c(state, args, instanceId);
const vd = d(state, args, instanceId);
const ve = e(state, args, instanceId);
return f(va, vb, vc, vd, ve, args);
};
} else if (a && b && c && d && e) {
selector = (stateOrApiRef: any, args: any, instanceIdParam: any) => {
const isAPIRef = checkIsAPIRef(stateOrApiRef);
const instanceId =
instanceIdParam ?? (isAPIRef ? stateOrApiRef.current.instanceId : DEFAULT_INSTANCE_ID);
const state = isAPIRef ? stateOrApiRef.current.state : stateOrApiRef;
const va = a(state, args, instanceId);
const vb = b(state, args, instanceId);
const vc = c(state, args, instanceId);
const vd = d(state, args, instanceId);
return e(va, vb, vc, vd, args);
};
} else if (a && b && c && d) {
selector = (stateOrApiRef: any, args: any, instanceIdParam: any) => {
const isAPIRef = checkIsAPIRef(stateOrApiRef);
const instanceId =
instanceIdParam ?? (isAPIRef ? stateOrApiRef.current.instanceId : DEFAULT_INSTANCE_ID);
const state = isAPIRef ? stateOrApiRef.current.state : stateOrApiRef;
const va = a(state, args, instanceId);
const vb = b(state, args, instanceId);
const vc = c(state, args, instanceId);
return d(va, vb, vc, args);
};
} else if (a && b && c) {
selector = (stateOrApiRef: any, args: any, instanceIdParam: any) => {
const isAPIRef = checkIsAPIRef(stateOrApiRef);
const instanceId =
instanceIdParam ?? (isAPIRef ? stateOrApiRef.current.instanceId : DEFAULT_INSTANCE_ID);
const state = isAPIRef ? stateOrApiRef.current.state : stateOrApiRef;
const va = a(state, args, instanceId);
const vb = b(state, args, instanceId);
return c(va, vb, args);
};
} else if (a && b) {
selector = (stateOrApiRef: any, args: any, instanceIdParam: any) => {
const isAPIRef = checkIsAPIRef(stateOrApiRef);
const instanceId =
instanceIdParam ?? (isAPIRef ? stateOrApiRef.current.instanceId : DEFAULT_INSTANCE_ID);
const state = isAPIRef ? stateOrApiRef.current.state : stateOrApiRef;
const va = a(state, args, instanceId);
return b(va, args);
};
} else {
throw new Error('Missing arguments');
}

// We use this property to detect if the selector was created with createSelector
// or it's only a simple function the receives the state and returns part of it.
selector.acceptsApiRef = true;

return selector;
}) as unknown as CreateSelectorFunctionV8;

// TODO v8: Remove this function
export const createSelectorMemoized: CreateSelectorFunction = (...args: any) => {
const selector = (stateOrApiRef: any, instanceId?: any) => {
const isAPIRef = checkIsAPIRef(stateOrApiRef);
Expand Down Expand Up @@ -168,3 +285,48 @@ export const createSelectorMemoized: CreateSelectorFunction = (...args: any) =>

return selector;
};

// TODO v8: Rename this function to `createSelectorMemoized`
export const createSelectorMemoizedV8: CreateSelectorFunctionV8 = (...args: any) => {
const selector = (stateOrApiRef: any, selectorArgs: any, instanceId?: any) => {
const isAPIRef = checkIsAPIRef(stateOrApiRef);
const cacheKey = isAPIRef
? stateOrApiRef.current.instanceId
: (instanceId ?? DEFAULT_INSTANCE_ID);
const state = isAPIRef ? stateOrApiRef.current.state : stateOrApiRef;

if (process.env.NODE_ENV !== 'production') {
if (cacheKey.id === 'default') {
warnOnce([
'MUI X: A selector was called without passing the instance ID, which may impact the performance of the grid.',
'To fix, call it with `apiRef`, for example `mySelector(apiRef)`, or pass the instance ID explicitly, for example `mySelector(state, apiRef.current.instanceId)`.',
]);
}
}

const cacheArgsInit = cache.get(cacheKey);
const cacheArgs = cacheArgsInit ?? new Map();
const cacheFn = cacheArgs?.get(args);

if (cacheArgs && cacheFn) {
// We pass the cache key because the called selector might have as
// dependency another selector created with this `createSelector`.
return cacheFn(state, selectorArgs, cacheKey);
}

const fn = reselectCreateSelector(...args);

if (!cacheArgsInit) {
cache.set(cacheKey, cacheArgs);
}
cacheArgs.set(args, fn);

return fn(state, selectorArgs, cacheKey);
};

// We use this property to detect if the selector was created with createSelector
// or it's only a simple function the receives the state and returns part of it.
selector.acceptsApiRef = true;

return selector;
};

0 comments on commit c94b26d

Please sign in to comment.