From 4854e59a4f278a01209d859e0ec4d69e8e868c5f Mon Sep 17 00:00:00 2001 From: Michael Schmidt Date: Fri, 24 May 2024 15:09:29 +0200 Subject: [PATCH] Add special suggestion for Math and Compare nodes in node selector (#2908) * Add special suggestion for Math and Compare nodes in node selector * docs * Fix type error --- backend/src/api/api.py | 11 +- backend/src/api/node_data.py | 55 ++++++- .../chaiNNer_standard/utility/math/compare.py | 39 ++++- .../chaiNNer_standard/utility/math/math.py | 33 ++++ backend/src/server.py | 1 + src/common/SchemaMap.ts | 1 + src/common/common-types.ts | 8 + .../components/PaneNodeSearchMenu.tsx | 149 ++++++++++++++---- src/renderer/helpers/reactFlowUtil.ts | 16 +- src/renderer/hooks/usePaneNodeSearchMenu.tsx | 5 +- 10 files changed, 273 insertions(+), 45 deletions(-) diff --git a/backend/src/api/api.py b/backend/src/api/api.py index 5277c4bd00..7fff1f3a8b 100644 --- a/backend/src/api/api.py +++ b/backend/src/api/api.py @@ -24,7 +24,13 @@ check_naming_conventions, check_schema_types, ) -from .node_data import IteratorInputInfo, IteratorOutputInfo, KeyInfo, NodeData +from .node_data import ( + IteratorInputInfo, + IteratorOutputInfo, + KeyInfo, + NodeData, + SpecialSuggestion, +) from .output import BaseOutput from .settings import Setting from .types import FeatureId, InputId, NodeId, NodeKind, OutputId, RunFn @@ -98,6 +104,7 @@ def to_dict(self): def register( self, schema_id: str, + *, name: str, description: str | list[str], inputs: list[BaseInput | NestedGroup], @@ -114,6 +121,7 @@ def register( iterator_outputs: list[IteratorOutputInfo] | IteratorOutputInfo | None = None, node_context: bool = False, key_info: KeyInfo | None = None, + suggestions: list[SpecialSuggestion] | None = None, ): if not isinstance(description, str): description = "\n\n".join(description) @@ -183,6 +191,7 @@ def inner_wrapper(wrapped_func: T) -> T: iterator_inputs=iterator_inputs, iterator_outputs=iterator_outputs, key_info=key_info, + suggestions=suggestions or [], side_effects=side_effects, deprecated=deprecated, node_context=node_context, diff --git a/backend/src/api/node_data.py b/backend/src/api/node_data.py index 31b222b23e..ea86851ac9 100644 --- a/backend/src/api/node_data.py +++ b/backend/src/api/node_data.py @@ -1,7 +1,8 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any +from enum import Enum +from typing import Any, Mapping import navi @@ -71,6 +72,57 @@ def to_dict(self): return self._data +class SpecialSuggestion: + """ + A special suggestion in chaiNNer's context node selector. + + A suggestion consists of 3 parts: + 1. The search query to match. The query may optionally contain a pattern at the end + to supply a value to an input. E.g. `+{2}` will match the search query "+123" + and "123" will be parsed for the input with ID 2. + 2. The name of the suggestion. This is the text that will be displayed in the + suggestion list. + 3. The input values to supply to the node. This is a mapping of input IDs to the + values to supply to them. Values that aren't defined here will be left as + default values. + """ + + def __init__( + self, + query: str, + *, + name: str | None = None, + inputs: Mapping[InputId | int, Any] = {}, + ) -> None: + self.query, self.parse_input = SpecialSuggestion._parse_query(query) + self.name = name + self.inputs: dict[InputId, Any] = {InputId(k): v for k, v in inputs.items()} + + @staticmethod + def _parse_query(query: str) -> tuple[str, InputId | None]: + # e.g. "+{2}" + if "{" in query: + query, input_id = query.split("{") + input_id = int(input_id[:-1]) + return query, InputId(input_id) + return query, None + + def to_dict(self): + def convert_value(value: Any) -> Any: + if isinstance(value, bool): + return int(value) + if isinstance(value, Enum): + return value.value + return value + + return { + "query": self.query, + "name": self.name, + "parseInput": self.parse_input, + "inputs": {k: convert_value(v) for k, v in self.inputs.items()}, + } + + @dataclass(frozen=True) class NodeData: schema_id: str @@ -88,6 +140,7 @@ class NodeData: iterator_outputs: list[IteratorOutputInfo] key_info: KeyInfo | None + suggestions: list[SpecialSuggestion] side_effects: bool deprecated: bool diff --git a/backend/src/packages/chaiNNer_standard/utility/math/compare.py b/backend/src/packages/chaiNNer_standard/utility/math/compare.py index 8965c700e4..1784987918 100644 --- a/backend/src/packages/chaiNNer_standard/utility/math/compare.py +++ b/backend/src/packages/chaiNNer_standard/utility/math/compare.py @@ -2,7 +2,7 @@ from enum import Enum -from api import KeyInfo +from api import KeyInfo, SpecialSuggestion from nodes.properties.inputs import EnumInput, NumberInput from nodes.properties.outputs import BoolOutput @@ -79,6 +79,43 @@ class Comparison(Enum): """, ).suggest(), ], + suggestions=[ + SpecialSuggestion( + "={2}", + name="Comparison: Equal", + inputs={0: Comparison.EQUAL}, + ), + SpecialSuggestion( + "=={2}", + name="Comparison: Equal", + inputs={0: Comparison.EQUAL}, + ), + SpecialSuggestion( + "!={2}", + name="Comparison: Not Equal", + inputs={0: Comparison.NOT_EQUAL}, + ), + SpecialSuggestion( + ">{2}", + name="Comparison: Greater", + inputs={0: Comparison.GREATER}, + ), + SpecialSuggestion( + "<{2}", + name="Comparison: Less", + inputs={0: Comparison.LESS}, + ), + SpecialSuggestion( + ">={2}", + name="Comparison: Greater or Equal", + inputs={0: Comparison.GREATER_EQUAL}, + ), + SpecialSuggestion( + "<={2}", + name="Comparison: Less or Equal", + inputs={0: Comparison.LESS_EQUAL}, + ), + ], key_info=KeyInfo.enum(0), ) def compare_node(op: Comparison, left: float, right: float) -> bool: diff --git a/backend/src/packages/chaiNNer_standard/utility/math/math.py b/backend/src/packages/chaiNNer_standard/utility/math/math.py index d36aa52e64..7ecf622d1d 100644 --- a/backend/src/packages/chaiNNer_standard/utility/math/math.py +++ b/backend/src/packages/chaiNNer_standard/utility/math/math.py @@ -3,6 +3,7 @@ import math from enum import Enum +from api import SpecialSuggestion from nodes.properties.inputs import EnumInput, NumberInput from nodes.properties.outputs import NumberOutput @@ -90,6 +91,38 @@ def nonZero(x: number): number { ) .as_passthrough_of(0) ], + suggestions=[ + SpecialSuggestion( + "+{2}", + name="Math: Add", + inputs={1: MathOperation.ADD}, + ), + SpecialSuggestion( + "-{2}", + name="Math: Subtract", + inputs={1: MathOperation.SUBTRACT}, + ), + SpecialSuggestion( + "*{2}", + name="Math: Multiply", + inputs={1: MathOperation.MULTIPLY, 2: 1}, + ), + SpecialSuggestion( + "/{2}", + name="Math: Divide", + inputs={1: MathOperation.DIVIDE, 2: 1}, + ), + SpecialSuggestion( + "**{2}", + name="Math: Power", + inputs={1: MathOperation.POWER, 2: 1}, + ), + SpecialSuggestion( + "^{2}", + name="Math: Power", + inputs={1: MathOperation.POWER, 2: 1}, + ), + ], ) def math_node(op: MathOperation, a: float, b: float) -> int | float: if op == MathOperation.ADD: diff --git a/backend/src/server.py b/backend/src/server.py index d6c25c68e8..061b1c2ff0 100644 --- a/backend/src/server.py +++ b/backend/src/server.py @@ -150,6 +150,7 @@ async def nodes(_request: Request): "iteratorInputs": [x.to_dict() for x in node.iterator_inputs], "iteratorOutputs": [x.to_dict() for x in node.iterator_outputs], "keyInfo": node.key_info.to_dict() if node.key_info else None, + "suggestions": [x.to_dict() for x in node.suggestions], "description": node.description, "seeAlso": node.see_also, "icon": node.icon, diff --git a/src/common/SchemaMap.ts b/src/common/SchemaMap.ts index f4b2fdeb48..95446110c4 100644 --- a/src/common/SchemaMap.ts +++ b/src/common/SchemaMap.ts @@ -27,6 +27,7 @@ const BLANK_SCHEMA: NodeSchema = { hasSideEffects: false, deprecated: false, features: [], + suggestions: [], }; export class SchemaMap { diff --git a/src/common/common-types.ts b/src/common/common-types.ts index a4e8225e59..aadc796a0a 100644 --- a/src/common/common-types.ts +++ b/src/common/common-types.ts @@ -301,6 +301,13 @@ export interface TypeKeyInfo { readonly expression: ExpressionJson; } +export interface SpecialSuggestion { + readonly query: string; + readonly name?: string | null; + readonly parseInput?: InputId | null; + readonly inputs: Partial; +} + export interface NodeSchema { readonly name: string; readonly category: CategoryId; @@ -315,6 +322,7 @@ export interface NodeSchema { readonly iteratorInputs: readonly IteratorInputInfo[]; readonly iteratorOutputs: readonly IteratorOutputInfo[]; readonly keyInfo?: KeyInfo | null; + readonly suggestions: readonly SpecialSuggestion[]; readonly schemaId: SchemaId; readonly hasSideEffects: boolean; readonly deprecated: boolean; diff --git a/src/renderer/components/PaneNodeSearchMenu.tsx b/src/renderer/components/PaneNodeSearchMenu.tsx index 5d28c43997..d090921a0c 100644 --- a/src/renderer/components/PaneNodeSearchMenu.tsx +++ b/src/renderer/components/PaneNodeSearchMenu.tsx @@ -19,8 +19,10 @@ import { CategoryId, FeatureId, FeatureState, + InputData, NodeSchema, SchemaId, + SpecialSuggestion, } from '../../common/common-types'; import { assertNever, cacheLast, groupBy, stopPropagation } from '../../common/util'; import { getCategoryAccentColor } from '../helpers/accentColors'; @@ -41,15 +43,16 @@ const clampWithWrap = (min: number, max: number, value: number): number => { }; interface SchemaItemProps { - schema: NodeSchema; + name: string; + icon: string; isFavorite: boolean; accentColor: string; - onClick: (schema: NodeSchema) => void; + onClick: () => void; isSelected: boolean; scrollRef?: React.RefObject; } const SchemaItem = memo( - ({ schema, onClick, isFavorite, accentColor, isSelected, scrollRef }: SchemaItemProps) => { + ({ name, icon, onClick, isFavorite, accentColor, isSelected, scrollRef }: SchemaItemProps) => { const bgColor = useThemeColor('--bg-700'); const menuBgColor = useThemeColor('--bg-800'); @@ -69,18 +72,17 @@ const SchemaItem = memo( : `linear(to-r, ${gradL}, ${gradR})` } borderRadius="md" - key={schema.schemaId} mx={1} my={0.5} outline={isSelected ? '1px solid' : undefined} px={2} py={0.5} ref={scrollRef} - onClick={() => onClick(schema)} + onClick={onClick} > - {schema.name} + {name} {isFavorite && ( ; +} + type SchemaGroup = FavoritesSchemaGroup | SuggestedSchemaGroup | CategorySchemaGroup; interface SchemaGroupBase { readonly name: string; - readonly schemata: readonly NodeSchema[]; + readonly items: readonly GroupItem[]; } interface FavoritesSchemaGroup extends SchemaGroupBase { type: 'favorites'; @@ -131,8 +148,18 @@ const groupSchemata = ( schemata: readonly NodeSchema[], categories: CategoryMap, favorites: ReadonlySet, - suggested: ReadonlySet + suggested: ReadonlySet, + specialSuggestions: readonly SuggestionGroupItem[] ): readonly SchemaGroup[] => { + const toItem = (schema: NodeSchema): SchemaGroupItem => { + return { + type: 'schema', + name: schema.name, + icon: schema.icon, + schema, + }; + }; + const cats = [...groupBy(schemata, 'category')].map( ([categoryId, categorySchemata]): CategorySchemaGroup => { const category = categories.get(categoryId); @@ -141,7 +168,7 @@ const groupSchemata = ( name: category?.name ?? categoryId, categoryId, category, - schemata: categorySchemata, + items: categorySchemata.map(toItem), }; } ); @@ -149,20 +176,19 @@ const groupSchemata = ( const favs: FavoritesSchemaGroup = { type: 'favorites', name: 'Favorites', - schemata: cats.flatMap((c) => c.schemata).filter((n) => favorites.has(n.schemaId)), + items: cats.flatMap((c) => c.items).filter((n) => favorites.has(n.schema.schemaId)), }; const suggs: SuggestedSchemaGroup = { type: 'suggested', name: 'Suggested', - schemata: schemata.filter((n) => suggested.has(n.schemaId)), + items: [ + ...specialSuggestions, + ...schemata.filter((n) => suggested.has(n.schemaId)).map(toItem), + ], }; - return [ - ...(suggs.schemata.length ? [suggs] : []), - ...(favs.schemata.length ? [favs] : []), - ...cats, - ]; + return [...(suggs.items.length ? [suggs] : []), ...(favs.items.length ? [favs] : []), ...cats]; }; const renderGroupIcon = (categories: CategoryMap, group: SchemaGroup) => { @@ -195,6 +221,50 @@ const renderGroupIcon = (categories: CategoryMap, group: SchemaGroup) => { } }; +// eslint-disable-next-line react-memo/require-memo +function* getSpecialSuggestions( + schemata: readonly NodeSchema[], + searchQuery: string +): Iterable { + const parse = ( + s: SpecialSuggestion, + schema: NodeSchema + ): { inputs: Partial } | undefined => { + if (searchQuery === s.query) { + return { inputs: s.inputs }; + } + if (s.parseInput != null && searchQuery.startsWith(s.query)) { + const rest = searchQuery.slice(s.query.length); + const input = schema.inputs.find((i) => i.id === s.parseInput); + if (input) { + // attempt to parse the rest of the query string + if (input.kind === 'number') { + const value = parseFloat(rest.trim()); + if (!Number.isNaN(value)) { + return { inputs: { ...s.inputs, [input.id]: value } }; + } + } + } + } + return undefined; + }; + + for (const schema of schemata) { + for (const suggestion of schema.suggestions) { + const parsed = parse(suggestion, schema); + if (parsed) { + yield { + type: 'suggestion', + icon: schema.icon, + name: suggestion.name ?? schema.name, + schema, + inputs: parsed.inputs, + }; + } + } + } +} + const createMatcher = ( schemata: readonly NodeSchema[], categories: CategoryMap, @@ -203,11 +273,18 @@ const createMatcher = ( featureStates: ReadonlyMap ) => { return cacheLast((searchQuery: string) => { + const specialSuggestions = [...getSpecialSuggestions(schemata, searchQuery)]; const matchingNodes = getMatchingNodes(searchQuery, schemata, categories); - const groups = groupSchemata(matchingNodes, categories, favorites, suggestions); - const flatGroups = groups.flatMap((group) => group.schemata); + const groups = groupSchemata( + matchingNodes, + categories, + favorites, + suggestions, + specialSuggestions + ); + const flatGroups: readonly GroupItem[] = groups.flatMap((group) => group.items); - const bestMatch = getBestMatch(searchQuery, matchingNodes, categories, (schema) => { + const bestSchema = getBestMatch(searchQuery, matchingNodes, categories, (schema) => { const isFeatureEnabled = schema.features.every((f) => { return featureStates.get(f)?.enabled ?? false; }); @@ -222,13 +299,14 @@ const createMatcher = ( } return 1; }); + const bestMatch = flatGroups.find((item) => item.schema === bestSchema); return { groups, flatGroups, bestMatch }; }); }; interface MenuProps { - onSelect: (schema: NodeSchema) => void; + onSelect: (schema: NodeSchema, inputs: Partial) => void; schemata: readonly NodeSchema[]; favorites: ReadonlySet; categories: CategoryMap; @@ -263,9 +341,9 @@ export const Menu = memo( const { groups, flatGroups } = useMemo(() => matcher(searchQuery), [searchQuery, matcher]); const onClickHandler = useCallback( - (schema: NodeSchema) => { + (item: GroupItem) => { changeSearchQuery(''); - onSelect(schema); + onSelect(item.schema, item.inputs ?? {}); }, [changeSearchQuery, onSelect] ); @@ -358,13 +436,13 @@ export const Menu = memo( {groups.map((group, groupIndex) => { const indexOffset = groups .slice(0, groupIndex) - .reduce((acc, g) => acc + g.schemata.length, 0); + .reduce((acc, g) => acc + g.items.length, 0); const nodeHeight = 28; const nodePadding = 2; const placeholderHeight = - nodeHeight * group.schemata.length + - nodePadding * (group.schemata.length + 1); + nodeHeight * group.items.length + + nodePadding * (group.items.length + 1); return ( @@ -380,26 +458,29 @@ export const Menu = memo( - {group.schemata.map((schema, schemaIndex) => { - const index = indexOffset + schemaIndex; + {group.items.map((item, itemIndex) => { + const index = indexOffset + itemIndex; const isSelected = selectedIndex === index; return ( { + onClickHandler(item); + }} /> ); })} diff --git a/src/renderer/helpers/reactFlowUtil.ts b/src/renderer/helpers/reactFlowUtil.ts index 21b9b77716..5a435afd1c 100644 --- a/src/renderer/helpers/reactFlowUtil.ts +++ b/src/renderer/helpers/reactFlowUtil.ts @@ -6,7 +6,7 @@ import { createUniqueId, deepCopy } from '../../common/util'; export interface NodeProto { id?: string; position: Readonly; - data: Omit & { inputData?: InputData }; + data: Omit & { inputData?: Partial }; } export const createNode = ( @@ -16,15 +16,19 @@ export const createNode = ( ): Node => { const schema = schemata.get(data.schemaId); + let inputData: InputData = schemata.getDefaultInput(data.schemaId); + if (data.inputData) { + inputData = { + ...inputData, + ...data.inputData, + }; + } + const newNode: Node> = { type: schema.kind, id, position: { ...position }, - data: { - ...data, - id, - inputData: data.inputData ?? schemata.getDefaultInput(data.schemaId), - }, + data: { ...data, id, inputData }, selected, }; diff --git a/src/renderer/hooks/usePaneNodeSearchMenu.tsx b/src/renderer/hooks/usePaneNodeSearchMenu.tsx index d73944d2a8..f67f80fcda 100644 --- a/src/renderer/hooks/usePaneNodeSearchMenu.tsx +++ b/src/renderer/hooks/usePaneNodeSearchMenu.tsx @@ -1,7 +1,7 @@ import { useCallback, useMemo, useState } from 'react'; import { OnConnectStartParams, useReactFlow } from 'reactflow'; import { useContext, useContextSelector } from 'use-context-selector'; -import { InputId, NodeSchema, OutputId, SchemaId } from '../../common/common-types'; +import { InputData, InputId, NodeSchema, OutputId, SchemaId } from '../../common/common-types'; import { getFirstPossibleInput, getFirstPossibleOutput } from '../../common/nodes/connectedInputs'; import { ChainLineage } from '../../common/nodes/lineage'; import { TypeState } from '../../common/nodes/TypeState'; @@ -162,7 +162,7 @@ export const usePaneNodeSearchMenu = (): UsePaneNodeSearchMenuValue => { }, [schemata.schemata, connectingFrom, typeState, chainLineage, functionDefinitions]); const onSchemaSelect = useCallback( - (schema: NodeSchema) => { + (schema: NodeSchema, inputs: Partial) => { const { x, y } = mousePosition; const projPosition = screenToFlowPosition({ x, y }); const nodeId = createUniqueId(); @@ -171,6 +171,7 @@ export const usePaneNodeSearchMenu = (): UsePaneNodeSearchMenuValue => { position: projPosition, data: { schemaId: schema.schemaId, + inputData: inputs, }, });