Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataGrid] Introduce selectors with arguments #14236

Merged
merged 6 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@ import {
GridDataSourceGroupNode,
useGridSelector,
} from '@mui/x-data-grid';
import { useGridSelectorV8 } from '@mui/x-data-grid/internals';
import CircularProgress from '@mui/material/CircularProgress';
import { useGridRootProps } from '../hooks/utils/useGridRootProps';
import { useGridPrivateApiContext } from '../hooks/utils/useGridPrivateApiContext';
import { DataGridProProcessedProps } from '../models/dataGridProProps';
import { GridPrivateApiPro } from '../models/gridApiPro';
import { GridStatePro } from '../models/gridStatePro';
import {
gridDataSourceErrorSelector,
gridDataSourceLoadingIdSelector,
} from '../hooks/features/dataSource/gridDataSourceSelector';

type OwnerState = DataGridProProcessedProps;

Expand Down Expand Up @@ -50,10 +55,8 @@ function GridTreeDataGroupingCellIcon(props: GridTreeDataGroupingCellIconProps)
const classes = useUtilityClasses(rootProps);
const { rowNode, id, field, descendantCount } = props;

const loadingSelector = (state: GridStatePro) => state.dataSource.loading[id] ?? false;
const errorSelector = (state: GridStatePro) => state.dataSource.errors[id];
const isDataLoading = useGridSelector(apiRef, loadingSelector);
const error = useGridSelector(apiRef, errorSelector);
const isDataLoading = useGridSelectorV8(apiRef, gridDataSourceLoadingIdSelector, id);
const error = useGridSelectorV8(apiRef, gridDataSourceErrorSelector, id);

