Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit e1bdbe8

Browse files
committedFeb 21, 2025·
feat: react-hook-form field array for provider muxes
1 parent 3c8b909 commit e1bdbe8

6 files changed

+404
-68
lines changed
 

‎package-lock.json

+38-8
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎package.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,15 @@
2424
"@dnd-kit/core": "^6.3.1",
2525
"@dnd-kit/sortable": "^10.0.0",
2626
"@hey-api/client-fetch": "^0.7.1",
27+
"@hookform/resolvers": "^4.1.0",
2728
"@jsonforms/core": "^3.5.1",
2829
"@jsonforms/react": "^3.5.1",
2930
"@jsonforms/vanilla-renderers": "^3.5.1",
3031
"@monaco-editor/react": "^4.6.0",
3132
"@radix-ui/react-dialog": "^1.1.4",
3233
"@radix-ui/react-separator": "^1.1.0",
3334
"@radix-ui/react-slot": "^1.1.0",
34-
"@stacklok/ui-kit": "^1.0.1-4",
35+
"@stacklok/ui-kit": "^1.0.1-9",
3536
"@tanstack/react-query": "^5.64.1",
3637
"@tanstack/react-query-devtools": "^5.66.0",
3738
"@types/lodash": "^4.17.15",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
import {
2+
Button,
3+
ComboBoxButton,
4+
ComboBoxClearButton,
5+
ComboBoxFieldGroup,
6+
ComboBoxInput,
7+
FormComboBox,
8+
FormTextField,
9+
Input,
10+
Label,
11+
OptionsSchema,
12+
TextField,
13+
Tooltip,
14+
TooltipInfoButton,
15+
TooltipTrigger,
16+
} from '@stacklok/ui-kit'
17+
import { useFieldArray, useFormContext } from 'react-hook-form'
18+
import {
19+
MUX_FIELD_NAME,
20+
WORKSPACE_CONFIG_FIELD_NAME,
21+
WorkspaceMuxFieldValues,
22+
} from '../lib/workspace-config-schema'
23+
import { useQueryListAllModelsForAllProviders } from '@/hooks/use-query-list-all-models-for-all-providers'
24+
import { ModelByProvider, MuxMatcherType, MuxRule } from '@/api/generated'
25+
import { groupBy, map } from 'lodash'
26+
import {
27+
BracketsSlash,
28+
DotsGrid,
29+
GridDotsTop,
30+
Plus,
31+
SearchMd,
32+
Trash01,
33+
} from '@untitled-ui/icons-react'
34+
import { tv } from 'tailwind-variants'
35+
import {
36+
SortableContext,
37+
sortableKeyboardCoordinates,
38+
useSortable,
39+
verticalListSortingStrategy,
40+
} from '@dnd-kit/sortable'
41+
import { CSS } from '@dnd-kit/utilities'
42+
import {
43+
closestCenter,
44+
DndContext,
45+
DragEndEvent,
46+
KeyboardSensor,
47+
PointerSensor,
48+
UniqueIdentifier,
49+
useSensor,
50+
useSensors,
51+
} from '@dnd-kit/core'
52+
import { ReactNode, useCallback } from 'react'
53+
import { twMerge } from 'tailwind-merge'
54+
55+
function getMuxComponentName({
56+
field,
57+
index,
58+
}: {
59+
index: number
60+
field: (typeof MUX_FIELD_NAME)[keyof typeof MUX_FIELD_NAME]
61+
}) {
62+
return `${MUX_FIELD_NAME}.${index}.${field}`
63+
}
64+
65+
function groupModels(
66+
models: ModelByProvider[] = []
67+
): OptionsSchema<'listbox'>[] {
68+
return map(groupBy(models, 'provider_name'), (items, providerName) => ({
69+
id: providerName,
70+
textValue: providerName,
71+
items: items.map((item) => ({
72+
id: `${item.provider_id}/${item.name}`,
73+
textValue: item.name,
74+
})),
75+
}))
76+
}
77+
78+
function getIndicesOnDragEnd<T extends { id: UniqueIdentifier }>(
79+
event: DragEndEvent,
80+
items: T[]
81+
): {
82+
from: number
83+
to: number
84+
} | null {
85+
const { active, over } = event
86+
87+
if (over == null || active.id || over.id) return null // no-op
88+
89+
const from = items.findIndex(({ id }) => id === active.id)
90+
const to = items.findIndex(({ id }) => id === over.id)
91+
92+
return {
93+
from,
94+
to,
95+
}
96+
}
97+
98+
function DndSortProvider<T extends { id: UniqueIdentifier }>({
99+
children,
100+
onDragEnd,
101+
items,
102+
}: {
103+
children: ReactNode
104+
onDragEnd: (event: DragEndEvent) => void
105+
items: T[]
106+
}) {
107+
const sensors = useSensors(
108+
useSensor(PointerSensor),
109+
useSensor(KeyboardSensor, {
110+
coordinateGetter: sortableKeyboardCoordinates,
111+
})
112+
)
113+
114+
return (
115+
<DndContext
116+
sensors={sensors}
117+
collisionDetection={closestCenter}
118+
onDragEnd={onDragEnd}
119+
>
120+
<SortableContext items={items} strategy={verticalListSortingStrategy}>
121+
{children}
122+
</SortableContext>
123+
</DndContext>
124+
)
125+
}
126+
127+
const gridStyles = tv({
128+
base: 'grid grid-cols-[2fr_1fr_2.5rem] items-center gap-2',
129+
})
130+
131+
function Labels() {
132+
return (
133+
<div className={gridStyles()}>
134+
<Label className="flex items-center gap-1">
135+
Filter by
136+
<TooltipTrigger delay={0}>
137+
<TooltipInfoButton aria-label="Filter by description" />
138+
<Tooltip placement="right" className="max-w-72 text-balance">
139+
Filters are applied in top-down order. The first rule that matches
140+
each prompt determines the chosen model. An empty filter applies to
141+
all prompts.
142+
</Tooltip>
143+
</TooltipTrigger>
144+
</Label>
145+
<Label>Preferred model</Label>
146+
</div>
147+
)
148+
}
149+
150+
function MuxRuleRow({
151+
index,
152+
item,
153+
models,
154+
hasDragDisabled,
155+
}: {
156+
index: number
157+
item: MuxRule & { id: string }
158+
models: OptionsSchema<'listbox'>[]
159+
hasDragDisabled: boolean
160+
}) {
161+
console.debug('👉 item:', item)
162+
163+
const isCatchAll = item.matcher_type === MuxMatcherType.CATCH_ALL
164+
165+
const { attributes, listeners, setNodeRef, transform, transition } =
166+
useSortable({ id: item.id })
167+
const style = {
168+
transform: CSS.Transform.toString(transform),
169+
transition,
170+
}
171+
172+
return (
173+
<li key={item.id} className={twMerge(gridStyles(), 'mb-2')} style={style}>
174+
<FormTextField
175+
aria-label="Matcher"
176+
isDisabled={isCatchAll}
177+
defaultValue={isCatchAll ? 'Catch-all' : undefined}
178+
name={getMuxComponentName({
179+
index,
180+
field: 'matcher',
181+
})}
182+
>
183+
<Input
184+
icon={
185+
isCatchAll ? undefined : (
186+
<div
187+
ref={setNodeRef}
188+
{...attributes}
189+
{...listeners}
190+
className="pointer-events-auto"
191+
>
192+
<DotsGrid className="size-5" />
193+
</div>
194+
)
195+
}
196+
/>
197+
</FormTextField>
198+
199+
<FormComboBox
200+
aria-label="Matcher"
201+
items={models}
202+
// isDisabled={isArchived || isDefaultRule}
203+
name={getMuxComponentName({
204+
index,
205+
field: 'model',
206+
})}
207+
>
208+
<ComboBoxFieldGroup>
209+
<ComboBoxInput
210+
icon={<SearchMd />}
211+
isBorderless
212+
placeholder="Type to search..."
213+
/>
214+
<ComboBoxClearButton />
215+
<ComboBoxButton />
216+
</ComboBoxFieldGroup>
217+
</FormComboBox>
218+
<Button
219+
aria-label="Delete"
220+
isIcon
221+
isDisabled
222+
isDestructive
223+
variant="secondary"
224+
// onPress={() => removeRule(index)}
225+
>
226+
<Trash01 />
227+
</Button>
228+
</li>
229+
)
230+
}
231+
232+
export function WorkspaceMuxesFieldsArray() {
233+
const { control } = useFormContext()
234+
235+
const { fields, move, prepend } = useFieldArray({
236+
control,
237+
name: WORKSPACE_CONFIG_FIELD_NAME.muxing_rules,
238+
})
239+
240+
const { data: models = [] } = useQueryListAllModelsForAllProviders({
241+
select: groupModels,
242+
})
243+
244+
const onDragEnd = useCallback(
245+
(event: DragEndEvent) => {
246+
const { from, to } = getIndicesOnDragEnd(event, fields) || {}
247+
if (from && to) move(from, to)
248+
},
249+
[fields, move]
250+
)
251+
252+
console.debug('👉 fields:', fields)
253+
254+
return (
255+
<>
256+
<Labels />
257+
<DndSortProvider items={fields} onDragEnd={onDragEnd}>
258+
<ul>
259+
{fields.map((item, index) => (
260+
<MuxRuleRow index={index} item={item} models={models} />
261+
))}
262+
</ul>
263+
</DndSortProvider>
264+
265+
<div className="flex gap-2">
266+
<Button
267+
className="w-fit"
268+
variant="tertiary"
269+
onPress={() => prepend({})}
270+
// isDisabled={isArchived}
271+
>
272+
<Plus /> Add Filter
273+
</Button>
274+
275+
{/* <LinkButton className="w-fit" variant="tertiary" href="/providers">
276+
<LayersThree01 /> Manage providers
277+
</LinkButton> */}
278+
</div>
279+
</>
280+
)
281+
}

