diff --git a/packages/x-tree-view/src/internals/plugins/useTreeViewFocus/useTreeViewFocus.ts b/packages/x-tree-view/src/internals/plugins/useTreeViewFocus/useTreeViewFocus.ts index 0491e5e3946b..567b676dc117 100644 --- a/packages/x-tree-view/src/internals/plugins/useTreeViewFocus/useTreeViewFocus.ts +++ b/packages/x-tree-view/src/internals/plugins/useTreeViewFocus/useTreeViewFocus.ts @@ -6,6 +6,7 @@ import { TreeViewPlugin, TreeViewUsedInstance } from '../../models'; import { UseTreeViewFocusSignature } from './useTreeViewFocus.types'; import { useInstanceEventHandler } from '../../hooks/useInstanceEventHandler'; import { getActiveElement } from '../../utils/utils'; +import { getFirstNavigableItem } from '../../useTreeView/useTreeView.utils'; const useTabbableItemId = ( instance: TreeViewUsedInstance, @@ -24,7 +25,7 @@ const useTabbableItemId = ( } if (tabbableItemId == null) { - tabbableItemId = instance.getNavigableChildrenIds(null)[0]; + tabbableItemId = getFirstNavigableItem(instance); } return tabbableItemId; @@ -95,7 +96,7 @@ export const useTreeViewFocus: TreeViewPlugin = ({ } if (itemToFocusId == null) { - itemToFocusId = instance.getNavigableChildrenIds(null)[0]; + itemToFocusId = getFirstNavigableItem(instance); } innerFocusItem(event, itemToFocusId); diff --git a/packages/x-tree-view/src/internals/plugins/useTreeViewItems/useTreeViewItems.ts b/packages/x-tree-view/src/internals/plugins/useTreeViewItems/useTreeViewItems.ts index 268a86c87877..5e36cf69e72b 100644 --- a/packages/x-tree-view/src/internals/plugins/useTreeViewItems/useTreeViewItems.ts +++ b/packages/x-tree-view/src/internals/plugins/useTreeViewItems/useTreeViewItems.ts @@ -144,13 +144,11 @@ export const useTreeViewItems: TreeViewPlugin = ({ [state.items.itemMetaMap], ); - const getNavigableChildrenIds = (itemId: string | null) => { - let childrenIds = instance.getChildrenIds(itemId); - - if (!params.disabledItemsFocusable) { - childrenIds = childrenIds.filter((item) => !instance.isItemDisabled(item)); + const isItemNavigable = (itemId: string) => { + if (params.disabledItemsFocusable) { + return true; } - return childrenIds; + return !instance.isItemDisabled(itemId); }; const areItemUpdatesPreventedRef = React.useRef(false); @@ -216,8 +214,8 @@ export const useTreeViewItems: TreeViewPlugin = ({ getItem, getItemsToRender, getChildrenIds, - getNavigableChildrenIds, isItemDisabled, + isItemNavigable, preventItemUpdates, areItemUpdatesPrevented, }, diff --git a/packages/x-tree-view/src/internals/plugins/useTreeViewItems/useTreeViewItems.types.ts b/packages/x-tree-view/src/internals/plugins/useTreeViewItems/useTreeViewItems.types.ts index c0cefb6e8639..e381ca448c73 100644 --- a/packages/x-tree-view/src/internals/plugins/useTreeViewItems/useTreeViewItems.types.ts +++ b/packages/x-tree-view/src/internals/plugins/useTreeViewItems/useTreeViewItems.types.ts @@ -21,8 +21,8 @@ export interface UseTreeViewItemsInstance extends UseTreeViewItems getItemMeta: (itemId: string) => TreeViewItemMeta; getItemsToRender: () => TreeViewItemProps[]; getChildrenIds: (itemId: string | null) => string[]; - getNavigableChildrenIds: (itemId: string | null) => string[]; - isItemDisabled: (itemId: string | null) => itemId is string; + isItemDisabled: (itemId: string) => itemId is string; + isItemNavigable: (itemId: string) => boolean; /** * Freeze any future update to the state based on the `items` prop. * This is useful when `useTreeViewJSXNodes` is used to avoid having conflicting sources of truth. diff --git a/packages/x-tree-view/src/internals/plugins/useTreeViewKeyboardNavigation/useTreeViewKeyboardNavigation.ts b/packages/x-tree-view/src/internals/plugins/useTreeViewKeyboardNavigation/useTreeViewKeyboardNavigation.ts index 0a97109fe0df..994bba6a0cae 100644 --- a/packages/x-tree-view/src/internals/plugins/useTreeViewKeyboardNavigation/useTreeViewKeyboardNavigation.ts +++ b/packages/x-tree-view/src/internals/plugins/useTreeViewKeyboardNavigation/useTreeViewKeyboardNavigation.ts @@ -3,10 +3,10 @@ import { useTheme } from '@mui/material/styles'; import useEventCallback from '@mui/utils/useEventCallback'; import { TreeViewItemMeta, TreeViewPlugin } from '../../models'; import { - getFirstItem, - getLastItem, - getNextItem, - getPreviousItem, + getFirstNavigableItem, + getLastNavigableItem, + getNextNavigableItem, + getPreviousNavigableItem, } from '../../useTreeView/useTreeView.utils'; import { TreeViewFirstCharMap, @@ -157,7 +157,7 @@ export const useTreeViewKeyboardNavigation: TreeViewPlugin< // Focus the next focusable item case key === 'ArrowDown': { - const nextItem = getNextItem(instance, itemId); + const nextItem = getNextNavigableItem(instance, itemId); if (nextItem) { event.preventDefault(); instance.focusItem(event, nextItem); @@ -181,7 +181,7 @@ export const useTreeViewKeyboardNavigation: TreeViewPlugin< // Focuses the previous focusable item case key === 'ArrowUp': { - const previousItem = getPreviousItem(instance, itemId); + const previousItem = getPreviousNavigableItem(instance, itemId); if (previousItem) { event.preventDefault(); instance.focusItem(event, previousItem); @@ -207,7 +207,7 @@ export const useTreeViewKeyboardNavigation: TreeViewPlugin< // If the focused item is collapsed and has children, we expand it case (key === 'ArrowRight' && !isRTL) || (key === 'ArrowLeft' && isRTL): { if (instance.isItemExpanded(itemId)) { - const nextItemId = getNextItem(instance, itemId); + const nextItemId = getNextNavigableItem(instance, itemId); if (nextItemId) { instance.focusItem(event, nextItemId); event.preventDefault(); @@ -239,7 +239,7 @@ export const useTreeViewKeyboardNavigation: TreeViewPlugin< // Focuses the first item in the tree case key === 'Home': { - instance.focusItem(event, getFirstItem(instance)); + instance.focusItem(event, getFirstNavigableItem(instance)); // Multi select behavior when pressing Ctrl + Shift + Home // Selects the focused item and all items up to the first item. @@ -253,7 +253,7 @@ export const useTreeViewKeyboardNavigation: TreeViewPlugin< // Focuses the last item in the tree case key === 'End': { - instance.focusItem(event, getLastItem(instance)); + instance.focusItem(event, getLastNavigableItem(instance)); // Multi select behavior when pressing Ctrl + Shirt + End // Selects the focused item and all the items down to the last item. @@ -276,8 +276,8 @@ export const useTreeViewKeyboardNavigation: TreeViewPlugin< // Selects all the items case key === 'a' && ctrlPressed && params.multiSelect && !params.disableSelection: { instance.selectRange(event, { - start: getFirstItem(instance), - end: getLastItem(instance), + start: getFirstNavigableItem(instance), + end: getLastNavigableItem(instance), }); event.preventDefault(); break; diff --git a/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.ts b/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.ts index 58c30f122b10..2a98e5553570 100644 --- a/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.ts +++ b/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.ts @@ -1,8 +1,11 @@ import * as React from 'react'; import { TreeViewPlugin, TreeViewItemRange } from '../../models'; -import { getNextItem, getFirstItem, getLastItem } from '../../useTreeView/useTreeView.utils'; +import { + getFirstNavigableItem, + getLastNavigableItem, + getNavigableItemsInRange, +} from '../../useTreeView/useTreeView.utils'; import { UseTreeViewSelectionSignature } from './useTreeViewSelection.types'; -import { findOrderInTremauxTree } from './useTreeViewSelection.utils'; export const useTreeViewSelection: TreeViewPlugin = ({ instance, @@ -80,20 +83,6 @@ export const useTreeViewSelection: TreeViewPlugin currentRangeSelection.current = []; }; - const getItemsInRange = (itemAId: string, itemBId: string) => { - const [first, last] = findOrderInTremauxTree(instance, itemAId, itemBId); - const items = [first]; - - let current = first; - - while (current !== last) { - current = getNextItem(instance, current)!; - items.push(current); - } - - return items; - }; - const handleRangeArrowSelect = (event: React.SyntheticEvent, items: TreeViewItemRange) => { let base = (models.selectedItems.value as string[]).slice(); const { start, next, current } = items; @@ -134,7 +123,7 @@ export const useTreeViewSelection: TreeViewPlugin base = base.filter((id) => currentRangeSelection.current.indexOf(id) === -1); } - let range = getItemsInRange(start, end); + let range = getNavigableItemsInRange(instance, start, end); range = range.filter((item) => !instance.isItemDisabled(item)); currentRangeSelection.current = range; let newSelected = base.concat(range); @@ -165,7 +154,7 @@ export const useTreeViewSelection: TreeViewPlugin instance.selectRange(event, { start, - end: getFirstItem(instance), + end: getFirstNavigableItem(instance), }); }; @@ -178,7 +167,7 @@ export const useTreeViewSelection: TreeViewPlugin instance.selectRange(event, { start, - end: getLastItem(instance), + end: getLastNavigableItem(instance), }); }; diff --git a/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.utils.ts b/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.utils.ts deleted file mode 100644 index 03c3daa60d60..000000000000 --- a/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.utils.ts +++ /dev/null @@ -1,75 +0,0 @@ -import { TreeViewInstance } from '../../models'; -import { UseTreeViewItemsSignature } from '../useTreeViewItems'; - -/** - * This is used to determine the start and end of a selection range so - * we can get the items between the two border items. - * - * It finds the items' common ancestor using - * a naive implementation of a lowest common ancestor algorithm - * (https://en.wikipedia.org/wiki/Lowest_common_ancestor). - * Then compares the ancestor's 2 children that are ancestors of itemA and ItemB - * so we can compare their indexes to work out which item comes first in a depth first search. - * (https://en.wikipedia.org/wiki/Depth-first_search) - * - * Another way to put it is which item is shallower in a trémaux tree - * https://en.wikipedia.org/wiki/Tr%C3%A9maux_tree - */ -export const findOrderInTremauxTree = ( - instance: TreeViewInstance<[UseTreeViewItemsSignature]>, - itemAId: string, - itemBId: string, -) => { - if (itemAId === itemBId) { - return [itemAId, itemBId]; - } - - const itemA = instance.getItemMeta(itemAId); - const itemB = instance.getItemMeta(itemBId); - - if (itemA.parentId === itemB.id || itemB.parentId === itemA.id) { - return itemB.parentId === itemA.id ? [itemA.id, itemB.id] : [itemB.id, itemA.id]; - } - - const aFamily: (string | null)[] = [itemA.id]; - const bFamily: (string | null)[] = [itemB.id]; - - let aAncestor = itemA.parentId; - let bAncestor = itemB.parentId; - - let aAncestorIsCommon = bFamily.indexOf(aAncestor) !== -1; - let bAncestorIsCommon = aFamily.indexOf(bAncestor) !== -1; - - let continueA = true; - let continueB = true; - - while (!bAncestorIsCommon && !aAncestorIsCommon) { - if (continueA) { - aFamily.push(aAncestor); - aAncestorIsCommon = bFamily.indexOf(aAncestor) !== -1; - continueA = aAncestor !== null; - if (!aAncestorIsCommon && continueA) { - aAncestor = instance.getItemMeta(aAncestor!).parentId; - } - } - - if (continueB && !aAncestorIsCommon) { - bFamily.push(bAncestor); - bAncestorIsCommon = aFamily.indexOf(bAncestor) !== -1; - continueB = bAncestor !== null; - if (!bAncestorIsCommon && continueB) { - bAncestor = instance.getItemMeta(bAncestor!).parentId; - } - } - } - - const commonAncestor = aAncestorIsCommon ? aAncestor : bAncestor; - const ancestorFamily = instance.getChildrenIds(commonAncestor); - - const aSide = aFamily[aFamily.indexOf(commonAncestor) - 1]; - const bSide = bFamily[bFamily.indexOf(commonAncestor) - 1]; - - return ancestorFamily.indexOf(aSide!) < ancestorFamily.indexOf(bSide!) - ? [itemAId, itemBId] - : [itemBId, itemAId]; -}; diff --git a/packages/x-tree-view/src/internals/useTreeView/useTreeView.utils.ts b/packages/x-tree-view/src/internals/useTreeView/useTreeView.utils.ts index 4c3033421b20..17cbc41f3ef6 100644 --- a/packages/x-tree-view/src/internals/useTreeView/useTreeView.utils.ts +++ b/packages/x-tree-view/src/internals/useTreeView/useTreeView.utils.ts @@ -2,46 +2,79 @@ import { TreeViewInstance } from '../models'; import type { UseTreeViewExpansionSignature } from '../plugins/useTreeViewExpansion'; import type { UseTreeViewItemsSignature } from '../plugins/useTreeViewItems'; -export const getPreviousItem = ( +const getLastNavigableItemInArray = ( + instance: TreeViewInstance<[UseTreeViewItemsSignature]>, + items: string[], +) => { + // Equivalent to Array.prototype.findLastIndex + let itemIndex = items.length - 1; + while (itemIndex >= 0 && !instance.isItemNavigable(items[itemIndex])) { + itemIndex -= 1; + } + + if (itemIndex === -1) { + return undefined; + } + + return items[itemIndex]; +}; + +export const getPreviousNavigableItem = ( instance: TreeViewInstance<[UseTreeViewItemsSignature, UseTreeViewExpansionSignature]>, itemId: string, ) => { const itemMeta = instance.getItemMeta(itemId); - const siblings = instance.getNavigableChildrenIds(itemMeta.parentId); + const siblings = instance.getChildrenIds(itemMeta.parentId); const itemIndex = siblings.indexOf(itemId); + // TODO: What should we do if the parent is not navigable? if (itemIndex === 0) { return itemMeta.parentId; } - let currentItem: string = siblings[itemIndex - 1]; - while ( - instance.isItemExpanded(currentItem) && - instance.getNavigableChildrenIds(currentItem).length > 0 - ) { - currentItem = instance.getNavigableChildrenIds(currentItem).pop()!; + let currentItemId: string = siblings[itemIndex - 1]; + let lastNavigableChild = getLastNavigableItemInArray( + instance, + instance.getChildrenIds(currentItemId), + ); + while (instance.isItemExpanded(currentItemId) && lastNavigableChild != null) { + currentItemId = lastNavigableChild; + lastNavigableChild = instance.getChildrenIds(currentItemId).find(instance.isItemNavigable); } - return currentItem; + return currentItemId; }; -export const getNextItem = ( +export const getNextNavigableItem = ( instance: TreeViewInstance<[UseTreeViewExpansionSignature, UseTreeViewItemsSignature]>, itemId: string, ) => { - // If expanded get first child - if (instance.isItemExpanded(itemId) && instance.getNavigableChildrenIds(itemId).length > 0) { - return instance.getNavigableChildrenIds(itemId)[0]; + // If the item is expanded and has some navigable children, return the first of them. + if (instance.isItemExpanded(itemId)) { + const firstNavigableChild = instance.getChildrenIds(itemId).find(instance.isItemNavigable); + if (firstNavigableChild != null) { + return firstNavigableChild; + } } let itemMeta = instance.getItemMeta(itemId); while (itemMeta != null) { - // Try to get next sibling - const siblings = instance.getNavigableChildrenIds(itemMeta.parentId); - const nextSibling = siblings[siblings.indexOf(itemMeta.id) + 1]; + // Try to find the first navigable sibling after the current item. + const siblings = instance.getChildrenIds(itemMeta.parentId); + const currentItemIndex = siblings.indexOf(itemMeta.id); + + if (currentItemIndex < siblings.length - 1) { + let nextItemIndex = currentItemIndex + 1; + while ( + !instance.isItemNavigable(siblings[nextItemIndex]) && + nextItemIndex < siblings.length - 1 + ) { + nextItemIndex += 1; + } - if (nextSibling) { - return nextSibling; + if (instance.isItemNavigable(siblings[nextItemIndex])) { + return siblings[nextItemIndex]; + } } // If the sibling does not exist, go up a level to the parent and try again. @@ -51,16 +84,142 @@ export const getNextItem = ( return null; }; -export const getLastItem = ( +export const getLastNavigableItem = ( instance: TreeViewInstance<[UseTreeViewExpansionSignature, UseTreeViewItemsSignature]>, ) => { - let lastItem = instance.getNavigableChildrenIds(null).pop()!; + let itemId: string | null = null; + while (itemId == null || instance.isItemExpanded(itemId)) { + const children = instance.getChildrenIds(itemId); + const lastNavigableChild = getLastNavigableItemInArray(instance, children); - while (instance.isItemExpanded(lastItem)) { - lastItem = instance.getNavigableChildrenIds(lastItem).pop()!; + // The item has no navigable children. + if (lastNavigableChild == null) { + return itemId!; + } + + itemId = lastNavigableChild; } - return lastItem; + + return itemId!; }; -export const getFirstItem = (instance: TreeViewInstance<[UseTreeViewItemsSignature]>) => - instance.getNavigableChildrenIds(null)[0]; +export const getFirstNavigableItem = (instance: TreeViewInstance<[UseTreeViewItemsSignature]>) => + instance.getChildrenIds(null).find(instance.isItemNavigable)!; + +/** + * This is used to determine the start and end of a selection range so + * we can get the items between the two border items. + * + * It finds the items' common ancestor using + * a naive implementation of a lowest common ancestor algorithm + * (https://en.wikipedia.org/wiki/Lowest_common_ancestor). + * Then compares the ancestor's 2 children that are ancestors of itemA and ItemB + * so we can compare their indexes to work out which item comes first in a depth first search. + * (https://en.wikipedia.org/wiki/Depth-first_search) + * + * Another way to put it is which item is shallower in a trémaux tree + * https://en.wikipedia.org/wiki/Tr%C3%A9maux_tree + */ +const findOrderInTremauxTree = ( + instance: TreeViewInstance<[UseTreeViewItemsSignature]>, + itemAId: string, + itemBId: string, +) => { + if (itemAId === itemBId) { + return [itemAId, itemBId]; + } + + const itemMetaA = instance.getItemMeta(itemAId); + const itemMetaB = instance.getItemMeta(itemBId); + + if (itemMetaA.parentId === itemMetaB.id || itemMetaB.parentId === itemMetaA.id) { + return itemMetaB.parentId === itemMetaA.id + ? [itemMetaA.id, itemMetaB.id] + : [itemMetaB.id, itemMetaA.id]; + } + + const aFamily: (string | null)[] = [itemMetaA.id]; + const bFamily: (string | null)[] = [itemMetaB.id]; + + let aAncestor = itemMetaA.parentId; + let bAncestor = itemMetaB.parentId; + + let aAncestorIsCommon = bFamily.indexOf(aAncestor) !== -1; + let bAncestorIsCommon = aFamily.indexOf(bAncestor) !== -1; + + let continueA = true; + let continueB = true; + + while (!bAncestorIsCommon && !aAncestorIsCommon) { + if (continueA) { + aFamily.push(aAncestor); + aAncestorIsCommon = bFamily.indexOf(aAncestor) !== -1; + continueA = aAncestor !== null; + if (!aAncestorIsCommon && continueA) { + aAncestor = instance.getItemMeta(aAncestor!).parentId; + } + } + + if (continueB && !aAncestorIsCommon) { + bFamily.push(bAncestor); + bAncestorIsCommon = aFamily.indexOf(bAncestor) !== -1; + continueB = bAncestor !== null; + if (!bAncestorIsCommon && continueB) { + bAncestor = instance.getItemMeta(bAncestor!).parentId; + } + } + } + + const commonAncestor = aAncestorIsCommon ? aAncestor : bAncestor; + const ancestorFamily = instance.getChildrenIds(commonAncestor); + + const aSide = aFamily[aFamily.indexOf(commonAncestor) - 1]; + const bSide = bFamily[bFamily.indexOf(commonAncestor) - 1]; + + return ancestorFamily.indexOf(aSide!) < ancestorFamily.indexOf(bSide!) + ? [itemAId, itemBId] + : [itemBId, itemAId]; +}; + +export const getNavigableItemsInRange = ( + instance: TreeViewInstance<[UseTreeViewItemsSignature, UseTreeViewExpansionSignature]>, + itemAId: string, + itemBId: string, +) => { + const [firstItemId, lastItemId] = findOrderInTremauxTree(instance, itemAId, itemBId); + const items = [firstItemId]; + + let currentItemSiblings = instance.getChildrenIds(instance.getItemMeta(firstItemId).parentId); + let currentItemIndex = currentItemSiblings.indexOf(firstItemId); + + while (currentItemSiblings[currentItemIndex] !== lastItemId) { + const currentItemId = currentItemSiblings[currentItemIndex]; + // If the item is expanded, get its first children. + if (instance.isItemExpanded(currentItemId)) { + currentItemSiblings = instance.getChildrenIds(currentItemId); + currentItemIndex = 0; + } + // If the item is not the last of its siblings, get the next of them + else if (currentItemIndex < currentItemSiblings.length - 1) { + currentItemIndex += 1; + } + // If the item is the last of its siblings, get the first ancestor that has a next sibling and get this next sibling. + else { + let parentId = instance.getItemMeta(currentItemId).parentId!; + let parentSiblings = instance.getChildrenIds(instance.getItemMeta(parentId).parentId); + while (parentId === parentSiblings[parentSiblings.length - 1]) { + parentId = instance.getItemMeta(parentId).parentId!; + parentSiblings = instance.getChildrenIds(instance.getItemMeta(parentId).parentId); + } + + currentItemSiblings = parentSiblings; + currentItemIndex = currentItemSiblings.indexOf(parentId) + 1; + } + + items.push(currentItemId); + } + + items.push(lastItemId); + + return items.filter(instance.isItemNavigable); +};