Skip to content

Commit

Permalink
Add Compare & Conditional node (#2866)
Browse files Browse the repository at this point in the history
* Add Compare & Conditional node

* Fix errors

---------

Co-authored-by: Joey Ballentine <[email protected]>
  • Loading branch information
RunDevelopment and joeyballentine authored May 18, 2024
1 parent 1562e1d commit a3e33f4
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 2 deletions.
30 changes: 29 additions & 1 deletion backend/src/nodes/properties/inputs/generic_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def wrap_with_conditional_group(self):
return group("conditional", {"condition": condition})(self)


class BoolInput(DropDownInput[bool]):
class _BoolEnumInput(DropDownInput[bool]):
def __init__(self, label: str, *, default: bool = True, icon: str | None = None):
super().__init__(
input_type="bool",
Expand Down Expand Up @@ -175,6 +175,34 @@ def enforce(self, value: object) -> bool:
return bool(value)


class _BoolGenericInput(BaseInput[bool]):
def __init__(self, label: str):
super().__init__(input_type="bool", label=label)
self.associated_type = bool

def enforce(self, value: object) -> bool:
if isinstance(value, bool):
return value
if isinstance(value, int):
return bool(value)

raise ValueError(
f"The value of input '{self.label}' should have been either True or False."
)


def BoolInput(
label: str,
*,
default: bool = True,
icon: str | None = None,
has_handle: bool = False,
):
if has_handle:
return _BoolGenericInput(label)
return _BoolEnumInput(label, default=default, icon=icon)


E = TypeVar("E", bound=Enum)


Expand Down
14 changes: 14 additions & 0 deletions backend/src/nodes/properties/outputs/generic_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,20 @@ def enforce(self, value: object) -> Color:
return value


class BoolOutput(BaseOutput):
def __init__(
self,
label: str = "Logical",
*,
output_type: navi.ExpressionJson = "bool",
):
super().__init__(
output_type=navi.intersect_with_error("bool", output_type),
label=label,
kind="generic",
)


class AudioStreamOutput(BaseOutput):
def __init__(self, label: str = "Audio Stream"):
super().__init__(
Expand Down
96 changes: 96 additions & 0 deletions backend/src/packages/chaiNNer_standard/utility/math/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from __future__ import annotations

from enum import Enum

from api import KeyInfo
from nodes.properties.inputs import EnumInput, NumberInput
from nodes.properties.outputs import BoolOutput

from .. import math_group


class Comparison(Enum):
EQUAL = 0
NOT_EQUAL = 1
LESS = 3
LESS_EQUAL = 5
GREATER = 2
GREATER_EQUAL = 4


@math_group.register(
schema_id="chainner:utility:compare",
name="Compare",
description="Compares the given numbers.",
icon="MdCalculate",
inputs=[
EnumInput(
Comparison,
label="Operation",
option_labels={
Comparison.EQUAL: "L == R",
Comparison.NOT_EQUAL: "L == R",
Comparison.GREATER: "L > R",
Comparison.LESS: "L < R",
Comparison.GREATER_EQUAL: "L >= R",
Comparison.LESS_EQUAL: "L <= R",
},
icons={
Comparison.EQUAL: "=",
Comparison.NOT_EQUAL: "≠",
Comparison.GREATER: ">",
Comparison.LESS: "<",
Comparison.GREATER_EQUAL: "≥",
Comparison.LESS_EQUAL: "≤",
},
label_style="hidden",
).with_id(0),
NumberInput(
label="Left",
min=None,
max=None,
precision=100,
step=1,
).with_id(1),
NumberInput(
label="Right",
min=None,
max=None,
precision=100,
step=1,
).with_id(2),
],
outputs=[
BoolOutput(
label="Result",
output_type="""
let op = Input0;
let l = Input1;
let r = Input2;
match op {
Comparison::Equal => l == r,
Comparison::NotEqual => l != r,
Comparison::Greater => l > r,
Comparison::Less => l < r,
Comparison::GreaterEqual => l >= r,
Comparison::LessEqual => l <= r,
}
""",
).suggest(),
],
key_info=KeyInfo.enum(0),
)
def compare_node(op: Comparison, left: float, right: float) -> bool:
if op == Comparison.EQUAL:
return left == right
elif op == Comparison.NOT_EQUAL:
return left != right
elif op == Comparison.GREATER:
return left > right
elif op == Comparison.LESS:
return left < right
elif op == Comparison.GREATER_EQUAL:
return left >= right
elif op == Comparison.LESS_EQUAL:
return left <= right
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import annotations

from api import Lazy
from nodes.properties.inputs import AnyInput, BoolInput
from nodes.properties.outputs import BaseOutput

from .. import value_group


@value_group.register(
schema_id="chainner:utility:conditional",
name="Conditional",
description="Allows you to pass in multiple inputs and then change which one passes through to the output.",
icon="BsShuffle",
inputs=[
BoolInput("Condition", default=True, has_handle=True).with_id(0),
AnyInput(label="If True").make_lazy(),
AnyInput(label="If False").make_lazy(),
],
outputs=[
BaseOutput(
output_type="""
if Input0 { Input1 } else { Input2 }
""",
label="Value",
).as_passthrough_of(1),
],
see_also=["chainner:utility:switch"],
)
def conditional_node(
cond: bool, if_true: Lazy[object], if_false: Lazy[object]
) -> object:
return if_true.value if cond else if_false.value
5 changes: 5 additions & 0 deletions src/common/types/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ interface KnownStructDefinitions {
Directory: {
readonly path: StringPrimitive;
};

true: Record<string, never>;
false: Record<string, never>;
}
interface KnownInstance<N extends keyof KnownStructDefinitions> {
readonly descriptor: StructDescriptor & { readonly name: N };
Expand All @@ -44,6 +47,8 @@ const createAssertFn = <N extends keyof KnownStructDefinitions>(
export const isImage = createAssertFn('Image');
export const isColor = createAssertFn('Color');
export const isDirectory = createAssertFn('Directory');
export const isTrue = createAssertFn('true');
export const isFalse = createAssertFn('false');

export const getFields = <N extends keyof KnownStructDefinitions>(
type: StructInstanceType & KnownInstance<N>
Expand Down
17 changes: 16 additions & 1 deletion src/renderer/components/TypeTag.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@ import { Tag, Tooltip, forwardRef } from '@chakra-ui/react';
import React, { ReactNode, memo } from 'react';
import { useTranslation } from 'react-i18next';
import { explain } from '../../common/types/explain';
import { getFields, isColor, isDirectory, isImage, withoutNull } from '../../common/types/util';
import {
getFields,
isColor,
isDirectory,
isFalse,
isImage,
isTrue,
withoutNull,
} from '../../common/types/util';
import { assertNever } from '../../common/util';

const getColorMode = (channels: number) => {
Expand Down Expand Up @@ -215,6 +223,13 @@ const getTypeText = (type: Type): TagValue[] => {
}
}

if (isTrue(type)) {
tags.push({ kind: 'literal', value: 'True' });
}
if (isFalse(type)) {
tags.push({ kind: 'literal', value: 'False' });
}

if (isStructInstance(type)) {
if (
type.descriptor.name === 'PyTorchModel' ||
Expand Down

0 comments on commit a3e33f4

Please sign in to comment.