‎src/features/workspace/components/workspace-muxing-model.tsx

+53-58
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,18 @@
33
Button,
44
Card,
55
CardBody,
66
CardFooter,
7-
Form,
7+
FormSubmitButton,
8+
FormV2,
89
Input,
910
Label,
1011
Link,
1112
LinkButton,
1213
Text,
1314
TextField,
1415
Tooltip,
1516
TooltipInfoButton,
1617
TooltipTrigger,
1718
} from '@stacklok/ui-kit'
1819
import { twMerge } from 'tailwind-merge'
1920
import { useMutationPreferredModelWorkspace } from '../hooks/use-mutation-preferred-model-workspace'
@@ -21,7 +22,6 @@
2122
MuxMatcherType,
2223
V1ListAllModelsForAllProvidersResponse,
2324
} from '@/api/generated'
24-
import { FormEvent } from 'react'
2525
import {
2626
LayersThree01,
2727
LinkExternal01,
@@ -37,6 +37,24 @@
3737
useMuxingRulesFormState,
3838
} from '../hooks/use-muxing-rules-form-workspace'
3939
import { FormButtons } from '@/components/FormButtons'
40+
import { WorkspaceMuxesFieldsArray } from './workspace-muxes-fields-array'
41+
import {
42+
schemaWorkspaceConfig,
43+
WorkspaceConfigFieldValues,
44+
} from '../lib/workspace-config-schema'
45+
import { zodResolver } from '@hookform/resolvers/zod'
46+
47+
const DEFAULT_VALUES: WorkspaceConfigFieldValues = {
48+
muxing_rules: [
49+
{
50+
provider_id: '',
51+
provider_name: '',
52+
model: '',
53+
matcher: '',
54+
matcher_type: MuxMatcherType.CATCH_ALL,
55+
},
56+
],
57+
}
4058