const handleClick = (event: React.MouseEvent<HTMLButtonElement>) => {
if (!rowNode.childrenExpanded) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ import {
gridFilterModelSelector,
gridSortModelSelector,
gridPaginationModelSelector,
GridRowId,
} from '@mui/x-data-grid';
import { createSelector } from '@mui/x-data-grid/internals';
import { createSelector, createArgumentsSelector } from '@mui/x-data-grid/internals';
import { GridStatePro } from '../../../models/gridStatePro';

const computeStartEnd = (paginationModel: GridPaginationModel) => {
Expand Down Expand Up @@ -37,7 +38,17 @@ export const gridDataSourceLoadingSelector = createSelector(
(dataSource) => dataSource.loading,
);

export const gridDataSourceLoadingIdSelector = createArgumentsSelector<GridRowId>()(
gridDataSourceStateSelector,
(dataSource, id) => dataSource.loading[id] ?? false,
);

export const gridDataSourceErrorsSelector = createSelector(
gridDataSourceStateSelector,
(dataSource) => dataSource.errors,
);

export const gridDataSourceErrorSelector = createArgumentsSelector<GridRowId>()(
gridDataSourceStateSelector,
(dataSource, id) => dataSource.errors[id],
);
75 changes: 74 additions & 1 deletion packages/x-data-grid/src/hooks/utils/useGridSelector.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
import * as React from 'react';
import { fastObjectShallowCompare } from '@mui/x-internals/fastObjectShallowCompare';
import type { GridApiCommon } from '../../models/api/gridApiCommon';
import { OutputSelector } from '../../utils/createSelector';
import { OutputSelector, OutputArgumentsSelector } from '../../utils/createSelector';
import { useLazyRef } from './useLazyRef';
import { useOnMount } from './useOnMount';
import { warnOnce } from '../../internals/utils/warning';
import type { GridCoreApi } from '../../models/api/gridCoreApi';

function isOutputSelector<Api extends GridApiCommon, T>(
selector: any,
): selector is OutputSelector<Api['state'], T> {
return selector.acceptsApiRef;
}

type Selector<Api extends GridApiCommon, Args, T> =
| ((state: Api['state']) => T)
| OutputArgumentsSelector<Api['state'], Args, T>;

// TODO v8: Remove this function
function applySelector<Api extends GridApiCommon, T>(
apiRef: React.MutableRefObject<Api>,
selector: ((state: Api['state']) => T) | OutputSelector<Api['state'], T>,
Expand All @@ -22,11 +28,25 @@ function applySelector<Api extends GridApiCommon, T>(
return selector(apiRef.current.state);
}

// TODO v8: Rename this function to `applySelector`
function applySelectorV8<Api extends GridApiCommon, Args, T>(
apiRef: React.MutableRefObject<Api>,
selector: Selector<Api, Args, T>,
args: Args,
instanceId: GridCoreApi['instanceId'],
) {
if (isOutputSelector(selector)) {
return selector(apiRef, args);
}
return selector(apiRef.current.state, instanceId);
}

const defaultCompare = Object.is;
export const objectShallowCompare = fastObjectShallowCompare;

const createRefs = () => ({ state: null, equals: null, selector: null }) as any;

// TODO v8: Remove this function
export const useGridSelector = <Api extends GridApiCommon, T>(
apiRef: React.MutableRefObject<Api>,
selector: ((state: Api['state']) => T) | OutputSelector<Api['state'], T>,
Expand Down Expand Up @@ -72,3 +92,56 @@ export const useGridSelector = <Api extends GridApiCommon, T>(

return state;
};

// TODO v8: Rename this function to `useGridSelector`
export const useGridSelectorV8 = <Api extends GridApiCommon, Args, T>(
apiRef: React.MutableRefObject<Api>,
selector: Selector<Api, Args, T>,
args: Args = {} as Args,
MBilalShafi marked this conversation as resolved.
Show resolved Hide resolved
equals: (a: T, b: T) => boolean = defaultCompare,
) => {
if (process.env.NODE_ENV !== 'production') {
if (!apiRef.current.state) {
warnOnce([
'MUI X: `useGridSelector` has been called before the initialization of the state.',
'This hook can only be used inside the context of the grid.',
]);
}
}

const refs = useLazyRef<
{
state: T;
equals: typeof equals;
selector: typeof selector;
},
never
>(createRefs);
const didInit = refs.current.selector !== null;

const [state, setState] = React.useState<T>(
// We don't use an initialization function to avoid allocations
(didInit ? null : applySelectorV8(apiRef, selector, args, apiRef.current.instanceId)) as T,
);

refs.current.state = state;
refs.current.equals = equals;
refs.current.selector = selector;

useOnMount(() => {
return apiRef.current.store.subscribe(() => {
const newState = applySelectorV8(
apiRef,
refs.current.selector,
args,
apiRef.current.instanceId,
) as T;
if (!refs.current.equals(refs.current.state, newState)) {
refs.current.state = newState;
setState(newState);
}
});
});

return state;
};
8 changes: 7 additions & 1 deletion packages/x-data-grid/src/internals/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,13 @@ export type * from '../models/props/DataGridProps';
export type * from '../models/gridDataSource';
export { getColumnsToExport, defaultGetRowsToExport } from '../hooks/features/export/utils';
export * from '../utils/createControllablePromise';
export { createSelector, createSelectorMemoized } from '../utils/createSelector';
export {
createSelector,
createArgumentsSelector,
createSelectorMemoized,
createArgumentsSelectorMemoized,
} from '../utils/createSelector';
export { useGridSelectorV8 } from '../hooks/utils/useGridSelector';
export { gridRowGroupsToFetchSelector } from '../hooks/features/rows/gridRowsSelector';
export {
findParentElementFromClassName,
Expand Down
158 changes: 158 additions & 0 deletions packages/x-data-grid/src/utils/createSelector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@ export interface OutputSelector<State, Result> {
acceptsApiRef: boolean;
}

export interface OutputArgumentsSelector<State, Args, Result> {
(
apiRef: React.MutableRefObject<{ state: State; instanceId: GridCoreApi['instanceId'] }>,
args: Args,
): Result;
(state: State, instanceId: GridCoreApi['instanceId']): Result;
acceptsApiRef: boolean;
}

type StateFromSelector<T> = T extends (first: infer F, ...args: any[]) => any
? F extends { state: infer F2 }
? F2
Expand All @@ -32,10 +41,31 @@ type SelectorArgs<Selectors extends ReadonlyArray<Selector<any>>, Result> =
// Input selectors as separate inline arguments
| [...Selectors, (...args: SelectorResultArray<Selectors>) => Result];

type SelectorResultArrayWithArgs<Selectors extends ReadonlyArray<Selector<any>>, Args> = [
...SelectorResultArray<Selectors>,
Args,
];

type ArgumentsSelectorArgs<Selectors extends ReadonlyArray<Selector<any>>, Args, Result> =
// Input selectors as a separate array
| [
selectors: [...Selectors],
combiner: (...args: SelectorResultArrayWithArgs<Selectors, Args>) => Result,
]
// Input selectors as separate inline arguments
| [...Selectors, (...args: SelectorResultArrayWithArgs<Selectors, Args>) => Result];

type CreateSelectorFunction = <Selectors extends ReadonlyArray<Selector<any>>, Result>(
...items: SelectorArgs<Selectors, Result>
) => OutputSelector<StateFromSelectorList<Selectors>, Result>;

type CreateArgumentsSelectorFunction = <Args = void>() => <
MBilalShafi marked this conversation as resolved.
Show resolved Hide resolved
Selectors extends ReadonlyArray<Selector<any>>,
Result,
>(
...items: ArgumentsSelectorArgs<Selectors, Args, Result>
) => OutputArgumentsSelector<StateFromSelectorList<Selectors>, Args, Result>;

const cache = new WeakMap<CacheKey, Map<any[], any>>();

function checkIsAPIRef(value: any) {
Expand Down Expand Up @@ -125,6 +155,88 @@ export const createSelector = ((
return selector;
}) as unknown as CreateSelectorFunction;

export const createArgumentsSelector = (() =>
(
a: Function,
b: Function,
c?: Function,
d?: Function,
e?: Function,
f?: Function,
...other: any[]
) => {
if (other.length > 0) {
throw new Error('Unsupported number of selectors');
}

let selector: any;

if (a && b && c && d && e && f) {
selector = (stateOrApiRef: any, args: any, instanceIdParam: any) => {
const isAPIRef = checkIsAPIRef(stateOrApiRef);
const instanceId =
instanceIdParam ?? (isAPIRef ? stateOrApiRef.current.instanceId : DEFAULT_INSTANCE_ID);
const state = isAPIRef ? stateOrApiRef.current.state : stateOrApiRef;
const va = a(state, args, instanceId);
const vb = b(state, args, instanceId);
const vc = c(state, args, instanceId);
const vd = d(state, args, instanceId);
const ve = e(state, args, instanceId);
return f(va, vb, vc, vd, ve, args);
};
} else if (a && b && c && d && e) {
selector = (stateOrApiRef: any, args: any, instanceIdParam: any) => {
const isAPIRef = checkIsAPIRef(stateOrApiRef);
const instanceId =
instanceIdParam ?? (isAPIRef ? stateOrApiRef.current.instanceId : DEFAULT_INSTANCE_ID);
const state = isAPIRef ? stateOrApiRef.current.state : stateOrApiRef;
const va = a(state, args, instanceId);
const vb = b(state, args, instanceId);
const vc = c(state, args, instanceId);
const vd = d(state, args, instanceId);
return e(va, vb, vc, vd, args);
};
} else if (a && b && c && d) {
selector = (stateOrApiRef: any, args: any, instanceIdParam: any) => {
const isAPIRef = checkIsAPIRef(stateOrApiRef);
const instanceId =
instanceIdParam ?? (isAPIRef ? stateOrApiRef.current.instanceId : DEFAULT_INSTANCE_ID);
const state = isAPIRef ? stateOrApiRef.current.state : stateOrApiRef;
const va = a(state, args, instanceId);
const vb = b(state, args, instanceId);
const vc = c(state, args, instanceId);
return d(va, vb, vc, args);
};
} else if (a && b && c) {
selector = (stateOrApiRef: any, args: any, instanceIdParam: any) => {
const isAPIRef = checkIsAPIRef(stateOrApiRef);
const instanceId =
instanceIdParam ?? (isAPIRef ? stateOrApiRef.current.instanceId : DEFAULT_INSTANCE_ID);
const state = isAPIRef ? stateOrApiRef.current.state : stateOrApiRef;
const va = a(state, args, instanceId);
const vb = b(state, args, instanceId);
return c(va, vb, args);
};
} else if (a && b) {
selector = (stateOrApiRef: any, args: any, instanceIdParam: any) => {
const isAPIRef = checkIsAPIRef(stateOrApiRef);
const instanceId =
instanceIdParam ?? (isAPIRef ? stateOrApiRef.current.instanceId : DEFAULT_INSTANCE_ID);
const state = isAPIRef ? stateOrApiRef.current.state : stateOrApiRef;
const va = a(state, args, instanceId);
return b(va, args);
};
} else {
throw new Error('Missing arguments');
}

// We use this property to detect if the selector was created with createSelector
// or it's only a simple function the receives the state and returns part of it.
selector.acceptsApiRef = true;

return selector;
}) as unknown as CreateArgumentsSelectorFunction;

export const createSelectorMemoized: CreateSelectorFunction = (...args: any) => {
const selector = (stateOrApiRef: any, instanceId?: any) => {
const isAPIRef = checkIsAPIRef(stateOrApiRef);
Expand Down Expand Up @@ -168,3 +280,49 @@ export const createSelectorMemoized: CreateSelectorFunction = (...args: any) =>

return selector;
};

export const createArgumentsSelectorMemoized: CreateArgumentsSelectorFunction =
() =>
(...args: any) => {
const selector = (stateOrApiRef: any, selectorArgs: any, instanceId?: any) => {
const isAPIRef = checkIsAPIRef(stateOrApiRef);
const cacheKey = isAPIRef
? stateOrApiRef.current.instanceId
: (instanceId ?? DEFAULT_INSTANCE_ID);
const state = isAPIRef ? stateOrApiRef.current.state : stateOrApiRef;

if (process.env.NODE_ENV !== 'production') {
if (cacheKey.id === 'default') {
warnOnce([
'MUI X: A selector was called without passing the instance ID, which may impact the performance of the grid.',
'To fix, call it with `apiRef`, for example `mySelector(apiRef)`, or pass the instance ID explicitly, for example `mySelector(state, apiRef.current.instanceId)`.',
]);
}
}

const cacheArgsInit = cache.get(cacheKey);
const cacheArgs = cacheArgsInit ?? new Map();
const cacheFn = cacheArgs?.get(args);

if (cacheArgs && cacheFn) {
// We pass the cache key because the called selector might have as
// dependency another selector created with this `createSelector`.
return cacheFn(state, selectorArgs, cacheKey);
}

const fn = reselectCreateSelector(...args);

if (!cacheArgsInit) {
cache.set(cacheKey, cacheArgs);
}
cacheArgs.set(args, fn);

return fn(state, selectorArgs, cacheKey);
};

// We use this property to detect if the selector was created with createSelector
// or it's only a simple function the receives the state and returns part of it.
selector.acceptsApiRef = true;

return selector;
};