4159
function MissingProviderBanner() {
4260
return (
@@ -80,19 +98,17 @@
8098
const placeholder = isDefaultRule ? 'Catch-all' : 'e.g. file type, file name'
8199
return (
82100
<div className="flex items-center gap-2" key={rule.id}>
83-
<div className="flex w-full justify-between">
84-
<TextField
85-
aria-labelledby="filter-by-label-id"
86-
value={rule?.matcher ?? ''}
87-
isDisabled={isArchived || isDefaultRule}
88-
name="matcher"
89-
onChange={(matcher) => {
90-
setRuleItem({ ...rule, matcher })
91-
}}
92-
>
93-
<Input placeholder={placeholder} />
94-
</TextField>
95-
</div>
101+
<TextField
102+
aria-labelledby="filter-by-label-id"
103+
value={rule?.matcher ?? ''}
104+
isDisabled={isArchived || isDefaultRule}
105+
name="matcher"
106+
onChange={(matcher) => {
107+
setRuleItem({ ...rule, matcher })
108+
}}
109+
>
110+
<Input placeholder={placeholder} />
111+
</TextField>
96112
<div className="flex w-3/5 gap-2">
97113
<WorkspaceModelsDropdown
98114
rule={rule}
@@ -141,25 +157,15 @@
141157
const isModelsEmpty = !isPending && providerModels.length === 0
142158
const showRemoveButton = rules.length > 1
143159

144-
const handleSubmit = (event: FormEvent) => {
145-
event.preventDefault()
146-
mutateAsync(
147-
{
148-
path: { workspace_name: workspaceName },
149-
body: rules.map(({ id, ...rest }) => {
150-
void id
151-
152-
return rest.matcher
153-
? { ...rest, matcher_type: MuxMatcherType.FILENAME_MATCH }
154-
: { ...rest }
155-
}),
156-
},
157-
{
158-
onSuccess: () => {
159-
formState.setInitialValues({ rules })
160-
},
161-
}
162-
)
160+
const handleSubmit = (data: WorkspaceConfigFieldValues) => {
161+
mutateAsync({
162+
path: { workspace_name: workspaceName },
163+
body: data.muxing_rules.map((rule) => {
164+
return rule.matcher
165+
? { ...rule, matcher_type: MuxMatcherType.FILENAME_MATCH }
166+
: { ...rule }
167+
}),
168+
})
163169
}
164170

165171
if (isModelsEmpty) {
@@ -174,10 +180,13 @@
174180
}
175181

176182
return (
177-
<Form
183+
<FormV2<WorkspaceConfigFieldValues>
178184
onSubmit={handleSubmit}
179-
validationBehavior="aria"
180185
data-testid="preferred-model"
186+
options={{
187+
defaultValues: DEFAULT_VALUES,
188+
resolver: zodResolver(schemaWorkspaceConfig),
189+
}}
181190
>
182191
<Card className={twMerge(className, 'shrink-0')}>
183192
<CardBody className="flex flex-col gap-6">
@@ -198,25 +207,8 @@
198207
</div>
199208

200209
<div className="flex w-full flex-col gap-2">
201-
<div className="flex gap-2">
202-
<div className="w-12">&nbsp;</div>
203-
<div className="w-full">
204-
<Label id="filter-by-label-id" className="flex items-center">
205-
Filter by
206-
<TooltipTrigger delay={0}>
207-
<TooltipInfoButton aria-label="Filter by description" />
208-
<Tooltip>
209-
Filters are applied in top-down order. The first rule that
210-
matches each prompt determines the chosen model. An empty
211-
filter applies to all prompts.
212-
</Tooltip>
213-
</TooltipTrigger>
214-
</Label>
215-
</div>
216-
<div className="w-3/5">
217-
<Label id="preferred-model-id">Preferred Model</Label>
218-
</div>
219-
</div>
210+
<WorkspaceMuxesFieldsArray />
211+
220212
<SortableArea
221213
items={rules}
222214
setItems={setRules}
@@ -241,7 +233,10 @@
241233
</SortableArea>
242234
</div>
243235
</CardBody>
244-
<CardFooter className="justify-between">
236+
<div>
237+
<FormSubmitButton />
238+
</div>
239+
{/* <CardFooter className="justify-between">
245240
<div className="flex gap-2">
246241
<Button
247242
className="w-fit"
@@ -261,8 +256,8 @@
261256
formState={formState}
262257
canSubmit={!isArchived}
263258
/>
264-
</CardFooter>
259+
</CardFooter> */}
265260
</Card>
266-
</Form>
261+
</FormV2>
267262
)
268263
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import { MuxMatcherType } from '@/api/generated'
2+
import { z } from 'zod'
3+
4+
const schemaWorkspaceMux = z.object({
5+
provider_name: z.string().nullable(),
6+
provider_id: z.string().uuid(),
7+
model: z.string(),
8+
matcher_type: z.nativeEnum(MuxMatcherType),
9+
matcher: z.string().nullable(),
10+
})
11+
12+
export type WorkspaceMuxFieldValues = z.infer<typeof schemaWorkspaceMux>
13+
14+
export const schemaWorkspaceConfig = z.object({
15+
muxing_rules: z.array(schemaWorkspaceMux),
16+
})
17+
18+
export type WorkspaceConfigFieldValues = z.infer<typeof schemaWorkspaceConfig>
19+
20+
export const WORKSPACE_CONFIG_FIELD_NAME = schemaWorkspaceConfig.keyof().Enum
21+
export const MUX_FIELD_NAME = schemaWorkspaceMux.keyof().Enum
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
import { useQuery } from '@tanstack/react-query'
22
import { v1ListAllModelsForAllProvidersOptions } from '@/api/generated/@tanstack/react-query.gen'
3+
import { V1ListAllModelsForAllProvidersResponse } from '@/api/generated/types.gen'
34

4-
export const useQueryListAllModelsForAllProviders = () => {
5+
export function useQueryListAllModelsForAllProviders<
6+
T = V1ListAllModelsForAllProvidersResponse,
7+
>({
8+
select,
9+
}: {
10+
select?: (data: V1ListAllModelsForAllProvidersResponse) => T
11+
} = {}) {
512
return useQuery({
613
...v1ListAllModelsForAllProvidersOptions(),
14+
select,
715
})
816
}

0 commit comments

Comments
 (0)
Please sign in to comment.