From 51ff21be0190a1ad3db486a30bee26ab8473184a Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Tue, 30 Jul 2024 09:11:59 +0800 Subject: [PATCH 01/89] feat(core): Add UI component for AWEL flow --- dbgpt/core/awel/flow/exceptions.py | 11 + dbgpt/core/awel/flow/ui.py | 348 +++++++++++++++++++++++++++++ 2 files changed, 359 insertions(+) create mode 100644 dbgpt/core/awel/flow/ui.py diff --git a/dbgpt/core/awel/flow/exceptions.py b/dbgpt/core/awel/flow/exceptions.py index 0c3dc667d..68c02f8ac 100644 --- a/dbgpt/core/awel/flow/exceptions.py +++ b/dbgpt/core/awel/flow/exceptions.py @@ -44,3 +44,14 @@ class FlowDAGMetadataException(FlowMetadataException): def __init__(self, message: str, error_type="build_dag_metadata_error"): """Create a new FlowDAGMetadataException.""" super().__init__(message, error_type) + + +class FlowUIComponentException(FlowException): + """The exception for UI parameter failed.""" + + def __init__( + self, message: str, component_name: str, error_type="build_ui_component_error" + ): + """Create a new FlowUIParameterException.""" + new_message = f"{component_name}: {message}" + super().__init__(new_message, error_type) diff --git a/dbgpt/core/awel/flow/ui.py b/dbgpt/core/awel/flow/ui.py new file mode 100644 index 000000000..a9f220961 --- /dev/null +++ b/dbgpt/core/awel/flow/ui.py @@ -0,0 +1,348 @@ +"""UI components for AWEL flow.""" + +from typing import Any, Dict, List, Literal, Optional + +from dbgpt._private.pydantic import BaseModel, Field + +from .exceptions import FlowUIComponentException + +_UI_TYPE = Literal[ + "cascader", + "checkbox", + "date_picker", + "input", + "text_area", + "auto_complete", + "slider", + "time_picker", + "tree_select", + "upload", + "variable", + "password", + "code_editor", +] + + +class RefreshableMixin(BaseModel): + """Refreshable mixin.""" + + refresh: Optional[bool] = Field( + False, + description="Whether to enable the refresh", + ) + refresh_depends: Optional[List[str]] = Field( + None, + description="The dependencies of the refresh", + ) + + +class UIComponent(RefreshableMixin, BaseModel): + """UI component.""" + + class UIRange(BaseModel): + """UI range.""" + + min: int | float | str | None = Field(None, description="Minimum value") + max: int | float | str | None = Field(None, description="Maximum value") + step: int | float | str | None = Field(None, description="Step value") + format: str | None = Field(None, description="Format") + + ui_type: _UI_TYPE = Field(..., description="UI component type") + + disabled: bool = Field( + False, + description="Whether the component is disabled", + ) + + def check_parameter(self, parameter_dict: Dict[str, Any]): + """Check parameter. + + Raises: + FlowUIParameterException: If the parameter is invalid. + """ + + def _check_options(self, options: Dict[str, Any]): + """Check options.""" + if not options: + raise FlowUIComponentException("options is required", self.ui_type) + + +class StatusMixin(BaseModel): + """Status mixin.""" + + status: Optional[Literal["error", "warning"]] = Field( + None, + description="Status of the input", + ) + + +class RangeMixin(BaseModel): + """Range mixin.""" + + ui_range: Optional[UIComponent.UIRange] = Field( + None, + description="Range for the component", + ) + + +class InputMixin(BaseModel): + """Input mixin.""" + + class Count(BaseModel): + """Count.""" + + show: Optional[bool] = Field( + None, + description="Whether to show count", + ) + max: Optional[int] = Field( + None, + description="The maximum count", + ) + exceed_strategy: Optional[Literal["cut", "warning"]] = Field( + None, + description="The strategy when the count exceeds", + ) + + count: Optional[Count] = Field( + None, + description="Count configuration", + ) + + +class PanelEditorMixin(BaseModel): + """Edit the content in the panel.""" + + class Editor(BaseModel): + """Editor configuration.""" + + width: Optional[int] = Field( + None, + description="The width of the panel", + ) + height: Optional[int] = Field( + None, + description="The height of the panel", + ) + + editor: Optional[Editor] = Field( + None, + description="The editor configuration", + ) + + +class UICascader(StatusMixin, UIComponent): + """Cascader component.""" + + ui_type: Literal["cascader"] = Field("cascader", frozen=True) + + show_search: bool = Field( + False, + description="Whether to show search input", + ) + + def check_parameter(self, parameter_dict: Dict[str, Any]): + """Check parameter.""" + options = parameter_dict.get("options") + if not options: + raise FlowUIComponentException("options is required", self.ui_type) + first_level = options[0] + if "children" not in first_level: + raise FlowUIComponentException( + "children is required in options", self.ui_type + ) + + +class UICheckbox(UIComponent): + """Checkbox component.""" + + ui_type: Literal["checkbox"] = Field("checkbox", frozen=True) + + def check_parameter(self, parameter_dict: Dict[str, Any]): + """Check parameter.""" + self._check_options(parameter_dict.get("options", {})) + + +class UIDatePicker(StatusMixin, RangeMixin, UIComponent): + """Date picker component.""" + + ui_type: Literal["date_picker"] = Field("date_picker", frozen=True) + + placement: Optional[ + Literal["topLeft", "topRight", "bottomLeft", "bottomRight"] + ] = Field( + None, + description="The position of the picker panel, None means bottomLeft", + ) + + +class UIInput(StatusMixin, InputMixin, UIComponent): + """Input component.""" + + ui_type: Literal["input"] = Field("input", frozen=True) + + prefix: Optional[str] = Field( + None, + description="The prefix, icon or text", + examples=["$", "icon:UserOutlined"], + ) + suffix: Optional[str] = Field( + None, + description="The suffix, icon or text", + examples=["$", "icon:SearchOutlined"], + ) + + +class UITextArea(PanelEditorMixin, UIInput): + """Text area component.""" + + ui_type: Literal["text_area"] = Field("text_area", frozen=True) # type: ignore + auto_size: Optional[bool] = Field( + None, + description="Whether the height of the textarea automatically adjusts based " + "on the content", + ) + min_rows: Optional[int] = Field( + None, + description="The minimum number of rows", + ) + max_rows: Optional[int] = Field( + None, + description="The maximum number of rows", + ) + + +class UIAutoComplete(UIInput): + """Auto complete component.""" + + ui_type: Literal["auto_complete"] = Field( # type: ignore + "auto_complete", frozen=True + ) + + +class UISlider(RangeMixin, UIComponent): + """Slider component.""" + + ui_type: Literal["slider"] = Field("slider", frozen=True) + + show_input: bool = Field( + False, description="Whether to display the value in a input component" + ) + + +class UITimePicker(StatusMixin, UIComponent): + """Time picker component.""" + + ui_type: Literal["time_picker"] = Field("time_picker", frozen=True) + + format: Optional[str] = Field( + None, + description="The format of the time", + examples=["HH:mm:ss", "HH:mm"], + ) + hour_step: Optional[int] = Field( + None, + description="The step of the hour input", + ) + minute_step: Optional[int] = Field( + None, + description="The step of the minute input", + ) + second_step: Optional[int] = Field( + None, + description="The step of the second input", + ) + + +class UITreeSelect(StatusMixin, UIComponent): + """Tree select component.""" + + ui_type: Literal["tree_select"] = Field("tree_select", frozen=True) + + def check_parameter(self, parameter_dict: Dict[str, Any]): + """Check parameter.""" + options = parameter_dict.get("options") + if not options: + raise FlowUIComponentException("options is required", self.ui_type) + first_level = options[0] + if "children" not in first_level: + raise FlowUIComponentException( + "children is required in options", self.ui_type + ) + + +class UIUpload(StatusMixin, UIComponent): + """Upload component.""" + + ui_type: Literal["upload"] = Field("upload", frozen=True) + + max_file_size: Optional[int] = Field( + None, + description="The maximum size of the file, in bytes", + ) + max_count: Optional[int] = Field( + None, + description="The maximum number of files that can be uploaded", + ) + file_types: Optional[List[str]] = Field( + None, + description="The file types that can be accepted", + examples=[[".png", ".jpg"]], + ) + up_event: Optional[Literal["after_select", "button_click"]] = Field( + None, + description="The event that triggers the upload", + ) + drag: bool = Field( + False, + description="Whether to support drag and drop upload", + ) + action: Optional[str] = Field( + None, + description="The URL for the file upload", + ) + + +class UIVariableInput(UIInput): + """Variable input component.""" + + ui_type: Literal["variable"] = Field("variable", frozen=True) # type: ignore + key: str = Field(..., description="The key of the variable") + key_type: Literal["common", "secret"] = Field( + "common", + description="The type of the key", + ) + refresh: Optional[bool] = Field( + True, + description="Whether to enable the refresh", + ) + + def check_parameter(self, parameter_dict: Dict[str, Any]): + """Check parameter.""" + self._check_options(parameter_dict.get("options", {})) + + +class UIPasswordInput(UIVariableInput): + """Password input component.""" + + ui_type: Literal["password"] = Field("password", frozen=True) # type: ignore + + key_type: Literal["secret"] = Field( + "secret", + description="The type of the key", + ) + + def check_parameter(self, parameter_dict: Dict[str, Any]): + """Check parameter.""" + self._check_options(parameter_dict.get("options", {})) + + +class UICodeEditor(UITextArea): + """Code editor component.""" + + ui_type: Literal["code_editor"] = Field("code_editor", frozen=True) # type: ignore + + language: Optional[str] = Field( + "python", + description="The language of the code", + ) From 8465726dc8e3ee0109d836bbf952bd9793bf8622 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Mon, 5 Aug 2024 18:08:02 +0800 Subject: [PATCH 02/89] feat: Add flow2.0 examples --- dbgpt/core/awel/flow/base.py | 18 +- dbgpt/core/awel/flow/ui.py | 289 +++++---- dbgpt/core/awel/util/parameter_util.py | 5 +- .../core/interface/operators/llm_operator.py | 5 + .../interface/operators/prompt_operator.py | 4 + dbgpt/serve/flow/api/endpoints.py | 5 +- examples/awel/awel_flow_ui_components.py | 583 ++++++++++++++++++ 7 files changed, 786 insertions(+), 123 deletions(-) create mode 100644 examples/awel/awel_flow_ui_components.py diff --git a/dbgpt/core/awel/flow/base.py b/dbgpt/core/awel/flow/base.py index fb60538ba..da0b2c378 100644 --- a/dbgpt/core/awel/flow/base.py +++ b/dbgpt/core/awel/flow/base.py @@ -19,6 +19,7 @@ from dbgpt.core.interface.serialization import Serializable from .exceptions import FlowMetadataException, FlowParameterMetadataException +from .ui import UIComponent _TYPE_REGISTRY: Dict[str, Type] = {} @@ -136,6 +137,7 @@ def __init__(self, label: str, description: str): "agent": _CategoryDetail("Agent", "The agent operator"), "rag": _CategoryDetail("RAG", "The RAG operator"), "experimental": _CategoryDetail("EXPERIMENTAL", "EXPERIMENTAL operator"), + "example": _CategoryDetail("Example", "Example operator"), } @@ -151,6 +153,7 @@ class OperatorCategory(str, Enum): AGENT = "agent" RAG = "rag" EXPERIMENTAL = "experimental" + EXAMPLE = "example" def label(self) -> str: """Get the label of the category.""" @@ -193,6 +196,7 @@ class OperatorType(str, Enum): "embeddings": _CategoryDetail("Embeddings", "The embeddings resource"), "rag": _CategoryDetail("RAG", "The resource"), "vector_store": _CategoryDetail("Vector Store", "The vector store resource"), + "example": _CategoryDetail("Example", "The example resource"), } @@ -209,6 +213,7 @@ class ResourceCategory(str, Enum): EMBEDDINGS = "embeddings" RAG = "rag" VECTOR_STORE = "vector_store" + EXAMPLE = "example" def label(self) -> str: """Get the label of the category.""" @@ -343,6 +348,9 @@ class Parameter(TypeMetadata, Serializable): alias: Optional[List[str]] = Field( None, description="The alias of the parameter(Compatible with old version)" ) + ui: Optional[UIComponent] = Field( + None, description="The UI component of the parameter" + ) @model_validator(mode="before") @classmethod @@ -398,6 +406,7 @@ def build_from( label: str, name: str, type: Type, + is_list: bool = False, optional: bool = False, default: Optional[Union[DefaultParameterType, _MISSING_TYPE]] = _MISSING_VALUE, placeholder: Optional[DefaultParameterType] = None, @@ -405,6 +414,7 @@ def build_from( options: Optional[Union[BaseDynamicOptions, List[OptionValue]]] = None, resource_type: ResourceType = ResourceType.INSTANCE, alias: Optional[List[str]] = None, + ui: Optional[UIComponent] = None, ): """Build the parameter from the type.""" type_name = type.__qualname__ @@ -419,6 +429,7 @@ def build_from( name=name, type_name=type_name, type_cls=type_cls, + is_list=is_list, category=category.value, resource_type=resource_type, optional=optional, @@ -427,6 +438,7 @@ def build_from( description=description or label, options=options, alias=alias, + ui=ui, ) @classmethod @@ -456,11 +468,12 @@ def build_from_ui(cls, data: Dict) -> "Parameter": description=data["description"], options=data["options"], value=data["value"], + ui=data.get("ui"), ) def to_dict(self) -> Dict: """Convert current metadata to json dict.""" - dict_value = model_to_dict(self, exclude={"options", "alias"}) + dict_value = model_to_dict(self, exclude={"options", "alias", "ui"}) if not self.options: dict_value["options"] = None elif isinstance(self.options, BaseDynamicOptions): @@ -468,6 +481,9 @@ def to_dict(self) -> Dict: dict_value["options"] = [value.to_dict() for value in values] else: dict_value["options"] = [value.to_dict() for value in self.options] + + if self.ui: + dict_value["ui"] = self.ui.to_dict() return dict_value def get_dict_options(self) -> Optional[List[Dict]]: diff --git a/dbgpt/core/awel/flow/ui.py b/dbgpt/core/awel/flow/ui.py index a9f220961..ca4361276 100644 --- a/dbgpt/core/awel/flow/ui.py +++ b/dbgpt/core/awel/flow/ui.py @@ -1,8 +1,9 @@ """UI components for AWEL flow.""" -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Union -from dbgpt._private.pydantic import BaseModel, Field +from dbgpt._private.pydantic import BaseModel, Field, model_to_dict +from dbgpt.core.interface.serialization import Serializable from .exceptions import FlowUIComponentException @@ -36,37 +37,6 @@ class RefreshableMixin(BaseModel): ) -class UIComponent(RefreshableMixin, BaseModel): - """UI component.""" - - class UIRange(BaseModel): - """UI range.""" - - min: int | float | str | None = Field(None, description="Minimum value") - max: int | float | str | None = Field(None, description="Maximum value") - step: int | float | str | None = Field(None, description="Step value") - format: str | None = Field(None, description="Format") - - ui_type: _UI_TYPE = Field(..., description="UI component type") - - disabled: bool = Field( - False, - description="Whether the component is disabled", - ) - - def check_parameter(self, parameter_dict: Dict[str, Any]): - """Check parameter. - - Raises: - FlowUIParameterException: If the parameter is invalid. - """ - - def _check_options(self, options: Dict[str, Any]): - """Check options.""" - if not options: - raise FlowUIComponentException("options is required", self.ui_type) - - class StatusMixin(BaseModel): """Status mixin.""" @@ -76,40 +46,6 @@ class StatusMixin(BaseModel): ) -class RangeMixin(BaseModel): - """Range mixin.""" - - ui_range: Optional[UIComponent.UIRange] = Field( - None, - description="Range for the component", - ) - - -class InputMixin(BaseModel): - """Input mixin.""" - - class Count(BaseModel): - """Count.""" - - show: Optional[bool] = Field( - None, - description="Whether to show count", - ) - max: Optional[int] = Field( - None, - description="The maximum count", - ) - exceed_strategy: Optional[Literal["cut", "warning"]] = Field( - None, - description="The strategy when the count exceeds", - ) - - count: Optional[Count] = Field( - None, - description="Count configuration", - ) - - class PanelEditorMixin(BaseModel): """Edit the content in the panel.""" @@ -126,19 +62,62 @@ class Editor(BaseModel): ) editor: Optional[Editor] = Field( - None, + default_factory=lambda: PanelEditorMixin.Editor(width=800, height=400), description="The editor configuration", ) -class UICascader(StatusMixin, UIComponent): +class UIComponent(RefreshableMixin, Serializable, BaseModel): + """UI component.""" + + class UIAttribute(StatusMixin, BaseModel): + """Base UI attribute.""" + + disabled: bool = Field( + False, + description="Whether the component is disabled", + ) + + ui_type: _UI_TYPE = Field(..., description="UI component type") + + attr: Optional[UIAttribute] = Field( + None, + description="The attributes of the component", + ) + + def check_parameter(self, parameter_dict: Dict[str, Any]): + """Check parameter. + + Raises: + FlowUIParameterException: If the parameter is invalid. + """ + + def _check_options(self, options: Dict[str, Any]): + """Check options.""" + if not options: + raise FlowUIComponentException("options is required", self.ui_type) + + def to_dict(self) -> Dict: + """Convert current metadata to json dict.""" + return model_to_dict(self) + + +class UICascader(UIComponent): """Cascader component.""" + class UIAttribute(UIComponent.UIAttribute): + """Cascader attribute.""" + + show_search: bool = Field( + False, + description="Whether to show search input", + ) + ui_type: Literal["cascader"] = Field("cascader", frozen=True) - show_search: bool = Field( - False, - description="Whether to show search input", + attr: Optional[UIAttribute] = Field( + None, + description="The attributes of the component", ) def check_parameter(self, parameter_dict: Dict[str, Any]): @@ -163,53 +142,81 @@ def check_parameter(self, parameter_dict: Dict[str, Any]): self._check_options(parameter_dict.get("options", {})) -class UIDatePicker(StatusMixin, RangeMixin, UIComponent): +class UIDatePicker(UIComponent): """Date picker component.""" + class UIAttribute(UIComponent.UIAttribute): + """Date picker attribute.""" + + placement: Optional[ + Literal["topLeft", "topRight", "bottomLeft", "bottomRight"] + ] = Field( + None, + description="The position of the picker panel, None means bottomLeft", + ) + ui_type: Literal["date_picker"] = Field("date_picker", frozen=True) - placement: Optional[ - Literal["topLeft", "topRight", "bottomLeft", "bottomRight"] - ] = Field( + attr: Optional[UIAttribute] = Field( None, - description="The position of the picker panel, None means bottomLeft", + description="The attributes of the component", ) -class UIInput(StatusMixin, InputMixin, UIComponent): +class UIInput(UIComponent): """Input component.""" + class UIAttribute(UIComponent.UIAttribute): + """Input attribute.""" + + prefix: Optional[str] = Field( + None, + description="The prefix, icon or text", + examples=["$", "icon:UserOutlined"], + ) + suffix: Optional[str] = Field( + None, + description="The suffix, icon or text", + examples=["$", "icon:SearchOutlined"], + ) + show_count: Optional[bool] = Field( + None, + description="Whether to show count", + ) + maxlength: Optional[int] = Field( + None, + description="The maximum length of the input", + ) + ui_type: Literal["input"] = Field("input", frozen=True) - prefix: Optional[str] = Field( + attr: Optional[UIAttribute] = Field( None, - description="The prefix, icon or text", - examples=["$", "icon:UserOutlined"], - ) - suffix: Optional[str] = Field( - None, - description="The suffix, icon or text", - examples=["$", "icon:SearchOutlined"], + description="The attributes of the component", ) class UITextArea(PanelEditorMixin, UIInput): """Text area component.""" + class AutoSize(BaseModel): + """Auto size configuration.""" + + min_rows: Optional[int] = Field( + None, + description="The minimum number of rows", + ) + max_rows: Optional[int] = Field( + None, + description="The maximum number of rows", + ) + ui_type: Literal["text_area"] = Field("text_area", frozen=True) # type: ignore - auto_size: Optional[bool] = Field( + autosize: Optional[Union[bool, AutoSize]] = Field( None, description="Whether the height of the textarea automatically adjusts based " "on the content", ) - min_rows: Optional[int] = Field( - None, - description="The minimum number of rows", - ) - max_rows: Optional[int] = Field( - None, - description="The maximum number of rows", - ) class UIAutoComplete(UIInput): @@ -220,44 +227,73 @@ class UIAutoComplete(UIInput): ) -class UISlider(RangeMixin, UIComponent): +class UISlider(UIComponent): """Slider component.""" + class UIAttribute(UIComponent.UIAttribute): + """Slider attribute.""" + + min: Optional[int | float] = Field( + None, + description="The minimum value", + ) + max: Optional[int | float] = Field( + None, + description="The maximum value", + ) + step: Optional[int | float] = Field( + None, + description="The step of the slider", + ) + ui_type: Literal["slider"] = Field("slider", frozen=True) show_input: bool = Field( False, description="Whether to display the value in a input component" ) + attr: Optional[UIAttribute] = Field( + None, + description="The attributes of the component", + ) + -class UITimePicker(StatusMixin, UIComponent): +class UITimePicker(UIComponent): """Time picker component.""" + class UIAttribute(UIComponent.UIAttribute): + """Time picker attribute.""" + + format: Optional[str] = Field( + None, + description="The format of the time", + examples=["HH:mm:ss", "HH:mm"], + ) + hour_step: Optional[int] = Field( + None, + description="The step of the hour input", + ) + minute_step: Optional[int] = Field( + None, + description="The step of the minute input", + ) + second_step: Optional[int] = Field( + None, + description="The step of the second input", + ) + ui_type: Literal["time_picker"] = Field("time_picker", frozen=True) - format: Optional[str] = Field( - None, - description="The format of the time", - examples=["HH:mm:ss", "HH:mm"], - ) - hour_step: Optional[int] = Field( + attr: Optional[UIAttribute] = Field( None, - description="The step of the hour input", - ) - minute_step: Optional[int] = Field( - None, - description="The step of the minute input", - ) - second_step: Optional[int] = Field( - None, - description="The step of the second input", + description="The attributes of the component", ) -class UITreeSelect(StatusMixin, UIComponent): +class UITreeSelect(UICascader): """Tree select component.""" - ui_type: Literal["tree_select"] = Field("tree_select", frozen=True) + ui_type: Literal["tree_select"] = Field("tree_select", frozen=True) # type: ignore def check_parameter(self, parameter_dict: Dict[str, Any]): """Check parameter.""" @@ -271,19 +307,24 @@ def check_parameter(self, parameter_dict: Dict[str, Any]): ) -class UIUpload(StatusMixin, UIComponent): +class UIUpload(UIComponent): """Upload component.""" + class UIAttribute(UIComponent.UIAttribute): + """Upload attribute.""" + + max_count: Optional[int] = Field( + None, + description="The maximum number of files that can be uploaded", + ) + ui_type: Literal["upload"] = Field("upload", frozen=True) max_file_size: Optional[int] = Field( None, description="The maximum size of the file, in bytes", ) - max_count: Optional[int] = Field( - None, - description="The maximum number of files that can be uploaded", - ) + file_types: Optional[List[str]] = Field( None, description="The file types that can be accepted", @@ -346,3 +387,13 @@ class UICodeEditor(UITextArea): "python", description="The language of the code", ) + + +class DefaultUITextArea(UITextArea): + """Default text area component.""" + + autosize: Union[bool, UITextArea.AutoSize] = Field( + default_factory=lambda: UITextArea.AutoSize(min_rows=2, max_rows=40), + description="Whether the height of the textarea automatically adjusts based " + "on the content", + ) diff --git a/dbgpt/core/awel/util/parameter_util.py b/dbgpt/core/awel/util/parameter_util.py index defd99a3b..70015c9ba 100644 --- a/dbgpt/core/awel/util/parameter_util.py +++ b/dbgpt/core/awel/util/parameter_util.py @@ -2,7 +2,7 @@ import inspect from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Optional from dbgpt._private.pydantic import BaseModel, Field, model_validator from dbgpt.core.interface.serialization import Serializable @@ -16,6 +16,9 @@ class OptionValue(Serializable, BaseModel): label: str = Field(..., description="The label of the option") name: str = Field(..., description="The name of the option") value: Any = Field(..., description="The value of the option") + children: Optional[List["OptionValue"]] = Field( + None, description="The children of the option" + ) def to_dict(self) -> Dict: """Convert current metadata to json dict.""" diff --git a/dbgpt/core/interface/operators/llm_operator.py b/dbgpt/core/interface/operators/llm_operator.py index 53e34ffe5..45863d0a9 100644 --- a/dbgpt/core/interface/operators/llm_operator.py +++ b/dbgpt/core/interface/operators/llm_operator.py @@ -24,6 +24,7 @@ OperatorType, Parameter, ViewMetadata, + ui, ) from dbgpt.core.interface.llm import ( LLMClient, @@ -69,6 +70,10 @@ class RequestBuilderOperator(MapOperator[RequestInput, ModelRequest]): optional=True, default=None, description=_("The temperature of the model request."), + ui=ui.UISlider( + show_input=True, + attr=ui.UISlider.UIAttribute(min=0.0, max=2.0, step=0.1), + ), ), Parameter.build_from( _("Max New Tokens"), diff --git a/dbgpt/core/interface/operators/prompt_operator.py b/dbgpt/core/interface/operators/prompt_operator.py index c3765aa67..7d97230ac 100644 --- a/dbgpt/core/interface/operators/prompt_operator.py +++ b/dbgpt/core/interface/operators/prompt_operator.py @@ -1,4 +1,5 @@ """The prompt operator.""" + from abc import ABC from typing import Any, Dict, List, Optional, Union @@ -18,6 +19,7 @@ ResourceCategory, ViewMetadata, register_resource, + ui, ) from dbgpt.core.interface.message import BaseMessage from dbgpt.core.interface.operators.llm_operator import BaseLLM @@ -48,6 +50,7 @@ optional=True, default="You are a helpful AI Assistant.", description=_("The system message."), + ui=ui.DefaultUITextArea(), ), Parameter.build_from( label=_("Message placeholder"), @@ -65,6 +68,7 @@ default="{user_input}", placeholder="{user_input}", description=_("The human message."), + ui=ui.DefaultUITextArea(), ), ], ) diff --git a/dbgpt/serve/flow/api/endpoints.py b/dbgpt/serve/flow/api/endpoints.py index 6cb5ef879..98ff81d2f 100644 --- a/dbgpt/serve/flow/api/endpoints.py +++ b/dbgpt/serve/flow/api/endpoints.py @@ -209,7 +209,7 @@ async def query_page( @router.get("/nodes", dependencies=[Depends(check_api_key)]) -async def get_nodes() -> Result[List[Union[ViewMetadata, ResourceMetadata]]]: +async def get_nodes(): """Get the operator or resource nodes Returns: @@ -218,7 +218,8 @@ async def get_nodes() -> Result[List[Union[ViewMetadata, ResourceMetadata]]]: """ from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY - return Result.succ(_OPERATOR_REGISTRY.metadata_list()) + metadata_list = _OPERATOR_REGISTRY.metadata_list() + return Result.succ(metadata_list) def init_endpoints(system_app: SystemApp) -> None: diff --git a/examples/awel/awel_flow_ui_components.py b/examples/awel/awel_flow_ui_components.py new file mode 100644 index 000000000..2af3e2bf3 --- /dev/null +++ b/examples/awel/awel_flow_ui_components.py @@ -0,0 +1,583 @@ +"""Some UI components for the AWEL flow.""" + +import logging +from typing import List, Optional + +from dbgpt.core.awel import MapOperator +from dbgpt.core.awel.flow import ( + IOField, + OperatorCategory, + OptionValue, + Parameter, + ViewMetadata, + ui, +) + +logger = logging.getLogger(__name__) + + +class ExampleFlowCascaderOperator(MapOperator[str, str]): + """An example flow operator that includes a cascader as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Cascader", + name="example_flow_cascader", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a cascader as parameter.", + parameters=[ + Parameter.build_from( + "Address Selector", + "address", + type=str, + is_list=True, + optional=True, + default=None, + placeholder="Select the address", + description="The address of the location.", + options=[ + OptionValue( + label="Zhejiang", + name="zhejiang", + value="zhejiang", + children=[ + OptionValue( + label="Hangzhou", + name="hangzhou", + value="hangzhou", + children=[ + OptionValue( + label="Xihu", + name="xihu", + value="xihu", + ), + OptionValue( + label="Feilaifeng", + name="feilaifeng", + value="feilaifeng", + ), + ], + ), + ], + ), + OptionValue( + label="Jiangsu", + name="jiangsu", + value="jiangsu", + children=[ + OptionValue( + label="Nanjing", + name="nanjing", + value="nanjing", + children=[ + OptionValue( + label="Zhonghua Gate", + name="zhonghuamen", + value="zhonghuamen", + ), + OptionValue( + label="Zhongshanling", + name="zhongshanling", + value="zhongshanling", + ), + ], + ), + ], + ), + ], + ui=ui.UICascader(attr=ui.UICascader.UIAttribute(show_search=True)), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Address", + "address", + str, + description="User's address.", + ) + ], + ) + + def __int__(self, address: Optional[List[str]] = None, **kwargs): + super().__init__(**kwargs) + self.address = address or [] + + async def map(self, user_name: str) -> str: + """Map the user name to the address.""" + full_address_str = " ".join(self.address) + return "Your name is %s, and your address is %s." % ( + user_name, + full_address_str, + ) + + +class ExampleFlowCheckboxOperator(MapOperator[str, str]): + """An example flow operator that includes a checkbox as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Checkbox", + name="example_flow_checkbox", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a checkbox as parameter.", + parameters=[ + Parameter.build_from( + "Fruits Selector", + "fruits", + type=str, + is_list=True, + optional=True, + default=None, + placeholder="Select the fruits", + description="The fruits you like.", + options=[ + OptionValue(label="Apple", name="apple", value="apple"), + OptionValue(label="Banana", name="banana", value="banana"), + OptionValue(label="Orange", name="orange", value="orange"), + OptionValue(label="Pear", name="pear", value="pear"), + ], + ui=ui.UICheckbox(attr=ui.UICheckbox.UIAttribute(show_search=True)), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Fruits", + "fruits", + str, + description="User's favorite fruits.", + ) + ], + ) + + def __init__(self, fruits: Optional[List[str]] = None, **kwargs): + super().__init__(**kwargs) + self.fruits = fruits or [] + + async def map(self, user_name: str) -> str: + """Map the user name to the fruits.""" + return "Your name is %s, and you like %s." % (user_name, ", ".join(self.fruits)) + + +class ExampleFlowDatePickerOperator(MapOperator[str, str]): + """An example flow operator that includes a date picker as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Date Picker", + name="example_flow_date_picker", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a date picker as parameter.", + parameters=[ + Parameter.build_from( + "Date Selector", + "date", + type=str, + placeholder="Select the date", + description="The date you choose.", + ui=ui.UIDatePicker( + attr=ui.UIDatePicker.UIAttribute(placement="bottomLeft") + ), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Date", + "date", + str, + description="User's selected date.", + ) + ], + ) + + def __init__(self, date: str, **kwargs): + super().__init__(**kwargs) + self.date = date + + async def map(self, user_name: str) -> str: + """Map the user name to the date.""" + return "Your name is %s, and you choose the date %s." % (user_name, self.date) + + +class ExampleFlowInputOperator(MapOperator[str, str]): + """An example flow operator that includes an input as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Input", + name="example_flow_input", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a input as parameter.", + parameters=[ + Parameter.build_from( + "Your hobby", + "hobby", + type=str, + placeholder="Please input your hobby", + description="The hobby you like.", + ui=ui.UIInput( + attr=ui.UIInput.UIAttribute( + prefix="icon:UserOutlined", show_count=True, maxlength=200 + ) + ), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "User Hobby", + "hobby", + str, + description="User's hobby.", + ) + ], + ) + + def __init__(self, hobby: str, **kwargs): + super().__init__(**kwargs) + self.hobby = hobby + + async def map(self, user_name: str) -> str: + """Map the user name to the input.""" + return "Your name is %s, and your hobby is %s." % (user_name, self.hobby) + + +class ExampleFlowTextAreaOperator(MapOperator[str, str]): + """An example flow operator that includes a text area as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Text Area", + name="example_flow_text_area", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a text area as parameter.", + parameters=[ + Parameter.build_from( + "Your comment", + "comment", + type=str, + placeholder="Please input your comment", + description="The comment you want to say.", + ui=ui.UITextArea( + attr=ui.UITextArea.UIAttribute(show_count=True, maxlength=1000), + autosize=ui.UITextArea.AutoSize(min_rows=2, max_rows=6), + ), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "User Comment", + "comment", + str, + description="User's comment.", + ) + ], + ) + + def __init__(self, comment: str, **kwargs): + super().__init__(**kwargs) + self.comment = comment + + async def map(self, user_name: str) -> str: + """Map the user name to the text area.""" + return "Your name is %s, and your comment is %s." % (user_name, self.comment) + + +class ExampleFlowSliderOperator(MapOperator[float, float]): + + metadata = ViewMetadata( + label="Example Flow Slider", + name="example_flow_slider", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a slider as parameter.", + parameters=[ + Parameter.build_from( + "Default Temperature", + "default_temperature", + type=float, + optional=True, + default=0.7, + placeholder="Set the default temperature, e.g., 0.7", + description="The default temperature to pass to the LLM.", + ui=ui.UISlider( + show_input=True, + attr=ui.UISlider.UIAttribute(min=0.0, max=2.0, step=0.1), + ), + ) + ], + inputs=[ + IOField.build_from( + "Temperature", + "temperature", + float, + description="The temperature.", + ) + ], + outputs=[ + IOField.build_from( + "Temperature", + "temperature", + float, + description="The temperature to pass to the LLM.", + ) + ], + ) + + def __init__(self, default_temperature: float = 0.7, **kwargs): + super().__init__(**kwargs) + self.default_temperature = default_temperature + + async def map(self, temperature: float) -> float: + """Map the temperature to the result.""" + if temperature < 0.0 or temperature > 2.0: + logger.warning("Temperature out of range: %s", temperature) + return self.default_temperature + else: + return temperature + + +class ExampleFlowSliderListOperator(MapOperator[float, float]): + """An example flow operator that includes a slider list as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Slider List", + name="example_flow_slider_list", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a slider list as parameter.", + parameters=[ + Parameter.build_from( + "Temperature Selector", + "temperature_range", + type=float, + is_list=True, + optional=True, + default=None, + placeholder="Set the temperature, e.g., [0.1, 0.9]", + description="The temperature range to pass to the LLM.", + ui=ui.UISlider( + show_input=True, + attr=ui.UISlider.UIAttribute(min=0.0, max=2.0, step=0.1), + ), + ) + ], + inputs=[ + IOField.build_from( + "Temperature", + "temperature", + float, + description="The temperature.", + ) + ], + outputs=[ + IOField.build_from( + "Temperature", + "temperature", + float, + description="The temperature to pass to the LLM.", + ) + ], + ) + + def __init__(self, temperature_range: Optional[List[float]] = None, **kwargs): + super().__init__(**kwargs) + temperature_range = temperature_range or [0.1, 0.9] + if temperature_range and len(temperature_range) != 2: + raise ValueError("The length of temperature range must be 2.") + self.temperature_range = temperature_range + + async def map(self, temperature: float) -> float: + """Map the temperature to the result.""" + min_temperature, max_temperature = self.temperature_range + if temperature < min_temperature or temperature > max_temperature: + logger.warning( + "Temperature out of range: %s, min: %s, max: %s", + temperature, + min_temperature, + max_temperature, + ) + return min_temperature + return temperature + + +class ExampleFlowTimePickerOperator(MapOperator[str, str]): + """An example flow operator that includes a time picker as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Time Picker", + name="example_flow_time_picker", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a time picker as parameter.", + parameters=[ + Parameter.build_from( + "Time Selector", + "time", + type=str, + placeholder="Select the time", + description="The time you choose.", + ui=ui.UITimePicker( + attr=ui.UITimePicker.UIAttribute( + format="HH:mm:ss", hour_step=2, minute_step=10, second_step=10 + ), + ), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Time", + "time", + str, + description="User's selected time.", + ) + ], + ) + + def __init__(self, time: str, **kwargs): + super().__init__(**kwargs) + self.time = time + + async def map(self, user_name: str) -> str: + """Map the user name to the time.""" + return "Your name is %s, and you choose the time %s." % (user_name, self.time) + + +class ExampleFlowTreeSelectOperator(MapOperator[str, str]): + """An example flow operator that includes a tree select as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Tree Select", + name="example_flow_tree_select", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a tree select as parameter.", + parameters=[ + Parameter.build_from( + "Address Selector", + "address", + type=str, + is_list=True, + optional=True, + default=None, + placeholder="Select the address", + description="The address of the location.", + options=[ + OptionValue( + label="Zhejiang", + name="zhejiang", + value="zhejiang", + children=[ + OptionValue( + label="Hangzhou", + name="hangzhou", + value="hangzhou", + children=[ + OptionValue( + label="Xihu", + name="xihu", + value="xihu", + ), + OptionValue( + label="Feilaifeng", + name="feilaifeng", + value="feilaifeng", + ), + ], + ), + ], + ), + OptionValue( + label="Jiangsu", + name="jiangsu", + value="jiangsu", + children=[ + OptionValue( + label="Nanjing", + name="nanjing", + value="nanjing", + children=[ + OptionValue( + label="Zhonghua Gate", + name="zhonghuamen", + value="zhonghuamen", + ), + OptionValue( + label="Zhongshanling", + name="zhongshanling", + value="zhongshanling", + ), + ], + ), + ], + ), + ], + ui=ui.UITreeSelect(attr=ui.UITreeSelect.UIAttribute(show_search=True)), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Address", + "address", + str, + description="User's address.", + ) + ], + ) + + def __int__(self, address: Optional[List[str]] = None, **kwargs): + super().__init__(**kwargs) + self.address = address or [] + + async def map(self, user_name: str) -> str: + """Map the user name to the address.""" + full_address_str = " ".join(self.address) + return "Your name is %s, and your address is %s." % ( + user_name, + full_address_str, + ) From db44f2b3a2d83dc7a56eb68c04902f82ad1588e2 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Tue, 6 Aug 2024 10:17:58 +0800 Subject: [PATCH 03/89] feat(core): Support refresh for AWEL flow --- dbgpt/core/awel/flow/base.py | 52 ++++++++- dbgpt/core/awel/flow/ui.py | 33 ++++++ dbgpt/core/awel/util/parameter_util.py | 38 ++++++- dbgpt/serve/flow/api/endpoints.py | 19 +++- dbgpt/serve/flow/api/schemas.py | 40 ++++++- examples/awel/awel_flow_ui_components.py | 136 +++++++++++++++++++++++ 6 files changed, 307 insertions(+), 11 deletions(-) diff --git a/dbgpt/core/awel/flow/base.py b/dbgpt/core/awel/flow/base.py index da0b2c378..846b18baf 100644 --- a/dbgpt/core/awel/flow/base.py +++ b/dbgpt/core/awel/flow/base.py @@ -15,7 +15,11 @@ model_to_dict, model_validator, ) -from dbgpt.core.awel.util.parameter_util import BaseDynamicOptions, OptionValue +from dbgpt.core.awel.util.parameter_util import ( + BaseDynamicOptions, + OptionValue, + RefreshOptionRequest, +) from dbgpt.core.interface.serialization import Serializable from .exceptions import FlowMetadataException, FlowParameterMetadataException @@ -486,6 +490,25 @@ def to_dict(self) -> Dict: dict_value["ui"] = self.ui.to_dict() return dict_value + def refresh(self, request: Optional[RefreshOptionRequest] = None) -> Dict: + """Refresh the options of the parameter. + + Args: + request (RefreshOptionRequest): The request to refresh the options. + + Returns: + Dict: The response. + """ + dict_value = self.to_dict() + if not self.options: + dict_value["options"] = None + elif isinstance(self.options, BaseDynamicOptions): + values = self.options.refresh(request) + dict_value["options"] = [value.to_dict() for value in values] + else: + dict_value["options"] = [value.to_dict() for value in self.options] + return dict_value + def get_dict_options(self) -> Optional[List[Dict]]: """Get the options of the parameter.""" if not self.options: @@ -655,10 +678,10 @@ class BaseMetadata(BaseResource): ], ) - tags: Optional[List[str]] = Field( + tags: Optional[Dict[str, str]] = Field( default=None, description="The tags of the operator", - examples=[["llm", "openai", "gpt3"]], + examples=[{"order": "higher-order"}, {"order": "first-order"}], ) parameters: List[Parameter] = Field( @@ -768,6 +791,20 @@ def to_dict(self) -> Dict: ] return dict_value + def refresh(self, request: List[RefreshOptionRequest]) -> Dict: + """Refresh the metadata.""" + name_to_request = {req.name: req for req in request} + parameter_requests = { + parameter.name: name_to_request.get(parameter.name) + for parameter in self.parameters + } + dict_value = self.to_dict() + dict_value["parameters"] = [ + parameter.refresh(parameter_requests.get(parameter.name)) + for parameter in self.parameters + ] + return dict_value + class ResourceMetadata(BaseMetadata, TypeMetadata): """The metadata of the resource.""" @@ -1051,6 +1088,15 @@ def metadata_list(self): """Get the metadata list.""" return [item.metadata.to_dict() for item in self._registry.values()] + def refresh( + self, key: str, is_operator: bool, request: List[RefreshOptionRequest] + ) -> Dict: + """Refresh the metadata.""" + if is_operator: + return _get_operator_class(key).metadata.refresh(request) # type: ignore + else: + return _get_resource_class(key).metadata.refresh(request) + _OPERATOR_REGISTRY: FlowRegistry = FlowRegistry() diff --git a/dbgpt/core/awel/flow/ui.py b/dbgpt/core/awel/flow/ui.py index ca4361276..91008269e 100644 --- a/dbgpt/core/awel/flow/ui.py +++ b/dbgpt/core/awel/flow/ui.py @@ -8,6 +8,7 @@ from .exceptions import FlowUIComponentException _UI_TYPE = Literal[ + "select", "cascader", "checkbox", "date_picker", @@ -102,6 +103,38 @@ def to_dict(self) -> Dict: return model_to_dict(self) +class UISelect(UIComponent): + """Select component.""" + + class UIAttribute(UIComponent.UIAttribute): + """Select attribute.""" + + show_search: bool = Field( + False, + description="Whether to show search input", + ) + mode: Optional[Literal["tags"]] = Field( + None, + description="The mode of the select", + ) + placement: Optional[ + Literal["topLeft", "topRight", "bottomLeft", "bottomRight"] + ] = Field( + None, + description="The position of the picker panel, None means bottomLeft", + ) + + ui_type: Literal["select"] = Field("select", frozen=True) + attr: Optional[UIAttribute] = Field( + None, + description="The attributes of the component", + ) + + def check_parameter(self, parameter_dict: Dict[str, Any]): + """Check parameter.""" + self._check_options(parameter_dict.get("options", {})) + + class UICascader(UIComponent): """Cascader component.""" diff --git a/dbgpt/core/awel/util/parameter_util.py b/dbgpt/core/awel/util/parameter_util.py index 70015c9ba..2393aed89 100644 --- a/dbgpt/core/awel/util/parameter_util.py +++ b/dbgpt/core/awel/util/parameter_util.py @@ -10,6 +10,27 @@ _DEFAULT_DYNAMIC_REGISTRY = {} +class RefreshOptionDependency(BaseModel): + """The refresh dependency.""" + + name: str = Field(..., description="The name of the refresh dependency") + value: Optional[Any] = Field( + None, description="The value of the refresh dependency" + ) + has_value: bool = Field( + False, description="Whether the refresh dependency has value" + ) + + +class RefreshOptionRequest(BaseModel): + """The refresh option request.""" + + name: str = Field(..., description="The name of parameter to refresh") + depends: Optional[List[RefreshOptionDependency]] = Field( + None, description="The depends of the refresh config" + ) + + class OptionValue(Serializable, BaseModel): """The option value of the parameter.""" @@ -28,24 +49,31 @@ def to_dict(self) -> Dict: class BaseDynamicOptions(Serializable, BaseModel, ABC): """The base dynamic options.""" - @abstractmethod def option_values(self) -> List[OptionValue]: """Return the option values of the parameter.""" + return self.refresh(None) + + @abstractmethod + def refresh(self, request: Optional[RefreshOptionRequest]) -> List[OptionValue]: + """Refresh the dynamic options.""" class FunctionDynamicOptions(BaseDynamicOptions): """The function dynamic options.""" - func: Callable[[], List[OptionValue]] = Field( + func: Callable[..., List[OptionValue]] = Field( ..., description="The function to generate the dynamic options" ) func_id: str = Field( ..., description="The unique id of the function to generate the dynamic options" ) - def option_values(self) -> List[OptionValue]: - """Return the option values of the parameter.""" - return self.func() + def refresh(self, request: Optional[RefreshOptionRequest]) -> List[OptionValue]: + """Refresh the dynamic options.""" + if not request or not request.depends: + return self.func() + kwargs = {dep.name: dep.value for dep in request.depends if dep.has_value} + return self.func(**kwargs) @model_validator(mode="before") @classmethod diff --git a/dbgpt/serve/flow/api/endpoints.py b/dbgpt/serve/flow/api/endpoints.py index 98ff81d2f..99852271a 100644 --- a/dbgpt/serve/flow/api/endpoints.py +++ b/dbgpt/serve/flow/api/endpoints.py @@ -11,7 +11,7 @@ from ..config import APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig from ..service.service import Service -from .schemas import ServeRequest, ServerResponse +from .schemas import RefreshNodeRequest, ServeRequest, ServerResponse router = APIRouter() @@ -222,6 +222,23 @@ async def get_nodes(): return Result.succ(metadata_list) +@router.post("/nodes/refresh", dependencies=[Depends(check_api_key)]) +async def refresh_nodes(refresh_request: RefreshNodeRequest): + """Refresh the operator or resource nodes + + Returns: + Result[None]: The response + """ + from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY + + new_metadata = _OPERATOR_REGISTRY.refresh( + key=refresh_request.id, + is_operator=refresh_request.flow_type == "operator", + request=refresh_request.refresh, + ) + return Result.succ(new_metadata) + + def init_endpoints(system_app: SystemApp) -> None: """Initialize the endpoints""" global global_system_app diff --git a/dbgpt/serve/flow/api/schemas.py b/dbgpt/serve/flow/api/schemas.py index 6fb8c1924..2daa8f581 100644 --- a/dbgpt/serve/flow/api/schemas.py +++ b/dbgpt/serve/flow/api/schemas.py @@ -1,7 +1,8 @@ -from dbgpt._private.pydantic import ConfigDict +from typing import List, Literal -# Define your Pydantic schemas here +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field from dbgpt.core.awel.flow.flow_factory import FlowPanel +from dbgpt.core.awel.util.parameter_util import RefreshOptionRequest from ..config import SERVE_APP_NAME_HUMP @@ -14,3 +15,38 @@ class ServerResponse(FlowPanel): # TODO define your own fields here model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}") + + +class RefreshNodeRequest(BaseModel): + """Flow response model""" + + model_config = ConfigDict(title=f"RefreshNodeRequest") + id: str = Field( + ..., + title="The id of the node", + description="The id of the node to refresh", + examples=["operator_llm_operator___$$___llm___$$___v1"], + ) + flow_type: Literal["operator", "resource"] = Field( + "operator", + title="The type of the node", + description="The type of the node to refresh", + examples=["operator", "resource"], + ) + type_name: str = Field( + ..., + title="The type of the node", + description="The type of the node to refresh", + examples=["LLMOperator"], + ) + type_cls: str = Field( + ..., + title="The class of the node", + description="The class of the node to refresh", + examples=["dbgpt.core.operator.llm.LLMOperator"], + ) + refresh: List[RefreshOptionRequest] = Field( + ..., + title="The refresh options", + description="The refresh options", + ) diff --git a/examples/awel/awel_flow_ui_components.py b/examples/awel/awel_flow_ui_components.py index 2af3e2bf3..fc8d9a5c4 100644 --- a/examples/awel/awel_flow_ui_components.py +++ b/examples/awel/awel_flow_ui_components.py @@ -5,6 +5,7 @@ from dbgpt.core.awel import MapOperator from dbgpt.core.awel.flow import ( + FunctionDynamicOptions, IOField, OperatorCategory, OptionValue, @@ -16,6 +17,59 @@ logger = logging.getLogger(__name__) +class ExampleFlowSelectOperator(MapOperator[str, str]): + """An example flow operator that includes a select as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Select", + name="example_flow_select", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a select as parameter.", + parameters=[ + Parameter.build_from( + "Fruits Selector", + "fruits", + type=str, + optional=True, + default=None, + placeholder="Select the fruits", + description="The fruits you like.", + options=[ + OptionValue(label="Apple", name="apple", value="apple"), + OptionValue(label="Banana", name="banana", value="banana"), + OptionValue(label="Orange", name="orange", value="orange"), + OptionValue(label="Pear", name="pear", value="pear"), + ], + ui=ui.UISelect(attr=ui.UISelect.UIAttribute(show_search=True)), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Fruits", + "fruits", + str, + description="User's favorite fruits.", + ) + ], + ) + + def __init__(self, fruits: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + self.fruits = fruits + + async def map(self, user_name: str) -> str: + """Map the user name to the fruits.""" + return "Your name is %s, and you like %s." % (user_name, self.fruits) + + class ExampleFlowCascaderOperator(MapOperator[str, str]): """An example flow operator that includes a cascader as parameter.""" @@ -581,3 +635,85 @@ async def map(self, user_name: str) -> str: user_name, full_address_str, ) + + +def get_recent_3_times(time_interval: int = 1) -> List[OptionValue]: + """Get the recent times.""" + from datetime import datetime, timedelta + + now = datetime.now() + recent_times = [now - timedelta(hours=time_interval * i) for i in range(3)] + formatted_times = [time.strftime("%Y-%m-%d %H:%M:%S") for time in recent_times] + option_values = [ + OptionValue(label=formatted_time, name=f"time_{i + 1}", value=formatted_time) + for i, formatted_time in enumerate(formatted_times) + ] + + return option_values + + +class ExampleFlowRefreshOperator(MapOperator[str, str]): + """An example flow operator that includes a refresh option.""" + + metadata = ViewMetadata( + label="Example Refresh Operator", + name="example_refresh_operator", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a refresh option.", + parameters=[ + Parameter.build_from( + "Time Interval", + "time_interval", + type=int, + optional=True, + default=1, + placeholder="Set the time interval", + description="The time interval to fetch the times", + ), + Parameter.build_from( + "Recent Time", + "recent_time", + type=str, + optional=True, + default=None, + placeholder="Select the recent time", + description="The recent time to choose.", + options=FunctionDynamicOptions(func=get_recent_3_times), + ui=ui.UISelect( + refresh=True, + refresh_depends=["time_interval"], + attr=ui.UISelect.UIAttribute(show_search=True), + ), + ), + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Time", + "time", + str, + description="User's selected time.", + ) + ], + ) + + def __init__( + self, time_interval: int = 1, recent_time: Optional[str] = None, **kwargs + ): + super().__init__(**kwargs) + self.time_interval = time_interval + self.recent_time = recent_time + + async def map(self, user_name: str) -> str: + """Map the user name to the time.""" + return "Your name is %s, and you choose the time %s." % ( + user_name, + self.recent_time, + ) From e97b63a4d4f347fa34b8e791397018c336a4b597 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A8=E6=AC=A3?= Date: Thu, 8 Aug 2024 19:27:52 +0800 Subject: [PATCH 04/89] =?UTF-8?q?feat:=20=E7=BC=96=E6=8E=92=E7=94=BB?= =?UTF-8?q?=E5=B8=83=E6=96=B0=E5=A2=9E=E5=9F=BA=E4=BA=8EAWEL2.0=E7=9A=84Se?= =?UTF-8?q?lect=E7=B1=BB=E5=9E=8B=E8=8A=82=E7=82=B9=E6=B8=B2=E6=9F=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/components/flow/node-param-handler.tsx | 36 ++++++++++++++----- .../flow/node-renderer/cascader.tsx | 3 ++ web/components/flow/node-renderer/index.ts | 2 ++ web/components/flow/node-renderer/select.tsx | 21 +++++++++++ web/types/flow.ts | 9 +++++ 5 files changed, 62 insertions(+), 9 deletions(-) create mode 100644 web/components/flow/node-renderer/cascader.tsx create mode 100644 web/components/flow/node-renderer/index.ts create mode 100644 web/components/flow/node-renderer/select.tsx diff --git a/web/components/flow/node-param-handler.tsx b/web/components/flow/node-param-handler.tsx index c617369cf..d86da6653 100644 --- a/web/components/flow/node-param-handler.tsx +++ b/web/components/flow/node-param-handler.tsx @@ -4,6 +4,7 @@ import React from 'react'; import RequiredIcon from './required-icon'; import NodeHandler from './node-handler'; import { InfoCircleOutlined } from '@ant-design/icons'; +import { RenderSelect, RenderCascader } from './node-renderer'; interface NodeParamHandlerProps { node: IFlowNode; @@ -14,14 +15,14 @@ interface NodeParamHandlerProps { // render node parameters item const NodeParamHandler: React.FC = ({ node, data, label, index }) => { - function handleChange(value: any) { + function onChange(value: any) { data.value = value; } - if (data.category === 'resource') { - return ; - } else if (data.category === 'common') { + // 基于AWEL1.0的流程设计器,对节点参数的渲染 + function renderNodeWithoutUiParam(data: IFlowNodeParameter) { let defaultValue = data.value !== null && data.value !== undefined ? data.value : data.default; + switch (data.type_name) { case 'int': case 'float': @@ -39,7 +40,7 @@ const NodeParamHandler: React.FC = ({ node, data, label, className="w-full" defaultValue={defaultValue} onChange={(value: number | null) => { - handleChange(value); + onChange(value); }} /> @@ -60,20 +61,20 @@ const NodeParamHandler: React.FC = ({ node, data, label, className="w-full nodrag" defaultValue={defaultValue} options={data.options.map((item: any) => ({ label: item.label, value: item.value }))} - onChange={handleChange} + onChange={onChange} /> ) : ( { - handleChange(e.target.value); + onChange(e.target.value); }} /> )} ); - case 'bool': + case 'checkbox': defaultValue = defaultValue === 'False' ? false : defaultValue; defaultValue = defaultValue === 'True' ? true : defaultValue; return ( @@ -89,7 +90,7 @@ const NodeParamHandler: React.FC = ({ node, data, label, className="ml-2" defaultChecked={defaultValue} onChange={(e) => { - handleChange(e.target.checked); + onChange(e.target.checked); }} />
{node.label}
{node.description}
{data.label}: {data.description && ( @@ -59,7 +59,7 @@ const NodeParamHandler: React.FC = ({ node, data, label, ); case 'str': return ( - + {data.label}: {data.description && ( @@ -86,11 +86,11 @@ const NodeParamHandler: React.FC = ({ node, data, label, )} ); - case 'checkbox': + case 'bool': defaultValue = defaultValue === 'False' ? false : defaultValue; defaultValue = defaultValue === 'True' ? true : defaultValue; return ( - + {data.label}: {data.description && ( diff --git a/web/components/flow/node-renderer/cascader.tsx b/web/components/flow/node-renderer/cascader.tsx index 4c9d69ac8..118c3ac40 100644 --- a/web/components/flow/node-renderer/cascader.tsx +++ b/web/components/flow/node-renderer/cascader.tsx @@ -13,15 +13,13 @@ export const RenderCascader = (params: Props) => { const attr = convertKeysToCamelCase(data.ui?.attr || {}); return ( - - - + ); }; diff --git a/web/components/flow/node-renderer/checkbox.tsx b/web/components/flow/node-renderer/checkbox.tsx index 34b8ac657..0500d3498 100644 --- a/web/components/flow/node-renderer/checkbox.tsx +++ b/web/components/flow/node-renderer/checkbox.tsx @@ -14,7 +14,7 @@ export const RenderCheckbox = (params: Props) => { return ( data.options?.length > 0 && ( - + ) diff --git a/web/components/flow/node-renderer/date-picker.tsx b/web/components/flow/node-renderer/date-picker.tsx index d69538a61..25ddd33cc 100644 --- a/web/components/flow/node-renderer/date-picker.tsx +++ b/web/components/flow/node-renderer/date-picker.tsx @@ -17,9 +17,5 @@ export const RenderDatePicker = (params: Props) => { onChange(dateString); }; - return ( - - - - ); + return ; }; diff --git a/web/components/flow/node-renderer/input.tsx b/web/components/flow/node-renderer/input.tsx index 8a47d3d2f..538d14081 100644 --- a/web/components/flow/node-renderer/input.tsx +++ b/web/components/flow/node-renderer/input.tsx @@ -13,16 +13,14 @@ export const RenderInput = (params: Props) => { const attr = convertKeysToCamelCase(data.ui?.attr || {}); return ( - - { - onChange(e.target.value); - }} - /> - + { + onChange(e.target.value); + }} + /> ); }; diff --git a/web/components/flow/node-renderer/radio.tsx b/web/components/flow/node-renderer/radio.tsx index 1bc1763a6..056681ef4 100644 --- a/web/components/flow/node-renderer/radio.tsx +++ b/web/components/flow/node-renderer/radio.tsx @@ -13,7 +13,7 @@ export const RenderRadio = (params: Props) => { const attr = convertKeysToCamelCase(data.ui?.attr || {}); return ( - + { const attr = convertKeysToCamelCase(data.ui?.attr || {}); return ( - - - + ); }; diff --git a/web/components/flow/node-renderer/slider-list.tsx b/web/components/flow/node-renderer/slider-list.tsx deleted file mode 100644 index a1e3553f4..000000000 --- a/web/components/flow/node-renderer/slider-list.tsx +++ /dev/null @@ -1,40 +0,0 @@ -import { IFlowNodeParameter } from '@/types/flow'; -import { convertKeysToCamelCase } from '@/utils/flow'; -import { Col, InputNumber, Row, Slider, Space } from 'antd'; -import type { InputNumberProps } from 'antd'; -import React, { useState } from 'react'; - -type TextAreaProps = { - data: IFlowNodeParameter; - defaultValue: any; - onChange: (value: any) => void; -}; - -export const RenderSlider = (params: TextAreaProps) => { - const { data, defaultValue, onChange } = params; - const attr = convertKeysToCamelCase(data.ui?.attr || {}); - const [inputValue, setInputValue] = useState(defaultValue); - - const onChangeSlider: InputNumberProps['onChange'] = (newValue) => { - setInputValue(newValue as number); - onChange(newValue as number); - }; - - return ( - - {data?.ui?.show_input ? ( - - - - - - - - - - ) : ( - - )} - - ); -}; diff --git a/web/components/flow/node-renderer/slider.tsx b/web/components/flow/node-renderer/slider.tsx index bb8456adb..1017e20bb 100644 --- a/web/components/flow/node-renderer/slider.tsx +++ b/web/components/flow/node-renderer/slider.tsx @@ -21,7 +21,7 @@ export const RenderSlider = (params: TextAreaProps) => { }; return ( - + <> {data?.ui?.show_input ? ( @@ -34,6 +34,6 @@ export const RenderSlider = (params: TextAreaProps) => { ) : ( )} - + > ); }; diff --git a/web/components/flow/node-renderer/textarea.tsx b/web/components/flow/node-renderer/textarea.tsx index 5f8c55ac0..4a2c40f20 100644 --- a/web/components/flow/node-renderer/textarea.tsx +++ b/web/components/flow/node-renderer/textarea.tsx @@ -1,6 +1,7 @@ import { IFlowNodeParameter } from '@/types/flow'; import { Input } from 'antd'; import { convertKeysToCamelCase } from '@/utils/flow'; +import classNames from 'classnames'; const { TextArea } = Input; @@ -12,11 +13,11 @@ type TextAreaProps = { export const RenderTextArea = (params: TextAreaProps) => { const { data, defaultValue, onChange } = params; - convertKeysToCamelCase(data?.ui?.attr?.autosize || {}); + const attr = convertKeysToCamelCase(data.ui?.attr || {}); return ( - - onChange(e.target.value)} {...data.ui.attr.autosize} rows={4} /> + + ); }; diff --git a/web/components/flow/node-renderer/time-picker.tsx b/web/components/flow/node-renderer/time-picker.tsx index 358f0b8a9..bfe3f8cb7 100644 --- a/web/components/flow/node-renderer/time-picker.tsx +++ b/web/components/flow/node-renderer/time-picker.tsx @@ -17,9 +17,5 @@ export const RenderTimePicker = (params: TextAreaProps) => { onChange(timeString); }; - return ( - - - - ); + return ; }; diff --git a/web/components/flow/node-renderer/tree-select.tsx b/web/components/flow/node-renderer/tree-select.tsx index 7acc3d73a..508407217 100644 --- a/web/components/flow/node-renderer/tree-select.tsx +++ b/web/components/flow/node-renderer/tree-select.tsx @@ -79,18 +79,16 @@ export const RenderTreeSelect = (params: TextAreaProps) => { }; return ( - - - + // TODO: Implement the TreeSelect component // Date: Thu, 15 Aug 2024 00:25:42 +0800 Subject: [PATCH 20/89] fix: fixed the issue of incorrect background color of canvas node in dark mode --- web/components/flow/canvas-node.tsx | 10 +++++----- web/components/flow/node-handler.tsx | 2 +- web/components/flow/node-renderer/date-picker.tsx | 12 +++++++++++- web/components/flow/node-renderer/time-picker.tsx | 2 +- 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/web/components/flow/canvas-node.tsx b/web/components/flow/canvas-node.tsx index de33a0e7b..0177aa4f1 100644 --- a/web/components/flow/canvas-node.tsx +++ b/web/components/flow/canvas-node.tsx @@ -74,7 +74,7 @@ const CanvasNode: React.FC = ({ data }) => { function renderOutput(data: IFlowNode) { if (flowType === 'operator' && outputs?.length > 0) { return ( - + {outputs?.map((output, index) => ( @@ -86,7 +86,7 @@ const CanvasNode: React.FC = ({ data }) => { } else if (flowType === 'resource') { // resource nodes show output default return ( - + @@ -126,7 +126,7 @@ const CanvasNode: React.FC = ({ data }) => { > = ({ data }) => { {inputs?.length > 0 && ( - + {inputs?.map((input, index) => ( @@ -155,7 +155,7 @@ const CanvasNode: React.FC = ({ data }) => { )} {parameters?.length > 0 && ( - + {parameters?.map((parameter, index) => ( diff --git a/web/components/flow/node-handler.tsx b/web/components/flow/node-handler.tsx index 9b7682291..dfe144b2c 100644 --- a/web/components/flow/node-handler.tsx +++ b/web/components/flow/node-handler.tsx @@ -101,7 +101,7 @@ const NodeHandler: React.FC = ({ node, data, type, label, inde isValidConnection={(connection) => isValidConnection(connection)} /> diff --git a/web/components/flow/node-renderer/date-picker.tsx b/web/components/flow/node-renderer/date-picker.tsx index 25ddd33cc..478295c75 100644 --- a/web/components/flow/node-renderer/date-picker.tsx +++ b/web/components/flow/node-renderer/date-picker.tsx @@ -11,11 +11,21 @@ type Props = { export const RenderDatePicker = (params: Props) => { const { data, defaultValue, onChange } = params; + console.log('data', data); + const attr = convertKeysToCamelCase(data.ui?.attr || {}); const onChangeDate: DatePickerProps['onChange'] = (date, dateString) => { onChange(dateString); }; - return ; + return ( + + ); }; diff --git a/web/components/flow/node-renderer/time-picker.tsx b/web/components/flow/node-renderer/time-picker.tsx index bfe3f8cb7..95e13524b 100644 --- a/web/components/flow/node-renderer/time-picker.tsx +++ b/web/components/flow/node-renderer/time-picker.tsx @@ -17,5 +17,5 @@ export const RenderTimePicker = (params: TextAreaProps) => { onChange(timeString); }; - return ; + return ; }; From 9796bf9a9d396973e1808d179f9f4deda1cbfdfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A8=E6=AC=A3?= Date: Thu, 15 Aug 2024 00:32:31 +0800 Subject: [PATCH 21/89] style: delete redundant font styles --- web/components/flow/canvas-node.tsx | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/web/components/flow/canvas-node.tsx b/web/components/flow/canvas-node.tsx index 0177aa4f1..59129718f 100644 --- a/web/components/flow/canvas-node.tsx +++ b/web/components/flow/canvas-node.tsx @@ -26,7 +26,7 @@ const CanvasNode: React.FC = ({ data }) => { const { inputs, outputs, parameters, flow_type: flowType } = node; const [isHovered, setIsHovered] = useState(false); const reactFlow = useReactFlow(); - + function onHover() { setIsHovered(true); } @@ -76,7 +76,7 @@ const CanvasNode: React.FC = ({ data }) => { return ( - + {outputs?.map((output, index) => ( ))} @@ -88,9 +88,7 @@ const CanvasNode: React.FC = ({ data }) => { return ( - - - + ); } @@ -146,7 +144,7 @@ const CanvasNode: React.FC = ({ data }) => { {inputs?.length > 0 && ( - + {inputs?.map((input, index) => ( ))} @@ -157,7 +155,7 @@ const CanvasNode: React.FC = ({ data }) => { {parameters?.length > 0 && ( - + {parameters?.map((parameter, index) => ( ))} From 6f31c01fb37a4eaa3304ff3943150a7ab2b8a2e8 Mon Sep 17 00:00:00 2001 From: yanzhiyong <932374019@qq.com> Date: Thu, 15 Aug 2024 01:27:09 +0800 Subject: [PATCH 22/89] =?UTF-8?q?feat:=20add=20components=20=EF=BC=88codeE?= =?UTF-8?q?ditor80%=20updata=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/components/flow/node-param-handler.tsx | 12 ++++++- .../flow/node-renderer/codeEditor.tsx | 28 ++++++++++++++++ web/components/flow/node-renderer/index.ts | 2 ++ web/components/flow/node-renderer/updata.tsx | 32 +++++++++++++++++++ web/utils/request.ts | 4 ++- 5 files changed, 76 insertions(+), 2 deletions(-) create mode 100644 web/components/flow/node-renderer/codeEditor.tsx create mode 100644 web/components/flow/node-renderer/updata.tsx diff --git a/web/components/flow/node-param-handler.tsx b/web/components/flow/node-param-handler.tsx index 58e9c8d78..233d4a033 100644 --- a/web/components/flow/node-param-handler.tsx +++ b/web/components/flow/node-param-handler.tsx @@ -15,7 +15,11 @@ import { RenderTreeSelect, RenderTimePicker, RenderTextArea, + RenderUpdata, + RenderCodeEditor, } from './node-renderer'; +import MonacoEditor from '@/components/chat/monaco-editor' +// C:\Users\Administrator\Desktop\ai\DB-GPT\web\components\chat\monaco-editor.tsx import { convertKeysToCamelCase } from '@/utils/flow'; interface NodeParamHandlerProps { @@ -130,7 +134,9 @@ const NodeParamHandler: React.FC = ({ node, data, label, case 'select': return ; case 'text_area': - return ; + return ; + + // return ; case 'slider': return ; case 'date_picker': @@ -139,6 +145,10 @@ const NodeParamHandler: React.FC = ({ node, data, label, return ; case 'tree_select': return ; + case 'upload': + return ; + case 'code_editor': + return ; default: return null; } diff --git a/web/components/flow/node-renderer/codeEditor.tsx b/web/components/flow/node-renderer/codeEditor.tsx new file mode 100644 index 000000000..c56d0626c --- /dev/null +++ b/web/components/flow/node-renderer/codeEditor.tsx @@ -0,0 +1,28 @@ +import React, { useState } from 'react'; +import { UnControlled as CodeMirror } from 'react-codemirror2'; + + +export const RenderCodeEditor = ()=> { + const [code, setCode] = useState('// 输入你的代码'); + + const handleChange = (editor, data, value) => { + // 处理代码变化 + setCode(value); + }; + + return ( + + ); +} + + + diff --git a/web/components/flow/node-renderer/index.ts b/web/components/flow/node-renderer/index.ts index 4eb0d5c39..bb12bea61 100644 --- a/web/components/flow/node-renderer/index.ts +++ b/web/components/flow/node-renderer/index.ts @@ -8,3 +8,5 @@ export * from './textarea'; export * from './slider'; export * from './time-picker'; export * from './tree-select'; +export * from './codeEditor'; +export * from './updata'; diff --git a/web/components/flow/node-renderer/updata.tsx b/web/components/flow/node-renderer/updata.tsx new file mode 100644 index 000000000..a1ae57ac6 --- /dev/null +++ b/web/components/flow/node-renderer/updata.tsx @@ -0,0 +1,32 @@ +import React from 'react'; +import { UploadOutlined } from '@ant-design/icons'; +import type { UploadProps } from 'antd'; +import { Button, message, Upload } from 'antd'; + +const props: UploadProps = { + name: 'file', + action: 'https://660d2bd96ddfa2943b33731c.mockapi.io/api/upload', + headers: { + authorization: 'authorization-text', + }, + onChange(info) { + if (info.file.status !== 'uploading') { + console.log(info.file, info.fileList); + } + if (info.file.status === 'done') { + message.success(`${info.file.name} file uploaded successfully`); + } else if (info.file.status === 'error') { + message.error(`${info.file.name} file upload failed.`); + } + }, +}; + +export const RenderUpdata: React.FC = () => ( + + + }>上传数据 + + + +); + diff --git a/web/utils/request.ts b/web/utils/request.ts index 102b7fdbd..10e38d27f 100644 --- a/web/utils/request.ts +++ b/web/utils/request.ts @@ -1,7 +1,9 @@ import { message } from 'antd'; import axios from './ctx-axios'; import { isPlainObject } from 'lodash'; - +import 'codemirror/lib/codemirror.css'; +import 'codemirror/theme/material.css'; // 引入你喜欢的主题 +import 'codemirror/mode/javascript/javascript'; // 引入JavaScript语言模式 const DEFAULT_HEADERS = { 'content-type': 'application/json', }; From 6284f6aaaa313b544b6b60b8f7f9533d066d2f1b Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Thu, 15 Aug 2024 09:22:09 +0800 Subject: [PATCH 23/89] feat(core): Add AWEL flow radio component --- dbgpt/core/awel/flow/ui.py | 46 +++++++++++------ dbgpt/serve/flow/api/endpoints.py | 16 +++--- dbgpt/serve/flow/api/schemas.py | 26 +++++++++- examples/awel/awel_flow_ui_components.py | 64 ++++++++++++++++++++++-- 4 files changed, 124 insertions(+), 28 deletions(-) diff --git a/dbgpt/core/awel/flow/ui.py b/dbgpt/core/awel/flow/ui.py index 875547e9a..c763859b0 100644 --- a/dbgpt/core/awel/flow/ui.py +++ b/dbgpt/core/awel/flow/ui.py @@ -11,6 +11,7 @@ "select", "cascader", "checkbox", + "radio", "date_picker", "input", "text_area", @@ -175,6 +176,12 @@ def check_parameter(self, parameter_dict: Dict[str, Any]): self._check_options(parameter_dict.get("options", {})) +class UIRadio(UICheckbox): + """Radio component.""" + + ui_type: Literal["radio"] = Field("radio", frozen=True) # type: ignore + + class UIDatePicker(UIComponent): """Date picker component.""" @@ -232,23 +239,31 @@ class UIAttribute(StatusMixin, UIComponent.UIAttribute): class UITextArea(PanelEditorMixin, UIInput): """Text area component.""" - class AutoSize(BaseModel): - """Auto size configuration.""" + class UIAttribute(UIInput.UIAttribute): + """Text area attribute.""" - min_rows: Optional[int] = Field( - None, - description="The minimum number of rows", - ) - max_rows: Optional[int] = Field( + class AutoSize(BaseModel): + """Auto size configuration.""" + + min_rows: Optional[int] = Field( + None, + description="The minimum number of rows", + ) + max_rows: Optional[int] = Field( + None, + description="The maximum number of rows", + ) + + auto_size: Optional[Union[bool, AutoSize]] = Field( None, - description="The maximum number of rows", + description="Whether the height of the textarea automatically adjusts " + "based on the content", ) ui_type: Literal["text_area"] = Field("text_area", frozen=True) # type: ignore - autosize: Optional[Union[bool, AutoSize]] = Field( + attr: Optional[UIAttribute] = Field( None, - description="Whether the height of the textarea automatically adjusts based " - "on the content", + description="The attributes of the component", ) @@ -430,8 +445,9 @@ class UICodeEditor(UITextArea): class DefaultUITextArea(UITextArea): """Default text area component.""" - autosize: Union[bool, UITextArea.AutoSize] = Field( - default_factory=lambda: UITextArea.AutoSize(min_rows=2, max_rows=40), - description="Whether the height of the textarea automatically adjusts based " - "on the content", + attr: Optional[UITextArea.UIAttribute] = Field( + default_factory=lambda: UITextArea.UIAttribute( + auto_size=UITextArea.UIAttribute.AutoSize(min_rows=2, max_rows=40) + ), + description="The attributes of the component", ) diff --git a/dbgpt/serve/flow/api/endpoints.py b/dbgpt/serve/flow/api/endpoints.py index 28c6a28c4..4b28641e8 100644 --- a/dbgpt/serve/flow/api/endpoints.py +++ b/dbgpt/serve/flow/api/endpoints.py @@ -1,20 +1,12 @@ import json from functools import cache -from typing import Dict, List, Literal, Optional, Union +from typing import Dict, List, Optional, Union from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer from dbgpt.component import SystemApp from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata -from dbgpt.core.interface.variables import ( - BUILTIN_VARIABLES_CORE_FLOW_NODES, - BUILTIN_VARIABLES_CORE_FLOWS, - BUILTIN_VARIABLES_CORE_SECRETS, - BUILTIN_VARIABLES_CORE_VARIABLES, - BuiltinVariablesProvider, - StorageVariables, -) from dbgpt.serve.core import Result, blocking_func_to_async from dbgpt.util import PaginationResult @@ -330,6 +322,12 @@ async def update_variables( return Result.succ(res) +@router.post("/flow/debug") +async def debug(): + """Debug the flow.""" + # TODO: Implement the debug endpoint + + def init_endpoints(system_app: SystemApp) -> None: """Initialize the endpoints""" from .variables_provider import ( diff --git a/dbgpt/serve/flow/api/schemas.py b/dbgpt/serve/flow/api/schemas.py index e63d3e6ce..537996fe7 100644 --- a/dbgpt/serve/flow/api/schemas.py +++ b/dbgpt/serve/flow/api/schemas.py @@ -1,6 +1,7 @@ -from typing import Any, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Union from dbgpt._private.pydantic import BaseModel, ConfigDict, Field +from dbgpt.core.awel import CommonLLMHttpRequestBody from dbgpt.core.awel.flow.flow_factory import FlowPanel from dbgpt.core.awel.util.parameter_util import RefreshOptionRequest @@ -113,3 +114,26 @@ class RefreshNodeRequest(BaseModel): title="The refresh options", description="The refresh options", ) + + +class FlowDebugRequest(BaseModel): + """Flow response model""" + + model_config = ConfigDict(title=f"FlowDebugRequest") + flow: ServeRequest = Field( + ..., + title="The flow to debug", + description="The flow to debug", + ) + request: Union[CommonLLMHttpRequestBody, Dict[str, Any]] = Field( + ..., + title="The request to debug", + description="The request to debug", + ) + variables: Optional[Dict[str, Any]] = Field( + None, + title="The variables to debug", + description="The variables to debug", + ) + user_name: Optional[str] = Field(None, description="User name") + sys_code: Optional[str] = Field(None, description="System code") diff --git a/examples/awel/awel_flow_ui_components.py b/examples/awel/awel_flow_ui_components.py index 9ce611c39..cba0c14df 100644 --- a/examples/awel/awel_flow_ui_components.py +++ b/examples/awel/awel_flow_ui_components.py @@ -206,7 +206,7 @@ class ExampleFlowCheckboxOperator(MapOperator[str, str]): OptionValue(label="Orange", name="orange", value="orange"), OptionValue(label="Pear", name="pear", value="pear"), ], - ui=ui.UICheckbox(attr=ui.UICheckbox.UIAttribute(show_search=True)), + ui=ui.UICheckbox(), ) ], inputs=[ @@ -236,6 +236,59 @@ async def map(self, user_name: str) -> str: return "Your name is %s, and you like %s." % (user_name, ", ".join(self.fruits)) +class ExampleFlowRadioOperator(MapOperator[str, str]): + """An example flow operator that includes a radio as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Radio", + name="example_flow_radio", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a radio as parameter.", + parameters=[ + Parameter.build_from( + "Fruits Selector", + "fruits", + type=str, + optional=True, + default=None, + placeholder="Select the fruits", + description="The fruits you like.", + options=[ + OptionValue(label="Apple", name="apple", value="apple"), + OptionValue(label="Banana", name="banana", value="banana"), + OptionValue(label="Orange", name="orange", value="orange"), + OptionValue(label="Pear", name="pear", value="pear"), + ], + ui=ui.UIRadio(), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Fruits", + "fruits", + str, + description="User's favorite fruits.", + ) + ], + ) + + def __init__(self, fruits: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + self.fruits = fruits + + async def map(self, user_name: str) -> str: + """Map the user name to the fruits.""" + return "Your name is %s, and you like %s." % (user_name, self.fruits) + + class ExampleFlowDatePickerOperator(MapOperator[str, str]): """An example flow operator that includes a date picker as parameter.""" @@ -348,8 +401,13 @@ class ExampleFlowTextAreaOperator(MapOperator[str, str]): placeholder="Please input your comment", description="The comment you want to say.", ui=ui.UITextArea( - attr=ui.UITextArea.UIAttribute(show_count=True, maxlength=1000), - autosize=ui.UITextArea.AutoSize(min_rows=2, max_rows=6), + attr=ui.UITextArea.UIAttribute( + show_count=True, + maxlength=1000, + auto_size=ui.UITextArea.UIAttribute.AutoSize( + min_rows=2, max_rows=6 + ), + ), ), ) ], From 2696d9a9e0ad4b64df107336efc04a09832088d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A8=E6=AC=A3?= Date: Thu, 15 Aug 2024 10:38:01 +0800 Subject: [PATCH 24/89] feat: add Password component to flow --- web/components/flow/node-param-handler.tsx | 3 +++ web/components/flow/node-renderer/index.ts | 1 + web/components/flow/node-renderer/input.tsx | 1 + web/components/flow/node-renderer/password.tsx | 18 ++++++++++++++++++ 4 files changed, 23 insertions(+) create mode 100644 web/components/flow/node-renderer/password.tsx diff --git a/web/components/flow/node-param-handler.tsx b/web/components/flow/node-param-handler.tsx index 2dbccd3fc..165eb59bb 100644 --- a/web/components/flow/node-param-handler.tsx +++ b/web/components/flow/node-param-handler.tsx @@ -15,6 +15,7 @@ import { RenderTreeSelect, RenderTimePicker, RenderTextArea, + RenderPassword, } from './node-renderer'; import { convertKeysToCamelCase } from '@/utils/flow'; @@ -138,6 +139,8 @@ const NodeParamHandler: React.FC = ({ node, data, label, case 'time_picker': return ; case 'tree_select': + return ; + case 'password': return ; default: return null; diff --git a/web/components/flow/node-renderer/index.ts b/web/components/flow/node-renderer/index.ts index 4eb0d5c39..59f1e44ef 100644 --- a/web/components/flow/node-renderer/index.ts +++ b/web/components/flow/node-renderer/index.ts @@ -8,3 +8,4 @@ export * from './textarea'; export * from './slider'; export * from './time-picker'; export * from './tree-select'; +export * from './password'; diff --git a/web/components/flow/node-renderer/input.tsx b/web/components/flow/node-renderer/input.tsx index 538d14081..60c559baa 100644 --- a/web/components/flow/node-renderer/input.tsx +++ b/web/components/flow/node-renderer/input.tsx @@ -18,6 +18,7 @@ export const RenderInput = (params: Props) => { className="w-full" placeholder="please input" defaultValue={defaultValue} + allowClear onChange={(e) => { onChange(e.target.value); }} diff --git a/web/components/flow/node-renderer/password.tsx b/web/components/flow/node-renderer/password.tsx new file mode 100644 index 000000000..93dec6e85 --- /dev/null +++ b/web/components/flow/node-renderer/password.tsx @@ -0,0 +1,18 @@ +import { IFlowNodeParameter } from '@/types/flow'; +import { Input } from 'antd'; +import { convertKeysToCamelCase } from '@/utils/flow'; + +const { Password } = Input; + +type TextAreaProps = { + data: IFlowNodeParameter; + defaultValue: any; + onChange: (value: any) => void; +}; + +export const RenderPassword = (params: TextAreaProps) => { + const { data, defaultValue, onChange } = params; + const attr = convertKeysToCamelCase(data.ui?.attr || {}); + + return ; +}; From 27ee22df47c780b46b0cf3cd916e4f3b7d4cf883 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=A5=E5=BF=97=E5=8B=87?= Date: Thu, 15 Aug 2024 10:59:07 +0800 Subject: [PATCH 25/89] feat:Remove test code --- web/components/flow/node-param-handler.tsx | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/web/components/flow/node-param-handler.tsx b/web/components/flow/node-param-handler.tsx index 233d4a033..2b4273203 100644 --- a/web/components/flow/node-param-handler.tsx +++ b/web/components/flow/node-param-handler.tsx @@ -134,9 +134,7 @@ const NodeParamHandler: React.FC = ({ node, data, label, case 'select': return ; case 'text_area': - return ; - - // return ; + return ; case 'slider': return ; case 'date_picker': From b264325151543cb1ba25156155163462b7ac5ac4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=A5=E5=BF=97=E5=8B=87?= Date: Thu, 15 Aug 2024 18:24:53 +0800 Subject: [PATCH 26/89] feat:code editor --- .../flow/node-renderer/codeEditor.tsx | 121 ++++++++++++++---- web/types/flow.ts | 5 + 2 files changed, 100 insertions(+), 26 deletions(-) diff --git a/web/components/flow/node-renderer/codeEditor.tsx b/web/components/flow/node-renderer/codeEditor.tsx index c56d0626c..eed5ec256 100644 --- a/web/components/flow/node-renderer/codeEditor.tsx +++ b/web/components/flow/node-renderer/codeEditor.tsx @@ -1,28 +1,97 @@ -import React, { useState } from 'react'; -import { UnControlled as CodeMirror } from 'react-codemirror2'; - - -export const RenderCodeEditor = ()=> { - const [code, setCode] = useState('// 输入你的代码'); - - const handleChange = (editor, data, value) => { - // 处理代码变化 - setCode(value); - }; - - return ( - - ); -} - +import React, { useState, useMemo } from 'react'; +import { Button, Modal } from 'antd'; +import Editor from '@monaco-editor/react'; +import { IFlowNodeParameter } from '@/types/flow'; +// import { MonacoEditor } from '../../chat/monaco-editor'; +// import { github, githubDark } from './ob-editor/theme'; +import { github, githubDark } from '../../chat/ob-editor/theme'; +type Props = { + data: IFlowNodeParameter; + defaultValue: any; + onChange: (value: any) => void; +}; +export const RenderCodeEditor = (params: Props) => { + const { data, defaultValue, onChange } = params; + + const [isModalOpen, setIsModalOpen] = useState(false); + const showModal = () => { + setIsModalOpen(true); + }; + + const handleOk = () => { + setIsModalOpen(false); + }; + + const handleCancel = () => { + setIsModalOpen(false); + }; + /** + * 设置弹窗宽度 + */ + const modalWidth = useMemo(() => { + if (data?.ui?.editor?.width) { + return data?.ui?.editor?.width + 100 + } + return '80%'; + }, [data?.ui?.editor?.width]); + + return ( + + + 打开代码编辑器 + + + + + {/* { + console.log(value); + onChange(value) + }} + options={{ + theme: {github}, // 编辑器主题颜色 + folding: true, // 是否折叠 + foldingHighlight: true, // 折叠等高线 + foldingStrategy: 'indentation', // 折叠方式 auto | indentation + showFoldingControls: 'always', // 是否一直显示折叠 always | mouseover + disableLayerHinting: true, // 等宽优化 + emptySelectionClipboard: false, // 空选择剪切板 + selectionClipboard: false, // 选择剪切板 + automaticLayout: true, // 自动布局 + codeLens: false, // 代码镜头 + scrollBeyondLastLine: false, // 滚动完最后一行后再滚动一屏幕 + colorDecorators: true, // 颜色装饰器 + accessibilitySupport: 'auto', // 辅助功能支持 "auto" | "off" | "on" + lineNumbers: 'on', // 行号 取值: "on" | "off" | "relative" | "interval" | function + lineNumbersMinChars: 5, // 行号最小字符 number + readOnly: false, //是否只读 取值 true | false + }} + /> */} + + + + ); +}; diff --git a/web/types/flow.ts b/web/types/flow.ts index b62df1840..175047e31 100644 --- a/web/types/flow.ts +++ b/web/types/flow.ts @@ -54,10 +54,15 @@ export type IFlowNodeParameter = { export type IFlowNodeParameterUI = { ui_type: string; + language: string; attr: { disabled: boolean; [key: string]: any; }; + editor: { + width: Number; + height: Number; + }; show_input: boolean; }; From b301860d0f3e85ce4f8dab2794b39d2a10fdac3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=A5=E5=BF=97=E5=8B=87?= Date: Thu, 15 Aug 2024 19:41:07 +0800 Subject: [PATCH 27/89] fix:Attribute error --- .../flow/node-renderer/codeEditor.tsx | 42 ++---------- .../flow/node-renderer/textarea.tsx | 3 +- .../flow/node-renderer/tree-select.tsx | 66 ++----------------- 3 files changed, 12 insertions(+), 99 deletions(-) diff --git a/web/components/flow/node-renderer/codeEditor.tsx b/web/components/flow/node-renderer/codeEditor.tsx index eed5ec256..3aeeb9b6b 100644 --- a/web/components/flow/node-renderer/codeEditor.tsx +++ b/web/components/flow/node-renderer/codeEditor.tsx @@ -2,9 +2,6 @@ import React, { useState, useMemo } from 'react'; import { Button, Modal } from 'antd'; import Editor from '@monaco-editor/react'; import { IFlowNodeParameter } from '@/types/flow'; -// import { MonacoEditor } from '../../chat/monaco-editor'; -// import { github, githubDark } from './ob-editor/theme'; -import { github, githubDark } from '../../chat/ob-editor/theme'; type Props = { data: IFlowNodeParameter; @@ -14,6 +11,7 @@ type Props = { export const RenderCodeEditor = (params: Props) => { const { data, defaultValue, onChange } = params; + const attr = convertKeysToCamelCase(data.ui?.attr || {}); const [isModalOpen, setIsModalOpen] = useState(false); const showModal = () => { @@ -37,8 +35,10 @@ export const RenderCodeEditor = (params: Props) => { return '80%'; }, [data?.ui?.editor?.width]); + + return ( - + 打开代码编辑器 @@ -47,7 +47,7 @@ export const RenderCodeEditor = (params: Props) => { {...data?.ui?.attr} width={data?.ui?.editor?.width || '100%'} value={defaultValue} - style={{padding:'10px'}} + style={{ padding: '10px' }} height={data?.ui?.editor?.height || 200} defaultLanguage={data?.ui?.language} onChange={onChange} @@ -59,39 +59,7 @@ export const RenderCodeEditor = (params: Props) => { wordWrap: 'on', }} /> - - {/* { - console.log(value); - onChange(value) - }} - options={{ - theme: {github}, // 编辑器主题颜色 - folding: true, // 是否折叠 - foldingHighlight: true, // 折叠等高线 - foldingStrategy: 'indentation', // 折叠方式 auto | indentation - showFoldingControls: 'always', // 是否一直显示折叠 always | mouseover - disableLayerHinting: true, // 等宽优化 - emptySelectionClipboard: false, // 空选择剪切板 - selectionClipboard: false, // 选择剪切板 - automaticLayout: true, // 自动布局 - codeLens: false, // 代码镜头 - scrollBeyondLastLine: false, // 滚动完最后一行后再滚动一屏幕 - colorDecorators: true, // 颜色装饰器 - accessibilitySupport: 'auto', // 辅助功能支持 "auto" | "off" | "on" - lineNumbers: 'on', // 行号 取值: "on" | "off" | "relative" | "interval" | function - lineNumbersMinChars: 5, // 行号最小字符 number - readOnly: false, //是否只读 取值 true | false - }} - /> */} - ); }; diff --git a/web/components/flow/node-renderer/textarea.tsx b/web/components/flow/node-renderer/textarea.tsx index 5f8c55ac0..a2f0df4aa 100644 --- a/web/components/flow/node-renderer/textarea.tsx +++ b/web/components/flow/node-renderer/textarea.tsx @@ -13,10 +13,11 @@ type TextAreaProps = { export const RenderTextArea = (params: TextAreaProps) => { const { data, defaultValue, onChange } = params; convertKeysToCamelCase(data?.ui?.attr?.autosize || {}); + const attr = convertKeysToCamelCase(data.ui?.attr || {}); return ( - onChange(e.target.value)} {...data.ui.attr.autosize} rows={4} /> + onChange(e.target.value)} {...data.ui.attr.autosize} rows={4} /> ); }; diff --git a/web/components/flow/node-renderer/tree-select.tsx b/web/components/flow/node-renderer/tree-select.tsx index 7acc3d73a..d2a4400d6 100644 --- a/web/components/flow/node-renderer/tree-select.tsx +++ b/web/components/flow/node-renderer/tree-select.tsx @@ -3,86 +3,30 @@ import { TreeSelect } from 'antd'; import type { TreeSelectProps } from 'antd'; import { IFlowNodeParameter } from '@/types/flow'; import { Label } from '@mui/icons-material'; +import { convertKeysToCamelCase } from '@/utils/flow'; type TextAreaProps = { data: IFlowNodeParameter; defaultValue: any; onChange: (value: any) => void; }; -const treeData = [ - { - value: 'parent 1', - title: 'parent 1', - children: [ - { - value: 'parent 1-0', - title: 'parent 1-0', - children: [ - { - value: 'leaf1', - title: 'leaf1', - }, - { - value: 'leaf2', - title: 'leaf2', - }, - { - value: 'leaf3', - title: 'leaf3', - }, - { - value: 'leaf4', - title: 'leaf4', - }, - { - value: 'leaf5', - title: 'leaf5', - }, - { - value: 'leaf6', - title: 'leaf6', - }, - ], - }, - { - value: 'parent 1-1', - title: 'parent 1-1', - children: [ - { - value: 'leaf11', - title: leaf11, - }, - ], - }, - ], - }, -]; export const RenderTreeSelect = (params: TextAreaProps) => { const { data, defaultValue, onChange } = params; - // console.log(data.options); - // const [value, setValue] = useState(); + const attr = convertKeysToCamelCase(data.ui?.attr || {}); - // const onChange = (newValue: string) => { - // setValue(newValue); - // }; const [dropdownVisible, setDropdownVisible] = useState(false); const handleDropdownVisibleChange = (visible: boolean | ((prevState: boolean) => boolean)) => { setDropdownVisible(visible); - - // 你可以在这里执行更多的逻辑,比如发送请求、更新状态等 - console.log('Dropdown is now:', visible ? 'visible' : 'hidden'); - }; - - const focus = () => { - // console.log('focus=========='); }; + console.log(data); + return ( Date: Thu, 15 Aug 2024 20:52:38 +0800 Subject: [PATCH 28/89] fix:Description error --- web/components/flow/node-param-handler.tsx | 4 +-- web/components/flow/node-renderer/index.ts | 2 +- .../flow/node-renderer/tree-select.tsx | 25 ------------------- .../node-renderer/{updata.tsx => upload.tsx} | 20 ++++++++++----- 4 files changed, 17 insertions(+), 34 deletions(-) rename web/components/flow/node-renderer/{updata.tsx => upload.tsx} (62%) diff --git a/web/components/flow/node-param-handler.tsx b/web/components/flow/node-param-handler.tsx index 2b4273203..2eadffe08 100644 --- a/web/components/flow/node-param-handler.tsx +++ b/web/components/flow/node-param-handler.tsx @@ -15,7 +15,7 @@ import { RenderTreeSelect, RenderTimePicker, RenderTextArea, - RenderUpdata, + RenderUpload, RenderCodeEditor, } from './node-renderer'; import MonacoEditor from '@/components/chat/monaco-editor' @@ -144,7 +144,7 @@ const NodeParamHandler: React.FC = ({ node, data, label, case 'tree_select': return ; case 'upload': - return ; + return ; case 'code_editor': return ; default: diff --git a/web/components/flow/node-renderer/index.ts b/web/components/flow/node-renderer/index.ts index bb12bea61..7049aab99 100644 --- a/web/components/flow/node-renderer/index.ts +++ b/web/components/flow/node-renderer/index.ts @@ -9,4 +9,4 @@ export * from './slider'; export * from './time-picker'; export * from './tree-select'; export * from './codeEditor'; -export * from './updata'; +export * from './upload'; diff --git a/web/components/flow/node-renderer/tree-select.tsx b/web/components/flow/node-renderer/tree-select.tsx index d2a4400d6..5db73438f 100644 --- a/web/components/flow/node-renderer/tree-select.tsx +++ b/web/components/flow/node-renderer/tree-select.tsx @@ -1,8 +1,6 @@ import React, { useState } from 'react'; import { TreeSelect } from 'antd'; -import type { TreeSelectProps } from 'antd'; import { IFlowNodeParameter } from '@/types/flow'; -import { Label } from '@mui/icons-material'; import { convertKeysToCamelCase } from '@/utils/flow'; type TextAreaProps = { @@ -14,14 +12,6 @@ export const RenderTreeSelect = (params: TextAreaProps) => { const { data, defaultValue, onChange } = params; const attr = convertKeysToCamelCase(data.ui?.attr || {}); - const [dropdownVisible, setDropdownVisible] = useState(false); - - const handleDropdownVisibleChange = (visible: boolean | ((prevState: boolean) => boolean)) => { - setDropdownVisible(visible); - }; - console.log(data); - - return ( { treeDefaultExpandAll onChange={onChange} treeData={data.options} - onDropdownVisibleChange={handleDropdownVisibleChange} /> - - // TODO: Implement the TreeSelect component - // document.body} - // /> ); }; diff --git a/web/components/flow/node-renderer/updata.tsx b/web/components/flow/node-renderer/upload.tsx similarity index 62% rename from web/components/flow/node-renderer/updata.tsx rename to web/components/flow/node-renderer/upload.tsx index a1ae57ac6..1658cd68f 100644 --- a/web/components/flow/node-renderer/updata.tsx +++ b/web/components/flow/node-renderer/upload.tsx @@ -2,6 +2,7 @@ import React from 'react'; import { UploadOutlined } from '@ant-design/icons'; import type { UploadProps } from 'antd'; import { Button, message, Upload } from 'antd'; +import { convertKeysToCamelCase } from '@/utils/flow'; const props: UploadProps = { name: 'file', @@ -21,12 +22,19 @@ const props: UploadProps = { }, }; -export const RenderUpdata: React.FC = () => ( - - - }>上传数据 - - +export const RenderUpload: React.FC = (params) => ( + const { data, defaultValue, onChange } = params; + +const attr = convertKeysToCamelCase(data.ui?.attr || {}); + +return ( + + + }>上传数据 + + +) + ); From c5ba26461e209cd10c2d585fc2f9f86cfbdbf9c5 Mon Sep 17 00:00:00 2001 From: yanzhiyong <932374019@qq.com> Date: Fri, 16 Aug 2024 00:19:21 +0800 Subject: [PATCH 29/89] =?UTF-8?q?feat:1=E3=80=81Set=20language=20=202?= =?UTF-8?q?=E3=80=81Remove=20junk=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/app/i18n.ts | 6 +++ web/components/flow/node-param-handler.tsx | 3 -- .../flow/node-renderer/codeEditor.tsx | 12 ++--- web/components/flow/node-renderer/upload.tsx | 44 +++++++++---------- web/utils/request.ts | 3 -- 5 files changed, 35 insertions(+), 33 deletions(-) diff --git a/web/app/i18n.ts b/web/app/i18n.ts index 1ecde2e6c..a03216c55 100644 --- a/web/app/i18n.ts +++ b/web/app/i18n.ts @@ -2,6 +2,9 @@ import i18n from 'i18next'; import { initReactI18next } from 'react-i18next'; const en = { + UploadData: 'Upload Data', + CodeEditor: 'Code Editor:', + openCodeEditor:'Open Code Editor', Knowledge_Space: 'Knowledge', space: 'space', Vector: 'Vector', @@ -234,6 +237,9 @@ export interface Resources { } const zh: Resources['translation'] = { + UploadData: '上传数据', + CodeEditor: '代码编辑:', + openCodeEditor: '打开代码编辑器', Knowledge_Space: '知识库', space: '知识库', Vector: '向量', diff --git a/web/components/flow/node-param-handler.tsx b/web/components/flow/node-param-handler.tsx index 2eadffe08..31357d5ad 100644 --- a/web/components/flow/node-param-handler.tsx +++ b/web/components/flow/node-param-handler.tsx @@ -18,9 +18,6 @@ import { RenderUpload, RenderCodeEditor, } from './node-renderer'; -import MonacoEditor from '@/components/chat/monaco-editor' -// C:\Users\Administrator\Desktop\ai\DB-GPT\web\components\chat\monaco-editor.tsx -import { convertKeysToCamelCase } from '@/utils/flow'; interface NodeParamHandlerProps { node: IFlowNode; diff --git a/web/components/flow/node-renderer/codeEditor.tsx b/web/components/flow/node-renderer/codeEditor.tsx index 3aeeb9b6b..03586f25c 100644 --- a/web/components/flow/node-renderer/codeEditor.tsx +++ b/web/components/flow/node-renderer/codeEditor.tsx @@ -2,6 +2,8 @@ import React, { useState, useMemo } from 'react'; import { Button, Modal } from 'antd'; import Editor from '@monaco-editor/react'; import { IFlowNodeParameter } from '@/types/flow'; +import { convertKeysToCamelCase } from '@/utils/flow'; +import { useTranslation } from 'react-i18next'; type Props = { data: IFlowNodeParameter; @@ -10,6 +12,8 @@ type Props = { }; export const RenderCodeEditor = (params: Props) => { + const { t } = useTranslation(); + const { data, defaultValue, onChange } = params; const attr = convertKeysToCamelCase(data.ui?.attr || {}); @@ -35,14 +39,12 @@ export const RenderCodeEditor = (params: Props) => { return '80%'; }, [data?.ui?.editor?.width]); - - return ( - 打开代码编辑器 + {t('openCodeEditor')} - + { height={data?.ui?.editor?.height || 200} defaultLanguage={data?.ui?.language} onChange={onChange} - theme='vs-dark' // 编辑器主题颜色 + theme='vs-dark' options={{ minimap: { enabled: false, diff --git a/web/components/flow/node-renderer/upload.tsx b/web/components/flow/node-renderer/upload.tsx index 1658cd68f..35e1ddd75 100644 --- a/web/components/flow/node-renderer/upload.tsx +++ b/web/components/flow/node-renderer/upload.tsx @@ -3,6 +3,8 @@ import { UploadOutlined } from '@ant-design/icons'; import type { UploadProps } from 'antd'; import { Button, message, Upload } from 'antd'; import { convertKeysToCamelCase } from '@/utils/flow'; +import { IFlowNodeParameter } from '@/types/flow'; +import { useTranslation } from 'react-i18next'; const props: UploadProps = { name: 'file', @@ -10,31 +12,29 @@ const props: UploadProps = { headers: { authorization: 'authorization-text', }, - onChange(info) { - if (info.file.status !== 'uploading') { - console.log(info.file, info.fileList); - } - if (info.file.status === 'done') { - message.success(`${info.file.name} file uploaded successfully`); - } else if (info.file.status === 'error') { - message.error(`${info.file.name} file upload failed.`); - } - }, }; -export const RenderUpload: React.FC = (params) => ( - const { data, defaultValue, onChange } = params; +type Props = { + data: IFlowNodeParameter; + defaultValue: any; + onChange: (value: any) => void; +}; + +export const RenderUpload = (params: Props) => { + const { t } = useTranslation(); + + const { data, defaultValue, onChange } = params; + + const attr = convertKeysToCamelCase(data.ui?.attr || {}); -const attr = convertKeysToCamelCase(data.ui?.attr || {}); + return ( + + + }>{t('UploadData')} + + + ) -return ( - - - }>上传数据 - - -) - -); +} diff --git a/web/utils/request.ts b/web/utils/request.ts index 10e38d27f..8a490bfb5 100644 --- a/web/utils/request.ts +++ b/web/utils/request.ts @@ -1,9 +1,6 @@ import { message } from 'antd'; import axios from './ctx-axios'; import { isPlainObject } from 'lodash'; -import 'codemirror/lib/codemirror.css'; -import 'codemirror/theme/material.css'; // 引入你喜欢的主题 -import 'codemirror/mode/javascript/javascript'; // 引入JavaScript语言模式 const DEFAULT_HEADERS = { 'content-type': 'application/json', }; From 77be0e6eace10eab209f0a98dffd58794541b899 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=A5=E5=BF=97=E5=8B=87?= Date: Fri, 16 Aug 2024 16:47:27 +0800 Subject: [PATCH 30/89] =?UTF-8?q?fix:=201=E3=80=81components=20treeSelect?= =?UTF-8?q?=20and=20Password=20switch=20locations=202=E3=80=81remove=20con?= =?UTF-8?q?sole?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/components/flow/node-param-handler.tsx | 4 ++-- web/components/flow/node-renderer/date-picker.tsx | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/web/components/flow/node-param-handler.tsx b/web/components/flow/node-param-handler.tsx index b7553ac51..918568cc9 100644 --- a/web/components/flow/node-param-handler.tsx +++ b/web/components/flow/node-param-handler.tsx @@ -140,9 +140,9 @@ const NodeParamHandler: React.FC = ({ node, data, label, case 'time_picker': return ; case 'tree_select': - return ; - case 'password': return ; + case 'password': + return ; case 'upload': return ; case 'code_editor': diff --git a/web/components/flow/node-renderer/date-picker.tsx b/web/components/flow/node-renderer/date-picker.tsx index 478295c75..a65391d79 100644 --- a/web/components/flow/node-renderer/date-picker.tsx +++ b/web/components/flow/node-renderer/date-picker.tsx @@ -11,7 +11,6 @@ type Props = { export const RenderDatePicker = (params: Props) => { const { data, defaultValue, onChange } = params; - console.log('data', data); const attr = convertKeysToCamelCase(data.ui?.attr || {}); From 6ad7199c3b83c33ad0d6189b1a0a31e18f0fac23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=A5=E5=BF=97=E5=8B=87?= Date: Fri, 16 Aug 2024 17:43:49 +0800 Subject: [PATCH 31/89] fix: Components slider and treeSelect drag --- web/components/flow/node-renderer/slider.tsx | 4 ++-- web/components/flow/node-renderer/tree-select.tsx | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/web/components/flow/node-renderer/slider.tsx b/web/components/flow/node-renderer/slider.tsx index 1017e20bb..b23780379 100644 --- a/web/components/flow/node-renderer/slider.tsx +++ b/web/components/flow/node-renderer/slider.tsx @@ -25,14 +25,14 @@ export const RenderSlider = (params: TextAreaProps) => { {data?.ui?.show_input ? ( - + ) : ( - + )} > ); diff --git a/web/components/flow/node-renderer/tree-select.tsx b/web/components/flow/node-renderer/tree-select.tsx index 5db73438f..74ee226fd 100644 --- a/web/components/flow/node-renderer/tree-select.tsx +++ b/web/components/flow/node-renderer/tree-select.tsx @@ -15,6 +15,7 @@ export const RenderTreeSelect = (params: TextAreaProps) => { return ( Date: Tue, 30 Jul 2024 09:11:59 +0800 Subject: [PATCH 32/89] feat(core): Add UI component for AWEL flow --- dbgpt/core/awel/flow/exceptions.py | 11 + dbgpt/core/awel/flow/ui.py | 348 +++++++++++++++++++++++++++++ 2 files changed, 359 insertions(+) create mode 100644 dbgpt/core/awel/flow/ui.py diff --git a/dbgpt/core/awel/flow/exceptions.py b/dbgpt/core/awel/flow/exceptions.py index 0c3dc667d..68c02f8ac 100644 --- a/dbgpt/core/awel/flow/exceptions.py +++ b/dbgpt/core/awel/flow/exceptions.py @@ -44,3 +44,14 @@ class FlowDAGMetadataException(FlowMetadataException): def __init__(self, message: str, error_type="build_dag_metadata_error"): """Create a new FlowDAGMetadataException.""" super().__init__(message, error_type) + + +class FlowUIComponentException(FlowException): + """The exception for UI parameter failed.""" + + def __init__( + self, message: str, component_name: str, error_type="build_ui_component_error" + ): + """Create a new FlowUIParameterException.""" + new_message = f"{component_name}: {message}" + super().__init__(new_message, error_type) diff --git a/dbgpt/core/awel/flow/ui.py b/dbgpt/core/awel/flow/ui.py new file mode 100644 index 000000000..a9f220961 --- /dev/null +++ b/dbgpt/core/awel/flow/ui.py @@ -0,0 +1,348 @@ +"""UI components for AWEL flow.""" + +from typing import Any, Dict, List, Literal, Optional + +from dbgpt._private.pydantic import BaseModel, Field + +from .exceptions import FlowUIComponentException + +_UI_TYPE = Literal[ + "cascader", + "checkbox", + "date_picker", + "input", + "text_area", + "auto_complete", + "slider", + "time_picker", + "tree_select", + "upload", + "variable", + "password", + "code_editor", +] + + +class RefreshableMixin(BaseModel): + """Refreshable mixin.""" + + refresh: Optional[bool] = Field( + False, + description="Whether to enable the refresh", + ) + refresh_depends: Optional[List[str]] = Field( + None, + description="The dependencies of the refresh", + ) + + +class UIComponent(RefreshableMixin, BaseModel): + """UI component.""" + + class UIRange(BaseModel): + """UI range.""" + + min: int | float | str | None = Field(None, description="Minimum value") + max: int | float | str | None = Field(None, description="Maximum value") + step: int | float | str | None = Field(None, description="Step value") + format: str | None = Field(None, description="Format") + + ui_type: _UI_TYPE = Field(..., description="UI component type") + + disabled: bool = Field( + False, + description="Whether the component is disabled", + ) + + def check_parameter(self, parameter_dict: Dict[str, Any]): + """Check parameter. + + Raises: + FlowUIParameterException: If the parameter is invalid. + """ + + def _check_options(self, options: Dict[str, Any]): + """Check options.""" + if not options: + raise FlowUIComponentException("options is required", self.ui_type) + + +class StatusMixin(BaseModel): + """Status mixin.""" + + status: Optional[Literal["error", "warning"]] = Field( + None, + description="Status of the input", + ) + + +class RangeMixin(BaseModel): + """Range mixin.""" + + ui_range: Optional[UIComponent.UIRange] = Field( + None, + description="Range for the component", + ) + + +class InputMixin(BaseModel): + """Input mixin.""" + + class Count(BaseModel): + """Count.""" + + show: Optional[bool] = Field( + None, + description="Whether to show count", + ) + max: Optional[int] = Field( + None, + description="The maximum count", + ) + exceed_strategy: Optional[Literal["cut", "warning"]] = Field( + None, + description="The strategy when the count exceeds", + ) + + count: Optional[Count] = Field( + None, + description="Count configuration", + ) + + +class PanelEditorMixin(BaseModel): + """Edit the content in the panel.""" + + class Editor(BaseModel): + """Editor configuration.""" + + width: Optional[int] = Field( + None, + description="The width of the panel", + ) + height: Optional[int] = Field( + None, + description="The height of the panel", + ) + + editor: Optional[Editor] = Field( + None, + description="The editor configuration", + ) + + +class UICascader(StatusMixin, UIComponent): + """Cascader component.""" + + ui_type: Literal["cascader"] = Field("cascader", frozen=True) + + show_search: bool = Field( + False, + description="Whether to show search input", + ) + + def check_parameter(self, parameter_dict: Dict[str, Any]): + """Check parameter.""" + options = parameter_dict.get("options") + if not options: + raise FlowUIComponentException("options is required", self.ui_type) + first_level = options[0] + if "children" not in first_level: + raise FlowUIComponentException( + "children is required in options", self.ui_type + ) + + +class UICheckbox(UIComponent): + """Checkbox component.""" + + ui_type: Literal["checkbox"] = Field("checkbox", frozen=True) + + def check_parameter(self, parameter_dict: Dict[str, Any]): + """Check parameter.""" + self._check_options(parameter_dict.get("options", {})) + + +class UIDatePicker(StatusMixin, RangeMixin, UIComponent): + """Date picker component.""" + + ui_type: Literal["date_picker"] = Field("date_picker", frozen=True) + + placement: Optional[ + Literal["topLeft", "topRight", "bottomLeft", "bottomRight"] + ] = Field( + None, + description="The position of the picker panel, None means bottomLeft", + ) + + +class UIInput(StatusMixin, InputMixin, UIComponent): + """Input component.""" + + ui_type: Literal["input"] = Field("input", frozen=True) + + prefix: Optional[str] = Field( + None, + description="The prefix, icon or text", + examples=["$", "icon:UserOutlined"], + ) + suffix: Optional[str] = Field( + None, + description="The suffix, icon or text", + examples=["$", "icon:SearchOutlined"], + ) + + +class UITextArea(PanelEditorMixin, UIInput): + """Text area component.""" + + ui_type: Literal["text_area"] = Field("text_area", frozen=True) # type: ignore + auto_size: Optional[bool] = Field( + None, + description="Whether the height of the textarea automatically adjusts based " + "on the content", + ) + min_rows: Optional[int] = Field( + None, + description="The minimum number of rows", + ) + max_rows: Optional[int] = Field( + None, + description="The maximum number of rows", + ) + + +class UIAutoComplete(UIInput): + """Auto complete component.""" + + ui_type: Literal["auto_complete"] = Field( # type: ignore + "auto_complete", frozen=True + ) + + +class UISlider(RangeMixin, UIComponent): + """Slider component.""" + + ui_type: Literal["slider"] = Field("slider", frozen=True) + + show_input: bool = Field( + False, description="Whether to display the value in a input component" + ) + + +class UITimePicker(StatusMixin, UIComponent): + """Time picker component.""" + + ui_type: Literal["time_picker"] = Field("time_picker", frozen=True) + + format: Optional[str] = Field( + None, + description="The format of the time", + examples=["HH:mm:ss", "HH:mm"], + ) + hour_step: Optional[int] = Field( + None, + description="The step of the hour input", + ) + minute_step: Optional[int] = Field( + None, + description="The step of the minute input", + ) + second_step: Optional[int] = Field( + None, + description="The step of the second input", + ) + + +class UITreeSelect(StatusMixin, UIComponent): + """Tree select component.""" + + ui_type: Literal["tree_select"] = Field("tree_select", frozen=True) + + def check_parameter(self, parameter_dict: Dict[str, Any]): + """Check parameter.""" + options = parameter_dict.get("options") + if not options: + raise FlowUIComponentException("options is required", self.ui_type) + first_level = options[0] + if "children" not in first_level: + raise FlowUIComponentException( + "children is required in options", self.ui_type + ) + + +class UIUpload(StatusMixin, UIComponent): + """Upload component.""" + + ui_type: Literal["upload"] = Field("upload", frozen=True) + + max_file_size: Optional[int] = Field( + None, + description="The maximum size of the file, in bytes", + ) + max_count: Optional[int] = Field( + None, + description="The maximum number of files that can be uploaded", + ) + file_types: Optional[List[str]] = Field( + None, + description="The file types that can be accepted", + examples=[[".png", ".jpg"]], + ) + up_event: Optional[Literal["after_select", "button_click"]] = Field( + None, + description="The event that triggers the upload", + ) + drag: bool = Field( + False, + description="Whether to support drag and drop upload", + ) + action: Optional[str] = Field( + None, + description="The URL for the file upload", + ) + + +class UIVariableInput(UIInput): + """Variable input component.""" + + ui_type: Literal["variable"] = Field("variable", frozen=True) # type: ignore + key: str = Field(..., description="The key of the variable") + key_type: Literal["common", "secret"] = Field( + "common", + description="The type of the key", + ) + refresh: Optional[bool] = Field( + True, + description="Whether to enable the refresh", + ) + + def check_parameter(self, parameter_dict: Dict[str, Any]): + """Check parameter.""" + self._check_options(parameter_dict.get("options", {})) + + +class UIPasswordInput(UIVariableInput): + """Password input component.""" + + ui_type: Literal["password"] = Field("password", frozen=True) # type: ignore + + key_type: Literal["secret"] = Field( + "secret", + description="The type of the key", + ) + + def check_parameter(self, parameter_dict: Dict[str, Any]): + """Check parameter.""" + self._check_options(parameter_dict.get("options", {})) + + +class UICodeEditor(UITextArea): + """Code editor component.""" + + ui_type: Literal["code_editor"] = Field("code_editor", frozen=True) # type: ignore + + language: Optional[str] = Field( + "python", + description="The language of the code", + ) From e3e08a83e6d0261092473254ee88a8da52c30d6b Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Mon, 5 Aug 2024 18:08:02 +0800 Subject: [PATCH 33/89] feat: Add flow2.0 examples --- dbgpt/core/awel/flow/base.py | 18 +- dbgpt/core/awel/flow/ui.py | 289 +++++---- dbgpt/core/awel/util/parameter_util.py | 5 +- .../core/interface/operators/llm_operator.py | 5 + .../interface/operators/prompt_operator.py | 4 + dbgpt/serve/flow/api/endpoints.py | 5 +- examples/awel/awel_flow_ui_components.py | 583 ++++++++++++++++++ 7 files changed, 786 insertions(+), 123 deletions(-) create mode 100644 examples/awel/awel_flow_ui_components.py diff --git a/dbgpt/core/awel/flow/base.py b/dbgpt/core/awel/flow/base.py index fb60538ba..da0b2c378 100644 --- a/dbgpt/core/awel/flow/base.py +++ b/dbgpt/core/awel/flow/base.py @@ -19,6 +19,7 @@ from dbgpt.core.interface.serialization import Serializable from .exceptions import FlowMetadataException, FlowParameterMetadataException +from .ui import UIComponent _TYPE_REGISTRY: Dict[str, Type] = {} @@ -136,6 +137,7 @@ def __init__(self, label: str, description: str): "agent": _CategoryDetail("Agent", "The agent operator"), "rag": _CategoryDetail("RAG", "The RAG operator"), "experimental": _CategoryDetail("EXPERIMENTAL", "EXPERIMENTAL operator"), + "example": _CategoryDetail("Example", "Example operator"), } @@ -151,6 +153,7 @@ class OperatorCategory(str, Enum): AGENT = "agent" RAG = "rag" EXPERIMENTAL = "experimental" + EXAMPLE = "example" def label(self) -> str: """Get the label of the category.""" @@ -193,6 +196,7 @@ class OperatorType(str, Enum): "embeddings": _CategoryDetail("Embeddings", "The embeddings resource"), "rag": _CategoryDetail("RAG", "The resource"), "vector_store": _CategoryDetail("Vector Store", "The vector store resource"), + "example": _CategoryDetail("Example", "The example resource"), } @@ -209,6 +213,7 @@ class ResourceCategory(str, Enum): EMBEDDINGS = "embeddings" RAG = "rag" VECTOR_STORE = "vector_store" + EXAMPLE = "example" def label(self) -> str: """Get the label of the category.""" @@ -343,6 +348,9 @@ class Parameter(TypeMetadata, Serializable): alias: Optional[List[str]] = Field( None, description="The alias of the parameter(Compatible with old version)" ) + ui: Optional[UIComponent] = Field( + None, description="The UI component of the parameter" + ) @model_validator(mode="before") @classmethod @@ -398,6 +406,7 @@ def build_from( label: str, name: str, type: Type, + is_list: bool = False, optional: bool = False, default: Optional[Union[DefaultParameterType, _MISSING_TYPE]] = _MISSING_VALUE, placeholder: Optional[DefaultParameterType] = None, @@ -405,6 +414,7 @@ def build_from( options: Optional[Union[BaseDynamicOptions, List[OptionValue]]] = None, resource_type: ResourceType = ResourceType.INSTANCE, alias: Optional[List[str]] = None, + ui: Optional[UIComponent] = None, ): """Build the parameter from the type.""" type_name = type.__qualname__ @@ -419,6 +429,7 @@ def build_from( name=name, type_name=type_name, type_cls=type_cls, + is_list=is_list, category=category.value, resource_type=resource_type, optional=optional, @@ -427,6 +438,7 @@ def build_from( description=description or label, options=options, alias=alias, + ui=ui, ) @classmethod @@ -456,11 +468,12 @@ def build_from_ui(cls, data: Dict) -> "Parameter": description=data["description"], options=data["options"], value=data["value"], + ui=data.get("ui"), ) def to_dict(self) -> Dict: """Convert current metadata to json dict.""" - dict_value = model_to_dict(self, exclude={"options", "alias"}) + dict_value = model_to_dict(self, exclude={"options", "alias", "ui"}) if not self.options: dict_value["options"] = None elif isinstance(self.options, BaseDynamicOptions): @@ -468,6 +481,9 @@ def to_dict(self) -> Dict: dict_value["options"] = [value.to_dict() for value in values] else: dict_value["options"] = [value.to_dict() for value in self.options] + + if self.ui: + dict_value["ui"] = self.ui.to_dict() return dict_value def get_dict_options(self) -> Optional[List[Dict]]: diff --git a/dbgpt/core/awel/flow/ui.py b/dbgpt/core/awel/flow/ui.py index a9f220961..ca4361276 100644 --- a/dbgpt/core/awel/flow/ui.py +++ b/dbgpt/core/awel/flow/ui.py @@ -1,8 +1,9 @@ """UI components for AWEL flow.""" -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Union -from dbgpt._private.pydantic import BaseModel, Field +from dbgpt._private.pydantic import BaseModel, Field, model_to_dict +from dbgpt.core.interface.serialization import Serializable from .exceptions import FlowUIComponentException @@ -36,37 +37,6 @@ class RefreshableMixin(BaseModel): ) -class UIComponent(RefreshableMixin, BaseModel): - """UI component.""" - - class UIRange(BaseModel): - """UI range.""" - - min: int | float | str | None = Field(None, description="Minimum value") - max: int | float | str | None = Field(None, description="Maximum value") - step: int | float | str | None = Field(None, description="Step value") - format: str | None = Field(None, description="Format") - - ui_type: _UI_TYPE = Field(..., description="UI component type") - - disabled: bool = Field( - False, - description="Whether the component is disabled", - ) - - def check_parameter(self, parameter_dict: Dict[str, Any]): - """Check parameter. - - Raises: - FlowUIParameterException: If the parameter is invalid. - """ - - def _check_options(self, options: Dict[str, Any]): - """Check options.""" - if not options: - raise FlowUIComponentException("options is required", self.ui_type) - - class StatusMixin(BaseModel): """Status mixin.""" @@ -76,40 +46,6 @@ class StatusMixin(BaseModel): ) -class RangeMixin(BaseModel): - """Range mixin.""" - - ui_range: Optional[UIComponent.UIRange] = Field( - None, - description="Range for the component", - ) - - -class InputMixin(BaseModel): - """Input mixin.""" - - class Count(BaseModel): - """Count.""" - - show: Optional[bool] = Field( - None, - description="Whether to show count", - ) - max: Optional[int] = Field( - None, - description="The maximum count", - ) - exceed_strategy: Optional[Literal["cut", "warning"]] = Field( - None, - description="The strategy when the count exceeds", - ) - - count: Optional[Count] = Field( - None, - description="Count configuration", - ) - - class PanelEditorMixin(BaseModel): """Edit the content in the panel.""" @@ -126,19 +62,62 @@ class Editor(BaseModel): ) editor: Optional[Editor] = Field( - None, + default_factory=lambda: PanelEditorMixin.Editor(width=800, height=400), description="The editor configuration", ) -class UICascader(StatusMixin, UIComponent): +class UIComponent(RefreshableMixin, Serializable, BaseModel): + """UI component.""" + + class UIAttribute(StatusMixin, BaseModel): + """Base UI attribute.""" + + disabled: bool = Field( + False, + description="Whether the component is disabled", + ) + + ui_type: _UI_TYPE = Field(..., description="UI component type") + + attr: Optional[UIAttribute] = Field( + None, + description="The attributes of the component", + ) + + def check_parameter(self, parameter_dict: Dict[str, Any]): + """Check parameter. + + Raises: + FlowUIParameterException: If the parameter is invalid. + """ + + def _check_options(self, options: Dict[str, Any]): + """Check options.""" + if not options: + raise FlowUIComponentException("options is required", self.ui_type) + + def to_dict(self) -> Dict: + """Convert current metadata to json dict.""" + return model_to_dict(self) + + +class UICascader(UIComponent): """Cascader component.""" + class UIAttribute(UIComponent.UIAttribute): + """Cascader attribute.""" + + show_search: bool = Field( + False, + description="Whether to show search input", + ) + ui_type: Literal["cascader"] = Field("cascader", frozen=True) - show_search: bool = Field( - False, - description="Whether to show search input", + attr: Optional[UIAttribute] = Field( + None, + description="The attributes of the component", ) def check_parameter(self, parameter_dict: Dict[str, Any]): @@ -163,53 +142,81 @@ def check_parameter(self, parameter_dict: Dict[str, Any]): self._check_options(parameter_dict.get("options", {})) -class UIDatePicker(StatusMixin, RangeMixin, UIComponent): +class UIDatePicker(UIComponent): """Date picker component.""" + class UIAttribute(UIComponent.UIAttribute): + """Date picker attribute.""" + + placement: Optional[ + Literal["topLeft", "topRight", "bottomLeft", "bottomRight"] + ] = Field( + None, + description="The position of the picker panel, None means bottomLeft", + ) + ui_type: Literal["date_picker"] = Field("date_picker", frozen=True) - placement: Optional[ - Literal["topLeft", "topRight", "bottomLeft", "bottomRight"] - ] = Field( + attr: Optional[UIAttribute] = Field( None, - description="The position of the picker panel, None means bottomLeft", + description="The attributes of the component", ) -class UIInput(StatusMixin, InputMixin, UIComponent): +class UIInput(UIComponent): """Input component.""" + class UIAttribute(UIComponent.UIAttribute): + """Input attribute.""" + + prefix: Optional[str] = Field( + None, + description="The prefix, icon or text", + examples=["$", "icon:UserOutlined"], + ) + suffix: Optional[str] = Field( + None, + description="The suffix, icon or text", + examples=["$", "icon:SearchOutlined"], + ) + show_count: Optional[bool] = Field( + None, + description="Whether to show count", + ) + maxlength: Optional[int] = Field( + None, + description="The maximum length of the input", + ) + ui_type: Literal["input"] = Field("input", frozen=True) - prefix: Optional[str] = Field( + attr: Optional[UIAttribute] = Field( None, - description="The prefix, icon or text", - examples=["$", "icon:UserOutlined"], - ) - suffix: Optional[str] = Field( - None, - description="The suffix, icon or text", - examples=["$", "icon:SearchOutlined"], + description="The attributes of the component", ) class UITextArea(PanelEditorMixin, UIInput): """Text area component.""" + class AutoSize(BaseModel): + """Auto size configuration.""" + + min_rows: Optional[int] = Field( + None, + description="The minimum number of rows", + ) + max_rows: Optional[int] = Field( + None, + description="The maximum number of rows", + ) + ui_type: Literal["text_area"] = Field("text_area", frozen=True) # type: ignore - auto_size: Optional[bool] = Field( + autosize: Optional[Union[bool, AutoSize]] = Field( None, description="Whether the height of the textarea automatically adjusts based " "on the content", ) - min_rows: Optional[int] = Field( - None, - description="The minimum number of rows", - ) - max_rows: Optional[int] = Field( - None, - description="The maximum number of rows", - ) class UIAutoComplete(UIInput): @@ -220,44 +227,73 @@ class UIAutoComplete(UIInput): ) -class UISlider(RangeMixin, UIComponent): +class UISlider(UIComponent): """Slider component.""" + class UIAttribute(UIComponent.UIAttribute): + """Slider attribute.""" + + min: Optional[int | float] = Field( + None, + description="The minimum value", + ) + max: Optional[int | float] = Field( + None, + description="The maximum value", + ) + step: Optional[int | float] = Field( + None, + description="The step of the slider", + ) + ui_type: Literal["slider"] = Field("slider", frozen=True) show_input: bool = Field( False, description="Whether to display the value in a input component" ) + attr: Optional[UIAttribute] = Field( + None, + description="The attributes of the component", + ) + -class UITimePicker(StatusMixin, UIComponent): +class UITimePicker(UIComponent): """Time picker component.""" + class UIAttribute(UIComponent.UIAttribute): + """Time picker attribute.""" + + format: Optional[str] = Field( + None, + description="The format of the time", + examples=["HH:mm:ss", "HH:mm"], + ) + hour_step: Optional[int] = Field( + None, + description="The step of the hour input", + ) + minute_step: Optional[int] = Field( + None, + description="The step of the minute input", + ) + second_step: Optional[int] = Field( + None, + description="The step of the second input", + ) + ui_type: Literal["time_picker"] = Field("time_picker", frozen=True) - format: Optional[str] = Field( - None, - description="The format of the time", - examples=["HH:mm:ss", "HH:mm"], - ) - hour_step: Optional[int] = Field( + attr: Optional[UIAttribute] = Field( None, - description="The step of the hour input", - ) - minute_step: Optional[int] = Field( - None, - description="The step of the minute input", - ) - second_step: Optional[int] = Field( - None, - description="The step of the second input", + description="The attributes of the component", ) -class UITreeSelect(StatusMixin, UIComponent): +class UITreeSelect(UICascader): """Tree select component.""" - ui_type: Literal["tree_select"] = Field("tree_select", frozen=True) + ui_type: Literal["tree_select"] = Field("tree_select", frozen=True) # type: ignore def check_parameter(self, parameter_dict: Dict[str, Any]): """Check parameter.""" @@ -271,19 +307,24 @@ def check_parameter(self, parameter_dict: Dict[str, Any]): ) -class UIUpload(StatusMixin, UIComponent): +class UIUpload(UIComponent): """Upload component.""" + class UIAttribute(UIComponent.UIAttribute): + """Upload attribute.""" + + max_count: Optional[int] = Field( + None, + description="The maximum number of files that can be uploaded", + ) + ui_type: Literal["upload"] = Field("upload", frozen=True) max_file_size: Optional[int] = Field( None, description="The maximum size of the file, in bytes", ) - max_count: Optional[int] = Field( - None, - description="The maximum number of files that can be uploaded", - ) + file_types: Optional[List[str]] = Field( None, description="The file types that can be accepted", @@ -346,3 +387,13 @@ class UICodeEditor(UITextArea): "python", description="The language of the code", ) + + +class DefaultUITextArea(UITextArea): + """Default text area component.""" + + autosize: Union[bool, UITextArea.AutoSize] = Field( + default_factory=lambda: UITextArea.AutoSize(min_rows=2, max_rows=40), + description="Whether the height of the textarea automatically adjusts based " + "on the content", + ) diff --git a/dbgpt/core/awel/util/parameter_util.py b/dbgpt/core/awel/util/parameter_util.py index defd99a3b..70015c9ba 100644 --- a/dbgpt/core/awel/util/parameter_util.py +++ b/dbgpt/core/awel/util/parameter_util.py @@ -2,7 +2,7 @@ import inspect from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List +from typing import Any, Callable, Dict, List, Optional from dbgpt._private.pydantic import BaseModel, Field, model_validator from dbgpt.core.interface.serialization import Serializable @@ -16,6 +16,9 @@ class OptionValue(Serializable, BaseModel): label: str = Field(..., description="The label of the option") name: str = Field(..., description="The name of the option") value: Any = Field(..., description="The value of the option") + children: Optional[List["OptionValue"]] = Field( + None, description="The children of the option" + ) def to_dict(self) -> Dict: """Convert current metadata to json dict.""" diff --git a/dbgpt/core/interface/operators/llm_operator.py b/dbgpt/core/interface/operators/llm_operator.py index 53e34ffe5..45863d0a9 100644 --- a/dbgpt/core/interface/operators/llm_operator.py +++ b/dbgpt/core/interface/operators/llm_operator.py @@ -24,6 +24,7 @@ OperatorType, Parameter, ViewMetadata, + ui, ) from dbgpt.core.interface.llm import ( LLMClient, @@ -69,6 +70,10 @@ class RequestBuilderOperator(MapOperator[RequestInput, ModelRequest]): optional=True, default=None, description=_("The temperature of the model request."), + ui=ui.UISlider( + show_input=True, + attr=ui.UISlider.UIAttribute(min=0.0, max=2.0, step=0.1), + ), ), Parameter.build_from( _("Max New Tokens"), diff --git a/dbgpt/core/interface/operators/prompt_operator.py b/dbgpt/core/interface/operators/prompt_operator.py index c3765aa67..7d97230ac 100644 --- a/dbgpt/core/interface/operators/prompt_operator.py +++ b/dbgpt/core/interface/operators/prompt_operator.py @@ -1,4 +1,5 @@ """The prompt operator.""" + from abc import ABC from typing import Any, Dict, List, Optional, Union @@ -18,6 +19,7 @@ ResourceCategory, ViewMetadata, register_resource, + ui, ) from dbgpt.core.interface.message import BaseMessage from dbgpt.core.interface.operators.llm_operator import BaseLLM @@ -48,6 +50,7 @@ optional=True, default="You are a helpful AI Assistant.", description=_("The system message."), + ui=ui.DefaultUITextArea(), ), Parameter.build_from( label=_("Message placeholder"), @@ -65,6 +68,7 @@ default="{user_input}", placeholder="{user_input}", description=_("The human message."), + ui=ui.DefaultUITextArea(), ), ], ) diff --git a/dbgpt/serve/flow/api/endpoints.py b/dbgpt/serve/flow/api/endpoints.py index 6cb5ef879..98ff81d2f 100644 --- a/dbgpt/serve/flow/api/endpoints.py +++ b/dbgpt/serve/flow/api/endpoints.py @@ -209,7 +209,7 @@ async def query_page( @router.get("/nodes", dependencies=[Depends(check_api_key)]) -async def get_nodes() -> Result[List[Union[ViewMetadata, ResourceMetadata]]]: +async def get_nodes(): """Get the operator or resource nodes Returns: @@ -218,7 +218,8 @@ async def get_nodes() -> Result[List[Union[ViewMetadata, ResourceMetadata]]]: """ from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY - return Result.succ(_OPERATOR_REGISTRY.metadata_list()) + metadata_list = _OPERATOR_REGISTRY.metadata_list() + return Result.succ(metadata_list) def init_endpoints(system_app: SystemApp) -> None: diff --git a/examples/awel/awel_flow_ui_components.py b/examples/awel/awel_flow_ui_components.py new file mode 100644 index 000000000..2af3e2bf3 --- /dev/null +++ b/examples/awel/awel_flow_ui_components.py @@ -0,0 +1,583 @@ +"""Some UI components for the AWEL flow.""" + +import logging +from typing import List, Optional + +from dbgpt.core.awel import MapOperator +from dbgpt.core.awel.flow import ( + IOField, + OperatorCategory, + OptionValue, + Parameter, + ViewMetadata, + ui, +) + +logger = logging.getLogger(__name__) + + +class ExampleFlowCascaderOperator(MapOperator[str, str]): + """An example flow operator that includes a cascader as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Cascader", + name="example_flow_cascader", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a cascader as parameter.", + parameters=[ + Parameter.build_from( + "Address Selector", + "address", + type=str, + is_list=True, + optional=True, + default=None, + placeholder="Select the address", + description="The address of the location.", + options=[ + OptionValue( + label="Zhejiang", + name="zhejiang", + value="zhejiang", + children=[ + OptionValue( + label="Hangzhou", + name="hangzhou", + value="hangzhou", + children=[ + OptionValue( + label="Xihu", + name="xihu", + value="xihu", + ), + OptionValue( + label="Feilaifeng", + name="feilaifeng", + value="feilaifeng", + ), + ], + ), + ], + ), + OptionValue( + label="Jiangsu", + name="jiangsu", + value="jiangsu", + children=[ + OptionValue( + label="Nanjing", + name="nanjing", + value="nanjing", + children=[ + OptionValue( + label="Zhonghua Gate", + name="zhonghuamen", + value="zhonghuamen", + ), + OptionValue( + label="Zhongshanling", + name="zhongshanling", + value="zhongshanling", + ), + ], + ), + ], + ), + ], + ui=ui.UICascader(attr=ui.UICascader.UIAttribute(show_search=True)), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Address", + "address", + str, + description="User's address.", + ) + ], + ) + + def __int__(self, address: Optional[List[str]] = None, **kwargs): + super().__init__(**kwargs) + self.address = address or [] + + async def map(self, user_name: str) -> str: + """Map the user name to the address.""" + full_address_str = " ".join(self.address) + return "Your name is %s, and your address is %s." % ( + user_name, + full_address_str, + ) + + +class ExampleFlowCheckboxOperator(MapOperator[str, str]): + """An example flow operator that includes a checkbox as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Checkbox", + name="example_flow_checkbox", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a checkbox as parameter.", + parameters=[ + Parameter.build_from( + "Fruits Selector", + "fruits", + type=str, + is_list=True, + optional=True, + default=None, + placeholder="Select the fruits", + description="The fruits you like.", + options=[ + OptionValue(label="Apple", name="apple", value="apple"), + OptionValue(label="Banana", name="banana", value="banana"), + OptionValue(label="Orange", name="orange", value="orange"), + OptionValue(label="Pear", name="pear", value="pear"), + ], + ui=ui.UICheckbox(attr=ui.UICheckbox.UIAttribute(show_search=True)), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Fruits", + "fruits", + str, + description="User's favorite fruits.", + ) + ], + ) + + def __init__(self, fruits: Optional[List[str]] = None, **kwargs): + super().__init__(**kwargs) + self.fruits = fruits or [] + + async def map(self, user_name: str) -> str: + """Map the user name to the fruits.""" + return "Your name is %s, and you like %s." % (user_name, ", ".join(self.fruits)) + + +class ExampleFlowDatePickerOperator(MapOperator[str, str]): + """An example flow operator that includes a date picker as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Date Picker", + name="example_flow_date_picker", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a date picker as parameter.", + parameters=[ + Parameter.build_from( + "Date Selector", + "date", + type=str, + placeholder="Select the date", + description="The date you choose.", + ui=ui.UIDatePicker( + attr=ui.UIDatePicker.UIAttribute(placement="bottomLeft") + ), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Date", + "date", + str, + description="User's selected date.", + ) + ], + ) + + def __init__(self, date: str, **kwargs): + super().__init__(**kwargs) + self.date = date + + async def map(self, user_name: str) -> str: + """Map the user name to the date.""" + return "Your name is %s, and you choose the date %s." % (user_name, self.date) + + +class ExampleFlowInputOperator(MapOperator[str, str]): + """An example flow operator that includes an input as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Input", + name="example_flow_input", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a input as parameter.", + parameters=[ + Parameter.build_from( + "Your hobby", + "hobby", + type=str, + placeholder="Please input your hobby", + description="The hobby you like.", + ui=ui.UIInput( + attr=ui.UIInput.UIAttribute( + prefix="icon:UserOutlined", show_count=True, maxlength=200 + ) + ), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "User Hobby", + "hobby", + str, + description="User's hobby.", + ) + ], + ) + + def __init__(self, hobby: str, **kwargs): + super().__init__(**kwargs) + self.hobby = hobby + + async def map(self, user_name: str) -> str: + """Map the user name to the input.""" + return "Your name is %s, and your hobby is %s." % (user_name, self.hobby) + + +class ExampleFlowTextAreaOperator(MapOperator[str, str]): + """An example flow operator that includes a text area as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Text Area", + name="example_flow_text_area", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a text area as parameter.", + parameters=[ + Parameter.build_from( + "Your comment", + "comment", + type=str, + placeholder="Please input your comment", + description="The comment you want to say.", + ui=ui.UITextArea( + attr=ui.UITextArea.UIAttribute(show_count=True, maxlength=1000), + autosize=ui.UITextArea.AutoSize(min_rows=2, max_rows=6), + ), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "User Comment", + "comment", + str, + description="User's comment.", + ) + ], + ) + + def __init__(self, comment: str, **kwargs): + super().__init__(**kwargs) + self.comment = comment + + async def map(self, user_name: str) -> str: + """Map the user name to the text area.""" + return "Your name is %s, and your comment is %s." % (user_name, self.comment) + + +class ExampleFlowSliderOperator(MapOperator[float, float]): + + metadata = ViewMetadata( + label="Example Flow Slider", + name="example_flow_slider", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a slider as parameter.", + parameters=[ + Parameter.build_from( + "Default Temperature", + "default_temperature", + type=float, + optional=True, + default=0.7, + placeholder="Set the default temperature, e.g., 0.7", + description="The default temperature to pass to the LLM.", + ui=ui.UISlider( + show_input=True, + attr=ui.UISlider.UIAttribute(min=0.0, max=2.0, step=0.1), + ), + ) + ], + inputs=[ + IOField.build_from( + "Temperature", + "temperature", + float, + description="The temperature.", + ) + ], + outputs=[ + IOField.build_from( + "Temperature", + "temperature", + float, + description="The temperature to pass to the LLM.", + ) + ], + ) + + def __init__(self, default_temperature: float = 0.7, **kwargs): + super().__init__(**kwargs) + self.default_temperature = default_temperature + + async def map(self, temperature: float) -> float: + """Map the temperature to the result.""" + if temperature < 0.0 or temperature > 2.0: + logger.warning("Temperature out of range: %s", temperature) + return self.default_temperature + else: + return temperature + + +class ExampleFlowSliderListOperator(MapOperator[float, float]): + """An example flow operator that includes a slider list as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Slider List", + name="example_flow_slider_list", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a slider list as parameter.", + parameters=[ + Parameter.build_from( + "Temperature Selector", + "temperature_range", + type=float, + is_list=True, + optional=True, + default=None, + placeholder="Set the temperature, e.g., [0.1, 0.9]", + description="The temperature range to pass to the LLM.", + ui=ui.UISlider( + show_input=True, + attr=ui.UISlider.UIAttribute(min=0.0, max=2.0, step=0.1), + ), + ) + ], + inputs=[ + IOField.build_from( + "Temperature", + "temperature", + float, + description="The temperature.", + ) + ], + outputs=[ + IOField.build_from( + "Temperature", + "temperature", + float, + description="The temperature to pass to the LLM.", + ) + ], + ) + + def __init__(self, temperature_range: Optional[List[float]] = None, **kwargs): + super().__init__(**kwargs) + temperature_range = temperature_range or [0.1, 0.9] + if temperature_range and len(temperature_range) != 2: + raise ValueError("The length of temperature range must be 2.") + self.temperature_range = temperature_range + + async def map(self, temperature: float) -> float: + """Map the temperature to the result.""" + min_temperature, max_temperature = self.temperature_range + if temperature < min_temperature or temperature > max_temperature: + logger.warning( + "Temperature out of range: %s, min: %s, max: %s", + temperature, + min_temperature, + max_temperature, + ) + return min_temperature + return temperature + + +class ExampleFlowTimePickerOperator(MapOperator[str, str]): + """An example flow operator that includes a time picker as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Time Picker", + name="example_flow_time_picker", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a time picker as parameter.", + parameters=[ + Parameter.build_from( + "Time Selector", + "time", + type=str, + placeholder="Select the time", + description="The time you choose.", + ui=ui.UITimePicker( + attr=ui.UITimePicker.UIAttribute( + format="HH:mm:ss", hour_step=2, minute_step=10, second_step=10 + ), + ), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Time", + "time", + str, + description="User's selected time.", + ) + ], + ) + + def __init__(self, time: str, **kwargs): + super().__init__(**kwargs) + self.time = time + + async def map(self, user_name: str) -> str: + """Map the user name to the time.""" + return "Your name is %s, and you choose the time %s." % (user_name, self.time) + + +class ExampleFlowTreeSelectOperator(MapOperator[str, str]): + """An example flow operator that includes a tree select as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Tree Select", + name="example_flow_tree_select", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a tree select as parameter.", + parameters=[ + Parameter.build_from( + "Address Selector", + "address", + type=str, + is_list=True, + optional=True, + default=None, + placeholder="Select the address", + description="The address of the location.", + options=[ + OptionValue( + label="Zhejiang", + name="zhejiang", + value="zhejiang", + children=[ + OptionValue( + label="Hangzhou", + name="hangzhou", + value="hangzhou", + children=[ + OptionValue( + label="Xihu", + name="xihu", + value="xihu", + ), + OptionValue( + label="Feilaifeng", + name="feilaifeng", + value="feilaifeng", + ), + ], + ), + ], + ), + OptionValue( + label="Jiangsu", + name="jiangsu", + value="jiangsu", + children=[ + OptionValue( + label="Nanjing", + name="nanjing", + value="nanjing", + children=[ + OptionValue( + label="Zhonghua Gate", + name="zhonghuamen", + value="zhonghuamen", + ), + OptionValue( + label="Zhongshanling", + name="zhongshanling", + value="zhongshanling", + ), + ], + ), + ], + ), + ], + ui=ui.UITreeSelect(attr=ui.UITreeSelect.UIAttribute(show_search=True)), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Address", + "address", + str, + description="User's address.", + ) + ], + ) + + def __int__(self, address: Optional[List[str]] = None, **kwargs): + super().__init__(**kwargs) + self.address = address or [] + + async def map(self, user_name: str) -> str: + """Map the user name to the address.""" + full_address_str = " ".join(self.address) + return "Your name is %s, and your address is %s." % ( + user_name, + full_address_str, + ) From f6669d3b2672b11f487feed78050a46851fdb391 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Tue, 6 Aug 2024 10:17:58 +0800 Subject: [PATCH 34/89] feat(core): Support refresh for AWEL flow --- dbgpt/core/awel/flow/base.py | 52 ++++++++- dbgpt/core/awel/flow/ui.py | 33 ++++++ dbgpt/core/awel/util/parameter_util.py | 38 ++++++- dbgpt/serve/flow/api/endpoints.py | 19 +++- dbgpt/serve/flow/api/schemas.py | 40 ++++++- examples/awel/awel_flow_ui_components.py | 136 +++++++++++++++++++++++ 6 files changed, 307 insertions(+), 11 deletions(-) diff --git a/dbgpt/core/awel/flow/base.py b/dbgpt/core/awel/flow/base.py index da0b2c378..846b18baf 100644 --- a/dbgpt/core/awel/flow/base.py +++ b/dbgpt/core/awel/flow/base.py @@ -15,7 +15,11 @@ model_to_dict, model_validator, ) -from dbgpt.core.awel.util.parameter_util import BaseDynamicOptions, OptionValue +from dbgpt.core.awel.util.parameter_util import ( + BaseDynamicOptions, + OptionValue, + RefreshOptionRequest, +) from dbgpt.core.interface.serialization import Serializable from .exceptions import FlowMetadataException, FlowParameterMetadataException @@ -486,6 +490,25 @@ def to_dict(self) -> Dict: dict_value["ui"] = self.ui.to_dict() return dict_value + def refresh(self, request: Optional[RefreshOptionRequest] = None) -> Dict: + """Refresh the options of the parameter. + + Args: + request (RefreshOptionRequest): The request to refresh the options. + + Returns: + Dict: The response. + """ + dict_value = self.to_dict() + if not self.options: + dict_value["options"] = None + elif isinstance(self.options, BaseDynamicOptions): + values = self.options.refresh(request) + dict_value["options"] = [value.to_dict() for value in values] + else: + dict_value["options"] = [value.to_dict() for value in self.options] + return dict_value + def get_dict_options(self) -> Optional[List[Dict]]: """Get the options of the parameter.""" if not self.options: @@ -655,10 +678,10 @@ class BaseMetadata(BaseResource): ], ) - tags: Optional[List[str]] = Field( + tags: Optional[Dict[str, str]] = Field( default=None, description="The tags of the operator", - examples=[["llm", "openai", "gpt3"]], + examples=[{"order": "higher-order"}, {"order": "first-order"}], ) parameters: List[Parameter] = Field( @@ -768,6 +791,20 @@ def to_dict(self) -> Dict: ] return dict_value + def refresh(self, request: List[RefreshOptionRequest]) -> Dict: + """Refresh the metadata.""" + name_to_request = {req.name: req for req in request} + parameter_requests = { + parameter.name: name_to_request.get(parameter.name) + for parameter in self.parameters + } + dict_value = self.to_dict() + dict_value["parameters"] = [ + parameter.refresh(parameter_requests.get(parameter.name)) + for parameter in self.parameters + ] + return dict_value + class ResourceMetadata(BaseMetadata, TypeMetadata): """The metadata of the resource.""" @@ -1051,6 +1088,15 @@ def metadata_list(self): """Get the metadata list.""" return [item.metadata.to_dict() for item in self._registry.values()] + def refresh( + self, key: str, is_operator: bool, request: List[RefreshOptionRequest] + ) -> Dict: + """Refresh the metadata.""" + if is_operator: + return _get_operator_class(key).metadata.refresh(request) # type: ignore + else: + return _get_resource_class(key).metadata.refresh(request) + _OPERATOR_REGISTRY: FlowRegistry = FlowRegistry() diff --git a/dbgpt/core/awel/flow/ui.py b/dbgpt/core/awel/flow/ui.py index ca4361276..91008269e 100644 --- a/dbgpt/core/awel/flow/ui.py +++ b/dbgpt/core/awel/flow/ui.py @@ -8,6 +8,7 @@ from .exceptions import FlowUIComponentException _UI_TYPE = Literal[ + "select", "cascader", "checkbox", "date_picker", @@ -102,6 +103,38 @@ def to_dict(self) -> Dict: return model_to_dict(self) +class UISelect(UIComponent): + """Select component.""" + + class UIAttribute(UIComponent.UIAttribute): + """Select attribute.""" + + show_search: bool = Field( + False, + description="Whether to show search input", + ) + mode: Optional[Literal["tags"]] = Field( + None, + description="The mode of the select", + ) + placement: Optional[ + Literal["topLeft", "topRight", "bottomLeft", "bottomRight"] + ] = Field( + None, + description="The position of the picker panel, None means bottomLeft", + ) + + ui_type: Literal["select"] = Field("select", frozen=True) + attr: Optional[UIAttribute] = Field( + None, + description="The attributes of the component", + ) + + def check_parameter(self, parameter_dict: Dict[str, Any]): + """Check parameter.""" + self._check_options(parameter_dict.get("options", {})) + + class UICascader(UIComponent): """Cascader component.""" diff --git a/dbgpt/core/awel/util/parameter_util.py b/dbgpt/core/awel/util/parameter_util.py index 70015c9ba..2393aed89 100644 --- a/dbgpt/core/awel/util/parameter_util.py +++ b/dbgpt/core/awel/util/parameter_util.py @@ -10,6 +10,27 @@ _DEFAULT_DYNAMIC_REGISTRY = {} +class RefreshOptionDependency(BaseModel): + """The refresh dependency.""" + + name: str = Field(..., description="The name of the refresh dependency") + value: Optional[Any] = Field( + None, description="The value of the refresh dependency" + ) + has_value: bool = Field( + False, description="Whether the refresh dependency has value" + ) + + +class RefreshOptionRequest(BaseModel): + """The refresh option request.""" + + name: str = Field(..., description="The name of parameter to refresh") + depends: Optional[List[RefreshOptionDependency]] = Field( + None, description="The depends of the refresh config" + ) + + class OptionValue(Serializable, BaseModel): """The option value of the parameter.""" @@ -28,24 +49,31 @@ def to_dict(self) -> Dict: class BaseDynamicOptions(Serializable, BaseModel, ABC): """The base dynamic options.""" - @abstractmethod def option_values(self) -> List[OptionValue]: """Return the option values of the parameter.""" + return self.refresh(None) + + @abstractmethod + def refresh(self, request: Optional[RefreshOptionRequest]) -> List[OptionValue]: + """Refresh the dynamic options.""" class FunctionDynamicOptions(BaseDynamicOptions): """The function dynamic options.""" - func: Callable[[], List[OptionValue]] = Field( + func: Callable[..., List[OptionValue]] = Field( ..., description="The function to generate the dynamic options" ) func_id: str = Field( ..., description="The unique id of the function to generate the dynamic options" ) - def option_values(self) -> List[OptionValue]: - """Return the option values of the parameter.""" - return self.func() + def refresh(self, request: Optional[RefreshOptionRequest]) -> List[OptionValue]: + """Refresh the dynamic options.""" + if not request or not request.depends: + return self.func() + kwargs = {dep.name: dep.value for dep in request.depends if dep.has_value} + return self.func(**kwargs) @model_validator(mode="before") @classmethod diff --git a/dbgpt/serve/flow/api/endpoints.py b/dbgpt/serve/flow/api/endpoints.py index 98ff81d2f..99852271a 100644 --- a/dbgpt/serve/flow/api/endpoints.py +++ b/dbgpt/serve/flow/api/endpoints.py @@ -11,7 +11,7 @@ from ..config import APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig from ..service.service import Service -from .schemas import ServeRequest, ServerResponse +from .schemas import RefreshNodeRequest, ServeRequest, ServerResponse router = APIRouter() @@ -222,6 +222,23 @@ async def get_nodes(): return Result.succ(metadata_list) +@router.post("/nodes/refresh", dependencies=[Depends(check_api_key)]) +async def refresh_nodes(refresh_request: RefreshNodeRequest): + """Refresh the operator or resource nodes + + Returns: + Result[None]: The response + """ + from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY + + new_metadata = _OPERATOR_REGISTRY.refresh( + key=refresh_request.id, + is_operator=refresh_request.flow_type == "operator", + request=refresh_request.refresh, + ) + return Result.succ(new_metadata) + + def init_endpoints(system_app: SystemApp) -> None: """Initialize the endpoints""" global global_system_app diff --git a/dbgpt/serve/flow/api/schemas.py b/dbgpt/serve/flow/api/schemas.py index 6fb8c1924..2daa8f581 100644 --- a/dbgpt/serve/flow/api/schemas.py +++ b/dbgpt/serve/flow/api/schemas.py @@ -1,7 +1,8 @@ -from dbgpt._private.pydantic import ConfigDict +from typing import List, Literal -# Define your Pydantic schemas here +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field from dbgpt.core.awel.flow.flow_factory import FlowPanel +from dbgpt.core.awel.util.parameter_util import RefreshOptionRequest from ..config import SERVE_APP_NAME_HUMP @@ -14,3 +15,38 @@ class ServerResponse(FlowPanel): # TODO define your own fields here model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}") + + +class RefreshNodeRequest(BaseModel): + """Flow response model""" + + model_config = ConfigDict(title=f"RefreshNodeRequest") + id: str = Field( + ..., + title="The id of the node", + description="The id of the node to refresh", + examples=["operator_llm_operator___$$___llm___$$___v1"], + ) + flow_type: Literal["operator", "resource"] = Field( + "operator", + title="The type of the node", + description="The type of the node to refresh", + examples=["operator", "resource"], + ) + type_name: str = Field( + ..., + title="The type of the node", + description="The type of the node to refresh", + examples=["LLMOperator"], + ) + type_cls: str = Field( + ..., + title="The class of the node", + description="The class of the node to refresh", + examples=["dbgpt.core.operator.llm.LLMOperator"], + ) + refresh: List[RefreshOptionRequest] = Field( + ..., + title="The refresh options", + description="The refresh options", + ) diff --git a/examples/awel/awel_flow_ui_components.py b/examples/awel/awel_flow_ui_components.py index 2af3e2bf3..fc8d9a5c4 100644 --- a/examples/awel/awel_flow_ui_components.py +++ b/examples/awel/awel_flow_ui_components.py @@ -5,6 +5,7 @@ from dbgpt.core.awel import MapOperator from dbgpt.core.awel.flow import ( + FunctionDynamicOptions, IOField, OperatorCategory, OptionValue, @@ -16,6 +17,59 @@ logger = logging.getLogger(__name__) +class ExampleFlowSelectOperator(MapOperator[str, str]): + """An example flow operator that includes a select as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Select", + name="example_flow_select", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a select as parameter.", + parameters=[ + Parameter.build_from( + "Fruits Selector", + "fruits", + type=str, + optional=True, + default=None, + placeholder="Select the fruits", + description="The fruits you like.", + options=[ + OptionValue(label="Apple", name="apple", value="apple"), + OptionValue(label="Banana", name="banana", value="banana"), + OptionValue(label="Orange", name="orange", value="orange"), + OptionValue(label="Pear", name="pear", value="pear"), + ], + ui=ui.UISelect(attr=ui.UISelect.UIAttribute(show_search=True)), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Fruits", + "fruits", + str, + description="User's favorite fruits.", + ) + ], + ) + + def __init__(self, fruits: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + self.fruits = fruits + + async def map(self, user_name: str) -> str: + """Map the user name to the fruits.""" + return "Your name is %s, and you like %s." % (user_name, self.fruits) + + class ExampleFlowCascaderOperator(MapOperator[str, str]): """An example flow operator that includes a cascader as parameter.""" @@ -581,3 +635,85 @@ async def map(self, user_name: str) -> str: user_name, full_address_str, ) + + +def get_recent_3_times(time_interval: int = 1) -> List[OptionValue]: + """Get the recent times.""" + from datetime import datetime, timedelta + + now = datetime.now() + recent_times = [now - timedelta(hours=time_interval * i) for i in range(3)] + formatted_times = [time.strftime("%Y-%m-%d %H:%M:%S") for time in recent_times] + option_values = [ + OptionValue(label=formatted_time, name=f"time_{i + 1}", value=formatted_time) + for i, formatted_time in enumerate(formatted_times) + ] + + return option_values + + +class ExampleFlowRefreshOperator(MapOperator[str, str]): + """An example flow operator that includes a refresh option.""" + + metadata = ViewMetadata( + label="Example Refresh Operator", + name="example_refresh_operator", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a refresh option.", + parameters=[ + Parameter.build_from( + "Time Interval", + "time_interval", + type=int, + optional=True, + default=1, + placeholder="Set the time interval", + description="The time interval to fetch the times", + ), + Parameter.build_from( + "Recent Time", + "recent_time", + type=str, + optional=True, + default=None, + placeholder="Select the recent time", + description="The recent time to choose.", + options=FunctionDynamicOptions(func=get_recent_3_times), + ui=ui.UISelect( + refresh=True, + refresh_depends=["time_interval"], + attr=ui.UISelect.UIAttribute(show_search=True), + ), + ), + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Time", + "time", + str, + description="User's selected time.", + ) + ], + ) + + def __init__( + self, time_interval: int = 1, recent_time: Optional[str] = None, **kwargs + ): + super().__init__(**kwargs) + self.time_interval = time_interval + self.recent_time = recent_time + + async def map(self, user_name: str) -> str: + """Map the user name to the time.""" + return "Your name is %s, and you choose the time %s." % ( + user_name, + self.recent_time, + ) From 125765a3ab46d68b1987f0af4cc4769c8c116cef Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Fri, 9 Aug 2024 17:52:23 +0800 Subject: [PATCH 35/89] feat(core): Support variables for AWEL --- .env.template | 6 + dbgpt/_private/config.py | 1 + .../initialization/db_model_initialization.py | 2 + .../initialization/serve_initialization.py | 2 + dbgpt/component.py | 1 + dbgpt/core/awel/dag/base.py | 36 +- dbgpt/core/awel/flow/__init__.py | 2 + dbgpt/core/awel/flow/base.py | 85 ++- dbgpt/core/awel/flow/ui.py | 27 +- dbgpt/core/awel/operators/base.py | 49 +- dbgpt/core/awel/tests/test_dag_variables.py | 111 +++ dbgpt/core/awel/util/parameter_util.py | 176 ++++- dbgpt/core/interface/storage.py | 3 +- dbgpt/core/interface/tests/test_variables.py | 114 +++ dbgpt/core/interface/variables.py | 678 ++++++++++++++++++ dbgpt/serve/core/__init__.py | 12 + dbgpt/serve/flow/api/endpoints.py | 100 ++- dbgpt/serve/flow/api/schemas.py | 65 +- dbgpt/serve/flow/api/variables_provider.py | 260 +++++++ dbgpt/serve/flow/config.py | 5 + dbgpt/serve/flow/models/models.py | 165 ++++- dbgpt/serve/flow/models/variables_adapter.py | 69 ++ dbgpt/serve/flow/serve.py | 37 +- dbgpt/serve/flow/service/variables_service.py | 148 ++++ examples/awel/awel_flow_ui_components.py | 164 +++++ setup.py | 2 + 26 files changed, 2273 insertions(+), 47 deletions(-) create mode 100644 dbgpt/core/awel/tests/test_dag_variables.py create mode 100644 dbgpt/core/interface/tests/test_variables.py create mode 100644 dbgpt/core/interface/variables.py create mode 100644 dbgpt/serve/flow/api/variables_provider.py create mode 100644 dbgpt/serve/flow/models/variables_adapter.py create mode 100644 dbgpt/serve/flow/service/variables_service.py diff --git a/.env.template b/.env.template index f90af90ee..44aa2d710 100644 --- a/.env.template +++ b/.env.template @@ -271,6 +271,12 @@ DBGPT_LOG_LEVEL=INFO # API_KEYS - The list of API keys that are allowed to access the API. Each of the below are an option, separated by commas. # API_KEYS=dbgpt +#*******************************************************************# +#** ENCRYPT **# +#*******************************************************************# +# ENCRYPT KEY - The key used to encrypt and decrypt the data +# ENCRYPT_KEY=your_secret_key + #*******************************************************************# #** Application Config **# diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index 0ea313bac..2dbfac0f0 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -303,6 +303,7 @@ def __init__(self) -> None: ) # global dbgpt api key self.API_KEYS = os.getenv("API_KEYS", None) + self.ENCRYPT_KEY = os.getenv("ENCRYPT_KEY", "your_secret_key") # Non-streaming scene retries self.DBGPT_APP_SCENE_NON_STREAMING_RETRIES_BASE = int( diff --git a/dbgpt/app/initialization/db_model_initialization.py b/dbgpt/app/initialization/db_model_initialization.py index 0749ccdf0..b8808c400 100644 --- a/dbgpt/app/initialization/db_model_initialization.py +++ b/dbgpt/app/initialization/db_model_initialization.py @@ -9,6 +9,7 @@ from dbgpt.serve.agent.db.my_plugin_db import MyPluginEntity from dbgpt.serve.agent.db.plugin_hub_db import PluginHubEntity from dbgpt.serve.flow.models.models import ServeEntity as FlowServeEntity +from dbgpt.serve.flow.models.models import VariablesEntity as FlowVariableEntity from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity from dbgpt.serve.rag.models.models import KnowledgeSpaceEntity from dbgpt.storage.chat_history.chat_history_db import ( @@ -29,4 +30,5 @@ ChatHistoryMessageEntity, ModelInstanceEntity, FlowServeEntity, + FlowVariableEntity, ] diff --git a/dbgpt/app/initialization/serve_initialization.py b/dbgpt/app/initialization/serve_initialization.py index 106da8fc9..f0b9c9e42 100644 --- a/dbgpt/app/initialization/serve_initialization.py +++ b/dbgpt/app/initialization/serve_initialization.py @@ -7,6 +7,8 @@ def register_serve_apps(system_app: SystemApp, cfg: Config): system_app.config.set("dbgpt.app.global.language", cfg.LANGUAGE) if cfg.API_KEYS: system_app.config.set("dbgpt.app.global.api_keys", cfg.API_KEYS) + if cfg.ENCRYPT_KEY: + system_app.config.set("dbgpt.app.global.encrypt_key", cfg.ENCRYPT_KEY) # ################################ Prompt Serve Register Begin ###################################### from dbgpt.serve.prompt.serve import ( diff --git a/dbgpt/component.py b/dbgpt/component.py index bb7a7a9e4..cb88a61ec 100644 --- a/dbgpt/component.py +++ b/dbgpt/component.py @@ -89,6 +89,7 @@ class ComponentType(str, Enum): CONNECTOR_MANAGER = "dbgpt_connector_manager" AGENT_MANAGER = "dbgpt_agent_manager" RESOURCE_MANAGER = "dbgpt_resource_manager" + VARIABLES_PROVIDER = "dbgpt_variables_provider" _EMPTY_DEFAULT_COMPONENT = "_EMPTY_DEFAULT_COMPONENT" diff --git a/dbgpt/core/awel/dag/base.py b/dbgpt/core/awel/dag/base.py index ddcfd52bc..512cd6126 100644 --- a/dbgpt/core/awel/dag/base.py +++ b/dbgpt/core/awel/dag/base.py @@ -11,7 +11,18 @@ from abc import ABC, abstractmethod from collections import deque from concurrent.futures import Executor -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Set, + Union, + cast, +) from dbgpt.component import SystemApp @@ -23,6 +34,9 @@ DependencyType = Union["DependencyMixin", Sequence["DependencyMixin"]] +if TYPE_CHECKING: + from ...interface.variables import VariablesProvider + def _is_async_context(): try: @@ -128,6 +142,8 @@ class DAGVar: # The executor for current DAG, this is used run some sync tasks in async DAG _executor: Optional[Executor] = None + _variables_provider: Optional["VariablesProvider"] = None + @classmethod def enter_dag(cls, dag) -> None: """Enter a DAG context. @@ -221,6 +237,24 @@ def set_executor(cls, executor: Executor) -> None: """ cls._executor = executor + @classmethod + def get_variables_provider(cls) -> Optional["VariablesProvider"]: + """Get the current variables provider. + + Returns: + Optional[VariablesProvider]: The current variables provider + """ + return cls._variables_provider + + @classmethod + def set_variables_provider(cls, variables_provider: "VariablesProvider") -> None: + """Set the current variables provider. + + Args: + variables_provider (VariablesProvider): The variables provider to set + """ + cls._variables_provider = variables_provider + class DAGLifecycle: """The lifecycle of DAG.""" diff --git a/dbgpt/core/awel/flow/__init__.py b/dbgpt/core/awel/flow/__init__.py index 5a173565f..80db5b7e6 100644 --- a/dbgpt/core/awel/flow/__init__.py +++ b/dbgpt/core/awel/flow/__init__.py @@ -7,6 +7,7 @@ BaseDynamicOptions, FunctionDynamicOptions, OptionValue, + VariablesDynamicOptions, ) from .base import ( # noqa: F401 IOField, @@ -35,4 +36,5 @@ "IOField", "BaseDynamicOptions", "FunctionDynamicOptions", + "VariablesDynamicOptions", ] diff --git a/dbgpt/core/awel/flow/base.py b/dbgpt/core/awel/flow/base.py index 846b18baf..57081420e 100644 --- a/dbgpt/core/awel/flow/base.py +++ b/dbgpt/core/awel/flow/base.py @@ -6,7 +6,7 @@ from abc import ABC from datetime import date, datetime from enum import Enum -from typing import Any, Dict, List, Optional, Type, TypeVar, Union, cast +from typing import Any, Dict, List, Literal, Optional, Type, TypeVar, Union, cast from dbgpt._private.pydantic import ( BaseModel, @@ -15,12 +15,14 @@ model_to_dict, model_validator, ) +from dbgpt.component import SystemApp from dbgpt.core.awel.util.parameter_util import ( BaseDynamicOptions, OptionValue, RefreshOptionRequest, ) from dbgpt.core.interface.serialization import Serializable +from dbgpt.util.executor_utils import DefaultExecutorFactory, blocking_func_to_async from .exceptions import FlowMetadataException, FlowParameterMetadataException from .ui import UIComponent @@ -490,11 +492,19 @@ def to_dict(self) -> Dict: dict_value["ui"] = self.ui.to_dict() return dict_value - def refresh(self, request: Optional[RefreshOptionRequest] = None) -> Dict: + async def refresh( + self, + request: Optional[RefreshOptionRequest] = None, + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, + ) -> Dict: """Refresh the options of the parameter. Args: request (RefreshOptionRequest): The request to refresh the options. + trigger (Literal["default", "http"], optional): The trigger type. + Defaults to "default". + system_app (Optional[SystemApp], optional): The system app. Returns: Dict: The response. @@ -503,7 +513,7 @@ def refresh(self, request: Optional[RefreshOptionRequest] = None) -> Dict: if not self.options: dict_value["options"] = None elif isinstance(self.options, BaseDynamicOptions): - values = self.options.refresh(request) + values = self.options.refresh(request, trigger, system_app) dict_value["options"] = [value.to_dict() for value in values] else: dict_value["options"] = [value.to_dict() for value in self.options] @@ -791,18 +801,56 @@ def to_dict(self) -> Dict: ] return dict_value - def refresh(self, request: List[RefreshOptionRequest]) -> Dict: - """Refresh the metadata.""" + async def refresh( + self, + request: List[RefreshOptionRequest], + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, + ) -> Dict: + """Refresh the metadata. + + Args: + request (List[RefreshOptionRequest]): The refresh request + trigger (Literal["default", "http"]): The trigger type, how to trigger + the refresh + system_app (Optional[SystemApp]): The system app + """ + executor = DefaultExecutorFactory.get_instance(system_app).create() + name_to_request = {req.name: req for req in request} parameter_requests = { parameter.name: name_to_request.get(parameter.name) for parameter in self.parameters } - dict_value = self.to_dict() - dict_value["parameters"] = [ - parameter.refresh(parameter_requests.get(parameter.name)) - for parameter in self.parameters - ] + dict_value = model_to_dict(self, exclude={"parameters"}) + parameters = [] + for parameter in self.parameters: + parameter_dict = parameter.to_dict() + parameter_request = parameter_requests.get(parameter.name) + if not parameter.options: + options = None + elif isinstance(parameter.options, BaseDynamicOptions): + options_obj = parameter.options + if options_obj.support_async(system_app, parameter_request): + values = await options_obj.async_refresh( + parameter_request, trigger, system_app + ) + else: + values = await blocking_func_to_async( + executor, + options_obj.refresh, + parameter_request, + trigger, + system_app, + ) + options = [value.to_dict() for value in values] + else: + options = [value.to_dict() for value in self.options] + parameter_dict["options"] = options + parameters.append(parameter_dict) + + dict_value["parameters"] = parameters + return dict_value @@ -1088,14 +1136,23 @@ def metadata_list(self): """Get the metadata list.""" return [item.metadata.to_dict() for item in self._registry.values()] - def refresh( - self, key: str, is_operator: bool, request: List[RefreshOptionRequest] + async def refresh( + self, + key: str, + is_operator: bool, + request: List[RefreshOptionRequest], + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, ) -> Dict: """Refresh the metadata.""" if is_operator: - return _get_operator_class(key).metadata.refresh(request) # type: ignore + return await _get_operator_class(key).metadata.refresh( # type: ignore + request, trigger, system_app + ) else: - return _get_resource_class(key).metadata.refresh(request) + return await _get_resource_class(key).metadata.refresh( + request, trigger, system_app + ) _OPERATOR_REGISTRY: FlowRegistry = FlowRegistry() diff --git a/dbgpt/core/awel/flow/ui.py b/dbgpt/core/awel/flow/ui.py index 91008269e..66b413a9f 100644 --- a/dbgpt/core/awel/flow/ui.py +++ b/dbgpt/core/awel/flow/ui.py @@ -71,7 +71,7 @@ class Editor(BaseModel): class UIComponent(RefreshableMixin, Serializable, BaseModel): """UI component.""" - class UIAttribute(StatusMixin, BaseModel): + class UIAttribute(BaseModel): """Base UI attribute.""" disabled: bool = Field( @@ -106,7 +106,7 @@ def to_dict(self) -> Dict: class UISelect(UIComponent): """Select component.""" - class UIAttribute(UIComponent.UIAttribute): + class UIAttribute(StatusMixin, UIComponent.UIAttribute): """Select attribute.""" show_search: bool = Field( @@ -138,7 +138,7 @@ def check_parameter(self, parameter_dict: Dict[str, Any]): class UICascader(UIComponent): """Cascader component.""" - class UIAttribute(UIComponent.UIAttribute): + class UIAttribute(StatusMixin, UIComponent.UIAttribute): """Cascader attribute.""" show_search: bool = Field( @@ -178,7 +178,7 @@ def check_parameter(self, parameter_dict: Dict[str, Any]): class UIDatePicker(UIComponent): """Date picker component.""" - class UIAttribute(UIComponent.UIAttribute): + class UIAttribute(StatusMixin, UIComponent.UIAttribute): """Date picker attribute.""" placement: Optional[ @@ -199,7 +199,7 @@ class UIAttribute(UIComponent.UIAttribute): class UIInput(UIComponent): """Input component.""" - class UIAttribute(UIComponent.UIAttribute): + class UIAttribute(StatusMixin, UIComponent.UIAttribute): """Input attribute.""" prefix: Optional[str] = Field( @@ -216,7 +216,7 @@ class UIAttribute(UIComponent.UIAttribute): None, description="Whether to show count", ) - maxlength: Optional[int] = Field( + max_length: Optional[int] = Field( None, description="The maximum length of the input", ) @@ -294,7 +294,7 @@ class UIAttribute(UIComponent.UIAttribute): class UITimePicker(UIComponent): """Time picker component.""" - class UIAttribute(UIComponent.UIAttribute): + class UIAttribute(StatusMixin, UIComponent.UIAttribute): """Time picker attribute.""" format: Optional[str] = Field( @@ -377,15 +377,20 @@ class UIAttribute(UIComponent.UIAttribute): ) -class UIVariableInput(UIInput): - """Variable input component.""" +class UIVariablesInput(UIInput): + """Variables input component.""" - ui_type: Literal["variable"] = Field("variable", frozen=True) # type: ignore + ui_type: Literal["variable"] = Field("variables", frozen=True) # type: ignore key: str = Field(..., description="The key of the variable") key_type: Literal["common", "secret"] = Field( "common", description="The type of the key", ) + scope: str = Field("global", description="The scope of the variables") + scope_key: Optional[str] = Field( + None, + description="The key of the scope", + ) refresh: Optional[bool] = Field( True, description="Whether to enable the refresh", @@ -396,7 +401,7 @@ def check_parameter(self, parameter_dict: Dict[str, Any]): self._check_options(parameter_dict.get("options", {})) -class UIPasswordInput(UIVariableInput): +class UIPasswordInput(UIVariablesInput): """Password input component.""" ui_type: Literal["password"] = Field("password", frozen=True) # type: ignore diff --git a/dbgpt/core/awel/operators/base.py b/dbgpt/core/awel/operators/base.py index 58f3acabc..77f056042 100644 --- a/dbgpt/core/awel/operators/base.py +++ b/dbgpt/core/awel/operators/base.py @@ -2,10 +2,12 @@ import asyncio import functools +import logging from abc import ABC, ABCMeta, abstractmethod from contextvars import ContextVar from types import FunctionType from typing import ( + TYPE_CHECKING, Any, AsyncIterator, Dict, @@ -29,6 +31,11 @@ from ..dag.base import DAG, DAGContext, DAGNode, DAGVar from ..task.base import EMPTY_DATA, OUT, T, TaskOutput, is_empty_data +if TYPE_CHECKING: + from ...interface.variables import VariablesProvider + +logger = logging.getLogger(__name__) + F = TypeVar("F", bound=FunctionType) CALL_DATA = Union[Dict[str, Any], Any] @@ -92,6 +99,9 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: kwargs.get("system_app") or DAGVar.get_current_system_app() ) executor = kwargs.get("executor") or DAGVar.get_executor() + variables_provider = ( + kwargs.get("variables_provider") or DAGVar.get_variables_provider() + ) if not executor: if system_app: executor = system_app.get_component( @@ -102,14 +112,24 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: else: executor = DefaultExecutorFactory().create() DAGVar.set_executor(executor) + if not variables_provider: + from ...interface.variables import VariablesProvider + + if system_app: + variables_provider = system_app.get_component( + ComponentType.VARIABLES_PROVIDER, + VariablesProvider, + default_component=None, + ) + else: + from ...interface.variables import StorageVariablesProvider + + variables_provider = StorageVariablesProvider() + DAGVar.set_variables_provider(variables_provider) if not task_id and dag: task_id = dag._new_node_id() runner: Optional[WorkflowRunner] = kwargs.get("runner") or default_runner - # print(f"self: {self}, kwargs dag: {kwargs.get('dag')}, kwargs: {kwargs}") - # for arg in sig_cache.parameters: - # if arg not in kwargs: - # kwargs[arg] = default_args[arg] if not kwargs.get("dag"): kwargs["dag"] = dag if not kwargs.get("task_id"): @@ -120,6 +140,8 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any: kwargs["system_app"] = system_app if not kwargs.get("executor"): kwargs["executor"] = executor + if not kwargs.get("variables_provider"): + kwargs["variables_provider"] = variables_provider real_obj = func(self, *args, **kwargs) return real_obj @@ -150,6 +172,7 @@ def __init__( dag: Optional[DAG] = None, runner: Optional[WorkflowRunner] = None, can_skip_in_branch: bool = True, + variables_provider: Optional["VariablesProvider"] = None, **kwargs, ) -> None: """Create a BaseOperator with an optional workflow runner. @@ -171,6 +194,7 @@ def __init__( self._runner: WorkflowRunner = runner self._dag_ctx: Optional[DAGContext] = None self._can_skip_in_branch = can_skip_in_branch + self._variables_provider = variables_provider @property def current_dag_context(self) -> DAGContext: @@ -199,6 +223,8 @@ async def _run(self, dag_ctx: DAGContext, task_log_id: str) -> TaskOutput[OUT]: if not task_log_id: raise ValueError(f"The task log ID can't be empty, current node {self}") CURRENT_DAG_CONTEXT.set(dag_ctx) + # Resolve variables + await self._resolve_variables(dag_ctx) return await self._do_run(dag_ctx) @abstractmethod @@ -349,6 +375,21 @@ def can_skip_in_branch(self) -> bool: """Check if the operator can be skipped in the branch.""" return self._can_skip_in_branch + async def _resolve_variables(self, _: DAGContext): + from ...interface.variables import VariablesPlaceHolder + + if not self._variables_provider: + return + for attr, value in self.__dict__.items(): + if isinstance(value, VariablesPlaceHolder): + resolved_value = await self.blocking_func_to_async( + value.parse, self._variables_provider + ) + logger.debug( + f"Resolve variable {attr} with value {resolved_value} for {self}" + ) + setattr(self, attr, resolved_value) + def initialize_runner(runner: WorkflowRunner): """Initialize the default runner.""" diff --git a/dbgpt/core/awel/tests/test_dag_variables.py b/dbgpt/core/awel/tests/test_dag_variables.py new file mode 100644 index 000000000..88c9b6660 --- /dev/null +++ b/dbgpt/core/awel/tests/test_dag_variables.py @@ -0,0 +1,111 @@ +from contextlib import asynccontextmanager + +import pytest +import pytest_asyncio + +from ...interface.variables import ( + StorageVariables, + StorageVariablesProvider, + VariablesIdentifier, + VariablesPlaceHolder, +) +from .. import DAG, DAGVar, InputOperator, MapOperator, SimpleInputSource + + +class VariablesOperator(MapOperator[str, str]): + def __init__(self, int_var: int, str_var: str, secret: str, **kwargs): + super().__init__(**kwargs) + self._int_var = int_var + self._str_var = str_var + self._secret = secret + + async def map(self, x: str) -> str: + return ( + f"x: {x}, int_var: {self._int_var}, str_var: {self._str_var}, " + f"secret: {self._secret}" + ) + + +@pytest.fixture +def default_dag(): + with DAG("test_dag") as dag: + input_node = InputOperator(input_source=SimpleInputSource.from_callable()) + map_node = MapOperator(lambda x: x * 2) + input_node >> map_node + return dag + + +@asynccontextmanager +async def _create_variables(**kwargs): + variables_provider = StorageVariablesProvider() + DAGVar.set_variables_provider(variables_provider) + + vars = kwargs.get("vars") + variables = {} + if vars and isinstance(vars, dict): + for param_key, param_var in vars.items(): + key = param_var.get("key") + value = param_var.get("value") + value_type = param_var.get("value_type") + category = param_var.get("category", "common") + id = VariablesIdentifier.from_str_identifier(key) + variables_provider.save( + StorageVariables.from_identifier( + id, value, value_type, label="", category=category + ) + ) + variables[param_key] = VariablesPlaceHolder(param_key, key, value_type) + else: + raise ValueError("vars is required.") + + with DAG("simple_dag") as dag: + map_node = VariablesOperator(**variables) + yield map_node + + +@pytest_asyncio.fixture +async def variables_node(request): + param = getattr(request, "param", {}) + async with _create_variables(**param) as node: + yield node + + +@pytest.mark.asyncio +async def test_default_dag(default_dag: DAG): + leaf_node = default_dag.leaf_nodes[0] + res = await leaf_node.call(2) + assert res == 4 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "variables_node", + [ + ( + { + "vars": { + "int_var": { + "key": "int_key@my_int_var@global", + "value": 0, + "value_type": "int", + }, + "str_var": { + "key": "str_key@my_str_var@global", + "value": "1", + "value_type": "str", + }, + "secret": { + "key": "secret_key@my_secret_var@global", + "value": "2131sdsdf", + "value_type": "str", + "category": "secret", + }, + } + } + ), + ], + indirect=["variables_node"], +) +async def test_input_nodes(variables_node: VariablesOperator): + res = await variables_node.call("test") + assert res == "x: test, int_var: 0, str_var: 1, secret: 2131sdsdf" diff --git a/dbgpt/core/awel/util/parameter_util.py b/dbgpt/core/awel/util/parameter_util.py index 2393aed89..a492169c5 100644 --- a/dbgpt/core/awel/util/parameter_util.py +++ b/dbgpt/core/awel/util/parameter_util.py @@ -2,9 +2,10 @@ import inspect from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Literal, Optional from dbgpt._private.pydantic import BaseModel, Field, model_validator +from dbgpt.component import SystemApp from dbgpt.core.interface.serialization import Serializable _DEFAULT_DYNAMIC_REGISTRY = {} @@ -29,6 +30,21 @@ class RefreshOptionRequest(BaseModel): depends: Optional[List[RefreshOptionDependency]] = Field( None, description="The depends of the refresh config" ) + variables_key: Optional[str] = Field( + None, description="The variables key to refresh" + ) + variables_scope: Optional[str] = Field( + None, description="The variables scope to refresh" + ) + variables_scope_key: Optional[str] = Field( + None, description="The variables scope key to refresh" + ) + variables_sys_code: Optional[str] = Field( + None, description="The system code to refresh" + ) + variables_user_name: Optional[str] = Field( + None, description="The user name to refresh" + ) class OptionValue(Serializable, BaseModel): @@ -49,13 +65,57 @@ def to_dict(self) -> Dict: class BaseDynamicOptions(Serializable, BaseModel, ABC): """The base dynamic options.""" + def support_async( + self, + system_app: Optional[SystemApp] = None, + request: Optional[RefreshOptionRequest] = None, + ) -> bool: + """Whether the dynamic options support async. + + Args: + system_app (Optional[SystemApp]): The system app + request (Optional[RefreshOptionRequest]): The refresh request + + Returns: + bool: Whether the dynamic options support async + """ + return False + def option_values(self) -> List[OptionValue]: """Return the option values of the parameter.""" return self.refresh(None) @abstractmethod - def refresh(self, request: Optional[RefreshOptionRequest]) -> List[OptionValue]: - """Refresh the dynamic options.""" + def refresh( + self, + request: Optional[RefreshOptionRequest] = None, + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, + ) -> List[OptionValue]: + """Refresh the dynamic options. + + Args: + request (Optional[RefreshOptionRequest]): The refresh request + trigger (Literal["default", "http"]): The trigger type, how to trigger + the refresh + system_app (Optional[SystemApp]): The system app + """ + + async def async_refresh( + self, + request: Optional[RefreshOptionRequest] = None, + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, + ) -> List[OptionValue]: + """Refresh the dynamic options async. + + Args: + request (Optional[RefreshOptionRequest]): The refresh request + trigger (Literal["default", "http"]): The trigger type, how to trigger + the refresh + system_app (Optional[SystemApp]): The system app + """ + raise NotImplementedError("The dynamic options does not support async.") class FunctionDynamicOptions(BaseDynamicOptions): @@ -68,7 +128,12 @@ class FunctionDynamicOptions(BaseDynamicOptions): ..., description="The unique id of the function to generate the dynamic options" ) - def refresh(self, request: Optional[RefreshOptionRequest]) -> List[OptionValue]: + def refresh( + self, + request: Optional[RefreshOptionRequest] = None, + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, + ) -> List[OptionValue]: """Refresh the dynamic options.""" if not request or not request.depends: return self.func() @@ -96,6 +161,109 @@ def to_dict(self) -> Dict: return {"func_id": self.func_id} +class VariablesDynamicOptions(BaseDynamicOptions): + """The variables dynamic options.""" + + def support_async( + self, + system_app: Optional[SystemApp] = None, + request: Optional[RefreshOptionRequest] = None, + ) -> bool: + """Whether the dynamic options support async.""" + if not system_app or not request or not request.variables_key: + return False + + from ...interface.variables import BuiltinVariablesProvider + + provider: BuiltinVariablesProvider = system_app.get_component( + request.variables_key, + component_type=BuiltinVariablesProvider, + default_component=None, + ) + if not provider: + return False + return provider.support_async() + + def refresh( + self, + request: Optional[RefreshOptionRequest] = None, + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, + ) -> List[OptionValue]: + """Refresh the dynamic options.""" + if ( + trigger == "default" + or not request + or not request.variables_key + or not request.variables_scope + ): + # Only refresh when trigger is http and request is not None + return [] + if not system_app: + raise ValueError("The system app is required when refresh the variables.") + from ...interface.variables import VariablesProvider + + vp: VariablesProvider = VariablesProvider.get_instance(system_app) + variables = vp.get_variables( + key=request.variables_key, + scope=request.variables_scope, + scope_key=request.variables_scope_key, + sys_code=request.variables_sys_code, + user_name=request.variables_user_name, + ) + options = [] + for var in variables: + options.append( + OptionValue( + label=var.label, + name=var.name, + value=var.value, + ) + ) + return options + + async def async_refresh( + self, + request: Optional[RefreshOptionRequest] = None, + trigger: Literal["default", "http"] = "default", + system_app: Optional[SystemApp] = None, + ) -> List[OptionValue]: + """Refresh the dynamic options async.""" + if ( + trigger == "default" + or not request + or not request.variables_key + or not request.variables_scope + ): + return [] + if not system_app: + raise ValueError("The system app is required when refresh the variables.") + from ...interface.variables import VariablesProvider + + vp: VariablesProvider = VariablesProvider.get_instance(system_app) + variables = await vp.async_get_variables( + key=request.variables_key, + scope=request.variables_scope, + scope_key=request.variables_scope_key, + sys_code=request.variables_sys_code, + user_name=request.variables_user_name, + ) + options = [] + for var in variables: + options.append( + OptionValue( + label=var.label, + name=var.name, + value=var.value, + ) + ) + return options + + def to_dict(self) -> Dict: + """Convert current metadata to json dict.""" + return {"key": self.key} + + def _generate_unique_id(func: Callable) -> str: if func.__name__ == "": func_id = f"lambda_{inspect.getfile(func)}_{inspect.getsourcelines(func)}" diff --git a/dbgpt/core/interface/storage.py b/dbgpt/core/interface/storage.py index 2a61746ec..4bf152ab8 100644 --- a/dbgpt/core/interface/storage.py +++ b/dbgpt/core/interface/storage.py @@ -3,13 +3,14 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, cast -from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource from dbgpt.core.interface.serialization import Serializable, Serializer from dbgpt.util.annotations import PublicAPI from dbgpt.util.i18n_utils import _ from dbgpt.util.pagination_utils import PaginationResult from dbgpt.util.serialization.json_serialization import JsonSerializer +from ..awel.flow import Parameter, ResourceCategory, register_resource + @PublicAPI(stability="beta") class ResourceIdentifier(Serializable, ABC): diff --git a/dbgpt/core/interface/tests/test_variables.py b/dbgpt/core/interface/tests/test_variables.py new file mode 100644 index 000000000..3b7ab8157 --- /dev/null +++ b/dbgpt/core/interface/tests/test_variables.py @@ -0,0 +1,114 @@ +import base64 +import os + +from cryptography.fernet import Fernet + +from ..variables import ( + FernetEncryption, + InMemoryStorage, + SimpleEncryption, + StorageVariables, + StorageVariablesProvider, + VariablesIdentifier, +) + + +def test_fernet_encryption(): + key = Fernet.generate_key() + encryption = FernetEncryption(key) + new_encryption = FernetEncryption(key) + data = "test_data" + salt = "test_salt" + + encrypted_data = encryption.encrypt(data, salt) + assert encrypted_data != data + + decrypted_data = encryption.decrypt(encrypted_data, salt) + assert decrypted_data == data + assert decrypted_data == new_encryption.decrypt(encrypted_data, salt) + + +def test_simple_encryption(): + key = base64.b64encode(os.urandom(32)).decode() + encryption = SimpleEncryption(key) + data = "test_data" + salt = "test_salt" + + encrypted_data = encryption.encrypt(data, salt) + assert encrypted_data != data + + decrypted_data = encryption.decrypt(encrypted_data, salt) + assert decrypted_data == data + + +def test_storage_variables_provider(): + storage = InMemoryStorage() + encryption = SimpleEncryption() + provider = StorageVariablesProvider(storage, encryption) + + full_key = "key@name@global" + value = "secret_value" + value_type = "str" + label = "test_label" + + id = VariablesIdentifier.from_str_identifier(full_key) + provider.save( + StorageVariables.from_identifier( + id, value, value_type, label, category="secret" + ) + ) + + loaded_variable_value = provider.get(full_key) + assert loaded_variable_value == value + + +def test_variables_identifier(): + full_key = "key@name@global@scope_key@sys_code@user_name" + identifier = VariablesIdentifier.from_str_identifier(full_key) + + assert identifier.key == "key" + assert identifier.name == "name" + assert identifier.scope == "global" + assert identifier.scope_key == "scope_key" + assert identifier.sys_code == "sys_code" + assert identifier.user_name == "user_name" + + str_identifier = identifier.str_identifier + assert str_identifier == full_key + + +def test_storage_variables(): + key = "test_key" + name = "test_name" + label = "test_label" + value = "test_value" + value_type = "str" + category = "common" + scope = "global" + + storage_variable = StorageVariables( + key=key, + name=name, + label=label, + value=value, + value_type=value_type, + category=category, + scope=scope, + ) + + assert storage_variable.key == key + assert storage_variable.name == name + assert storage_variable.label == label + assert storage_variable.value == value + assert storage_variable.value_type == value_type + assert storage_variable.category == category + assert storage_variable.scope == scope + + dict_representation = storage_variable.to_dict() + assert dict_representation["key"] == key + assert dict_representation["name"] == name + assert dict_representation["label"] == label + assert dict_representation["value"] == value + assert dict_representation["value_type"] == value_type + assert dict_representation["category"] == category + assert dict_representation["scope"] == scope diff --git a/dbgpt/core/interface/variables.py b/dbgpt/core/interface/variables.py new file mode 100644 index 000000000..8f99d1e30 --- /dev/null +++ b/dbgpt/core/interface/variables.py @@ -0,0 +1,678 @@ +"""Variables Module.""" + +import base64 +import dataclasses +import hashlib +import json +import os +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +from dbgpt.component import BaseComponent, ComponentType, SystemApp +from dbgpt.util.executor_utils import ( + DefaultExecutorFactory, + blocking_func_to_async, + blocking_func_to_async_no_executor, +) + +from .storage import ( + InMemoryStorage, + QuerySpec, + ResourceIdentifier, + StorageInterface, + StorageItem, +) + +_EMPTY_DEFAULT_VALUE = "_EMPTY_DEFAULT_VALUE" + +BUILTIN_VARIABLES_CORE_FLOWS = "dbgpt.core.flow.flows" +BUILTIN_VARIABLES_CORE_FLOW_NODES = "dbgpt.core.flow.nodes" +BUILTIN_VARIABLES_CORE_VARIABLES = "dbgpt.core.variables" +BUILTIN_VARIABLES_CORE_SECRETS = "dbgpt.core.secrets" +BUILTIN_VARIABLES_CORE_LLMS = "dbgpt.core.model.llms" +BUILTIN_VARIABLES_CORE_EMBEDDINGS = "dbgpt.core.model.embeddings" +BUILTIN_VARIABLES_CORE_RERANKERS = "dbgpt.core.model.rerankers" +BUILTIN_VARIABLES_CORE_DATASOURCES = "dbgpt.core.datasources" +BUILTIN_VARIABLES_CORE_AGENTS = "dbgpt.core.agent.agents" +BUILTIN_VARIABLES_CORE_KNOWLEDGE_SPACES = "dbgpt.core.knowledge_spaces" + + +class Encryption(ABC): + """Encryption interface.""" + + name: str = "__abstract__" + + @abstractmethod + def encrypt(self, data: str, salt: str) -> str: + """Encrypt the data.""" + + @abstractmethod + def decrypt(self, encrypted_data: str, salt: str) -> str: + """Decrypt the data.""" + + +def _generate_key_from_password( + password: bytes, salt: Optional[Union[str, bytes]] = None +): + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + + if salt is None: + salt = os.urandom(16) + elif isinstance(salt, str): + salt = salt.encode() + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + iterations=100000, + ) + key = base64.urlsafe_b64encode(kdf.derive(password)) + return key, salt + + +class FernetEncryption(Encryption): + """Fernet encryption. + + A symmetric encryption algorithm that uses the same key for both encryption and + decryption which is powered by the cryptography library. + """ + + name = "fernet" + + def __init__(self, key: Optional[bytes] = None): + """Initialize the fernet encryption.""" + if key is not None and isinstance(key, str): + key = key.encode() + try: + from cryptography.fernet import Fernet + except ImportError: + raise ImportError( + "cryptography is required for encryption, please install by running " + "`pip install cryptography`" + ) + if key is None: + key = Fernet.generate_key() + self.key = key + + def encrypt(self, data: str, salt: str) -> str: + """Encrypt the data with the salt. + + Args: + data (str): The data to encrypt. + salt (str): The salt to use, which is used to derive the key. + + Returns: + str: The encrypted data. + """ + from cryptography.fernet import Fernet + + key, salt = _generate_key_from_password(self.key, salt) + fernet = Fernet(key) + encrypted_secret = fernet.encrypt(data.encode()).decode() + return encrypted_secret + + def decrypt(self, encrypted_data: str, salt: str) -> str: + """Decrypt the data with the salt. + + Args: + encrypted_data (str): The encrypted data. + salt (str): The salt to use, which is used to derive the key. + + Returns: + str: The decrypted data. + """ + from cryptography.fernet import Fernet + + key, salt = _generate_key_from_password(self.key, salt) + fernet = Fernet(key) + return fernet.decrypt(encrypted_data.encode()).decode() + + +class SimpleEncryption(Encryption): + """Simple implementation of encryption. + + A simple encryption algorithm that uses a key to XOR the data. + """ + + name = "simple" + + def __init__(self, key: Optional[str] = None): + """Initialize the simple encryption.""" + if key is None: + key = base64.b64encode(os.urandom(32)).decode() + self.key = key + + def _derive_key(self, salt: str) -> bytes: + return hashlib.pbkdf2_hmac("sha256", self.key.encode(), salt.encode(), 100000) + + def encrypt(self, data: str, salt: str) -> str: + """Encrypt the data with the salt.""" + key = self._derive_key(salt) + encrypted = bytes( + x ^ y for x, y in zip(data.encode(), key * (len(data) // len(key) + 1)) + ) + return base64.b64encode(encrypted).decode() + + def decrypt(self, encrypted_data: str, salt: str) -> str: + """Decrypt the data with the salt.""" + key = self._derive_key(salt) + data = base64.b64decode(encrypted_data) + decrypted = bytes( + x ^ y for x, y in zip(data, key * (len(data) // len(key) + 1)) + ) + return decrypted.decode() + + +@dataclasses.dataclass +class VariablesIdentifier(ResourceIdentifier): + """The variables identifier.""" + + identifier_split: str = dataclasses.field(default="@", init=False) + + key: str + name: str + scope: str = "global" + scope_key: Optional[str] = None + sys_code: Optional[str] = None + user_name: Optional[str] = None + + def __post_init__(self): + """Post init method.""" + if not self.key or not self.name or not self.scope: + raise ValueError("Key, name, and scope are required.") + + if any( + self.identifier_split in key + for key in [ + self.key, + self.name, + self.scope, + self.scope_key, + self.sys_code, + self.user_name, + ] + if key is not None + ): + raise ValueError( + f"identifier_split {self.identifier_split} is not allowed in " + f"key, name, scope, scope_key, sys_code, user_name." + ) + + @property + def str_identifier(self) -> str: + """Return the string identifier of the identifier.""" + return self.identifier_split.join( + key or "" + for key in [ + self.key, + self.name, + self.scope, + self.scope_key, + self.sys_code, + self.user_name, + ] + ) + + def to_dict(self) -> Dict: + """Convert the identifier to a dict. + + Returns: + Dict: The dict of the identifier. + """ + return { + "key": self.key, + "name": self.name, + "scope": self.scope, + "scope_key": self.scope_key, + "sys_code": self.sys_code, + "user_name": self.user_name, + } + + @classmethod + def from_str_identifier( + cls, str_identifier: str, identifier_split: str = "@" + ) -> "VariablesIdentifier": + """Create a VariablesIdentifier from a string identifier. + + Args: + str_identifier (str): The string identifier. + identifier_split (str): The identifier split. + + Returns: + VariablesIdentifier: The VariablesIdentifier. + """ + keys = str_identifier.split(identifier_split) + if not keys: + raise ValueError("Invalid string identifier.") + if len(keys) < 2: + raise ValueError("Invalid string identifier, must have name") + if len(keys) < 3: + raise ValueError("Invalid string identifier, must have scope") + + return cls( + key=keys[0], + name=keys[1], + scope=keys[2], + scope_key=keys[3] if len(keys) > 3 else None, + sys_code=keys[4] if len(keys) > 4 else None, + user_name=keys[5] if len(keys) > 5 else None, + ) + + +@dataclasses.dataclass +class StorageVariables(StorageItem): + """The storage variables.""" + + key: str + name: str + label: str + value: Any + category: Literal["common", "secret"] = "common" + scope: str = "global" + value_type: Optional[str] = None + scope_key: Optional[str] = None + sys_code: Optional[str] = None + user_name: Optional[str] = None + encryption_method: Optional[str] = None + salt: Optional[str] = None + enabled: int = 1 + + _identifier: VariablesIdentifier = dataclasses.field(init=False) + + def __post_init__(self): + """Post init method.""" + self._identifier = VariablesIdentifier( + key=self.key, + name=self.name, + scope=self.scope, + scope_key=self.scope_key, + sys_code=self.sys_code, + user_name=self.user_name, + ) + if not self.value_type: + self.value_type = type(self.value).__name__ + + @property + def identifier(self) -> ResourceIdentifier: + """Return the identifier.""" + return self._identifier + + def merge(self, other: "StorageItem") -> None: + """Merge with another storage variables.""" + if not isinstance(other, StorageVariables): + raise ValueError(f"Cannot merge with {type(other)}") + self.from_object(other) + + def to_dict(self) -> Dict: + """Convert the storage variables to a dict. + + Returns: + Dict: The dict of the storage variables. + """ + return { + **self._identifier.to_dict(), + "label": self.label, + "value": self.value, + "value_type": self.value_type, + "category": self.category, + "encryption_method": self.encryption_method, + "salt": self.salt, + } + + def from_object(self, other: "StorageVariables") -> None: + """Copy the values from another storage variables object.""" + self.label = other.label + self.value = other.value + self.value_type = other.value_type + self.category = other.category + self.scope = other.scope + self.scope_key = other.scope_key + self.sys_code = other.sys_code + self.user_name = other.user_name + self.encryption_method = other.encryption_method + self.salt = other.salt + + @classmethod + def from_identifier( + cls, + identifier: VariablesIdentifier, + value: Any, + value_type: str, + label: str = "", + category: Literal["common", "secret"] = "common", + encryption_method: Optional[str] = None, + salt: Optional[str] = None, + ) -> "StorageVariables": + """Copy the values from an identifier.""" + return cls( + key=identifier.key, + name=identifier.name, + label=label, + value=value, + value_type=value_type, + category=category, + scope=identifier.scope, + scope_key=identifier.scope_key, + sys_code=identifier.sys_code, + user_name=identifier.user_name, + encryption_method=encryption_method, + salt=salt, + ) + + +class VariablesProvider(BaseComponent, ABC): + """The variables provider interface.""" + + name = ComponentType.VARIABLES_PROVIDER.value + + @abstractmethod + def get( + self, full_key: str, default_value: Optional[str] = _EMPTY_DEFAULT_VALUE + ) -> Any: + """Query variables from storage.""" + + @abstractmethod + def save(self, variables_item: StorageVariables) -> None: + """Save variables to storage.""" + + @abstractmethod + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get variables by key.""" + + async def async_get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get variables by key async.""" + raise NotImplementedError("Current variables provider does not support async.") + + def support_async(self) -> bool: + """Whether the variables provider support async.""" + return False + + +class VariablesPlaceHolder: + """The variables place holder.""" + + def __init__( + self, + param_name: str, + full_key: str, + value_type: str, + default_value: Any = _EMPTY_DEFAULT_VALUE, + ): + """Initialize the variables place holder.""" + self.param_name = param_name + self.full_key = full_key + self.value_type = value_type + self.default_value = default_value + + def parse(self, variables_provider: VariablesProvider) -> Any: + """Parse the variables.""" + value = variables_provider.get(self.full_key, self.default_value) + if value: + return self._cast_to_type(value) + else: + return value + + def _cast_to_type(self, value: Any) -> Any: + if self.value_type == "str": + return str(value) + elif self.value_type == "int": + return int(value) + elif self.value_type == "float": + return float(value) + elif self.value_type == "bool": + if value.lower() in ["true", "1"]: + return True + elif value.lower() in ["false", "0"]: + return False + else: + return bool(value) + else: + return value + + def __repr__(self): + """Return the representation of the variables place holder.""" + return ( + f"" + ) + + +class StorageVariablesProvider(VariablesProvider): + """The storage variables provider.""" + + def __init__( + self, + storage: Optional[StorageInterface] = None, + encryption: Optional[Encryption] = None, + system_app: Optional[SystemApp] = None, + key: Optional[str] = None, + ): + """Initialize the storage variables provider.""" + if storage is None: + storage = InMemoryStorage() + self.system_app = system_app + self.encryption = encryption or SimpleEncryption(key) + + self.storage = storage + super().__init__(system_app) + + def init_app(self, system_app: SystemApp): + """Initialize the storage variables provider.""" + self.system_app = system_app + + def get( + self, full_key: str, default_value: Optional[str] = _EMPTY_DEFAULT_VALUE + ) -> Any: + """Query variables from storage.""" + key = VariablesIdentifier.from_str_identifier(full_key) + variable: Optional[StorageVariables] = self.storage.load(key, StorageVariables) + if variable is None: + if default_value == _EMPTY_DEFAULT_VALUE: + raise ValueError(f"Variable {full_key} not found") + return default_value + variable.value = self.deserialize_value(variable.value) + if ( + variable.value is not None + and variable.category == "secret" + and variable.encryption_method + and variable.salt + ): + variable.value = self.encryption.decrypt(variable.value, variable.salt) + return variable.value + + def save(self, variables_item: StorageVariables) -> None: + """Save variables to storage.""" + if variables_item.category == "secret": + salt = base64.b64encode(os.urandom(16)).decode() + variables_item.value = self.encryption.encrypt( + str(variables_item.value), salt + ) + variables_item.encryption_method = self.encryption.name + variables_item.salt = salt + # Replace value to a json serializable object + variables_item.value = self.serialize_value(variables_item.value) + + self.storage.save_or_update(variables_item) + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Query variables from storage.""" + # Try to get builtin variables + is_builtin, builtin_variables = self._get_builtins_variables( + key, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + if is_builtin: + return builtin_variables + variables = self.storage.query( + QuerySpec( + conditions={ + "key": key, + "scope": scope, + "scope_key": scope_key, + "sys_code": sys_code, + "user_name": user_name, + "enabled": 1, + } + ), + StorageVariables, + ) + for variable in variables: + variable.value = self.deserialize_value(variable.value) + return variables + + async def async_get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Query variables from storage async.""" + # Try to get builtin variables + is_builtin, builtin_variables = await self._async_get_builtins_variables( + key, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + if is_builtin: + return builtin_variables + executor_factory: Optional[ + DefaultExecutorFactory + ] = DefaultExecutorFactory.get_instance(self.system_app, default_component=None) + if executor_factory: + return await blocking_func_to_async( + executor_factory.create(), + self.get_variables, + key, + scope, + scope_key, + sys_code, + user_name, + ) + else: + return await blocking_func_to_async_no_executor( + self.get_variables, key, scope, scope_key, sys_code, user_name + ) + + def _get_builtins_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> Tuple[bool, List[StorageVariables]]: + """Get builtin variables.""" + if self.system_app is None: + return False, [] + provider: BuiltinVariablesProvider = self.system_app.get_component( + key, + component_type=BuiltinVariablesProvider, + default_component=None, + ) + if not provider: + return False, [] + return True, provider.get_variables( + key, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + + async def _async_get_builtins_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> Tuple[bool, List[StorageVariables]]: + """Get builtin variables.""" + if self.system_app is None: + return False, [] + provider: BuiltinVariablesProvider = self.system_app.get_component( + key, + component_type=BuiltinVariablesProvider, + default_component=None, + ) + if not provider: + return False, [] + if not provider.support_async(): + return False, [] + return True, await provider.async_get_variables( + key, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + + @classmethod + def serialize_value(cls, value: Any) -> str: + """Serialize the value.""" + value_dict = {"value": value} + return json.dumps(value_dict, ensure_ascii=False) + + @classmethod + def deserialize_value(cls, value: str) -> Any: + """Deserialize the value.""" + value_dict = json.loads(value) + return value_dict["value"] + + +class BuiltinVariablesProvider(VariablesProvider, ABC): + """The builtin variables provider. + + You can implement this class to provide builtin variables. Such LLMs, agents, + datasource, knowledge base, etc. + """ + + name = "dbgpt_variables_builtin" + + def __init__(self, system_app: Optional[SystemApp] = None): + """Initialize the builtin variables provider.""" + self.system_app = system_app + super().__init__(system_app) + + def init_app(self, system_app: SystemApp): + """Initialize the builtin variables provider.""" + self.system_app = system_app + + def get( + self, full_key: str, default_value: Optional[str] = _EMPTY_DEFAULT_VALUE + ) -> Any: + """Query variables from storage.""" + raise NotImplementedError("BuiltinVariablesProvider does not support get.") + + def save(self, variables_item: StorageVariables) -> None: + """Save variables to storage.""" + raise NotImplementedError("BuiltinVariablesProvider does not support save.") diff --git a/dbgpt/serve/core/__init__.py b/dbgpt/serve/core/__init__.py index 090288128..31edd5d6c 100644 --- a/dbgpt/serve/core/__init__.py +++ b/dbgpt/serve/core/__init__.py @@ -1,7 +1,11 @@ +from typing import Any + from dbgpt.serve.core.config import BaseServeConfig from dbgpt.serve.core.schemas import Result, add_exception_handler from dbgpt.serve.core.serve import BaseServe from dbgpt.serve.core.service import BaseService +from dbgpt.util.executor_utils import BlockingFunction, DefaultExecutorFactory +from dbgpt.util.executor_utils import blocking_func_to_async as _blocking_func_to_async __ALL__ = [ "Result", @@ -10,3 +14,11 @@ "BaseService", "BaseServe", ] + + +async def blocking_func_to_async( + system_app, func: BlockingFunction, *args, **kwargs +) -> Any: + """Run a potentially blocking function within an executor.""" + executor = DefaultExecutorFactory.get_instance(system_app).create() + return await _blocking_func_to_async(executor, func, *args, **kwargs) diff --git a/dbgpt/serve/flow/api/endpoints.py b/dbgpt/serve/flow/api/endpoints.py index 99852271a..74da7dd72 100644 --- a/dbgpt/serve/flow/api/endpoints.py +++ b/dbgpt/serve/flow/api/endpoints.py @@ -1,17 +1,32 @@ from functools import cache -from typing import List, Optional, Union +from typing import List, Literal, Optional, Union from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer from dbgpt.component import SystemApp from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata -from dbgpt.serve.core import Result +from dbgpt.core.interface.variables import ( + BUILTIN_VARIABLES_CORE_FLOW_NODES, + BUILTIN_VARIABLES_CORE_FLOWS, + BUILTIN_VARIABLES_CORE_SECRETS, + BUILTIN_VARIABLES_CORE_VARIABLES, + BuiltinVariablesProvider, + StorageVariables, +) +from dbgpt.serve.core import Result, blocking_func_to_async from dbgpt.util import PaginationResult from ..config import APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig from ..service.service import Service -from .schemas import RefreshNodeRequest, ServeRequest, ServerResponse +from ..service.variables_service import VariablesService +from .schemas import ( + RefreshNodeRequest, + ServeRequest, + ServerResponse, + VariablesRequest, + VariablesResponse, +) router = APIRouter() @@ -22,7 +37,12 @@ def get_service() -> Service: """Get the service instance""" - return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service) + return Service.get_instance(global_system_app) + + +def get_variable_service() -> VariablesService: + """Get the service instance""" + return VariablesService.get_instance(global_system_app) get_bearer_token = HTTPBearer(auto_error=False) @@ -231,16 +251,80 @@ async def refresh_nodes(refresh_request: RefreshNodeRequest): """ from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY - new_metadata = _OPERATOR_REGISTRY.refresh( - key=refresh_request.id, - is_operator=refresh_request.flow_type == "operator", - request=refresh_request.refresh, + # Make sure the variables provider is initialized + _ = get_variable_service().variables_provider + + new_metadata = await _OPERATOR_REGISTRY.refresh( + refresh_request.id, + refresh_request.flow_type == "operator", + refresh_request.refresh, + "http", + global_system_app, ) return Result.succ(new_metadata) +@router.post( + "/variables", + response_model=Result[VariablesResponse], + dependencies=[Depends(check_api_key)], +) +async def create_variables( + variables_request: VariablesRequest, +) -> Result[VariablesResponse]: + """Create a new Variables entity + + Args: + variables_request (VariablesRequest): The request + Returns: + VariablesResponse: The response + """ + res = await blocking_func_to_async( + global_system_app, get_variable_service().create, variables_request + ) + return Result.succ(res) + + +@router.put( + "/variables/{v_id}", + response_model=Result[VariablesResponse], + dependencies=[Depends(check_api_key)], +) +async def update_variables( + v_id: int, variables_request: VariablesRequest +) -> Result[VariablesResponse]: + """Update a Variables entity + + Args: + v_id (int): The variable id + variables_request (VariablesRequest): The request + Returns: + VariablesResponse: The response + """ + res = await blocking_func_to_async( + global_system_app, get_variable_service().update, v_id, variables_request + ) + return Result.succ(res) + + def init_endpoints(system_app: SystemApp) -> None: """Initialize the endpoints""" + from .variables_provider import ( + BuiltinAllSecretVariablesProvider, + BuiltinAllVariablesProvider, + BuiltinEmbeddingsVariablesProvider, + BuiltinFlowVariablesProvider, + BuiltinLLMVariablesProvider, + BuiltinNodeVariablesProvider, + ) + global global_system_app system_app.register(Service) + system_app.register(VariablesService) + system_app.register(BuiltinFlowVariablesProvider) + system_app.register(BuiltinNodeVariablesProvider) + system_app.register(BuiltinAllVariablesProvider) + system_app.register(BuiltinAllSecretVariablesProvider) + system_app.register(BuiltinLLMVariablesProvider) + system_app.register(BuiltinEmbeddingsVariablesProvider) global_system_app = system_app diff --git a/dbgpt/serve/flow/api/schemas.py b/dbgpt/serve/flow/api/schemas.py index 2daa8f581..e63d3e6ce 100644 --- a/dbgpt/serve/flow/api/schemas.py +++ b/dbgpt/serve/flow/api/schemas.py @@ -1,4 +1,4 @@ -from typing import List, Literal +from typing import Any, List, Literal, Optional from dbgpt._private.pydantic import BaseModel, ConfigDict, Field from dbgpt.core.awel.flow.flow_factory import FlowPanel @@ -17,6 +17,69 @@ class ServerResponse(FlowPanel): model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}") +class VariablesRequest(BaseModel): + """Variable request model. + + For creating a new variable in the DB-GPT. + """ + + key: str = Field( + ..., + description="The key of the variable to create", + examples=["dbgpt.model.openai.api_key"], + ) + name: str = Field( + ..., + description="The name of the variable to create", + examples=["my_first_openai_key"], + ) + label: str = Field( + ..., + description="The label of the variable to create", + examples=["My First OpenAI Key"], + ) + value: Any = Field( + ..., description="The value of the variable to create", examples=["1234567890"] + ) + value_type: Literal["str", "int", "float", "bool"] = Field( + "str", + description="The type of the value of the variable to create", + examples=["str", "int", "float", "bool"], + ) + category: Literal["common", "secret"] = Field( + ..., + description="The category of the variable to create", + examples=["common"], + ) + scope: str = Field( + ..., + description="The scope of the variable to create", + examples=["global"], + ) + scope_key: Optional[str] = Field( + ..., + description="The scope key of the variable to create", + examples=["dbgpt"], + ) + enabled: Optional[bool] = Field( + True, + description="Whether the variable is enabled", + examples=[True], + ) + user_name: Optional[str] = Field(None, description="User name") + sys_code: Optional[str] = Field(None, description="System code") + + +class VariablesResponse(VariablesRequest): + """Variable response model.""" + + id: int = Field( + ..., + description="The id of the variable", + examples=[1], + ) + + class RefreshNodeRequest(BaseModel): """Flow response model""" diff --git a/dbgpt/serve/flow/api/variables_provider.py b/dbgpt/serve/flow/api/variables_provider.py new file mode 100644 index 000000000..4728f80e6 --- /dev/null +++ b/dbgpt/serve/flow/api/variables_provider.py @@ -0,0 +1,260 @@ +from typing import List, Literal, Optional + +from dbgpt.core.interface.variables import ( + BUILTIN_VARIABLES_CORE_EMBEDDINGS, + BUILTIN_VARIABLES_CORE_FLOW_NODES, + BUILTIN_VARIABLES_CORE_FLOWS, + BUILTIN_VARIABLES_CORE_LLMS, + BUILTIN_VARIABLES_CORE_SECRETS, + BUILTIN_VARIABLES_CORE_VARIABLES, + BuiltinVariablesProvider, + StorageVariables, +) + +from ..service.service import Service +from .endpoints import get_service, get_variable_service +from .schemas import ServerResponse + + +class BuiltinFlowVariablesProvider(BuiltinVariablesProvider): + """Builtin flow variables provider. + + Provide all flows by variables "${dbgpt.core.flow.flows}" + """ + + name = BUILTIN_VARIABLES_CORE_FLOWS + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + service: Service = get_service() + page_result = service.get_list_by_page( + { + "user_name": user_name, + "sys_code": sys_code, + }, + 1, + 1000, + ) + flows: List[ServerResponse] = page_result.items + variables = [] + for flow in flows: + variables.append( + StorageVariables( + key=key, + name=flow.name, + label=flow.label, + value=flow.uid, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + ) + return variables + + +class BuiltinNodeVariablesProvider(BuiltinVariablesProvider): + """Builtin node variables provider. + + Provide all nodes by variables "${dbgpt.core.flow.nodes}" + """ + + name = BUILTIN_VARIABLES_CORE_FLOW_NODES + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables.""" + from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY + + metadata_list = _OPERATOR_REGISTRY.metadata_list() + variables = [] + for metadata in metadata_list: + variables.append( + StorageVariables( + key=key, + name=metadata["name"], + label=metadata["label"], + value=metadata["id"], + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + ) + return variables + + +class BuiltinAllVariablesProvider(BuiltinVariablesProvider): + """Builtin all variables provider. + + Provide all variables by variables "${dbgpt.core.variables}" + """ + + name = BUILTIN_VARIABLES_CORE_VARIABLES + + def _get_variables_from_db( + self, + key: str, + scope: str, + scope_key: Optional[str], + sys_code: Optional[str], + user_name: Optional[str], + category: Literal["common", "secret"] = "common", + ) -> List[StorageVariables]: + storage_variables = get_variable_service().list_all_variables(category) + variables = [] + for var in storage_variables: + variables.append( + StorageVariables( + key=key, + name=var.name, + label=var.label, + value=var.value, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + ) + return variables + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables. + + TODO: Return all builtin variables + """ + return self._get_variables_from_db(key, scope, scope_key, sys_code, user_name) + + +class BuiltinAllSecretVariablesProvider(BuiltinAllVariablesProvider): + """Builtin all secret variables provider. + + Provide all secret variables by variables "${dbgpt.core.secrets}" + """ + + name = BUILTIN_VARIABLES_CORE_SECRETS + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables.""" + return self._get_variables_from_db( + key, scope, scope_key, sys_code, user_name, "secret" + ) + + +class BuiltinLLMVariablesProvider(BuiltinVariablesProvider): + """Builtin LLM variables provider. + + Provide all LLM variables by variables "${dbgpt.core.llmv}" + """ + + name = BUILTIN_VARIABLES_CORE_LLMS + + def support_async(self) -> bool: + """Whether the dynamic options support async.""" + return True + + async def _get_models( + self, + key: str, + scope: str, + scope_key: Optional[str], + sys_code: Optional[str], + user_name: Optional[str], + expect_worker_type: str = "llm", + ) -> List[StorageVariables]: + from dbgpt.model.cluster.controller.controller import BaseModelController + + controller = BaseModelController.get_instance(self.system_app) + models = await controller.get_all_instances(healthy_only=True) + model_dict = {} + for model in models: + worker_name, worker_type = model.model_name.split("@") + if expect_worker_type == worker_type: + model_dict[worker_name] = model + variables = [] + for worker_name, model in model_dict.items(): + variables.append( + StorageVariables( + key=key, + name=worker_name, + label=worker_name, + value=worker_name, + scope=scope, + scope_key=scope_key, + sys_code=sys_code, + user_name=user_name, + ) + ) + return variables + + async def async_get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables.""" + return await self._get_models(key, scope, scope_key, sys_code, user_name) + + def get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables.""" + raise NotImplementedError( + "Not implemented get variables sync, please use async_get_variables" + ) + + +class BuiltinEmbeddingsVariablesProvider(BuiltinLLMVariablesProvider): + """Builtin embeddings variables provider. + + Provide all embeddings variables by variables "${dbgpt.core.embeddings}" + """ + + name = BUILTIN_VARIABLES_CORE_EMBEDDINGS + + async def async_get_variables( + self, + key: str, + scope: str = "global", + scope_key: Optional[str] = None, + sys_code: Optional[str] = None, + user_name: Optional[str] = None, + ) -> List[StorageVariables]: + """Get the builtin variables.""" + return await self._get_models( + key, scope, scope_key, sys_code, user_name, "text2vec" + ) diff --git a/dbgpt/serve/flow/config.py b/dbgpt/serve/flow/config.py index 97eea7478..0cc35667d 100644 --- a/dbgpt/serve/flow/config.py +++ b/dbgpt/serve/flow/config.py @@ -8,8 +8,10 @@ SERVE_APP_NAME_HUMP = "dbgpt_serve_Flow" SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.flow." SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service" +SERVE_VARIABLES_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_variables_service" # Database table name SERVER_APP_TABLE_NAME = "dbgpt_serve_flow" +SERVER_APP_VARIABLES_TABLE_NAME = "dbgpt_serve_variables" @dataclass @@ -23,3 +25,6 @@ class ServeConfig(BaseServeConfig): load_dbgpts_interval: int = field( default=5, metadata={"help": "Interval to load dbgpts from installed packages"} ) + encrypt_key: Optional[str] = field( + default=None, metadata={"help": "The key to encrypt the data"} + ) diff --git a/dbgpt/serve/flow/models/models.py b/dbgpt/serve/flow/models/models.py index ea4c7f3ea..c4166147d 100644 --- a/dbgpt/serve/flow/models/models.py +++ b/dbgpt/serve/flow/models/models.py @@ -10,11 +10,17 @@ from dbgpt._private.pydantic import model_to_dict from dbgpt.core.awel.flow.flow_factory import State +from dbgpt.core.interface.variables import StorageVariablesProvider from dbgpt.storage.metadata import BaseDao, Model from dbgpt.storage.metadata._base_dao import QUERY_SPEC -from ..api.schemas import ServeRequest, ServerResponse -from ..config import SERVER_APP_TABLE_NAME, ServeConfig +from ..api.schemas import ( + ServeRequest, + ServerResponse, + VariablesRequest, + VariablesResponse, +) +from ..config import SERVER_APP_TABLE_NAME, SERVER_APP_VARIABLES_TABLE_NAME, ServeConfig class ServeEntity(Model): @@ -74,6 +80,56 @@ def to_bool_editable(cls, editable: int) -> bool: return editable is None or editable == 0 +class VariablesEntity(Model): + __tablename__ = SERVER_APP_VARIABLES_TABLE_NAME + + id = Column(Integer, primary_key=True, comment="Auto increment id") + key = Column(String(128), index=True, nullable=False, comment="Variable key") + name = Column(String(128), index=True, nullable=True, comment="Variable name") + label = Column(String(128), nullable=True, comment="Variable label") + value = Column(Text, nullable=True, comment="Variable value, JSON format") + value_type = Column( + String(32), + nullable=True, + comment="Variable value type(string, int, float, bool)", + ) + category = Column( + String(32), + default="common", + nullable=True, + comment="Variable category(common or secret)", + ) + encryption_method = Column( + String(32), + nullable=True, + comment="Variable encryption method(fernet, simple, rsa, aes)", + ) + salt = Column(String(128), nullable=True, comment="Variable salt") + scope = Column( + String(32), + default="global", + nullable=True, + comment="Variable scope(global,flow,app,agent,datasource,flow:uid," + "flow:dag_name,agent:agent_name) etc", + ) + scope_key = Column( + String(256), + nullable=True, + comment="Variable scope key, default is empty, for scope is 'flow:uid', " + "the scope_key is uid of flow", + ) + enabled = Column( + Integer, + default=1, + nullable=True, + comment="Variable enabled, 0: disabled, 1: enabled", + ) + user_name = Column(String(128), index=True, nullable=True, comment="User name") + sys_code = Column(String(128), index=True, nullable=True, comment="System code") + gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") + gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time") + + class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): """The DAO class for Flow""" @@ -222,3 +278,108 @@ def update( session.merge(entry) session.commit() return self.get_one(query_request) + + +class VariablesDao(BaseDao[VariablesEntity, VariablesRequest, VariablesResponse]): + """The DAO class for Variables""" + + def __init__(self, serve_config: ServeConfig): + super().__init__() + self._serve_config = serve_config + + def from_request( + self, request: Union[VariablesRequest, Dict[str, Any]] + ) -> VariablesEntity: + """Convert the request to an entity + + Args: + request (Union[VariablesRequest, Dict[str, Any]]): The request + + Returns: + T: The entity + """ + request_dict = ( + model_to_dict(request) if isinstance(request, VariablesRequest) else request + ) + value = StorageVariablesProvider.serialize_value(request_dict.get("value")) + enabled = 1 if request_dict.get("enabled", True) else 0 + new_dict = { + "key": request_dict.get("key"), + "name": request_dict.get("name"), + "label": request_dict.get("label"), + "value": value, + "value_type": request_dict.get("value_type"), + "category": request_dict.get("category"), + "encryption_method": request_dict.get("encryption_method"), + "salt": request_dict.get("salt"), + "scope": request_dict.get("scope"), + "scope_key": request_dict.get("scope_key"), + "enabled": enabled, + "user_name": request_dict.get("user_name"), + "sys_code": request_dict.get("sys_code"), + } + entity = VariablesEntity(**new_dict) + return entity + + def to_request(self, entity: VariablesEntity) -> VariablesRequest: + """Convert the entity to a request + + Args: + entity (T): The entity + + Returns: + REQ: The request + """ + value = StorageVariablesProvider.deserialize_value(entity.value) + if entity.category == "secret": + value = "******" + enabled = entity.enabled == 1 + return VariablesRequest( + key=entity.key, + name=entity.name, + label=entity.label, + value=value, + value_type=entity.value_type, + category=entity.category, + encryption_method=entity.encryption_method, + salt=entity.salt, + scope=entity.scope, + scope_key=entity.scope_key, + enabled=enabled, + user_name=entity.user_name, + sys_code=entity.sys_code, + ) + + def to_response(self, entity: VariablesEntity) -> VariablesResponse: + """Convert the entity to a response + + Args: + entity (T): The entity + + Returns: + RES: The response + """ + value = StorageVariablesProvider.deserialize_value(entity.value) + if entity.category == "secret": + value = "******" + gmt_created_str = entity.gmt_created.strftime("%Y-%m-%d %H:%M:%S") + gmt_modified_str = entity.gmt_modified.strftime("%Y-%m-%d %H:%M:%S") + enabled = entity.enabled == 1 + return VariablesResponse( + id=entity.id, + key=entity.key, + name=entity.name, + label=entity.label, + value=value, + value_type=entity.value_type, + category=entity.category, + encryption_method=entity.encryption_method, + salt=entity.salt, + scope=entity.scope, + scope_key=entity.scope_key, + enabled=enabled, + user_name=entity.user_name, + sys_code=entity.sys_code, + gmt_created=gmt_created_str, + gmt_modified=gmt_modified_str, + ) diff --git a/dbgpt/serve/flow/models/variables_adapter.py b/dbgpt/serve/flow/models/variables_adapter.py new file mode 100644 index 000000000..d8a1ef1e0 --- /dev/null +++ b/dbgpt/serve/flow/models/variables_adapter.py @@ -0,0 +1,69 @@ +from typing import Type + +from sqlalchemy.orm import Session + +from dbgpt.core.interface.storage import StorageItemAdapter +from dbgpt.core.interface.variables import StorageVariables, VariablesIdentifier + +from .models import VariablesEntity + + +class VariablesAdapter(StorageItemAdapter[StorageVariables, VariablesEntity]): + """Variables adapter. + + Convert between storage format and database model. + """ + + def to_storage_format(self, item: StorageVariables) -> VariablesEntity: + """Convert to storage format.""" + return VariablesEntity( + key=item.key, + name=item.name, + label=item.label, + value=item.value, + value_type=item.value_type, + category=item.category, + encryption_method=item.encryption_method, + salt=item.salt, + scope=item.scope, + scope_key=item.scope_key, + sys_code=item.sys_code, + user_name=item.user_name, + ) + + def from_storage_format(self, model: VariablesEntity) -> StorageVariables: + """Convert from storage format.""" + return StorageVariables( + key=model.key, + name=model.name, + label=model.label, + value=model.value, + value_type=model.value_type, + category=model.category, + encryption_method=model.encryption_method, + salt=model.salt, + scope=model.scope, + scope_key=model.scope_key, + sys_code=model.sys_code, + user_name=model.user_name, + ) + + def get_query_for_identifier( + self, + storage_format: Type[VariablesEntity], + resource_id: VariablesIdentifier, + **kwargs, + ): + """Get query for identifier.""" + session: Session = kwargs.get("session") + if session is None: + raise Exception("session is None") + query_obj = session.query(VariablesEntity) + for key, value in resource_id.to_dict().items(): + if value is None: + continue + query_obj = query_obj.filter(getattr(VariablesEntity, key) == value) + + # enabled must be True + query_obj = query_obj.filter(VariablesEntity.enabled == 1) + return query_obj diff --git a/dbgpt/serve/flow/serve.py b/dbgpt/serve/flow/serve.py index 126841e57..a27e3d28f 100644 --- a/dbgpt/serve/flow/serve.py +++ b/dbgpt/serve/flow/serve.py @@ -4,6 +4,7 @@ from sqlalchemy import URL from dbgpt.component import SystemApp +from dbgpt.core.interface.variables import VariablesProvider from dbgpt.serve.core import BaseServe from dbgpt.storage.metadata import DatabaseManager @@ -40,6 +41,8 @@ def __init__( system_app, api_prefix, api_tags, db_url_or_db, try_create_tables ) self._db_manager: Optional[DatabaseManager] = None + self._variables_provider: Optional[VariablesProvider] = None + self._serve_config: Optional[ServeConfig] = None def init_app(self, system_app: SystemApp): if self._app_has_initiated: @@ -62,5 +65,37 @@ def on_init(self): def before_start(self): """Called before the start of the application.""" - # TODO: Your code here + from dbgpt.core.interface.variables import ( + FernetEncryption, + StorageVariablesProvider, + ) + from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage + from dbgpt.util.serialization.json_serialization import JsonSerializer + + from .models.models import ServeEntity, VariablesEntity + from .models.variables_adapter import VariablesAdapter + self._db_manager = self.create_or_get_db_manager() + self._serve_config = ServeConfig.from_app_config( + self._system_app.config, SERVE_CONFIG_KEY_PREFIX + ) + + self._db_manager = self.create_or_get_db_manager() + storage_adapter = VariablesAdapter() + serializer = JsonSerializer() + storage = SQLAlchemyStorage( + self._db_manager, + VariablesEntity, + storage_adapter, + serializer, + ) + self._variables_provider = StorageVariablesProvider( + storage=storage, + encryption=FernetEncryption(self._serve_config.encrypt_key), + system_app=self._system_app, + ) + + @property + def variables_provider(self): + """Get the variables provider of the serve app with db storage""" + return self._variables_provider diff --git a/dbgpt/serve/flow/service/variables_service.py b/dbgpt/serve/flow/service/variables_service.py new file mode 100644 index 000000000..4b79d27db --- /dev/null +++ b/dbgpt/serve/flow/service/variables_service.py @@ -0,0 +1,148 @@ +from typing import List, Optional + +from dbgpt import SystemApp +from dbgpt.core.interface.variables import StorageVariables, VariablesProvider +from dbgpt.serve.core import BaseService + +from ..api.schemas import VariablesRequest, VariablesResponse +from ..config import ( + SERVE_CONFIG_KEY_PREFIX, + SERVE_VARIABLES_SERVICE_COMPONENT_NAME, + ServeConfig, +) +from ..models.models import VariablesDao, VariablesEntity + + +class VariablesService( + BaseService[VariablesEntity, VariablesRequest, VariablesResponse] +): + """Variables service""" + + name = SERVE_VARIABLES_SERVICE_COMPONENT_NAME + + def __init__(self, system_app: SystemApp, dao: Optional[VariablesDao] = None): + self._system_app = None + self._serve_config: ServeConfig = None + self._dao: VariablesDao = dao + + super().__init__(system_app) + + def init_app(self, system_app: SystemApp) -> None: + """Initialize the service + + Args: + system_app (SystemApp): The system app + """ + super().init_app(system_app) + + self._serve_config = ServeConfig.from_app_config( + system_app.config, SERVE_CONFIG_KEY_PREFIX + ) + self._dao = self._dao or VariablesDao(self._serve_config) + self._system_app = system_app + + @property + def dao(self) -> VariablesDao: + """Returns the internal DAO.""" + return self._dao + + @property + def variables_provider(self) -> VariablesProvider: + """Returns the internal VariablesProvider. + + Returns: + VariablesProvider: The internal VariablesProvider + """ + variables_provider = VariablesProvider.get_instance( + self._system_app, default_component=None + ) + if variables_provider: + return variables_provider + else: + from ..serve import Serve + + variables_provider = Serve.get_instance(self._system_app).variables_provider + self._system_app.register_instance(variables_provider) + return variables_provider + + @property + def config(self) -> ServeConfig: + """Returns the internal ServeConfig.""" + return self._serve_config + + def create(self, request: VariablesRequest) -> VariablesResponse: + """Create a new entity + + Args: + request (VariablesRequest): The request + + Returns: + VariablesResponse: The response + """ + variables = StorageVariables( + key=request.key, + name=request.name, + label=request.label, + value=request.value, + value_type=request.value_type, + category=request.category, + scope=request.scope, + scope_key=request.scope_key, + user_name=request.user_name, + sys_code=request.sys_code, + ) + self.variables_provider.save(variables) + query = { + "key": request.key, + "name": request.name, + "scope": request.scope, + "scope_key": request.scope_key, + "sys_code": request.sys_code, + "user_name": request.user_name, + "enabled": request.enabled, + } + return self.dao.get_one(query) + + def update(self, _: int, request: VariablesRequest) -> VariablesResponse: + """Update variables. + + Args: + request (VariablesRequest): The request + + Returns: + VariablesResponse: The response + """ + variables = StorageVariables( + key=request.key, + name=request.name, + label=request.label, + value=request.value, + value_type=request.value_type, + category=request.category, + scope=request.scope, + scope_key=request.scope_key, + user_name=request.user_name, + sys_code=request.sys_code, + ) + exist_value = self.variables_provider.get( + variables.identifier.str_identifier, None + ) + if exist_value is None: + raise ValueError( + f"Variable {variables.identifier.str_identifier} not found" + ) + self.variables_provider.save(variables) + query = { + "key": request.key, + "name": request.name, + "scope": request.scope, + "scope_key": request.scope_key, + "sys_code": request.sys_code, + "user_name": request.user_name, + "enabled": request.enabled, + } + return self.dao.get_one(query) + + def list_all_variables(self, category: str = "common") -> List[VariablesResponse]: + """List all variables.""" + return self.dao.get_list({"enabled": True, "category": category}) diff --git a/examples/awel/awel_flow_ui_components.py b/examples/awel/awel_flow_ui_components.py index fc8d9a5c4..7a38f8d4b 100644 --- a/examples/awel/awel_flow_ui_components.py +++ b/examples/awel/awel_flow_ui_components.py @@ -1,5 +1,6 @@ """Some UI components for the AWEL flow.""" +import json import logging from typing import List, Optional @@ -10,9 +11,18 @@ OperatorCategory, OptionValue, Parameter, + VariablesDynamicOptions, ViewMetadata, ui, ) +from dbgpt.core.interface.variables import ( + BUILTIN_VARIABLES_CORE_EMBEDDINGS, + BUILTIN_VARIABLES_CORE_FLOW_NODES, + BUILTIN_VARIABLES_CORE_FLOWS, + BUILTIN_VARIABLES_CORE_LLMS, + BUILTIN_VARIABLES_CORE_SECRETS, + BUILTIN_VARIABLES_CORE_VARIABLES, +) logger = logging.getLogger(__name__) @@ -717,3 +727,157 @@ async def map(self, user_name: str) -> str: user_name, self.recent_time, ) + + +class ExampleFlowVariablesOperator(MapOperator[str, str]): + """An example flow operator that includes a variables option.""" + + metadata = ViewMetadata( + label="Example Variables Operator", + name="example_variables_operator", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a variables option.", + parameters=[ + Parameter.build_from( + "OpenAI API Key", + "openai_api_key", + type=str, + placeholder="Please select the OpenAI API key", + description="The OpenAI API key to use.", + options=VariablesDynamicOptions(), + ui=ui.UIPasswordInput( + key="dbgpt.model.openai.api_key", + ), + ), + Parameter.build_from( + "Model", + "model", + type=str, + placeholder="Please select the model", + description="The model to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key="dbgpt.model.openai.model", + ), + ), + Parameter.build_from( + "Builtin Flows", + "builtin_flow", + type=str, + placeholder="Please select the builtin flows", + description="The builtin flows to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key=BUILTIN_VARIABLES_CORE_FLOWS, + ), + ), + Parameter.build_from( + "Builtin Flow Nodes", + "builtin_flow_node", + type=str, + placeholder="Please select the builtin flow nodes", + description="The builtin flow nodes to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key=BUILTIN_VARIABLES_CORE_FLOW_NODES, + ), + ), + Parameter.build_from( + "Builtin Variables", + "builtin_variable", + type=str, + placeholder="Please select the builtin variables", + description="The builtin variables to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key=BUILTIN_VARIABLES_CORE_VARIABLES, + ), + ), + Parameter.build_from( + "Builtin Secrets", + "builtin_secret", + type=str, + placeholder="Please select the builtin secrets", + description="The builtin secrets to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key=BUILTIN_VARIABLES_CORE_SECRETS, + ), + ), + Parameter.build_from( + "Builtin LLMs", + "builtin_llm", + type=str, + placeholder="Please select the builtin LLMs", + description="The builtin LLMs to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key=BUILTIN_VARIABLES_CORE_LLMS, + ), + ), + Parameter.build_from( + "Builtin Embeddings", + "builtin_embedding", + type=str, + placeholder="Please select the builtin embeddings", + description="The builtin embeddings to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key=BUILTIN_VARIABLES_CORE_EMBEDDINGS, + ), + ), + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ), + ], + outputs=[ + IOField.build_from( + "Model info", + "model", + str, + description="The model info.", + ), + ], + ) + + def __init__( + self, + openai_api_key: str, + model: str, + builtin_flow: str, + builtin_flow_node: str, + builtin_variable: str, + builtin_secret: str, + builtin_llm: str, + builtin_embedding: str, + **kwargs, + ): + super().__init__(**kwargs) + self.openai_api_key = openai_api_key + self.model = model + self.builtin_flow = builtin_flow + self.builtin_flow_node = builtin_flow_node + self.builtin_variable = builtin_variable + self.builtin_secret = builtin_secret + self.builtin_llm = builtin_llm + self.builtin_embedding = builtin_embedding + + async def map(self, user_name: str) -> str: + """Map the user name to the model.""" + dict_dict = { + "openai_api_key": self.openai_api_key, + "model": self.model, + "builtin_flow": self.builtin_flow, + "builtin_flow_node": self.builtin_flow_node, + "builtin_variable": self.builtin_variable, + "builtin_secret": self.builtin_secret, + "builtin_llm": self.builtin_llm, + "builtin_embedding": self.builtin_embedding, + } + json_data = json.dumps(dict_dict, ensure_ascii=False) + return "Your name is %s, and your model info is %s." % (user_name, json_data) diff --git a/setup.py b/setup.py index cbe5592ce..a968892df 100644 --- a/setup.py +++ b/setup.py @@ -498,6 +498,8 @@ def core_requires(): "GitPython", # For AWEL dag visualization, graphviz is a small package, also we can move it to default. "graphviz", + # For security + "cryptography", ] From 94ef5da873333595669b554447804c6734f9d776 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Sun, 11 Aug 2024 16:09:26 +0800 Subject: [PATCH 36/89] feat(core): Support complex variables parsing --- dbgpt/core/awel/flow/base.py | 27 +- .../awel/flow/tests/test_flow_variables.py | 223 ++++++++++ dbgpt/core/awel/flow/ui.py | 2 +- dbgpt/core/awel/operators/base.py | 1 + dbgpt/core/awel/tests/conftest.py | 43 +- dbgpt/core/awel/tests/test_dag_variables.py | 8 +- dbgpt/core/interface/tests/test_variables.py | 217 +++++++++- dbgpt/core/interface/variables.py | 399 ++++++++++++++---- 8 files changed, 827 insertions(+), 93 deletions(-) create mode 100644 dbgpt/core/awel/flow/tests/test_flow_variables.py diff --git a/dbgpt/core/awel/flow/base.py b/dbgpt/core/awel/flow/base.py index 57081420e..61e0dfa75 100644 --- a/dbgpt/core/awel/flow/base.py +++ b/dbgpt/core/awel/flow/base.py @@ -380,27 +380,40 @@ def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]: return values @classmethod - def _covert_to_real_type(cls, type_cls: str, v: Any): + def _covert_to_real_type(cls, type_cls: str, v: Any) -> Any: if type_cls and v is not None: + typed_value: Any = v try: # Try to convert the value to the type. if type_cls == "builtins.str": - return str(v) + typed_value = str(v) elif type_cls == "builtins.int": - return int(v) + typed_value = int(v) elif type_cls == "builtins.float": - return float(v) + typed_value = float(v) elif type_cls == "builtins.bool": if str(v).lower() in ["false", "0", "", "no", "off"]: return False - return bool(v) + typed_value = bool(v) + return typed_value except ValueError: raise ValidationError(f"Value '{v}' is not valid for type {type_cls}") return v def get_typed_value(self) -> Any: - """Get the typed value.""" - return self._covert_to_real_type(self.type_cls, self.value) + """Get the typed value. + + Returns: + Any: The typed value. VariablesPlaceHolder if the value is a variable + string. Otherwise, the real type value. + """ + from ...interface.variables import VariablesPlaceHolder, is_variable_string + + is_variables = is_variable_string(self.value) if self.value else False + if is_variables and self.value is not None and isinstance(self.value, str): + return VariablesPlaceHolder(self.name, self.value) + else: + return self._covert_to_real_type(self.type_cls, self.value) def get_typed_default(self) -> Any: """Get the typed default.""" diff --git a/dbgpt/core/awel/flow/tests/test_flow_variables.py b/dbgpt/core/awel/flow/tests/test_flow_variables.py new file mode 100644 index 000000000..eaa548b09 --- /dev/null +++ b/dbgpt/core/awel/flow/tests/test_flow_variables.py @@ -0,0 +1,223 @@ +import json +from typing import cast + +import pytest + +from dbgpt.core.awel import BaseOperator, DAGVar, MapOperator +from dbgpt.core.awel.flow import ( + IOField, + OperatorCategory, + Parameter, + VariablesDynamicOptions, + ViewMetadata, + ui, +) +from dbgpt.core.awel.flow.flow_factory import FlowData, FlowFactory, FlowPanel + +from ...tests.conftest import variables_provider + + +class MyVariablesOperator(MapOperator[str, str]): + metadata = ViewMetadata( + label="My Test Variables Operator", + name="my_test_variables_operator", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a variables option.", + parameters=[ + Parameter.build_from( + "OpenAI API Key", + "openai_api_key", + type=str, + placeholder="Please select the OpenAI API key", + description="The OpenAI API key to use.", + options=VariablesDynamicOptions(), + ui=ui.UIPasswordInput( + key="dbgpt.model.openai.api_key", + ), + ), + Parameter.build_from( + "Model", + "model", + type=str, + placeholder="Please select the model", + description="The model to use.", + options=VariablesDynamicOptions(), + ui=ui.UIVariablesInput( + key="dbgpt.model.openai.model", + ), + ), + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ), + ], + outputs=[ + IOField.build_from( + "Model info", + "model", + str, + description="The model info.", + ), + ], + ) + + def __init__(self, openai_api_key: str, model: str, **kwargs): + super().__init__(**kwargs) + self._openai_api_key = openai_api_key + self._model = model + + async def map(self, user_name: str) -> str: + dict_dict = { + "openai_api_key": self._openai_api_key, + "model": self._model, + } + json_data = json.dumps(dict_dict, ensure_ascii=False) + return "Your name is %s, and your model info is %s." % (user_name, json_data) + + +class EndOperator(MapOperator[str, str]): + metadata = ViewMetadata( + label="End Operator", + name="end_operator", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that ends the flow.", + parameters=[], + inputs=[ + IOField.build_from( + "Input", + "input", + str, + description="The input to the end operator.", + ), + ], + outputs=[ + IOField.build_from( + "Output", + "output", + str, + description="The output of the end operator.", + ), + ], + ) + + async def map(self, input: str) -> str: + return f"End operator received input: {input}" + + +@pytest.fixture +def json_flow(): + operators = [MyVariablesOperator, EndOperator] + metadata_list = [operator.metadata.to_dict() for operator in operators] + node_names = {} + name_to_parameters_dict = { + "my_test_variables_operator": { + "openai_api_key": "${dbgpt.model.openai.api_key:my_key@global}", + "model": "${dbgpt.model.openai.model:default_model@global}", + } + } + name_to_metadata_dict = {metadata["name"]: metadata for metadata in metadata_list} + ui_nodes = [] + for metadata in metadata_list: + type_name = metadata["type_name"] + name = metadata["name"] + id = metadata["id"] + if type_name in node_names: + raise ValueError(f"Duplicate node type name: {type_name}") + # Replace id to flow data id. + metadata["id"] = f"{id}_0" + parameters = metadata["parameters"] + parameters_dict = name_to_parameters_dict.get(name, {}) + for parameter in parameters: + parameter_name = parameter["name"] + if parameter_name in parameters_dict: + parameter["value"] = parameters_dict[parameter_name] + ui_nodes.append( + { + "width": 288, + "height": 352, + "id": metadata["id"], + "position": { + "x": -149.98120112708142, + "y": 666.9468497341901, + "zoom": 0.0, + }, + "type": "customNode", + "position_absolute": { + "x": -149.98120112708142, + "y": 666.9468497341901, + "zoom": 0.0, + }, + "data": metadata, + } + ) + + ui_edges = [] + source_id = name_to_metadata_dict["my_test_variables_operator"]["id"] + target_id = name_to_metadata_dict["end_operator"]["id"] + ui_edges.append( + { + "source": source_id, + "target": target_id, + "source_order": 0, + "target_order": 0, + "id": f"{source_id}|{target_id}", + "source_handle": f"{source_id}|outputs|0", + "target_handle": f"{target_id}|inputs|0", + "type": "buttonedge", + } + ) + return { + "nodes": ui_nodes, + "edges": ui_edges, + "viewport": { + "x": 509.2191773722104, + "y": -66.11286175905718, + "zoom": 1.252741002590748, + }, + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "variables_provider", + [ + ( + { + "vars": { + "openai_api_key": { + "key": "${dbgpt.model.openai.api_key:my_key@global}", + "value": "my_openai_api_key", + "value_type": "str", + "category": "secret", + }, + "model": { + "key": "${dbgpt.model.openai.model:default_model@global}", + "value": "GPT-4o", + "value_type": "str", + }, + } + } + ), + ], + indirect=["variables_provider"], +) +async def test_build_flow(json_flow, variables_provider): + DAGVar.set_variables_provider(variables_provider) + flow_data = FlowData(**json_flow) + flow_panel = FlowPanel( + label="My Test Flow", name="my_test_flow", flow_data=flow_data, state="deployed" + ) + factory = FlowFactory() + dag = factory.build(flow_panel) + + leaf_node: BaseOperator = cast(BaseOperator, dag.leaf_nodes[0]) + result = await leaf_node.call("Alice") + assert ( + result + == "End operator received input: Your name is Alice, and your model info is " + '{"openai_api_key": "my_openai_api_key", "model": "GPT-4o"}.' + ) diff --git a/dbgpt/core/awel/flow/ui.py b/dbgpt/core/awel/flow/ui.py index 66b413a9f..875547e9a 100644 --- a/dbgpt/core/awel/flow/ui.py +++ b/dbgpt/core/awel/flow/ui.py @@ -19,7 +19,7 @@ "time_picker", "tree_select", "upload", - "variable", + "variables", "password", "code_editor", ] diff --git a/dbgpt/core/awel/operators/base.py b/dbgpt/core/awel/operators/base.py index 77f056042..0933a4547 100644 --- a/dbgpt/core/awel/operators/base.py +++ b/dbgpt/core/awel/operators/base.py @@ -380,6 +380,7 @@ async def _resolve_variables(self, _: DAGContext): if not self._variables_provider: return + # TODO: Resolve variables parallel for attr, value in self.__dict__.items(): if isinstance(value, VariablesPlaceHolder): resolved_value = await self.blocking_func_to_async( diff --git a/dbgpt/core/awel/tests/conftest.py b/dbgpt/core/awel/tests/conftest.py index d68ddcfc8..607783028 100644 --- a/dbgpt/core/awel/tests/conftest.py +++ b/dbgpt/core/awel/tests/conftest.py @@ -1,17 +1,15 @@ -from contextlib import asynccontextmanager, contextmanager +from contextlib import asynccontextmanager from typing import AsyncIterator, List import pytest import pytest_asyncio -from .. import ( - DAGContext, - DefaultWorkflowRunner, - InputOperator, - SimpleInputSource, - TaskState, - WorkflowRunner, +from ...interface.variables import ( + StorageVariables, + StorageVariablesProvider, + VariablesIdentifier, ) +from .. import DefaultWorkflowRunner, InputOperator, SimpleInputSource from ..task.task_impl import _is_async_iterator @@ -102,3 +100,32 @@ async def stream_input_nodes(request): param["is_stream"] = True async with _create_input_node(**param) as input_nodes: yield input_nodes + + +@asynccontextmanager +async def _create_variables(**kwargs): + vp = StorageVariablesProvider() + vars = kwargs.get("vars") + if vars and isinstance(vars, dict): + for param_key, param_var in vars.items(): + key = param_var.get("key") + value = param_var.get("value") + value_type = param_var.get("value_type") + category = param_var.get("category", "common") + id = VariablesIdentifier.from_str_identifier(key) + vp.save( + StorageVariables.from_identifier( + id, value, value_type, label="", category=category + ) + ) + else: + raise ValueError("vars is required.") + + yield vp + + +@pytest_asyncio.fixture +async def variables_provider(request): + param = getattr(request, "param", {}) + async with _create_variables(**param) as vp: + yield vp diff --git a/dbgpt/core/awel/tests/test_dag_variables.py b/dbgpt/core/awel/tests/test_dag_variables.py index 88c9b6660..8bdb29143 100644 --- a/dbgpt/core/awel/tests/test_dag_variables.py +++ b/dbgpt/core/awel/tests/test_dag_variables.py @@ -54,7 +54,7 @@ async def _create_variables(**kwargs): id, value, value_type, label="", category=category ) ) - variables[param_key] = VariablesPlaceHolder(param_key, key, value_type) + variables[param_key] = VariablesPlaceHolder(param_key, key) else: raise ValueError("vars is required.") @@ -85,17 +85,17 @@ async def test_default_dag(default_dag: DAG): { "vars": { "int_var": { - "key": "int_key@my_int_var@global", + "key": "${int_key:my_int_var@global}", "value": 0, "value_type": "int", }, "str_var": { - "key": "str_key@my_str_var@global", + "key": "${str_key:my_str_var@global}", "value": "1", "value_type": "str", }, "secret": { - "key": "secret_key@my_secret_var@global", + "key": "${secret_key:my_secret_var@global}", "value": "2131sdsdf", "value_type": "str", "category": "secret", diff --git a/dbgpt/core/interface/tests/test_variables.py b/dbgpt/core/interface/tests/test_variables.py index 3b7ab8157..313657b4e 100644 --- a/dbgpt/core/interface/tests/test_variables.py +++ b/dbgpt/core/interface/tests/test_variables.py @@ -1,5 +1,6 @@ import base64 import os +from itertools import product from cryptography.fernet import Fernet @@ -10,6 +11,8 @@ StorageVariables, StorageVariablesProvider, VariablesIdentifier, + build_variable_string, + parse_variable, ) @@ -46,7 +49,7 @@ def test_storage_variables_provider(): encryption = SimpleEncryption() provider = StorageVariablesProvider(storage, encryption) - full_key = "key@name@global" + full_key = "${key:name@global}" value = "secret_value" value_type = "str" label = "test_label" @@ -63,7 +66,7 @@ def test_storage_variables_provider(): def test_variables_identifier(): - full_key = "key@name@global@scope_key@sys_code@user_name" + full_key = "${key:name@global:scope_key#sys_code%user_name}" identifier = VariablesIdentifier.from_str_identifier(full_key) assert identifier.key == "key" @@ -112,3 +115,213 @@ def test_storage_variables(): assert dict_representation["value_type"] == value_type assert dict_representation["category"] == category assert dict_representation["scope"] == scope + + +def generate_test_cases(enable_escape=False): + # Define possible values for each field, including special characters for escaping + _EMPTY_ = "___EMPTY___" + fields = { + "name": [ + None, + "test_name", + "test:name" if enable_escape else _EMPTY_, + "test::name" if enable_escape else _EMPTY_, + "test#name" if enable_escape else _EMPTY_, + "test##name" if enable_escape else _EMPTY_, + "test::@@@#22name" if enable_escape else _EMPTY_, + ], + "scope": [ + None, + "test_scope", + "test@scope" if enable_escape else _EMPTY_, + "test@@scope" if enable_escape else _EMPTY_, + "test:scope" if enable_escape else _EMPTY_, + "test:#:scope" if enable_escape else _EMPTY_, + ], + "scope_key": [ + None, + "test_scope_key", + "test:scope_key" if enable_escape else _EMPTY_, + ], + "sys_code": [ + None, + "test_sys_code", + "test#sys_code" if enable_escape else _EMPTY_, + ], + "user_name": [ + None, + "test_user_name", + "test%user_name" if enable_escape else _EMPTY_, + ], + } + # Remove empty values + fields = {k: [v for v in values if v != _EMPTY_] for k, values in fields.items()} + + # Generate all possible combinations + combinations = product(*fields.values()) + + test_cases = [] + for combo in combinations: + name, scope, scope_key, sys_code, user_name = combo + + var_str = build_variable_string( + { + "key": "test_key", + "name": name, + "scope": scope, + "scope_key": scope_key, + "sys_code": sys_code, + "user_name": user_name, + }, + enable_escape=enable_escape, + ) + + # Construct the expected output + expected = { + "key": "test_key", + "name": name, + "scope": scope, + "scope_key": scope_key, + "sys_code": sys_code, + "user_name": user_name, + } + + test_cases.append((var_str, expected, enable_escape)) + + return test_cases + + +def test_parse_variables(): + # Run test cases without escape + test_cases = generate_test_cases(enable_escape=False) + for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1): + result = parse_variable(input_str, enable_escape=enable_escape) + assert result == expected_output, f"Test case {i} failed without escape" + + # Run test cases with escape + test_cases = generate_test_cases(enable_escape=True) + for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1): + print(f"input_str: {input_str}, expected_output: {expected_output}") + result = parse_variable(input_str, enable_escape=enable_escape) + assert result == expected_output, f"Test case {i} failed with escape" + + +def generate_build_test_cases(enable_escape=False): + # Define possible values for each field, including special characters for escaping + _EMPTY_ = "___EMPTY___" + fields = { + "name": [ + None, + "test_name", + "test:name" if enable_escape else _EMPTY_, + "test::name" if enable_escape else _EMPTY_, + "test\name" if enable_escape else _EMPTY_, + "test\\name" if enable_escape else _EMPTY_, + "test\:\#\@\%name" if enable_escape else _EMPTY_, + "test\::\###\@@\%%name" if enable_escape else _EMPTY_, + "test\\::\\###\\@@\\%%name" if enable_escape else _EMPTY_, + "test\:#:name" if enable_escape else _EMPTY_, + ], + "scope": [None, "test_scope", "test@scope" if enable_escape else _EMPTY_], + "scope_key": [ + None, + "test_scope_key", + "test:scope_key" if enable_escape else _EMPTY_, + ], + "sys_code": [ + None, + "test_sys_code", + "test#sys_code" if enable_escape else _EMPTY_, + ], + "user_name": [ + None, + "test_user_name", + "test%user_name" if enable_escape else _EMPTY_, + ], + } + # Remove empty values + fields = {k: [v for v in values if v != _EMPTY_] for k, values in fields.items()} + + # Generate all possible combinations + combinations = product(*fields.values()) + + test_cases = [] + + def escape_special_chars(s): + if not enable_escape or s is None: + return s + return ( + s.replace(":", "\\:") + .replace("@", "\\@") + .replace("%", "\\%") + .replace("#", "\\#") + ) + + for combo in combinations: + name, scope, scope_key, sys_code, user_name = combo + + # Construct the input dictionary + input_dict = { + "key": "test_key", + "name": name, + "scope": scope, + "scope_key": scope_key, + "sys_code": sys_code, + "user_name": user_name, + } + input_dict_with_escape = { + k: escape_special_chars(v) for k, v in input_dict.items() + } + + # Construct the expected variable string + expected_str = "${test_key" + if name: + expected_str += f":{input_dict_with_escape['name']}" + if scope or scope_key: + expected_str += "@" + if scope: + expected_str += input_dict_with_escape["scope"] + if scope_key: + expected_str += f":{input_dict_with_escape['scope_key']}" + if sys_code: + expected_str += f"#{input_dict_with_escape['sys_code']}" + if user_name: + expected_str += f"%{input_dict_with_escape['user_name']}" + expected_str += "}" + + test_cases.append((input_dict, expected_str, enable_escape)) + + return test_cases + + +def test_build_variable_string(): + # Run test cases without escape + test_cases = generate_build_test_cases(enable_escape=False) + for i, (input_dict, expected_str, enable_escape) in enumerate(test_cases, 1): + result = build_variable_string(input_dict, enable_escape=enable_escape) + assert result == expected_str, f"Test case {i} failed without escape" + + # Run test cases with escape + test_cases = generate_build_test_cases(enable_escape=True) + for i, (input_dict, expected_str, enable_escape) in enumerate(test_cases, 1): + print(f"input_dict: {input_dict}, expected_str: {expected_str}") + result = build_variable_string(input_dict, enable_escape=enable_escape) + assert result == expected_str, f"Test case {i} failed with escape" + + +def test_variable_string_round_trip(): + # Run test cases without escape + test_cases = generate_test_cases(enable_escape=False) + for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1): + parsed_result = parse_variable(input_str, enable_escape=enable_escape) + built_result = build_variable_string(parsed_result, enable_escape=enable_escape) + assert ( + built_result == input_str + ), f"Round trip test case {i} failed without escape" + + # Run test cases with escape + test_cases = generate_test_cases(enable_escape=True) + for i, (input_str, expected_output, enable_escape) in enumerate(test_cases, 1): + parsed_result = parse_variable(input_str, enable_escape=enable_escape) + built_result = build_variable_string(parsed_result, enable_escape=enable_escape) + assert built_result == input_str, f"Round trip test case {i} failed with escape" diff --git a/dbgpt/core/interface/variables.py b/dbgpt/core/interface/variables.py index 8f99d1e30..7e308127c 100644 --- a/dbgpt/core/interface/variables.py +++ b/dbgpt/core/interface/variables.py @@ -182,36 +182,18 @@ def __post_init__(self): if not self.key or not self.name or not self.scope: raise ValueError("Key, name, and scope are required.") - if any( - self.identifier_split in key - for key in [ - self.key, - self.name, - self.scope, - self.scope_key, - self.sys_code, - self.user_name, - ] - if key is not None - ): - raise ValueError( - f"identifier_split {self.identifier_split} is not allowed in " - f"key, name, scope, scope_key, sys_code, user_name." - ) - @property def str_identifier(self) -> str: """Return the string identifier of the identifier.""" - return self.identifier_split.join( - key or "" - for key in [ - self.key, - self.name, - self.scope, - self.scope_key, - self.sys_code, - self.user_name, - ] + return build_variable_string( + { + "key": self.key, + "name": self.name, + "scope": self.scope, + "scope_key": self.scope_key, + "sys_code": self.sys_code, + "user_name": self.user_name, + } ) def to_dict(self) -> Dict: @@ -230,33 +212,30 @@ def to_dict(self) -> Dict: } @classmethod - def from_str_identifier( - cls, str_identifier: str, identifier_split: str = "@" - ) -> "VariablesIdentifier": + def from_str_identifier(cls, str_identifier: str) -> "VariablesIdentifier": """Create a VariablesIdentifier from a string identifier. Args: str_identifier (str): The string identifier. - identifier_split (str): The identifier split. Returns: VariablesIdentifier: The VariablesIdentifier. """ - keys = str_identifier.split(identifier_split) - if not keys: + variable_dict = parse_variable(str_identifier) + if not variable_dict: raise ValueError("Invalid string identifier.") - if len(keys) < 2: + if not variable_dict.get("key"): + raise ValueError("Invalid string identifier, must have key") + if not variable_dict.get("name"): raise ValueError("Invalid string identifier, must have name") - if len(keys) < 3: - raise ValueError("Invalid string identifier, must have scope") return cls( - key=keys[0], - name=keys[1], - scope=keys[2], - scope_key=keys[3] if len(keys) > 3 else None, - sys_code=keys[4] if len(keys) > 4 else None, - user_name=keys[5] if len(keys) > 5 else None, + key=variable_dict["key"], + name=variable_dict["name"], + scope=variable_dict.get("scope", "global"), + scope_key=variable_dict.get("scope_key"), + sys_code=variable_dict.get("sys_code"), + user_name=variable_dict.get("user_name"), ) @@ -402,6 +381,26 @@ def support_async(self) -> bool: """Whether the variables provider support async.""" return False + def _convert_to_value_type(self, var: StorageVariables): + """Convert the variable to the value type.""" + if var.value is None: + return None + if var.value_type == "str": + return str(var.value) + elif var.value_type == "int": + return int(var.value) + elif var.value_type == "float": + return float(var.value) + elif var.value_type == "bool": + if var.value.lower() in ["true", "1"]: + return True + elif var.value.lower() in ["false", "0"]: + return False + else: + return bool(var.value) + else: + return var.value + class VariablesPlaceHolder: """The variables place holder.""" @@ -410,46 +409,20 @@ def __init__( self, param_name: str, full_key: str, - value_type: str, default_value: Any = _EMPTY_DEFAULT_VALUE, ): """Initialize the variables place holder.""" self.param_name = param_name self.full_key = full_key - self.value_type = value_type self.default_value = default_value def parse(self, variables_provider: VariablesProvider) -> Any: """Parse the variables.""" - value = variables_provider.get(self.full_key, self.default_value) - if value: - return self._cast_to_type(value) - else: - return value - - def _cast_to_type(self, value: Any) -> Any: - if self.value_type == "str": - return str(value) - elif self.value_type == "int": - return int(value) - elif self.value_type == "float": - return float(value) - elif self.value_type == "bool": - if value.lower() in ["true", "1"]: - return True - elif value.lower() in ["false", "0"]: - return False - else: - return bool(value) - else: - return value + return variables_provider.get(self.full_key, self.default_value) def __repr__(self): """Return the representation of the variables place holder.""" - return ( - f"" - ) + return f"" class StorageVariablesProvider(VariablesProvider): @@ -493,7 +466,7 @@ def get( and variable.salt ): variable.value = self.encryption.decrypt(variable.value, variable.salt) - return variable.value + return self._convert_to_value_type(variable) def save(self, variables_item: StorageVariables) -> None: """Save variables to storage.""" @@ -676,3 +649,287 @@ def get( def save(self, variables_item: StorageVariables) -> None: """Save variables to storage.""" raise NotImplementedError("BuiltinVariablesProvider does not support save.") + + +def parse_variable( + variable_str: str, + enable_escape: bool = True, +) -> Dict[str, Any]: + """Parse the variable string. + + Examples: + .. code-block:: python + + cases = [ + { + "full_key": "${test_key:test_name@test_scope:test_scope_key}", + "expected": { + "key": "test_key", + "name": "test_name", + "scope": "test_scope", + "scope_key": "test_scope_key", + "sys_code": None, + "user_name": None, + }, + }, + { + "full_key": "${test_key#test_sys_code}", + "expected": { + "key": "test_key", + "name": None, + "scope": None, + "scope_key": None, + "sys_code": "test_sys_code", + "user_name": None, + }, + }, + { + "full_key": "${test_key@:test_scope_key}", + "expected": { + "key": "test_key", + "name": None, + "scope": None, + "scope_key": "test_scope_key", + "sys_code": None, + "user_name": None, + }, + }, + ] + for case in cases: + assert parse_variable(case["full_key"]) == case["expected"] + Args: + variable_str (str): The variable string. + enable_escape (bool): Whether to handle escaped characters. + Returns: + Dict[str, Any]: The parsed variable. + """ + if not variable_str.startswith("${") or not variable_str.endswith("}"): + raise ValueError( + "Invalid variable format, must start with '${' and end with '}'" + ) + + # Remove the surrounding ${ and } + content = variable_str[2:-1] + + # Define placeholders for escaped characters + placeholders = { + r"\@": "__ESCAPED_AT__", + r"\#": "__ESCAPED_HASH__", + r"\%": "__ESCAPED_PERCENT__", + r"\:": "__ESCAPED_COLON__", + } + + if enable_escape: + # Replace escaped characters with placeholders + for original, placeholder in placeholders.items(): + content = content.replace(original, placeholder) + + # Initialize the result dictionary + result: Dict[str, Optional[str]] = { + "key": None, + "name": None, + "scope": None, + "scope_key": None, + "sys_code": None, + "user_name": None, + } + + # Split the content by special characters + parts = content.split("@") + + # Parse key and name + key_name = parts[0].split("#")[0].split("%")[0] + if ":" in key_name: + result["key"], result["name"] = key_name.split(":", 1) + else: + result["key"] = key_name + + # Parse scope and scope_key + if len(parts) > 1: + scope_part = parts[1].split("#")[0].split("%")[0] + if ":" in scope_part: + result["scope"], result["scope_key"] = scope_part.split(":", 1) + else: + result["scope"] = scope_part + + # Parse sys_code + if "#" in content: + result["sys_code"] = content.split("#", 1)[1].split("%")[0] + + # Parse user_name + if "%" in content: + result["user_name"] = content.split("%", 1)[1] + + if enable_escape: + # Replace placeholders back with escaped characters + reverse_placeholders = {v: k[1:] for k, v in placeholders.items()} + for key, value in result.items(): + if value: + for placeholder, original in reverse_placeholders.items(): + result[key] = result[key].replace( # type: ignore + placeholder, original + ) + + # Replace empty strings with None + for key, value in result.items(): + if value == "": + result[key] = None + + return result + + +def _is_variable_format(value: str) -> bool: + if not value.startswith("${") or not value.endswith("}"): + return False + return True + + +def is_variable_string(variable_str: str) -> bool: + """Check if the given string is a variable string. + + A valid variable string should start with "${" and end with "}", and contain key + and name + + Args: + variable_str (str): The string to check. + + Returns: + bool: True if the string is a variable string, False otherwise. + """ + if not _is_variable_format(variable_str): + return False + try: + variable_dict = parse_variable(variable_str) + if not variable_dict.get("key"): + return False + if not variable_dict.get("name"): + return False + return True + except Exception: + return False + + +def is_variable_list_string(variable_str: str) -> bool: + """Check if the given string is a variable string. + + A valid variable list string should start with "${" and end with "}", and contain + key and not contain name + + A valid variable list string means that the variable is a list of variables with the + same key. + + Args: + variable_str (str): The string to check. + + Returns: + bool: True if the string is a variable string, False otherwise. + """ + if not _is_variable_format(variable_str): + return False + try: + variable_dict = parse_variable(variable_str) + if not variable_dict.get("key"): + return False + if variable_dict.get("name"): + return False + return True + except Exception: + return False + + +def build_variable_string( + variable_dict: Dict[str, Any], + scope_sig: str = "@", + sys_code_sig: str = "#", + user_sig: str = "%", + kv_sig: str = ":", + enable_escape: bool = True, +) -> str: + """Build a variable string from the given dictionary. + + Args: + variable_dict (Dict[str, Any]): The dictionary containing the variable details. + scope_sig (str): The scope signature. + sys_code_sig (str): The sys code signature. + user_sig (str): The user signature. + kv_sig (str): The key-value split signature. + enable_escape (bool): Whether to escape special characters + + Returns: + str: The formatted variable string. + + Examples: + >>> build_variable_string( + ... { + ... "key": "test_key", + ... "name": "test_name", + ... "scope": "test_scope", + ... "scope_key": "test_scope_key", + ... "sys_code": "test_sys_code", + ... "user_name": "test_user", + ... } + ... ) + '${test_key:test_name@test_scope:test_scope_key#test_sys_code%test_user}' + + >>> build_variable_string({"key": "test_key", "scope_key": "test_scope_key"}) + '${test_key@:test_scope_key}' + + >>> build_variable_string({"key": "test_key", "sys_code": "test_sys_code"}) + '${test_key#test_sys_code}' + + >>> build_variable_string({"key": "test_key"}) + '${test_key}' + """ + special_chars = {scope_sig, sys_code_sig, user_sig, kv_sig} + # Replace None with "" + new_variable_dict = {key: value or "" for key, value in variable_dict.items()} + + # Check if the variable_dict contains any special characters + for key, value in new_variable_dict.items(): + if value != "" and any(char in value for char in special_chars): + if enable_escape: + # Escape special characters + new_variable_dict[key] = ( + value.replace("@", "\\@") + .replace("#", "\\#") + .replace("%", "\\%") + .replace(":", "\\:") + ) + else: + raise ValueError( + f"{key} contains special characters, error value: {value}, special " + f"characters: {special_chars}" + ) + + key = new_variable_dict.get("key", "") + name = new_variable_dict.get("name", "") + scope = new_variable_dict.get("scope", "") + scope_key = new_variable_dict.get("scope_key", "") + sys_code = new_variable_dict.get("sys_code", "") + user_name = new_variable_dict.get("user_name", "") + + # Construct the base of the variable string + variable_str = f"${{{key}" + + # Add name if present + if name: + variable_str += f":{name}" + + # Add scope and scope_key if present + if scope or scope_key: + variable_str += f"@{scope}" + if scope_key: + variable_str += f":{scope_key}" + + # Add sys_code if present + if sys_code: + variable_str += f"#{sys_code}" + + # Add user_name if present + if user_name: + variable_str += f"%{user_name}" + + # Close the variable string + variable_str += "}" + + return variable_str From abf1c78748ea60761f147e25074c047ae2c2e16c Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Sun, 11 Aug 2024 22:07:21 +0800 Subject: [PATCH 37/89] feat(core): Fetch flow nodes API supports filterd by tags --- dbgpt/core/awel/flow/base.py | 37 ++++++++++++++++++++++-- dbgpt/core/interface/variables.py | 2 ++ dbgpt/serve/flow/api/endpoints.py | 29 +++++++++++++++++-- examples/awel/awel_flow_ui_components.py | 36 +++++++++++++++++++++++ 4 files changed, 98 insertions(+), 6 deletions(-) diff --git a/dbgpt/core/awel/flow/base.py b/dbgpt/core/awel/flow/base.py index 61e0dfa75..7ab7cbb34 100644 --- a/dbgpt/core/awel/flow/base.py +++ b/dbgpt/core/awel/flow/base.py @@ -1145,9 +1145,40 @@ def get_registry_item(self, key: str) -> Optional[_RegistryItem]: """Get the registry item by the key.""" return self._registry.get(key) - def metadata_list(self): - """Get the metadata list.""" - return [item.metadata.to_dict() for item in self._registry.values()] + def metadata_list( + self, + tags: Optional[Dict[str, str]] = None, + user_name: Optional[str] = None, + sys_code: Optional[str] = None, + ) -> List[Dict]: + """Get the metadata list. + + TODO: Support the user and system code filter. + + Args: + tags (Optional[Dict[str, str]], optional): The tags. Defaults to None. + user_name (Optional[str], optional): The user name. Defaults to None. + sys_code (Optional[str], optional): The system code. Defaults to None. + + Returns: + List[Dict]: The metadata list. + """ + if not tags: + return [item.metadata.to_dict() for item in self._registry.values()] + else: + results = [] + for item in self._registry.values(): + node_tags = item.metadata.tags + is_match = True + if not node_tags or not isinstance(node_tags, dict): + continue + for k, v in tags.items(): + if node_tags.get(k) != v: + is_match = False + break + if is_match: + results.append(item.metadata.to_dict()) + return results async def refresh( self, diff --git a/dbgpt/core/interface/variables.py b/dbgpt/core/interface/variables.py index 7e308127c..2d00df44c 100644 --- a/dbgpt/core/interface/variables.py +++ b/dbgpt/core/interface/variables.py @@ -796,6 +796,8 @@ def is_variable_string(variable_str: str) -> bool: Returns: bool: True if the string is a variable string, False otherwise. """ + if not variable_str or not isinstance(variable_str, str): + return False if not _is_variable_format(variable_str): return False try: diff --git a/dbgpt/serve/flow/api/endpoints.py b/dbgpt/serve/flow/api/endpoints.py index 74da7dd72..28c6a28c4 100644 --- a/dbgpt/serve/flow/api/endpoints.py +++ b/dbgpt/serve/flow/api/endpoints.py @@ -1,5 +1,6 @@ +import json from functools import cache -from typing import List, Literal, Optional, Union +from typing import Dict, List, Literal, Optional, Union from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer @@ -229,16 +230,38 @@ async def query_page( @router.get("/nodes", dependencies=[Depends(check_api_key)]) -async def get_nodes(): +async def get_nodes( + user_name: Optional[str] = Query(default=None, description="user name"), + sys_code: Optional[str] = Query(default=None, description="system code"), + tags: Optional[str] = Query(default=None, description="tags"), +): """Get the operator or resource nodes + Args: + user_name (Optional[str]): The username + sys_code (Optional[str]): The system code + tags (Optional[str]): The tags encoded in JSON format + Returns: Result[List[Union[ViewMetadata, ResourceMetadata]]]: The operator or resource nodes """ from dbgpt.core.awel.flow.base import _OPERATOR_REGISTRY - metadata_list = _OPERATOR_REGISTRY.metadata_list() + tags_dict: Optional[Dict[str, str]] = None + if tags: + try: + tags_dict = json.loads(tags) + except json.JSONDecodeError: + return Result.fail("Invalid JSON format for tags") + + metadata_list = await blocking_func_to_async( + global_system_app, + _OPERATOR_REGISTRY.metadata_list, + tags_dict, + user_name, + sys_code, + ) return Result.succ(metadata_list) diff --git a/examples/awel/awel_flow_ui_components.py b/examples/awel/awel_flow_ui_components.py index 7a38f8d4b..9ce611c39 100644 --- a/examples/awel/awel_flow_ui_components.py +++ b/examples/awel/awel_flow_ui_components.py @@ -881,3 +881,39 @@ async def map(self, user_name: str) -> str: } json_data = json.dumps(dict_dict, ensure_ascii=False) return "Your name is %s, and your model info is %s." % (user_name, json_data) + + +class ExampleFlowTagsOperator(MapOperator[str, str]): + """An example flow operator that includes a tags option.""" + + metadata = ViewMetadata( + label="Example Tags Operator", + name="example_tags_operator", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a tags", + parameters=[], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ), + ], + outputs=[ + IOField.build_from( + "Tags", + "tags", + str, + description="The tags to use.", + ), + ], + tags={"order": "higher-order", "type": "example"}, + ) + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + async def map(self, user_name: str) -> str: + """Map the user name to the tags.""" + return "Your name is %s, and your tags are %s." % (user_name, "higher-order") From 74f343331b8444474da2c3f55d981c4de725e30f Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Thu, 15 Aug 2024 09:22:09 +0800 Subject: [PATCH 38/89] feat(core): Add AWEL flow radio component --- dbgpt/core/awel/flow/ui.py | 46 +++++++++++------ dbgpt/serve/flow/api/endpoints.py | 16 +++--- dbgpt/serve/flow/api/schemas.py | 26 +++++++++- examples/awel/awel_flow_ui_components.py | 64 ++++++++++++++++++++++-- 4 files changed, 124 insertions(+), 28 deletions(-) diff --git a/dbgpt/core/awel/flow/ui.py b/dbgpt/core/awel/flow/ui.py index 875547e9a..c763859b0 100644 --- a/dbgpt/core/awel/flow/ui.py +++ b/dbgpt/core/awel/flow/ui.py @@ -11,6 +11,7 @@ "select", "cascader", "checkbox", + "radio", "date_picker", "input", "text_area", @@ -175,6 +176,12 @@ def check_parameter(self, parameter_dict: Dict[str, Any]): self._check_options(parameter_dict.get("options", {})) +class UIRadio(UICheckbox): + """Radio component.""" + + ui_type: Literal["radio"] = Field("radio", frozen=True) # type: ignore + + class UIDatePicker(UIComponent): """Date picker component.""" @@ -232,23 +239,31 @@ class UIAttribute(StatusMixin, UIComponent.UIAttribute): class UITextArea(PanelEditorMixin, UIInput): """Text area component.""" - class AutoSize(BaseModel): - """Auto size configuration.""" + class UIAttribute(UIInput.UIAttribute): + """Text area attribute.""" - min_rows: Optional[int] = Field( - None, - description="The minimum number of rows", - ) - max_rows: Optional[int] = Field( + class AutoSize(BaseModel): + """Auto size configuration.""" + + min_rows: Optional[int] = Field( + None, + description="The minimum number of rows", + ) + max_rows: Optional[int] = Field( + None, + description="The maximum number of rows", + ) + + auto_size: Optional[Union[bool, AutoSize]] = Field( None, - description="The maximum number of rows", + description="Whether the height of the textarea automatically adjusts " + "based on the content", ) ui_type: Literal["text_area"] = Field("text_area", frozen=True) # type: ignore - autosize: Optional[Union[bool, AutoSize]] = Field( + attr: Optional[UIAttribute] = Field( None, - description="Whether the height of the textarea automatically adjusts based " - "on the content", + description="The attributes of the component", ) @@ -430,8 +445,9 @@ class UICodeEditor(UITextArea): class DefaultUITextArea(UITextArea): """Default text area component.""" - autosize: Union[bool, UITextArea.AutoSize] = Field( - default_factory=lambda: UITextArea.AutoSize(min_rows=2, max_rows=40), - description="Whether the height of the textarea automatically adjusts based " - "on the content", + attr: Optional[UITextArea.UIAttribute] = Field( + default_factory=lambda: UITextArea.UIAttribute( + auto_size=UITextArea.UIAttribute.AutoSize(min_rows=2, max_rows=40) + ), + description="The attributes of the component", ) diff --git a/dbgpt/serve/flow/api/endpoints.py b/dbgpt/serve/flow/api/endpoints.py index 28c6a28c4..4b28641e8 100644 --- a/dbgpt/serve/flow/api/endpoints.py +++ b/dbgpt/serve/flow/api/endpoints.py @@ -1,20 +1,12 @@ import json from functools import cache -from typing import Dict, List, Literal, Optional, Union +from typing import Dict, List, Optional, Union from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer from dbgpt.component import SystemApp from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata -from dbgpt.core.interface.variables import ( - BUILTIN_VARIABLES_CORE_FLOW_NODES, - BUILTIN_VARIABLES_CORE_FLOWS, - BUILTIN_VARIABLES_CORE_SECRETS, - BUILTIN_VARIABLES_CORE_VARIABLES, - BuiltinVariablesProvider, - StorageVariables, -) from dbgpt.serve.core import Result, blocking_func_to_async from dbgpt.util import PaginationResult @@ -330,6 +322,12 @@ async def update_variables( return Result.succ(res) +@router.post("/flow/debug") +async def debug(): + """Debug the flow.""" + # TODO: Implement the debug endpoint + + def init_endpoints(system_app: SystemApp) -> None: """Initialize the endpoints""" from .variables_provider import ( diff --git a/dbgpt/serve/flow/api/schemas.py b/dbgpt/serve/flow/api/schemas.py index e63d3e6ce..537996fe7 100644 --- a/dbgpt/serve/flow/api/schemas.py +++ b/dbgpt/serve/flow/api/schemas.py @@ -1,6 +1,7 @@ -from typing import Any, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Union from dbgpt._private.pydantic import BaseModel, ConfigDict, Field +from dbgpt.core.awel import CommonLLMHttpRequestBody from dbgpt.core.awel.flow.flow_factory import FlowPanel from dbgpt.core.awel.util.parameter_util import RefreshOptionRequest @@ -113,3 +114,26 @@ class RefreshNodeRequest(BaseModel): title="The refresh options", description="The refresh options", ) + + +class FlowDebugRequest(BaseModel): + """Flow response model""" + + model_config = ConfigDict(title=f"FlowDebugRequest") + flow: ServeRequest = Field( + ..., + title="The flow to debug", + description="The flow to debug", + ) + request: Union[CommonLLMHttpRequestBody, Dict[str, Any]] = Field( + ..., + title="The request to debug", + description="The request to debug", + ) + variables: Optional[Dict[str, Any]] = Field( + None, + title="The variables to debug", + description="The variables to debug", + ) + user_name: Optional[str] = Field(None, description="User name") + sys_code: Optional[str] = Field(None, description="System code") diff --git a/examples/awel/awel_flow_ui_components.py b/examples/awel/awel_flow_ui_components.py index 9ce611c39..cba0c14df 100644 --- a/examples/awel/awel_flow_ui_components.py +++ b/examples/awel/awel_flow_ui_components.py @@ -206,7 +206,7 @@ class ExampleFlowCheckboxOperator(MapOperator[str, str]): OptionValue(label="Orange", name="orange", value="orange"), OptionValue(label="Pear", name="pear", value="pear"), ], - ui=ui.UICheckbox(attr=ui.UICheckbox.UIAttribute(show_search=True)), + ui=ui.UICheckbox(), ) ], inputs=[ @@ -236,6 +236,59 @@ async def map(self, user_name: str) -> str: return "Your name is %s, and you like %s." % (user_name, ", ".join(self.fruits)) +class ExampleFlowRadioOperator(MapOperator[str, str]): + """An example flow operator that includes a radio as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Radio", + name="example_flow_radio", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a radio as parameter.", + parameters=[ + Parameter.build_from( + "Fruits Selector", + "fruits", + type=str, + optional=True, + default=None, + placeholder="Select the fruits", + description="The fruits you like.", + options=[ + OptionValue(label="Apple", name="apple", value="apple"), + OptionValue(label="Banana", name="banana", value="banana"), + OptionValue(label="Orange", name="orange", value="orange"), + OptionValue(label="Pear", name="pear", value="pear"), + ], + ui=ui.UIRadio(), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Fruits", + "fruits", + str, + description="User's favorite fruits.", + ) + ], + ) + + def __init__(self, fruits: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + self.fruits = fruits + + async def map(self, user_name: str) -> str: + """Map the user name to the fruits.""" + return "Your name is %s, and you like %s." % (user_name, self.fruits) + + class ExampleFlowDatePickerOperator(MapOperator[str, str]): """An example flow operator that includes a date picker as parameter.""" @@ -348,8 +401,13 @@ class ExampleFlowTextAreaOperator(MapOperator[str, str]): placeholder="Please input your comment", description="The comment you want to say.", ui=ui.UITextArea( - attr=ui.UITextArea.UIAttribute(show_count=True, maxlength=1000), - autosize=ui.UITextArea.AutoSize(min_rows=2, max_rows=6), + attr=ui.UITextArea.UIAttribute( + show_count=True, + maxlength=1000, + auto_size=ui.UITextArea.UIAttribute.AutoSize( + min_rows=2, max_rows=6 + ), + ), ), ) ], From 58fb29bec0836c00f467a87c9ab9b124713745f8 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Sun, 18 Aug 2024 07:41:46 +0800 Subject: [PATCH 39/89] feat(core): Add debug and export/import for AWEL flow --- dbgpt/core/awel/dag/dag_manager.py | 2 +- dbgpt/core/awel/flow/flow_factory.py | 61 +++++++++- dbgpt/core/awel/operators/common_operator.py | 7 +- dbgpt/core/awel/trigger/http_trigger.py | 42 ++++++- dbgpt/serve/flow/api/endpoints.py | 121 ++++++++++++++++++- dbgpt/serve/flow/api/schemas.py | 55 +-------- dbgpt/serve/flow/service/service.py | 65 +++++++++- dbgpt/serve/flow/service/share_utils.py | 121 +++++++++++++++++++ dbgpt/util/dbgpts/loader.py | 47 ++++++- 9 files changed, 446 insertions(+), 75 deletions(-) create mode 100644 dbgpt/serve/flow/service/share_utils.py diff --git a/dbgpt/core/awel/dag/dag_manager.py b/dbgpt/core/awel/dag/dag_manager.py index 91a49a166..15a07254a 100644 --- a/dbgpt/core/awel/dag/dag_manager.py +++ b/dbgpt/core/awel/dag/dag_manager.py @@ -197,7 +197,7 @@ def get_dag_metadata( return self._dag_metadata_map.get(dag.dag_id) -def _parse_metadata(dag: DAG): +def _parse_metadata(dag: DAG) -> DAGMetadata: from ..util.chat_util import _is_sse_output metadata = DAGMetadata() diff --git a/dbgpt/core/awel/flow/flow_factory.py b/dbgpt/core/awel/flow/flow_factory.py index 3f847c07c..e0d505aa5 100644 --- a/dbgpt/core/awel/flow/flow_factory.py +++ b/dbgpt/core/awel/flow/flow_factory.py @@ -4,7 +4,7 @@ import uuid from contextlib import suppress from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast +from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union, cast from typing_extensions import Annotated @@ -166,6 +166,59 @@ class FlowData(BaseModel): viewport: FlowPositionData = Field(..., description="Viewport of the flow") +class VariablesRequest(BaseModel): + """Variable request model. + + For creating a new variable in the DB-GPT. + """ + + key: str = Field( + ..., + description="The key of the variable to create", + examples=["dbgpt.model.openai.api_key"], + ) + name: str = Field( + ..., + description="The name of the variable to create", + examples=["my_first_openai_key"], + ) + label: str = Field( + ..., + description="The label of the variable to create", + examples=["My First OpenAI Key"], + ) + value: Any = Field( + ..., description="The value of the variable to create", examples=["1234567890"] + ) + value_type: Literal["str", "int", "float", "bool"] = Field( + "str", + description="The type of the value of the variable to create", + examples=["str", "int", "float", "bool"], + ) + category: Literal["common", "secret"] = Field( + ..., + description="The category of the variable to create", + examples=["common"], + ) + scope: str = Field( + ..., + description="The scope of the variable to create", + examples=["global"], + ) + scope_key: Optional[str] = Field( + ..., + description="The scope key of the variable to create", + examples=["dbgpt"], + ) + enabled: Optional[bool] = Field( + True, + description="Whether the variable is enabled", + examples=[True], + ) + user_name: Optional[str] = Field(None, description="User name") + sys_code: Optional[str] = Field(None, description="System code") + + class State(str, Enum): """State of a flow panel.""" @@ -356,6 +409,12 @@ class FlowPanel(BaseModel): metadata: Optional[Union[DAGMetadata, Dict[str, Any]]] = Field( default=None, description="The metadata of the flow" ) + variables: Optional[List[VariablesRequest]] = Field( + default=None, description="The variables of the flow" + ) + authors: Optional[List[str]] = Field( + default=None, description="The authors of the flow" + ) @model_validator(mode="before") @classmethod diff --git a/dbgpt/core/awel/operators/common_operator.py b/dbgpt/core/awel/operators/common_operator.py index fc2dc098b..f8bc25370 100644 --- a/dbgpt/core/awel/operators/common_operator.py +++ b/dbgpt/core/awel/operators/common_operator.py @@ -334,7 +334,8 @@ def __init__(self, input_source: InputSource[OUT], **kwargs) -> None: async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[OUT]: curr_task_ctx: TaskContext[OUT] = dag_ctx.current_task_context task_output = await self._input_source.read(curr_task_ctx) - curr_task_ctx.set_task_output(task_output) + new_task_output: TaskOutput[OUT] = await task_output.map(self.map) + curr_task_ctx.set_task_output(new_task_output) return task_output @classmethod @@ -342,6 +343,10 @@ def dummy_input(cls, dummy_data: Any = SKIP_DATA, **kwargs) -> "InputOperator[OU """Create a dummy InputOperator with a given input value.""" return cls(input_source=InputSource.from_data(dummy_data), **kwargs) + async def map(self, input_data: OUT) -> OUT: + """Map the input data to a new value.""" + return input_data + class TriggerOperator(InputOperator[OUT], Generic[OUT]): """Operator node that triggers the DAG to run.""" diff --git a/dbgpt/core/awel/trigger/http_trigger.py b/dbgpt/core/awel/trigger/http_trigger.py index 22e025c13..8f0298297 100644 --- a/dbgpt/core/awel/trigger/http_trigger.py +++ b/dbgpt/core/awel/trigger/http_trigger.py @@ -87,7 +87,9 @@ class HttpTriggerMetadata(TriggerMetadata): path: str = Field(..., description="The path of the trigger") methods: List[str] = Field(..., description="The methods of the trigger") - + trigger_mode: str = Field( + default="command", description="The mode of the trigger, command or chat" + ) trigger_type: Optional[str] = Field( default="http", description="The type of the trigger" ) @@ -477,7 +479,9 @@ def mount_to_router( )(dynamic_route_function) logger.info(f"Mount http trigger success, path: {path}") - return HttpTriggerMetadata(path=path, methods=self._methods) + return HttpTriggerMetadata( + path=path, methods=self._methods, trigger_mode=self._trigger_mode() + ) def mount_to_app( self, app: "FastAPI", global_prefix: Optional[str] = None @@ -512,7 +516,9 @@ def mount_to_app( app.openapi_schema = None app.middleware_stack = None logger.info(f"Mount http trigger success, path: {path}") - return HttpTriggerMetadata(path=path, methods=self._methods) + return HttpTriggerMetadata( + path=path, methods=self._methods, trigger_mode=self._trigger_mode() + ) def remove_from_app( self, app: "FastAPI", global_prefix: Optional[str] = None @@ -537,6 +543,36 @@ def remove_from_app( # TODO, remove with path and methods del app_router.routes[i] + def _trigger_mode(self) -> str: + if ( + self._req_body + and isinstance(self._req_body, type) + and issubclass(self._req_body, CommonLLMHttpRequestBody) + ): + return "chat" + return "command" + + async def map(self, input_data: Any) -> Any: + """Map the input data. + + Do some transformation for the input data. + + Args: + input_data (Any): The input data from caller. + + Returns: + Any: The mapped data. + """ + if not self._req_body or not input_data: + return await super().map(input_data) + if ( + isinstance(self._req_body, type) + and issubclass(self._req_body, BaseModel) + and isinstance(input_data, dict) + ): + return self._req_body(**input_data) + return await super().map(input_data) + def _create_route_func(self): from inspect import Parameter, Signature from typing import get_type_hints diff --git a/dbgpt/serve/flow/api/endpoints.py b/dbgpt/serve/flow/api/endpoints.py index 4b28641e8..4174502a5 100644 --- a/dbgpt/serve/flow/api/endpoints.py +++ b/dbgpt/serve/flow/api/endpoints.py @@ -1,9 +1,11 @@ +import io import json from functools import cache -from typing import Dict, List, Optional, Union +from typing import Dict, List, Literal, Optional, Union -from fastapi import APIRouter, Depends, HTTPException, Query, Request +from fastapi import APIRouter, Depends, File, HTTPException, Query, Request, UploadFile from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer +from starlette.responses import JSONResponse, StreamingResponse from dbgpt.component import SystemApp from dbgpt.core.awel.flow import ResourceMetadata, ViewMetadata @@ -14,6 +16,7 @@ from ..service.service import Service from ..service.variables_service import VariablesService from .schemas import ( + FlowDebugRequest, RefreshNodeRequest, ServeRequest, ServerResponse, @@ -322,10 +325,116 @@ async def update_variables( return Result.succ(res) -@router.post("/flow/debug") -async def debug(): - """Debug the flow.""" - # TODO: Implement the debug endpoint +@router.post("/flow/debug", dependencies=[Depends(check_api_key)]) +async def debug_flow( + flow_debug_request: FlowDebugRequest, service: Service = Depends(get_service) +): + """Run the flow in debug mode.""" + # Return the no-incremental stream by default + stream_iter = service.debug_flow(flow_debug_request, default_incremental=False) + + headers = { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Transfer-Encoding": "chunked", + } + return StreamingResponse( + service._wrapper_chat_stream_flow_str(stream_iter), + headers=headers, + media_type="text/event-stream", + ) + + +@router.get("/flow/export/{uid}", dependencies=[Depends(check_api_key)]) +async def export_flow( + uid: str, + export_type: Literal["json", "dbgpts"] = Query( + "json", description="export type(json or dbgpts)" + ), + format: Literal["file", "json"] = Query( + "file", description="response format(file or json)" + ), + file_name: Optional[str] = Query(default=None, description="file name to export"), + user_name: Optional[str] = Query(default=None, description="user name"), + sys_code: Optional[str] = Query(default=None, description="system code"), + service: Service = Depends(get_service), +): + """Export the flow to a file.""" + flow = service.get({"uid": uid, "user_name": user_name, "sys_code": sys_code}) + if not flow: + raise HTTPException(status_code=404, detail=f"Flow {uid} not found") + package_name = flow.name.replace("_", "-") + file_name = file_name or package_name + if export_type == "json": + flow_dict = {"flow": flow.to_dict()} + if format == "json": + return JSONResponse(content=flow_dict) + else: + # Return the json file + return StreamingResponse( + io.BytesIO(json.dumps(flow_dict, ensure_ascii=False).encode("utf-8")), + media_type="application/file", + headers={ + "Content-Disposition": f"attachment;filename={file_name}.json" + }, + ) + + elif export_type == "dbgpts": + from ..service.share_utils import _generate_dbgpts_zip + + if format == "json": + raise HTTPException( + status_code=400, detail="json response is not supported for dbgpts" + ) + + zip_buffer = await blocking_func_to_async( + global_system_app, _generate_dbgpts_zip, package_name, flow + ) + return StreamingResponse( + zip_buffer, + media_type="application/x-zip-compressed", + headers={"Content-Disposition": f"attachment;filename={file_name}.zip"}, + ) + + +@router.post( + "/flow/import", + response_model=Result[ServerResponse], + dependencies=[Depends(check_api_key)], +) +async def import_flow( + file: UploadFile = File(...), + save_flow: bool = Query( + False, description="Whether to save the flow after importing" + ), + service: Service = Depends(get_service), +): + """Import the flow from a file.""" + filename = file.filename + file_extension = filename.split(".")[-1].lower() + if file_extension == "json": + # Handle json file + json_content = await file.read() + json_dict = json.loads(json_content) + if "flow" not in json_dict: + raise HTTPException( + status_code=400, detail="invalid json file, missing 'flow' key" + ) + flow = ServeRequest.parse_obj(json_dict["flow"]) + elif file_extension == "zip": + from ..service.share_utils import _parse_flow_from_zip_file + + # Handle zip file + flow = await _parse_flow_from_zip_file(file, global_system_app) + else: + raise HTTPException( + status_code=400, detail=f"invalid file extension {file_extension}" + ) + if save_flow: + return Result.succ(service.create_and_save_dag(flow)) + else: + return Result.succ(flow) def init_endpoints(system_app: SystemApp) -> None: diff --git a/dbgpt/serve/flow/api/schemas.py b/dbgpt/serve/flow/api/schemas.py index 537996fe7..cf82de982 100644 --- a/dbgpt/serve/flow/api/schemas.py +++ b/dbgpt/serve/flow/api/schemas.py @@ -2,7 +2,7 @@ from dbgpt._private.pydantic import BaseModel, ConfigDict, Field from dbgpt.core.awel import CommonLLMHttpRequestBody -from dbgpt.core.awel.flow.flow_factory import FlowPanel +from dbgpt.core.awel.flow.flow_factory import FlowPanel, VariablesRequest from dbgpt.core.awel.util.parameter_util import RefreshOptionRequest from ..config import SERVE_APP_NAME_HUMP @@ -18,59 +18,6 @@ class ServerResponse(FlowPanel): model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}") -class VariablesRequest(BaseModel): - """Variable request model. - - For creating a new variable in the DB-GPT. - """ - - key: str = Field( - ..., - description="The key of the variable to create", - examples=["dbgpt.model.openai.api_key"], - ) - name: str = Field( - ..., - description="The name of the variable to create", - examples=["my_first_openai_key"], - ) - label: str = Field( - ..., - description="The label of the variable to create", - examples=["My First OpenAI Key"], - ) - value: Any = Field( - ..., description="The value of the variable to create", examples=["1234567890"] - ) - value_type: Literal["str", "int", "float", "bool"] = Field( - "str", - description="The type of the value of the variable to create", - examples=["str", "int", "float", "bool"], - ) - category: Literal["common", "secret"] = Field( - ..., - description="The category of the variable to create", - examples=["common"], - ) - scope: str = Field( - ..., - description="The scope of the variable to create", - examples=["global"], - ) - scope_key: Optional[str] = Field( - ..., - description="The scope key of the variable to create", - examples=["dbgpt"], - ) - enabled: Optional[bool] = Field( - True, - description="Whether the variable is enabled", - examples=[True], - ) - user_name: Optional[str] = Field(None, description="User name") - sys_code: Optional[str] = Field(None, description="System code") - - class VariablesResponse(VariablesRequest): """Variable response model.""" diff --git a/dbgpt/serve/flow/service/service.py b/dbgpt/serve/flow/service/service.py index 83b79847f..3cdb136eb 100644 --- a/dbgpt/serve/flow/service/service.py +++ b/dbgpt/serve/flow/service/service.py @@ -8,7 +8,6 @@ from dbgpt._private.pydantic import model_to_json from dbgpt.component import SystemApp from dbgpt.core.awel import DAG, BaseOperator, CommonLLMHttpRequestBody -from dbgpt.core.awel.dag.dag_manager import DAGManager from dbgpt.core.awel.flow.flow_factory import ( FlowCategory, FlowFactory, @@ -33,7 +32,7 @@ from dbgpt.util.dbgpts.loader import DBGPTsLoader from dbgpt.util.pagination_utils import PaginationResult -from ..api.schemas import ServeRequest, ServerResponse +from ..api.schemas import FlowDebugRequest, ServeRequest, ServerResponse from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig from ..models.models import ServeDao, ServeEntity @@ -146,7 +145,9 @@ def create_and_save_dag( raise ValueError( f"Create DAG {request.name} error, define_type: {request.define_type}, error: {str(e)}" ) from e - res = self.dao.create(request) + self.dao.create(request) + # Query from database + res = self.get({"uid": request.uid}) state = request.state try: @@ -563,3 +564,61 @@ def _parse_flow_category(self, dag: DAG) -> FlowCategory: return FlowCategory.CHAT_FLOW except Exception: return FlowCategory.COMMON + + async def debug_flow( + self, request: FlowDebugRequest, default_incremental: Optional[bool] = None + ) -> AsyncIterator[ModelOutput]: + """Debug the flow. + + Args: + request (FlowDebugRequest): The request + default_incremental (Optional[bool]): The default incremental configuration + + Returns: + AsyncIterator[ModelOutput]: The output + """ + from dbgpt.core.awel.dag.dag_manager import DAGMetadata, _parse_metadata + + dag = self._flow_factory.build(request.flow) + leaf_nodes = dag.leaf_nodes + if len(leaf_nodes) != 1: + raise ValueError("Chat Flow just support one leaf node in dag") + task = cast(BaseOperator, leaf_nodes[0]) + dag_metadata = _parse_metadata(dag) + # TODO: Run task with variables + variables = request.variables + dag_request = request.request + + if isinstance(request.request, CommonLLMHttpRequestBody): + incremental = request.request.incremental + elif isinstance(request.request, dict): + incremental = request.request.get("incremental", False) + else: + raise ValueError("Invalid request type") + + if default_incremental is not None: + incremental = default_incremental + + try: + async for output in safe_chat_stream_with_dag_task( + task, dag_request, incremental + ): + yield output + except HTTPException as e: + yield ModelOutput(error_code=1, text=e.detail, incremental=incremental) + except Exception as e: + yield ModelOutput(error_code=1, text=str(e), incremental=incremental) + + async def _wrapper_chat_stream_flow_str( + self, stream_iter: AsyncIterator[ModelOutput] + ) -> AsyncIterator[str]: + + async for output in stream_iter: + text = output.text + if text: + text = text.replace("\n", "\\n") + if output.error_code != 0: + yield f"data:[SERVER_ERROR]{text}\n\n" + break + else: + yield f"data:{text}\n\n" diff --git a/dbgpt/serve/flow/service/share_utils.py b/dbgpt/serve/flow/service/share_utils.py new file mode 100644 index 000000000..99ba222a9 --- /dev/null +++ b/dbgpt/serve/flow/service/share_utils.py @@ -0,0 +1,121 @@ +import io +import json +import os +import tempfile +import zipfile + +import aiofiles +import tomlkit +from fastapi import UploadFile + +from dbgpt.component import SystemApp +from dbgpt.serve.core import blocking_func_to_async + +from ..api.schemas import ServeRequest + + +def _generate_dbgpts_zip(package_name: str, flow: ServeRequest) -> io.BytesIO: + + zip_buffer = io.BytesIO() + flow_name = flow.name + flow_label = flow.label + flow_description = flow.description + dag_json = json.dumps(flow.flow_data.dict(), indent=4, ensure_ascii=False) + with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED, False) as zip_file: + manifest = f"include dbgpts.toml\ninclude {flow_name}/definition/*.json" + readme = f"# {flow_label}\n\n{flow_description}" + zip_file.writestr(f"{package_name}/MANIFEST.in", manifest) + zip_file.writestr(f"{package_name}/README.md", readme) + zip_file.writestr( + f"{package_name}/{flow_name}/__init__.py", + "", + ) + zip_file.writestr( + f"{package_name}/{flow_name}/definition/flow_definition.json", + dag_json, + ) + dbgpts_toml = tomlkit.document() + # Add flow information + dbgpts_flow_toml = tomlkit.document() + dbgpts_flow_toml.add("label", "Simple Streaming Chat") + name_with_comment = tomlkit.string("awel_flow_simple_streaming_chat") + name_with_comment.comment("A unique name for all dbgpts") + dbgpts_flow_toml.add("name", name_with_comment) + + dbgpts_flow_toml.add("version", "0.1.0") + dbgpts_flow_toml.add( + "description", + flow_description, + ) + dbgpts_flow_toml.add("authors", []) + + definition_type_with_comment = tomlkit.string("json") + definition_type_with_comment.comment("How to define the flow, python or json") + dbgpts_flow_toml.add("definition_type", definition_type_with_comment) + + dbgpts_toml.add("flow", dbgpts_flow_toml) + + # Add python and json config + python_config = tomlkit.table() + dbgpts_toml.add("python_config", python_config) + + json_config = tomlkit.table() + json_config.add("file_path", "definition/flow_definition.json") + json_config.comment("Json config") + + dbgpts_toml.add("json_config", json_config) + + # Transform to string + toml_string = tomlkit.dumps(dbgpts_toml) + zip_file.writestr(f"{package_name}/dbgpts.toml", toml_string) + + pyproject_toml = tomlkit.document() + + # Add [tool.poetry] section + tool_poetry_toml = tomlkit.table() + tool_poetry_toml.add("name", package_name) + tool_poetry_toml.add("version", "0.1.0") + tool_poetry_toml.add("description", "A dbgpts package") + tool_poetry_toml.add("authors", []) + tool_poetry_toml.add("readme", "README.md") + pyproject_toml["tool"] = tomlkit.table() + pyproject_toml["tool"]["poetry"] = tool_poetry_toml + + # Add [tool.poetry.dependencies] section + dependencies = tomlkit.table() + dependencies.add("python", "^3.10") + pyproject_toml["tool"]["poetry"]["dependencies"] = dependencies + + # Add [build-system] section + build_system = tomlkit.table() + build_system.add("requires", ["poetry-core"]) + build_system.add("build-backend", "poetry.core.masonry.api") + pyproject_toml["build-system"] = build_system + + # Transform to string + pyproject_toml_string = tomlkit.dumps(pyproject_toml) + zip_file.writestr(f"{package_name}/pyproject.toml", pyproject_toml_string) + zip_buffer.seek(0) + return zip_buffer + + +async def _parse_flow_from_zip_file( + file: UploadFile, sys_app: SystemApp +) -> ServeRequest: + from dbgpt.util.dbgpts.loader import _load_flow_package_from_zip_path + + filename = file.filename + if not filename.endswith(".zip"): + raise ValueError("Uploaded file must be a ZIP file") + + with tempfile.TemporaryDirectory() as temp_dir: + zip_path = os.path.join(temp_dir, filename) + + # Save uploaded file to temporary directory + async with aiofiles.open(zip_path, "wb") as out_file: + while content := await file.read(1024 * 64): # Read in chunks of 64KB + await out_file.write(content) + flow = await blocking_func_to_async( + sys_app, _load_flow_package_from_zip_path, zip_path + ) + return flow diff --git a/dbgpt/util/dbgpts/loader.py b/dbgpt/util/dbgpts/loader.py index 8545ad067..4151546e9 100644 --- a/dbgpt/util/dbgpts/loader.py +++ b/dbgpt/util/dbgpts/loader.py @@ -320,14 +320,19 @@ def _load_package_from_path(path: str): return parsed_packages -def _load_flow_package_from_path(name: str, path: str = INSTALL_DIR) -> FlowPackage: +def _load_flow_package_from_path( + name: str, path: str = INSTALL_DIR, filter_by_name: bool = True +) -> FlowPackage: raw_packages = _load_installed_package(path) new_name = name.replace("_", "-") - packages = [p for p in raw_packages if p.package == name or p.name == name] - if not packages: - packages = [ - p for p in raw_packages if p.package == new_name or p.name == new_name - ] + if filter_by_name: + packages = [p for p in raw_packages if p.package == name or p.name == name] + if not packages: + packages = [ + p for p in raw_packages if p.package == new_name or p.name == new_name + ] + else: + packages = raw_packages if not packages: raise ValueError(f"Can't find the package {name} or {new_name}") flow_package = _parse_package_metadata(packages[0]) @@ -336,6 +341,35 @@ def _load_flow_package_from_path(name: str, path: str = INSTALL_DIR) -> FlowPack return cast(FlowPackage, flow_package) +def _load_flow_package_from_zip_path(zip_path: str) -> FlowPanel: + import tempfile + import zipfile + + with tempfile.TemporaryDirectory() as temp_dir: + with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(temp_dir) + package_names = os.listdir(temp_dir) + if not package_names: + raise ValueError("No package found in the zip file") + if len(package_names) > 1: + raise ValueError("Only support one package in the zip file") + package_name = package_names[0] + with open( + Path(temp_dir) / package_name / INSTALL_METADATA_FILE, mode="w+" + ) as f: + # Write the metadata + import tomlkit + + install_metadata = { + "name": package_name, + "repo": "local/dbgpts", + } + tomlkit.dump(install_metadata, f) + + package = _load_flow_package_from_path("", path=temp_dir, filter_by_name=False) + return _flow_package_to_flow_panel(package) + + def _flow_package_to_flow_panel(package: FlowPackage) -> FlowPanel: dict_value = { "name": package.name, @@ -345,6 +379,7 @@ def _flow_package_to_flow_panel(package: FlowPackage) -> FlowPanel: "description": package.description, "source": package.repo, "define_type": "json", + "authors": package.authors, } if isinstance(package, FlowJsonPackage): dict_value["flow_data"] = package.read_definition_json() From b6d54ed8ab3d69a7ce4f70894e30acdc8c9c5a1d Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Mon, 19 Aug 2024 00:19:53 +0800 Subject: [PATCH 40/89] feat(core): Add file server for DB-GPT --- .env.template | 5 + .mypy.ini | 3 + dbgpt/_private/config.py | 7 + dbgpt/app/component_configs.py | 2 +- .../initialization/db_model_initialization.py | 2 + .../initialization/serve_initialization.py | 36 +- dbgpt/component.py | 1 + dbgpt/configs/model_config.py | 1 + dbgpt/core/interface/file.py | 791 ++++++++++++++++++ dbgpt/core/interface/tests/test_file.py | 506 +++++++++++ dbgpt/serve/file/__init__.py | 2 + dbgpt/serve/file/api/__init__.py | 2 + dbgpt/serve/file/api/endpoints.py | 159 ++++ dbgpt/serve/file/api/schemas.py | 43 + dbgpt/serve/file/config.py | 68 ++ dbgpt/serve/file/dependencies.py | 1 + dbgpt/serve/file/models/__init__.py | 2 + dbgpt/serve/file/models/file_adapter.py | 66 ++ dbgpt/serve/file/models/models.py | 87 ++ dbgpt/serve/file/serve.py | 113 +++ dbgpt/serve/file/service/__init__.py | 0 dbgpt/serve/file/service/service.py | 106 +++ dbgpt/serve/file/tests/__init__.py | 0 dbgpt/serve/file/tests/test_endpoints.py | 124 +++ dbgpt/serve/file/tests/test_models.py | 99 +++ dbgpt/serve/file/tests/test_service.py | 78 ++ 26 files changed, 2301 insertions(+), 3 deletions(-) create mode 100644 dbgpt/core/interface/file.py create mode 100644 dbgpt/core/interface/tests/test_file.py create mode 100644 dbgpt/serve/file/__init__.py create mode 100644 dbgpt/serve/file/api/__init__.py create mode 100644 dbgpt/serve/file/api/endpoints.py create mode 100644 dbgpt/serve/file/api/schemas.py create mode 100644 dbgpt/serve/file/config.py create mode 100644 dbgpt/serve/file/dependencies.py create mode 100644 dbgpt/serve/file/models/__init__.py create mode 100644 dbgpt/serve/file/models/file_adapter.py create mode 100644 dbgpt/serve/file/models/models.py create mode 100644 dbgpt/serve/file/serve.py create mode 100644 dbgpt/serve/file/service/__init__.py create mode 100644 dbgpt/serve/file/service/service.py create mode 100644 dbgpt/serve/file/tests/__init__.py create mode 100644 dbgpt/serve/file/tests/test_endpoints.py create mode 100644 dbgpt/serve/file/tests/test_models.py create mode 100644 dbgpt/serve/file/tests/test_service.py diff --git a/.env.template b/.env.template index 44aa2d710..2a281e698 100644 --- a/.env.template +++ b/.env.template @@ -277,6 +277,11 @@ DBGPT_LOG_LEVEL=INFO # ENCRYPT KEY - The key used to encrypt and decrypt the data # ENCRYPT_KEY=your_secret_key +#*******************************************************************# +#** File Server **# +#*******************************************************************# +## The local storage path of the file server, the default is pilot/data/file_server +# FILE_SERVER_LOCAL_STORAGE_PATH = #*******************************************************************# #** Application Config **# diff --git a/.mypy.ini b/.mypy.ini index e2c2bc3ab..52ae00c35 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -115,3 +115,6 @@ ignore_missing_imports = True [mypy-networkx.*] ignore_missing_imports = True + +[mypy-pypdf.*] +ignore_missing_imports = True diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index 2dbfac0f0..18e972a4c 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -316,6 +316,13 @@ def __init__(self) -> None: # experimental financial report model configuration self.FIN_REPORT_MODEL = os.getenv("FIN_REPORT_MODEL", None) + # file server configuration + # The host of the current file server, if None, get the host automatically + self.FILE_SERVER_HOST = os.getenv("FILE_SERVER_HOST") + self.FILE_SERVER_LOCAL_STORAGE_PATH = os.getenv( + "FILE_SERVER_LOCAL_STORAGE_PATH" + ) + @property def local_db_manager(self) -> "ConnectorManager": from dbgpt.datasource.manages import ConnectorManager diff --git a/dbgpt/app/component_configs.py b/dbgpt/app/component_configs.py index 3ef08d4bc..29c9e59be 100644 --- a/dbgpt/app/component_configs.py +++ b/dbgpt/app/component_configs.py @@ -59,7 +59,7 @@ def initialize_components( _initialize_agent(system_app) _initialize_openapi(system_app) # Register serve apps - register_serve_apps(system_app, CFG) + register_serve_apps(system_app, CFG, param.port) def _initialize_model_cache(system_app: SystemApp): diff --git a/dbgpt/app/initialization/db_model_initialization.py b/dbgpt/app/initialization/db_model_initialization.py index b8808c400..969340c44 100644 --- a/dbgpt/app/initialization/db_model_initialization.py +++ b/dbgpt/app/initialization/db_model_initialization.py @@ -8,6 +8,7 @@ from dbgpt.model.cluster.registry_impl.db_storage import ModelInstanceEntity from dbgpt.serve.agent.db.my_plugin_db import MyPluginEntity from dbgpt.serve.agent.db.plugin_hub_db import PluginHubEntity +from dbgpt.serve.file.models.models import ServeEntity as FileServeEntity from dbgpt.serve.flow.models.models import ServeEntity as FlowServeEntity from dbgpt.serve.flow.models.models import VariablesEntity as FlowVariableEntity from dbgpt.serve.prompt.models.models import ServeEntity as PromptManageEntity @@ -19,6 +20,7 @@ _MODELS = [ PluginHubEntity, + FileServeEntity, MyPluginEntity, PromptManageEntity, KnowledgeSpaceEntity, diff --git a/dbgpt/app/initialization/serve_initialization.py b/dbgpt/app/initialization/serve_initialization.py index f0b9c9e42..7838644e0 100644 --- a/dbgpt/app/initialization/serve_initialization.py +++ b/dbgpt/app/initialization/serve_initialization.py @@ -2,7 +2,7 @@ from dbgpt.component import SystemApp -def register_serve_apps(system_app: SystemApp, cfg: Config): +def register_serve_apps(system_app: SystemApp, cfg: Config, webserver_port: int): """Register serve apps""" system_app.config.set("dbgpt.app.global.language", cfg.LANGUAGE) if cfg.API_KEYS: @@ -47,6 +47,8 @@ def register_serve_apps(system_app: SystemApp, cfg: Config): # Register serve app system_app.register(FlowServe) + # ################################ AWEL Flow Serve Register End ######################################## + # ################################ Rag Serve Register Begin ###################################### from dbgpt.serve.rag.serve import ( @@ -57,6 +59,8 @@ def register_serve_apps(system_app: SystemApp, cfg: Config): # Register serve app system_app.register(RagServe) + # ################################ Rag Serve Register End ######################################## + # ################################ Datasource Serve Register Begin ###################################### from dbgpt.serve.datasource.serve import ( @@ -66,4 +70,32 @@ def register_serve_apps(system_app: SystemApp, cfg: Config): # Register serve app system_app.register(DatasourceServe) - # ################################ AWEL Flow Serve Register End ######################################## + + # ################################ Datasource Serve Register End ######################################## + + # ################################ File Serve Register Begin ###################################### + + from dbgpt.configs.model_config import FILE_SERVER_LOCAL_STORAGE_PATH + from dbgpt.serve.file.serve import ( + SERVE_CONFIG_KEY_PREFIX as FILE_SERVE_CONFIG_KEY_PREFIX, + ) + from dbgpt.serve.file.serve import Serve as FileServe + + local_storage_path = ( + cfg.FILE_SERVER_LOCAL_STORAGE_PATH or FILE_SERVER_LOCAL_STORAGE_PATH + ) + # Set config + system_app.config.set( + f"{FILE_SERVE_CONFIG_KEY_PREFIX}local_storage_path", local_storage_path + ) + system_app.config.set( + f"{FILE_SERVE_CONFIG_KEY_PREFIX}file_server_port", webserver_port + ) + if cfg.FILE_SERVER_HOST: + system_app.config.set( + f"{FILE_SERVE_CONFIG_KEY_PREFIX}file_server_host", cfg.FILE_SERVER_HOST + ) + # Register serve app + system_app.register(FileServe) + + # ################################ File Serve Register End ######################################## diff --git a/dbgpt/component.py b/dbgpt/component.py index cb88a61ec..da3c5e753 100644 --- a/dbgpt/component.py +++ b/dbgpt/component.py @@ -90,6 +90,7 @@ class ComponentType(str, Enum): AGENT_MANAGER = "dbgpt_agent_manager" RESOURCE_MANAGER = "dbgpt_resource_manager" VARIABLES_PROVIDER = "dbgpt_variables_provider" + FILE_STORAGE_CLIENT = "dbgpt_file_storage_client" _EMPTY_DEFAULT_COMPONENT = "_EMPTY_DEFAULT_COMPONENT" diff --git a/dbgpt/configs/model_config.py b/dbgpt/configs/model_config.py index 4d02a2730..e4abac3e7 100644 --- a/dbgpt/configs/model_config.py +++ b/dbgpt/configs/model_config.py @@ -14,6 +14,7 @@ DATA_DIR = os.path.join(PILOT_PATH, "data") PLUGINS_DIR = os.path.join(ROOT_PATH, "plugins") MODEL_DISK_CACHE_DIR = os.path.join(DATA_DIR, "model_cache") +FILE_SERVER_LOCAL_STORAGE_PATH = os.path.join(DATA_DIR, "file_server") _DAG_DEFINITION_DIR = os.path.join(ROOT_PATH, "examples/awel") # Global language setting LOCALES_DIR = os.path.join(ROOT_PATH, "i18n/locales") diff --git a/dbgpt/core/interface/file.py b/dbgpt/core/interface/file.py new file mode 100644 index 000000000..5bd6cf842 --- /dev/null +++ b/dbgpt/core/interface/file.py @@ -0,0 +1,791 @@ +"""File storage interface.""" + +import dataclasses +import hashlib +import io +import os +import uuid +from abc import ABC, abstractmethod +from io import BytesIO +from typing import Any, BinaryIO, Dict, List, Optional, Tuple +from urllib.parse import parse_qs, urlencode, urlparse + +import requests + +from dbgpt.component import BaseComponent, ComponentType, SystemApp +from dbgpt.util.tracer import root_tracer, trace + +from .storage import ( + InMemoryStorage, + QuerySpec, + ResourceIdentifier, + StorageError, + StorageInterface, + StorageItem, +) + +_SCHEMA = "dbgpt-fs" + + +@dataclasses.dataclass +class FileMetadataIdentifier(ResourceIdentifier): + """File metadata identifier.""" + + file_id: str + bucket: str + + def to_dict(self) -> Dict: + """Convert the identifier to a dictionary.""" + return {"file_id": self.file_id, "bucket": self.bucket} + + @property + def str_identifier(self) -> str: + """Get the string identifier. + + Returns: + str: The string identifier + """ + return f"{self.bucket}/{self.file_id}" + + +@dataclasses.dataclass +class FileMetadata(StorageItem): + """File metadata for storage.""" + + file_id: str + bucket: str + file_name: str + file_size: int + storage_type: str + storage_path: str + uri: str + custom_metadata: Dict[str, Any] + file_hash: str + _identifier: FileMetadataIdentifier = dataclasses.field(init=False) + + def __post_init__(self): + """Post init method.""" + self._identifier = FileMetadataIdentifier( + file_id=self.file_id, bucket=self.bucket + ) + + @property + def identifier(self) -> ResourceIdentifier: + """Get the resource identifier.""" + return self._identifier + + def merge(self, other: "StorageItem") -> None: + """Merge the metadata with another item.""" + if not isinstance(other, FileMetadata): + raise StorageError("Cannot merge different types of items") + self._from_object(other) + + def to_dict(self) -> Dict: + """Convert the metadata to a dictionary.""" + return { + "file_id": self.file_id, + "bucket": self.bucket, + "file_name": self.file_name, + "file_size": self.file_size, + "storage_type": self.storage_type, + "storage_path": self.storage_path, + "uri": self.uri, + "custom_metadata": self.custom_metadata, + "file_hash": self.file_hash, + } + + def _from_object(self, obj: "FileMetadata") -> None: + self.file_id = obj.file_id + self.bucket = obj.bucket + self.file_name = obj.file_name + self.file_size = obj.file_size + self.storage_type = obj.storage_type + self.storage_path = obj.storage_path + self.uri = obj.uri + self.custom_metadata = obj.custom_metadata + self.file_hash = obj.file_hash + self._identifier = obj._identifier + + +class FileStorageURI: + """File storage URI.""" + + def __init__( + self, + storage_type: str, + bucket: str, + file_id: str, + version: Optional[str] = None, + custom_params: Optional[Dict[str, Any]] = None, + ): + """Initialize the file storage URI.""" + self.scheme = _SCHEMA + self.storage_type = storage_type + self.bucket = bucket + self.file_id = file_id + self.version = version + self.custom_params = custom_params or {} + + @classmethod + def parse(cls, uri: str) -> "FileStorageURI": + """Parse the URI string.""" + parsed = urlparse(uri) + if parsed.scheme != _SCHEMA: + raise ValueError(f"Invalid URI scheme. Must be '{_SCHEMA}'") + path_parts = parsed.path.strip("/").split("/") + if len(path_parts) < 2: + raise ValueError("Invalid URI path. Must contain bucket and file ID") + storage_type = parsed.netloc + bucket = path_parts[0] + file_id = path_parts[1] + version = path_parts[2] if len(path_parts) > 2 else None + custom_params = parse_qs(parsed.query) + return cls(storage_type, bucket, file_id, version, custom_params) + + def __str__(self) -> str: + """Get the string representation of the URI.""" + base_uri = f"{self.scheme}://{self.storage_type}/{self.bucket}/{self.file_id}" + if self.version: + base_uri += f"/{self.version}" + if self.custom_params: + query_string = urlencode(self.custom_params, doseq=True) + base_uri += f"?{query_string}" + return base_uri + + +class StorageBackend(ABC): + """Storage backend interface.""" + + storage_type: str = "__base__" + + @abstractmethod + def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str: + """Save the file data to the storage backend. + + Args: + bucket (str): The bucket name + file_id (str): The file ID + file_data (BinaryIO): The file data + + Returns: + str: The storage path + """ + + @abstractmethod + def load(self, fm: FileMetadata) -> BinaryIO: + """Load the file data from the storage backend. + + Args: + fm (FileMetadata): The file metadata + + Returns: + BinaryIO: The file data + """ + + @abstractmethod + def delete(self, fm: FileMetadata) -> bool: + """Delete the file data from the storage backend. + + Args: + fm (FileMetadata): The file metadata + + Returns: + bool: True if the file was deleted, False otherwise + """ + + @property + @abstractmethod + def save_chunk_size(self) -> int: + """Get the save chunk size. + + Returns: + int: The save chunk size + """ + + +class LocalFileStorage(StorageBackend): + """Local file storage backend.""" + + storage_type: str = "local" + + def __init__(self, base_path: str, save_chunk_size: int = 1024 * 1024): + """Initialize the local file storage backend.""" + self.base_path = base_path + self._save_chunk_size = save_chunk_size + os.makedirs(self.base_path, exist_ok=True) + + @property + def save_chunk_size(self) -> int: + """Get the save chunk size.""" + return self._save_chunk_size + + def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str: + """Save the file data to the local storage backend.""" + bucket_path = os.path.join(self.base_path, bucket) + os.makedirs(bucket_path, exist_ok=True) + file_path = os.path.join(bucket_path, file_id) + with open(file_path, "wb") as f: + while True: + chunk = file_data.read(self.save_chunk_size) + if not chunk: + break + f.write(chunk) + return file_path + + def load(self, fm: FileMetadata) -> BinaryIO: + """Load the file data from the local storage backend.""" + bucket_path = os.path.join(self.base_path, fm.bucket) + file_path = os.path.join(bucket_path, fm.file_id) + return open(file_path, "rb") # noqa: SIM115 + + def delete(self, fm: FileMetadata) -> bool: + """Delete the file data from the local storage backend.""" + bucket_path = os.path.join(self.base_path, fm.bucket) + file_path = os.path.join(bucket_path, fm.file_id) + if os.path.exists(file_path): + os.remove(file_path) + return True + return False + + +class FileStorageSystem: + """File storage system.""" + + def __init__( + self, + storage_backends: Dict[str, StorageBackend], + metadata_storage: Optional[StorageInterface[FileMetadata, Any]] = None, + check_hash: bool = True, + ): + """Initialize the file storage system.""" + metadata_storage = metadata_storage or InMemoryStorage() + self.storage_backends = storage_backends + self.metadata_storage = metadata_storage + self.check_hash = check_hash + self._save_chunk_size = min( + backend.save_chunk_size for backend in storage_backends.values() + ) + + def _calculate_file_hash(self, file_data: BinaryIO) -> str: + """Calculate the MD5 hash of the file data.""" + if not self.check_hash: + return "-1" + hasher = hashlib.md5() + file_data.seek(0) + while chunk := file_data.read(self._save_chunk_size): + hasher.update(chunk) + file_data.seek(0) + return hasher.hexdigest() + + @trace("file_storage_system.save_file") + def save_file( + self, + bucket: str, + file_name: str, + file_data: BinaryIO, + storage_type: str, + custom_metadata: Optional[Dict[str, Any]] = None, + ) -> str: + """Save the file data to the storage backend.""" + file_id = str(uuid.uuid4()) + backend = self.storage_backends.get(storage_type) + if not backend: + raise ValueError(f"Unsupported storage type: {storage_type}") + + with root_tracer.start_span( + "file_storage_system.save_file.backend_save", + metadata={ + "bucket": bucket, + "file_id": file_id, + "file_name": file_name, + "storage_type": storage_type, + }, + ): + storage_path = backend.save(bucket, file_id, file_data) + file_data.seek(0, 2) # Move to the end of the file + file_size = file_data.tell() # Get the file size + file_data.seek(0) # Reset file pointer + + with root_tracer.start_span( + "file_storage_system.save_file.calculate_hash", + ): + file_hash = self._calculate_file_hash(file_data) + uri = FileStorageURI( + storage_type, bucket, file_id, custom_params=custom_metadata + ) + + metadata = FileMetadata( + file_id=file_id, + bucket=bucket, + file_name=file_name, + file_size=file_size, + storage_type=storage_type, + storage_path=storage_path, + uri=str(uri), + custom_metadata=custom_metadata or {}, + file_hash=file_hash, + ) + + self.metadata_storage.save(metadata) + return str(uri) + + @trace("file_storage_system.get_file") + def get_file(self, uri: str) -> Tuple[BinaryIO, FileMetadata]: + """Get the file data from the storage backend.""" + parsed_uri = FileStorageURI.parse(uri) + metadata = self.metadata_storage.load( + FileMetadataIdentifier( + file_id=parsed_uri.file_id, bucket=parsed_uri.bucket + ), + FileMetadata, + ) + if not metadata: + raise FileNotFoundError(f"No metadata found for URI: {uri}") + + backend = self.storage_backends.get(metadata.storage_type) + if not backend: + raise ValueError(f"Unsupported storage type: {metadata.storage_type}") + + with root_tracer.start_span( + "file_storage_system.get_file.backend_load", + metadata={ + "bucket": metadata.bucket, + "file_id": metadata.file_id, + "file_name": metadata.file_name, + "storage_type": metadata.storage_type, + }, + ): + file_data = backend.load(metadata) + + with root_tracer.start_span( + "file_storage_system.get_file.verify_hash", + ): + calculated_hash = self._calculate_file_hash(file_data) + if calculated_hash != "-1" and calculated_hash != metadata.file_hash: + raise ValueError("File integrity check failed. Hash mismatch.") + + return file_data, metadata + + def get_file_metadata(self, bucket: str, file_id: str) -> Optional[FileMetadata]: + """Get the file metadata. + + Args: + bucket (str): The bucket name + file_id (str): The file ID + + Returns: + Optional[FileMetadata]: The file metadata + """ + fid = FileMetadataIdentifier(file_id=file_id, bucket=bucket) + return self.metadata_storage.load(fid, FileMetadata) + + def delete_file(self, uri: str) -> bool: + """Delete the file data from the storage backend. + + Args: + uri (str): The file URI + + Returns: + bool: True if the file was deleted, False otherwise + """ + parsed_uri = FileStorageURI.parse(uri) + fid = FileMetadataIdentifier( + file_id=parsed_uri.file_id, bucket=parsed_uri.bucket + ) + metadata = self.metadata_storage.load(fid, FileMetadata) + if not metadata: + return False + + backend = self.storage_backends.get(metadata.storage_type) + if not backend: + raise ValueError(f"Unsupported storage type: {metadata.storage_type}") + + if backend.delete(metadata): + try: + self.metadata_storage.delete(fid) + return True + except Exception: + # If the metadata deletion fails, log the error and return False + return False + return False + + def list_files( + self, bucket: str, filters: Optional[Dict[str, Any]] = None + ) -> List[FileMetadata]: + """List the files in the bucket.""" + filters = filters or {} + filters["bucket"] = bucket + return self.metadata_storage.query(QuerySpec(conditions=filters), FileMetadata) + + +class FileStorageClient(BaseComponent): + """File storage client component.""" + + name = ComponentType.FILE_STORAGE_CLIENT.value + + def __init__( + self, + system_app: Optional[SystemApp] = None, + storage_system: Optional[FileStorageSystem] = None, + ): + """Initialize the file storage client.""" + super().__init__(system_app=system_app) + if not storage_system: + from pathlib import Path + + base_path = Path.home() / ".cache" / "dbgpt" / "files" + storage_system = FileStorageSystem( + { + LocalFileStorage.storage_type: LocalFileStorage( + base_path=str(base_path) + ) + } + ) + + self.system_app = system_app + self._storage_system = storage_system + + def init_app(self, system_app: SystemApp): + """Initialize the application.""" + self.system_app = system_app + + @property + def storage_system(self) -> FileStorageSystem: + """Get the file storage system.""" + if not self._storage_system: + raise ValueError("File storage system not initialized") + return self._storage_system + + def upload_file( + self, + bucket: str, + file_path: str, + storage_type: str, + custom_metadata: Optional[Dict[str, Any]] = None, + ) -> str: + """Upload a file to the storage system. + + Args: + bucket (str): The bucket name + file_path (str): The file path + storage_type (str): The storage type + custom_metadata (Dict[str, Any], optional): Custom metadata. Defaults to + None. + + Returns: + str: The file URI + """ + with open(file_path, "rb") as file: + return self.save_file( + bucket, os.path.basename(file_path), file, storage_type, custom_metadata + ) + + def save_file( + self, + bucket: str, + file_name: str, + file_data: BinaryIO, + storage_type: str, + custom_metadata: Optional[Dict[str, Any]] = None, + ) -> str: + """Save the file data to the storage system. + + Args: + bucket (str): The bucket name + file_name (str): The file name + file_data (BinaryIO): The file data + storage_type (str): The storage type + custom_metadata (Dict[str, Any], optional): Custom metadata. Defaults to + None. + + Returns: + str: The file URI + """ + return self.storage_system.save_file( + bucket, file_name, file_data, storage_type, custom_metadata + ) + + def download_file(self, uri: str, destination_path: str) -> None: + """Download a file from the storage system. + + Args: + uri (str): The file URI + destination_path (str): The destination + + Raises: + FileNotFoundError: If the file is not found + """ + file_data, _ = self.storage_system.get_file(uri) + with open(destination_path, "wb") as f: + f.write(file_data.read()) + + def get_file(self, uri: str) -> Tuple[BinaryIO, FileMetadata]: + """Get the file data from the storage system. + + Args: + uri (str): The file URI + + Returns: + Tuple[BinaryIO, FileMetadata]: The file data and metadata + """ + return self.storage_system.get_file(uri) + + def get_file_by_id( + self, bucket: str, file_id: str + ) -> Tuple[BinaryIO, FileMetadata]: + """Get the file data from the storage system by ID. + + Args: + bucket (str): The bucket name + file_id (str): The file ID + + Returns: + Tuple[BinaryIO, FileMetadata]: The file data and metadata + """ + metadata = self.storage_system.get_file_metadata(bucket, file_id) + if not metadata: + raise FileNotFoundError(f"File {file_id} not found in bucket {bucket}") + return self.get_file(metadata.uri) + + def delete_file(self, uri: str) -> bool: + """Delete the file data from the storage system. + + Args: + uri (str): The file URI + + Returns: + bool: True if the file was deleted, False otherwise + """ + return self.storage_system.delete_file(uri) + + def delete_file_by_id(self, bucket: str, file_id: str) -> bool: + """Delete the file data from the storage system by ID. + + Args: + bucket (str): The bucket name + file_id (str): The file ID + + Returns: + bool: True if the file was deleted, False otherwise + """ + metadata = self.storage_system.get_file_metadata(bucket, file_id) + if not metadata: + raise FileNotFoundError(f"File {file_id} not found in bucket {bucket}") + return self.delete_file(metadata.uri) + + def list_files( + self, bucket: str, filters: Optional[Dict[str, Any]] = None + ) -> List[FileMetadata]: + """List the files in the bucket. + + Args: + bucket (str): The bucket name + filters (Dict[str, Any], optional): Filters. Defaults to None. + + Returns: + List[FileMetadata]: The list of file metadata + """ + return self.storage_system.list_files(bucket, filters) + + +class SimpleDistributedStorage(StorageBackend): + """Simple distributed storage backend.""" + + storage_type: str = "distributed" + + def __init__( + self, + node_address: str, + local_storage_path: str, + save_chunk_size: int = 1024 * 1024, + transfer_chunk_size: int = 1024 * 1024, + transfer_timeout: int = 360, + api_prefix: str = "/api/v2/serve/file/files", + ): + """Initialize the simple distributed storage backend.""" + self.node_address = node_address + self.local_storage_path = local_storage_path + os.makedirs(self.local_storage_path, exist_ok=True) + self._save_chunk_size = save_chunk_size + self._transfer_chunk_size = transfer_chunk_size + self._transfer_timeout = transfer_timeout + self._api_prefix = api_prefix + + @property + def save_chunk_size(self) -> int: + """Get the save chunk size.""" + return self._save_chunk_size + + def _get_file_path(self, bucket: str, file_id: str, node_address: str) -> str: + node_id = hashlib.md5(node_address.encode()).hexdigest() + return os.path.join(self.local_storage_path, bucket, f"{file_id}_{node_id}") + + def _parse_node_address(self, fm: FileMetadata) -> str: + storage_path = fm.storage_path + if not storage_path.startswith("distributed://"): + raise ValueError("Invalid storage path") + return storage_path.split("//")[1].split("/")[0] + + def save(self, bucket: str, file_id: str, file_data: BinaryIO) -> str: + """Save the file data to the distributed storage backend. + + Just save the file locally. + + Args: + bucket (str): The bucket name + file_id (str): The file ID + file_data (BinaryIO): The file data + + Returns: + str: The storage path + """ + file_path = self._get_file_path(bucket, file_id, self.node_address) + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, "wb") as f: + while True: + chunk = file_data.read(self.save_chunk_size) + if not chunk: + break + f.write(chunk) + + return f"distributed://{self.node_address}/{bucket}/{file_id}" + + def load(self, fm: FileMetadata) -> BinaryIO: + """Load the file data from the distributed storage backend. + + If the file is stored on the local node, load it from the local storage. + + Args: + fm (FileMetadata): The file metadata + + Returns: + BinaryIO: The file data + """ + file_id = fm.file_id + bucket = fm.bucket + node_address = self._parse_node_address(fm) + file_path = self._get_file_path(bucket, file_id, node_address) + + # TODO: check if the file is cached in local storage + if node_address == self.node_address: + if os.path.exists(file_path): + return open(file_path, "rb") # noqa: SIM115 + else: + raise FileNotFoundError(f"File {file_id} not found on the local node") + else: + response = requests.get( + f"http://{node_address}{self._api_prefix}/{bucket}/{file_id}", + timeout=self._transfer_timeout, + stream=True, + ) + response.raise_for_status() + # TODO: cache the file in local storage + return StreamedBytesIO( + response.iter_content(chunk_size=self._transfer_chunk_size) + ) + + def delete(self, fm: FileMetadata) -> bool: + """Delete the file data from the distributed storage backend. + + If the file is stored on the local node, delete it from the local storage. + If the file is stored on a remote node, send a delete request to the remote + node. + + Args: + fm (FileMetadata): The file metadata + + Returns: + bool: True if the file was deleted, False otherwise + """ + file_id = fm.file_id + bucket = fm.bucket + node_address = self._parse_node_address(fm) + file_path = self._get_file_path(bucket, file_id, node_address) + if node_address == self.node_address: + if os.path.exists(file_path): + os.remove(file_path) + return True + return False + else: + try: + response = requests.delete( + f"http://{node_address}{self._api_prefix}/{bucket}/{file_id}", + timeout=self._transfer_timeout, + ) + response.raise_for_status() + return True + except Exception: + return False + + +class StreamedBytesIO(io.BytesIO): + """A BytesIO subclass that can be used with streaming responses. + + Adapted from: https://gist.github.com/obskyr/b9d4b4223e7eaf4eedcd9defabb34f13 + """ + + def __init__(self, request_iterator): + """Initialize the StreamedBytesIO instance.""" + super().__init__() + self._bytes = BytesIO() + self._iterator = request_iterator + + def _load_all(self): + self._bytes.seek(0, io.SEEK_END) + for chunk in self._iterator: + self._bytes.write(chunk) + + def _load_until(self, goal_position): + current_position = self._bytes.seek(0, io.SEEK_END) + while current_position < goal_position: + try: + current_position += self._bytes.write(next(self._iterator)) + except StopIteration: + break + + def tell(self) -> int: + """Get the current position.""" + return self._bytes.tell() + + def read(self, size: Optional[int] = None) -> bytes: + """Read the data from the stream. + + Args: + size (Optional[int], optional): The number of bytes to read. Defaults to + None. + + Returns: + bytes: The read data + """ + left_off_at = self._bytes.tell() + if size is None: + self._load_all() + else: + goal_position = left_off_at + size + self._load_until(goal_position) + + self._bytes.seek(left_off_at) + return self._bytes.read(size) + + def seek(self, position: int, whence: int = io.SEEK_SET): + """Seek to a position in the stream. + + Args: + position (int): The position + whence (int, optional): The reference point. Defaults to io.SEEK + + Raises: + ValueError: If the reference point is invalid + """ + if whence == io.SEEK_END: + self._load_all() + else: + self._bytes.seek(position, whence) + + def __enter__(self): + """Enter the context manager.""" + return self + + def __exit__(self, ext_type, value, tb): + """Exit the context manager.""" + self._bytes.close() diff --git a/dbgpt/core/interface/tests/test_file.py b/dbgpt/core/interface/tests/test_file.py new file mode 100644 index 000000000..f6e462944 --- /dev/null +++ b/dbgpt/core/interface/tests/test_file.py @@ -0,0 +1,506 @@ +import hashlib +import io +import os +from unittest import mock + +import pytest + +from ..file import ( + FileMetadata, + FileMetadataIdentifier, + FileStorageClient, + FileStorageSystem, + InMemoryStorage, + LocalFileStorage, + SimpleDistributedStorage, +) + + +@pytest.fixture +def temp_test_file_dir(tmpdir): + return str(tmpdir) + + +@pytest.fixture +def temp_storage_path(tmpdir): + return str(tmpdir) + + +@pytest.fixture +def local_storage_backend(temp_storage_path): + return LocalFileStorage(temp_storage_path) + + +@pytest.fixture +def distributed_storage_backend(temp_storage_path): + node_address = "127.0.0.1:8000" + return SimpleDistributedStorage(node_address, temp_storage_path) + + +@pytest.fixture +def file_storage_system(local_storage_backend): + backends = {"local": local_storage_backend} + metadata_storage = InMemoryStorage() + return FileStorageSystem(backends, metadata_storage) + + +@pytest.fixture +def file_storage_client(file_storage_system): + return FileStorageClient(storage_system=file_storage_system) + + +@pytest.fixture +def sample_file_path(temp_test_file_dir): + file_path = os.path.join(temp_test_file_dir, "sample.txt") + with open(file_path, "wb") as f: + f.write(b"Sample file content") + return file_path + + +@pytest.fixture +def sample_file_data(): + return io.BytesIO(b"Sample file content for distributed storage") + + +def test_save_file(file_storage_client, sample_file_path): + bucket = "test-bucket" + uri = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + assert uri.startswith("dbgpt-fs://local/test-bucket/") + assert os.path.exists(sample_file_path) + + +def test_get_file(file_storage_client, sample_file_path): + bucket = "test-bucket" + uri = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + file_data, metadata = file_storage_client.storage_system.get_file(uri) + assert file_data.read() == b"Sample file content" + assert metadata.file_name == "sample.txt" + assert metadata.bucket == bucket + + +def test_delete_file(file_storage_client, sample_file_path): + bucket = "test-bucket" + uri = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + assert len(file_storage_client.list_files(bucket=bucket)) == 1 + result = file_storage_client.delete_file(uri) + assert result is True + assert len(file_storage_client.list_files(bucket=bucket)) == 0 + + +def test_list_files(file_storage_client, sample_file_path): + bucket = "test-bucket" + uri1 = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + files = file_storage_client.list_files(bucket=bucket) + assert len(files) == 1 + + +def test_save_file_unsupported_storage(file_storage_system, sample_file_path): + bucket = "test-bucket" + with pytest.raises(ValueError): + file_storage_system.save_file( + bucket=bucket, + file_name="unsupported.txt", + file_data=io.BytesIO(b"Unsupported storage"), + storage_type="unsupported", + ) + + +def test_get_file_not_found(file_storage_system): + with pytest.raises(FileNotFoundError): + file_storage_system.get_file("dbgpt-fs://local/test-bucket/nonexistent") + + +def test_delete_file_not_found(file_storage_system): + result = file_storage_system.delete_file("dbgpt-fs://local/test-bucket/nonexistent") + assert result is False + + +def test_metadata_management(file_storage_system): + bucket = "test-bucket" + file_id = "test_file" + metadata = file_storage_system.metadata_storage.save( + FileMetadata( + file_id=file_id, + bucket=bucket, + file_name="test.txt", + file_size=100, + storage_type="local", + storage_path="/path/to/test.txt", + uri="dbgpt-fs://local/test-bucket/test_file", + custom_metadata={"key": "value"}, + file_hash="hash", + ) + ) + + loaded_metadata = file_storage_system.metadata_storage.load( + FileMetadataIdentifier(file_id=file_id, bucket=bucket), FileMetadata + ) + assert loaded_metadata.file_name == "test.txt" + assert loaded_metadata.custom_metadata["key"] == "value" + assert loaded_metadata.bucket == bucket + + +def test_concurrent_save_and_delete(file_storage_client, sample_file_path): + bucket = "test-bucket" + + # Simulate concurrent file save and delete operations + def save_file(): + return file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + + def delete_file(uri): + return file_storage_client.delete_file(uri) + + uri = save_file() + + # Simulate concurrent operations + save_file() + delete_file(uri) + assert len(file_storage_client.list_files(bucket=bucket)) == 1 + + +def test_large_file_handling(file_storage_client, temp_storage_path): + bucket = "test-bucket" + large_file_path = os.path.join(temp_storage_path, "large_sample.bin") + with open(large_file_path, "wb") as f: + f.write(os.urandom(10 * 1024 * 1024)) # 10 MB file + + uri = file_storage_client.upload_file( + bucket=bucket, + file_path=large_file_path, + storage_type="local", + custom_metadata={"description": "Large file test"}, + ) + file_data, metadata = file_storage_client.storage_system.get_file(uri) + assert file_data.read() == open(large_file_path, "rb").read() + assert metadata.file_name == "large_sample.bin" + assert metadata.bucket == bucket + + +def test_file_hash_verification_success(file_storage_client, sample_file_path): + bucket = "test-bucket" + # Upload file and + uri = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + + file_data, metadata = file_storage_client.storage_system.get_file(uri) + file_hash = metadata.file_hash + calculated_hash = file_storage_client.storage_system._calculate_file_hash(file_data) + + assert ( + file_hash == calculated_hash + ), "File hash should match after saving and loading" + + +def test_file_hash_verification_failure(file_storage_client, sample_file_path): + bucket = "test-bucket" + # Upload file and + uri = file_storage_client.upload_file( + bucket=bucket, file_path=sample_file_path, storage_type="local" + ) + + # Modify the file content manually to simulate file tampering + storage_system = file_storage_client.storage_system + metadata = storage_system.metadata_storage.load( + FileMetadataIdentifier(file_id=uri.split("/")[-1], bucket=bucket), FileMetadata + ) + with open(metadata.storage_path, "wb") as f: + f.write(b"Tampered content") + + # Get file should raise an exception due to hash mismatch + with pytest.raises(ValueError, match="File integrity check failed. Hash mismatch."): + storage_system.get_file(uri) + + +def test_file_isolation_across_buckets(file_storage_client, sample_file_path): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload the same file to two different buckets + uri1 = file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + uri2 = file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + # Verify both URIs are different and point to different files + assert uri1 != uri2 + + file_data1, metadata1 = file_storage_client.storage_system.get_file(uri1) + file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2) + + assert file_data1.read() == b"Sample file content" + assert file_data2.read() == b"Sample file content" + assert metadata1.bucket == bucket1 + assert metadata2.bucket == bucket2 + + +def test_list_files_in_specific_bucket(file_storage_client, sample_file_path): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload a file to both buckets + file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + # List files in bucket1 and bucket2 + files_in_bucket1 = file_storage_client.list_files(bucket=bucket1) + files_in_bucket2 = file_storage_client.list_files(bucket=bucket2) + + assert len(files_in_bucket1) == 1 + assert len(files_in_bucket2) == 1 + assert files_in_bucket1[0].bucket == bucket1 + assert files_in_bucket2[0].bucket == bucket2 + + +def test_delete_file_in_one_bucket_does_not_affect_other_bucket( + file_storage_client, sample_file_path +): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload the same file to two different buckets + uri1 = file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + uri2 = file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + # Delete the file in bucket1 + file_storage_client.delete_file(uri1) + + # Check that the file in bucket1 is deleted + assert len(file_storage_client.list_files(bucket=bucket1)) == 0 + + # Check that the file in bucket2 is still there + assert len(file_storage_client.list_files(bucket=bucket2)) == 1 + file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2) + assert file_data2.read() == b"Sample file content" + + +def test_file_hash_verification_in_different_buckets( + file_storage_client, sample_file_path +): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload the file to both buckets + uri1 = file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + uri2 = file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + file_data1, metadata1 = file_storage_client.storage_system.get_file(uri1) + file_data2, metadata2 = file_storage_client.storage_system.get_file(uri2) + + # Verify that file hashes are the same for the same content + file_hash1 = file_storage_client.storage_system._calculate_file_hash(file_data1) + file_hash2 = file_storage_client.storage_system._calculate_file_hash(file_data2) + + assert file_hash1 == metadata1.file_hash + assert file_hash2 == metadata2.file_hash + assert file_hash1 == file_hash2 + + +def test_file_download_from_different_buckets( + file_storage_client, sample_file_path, temp_storage_path +): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload the file to both buckets + uri1 = file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + uri2 = file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + # Download files to different locations + download_path1 = os.path.join(temp_storage_path, "downloaded_bucket1.txt") + download_path2 = os.path.join(temp_storage_path, "downloaded_bucket2.txt") + + file_storage_client.download_file(uri1, download_path1) + file_storage_client.download_file(uri2, download_path2) + + # Verify contents of downloaded files + assert open(download_path1, "rb").read() == b"Sample file content" + assert open(download_path2, "rb").read() == b"Sample file content" + + +def test_delete_all_files_in_bucket(file_storage_client, sample_file_path): + bucket1 = "bucket1" + bucket2 = "bucket2" + + # Upload files to both buckets + file_storage_client.upload_file( + bucket=bucket1, file_path=sample_file_path, storage_type="local" + ) + file_storage_client.upload_file( + bucket=bucket2, file_path=sample_file_path, storage_type="local" + ) + + # Delete all files in bucket1 + for file in file_storage_client.list_files(bucket=bucket1): + file_storage_client.delete_file(file.uri) + + # Verify bucket1 is empty + assert len(file_storage_client.list_files(bucket=bucket1)) == 0 + + # Verify bucket2 still has files + assert len(file_storage_client.list_files(bucket=bucket2)) == 1 + + +def test_simple_distributed_storage_save_file( + distributed_storage_backend, sample_file_data, temp_storage_path +): + bucket = "test-bucket" + file_id = "test_file" + file_path = distributed_storage_backend.save(bucket, file_id, sample_file_data) + + expected_path = os.path.join( + temp_storage_path, + bucket, + f"{file_id}_{hashlib.md5('127.0.0.1:8000'.encode()).hexdigest()}", + ) + assert file_path == f"distributed://127.0.0.1:8000/{bucket}/{file_id}" + assert os.path.exists(expected_path) + + +def test_simple_distributed_storage_load_file_local( + distributed_storage_backend, sample_file_data +): + bucket = "test-bucket" + file_id = "test_file" + distributed_storage_backend.save(bucket, file_id, sample_file_data) + + metadata = FileMetadata( + file_id=file_id, + bucket=bucket, + file_name="test.txt", + file_size=len(sample_file_data.getvalue()), + storage_type="distributed", + storage_path=f"distributed://127.0.0.1:8000/{bucket}/{file_id}", + uri=f"distributed://127.0.0.1:8000/{bucket}/{file_id}", + custom_metadata={}, + file_hash="hash", + ) + + file_data = distributed_storage_backend.load(metadata) + assert file_data.read() == b"Sample file content for distributed storage" + + +@mock.patch("requests.get") +def test_simple_distributed_storage_load_file_remote( + mock_get, distributed_storage_backend, sample_file_data +): + bucket = "test-bucket" + file_id = "test_file" + remote_node_address = "127.0.0.2:8000" + + # Mock the response from remote node + mock_response = mock.Mock() + mock_response.iter_content = mock.Mock( + return_value=iter([b"Sample file content for distributed storage"]) + ) + mock_response.raise_for_status = mock.Mock(return_value=None) + mock_get.return_value = mock_response + + metadata = FileMetadata( + file_id=file_id, + bucket=bucket, + file_name="test.txt", + file_size=len(sample_file_data.getvalue()), + storage_type="distributed", + storage_path=f"distributed://{remote_node_address}/{bucket}/{file_id}", + uri=f"distributed://{remote_node_address}/{bucket}/{file_id}", + custom_metadata={}, + file_hash="hash", + ) + + file_data = distributed_storage_backend.load(metadata) + assert file_data.read() == b"Sample file content for distributed storage" + mock_get.assert_called_once_with( + f"http://{remote_node_address}/api/v2/serve/file/files/{bucket}/{file_id}", + stream=True, + timeout=360, + ) + + +def test_simple_distributed_storage_delete_file_local( + distributed_storage_backend, sample_file_data, temp_storage_path +): + bucket = "test-bucket" + file_id = "test_file" + distributed_storage_backend.save(bucket, file_id, sample_file_data) + + metadata = FileMetadata( + file_id=file_id, + bucket=bucket, + file_name="test.txt", + file_size=len(sample_file_data.getvalue()), + storage_type="distributed", + storage_path=f"distributed://127.0.0.1:8000/{bucket}/{file_id}", + uri=f"distributed://127.0.0.1:8000/{bucket}/{file_id}", + custom_metadata={}, + file_hash="hash", + ) + + result = distributed_storage_backend.delete(metadata) + file_path = os.path.join( + temp_storage_path, + bucket, + f"{file_id}_{hashlib.md5('127.0.0.1:8000'.encode()).hexdigest()}", + ) + assert result is True + assert not os.path.exists(file_path) + + +@mock.patch("requests.delete") +def test_simple_distributed_storage_delete_file_remote( + mock_delete, distributed_storage_backend, sample_file_data +): + bucket = "test-bucket" + file_id = "test_file" + remote_node_address = "127.0.0.2:8000" + + mock_response = mock.Mock() + mock_response.raise_for_status = mock.Mock(return_value=None) + mock_delete.return_value = mock_response + + metadata = FileMetadata( + file_id=file_id, + bucket=bucket, + file_name="test.txt", + file_size=len(sample_file_data.getvalue()), + storage_type="distributed", + storage_path=f"distributed://{remote_node_address}/{bucket}/{file_id}", + uri=f"distributed://{remote_node_address}/{bucket}/{file_id}", + custom_metadata={}, + file_hash="hash", + ) + + result = distributed_storage_backend.delete(metadata) + assert result is True + mock_delete.assert_called_once_with( + f"http://{remote_node_address}/api/v2/serve/file/files/{bucket}/{file_id}", + timeout=360, + ) diff --git a/dbgpt/serve/file/__init__.py b/dbgpt/serve/file/__init__.py new file mode 100644 index 000000000..54a428180 --- /dev/null +++ b/dbgpt/serve/file/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve file` diff --git a/dbgpt/serve/file/api/__init__.py b/dbgpt/serve/file/api/__init__.py new file mode 100644 index 000000000..54a428180 --- /dev/null +++ b/dbgpt/serve/file/api/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve file` diff --git a/dbgpt/serve/file/api/endpoints.py b/dbgpt/serve/file/api/endpoints.py new file mode 100644 index 000000000..edf1d2d98 --- /dev/null +++ b/dbgpt/serve/file/api/endpoints.py @@ -0,0 +1,159 @@ +import logging +from functools import cache +from typing import List, Optional +from urllib.parse import quote + +from fastapi import APIRouter, Depends, HTTPException, Query, UploadFile +from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer +from starlette.responses import StreamingResponse + +from dbgpt.component import SystemApp +from dbgpt.serve.core import Result, blocking_func_to_async +from dbgpt.util import PaginationResult + +from ..config import APP_NAME, SERVE_APP_NAME, SERVE_SERVICE_COMPONENT_NAME, ServeConfig +from ..service.service import Service +from .schemas import ServeRequest, ServerResponse, UploadFileResponse + +router = APIRouter() +logger = logging.getLogger(__name__) + +# Add your API endpoints here + +global_system_app: Optional[SystemApp] = None + + +def get_service() -> Service: + """Get the service instance""" + return global_system_app.get_component(SERVE_SERVICE_COMPONENT_NAME, Service) + + +get_bearer_token = HTTPBearer(auto_error=False) + + +@cache +def _parse_api_keys(api_keys: str) -> List[str]: + """Parse the string api keys to a list + + Args: + api_keys (str): The string api keys + + Returns: + List[str]: The list of api keys + """ + if not api_keys: + return [] + return [key.strip() for key in api_keys.split(",")] + + +async def check_api_key( + auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), + service: Service = Depends(get_service), +) -> Optional[str]: + """Check the api key + + If the api key is not set, allow all. + + Your can pass the token in you request header like this: + + .. code-block:: python + + import requests + + client_api_key = "your_api_key" + headers = {"Authorization": "Bearer " + client_api_key} + res = requests.get("http://test/hello", headers=headers) + assert res.status_code == 200 + + """ + if service.config.api_keys: + api_keys = _parse_api_keys(service.config.api_keys) + if auth is None or (token := auth.credentials) not in api_keys: + raise HTTPException( + status_code=401, + detail={ + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + }, + ) + return token + else: + # api_keys not set; allow all + return None + + +@router.get("/health") +async def health(): + """Health check endpoint""" + return {"status": "ok"} + + +@router.get("/test_auth", dependencies=[Depends(check_api_key)]) +async def test_auth(): + """Test auth endpoint""" + return {"status": "ok"} + + +@router.post( + "/files/{bucket}", + response_model=Result[List[UploadFileResponse]], + dependencies=[Depends(check_api_key)], +) +async def upload_files( + bucket: str, files: List[UploadFile], service: Service = Depends(get_service) +) -> Result[List[UploadFileResponse]]: + """Upload files by a list of UploadFile.""" + logger.info(f"upload_files: bucket={bucket}, files={files}") + results = await blocking_func_to_async( + global_system_app, service.upload_files, bucket, "distributed", files + ) + return Result.succ(results) + + +@router.get("/files/{bucket}/{file_id}", dependencies=[Depends(check_api_key)]) +async def download_file( + bucket: str, file_id: str, service: Service = Depends(get_service) +): + """Download a file by file_id.""" + logger.info(f"download_file: bucket={bucket}, file_id={file_id}") + file_data, file_metadata = await blocking_func_to_async( + global_system_app, service.download_file, bucket, file_id + ) + file_name_encoded = quote(file_metadata.file_name) + + def file_iterator(raw_iter): + with raw_iter: + while chunk := raw_iter.read( + service.config.file_server_download_chunk_size + ): + yield chunk + + response = StreamingResponse( + file_iterator(file_data), media_type="application/octet-stream" + ) + response.headers[ + "Content-Disposition" + ] = f"attachment; filename={file_name_encoded}" + return response + + +@router.delete("/files/{bucket}/{file_id}", dependencies=[Depends(check_api_key)]) +async def delete_file( + bucket: str, file_id: str, service: Service = Depends(get_service) +): + """Delete a file by file_id.""" + await blocking_func_to_async( + global_system_app, service.delete_file, bucket, file_id + ) + return Result.succ(None) + + +def init_endpoints(system_app: SystemApp) -> None: + """Initialize the endpoints""" + global global_system_app + system_app.register(Service) + global_system_app = system_app diff --git a/dbgpt/serve/file/api/schemas.py b/dbgpt/serve/file/api/schemas.py new file mode 100644 index 000000000..911f71db3 --- /dev/null +++ b/dbgpt/serve/file/api/schemas.py @@ -0,0 +1,43 @@ +# Define your Pydantic schemas here +from typing import Any, Dict + +from dbgpt._private.pydantic import BaseModel, ConfigDict, Field, model_to_dict + +from ..config import SERVE_APP_NAME_HUMP + + +class ServeRequest(BaseModel): + """File request model""" + + # TODO define your own fields here + + model_config = ConfigDict(title=f"ServeRequest for {SERVE_APP_NAME_HUMP}") + + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Convert the model to a dictionary""" + return model_to_dict(self, **kwargs) + + +class ServerResponse(BaseModel): + """File response model""" + + # TODO define your own fields here + + model_config = ConfigDict(title=f"ServerResponse for {SERVE_APP_NAME_HUMP}") + + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Convert the model to a dictionary""" + return model_to_dict(self, **kwargs) + + +class UploadFileResponse(BaseModel): + """Upload file response model""" + + file_name: str = Field(..., title="The name of the uploaded file") + file_id: str = Field(..., title="The ID of the uploaded file") + bucket: str = Field(..., title="The bucket of the uploaded file") + uri: str = Field(..., title="The URI of the uploaded file") + + def to_dict(self, **kwargs) -> Dict[str, Any]: + """Convert the model to a dictionary""" + return model_to_dict(self, **kwargs) diff --git a/dbgpt/serve/file/config.py b/dbgpt/serve/file/config.py new file mode 100644 index 000000000..1ab1afede --- /dev/null +++ b/dbgpt/serve/file/config.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass, field +from typing import Optional + +from dbgpt.serve.core import BaseServeConfig + +APP_NAME = "file" +SERVE_APP_NAME = "dbgpt_serve_file" +SERVE_APP_NAME_HUMP = "dbgpt_serve_File" +SERVE_CONFIG_KEY_PREFIX = "dbgpt.serve.file." +SERVE_SERVICE_COMPONENT_NAME = f"{SERVE_APP_NAME}_service" +# Database table name +SERVER_APP_TABLE_NAME = "dbgpt_serve_file" + + +@dataclass +class ServeConfig(BaseServeConfig): + """Parameters for the serve command""" + + # TODO: add your own parameters here + api_keys: Optional[str] = field( + default=None, metadata={"help": "API keys for the endpoint, if None, allow all"} + ) + check_hash: Optional[bool] = field( + default=True, metadata={"help": "Check the hash of the file when downloading"} + ) + file_server_host: Optional[str] = field( + default=None, metadata={"help": "The host of the file server"} + ) + file_server_port: Optional[int] = field( + default=5670, metadata={"help": "The port of the file server"} + ) + file_server_download_chunk_size: Optional[int] = field( + default=1024 * 1024, + metadata={"help": "The chunk size when downloading the file"}, + ) + file_server_save_chunk_size: Optional[int] = field( + default=1024 * 1024, metadata={"help": "The chunk size when saving the file"} + ) + file_server_transfer_chunk_size: Optional[int] = field( + default=1024 * 1024, + metadata={"help": "The chunk size when transferring the file"}, + ) + file_server_transfer_timeout: Optional[int] = field( + default=360, metadata={"help": "The timeout when transferring the file"} + ) + local_storage_path: Optional[str] = field( + default=None, metadata={"help": "The local storage path"} + ) + + def get_node_address(self) -> str: + """Get the node address""" + file_server_host = self.file_server_host + if not file_server_host: + from dbgpt.util.net_utils import _get_ip_address + + file_server_host = _get_ip_address() + file_server_port = self.file_server_port or 5670 + return f"{file_server_host}:{file_server_port}" + + def get_local_storage_path(self) -> str: + """Get the local storage path""" + local_storage_path = self.local_storage_path + if not local_storage_path: + from pathlib import Path + + base_path = Path.home() / ".cache" / "dbgpt" / "files" + local_storage_path = str(base_path) + return local_storage_path diff --git a/dbgpt/serve/file/dependencies.py b/dbgpt/serve/file/dependencies.py new file mode 100644 index 000000000..8598ecd97 --- /dev/null +++ b/dbgpt/serve/file/dependencies.py @@ -0,0 +1 @@ +# Define your dependencies here diff --git a/dbgpt/serve/file/models/__init__.py b/dbgpt/serve/file/models/__init__.py new file mode 100644 index 000000000..54a428180 --- /dev/null +++ b/dbgpt/serve/file/models/__init__.py @@ -0,0 +1,2 @@ +# This is an auto-generated __init__.py file +# generated by `dbgpt new serve file` diff --git a/dbgpt/serve/file/models/file_adapter.py b/dbgpt/serve/file/models/file_adapter.py new file mode 100644 index 000000000..a8ab36465 --- /dev/null +++ b/dbgpt/serve/file/models/file_adapter.py @@ -0,0 +1,66 @@ +import json +from typing import Type + +from sqlalchemy.orm import Session + +from dbgpt.core.interface.file import FileMetadata, FileMetadataIdentifier +from dbgpt.core.interface.storage import StorageItemAdapter + +from .models import ServeEntity + + +class FileMetadataAdapter(StorageItemAdapter[FileMetadata, ServeEntity]): + """File metadata adapter. + + Convert between storage format and database model. + """ + + def to_storage_format(self, item: FileMetadata) -> ServeEntity: + """Convert to storage format.""" + custom_metadata = ( + json.dumps(item.custom_metadata, ensure_ascii=False) + if item.custom_metadata + else None + ) + return ServeEntity( + bucket=item.bucket, + file_id=item.file_id, + file_name=item.file_name, + file_size=item.file_size, + storage_type=item.storage_type, + storage_path=item.storage_path, + uri=item.uri, + custom_metadata=custom_metadata, + file_hash=item.file_hash, + ) + + def from_storage_format(self, model: ServeEntity) -> FileMetadata: + """Convert from storage format.""" + custom_metadata = ( + json.loads(model.custom_metadata) if model.custom_metadata else None + ) + return FileMetadata( + bucket=model.bucket, + file_id=model.file_id, + file_name=model.file_name, + file_size=model.file_size, + storage_type=model.storage_type, + storage_path=model.storage_path, + uri=model.uri, + custom_metadata=custom_metadata, + file_hash=model.file_hash, + ) + + def get_query_for_identifier( + self, + storage_format: Type[ServeEntity], + resource_id: FileMetadataIdentifier, + **kwargs, + ): + """Get query for identifier.""" + session: Session = kwargs.get("session") + if session is None: + raise Exception("session is None") + return session.query(storage_format).filter( + storage_format.file_id == resource_id.file_id + ) diff --git a/dbgpt/serve/file/models/models.py b/dbgpt/serve/file/models/models.py new file mode 100644 index 000000000..62dd1ef80 --- /dev/null +++ b/dbgpt/serve/file/models/models.py @@ -0,0 +1,87 @@ +"""This is an auto-generated model file +You can define your own models and DAOs here +""" + +from datetime import datetime +from typing import Any, Dict, Union + +from sqlalchemy import Column, DateTime, Index, Integer, String, Text + +from dbgpt.storage.metadata import BaseDao, Model, db + +from ..api.schemas import ServeRequest, ServerResponse +from ..config import SERVER_APP_TABLE_NAME, ServeConfig + + +class ServeEntity(Model): + __tablename__ = SERVER_APP_TABLE_NAME + id = Column(Integer, primary_key=True, comment="Auto increment id") + + bucket = Column(String(255), nullable=False, comment="Bucket name") + file_id = Column(String(255), nullable=False, comment="File id") + file_name = Column(String(256), nullable=False, comment="File name") + file_size = Column(Integer, nullable=True, comment="File size") + storage_type = Column(String(32), nullable=False, comment="Storage type") + storage_path = Column(String(512), nullable=False, comment="Storage path") + uri = Column(String(512), nullable=False, comment="File URI") + custom_metadata = Column( + Text, nullable=True, comment="Custom metadata, JSON format" + ) + file_hash = Column(String(128), nullable=True, comment="File hash") + + gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") + gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time") + + def __repr__(self): + return ( + f"ServeEntity(id={self.id}, gmt_created='{self.gmt_created}', " + f"gmt_modified='{self.gmt_modified}')" + ) + + +class ServeDao(BaseDao[ServeEntity, ServeRequest, ServerResponse]): + """The DAO class for File""" + + def __init__(self, serve_config: ServeConfig): + super().__init__() + self._serve_config = serve_config + + def from_request(self, request: Union[ServeRequest, Dict[str, Any]]) -> ServeEntity: + """Convert the request to an entity + + Args: + request (Union[ServeRequest, Dict[str, Any]]): The request + + Returns: + T: The entity + """ + request_dict = ( + request.to_dict() if isinstance(request, ServeRequest) else request + ) + entity = ServeEntity(**request_dict) + # TODO implement your own logic here, transfer the request_dict to an entity + return entity + + def to_request(self, entity: ServeEntity) -> ServeRequest: + """Convert the entity to a request + + Args: + entity (T): The entity + + Returns: + REQ: The request + """ + # TODO implement your own logic here, transfer the entity to a request + return ServeRequest() + + def to_response(self, entity: ServeEntity) -> ServerResponse: + """Convert the entity to a response + + Args: + entity (T): The entity + + Returns: + RES: The response + """ + # TODO implement your own logic here, transfer the entity to a response + return ServerResponse() diff --git a/dbgpt/serve/file/serve.py b/dbgpt/serve/file/serve.py new file mode 100644 index 000000000..559509573 --- /dev/null +++ b/dbgpt/serve/file/serve.py @@ -0,0 +1,113 @@ +import logging +from typing import List, Optional, Union + +from sqlalchemy import URL + +from dbgpt.component import SystemApp +from dbgpt.core.interface.file import FileStorageClient +from dbgpt.serve.core import BaseServe +from dbgpt.storage.metadata import DatabaseManager + +from .api.endpoints import init_endpoints, router +from .config import ( + APP_NAME, + SERVE_APP_NAME, + SERVE_APP_NAME_HUMP, + SERVE_CONFIG_KEY_PREFIX, + ServeConfig, +) + +logger = logging.getLogger(__name__) + + +class Serve(BaseServe): + """Serve component for DB-GPT""" + + name = SERVE_APP_NAME + + def __init__( + self, + system_app: SystemApp, + api_prefix: Optional[str] = f"/api/v2/serve/{APP_NAME}", + api_tags: Optional[List[str]] = None, + db_url_or_db: Union[str, URL, DatabaseManager] = None, + try_create_tables: Optional[bool] = False, + ): + if api_tags is None: + api_tags = [SERVE_APP_NAME_HUMP] + super().__init__( + system_app, api_prefix, api_tags, db_url_or_db, try_create_tables + ) + self._db_manager: Optional[DatabaseManager] = None + + self._db_manager: Optional[DatabaseManager] = None + self._file_storage_client: Optional[FileStorageClient] = None + self._serve_config: Optional[ServeConfig] = None + + def init_app(self, system_app: SystemApp): + if self._app_has_initiated: + return + self._system_app = system_app + self._system_app.app.include_router( + router, prefix=self._api_prefix, tags=self._api_tags + ) + init_endpoints(self._system_app) + self._app_has_initiated = True + + def on_init(self): + """Called when init the application. + + You can do some initialization here. You can't get other components here because they may be not initialized yet + """ + # import your own module here to ensure the module is loaded before the application starts + from .models.models import ServeEntity + + def before_start(self): + """Called before the start of the application.""" + from dbgpt.core.interface.file import ( + FileStorageSystem, + SimpleDistributedStorage, + ) + from dbgpt.storage.metadata.db_storage import SQLAlchemyStorage + from dbgpt.util.serialization.json_serialization import JsonSerializer + + from .models.file_adapter import FileMetadataAdapter + from .models.models import ServeEntity + + self._serve_config = ServeConfig.from_app_config( + self._system_app.config, SERVE_CONFIG_KEY_PREFIX + ) + + self._db_manager = self.create_or_get_db_manager() + serializer = JsonSerializer() + storage = SQLAlchemyStorage( + self._db_manager, + ServeEntity, + FileMetadataAdapter(), + serializer, + ) + simple_distributed_storage = SimpleDistributedStorage( + node_address=self._serve_config.get_node_address(), + local_storage_path=self._serve_config.get_local_storage_path(), + save_chunk_size=self._serve_config.file_server_save_chunk_size, + transfer_chunk_size=self._serve_config.file_server_transfer_chunk_size, + transfer_timeout=self._serve_config.file_server_transfer_timeout, + ) + storage_backends = { + simple_distributed_storage.storage_type: simple_distributed_storage, + } + fs = FileStorageSystem( + storage_backends, + metadata_storage=storage, + check_hash=self._serve_config.check_hash, + ) + self._file_storage_client = FileStorageClient( + system_app=self._system_app, storage_system=fs + ) + + @property + def file_storage_client(self) -> FileStorageClient: + """Returns the file storage client.""" + if not self._file_storage_client: + raise ValueError("File storage client is not initialized") + return self._file_storage_client diff --git a/dbgpt/serve/file/service/__init__.py b/dbgpt/serve/file/service/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/file/service/service.py b/dbgpt/serve/file/service/service.py new file mode 100644 index 000000000..d4d0118f3 --- /dev/null +++ b/dbgpt/serve/file/service/service.py @@ -0,0 +1,106 @@ +import logging +from typing import BinaryIO, List, Optional, Tuple + +from fastapi import UploadFile + +from dbgpt.component import BaseComponent, SystemApp +from dbgpt.core.interface.file import FileMetadata, FileStorageClient, FileStorageURI +from dbgpt.serve.core import BaseService +from dbgpt.storage.metadata import BaseDao +from dbgpt.util.pagination_utils import PaginationResult +from dbgpt.util.tracer import root_tracer, trace + +from ..api.schemas import ServeRequest, ServerResponse, UploadFileResponse +from ..config import SERVE_CONFIG_KEY_PREFIX, SERVE_SERVICE_COMPONENT_NAME, ServeConfig +from ..models.models import ServeDao, ServeEntity + +logger = logging.getLogger(__name__) + + +class Service(BaseService[ServeEntity, ServeRequest, ServerResponse]): + """The service class for File""" + + name = SERVE_SERVICE_COMPONENT_NAME + + def __init__(self, system_app: SystemApp, dao: Optional[ServeDao] = None): + self._system_app = None + self._serve_config: ServeConfig = None + self._dao: ServeDao = dao + super().__init__(system_app) + + def init_app(self, system_app: SystemApp) -> None: + """Initialize the service + + Args: + system_app (SystemApp): The system app + """ + super().init_app(system_app) + self._serve_config = ServeConfig.from_app_config( + system_app.config, SERVE_CONFIG_KEY_PREFIX + ) + self._dao = self._dao or ServeDao(self._serve_config) + self._system_app = system_app + + @property + def dao(self) -> BaseDao[ServeEntity, ServeRequest, ServerResponse]: + """Returns the internal DAO.""" + return self._dao + + @property + def config(self) -> ServeConfig: + """Returns the internal ServeConfig.""" + return self._serve_config + + @property + def file_storage_client(self) -> FileStorageClient: + """Returns the internal FileStorageClient. + + Returns: + FileStorageClient: The internal FileStorageClient + """ + file_storage_client = FileStorageClient.get_instance( + self._system_app, default_component=None + ) + if file_storage_client: + return file_storage_client + else: + from ..serve import Serve + + file_storage_client = Serve.get_instance( + self._system_app + ).file_storage_client + self._system_app.register_instance(file_storage_client) + return file_storage_client + + @trace("upload_files") + def upload_files( + self, bucket: str, storage_type: str, files: List[UploadFile] + ) -> List[UploadFileResponse]: + """Upload files by a list of UploadFile.""" + results = [] + for file in files: + file_name = file.filename + logger.info(f"Uploading file {file_name} to bucket {bucket}") + uri = self.file_storage_client.save_file( + bucket, file_name, file_data=file.file, storage_type=storage_type + ) + parsed_uri = FileStorageURI.parse(uri) + logger.info(f"Uploaded file {file_name} to bucket {bucket}, uri={uri}") + results.append( + UploadFileResponse( + file_name=file_name, + file_id=parsed_uri.file_id, + bucket=bucket, + uri=uri, + ) + ) + return results + + @trace("download_file") + def download_file(self, bucket: str, file_id: str) -> Tuple[BinaryIO, FileMetadata]: + """Download a file by file_id.""" + return self.file_storage_client.get_file_by_id(bucket, file_id) + + def delete_file(self, bucket: str, file_id: str) -> None: + """Delete a file by file_id.""" + self.file_storage_client.delete_file_by_id(bucket, file_id) diff --git a/dbgpt/serve/file/tests/__init__.py b/dbgpt/serve/file/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/dbgpt/serve/file/tests/test_endpoints.py b/dbgpt/serve/file/tests/test_endpoints.py new file mode 100644 index 000000000..ba7b4f0cd --- /dev/null +++ b/dbgpt/serve/file/tests/test_endpoints.py @@ -0,0 +1,124 @@ +import pytest +from fastapi import FastAPI +from httpx import AsyncClient + +from dbgpt.component import SystemApp +from dbgpt.serve.core.tests.conftest import asystem_app, client +from dbgpt.storage.metadata import db +from dbgpt.util import PaginationResult + +from ..api.endpoints import init_endpoints, router +from ..api.schemas import ServeRequest, ServerResponse +from ..config import SERVE_CONFIG_KEY_PREFIX + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + + yield + + +def client_init_caller(app: FastAPI, system_app: SystemApp): + app.include_router(router) + init_endpoints(system_app) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client, asystem_app, has_auth", + [ + ( + { + "app_caller": client_init_caller, + "client_api_key": "test_token1", + }, + { + "app_config": { + f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2" + } + }, + True, + ), + ( + { + "app_caller": client_init_caller, + "client_api_key": "error_token", + }, + { + "app_config": { + f"{SERVE_CONFIG_KEY_PREFIX}api_keys": "test_token1,test_token2" + } + }, + False, + ), + ], + indirect=["client", "asystem_app"], +) +async def test_api_health(client: AsyncClient, asystem_app, has_auth: bool): + response = await client.get("/test_auth") + if has_auth: + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + else: + assert response.status_code == 401 + assert response.json() == { + "detail": { + "error": { + "message": "", + "type": "invalid_request_error", + "param": None, + "code": "invalid_api_key", + } + } + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_health(client: AsyncClient): + response = await client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_create(client: AsyncClient): + # TODO: add your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_update(client: AsyncClient): + # TODO: implement your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_query(client: AsyncClient): + # TODO: implement your test case + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "client", [{"app_caller": client_init_caller}], indirect=["client"] +) +async def test_api_query_by_page(client: AsyncClient): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic diff --git a/dbgpt/serve/file/tests/test_models.py b/dbgpt/serve/file/tests/test_models.py new file mode 100644 index 000000000..8b66e9f97 --- /dev/null +++ b/dbgpt/serve/file/tests/test_models.py @@ -0,0 +1,99 @@ +import pytest + +from dbgpt.storage.metadata import db + +from ..api.schemas import ServeRequest, ServerResponse +from ..config import ServeConfig +from ..models.models import ServeDao, ServeEntity + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + + yield + + +@pytest.fixture +def server_config(): + # TODO : build your server config + return ServeConfig() + + +@pytest.fixture +def dao(server_config): + return ServeDao(server_config) + + +@pytest.fixture +def default_entity_dict(): + # TODO: build your default entity dict + return {} + + +def test_table_exist(): + assert ServeEntity.__tablename__ in db.metadata.tables + + +def test_entity_create(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_unique_key(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_get(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_update(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_delete(default_entity_dict): + # TODO: implement your test case + pass + + +def test_entity_all(): + # TODO: implement your test case + pass + + +def test_dao_create(dao, default_entity_dict): + # TODO: implement your test case + pass + + +def test_dao_get_one(dao, default_entity_dict): + # TODO: implement your test case + pass + + +def test_get_dao_get_list(dao): + # TODO: implement your test case + pass + + +def test_dao_update(dao, default_entity_dict): + # TODO: implement your test case + pass + + +def test_dao_delete(dao, default_entity_dict): + # TODO: implement your test case + pass + + +def test_dao_get_list_page(dao): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic diff --git a/dbgpt/serve/file/tests/test_service.py b/dbgpt/serve/file/tests/test_service.py new file mode 100644 index 000000000..00177924d --- /dev/null +++ b/dbgpt/serve/file/tests/test_service.py @@ -0,0 +1,78 @@ +from typing import List + +import pytest + +from dbgpt.component import SystemApp +from dbgpt.serve.core.tests.conftest import system_app +from dbgpt.storage.metadata import db + +from ..api.schemas import ServeRequest, ServerResponse +from ..models.models import ServeEntity +from ..service.service import Service + + +@pytest.fixture(autouse=True) +def setup_and_teardown(): + db.init_db("sqlite:///:memory:") + db.create_all() + yield + + +@pytest.fixture +def service(system_app: SystemApp): + instance = Service(system_app) + instance.init_app(system_app) + return instance + + +@pytest.fixture +def default_entity_dict(): + # TODO: build your default entity dict + return {} + + +@pytest.mark.parametrize( + "system_app", + [{"app_config": {"DEBUG": True, "dbgpt.serve.test_key": "hello"}}], + indirect=True, +) +def test_config_exists(service: Service): + system_app: SystemApp = service._system_app + assert system_app.config.get("DEBUG") is True + assert system_app.config.get("dbgpt.serve.test_key") == "hello" + assert service.config is not None + + +def test_service_create(service: Service, default_entity_dict): + # TODO: implement your test case + # eg. entity: ServerResponse = service.create(ServeRequest(**default_entity_dict)) + # ... + pass + + +def test_service_update(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_get(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_delete(service: Service, default_entity_dict): + # TODO: implement your test case + pass + + +def test_service_get_list(service: Service): + # TODO: implement your test case + pass + + +def test_service_get_list_by_page(service: Service): + # TODO: implement your test case + pass + + +# Add more test cases according to your own logic From 48312ed94683224b785dc04b6dfc6a4531909709 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Mon, 19 Aug 2024 07:35:04 +0800 Subject: [PATCH 41/89] feat(core): Add file upload operator --- assets/schema/dbgpt.sql | 44 ++ .../upgrade/v0_6_0/upgrade_to_v0.6.0.sql | 43 ++ assets/schema/upgrade/v0_6_0/v0.5.10.sql | 419 ++++++++++++++++++ dbgpt/core/awel/flow/ui.py | 4 +- dbgpt/core/interface/file.py | 7 + dbgpt/serve/file/api/endpoints.py | 14 +- dbgpt/serve/file/models/file_adapter.py | 32 +- dbgpt/serve/file/models/models.py | 7 +- dbgpt/serve/file/service/service.py | 17 +- examples/awel/awel_flow_ui_components.py | 106 ++++- 10 files changed, 679 insertions(+), 14 deletions(-) create mode 100644 assets/schema/upgrade/v0_6_0/upgrade_to_v0.6.0.sql create mode 100644 assets/schema/upgrade/v0_6_0/v0.5.10.sql diff --git a/assets/schema/dbgpt.sql b/assets/schema/dbgpt.sql index 0cdd7d17e..f0683d5a6 100644 --- a/assets/schema/dbgpt.sql +++ b/assets/schema/dbgpt.sql @@ -295,6 +295,50 @@ CREATE TABLE `dbgpt_serve_flow` ( KEY `ix_dbgpt_serve_flow_name` (`name`) ) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; +-- dbgpt.dbgpt_serve_file definition +CREATE TABLE `dbgpt_serve_file` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `bucket` varchar(255) NOT NULL COMMENT 'Bucket name', + `file_id` varchar(255) NOT NULL COMMENT 'File id', + `file_name` varchar(256) NOT NULL COMMENT 'File name', + `file_size` int DEFAULT NULL COMMENT 'File size', + `storage_type` varchar(32) NOT NULL COMMENT 'Storage type', + `storage_path` varchar(512) NOT NULL COMMENT 'Storage path', + `uri` varchar(512) NOT NULL COMMENT 'File URI', + `custom_metadata` text DEFAULT NULL COMMENT 'Custom metadata, JSON format', + `file_hash` varchar(128) DEFAULT NULL COMMENT 'File hash', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_bucket_file_id` (`bucket`, `file_id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- dbgpt.dbgpt_serve_variables definition +CREATE TABLE `dbgpt_serve_variables` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `key` varchar(128) NOT NULL COMMENT 'Variable key', + `name` varchar(128) DEFAULT NULL COMMENT 'Variable name', + `label` varchar(128) DEFAULT NULL COMMENT 'Variable label', + `value` text DEFAULT NULL COMMENT 'Variable value, JSON format', + `value_type` varchar(32) DEFAULT NULL COMMENT 'Variable value type(string, int, float, bool)', + `category` varchar(32) DEFAULT 'common' COMMENT 'Variable category(common or secret)', + `encryption_method` varchar(32) DEFAULT NULL COMMENT 'Variable encryption method(fernet, simple, rsa, aes)', + `salt` varchar(128) DEFAULT NULL COMMENT 'Variable salt', + `scope` varchar(32) DEFAULT 'global' COMMENT 'Variable scope(global,flow,app,agent,datasource,flow:uid, flow:dag_name,agent:agent_name) etc', + `scope_key` varchar(256) DEFAULT NULL COMMENT 'Variable scope key, default is empty, for scope is "flow:uid", the scope_key is uid of flow', + `enabled` int DEFAULT 1 COMMENT 'Variable enabled, 0: disabled, 1: enabled', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time', + PRIMARY KEY (`id`), + KEY `ix_your_table_name_key` (`key`), + KEY `ix_your_table_name_name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + + -- dbgpt.gpts_app definition CREATE TABLE `gpts_app` ( `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', diff --git a/assets/schema/upgrade/v0_6_0/upgrade_to_v0.6.0.sql b/assets/schema/upgrade/v0_6_0/upgrade_to_v0.6.0.sql new file mode 100644 index 000000000..fa345fabe --- /dev/null +++ b/assets/schema/upgrade/v0_6_0/upgrade_to_v0.6.0.sql @@ -0,0 +1,43 @@ +-- dbgpt.dbgpt_serve_file definition +CREATE TABLE `dbgpt_serve_file` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `bucket` varchar(255) NOT NULL COMMENT 'Bucket name', + `file_id` varchar(255) NOT NULL COMMENT 'File id', + `file_name` varchar(256) NOT NULL COMMENT 'File name', + `file_size` int DEFAULT NULL COMMENT 'File size', + `storage_type` varchar(32) NOT NULL COMMENT 'Storage type', + `storage_path` varchar(512) NOT NULL COMMENT 'Storage path', + `uri` varchar(512) NOT NULL COMMENT 'File URI', + `custom_metadata` text DEFAULT NULL COMMENT 'Custom metadata, JSON format', + `file_hash` varchar(128) DEFAULT NULL COMMENT 'File hash', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_bucket_file_id` (`bucket`, `file_id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- dbgpt.dbgpt_serve_variables definition +CREATE TABLE `dbgpt_serve_variables` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `key` varchar(128) NOT NULL COMMENT 'Variable key', + `name` varchar(128) DEFAULT NULL COMMENT 'Variable name', + `label` varchar(128) DEFAULT NULL COMMENT 'Variable label', + `value` text DEFAULT NULL COMMENT 'Variable value, JSON format', + `value_type` varchar(32) DEFAULT NULL COMMENT 'Variable value type(string, int, float, bool)', + `category` varchar(32) DEFAULT 'common' COMMENT 'Variable category(common or secret)', + `encryption_method` varchar(32) DEFAULT NULL COMMENT 'Variable encryption method(fernet, simple, rsa, aes)', + `salt` varchar(128) DEFAULT NULL COMMENT 'Variable salt', + `scope` varchar(32) DEFAULT 'global' COMMENT 'Variable scope(global,flow,app,agent,datasource,flow:uid, flow:dag_name,agent:agent_name) etc', + `scope_key` varchar(256) DEFAULT NULL COMMENT 'Variable scope key, default is empty, for scope is "flow:uid", the scope_key is uid of flow', + `enabled` int DEFAULT 1 COMMENT 'Variable enabled, 0: disabled, 1: enabled', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time', + PRIMARY KEY (`id`), + KEY `ix_your_table_name_key` (`key`), + KEY `ix_your_table_name_name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + diff --git a/assets/schema/upgrade/v0_6_0/v0.5.10.sql b/assets/schema/upgrade/v0_6_0/v0.5.10.sql new file mode 100644 index 000000000..a70d8e643 --- /dev/null +++ b/assets/schema/upgrade/v0_6_0/v0.5.10.sql @@ -0,0 +1,419 @@ +-- Full SQL of v0.5.10, please not modify this file(It must be same as the file in the release package) + +CREATE +DATABASE IF NOT EXISTS dbgpt; +use dbgpt; + +-- For alembic migration tool +CREATE TABLE IF NOT EXISTS `alembic_version` +( + version_num VARCHAR(32) NOT NULL, + CONSTRAINT alembic_version_pkc PRIMARY KEY (version_num) +) DEFAULT CHARSET=utf8mb4 ; + +CREATE TABLE IF NOT EXISTS `knowledge_space` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `name` varchar(100) NOT NULL COMMENT 'knowledge space name', + `vector_type` varchar(50) NOT NULL COMMENT 'vector type', + `domain_type` varchar(50) NOT NULL COMMENT 'domain type', + `desc` varchar(500) NOT NULL COMMENT 'description', + `owner` varchar(100) DEFAULT NULL COMMENT 'owner', + `context` TEXT DEFAULT NULL COMMENT 'context argument', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_name` (`name`) COMMENT 'index:idx_name' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge space table'; + +CREATE TABLE IF NOT EXISTS `knowledge_document` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `doc_name` varchar(100) NOT NULL COMMENT 'document path name', + `doc_type` varchar(50) NOT NULL COMMENT 'doc type', + `space` varchar(50) NOT NULL COMMENT 'knowledge space', + `chunk_size` int NOT NULL COMMENT 'chunk size', + `last_sync` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'last sync time', + `status` varchar(50) NOT NULL COMMENT 'status TODO,RUNNING,FAILED,FINISHED', + `content` LONGTEXT NOT NULL COMMENT 'knowledge embedding sync result', + `result` TEXT NULL COMMENT 'knowledge content', + `vector_ids` LONGTEXT NULL COMMENT 'vector_ids', + `summary` LONGTEXT NULL COMMENT 'knowledge summary', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_doc_name` (`doc_name`) COMMENT 'index:idx_doc_name' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document table'; + +CREATE TABLE IF NOT EXISTS `document_chunk` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'auto increment id', + `doc_name` varchar(100) NOT NULL COMMENT 'document path name', + `doc_type` varchar(50) NOT NULL COMMENT 'doc type', + `document_id` int NOT NULL COMMENT 'document parent id', + `content` longtext NOT NULL COMMENT 'chunk content', + `meta_info` varchar(200) NOT NULL COMMENT 'metadata info', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + KEY `idx_document_id` (`document_id`) COMMENT 'index:document_id' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='knowledge document chunk detail'; + + + +CREATE TABLE IF NOT EXISTS `connect_config` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `db_type` varchar(255) NOT NULL COMMENT 'db type', + `db_name` varchar(255) NOT NULL COMMENT 'db name', + `db_path` varchar(255) DEFAULT NULL COMMENT 'file db path', + `db_host` varchar(255) DEFAULT NULL COMMENT 'db connect host(not file db)', + `db_port` varchar(255) DEFAULT NULL COMMENT 'db cnnect port(not file db)', + `db_user` varchar(255) DEFAULT NULL COMMENT 'db user', + `db_pwd` varchar(255) DEFAULT NULL COMMENT 'db password', + `comment` text COMMENT 'db comment', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_db` (`db_name`), + KEY `idx_q_db_type` (`db_type`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT 'Connection confi'; + +CREATE TABLE IF NOT EXISTS `chat_history` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id', + `chat_mode` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation scene mode', + `summary` longtext COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record summary', + `user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'interlocutor', + `messages` text COLLATE utf8mb4_unicode_ci COMMENT 'Conversation details', + `message_ids` text COLLATE utf8mb4_unicode_ci COMMENT 'Message id list, split by comma', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + UNIQUE KEY `conv_uid` (`conv_uid`), + PRIMARY KEY (`id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history'; + +CREATE TABLE IF NOT EXISTS `chat_history_message` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_uid` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'Conversation record unique id', + `index` int NOT NULL COMMENT 'Message index', + `round_index` int NOT NULL COMMENT 'Round of conversation', + `message_detail` text COLLATE utf8mb4_unicode_ci COMMENT 'Message details, json format', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + UNIQUE KEY `message_uid_index` (`conv_uid`, `index`), + PRIMARY KEY (`id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT 'Chat history message'; + +CREATE TABLE IF NOT EXISTS `chat_feed_back` +( + `id` bigint(20) NOT NULL AUTO_INCREMENT, + `conv_uid` varchar(128) DEFAULT NULL COMMENT 'Conversation ID', + `conv_index` int(4) DEFAULT NULL COMMENT 'Round of conversation', + `score` int(1) DEFAULT NULL COMMENT 'Score of user', + `ques_type` varchar(32) DEFAULT NULL COMMENT 'User question category', + `question` longtext DEFAULT NULL COMMENT 'User question', + `knowledge_space` varchar(128) DEFAULT NULL COMMENT 'Knowledge space name', + `messages` longtext DEFAULT NULL COMMENT 'The details of user feedback', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_conv` (`conv_uid`,`conv_index`), + KEY `idx_conv` (`conv_uid`,`conv_index`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='User feedback table'; + + +CREATE TABLE IF NOT EXISTS `my_plugin` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `tenant` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user tenant', + `user_code` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'user code', + `user_name` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'user name', + `name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name', + `file_name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin package file name', + `type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type', + `version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version', + `use_count` int DEFAULT NULL COMMENT 'plugin total use count', + `succ_count` int DEFAULT NULL COMMENT 'plugin total success count', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin install time', + PRIMARY KEY (`id`), + UNIQUE KEY `name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='User plugin table'; + +CREATE TABLE IF NOT EXISTS `plugin_hub` +( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `name` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin name', + `description` varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'plugin description', + `author` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author', + `email` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin author email', + `type` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin type', + `version` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin version', + `storage_channel` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin storage channel', + `storage_url` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download url', + `download_param` varchar(255) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'plugin download param', + `gmt_created` TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT 'plugin upload time', + `installed` int DEFAULT NULL COMMENT 'plugin already installed count', + PRIMARY KEY (`id`), + UNIQUE KEY `name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Plugin Hub table'; + + +CREATE TABLE IF NOT EXISTS `prompt_manage` +( + `id` int(11) NOT NULL AUTO_INCREMENT, + `chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Chat scene', + `sub_chat_scene` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Sub chat scene', + `prompt_type` varchar(100) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt type: common or private', + `prompt_name` varchar(256) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'prompt name', + `content` longtext COLLATE utf8mb4_unicode_ci COMMENT 'Prompt content', + `input_variables` varchar(1024) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt input variables(split by comma))', + `model` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt model name(we can use different models for different prompt)', + `prompt_language` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt language(eg:en, zh-cn)', + `prompt_format` varchar(32) COLLATE utf8mb4_unicode_ci DEFAULT 'f-string' COMMENT 'Prompt format(eg: f-string, jinja2)', + `prompt_desc` varchar(512) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'Prompt description', + `user_name` varchar(128) COLLATE utf8mb4_unicode_ci DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` timestamp NULL DEFAULT CURRENT_TIMESTAMP COMMENT 'created time', + `gmt_modified` timestamp NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'update time', + PRIMARY KEY (`id`), + UNIQUE KEY `prompt_name_uiq` (`prompt_name`, `sys_code`, `prompt_language`, `model`), + KEY `gmt_created_idx` (`gmt_created`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='Prompt management table'; + + CREATE TABLE IF NOT EXISTS `gpts_conversations` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record', + `user_goal` text NOT NULL COMMENT 'User''s goals content', + `gpts_name` varchar(255) NOT NULL COMMENT 'The gpts name', + `state` varchar(255) DEFAULT NULL COMMENT 'The gpts state', + `max_auto_reply_round` int(11) NOT NULL COMMENT 'max auto reply round', + `auto_reply_count` int(11) NOT NULL COMMENT 'auto reply count', + `user_code` varchar(255) DEFAULT NULL COMMENT 'user code', + `sys_code` varchar(255) DEFAULT NULL COMMENT 'system app ', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + `team_mode` varchar(255) NULL COMMENT 'agent team work mode', + + PRIMARY KEY (`id`), + UNIQUE KEY `uk_gpts_conversations` (`conv_id`), + KEY `idx_gpts_name` (`gpts_name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt conversations"; + +CREATE TABLE IF NOT EXISTS `gpts_instance` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `gpts_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name', + `gpts_describe` varchar(2255) NOT NULL COMMENT 'Current AI assistant describe', + `resource_db` text COMMENT 'List of structured database names contained in the current gpts', + `resource_internet` text COMMENT 'Is it possible to retrieve information from the internet', + `resource_knowledge` text COMMENT 'List of unstructured database names contained in the current gpts', + `gpts_agents` varchar(1000) DEFAULT NULL COMMENT 'List of agents names contained in the current gpts', + `gpts_models` varchar(1000) DEFAULT NULL COMMENT 'List of llm model names contained in the current gpts', + `language` varchar(100) DEFAULT NULL COMMENT 'gpts language', + `user_code` varchar(255) NOT NULL COMMENT 'user code', + `sys_code` varchar(255) DEFAULT NULL COMMENT 'system app code', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + `team_mode` varchar(255) NOT NULL COMMENT 'Team work mode', + `is_sustainable` tinyint(1) NOT NULL COMMENT 'Applications for sustainable dialogue', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_gpts` (`gpts_name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpts instance"; + +CREATE TABLE `gpts_messages` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record', + `sender` varchar(255) NOT NULL COMMENT 'Who speaking in the current conversation turn', + `receiver` varchar(255) NOT NULL COMMENT 'Who receive message in the current conversation turn', + `model_name` varchar(255) DEFAULT NULL COMMENT 'message generate model', + `rounds` int(11) NOT NULL COMMENT 'dialogue turns', + `content` text COMMENT 'Content of the speech', + `current_goal` text COMMENT 'The target corresponding to the current message', + `context` text COMMENT 'Current conversation context', + `review_info` text COMMENT 'Current conversation review info', + `action_report` text COMMENT 'Current conversation action report', + `role` varchar(255) DEFAULT NULL COMMENT 'The role of the current message content', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + PRIMARY KEY (`id`), + KEY `idx_q_messages` (`conv_id`,`rounds`,`sender`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpts message"; + + +CREATE TABLE `gpts_plans` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `conv_id` varchar(255) NOT NULL COMMENT 'The unique id of the conversation record', + `sub_task_num` int(11) NOT NULL COMMENT 'Subtask number', + `sub_task_title` varchar(255) NOT NULL COMMENT 'subtask title', + `sub_task_content` text NOT NULL COMMENT 'subtask content', + `sub_task_agent` varchar(255) DEFAULT NULL COMMENT 'Available agents corresponding to subtasks', + `resource_name` varchar(255) DEFAULT NULL COMMENT 'resource name', + `rely` varchar(255) DEFAULT NULL COMMENT 'Subtask dependencies,like: 1,2,3', + `agent_model` varchar(255) DEFAULT NULL COMMENT 'LLM model used by subtask processing agents', + `retry_times` int(11) DEFAULT NULL COMMENT 'number of retries', + `max_retry_times` int(11) DEFAULT NULL COMMENT 'Maximum number of retries', + `state` varchar(255) DEFAULT NULL COMMENT 'subtask status', + `result` longtext COMMENT 'subtask result', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_sub_task` (`conv_id`,`sub_task_num`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt plan"; + +-- dbgpt.dbgpt_serve_flow definition +CREATE TABLE `dbgpt_serve_flow` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `uid` varchar(128) NOT NULL COMMENT 'Unique id', + `dag_id` varchar(128) DEFAULT NULL COMMENT 'DAG id', + `name` varchar(128) DEFAULT NULL COMMENT 'Flow name', + `flow_data` text COMMENT 'Flow data, JSON format', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT NULL COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT NULL COMMENT 'Record update time', + `flow_category` varchar(64) DEFAULT NULL COMMENT 'Flow category', + `description` varchar(512) DEFAULT NULL COMMENT 'Flow description', + `state` varchar(32) DEFAULT NULL COMMENT 'Flow state', + `error_message` varchar(512) NULL comment 'Error message', + `source` varchar(64) DEFAULT NULL COMMENT 'Flow source', + `source_url` varchar(512) DEFAULT NULL COMMENT 'Flow source url', + `version` varchar(32) DEFAULT NULL COMMENT 'Flow version', + `define_type` varchar(32) null comment 'Flow define type(json or python)', + `label` varchar(128) DEFAULT NULL COMMENT 'Flow label', + `editable` int DEFAULT NULL COMMENT 'Editable, 0: editable, 1: not editable', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_uid` (`uid`), + KEY `ix_dbgpt_serve_flow_sys_code` (`sys_code`), + KEY `ix_dbgpt_serve_flow_uid` (`uid`), + KEY `ix_dbgpt_serve_flow_dag_id` (`dag_id`), + KEY `ix_dbgpt_serve_flow_user_name` (`user_name`), + KEY `ix_dbgpt_serve_flow_name` (`name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- dbgpt.gpts_app definition +CREATE TABLE `gpts_app` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code', + `app_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name', + `app_describe` varchar(2255) NOT NULL COMMENT 'Current AI assistant describe', + `language` varchar(100) NOT NULL COMMENT 'gpts language', + `team_mode` varchar(255) NOT NULL COMMENT 'Team work mode', + `team_context` text COMMENT 'The execution logic and team member content that teams with different working modes rely on', + `user_code` varchar(255) DEFAULT NULL COMMENT 'user code', + `sys_code` varchar(255) DEFAULT NULL COMMENT 'system app code', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + `icon` varchar(1024) DEFAULT NULL COMMENT 'app icon, url', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_gpts_app` (`app_name`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +CREATE TABLE `gpts_app_collection` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code', + `user_code` int(11) NOT NULL COMMENT 'user code', + `sys_code` varchar(255) NOT NULL COMMENT 'system app code', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + PRIMARY KEY (`id`), + KEY `idx_app_code` (`app_code`), + KEY `idx_user_code` (`user_code`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT="gpt collections"; + +-- dbgpt.gpts_app_detail definition +CREATE TABLE `gpts_app_detail` ( + `id` int NOT NULL AUTO_INCREMENT COMMENT 'autoincrement id', + `app_code` varchar(255) NOT NULL COMMENT 'Current AI assistant code', + `app_name` varchar(255) NOT NULL COMMENT 'Current AI assistant name', + `agent_name` varchar(255) NOT NULL COMMENT ' Agent name', + `node_id` varchar(255) NOT NULL COMMENT 'Current AI assistant Agent Node id', + `resources` text COMMENT 'Agent bind resource', + `prompt_template` text COMMENT 'Agent bind template', + `llm_strategy` varchar(25) DEFAULT NULL COMMENT 'Agent use llm strategy', + `llm_strategy_value` text COMMENT 'Agent use llm strategy value', + `created_at` datetime DEFAULT NULL COMMENT 'create time', + `updated_at` datetime DEFAULT NULL COMMENT 'last update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_gpts_app_agent_node` (`app_name`,`agent_name`,`node_id`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + + +-- For deploy model cluster of DB-GPT(StorageModelRegistry) +CREATE TABLE IF NOT EXISTS `dbgpt_cluster_registry_instance` ( + `id` int(11) NOT NULL AUTO_INCREMENT COMMENT 'Auto increment id', + `model_name` varchar(128) NOT NULL COMMENT 'Model name', + `host` varchar(128) NOT NULL COMMENT 'Host of the model', + `port` int(11) NOT NULL COMMENT 'Port of the model', + `weight` float DEFAULT 1.0 COMMENT 'Weight of the model', + `check_healthy` tinyint(1) DEFAULT 1 COMMENT 'Whether to check the health of the model', + `healthy` tinyint(1) DEFAULT 0 COMMENT 'Whether the model is healthy', + `enabled` tinyint(1) DEFAULT 1 COMMENT 'Whether the model is enabled', + `prompt_template` varchar(128) DEFAULT NULL COMMENT 'Prompt template for the model instance', + `last_heartbeat` datetime DEFAULT NULL COMMENT 'Last heartbeat time of the model instance', + `user_name` varchar(128) DEFAULT NULL COMMENT 'User name', + `sys_code` varchar(128) DEFAULT NULL COMMENT 'System code', + `gmt_created` datetime DEFAULT CURRENT_TIMESTAMP COMMENT 'Record creation time', + `gmt_modified` datetime DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'Record update time', + PRIMARY KEY (`id`), + UNIQUE KEY `uk_model_instance` (`model_name`, `host`, `port`, `sys_code`) +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='Cluster model instance table, for registering and managing model instances'; + + +CREATE +DATABASE IF NOT EXISTS EXAMPLE_1; +use EXAMPLE_1; +CREATE TABLE IF NOT EXISTS `users` +( + `id` int NOT NULL AUTO_INCREMENT, + `username` varchar(50) NOT NULL COMMENT '用户名', + `password` varchar(50) NOT NULL COMMENT '密码', + `email` varchar(50) NOT NULL COMMENT '邮箱', + `phone` varchar(20) DEFAULT NULL COMMENT '电话', + PRIMARY KEY (`id`), + KEY `idx_username` (`username`) COMMENT '索引:按用户名查询' +) ENGINE=InnoDB AUTO_INCREMENT=1 DEFAULT CHARSET=utf8mb4 COMMENT='聊天用户表'; + +INSERT INTO users (username, password, email, phone) +VALUES ('user_1', 'password_1', 'user_1@example.com', '12345678901'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_2', 'password_2', 'user_2@example.com', '12345678902'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_3', 'password_3', 'user_3@example.com', '12345678903'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_4', 'password_4', 'user_4@example.com', '12345678904'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_5', 'password_5', 'user_5@example.com', '12345678905'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_6', 'password_6', 'user_6@example.com', '12345678906'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_7', 'password_7', 'user_7@example.com', '12345678907'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_8', 'password_8', 'user_8@example.com', '12345678908'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_9', 'password_9', 'user_9@example.com', '12345678909'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_10', 'password_10', 'user_10@example.com', '12345678900'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_11', 'password_11', 'user_11@example.com', '12345678901'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_12', 'password_12', 'user_12@example.com', '12345678902'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_13', 'password_13', 'user_13@example.com', '12345678903'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_14', 'password_14', 'user_14@example.com', '12345678904'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_15', 'password_15', 'user_15@example.com', '12345678905'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_16', 'password_16', 'user_16@example.com', '12345678906'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_17', 'password_17', 'user_17@example.com', '12345678907'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_18', 'password_18', 'user_18@example.com', '12345678908'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_19', 'password_19', 'user_19@example.com', '12345678909'); +INSERT INTO users (username, password, email, phone) +VALUES ('user_20', 'password_20', 'user_20@example.com', '12345678900'); \ No newline at end of file diff --git a/dbgpt/core/awel/flow/ui.py b/dbgpt/core/awel/flow/ui.py index c763859b0..7fd2a4ba4 100644 --- a/dbgpt/core/awel/flow/ui.py +++ b/dbgpt/core/awel/flow/ui.py @@ -387,8 +387,8 @@ class UIAttribute(UIComponent.UIAttribute): description="Whether to support drag and drop upload", ) action: Optional[str] = Field( - None, - description="The URL for the file upload", + "/api/v2/serve/file/files/dbgpt", + description="The URL for the file upload(default bucket is 'dbgpt')", ) diff --git a/dbgpt/core/interface/file.py b/dbgpt/core/interface/file.py index 5bd6cf842..83a524510 100644 --- a/dbgpt/core/interface/file.py +++ b/dbgpt/core/interface/file.py @@ -61,6 +61,8 @@ class FileMetadata(StorageItem): uri: str custom_metadata: Dict[str, Any] file_hash: str + user_name: Optional[str] = None + sys_code: Optional[str] = None _identifier: FileMetadataIdentifier = dataclasses.field(init=False) def __post_init__(self): @@ -68,6 +70,11 @@ def __post_init__(self): self._identifier = FileMetadataIdentifier( file_id=self.file_id, bucket=self.bucket ) + custom_metadata = self.custom_metadata or {} + if not self.user_name: + self.user_name = custom_metadata.get("user_name") + if not self.sys_code: + self.sys_code = custom_metadata.get("sys_code") @property def identifier(self) -> ResourceIdentifier: diff --git a/dbgpt/serve/file/api/endpoints.py b/dbgpt/serve/file/api/endpoints.py index edf1d2d98..26bbb9673 100644 --- a/dbgpt/serve/file/api/endpoints.py +++ b/dbgpt/serve/file/api/endpoints.py @@ -104,12 +104,22 @@ async def test_auth(): dependencies=[Depends(check_api_key)], ) async def upload_files( - bucket: str, files: List[UploadFile], service: Service = Depends(get_service) + bucket: str, + files: List[UploadFile], + user_name: Optional[str] = Query(default=None, description="user name"), + sys_code: Optional[str] = Query(default=None, description="system code"), + service: Service = Depends(get_service), ) -> Result[List[UploadFileResponse]]: """Upload files by a list of UploadFile.""" logger.info(f"upload_files: bucket={bucket}, files={files}") results = await blocking_func_to_async( - global_system_app, service.upload_files, bucket, "distributed", files + global_system_app, + service.upload_files, + bucket, + "distributed", + files, + user_name, + sys_code, ) return Result.succ(results) diff --git a/dbgpt/serve/file/models/file_adapter.py b/dbgpt/serve/file/models/file_adapter.py index a8ab36465..29ee831f4 100644 --- a/dbgpt/serve/file/models/file_adapter.py +++ b/dbgpt/serve/file/models/file_adapter.py @@ -18,9 +18,18 @@ class FileMetadataAdapter(StorageItemAdapter[FileMetadata, ServeEntity]): def to_storage_format(self, item: FileMetadata) -> ServeEntity: """Convert to storage format.""" custom_metadata = ( - json.dumps(item.custom_metadata, ensure_ascii=False) + {k: v for k, v in item.custom_metadata.items()} if item.custom_metadata - else None + else {} + ) + user_name = item.user_name or custom_metadata.get("user_name") + sys_code = item.sys_code or custom_metadata.get("sys_code") + if "user_name" in custom_metadata: + del custom_metadata["user_name"] + if "sys_code" in custom_metadata: + del custom_metadata["sys_code"] + custom_metadata_json = ( + json.dumps(custom_metadata, ensure_ascii=False) if custom_metadata else None ) return ServeEntity( bucket=item.bucket, @@ -30,8 +39,10 @@ def to_storage_format(self, item: FileMetadata) -> ServeEntity: storage_type=item.storage_type, storage_path=item.storage_path, uri=item.uri, - custom_metadata=custom_metadata, + custom_metadata=custom_metadata_json, file_hash=item.file_hash, + user_name=user_name, + sys_code=sys_code, ) def from_storage_format(self, model: ServeEntity) -> FileMetadata: @@ -39,6 +50,13 @@ def from_storage_format(self, model: ServeEntity) -> FileMetadata: custom_metadata = ( json.loads(model.custom_metadata) if model.custom_metadata else None ) + if custom_metadata is None: + custom_metadata = {} + if model.user_name: + custom_metadata["user_name"] = model.user_name + if model.sys_code: + custom_metadata["sys_code"] = model.sys_code + return FileMetadata( bucket=model.bucket, file_id=model.file_id, @@ -49,6 +67,8 @@ def from_storage_format(self, model: ServeEntity) -> FileMetadata: uri=model.uri, custom_metadata=custom_metadata, file_hash=model.file_hash, + user_name=model.user_name, + sys_code=model.sys_code, ) def get_query_for_identifier( @@ -61,6 +81,8 @@ def get_query_for_identifier( session: Session = kwargs.get("session") if session is None: raise Exception("session is None") - return session.query(storage_format).filter( - storage_format.file_id == resource_id.file_id + return ( + session.query(storage_format) + .filter(storage_format.bucket == resource_id.bucket) + .filter(storage_format.file_id == resource_id.file_id) ) diff --git a/dbgpt/serve/file/models/models.py b/dbgpt/serve/file/models/models.py index 62dd1ef80..fd816740d 100644 --- a/dbgpt/serve/file/models/models.py +++ b/dbgpt/serve/file/models/models.py @@ -5,7 +5,7 @@ from datetime import datetime from typing import Any, Dict, Union -from sqlalchemy import Column, DateTime, Index, Integer, String, Text +from sqlalchemy import Column, DateTime, Index, Integer, String, Text, UniqueConstraint from dbgpt.storage.metadata import BaseDao, Model, db @@ -15,6 +15,8 @@ class ServeEntity(Model): __tablename__ = SERVER_APP_TABLE_NAME + __table_args__ = (UniqueConstraint("bucket", "file_id", name="uk_bucket_file_id"),) + id = Column(Integer, primary_key=True, comment="Auto increment id") bucket = Column(String(255), nullable=False, comment="Bucket name") @@ -28,7 +30,8 @@ class ServeEntity(Model): Text, nullable=True, comment="Custom metadata, JSON format" ) file_hash = Column(String(128), nullable=True, comment="File hash") - + user_name = Column(String(128), index=True, nullable=True, comment="User name") + sys_code = Column(String(128), index=True, nullable=True, comment="System code") gmt_created = Column(DateTime, default=datetime.now, comment="Record creation time") gmt_modified = Column(DateTime, default=datetime.now, comment="Record update time") diff --git a/dbgpt/serve/file/service/service.py b/dbgpt/serve/file/service/service.py index d4d0118f3..13e8b6225 100644 --- a/dbgpt/serve/file/service/service.py +++ b/dbgpt/serve/file/service/service.py @@ -74,15 +74,28 @@ def file_storage_client(self) -> FileStorageClient: @trace("upload_files") def upload_files( - self, bucket: str, storage_type: str, files: List[UploadFile] + self, + bucket: str, + storage_type: str, + files: List[UploadFile], + user_name: Optional[str] = None, + sys_code: Optional[str] = None, ) -> List[UploadFileResponse]: """Upload files by a list of UploadFile.""" results = [] for file in files: file_name = file.filename logger.info(f"Uploading file {file_name} to bucket {bucket}") + custom_metadata = { + "user_name": user_name, + "sys_code": sys_code, + } uri = self.file_storage_client.save_file( - bucket, file_name, file_data=file.file, storage_type=storage_type + bucket, + file_name, + file_data=file.file, + storage_type=storage_type, + custom_metadata=custom_metadata, ) parsed_uri = FileStorageURI.parse(uri) logger.info(f"Uploaded file {file_name} to bucket {bucket}, uri={uri}") diff --git a/examples/awel/awel_flow_ui_components.py b/examples/awel/awel_flow_ui_components.py index cba0c14df..7fa2dc236 100644 --- a/examples/awel/awel_flow_ui_components.py +++ b/examples/awel/awel_flow_ui_components.py @@ -2,7 +2,7 @@ import json import logging -from typing import List, Optional +from typing import Any, Dict, List, Optional from dbgpt.core.awel import MapOperator from dbgpt.core.awel.flow import ( @@ -15,6 +15,7 @@ ViewMetadata, ui, ) +from dbgpt.core.interface.file import FileStorageClient from dbgpt.core.interface.variables import ( BUILTIN_VARIABLES_CORE_EMBEDDINGS, BUILTIN_VARIABLES_CORE_FLOW_NODES, @@ -787,6 +788,109 @@ async def map(self, user_name: str) -> str: ) +class ExampleFlowUploadOperator(MapOperator[str, str]): + """An example flow operator that includes an upload as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Upload", + name="example_flow_upload", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a upload as parameter.", + parameters=[ + Parameter.build_from( + "Single File Selector", + "file", + type=str, + optional=True, + default=None, + placeholder="Select the file", + description="The file you want to upload.", + ui=ui.UIUpload( + max_file_size=1024 * 1024 * 100, + up_event="after_select", + attr=ui.UIUpload.UIAttribute(max_count=1), + ), + ), + Parameter.build_from( + "Multiple Files Selector", + "multiple_files", + type=str, + is_list=True, + optional=True, + default=None, + placeholder="Select the multiple files", + description="The multiple files you want to upload.", + ui=ui.UIUpload( + max_file_size=1024 * 1024 * 100, + up_event="button_click", + attr=ui.UIUpload.UIAttribute(max_count=5), + ), + ), + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "File", + "file", + str, + description="User's uploaded file.", + ) + ], + ) + + def __init__( + self, + file: Optional[str] = None, + multiple_files: Optional[List[str]] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.file = file + self.multiple_files = multiple_files or [] + + async def map(self, user_name: str) -> str: + """Map the user name to the file.""" + + fsc = FileStorageClient.get_instance(self.system_app) + files_metadata = await self.blocking_func_to_async( + self._parse_files_metadata, fsc + ) + files_metadata_str = json.dumps(files_metadata, ensure_ascii=False) + return "Your name is %s, and you files are %s." % ( + user_name, + files_metadata_str, + ) + + def _parse_files_metadata(self, fsc: FileStorageClient) -> List[Dict[str, Any]]: + """Parse the files metadata.""" + if not self.file: + raise ValueError("The file is not uploaded.") + if not self.multiple_files: + raise ValueError("The multiple files are not uploaded.") + files = [self.file] + self.multiple_files + results = [] + for file in files: + _, metadata = fsc.get_file(file) + results.append( + { + "bucket": metadata.bucket, + "file_id": metadata.file_id, + "file_size": metadata.file_size, + "storage_type": metadata.storage_type, + "uri": metadata.uri, + "file_hash": metadata.file_hash, + } + ) + return results + + class ExampleFlowVariablesOperator(MapOperator[str, str]): """An example flow operator that includes a variables option.""" From 7c241b64054199395b1cd54d29bb0e6a2cbbaafe Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Mon, 19 Aug 2024 07:42:11 +0800 Subject: [PATCH 42/89] fix(core): Fix upload ui component attr error --- dbgpt/core/awel/flow/ui.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dbgpt/core/awel/flow/ui.py b/dbgpt/core/awel/flow/ui.py index 7fd2a4ba4..928755a20 100644 --- a/dbgpt/core/awel/flow/ui.py +++ b/dbgpt/core/awel/flow/ui.py @@ -367,7 +367,10 @@ class UIAttribute(UIComponent.UIAttribute): ) ui_type: Literal["upload"] = Field("upload", frozen=True) - + attr: Optional[UIAttribute] = Field( + None, + description="The attributes of the component", + ) max_file_size: Optional[int] = Field( None, description="The maximum size of the file, in bytes", From 3d1d2757ceb0bef52d671ea2eeb84c71ef74b8e7 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Mon, 19 Aug 2024 08:39:09 +0800 Subject: [PATCH 43/89] feat(core): Add multi-instance config --- dbgpt/_private/config.py | 4 ++++ dbgpt/app/component_configs.py | 6 ++++-- dbgpt/app/initialization/serve_initialization.py | 2 ++ 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index 18e972a4c..403570c5a 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -322,6 +322,10 @@ def __init__(self) -> None: self.FILE_SERVER_LOCAL_STORAGE_PATH = os.getenv( "FILE_SERVER_LOCAL_STORAGE_PATH" ) + # multi-instance flag + self.WEBSERVER_MULTI_INSTANCE = ( + os.getenv("MULTI_INSTANCE", "False").lower() == "true" + ) @property def local_db_manager(self) -> "ConnectorManager": diff --git a/dbgpt/app/component_configs.py b/dbgpt/app/component_configs.py index 29c9e59be..a8a0f24d1 100644 --- a/dbgpt/app/component_configs.py +++ b/dbgpt/app/component_configs.py @@ -52,7 +52,7 @@ def initialize_components( param, system_app, embedding_model_name, embedding_model_path ) _initialize_rerank_model(param, system_app, rerank_model_name, rerank_model_path) - _initialize_model_cache(system_app) + _initialize_model_cache(system_app, param.port) _initialize_awel(system_app, param) # Initialize resource manager of agent _initialize_resource_manager(system_app) @@ -62,7 +62,7 @@ def initialize_components( register_serve_apps(system_app, CFG, param.port) -def _initialize_model_cache(system_app: SystemApp): +def _initialize_model_cache(system_app: SystemApp, port: int): from dbgpt.storage.cache import initialize_cache if not CFG.MODEL_CACHE_ENABLE: @@ -72,6 +72,8 @@ def _initialize_model_cache(system_app: SystemApp): storage_type = CFG.MODEL_CACHE_STORAGE_TYPE or "disk" max_memory_mb = CFG.MODEL_CACHE_MAX_MEMORY_MB or 256 persist_dir = CFG.MODEL_CACHE_STORAGE_DISK_DIR or MODEL_DISK_CACHE_DIR + if CFG.WEBSERVER_MULTI_INSTANCE: + persist_dir = f"{persist_dir}_{port}" initialize_cache(system_app, storage_type, max_memory_mb, persist_dir) diff --git a/dbgpt/app/initialization/serve_initialization.py b/dbgpt/app/initialization/serve_initialization.py index 7838644e0..5b29ce455 100644 --- a/dbgpt/app/initialization/serve_initialization.py +++ b/dbgpt/app/initialization/serve_initialization.py @@ -84,6 +84,8 @@ def register_serve_apps(system_app: SystemApp, cfg: Config, webserver_port: int) local_storage_path = ( cfg.FILE_SERVER_LOCAL_STORAGE_PATH or FILE_SERVER_LOCAL_STORAGE_PATH ) + if cfg.WEBSERVER_MULTI_INSTANCE: + local_storage_path = f"{local_storage_path}_{webserver_port}" # Set config system_app.config.set( f"{FILE_SERVE_CONFIG_KEY_PREFIX}local_storage_path", local_storage_path From 8ced94066c7ee155e235a2cf76c8196f554d0eea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A8=E6=AC=A3?= Date: Mon, 19 Aug 2024 10:36:27 +0800 Subject: [PATCH 44/89] feat: add Variables component to flow --- web/components/flow/node-param-handler.tsx | 5 +++- web/components/flow/node-renderer/index.ts | 1 + .../flow/node-renderer/variables.tsx | 27 +++++++++++++++++++ 3 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 web/components/flow/node-renderer/variables.tsx diff --git a/web/components/flow/node-param-handler.tsx b/web/components/flow/node-param-handler.tsx index 165eb59bb..d10bfd1d2 100644 --- a/web/components/flow/node-param-handler.tsx +++ b/web/components/flow/node-param-handler.tsx @@ -16,6 +16,7 @@ import { RenderTimePicker, RenderTextArea, RenderPassword, + RenderVariables, } from './node-renderer'; import { convertKeysToCamelCase } from '@/utils/flow'; @@ -116,7 +117,7 @@ const NodeParamHandler: React.FC = ({ node, data, label, function renderNodeWithUiParam(data: IFlowNodeParameter) { let defaultValue = data.value ?? data.default; const props = { data, defaultValue, onChange }; - + switch (data?.ui?.ui_type) { case 'select': return ; @@ -142,6 +143,8 @@ const NodeParamHandler: React.FC = ({ node, data, label, return ; case 'password': return ; + case 'variables': + return ; default: return null; } diff --git a/web/components/flow/node-renderer/index.ts b/web/components/flow/node-renderer/index.ts index 59f1e44ef..f95706572 100644 --- a/web/components/flow/node-renderer/index.ts +++ b/web/components/flow/node-renderer/index.ts @@ -9,3 +9,4 @@ export * from './slider'; export * from './time-picker'; export * from './tree-select'; export * from './password'; +export * from './variables'; diff --git a/web/components/flow/node-renderer/variables.tsx b/web/components/flow/node-renderer/variables.tsx new file mode 100644 index 000000000..b3fb37af4 --- /dev/null +++ b/web/components/flow/node-renderer/variables.tsx @@ -0,0 +1,27 @@ +import { IFlowNodeParameter } from '@/types/flow'; +import { convertKeysToCamelCase } from '@/utils/flow'; +import { Input } from 'antd'; + +type Props = { + data: IFlowNodeParameter; + defaultValue: any; + onChange: (value: any) => void; +}; + +export const RenderVariables = (params: Props) => { + const { data, defaultValue, onChange } = params; + const attr = convertKeysToCamelCase(data.ui?.attr || {}); + + return ( + { + onChange(e.target.value); + }} + /> + ); +}; From a731233fdc224cbd7dc8a968c0689d82405093ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A8=E6=AC=A3?= Date: Mon, 19 Aug 2024 14:42:52 +0800 Subject: [PATCH 45/89] style: delete repeated Select component --- web/components/flow/node-param-handler.tsx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/web/components/flow/node-param-handler.tsx b/web/components/flow/node-param-handler.tsx index edc3a8952..549fa996e 100644 --- a/web/components/flow/node-param-handler.tsx +++ b/web/components/flow/node-param-handler.tsx @@ -130,8 +130,6 @@ const NodeParamHandler: React.FC = ({ node, data, label, return ; case 'input': return ; - case 'select': - return ; case 'text_area': return ; case 'slider': @@ -146,6 +144,8 @@ const NodeParamHandler: React.FC = ({ node, data, label, return ; case 'upload': return ; + case 'variables': + return ; case 'code_editor': return ; default: From 9394d34b199b5e4b0946830657074389e17691d4 Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Mon, 19 Aug 2024 15:39:42 +0800 Subject: [PATCH 46/89] feat(core): Add code editor for UI component --- dbgpt/core/interface/file.py | 38 ++++++- dbgpt/serve/flow/api/endpoints.py | 4 +- examples/awel/awel_flow_ui_components.py | 129 +++++++++++++++++++++++ 3 files changed, 169 insertions(+), 2 deletions(-) diff --git a/dbgpt/core/interface/file.py b/dbgpt/core/interface/file.py index 83a524510..ea1ddb2f3 100644 --- a/dbgpt/core/interface/file.py +++ b/dbgpt/core/interface/file.py @@ -3,6 +3,7 @@ import dataclasses import hashlib import io +import logging import os import uuid from abc import ABC, abstractmethod @@ -24,6 +25,7 @@ StorageItem, ) +logger = logging.getLogger(__name__) _SCHEMA = "dbgpt-fs" @@ -133,6 +135,14 @@ def __init__( self.version = version self.custom_params = custom_params or {} + @classmethod + def is_local_file(cls, uri: str) -> bool: + """Check if the URI is local.""" + parsed = urlparse(uri) + if not parsed.scheme or parsed.scheme == "file": + return True + return False + @classmethod def parse(cls, uri: str) -> "FileStorageURI": """Parse the URI string.""" @@ -313,6 +323,13 @@ def save_file( file_size = file_data.tell() # Get the file size file_data.seek(0) # Reset file pointer + # filter None value + custom_metadata = ( + {k: v for k, v in custom_metadata.items() if v is not None} + if custom_metadata + else {} + ) + with root_tracer.start_span( "file_storage_system.save_file.calculate_hash", ): @@ -329,7 +346,7 @@ def save_file( storage_type=storage_type, storage_path=storage_path, uri=str(uri), - custom_metadata=custom_metadata or {}, + custom_metadata=custom_metadata, file_hash=file_hash, ) @@ -339,6 +356,25 @@ def save_file( @trace("file_storage_system.get_file") def get_file(self, uri: str) -> Tuple[BinaryIO, FileMetadata]: """Get the file data from the storage backend.""" + if FileStorageURI.is_local_file(uri): + local_file_name = uri.split("/")[-1] + if not os.path.exists(uri): + raise FileNotFoundError(f"File not found: {uri}") + + dummy_metadata = FileMetadata( + file_id=local_file_name, + bucket="dummy_bucket", + file_name=local_file_name, + file_size=-1, + storage_type="local", + storage_path=uri, + uri=uri, + custom_metadata={}, + file_hash="", + ) + logger.info(f"Reading local file: {uri}") + return open(uri, "rb"), dummy_metadata # noqa: SIM115 + parsed_uri = FileStorageURI.parse(uri) metadata = self.metadata_storage.load( FileMetadataIdentifier( diff --git a/dbgpt/serve/flow/api/endpoints.py b/dbgpt/serve/flow/api/endpoints.py index 4174502a5..40342f7f1 100644 --- a/dbgpt/serve/flow/api/endpoints.py +++ b/dbgpt/serve/flow/api/endpoints.py @@ -117,7 +117,9 @@ async def test_auth(): @router.post( - "/flows", response_model=Result[None], dependencies=[Depends(check_api_key)] + "/flows", + response_model=Result[ServerResponse], + dependencies=[Depends(check_api_key)], ) async def create( request: ServeRequest, service: Service = Depends(get_service) diff --git a/examples/awel/awel_flow_ui_components.py b/examples/awel/awel_flow_ui_components.py index 7fa2dc236..b187b085d 100644 --- a/examples/awel/awel_flow_ui_components.py +++ b/examples/awel/awel_flow_ui_components.py @@ -1079,3 +1079,132 @@ def __init__(self, **kwargs): async def map(self, user_name: str) -> str: """Map the user name to the tags.""" return "Your name is %s, and your tags are %s." % (user_name, "higher-order") + + +class ExampleFlowCodeEditorOperator(MapOperator[str, str]): + """An example flow operator that includes a code editor as parameter.""" + + metadata = ViewMetadata( + label="Example Flow Code Editor", + name="example_flow_code_editor", + category=OperatorCategory.EXAMPLE, + description="An example flow operator that includes a code editor as parameter.", + parameters=[ + Parameter.build_from( + "Code Editor", + "code", + type=str, + placeholder="Please input your code", + description="The code you want to edit.", + ui=ui.UICodeEditor( + language="python", + ), + ) + ], + inputs=[ + IOField.build_from( + "User Name", + "user_name", + str, + description="The name of the user.", + ) + ], + outputs=[ + IOField.build_from( + "Code", + "code", + str, + description="Result of the code.", + ) + ], + ) + + def __init__(self, code: str, **kwargs): + super().__init__(**kwargs) + self.code = code + + async def map(self, user_name: str) -> str: + """Map the user name to the code.""" + from dbgpt.util.code_utils import UNKNOWN, extract_code + + code = self.code + exitcode = -1 + try: + code_blocks = extract_code(self.code) + if len(code_blocks) < 1: + logger.info( + f"No executable code found in: \n{code}", + ) + raise ValueError(f"No executable code found in: \n{code}") + elif len(code_blocks) > 1 and code_blocks[0][0] == UNKNOWN: + # found code blocks, execute code and push "last_n_messages" back + logger.info( + f"Missing available code block type, unable to execute code," + f"\n{code}", + ) + raise ValueError( + "Missing available code block type, unable to execute code, " + f"\n{code}" + ) + exitcode, logs = await self.blocking_func_to_async( + self.execute_code_blocks, code_blocks + ) + # exitcode, logs = self.execute_code_blocks(code_blocks) + except Exception as e: + logger.error(f"Failed to execute code: {e}") + logs = f"Failed to execute code: {e}" + return ( + f"Your name is {user_name}, and your code is \n\n```python\n{self.code}" + f"\n\n```\n\nThe execution result is \n\n```\n{logs}\n\n```\n\n" + f"Exit code: {exitcode}." + ) + + def execute_code_blocks(self, code_blocks): + """Execute the code blocks and return the result.""" + from dbgpt.util.code_utils import execute_code, infer_lang + from dbgpt.util.utils import colored + + logs_all = "" + exitcode = -1 + _code_execution_config = {"use_docker": False} + for i, code_block in enumerate(code_blocks): + lang, code = code_block + if not lang: + lang = infer_lang(code) + print( + colored( + f"\n>>>>>>>> EXECUTING CODE BLOCK {i} " + f"(inferred language is {lang})...", + "red", + ), + flush=True, + ) + if lang in ["bash", "shell", "sh"]: + exitcode, logs, image = execute_code( + code, lang=lang, **_code_execution_config + ) + elif lang in ["python", "Python"]: + if code.startswith("# filename: "): + filename = code[11 : code.find("\n")].strip() + else: + filename = None + exitcode, logs, image = execute_code( + code, + lang="python", + filename=filename, + **_code_execution_config, + ) + else: + # In case the language is not supported, we return an error message. + exitcode, logs, image = ( + 1, + f"unknown language {lang}", + None, + ) + # raise NotImplementedError + if image is not None: + _code_execution_config["use_docker"] = image + logs_all += "\n" + logs + if exitcode != 0: + return exitcode, logs_all + return exitcode, logs_all From 978ffe175d720fd354d7647798cc0caf965108e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=A5=E5=BF=97=E5=8B=87?= Date: Mon, 19 Aug 2024 20:57:41 +0800 Subject: [PATCH 47/89] =?UTF-8?q?feat:=201=E3=80=81component=20=20textArea?= =?UTF-8?q?=20support=20mouse=20scrolling=202=E3=80=81component=20slider?= =?UTF-8?q?=20=20support=20range?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/components/flow/node-renderer/slider.tsx | 10 +++++----- web/components/flow/node-renderer/textarea.tsx | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/web/components/flow/node-renderer/slider.tsx b/web/components/flow/node-renderer/slider.tsx index b23780379..7ee011f08 100644 --- a/web/components/flow/node-renderer/slider.tsx +++ b/web/components/flow/node-renderer/slider.tsx @@ -14,25 +14,25 @@ export const RenderSlider = (params: TextAreaProps) => { const { data, defaultValue, onChange } = params; const attr = convertKeysToCamelCase(data.ui?.attr || {}); const [inputValue, setInputValue] = useState(defaultValue); - const onChangeSlider: InputNumberProps['onChange'] = (newValue) => { - setInputValue(newValue as number); - onChange(newValue as number); + setInputValue(newValue); + onChange(newValue); }; +console.log(data); return ( <> {data?.ui?.show_input ? ( - + ) : ( - + )} > ); diff --git a/web/components/flow/node-renderer/textarea.tsx b/web/components/flow/node-renderer/textarea.tsx index 47af5a53b..8f62df8d8 100644 --- a/web/components/flow/node-renderer/textarea.tsx +++ b/web/components/flow/node-renderer/textarea.tsx @@ -18,7 +18,7 @@ export const RenderTextArea = (params: TextAreaProps) => { return ( - + ); }; From 09b95235c5f903979399f64f8eaf221bbd1dad02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A8=E6=AC=A3?= Date: Tue, 20 Aug 2024 08:51:18 +0800 Subject: [PATCH 48/89] fix: add label of component --- web/components/flow/node-param-handler.tsx | 79 +++++++++++----------- 1 file changed, 39 insertions(+), 40 deletions(-) diff --git a/web/components/flow/node-param-handler.tsx b/web/components/flow/node-param-handler.tsx index 549fa996e..540a74c86 100644 --- a/web/components/flow/node-param-handler.tsx +++ b/web/components/flow/node-param-handler.tsx @@ -28,6 +28,17 @@ interface NodeParamHandlerProps { index: number; // index of array } +const renderLabelWithTooltip = (data: IFlowNodeParameter) => ( + + {data.label}: + {data.description && ( + + + + )} + +); + // render node parameters item const NodeParamHandler: React.FC = ({ node, data, label, index }) => { function onChange(value: any) { @@ -43,14 +54,7 @@ const NodeParamHandler: React.FC = ({ node, data, label, case 'float': return ( - - {data.label}: - {data.description && ( - - - - )} - + {renderLabelWithTooltip(data)} = ({ node, data, label, case 'str': return ( - - {data.label}: - {data.description && ( - - - - )} - + {renderLabelWithTooltip(data)} {data.options?.length > 0 ? ( = ({ node, data, label, defaultValue = defaultValue === 'True' ? true : defaultValue; return ( - - {data.label}: - {data.description && ( - - - - )} - { - onChange(e.target.checked); - }} - /> - + {renderLabelWithTooltip(data)} + { + onChange(e.target.checked); + }} + /> ); } } - // render node parameters based on AWEL2.0 - function renderNodeWithUiParam(data: IFlowNodeParameter) { - let defaultValue = data.value ?? data.default; - const props = { data, defaultValue, onChange }; - - switch (data?.ui?.ui_type) { + function renderComponentByType(type: string, props?: any) { + switch (type) { case 'select': return ; case 'cascader': @@ -144,8 +130,8 @@ const NodeParamHandler: React.FC = ({ node, data, label, return ; case 'upload': return ; - case 'variables': - return ; + case 'variables': + return ; case 'code_editor': return ; default: @@ -153,6 +139,19 @@ const NodeParamHandler: React.FC = ({ node, data, label, } } + // render node parameters based on AWEL2.0 + function renderNodeWithUiParam(data: IFlowNodeParameter) { + let defaultValue = data.value ?? data.default; + const props = { data, defaultValue, onChange }; + + return ( + + {renderLabelWithTooltip(data)} + {renderComponentByType(data?.ui?.ui_type, props)} + + ); + } + if (data.category === 'resource') { return ; } else if (data.category === 'common') { @@ -160,4 +159,4 @@ const NodeParamHandler: React.FC = ({ node, data, label, } }; -export default NodeParamHandler; +export default NodeParamHandler; \ No newline at end of file From 861ed9e9804a45b57453c933027b47f877eb5d2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A8=E6=AC=A3?= Date: Tue, 20 Aug 2024 08:54:53 +0800 Subject: [PATCH 49/89] feat: refactor NodeParamHandler component --- web/components/flow/node-param-handler.tsx | 24 ++++++++++++---------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/web/components/flow/node-param-handler.tsx b/web/components/flow/node-param-handler.tsx index 540a74c86..ab76b46a1 100644 --- a/web/components/flow/node-param-handler.tsx +++ b/web/components/flow/node-param-handler.tsx @@ -28,23 +28,25 @@ interface NodeParamHandlerProps { index: number; // index of array } -const renderLabelWithTooltip = (data: IFlowNodeParameter) => ( - - {data.label}: - {data.description && ( - - - - )} - -); - // render node parameters item const NodeParamHandler: React.FC = ({ node, data, label, index }) => { function onChange(value: any) { data.value = value; } + function renderLabelWithTooltip(data: IFlowNodeParameter) { + return ( + + {data.label}: + {data.description && ( + + + + )} + + ); + } + // render node parameters based on AWEL1.0 function renderNodeWithoutUiParam(data: IFlowNodeParameter) { let defaultValue = data.value ?? data.default; From 5cf607012147007c4d3554683664103b34352176 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A8=E6=AC=A3?= Date: Tue, 20 Aug 2024 11:24:30 +0800 Subject: [PATCH 50/89] feat: Update UI components for better user experience --- web/client/api/request.ts | 18 ++++++---- web/components/flow/node-param-handler.tsx | 34 ++++++++++++++++++- .../flow/node-renderer/codeEditor.tsx | 18 +++++----- .../flow/node-renderer/date-picker.tsx | 1 - .../flow/node-renderer/password.tsx | 4 +-- web/components/flow/node-renderer/slider.tsx | 5 ++- .../flow/node-renderer/textarea.tsx | 4 +-- .../flow/node-renderer/time-picker.tsx | 7 ++-- .../flow/node-renderer/tree-select.tsx | 25 ++++++-------- web/components/flow/node-renderer/upload.tsx | 4 +-- web/pages/flow/canvas/index.tsx | 1 + web/types/flow.ts | 19 +++++++++-- 12 files changed, 93 insertions(+), 47 deletions(-) diff --git a/web/client/api/request.ts b/web/client/api/request.ts index bd735d8b1..91141f586 100644 --- a/web/client/api/request.ts +++ b/web/client/api/request.ts @@ -34,7 +34,7 @@ import { SpaceConfig, } from '@/types/knowledge'; import { UpdatePromptParams, IPrompt, PromptParams } from '@/types/prompt'; -import { IFlow, IFlowNode, IFlowResponse, IFlowUpdateParam } from '@/types/flow'; +import { IFlow, IFlowNode, IFlowResponse, IFlowUpdateParam, IFlowRefreshParams } from '@/types/flow'; import { IAgent, IApp, IAppData, ITeamModal } from '@/types/app'; /** App */ @@ -262,27 +262,31 @@ export const addPrompt = (data: UpdatePromptParams) => { /** AWEL Flow */ export const addFlow = (data: IFlowUpdateParam) => { - return POST('/api/v1/serve/awel/flows', data); + return POST('/api/v2/serve/awel/flows', data); }; export const getFlows = () => { - return GET('/api/v1/serve/awel/flows'); + return GET('/api/v2/serve/awel/flows'); }; export const getFlowById = (id: string) => { - return GET(`/api/v1/serve/awel/flows/${id}`); + return GET(`/api/v2/serve/awel/flows/${id}`); }; export const updateFlowById = (id: string, data: IFlowUpdateParam) => { - return PUT(`/api/v1/serve/awel/flows/${id}`, data); + return PUT(`/api/v2/serve/awel/flows/${id}`, data); }; export const deleteFlowById = (id: string) => { - return DELETE(`/api/v1/serve/awel/flows/${id}`); + return DELETE(`/api/v2/serve/awel/flows/${id}`); }; export const getFlowNodes = () => { - return GET>(`/api/v1/serve/awel/nodes`); + return GET>(`/api/v2/serve/awel/nodes`); +}; + +export const refreshFlowNodeById = (data: IFlowRefreshParams) => { + return POST('/api/v2/serve/awel/nodes/refresh', data); }; /** app */ diff --git a/web/components/flow/node-param-handler.tsx b/web/components/flow/node-param-handler.tsx index ab76b46a1..f7d3abadf 100644 --- a/web/components/flow/node-param-handler.tsx +++ b/web/components/flow/node-param-handler.tsx @@ -1,4 +1,5 @@ import { IFlowNode, IFlowNodeParameter } from '@/types/flow'; +import { refreshFlowNodeById, apiInterceptors } from '@/client/api'; import { Checkbox, Input, InputNumber, Select, Tooltip } from 'antd'; import React from 'react'; import RequiredIcon from './required-icon'; @@ -51,6 +52,7 @@ const NodeParamHandler: React.FC = ({ node, data, label, function renderNodeWithoutUiParam(data: IFlowNodeParameter) { let defaultValue = data.value ?? data.default; + console.log('datacc', data); switch (data.type_name) { case 'int': case 'float': @@ -61,6 +63,8 @@ const NodeParamHandler: React.FC = ({ node, data, label, className="w-full" defaultValue={defaultValue} onChange={(value: number | null) => { + console.log('value', value); + onChange(value); }} /> @@ -106,6 +110,32 @@ const NodeParamHandler: React.FC = ({ node, data, label, } } + // TODO: refresh flow node + async function refreshFlowNode() { + // setLoading(true); + const params = { + id: '', + type_name: '', + type_cls: '', + flow_type: 'operator' as const, + refresh: [ + { + name: '', + depends: [ + { + name: '', + value: '', + has_value: true, + }, + ], + }, + ], + }; + const [_, data] = await apiInterceptors(refreshFlowNodeById(params)); + // setLoading(false); + // setFlowList(data?.items ?? []); + } + function renderComponentByType(type: string, props?: any) { switch (type) { case 'select': @@ -146,6 +176,8 @@ const NodeParamHandler: React.FC = ({ node, data, label, let defaultValue = data.value ?? data.default; const props = { data, defaultValue, onChange }; + console.log('xxx', props); + return ( {renderLabelWithTooltip(data)} @@ -161,4 +193,4 @@ const NodeParamHandler: React.FC = ({ node, data, label, } }; -export default NodeParamHandler; \ No newline at end of file +export default NodeParamHandler; diff --git a/web/components/flow/node-renderer/codeEditor.tsx b/web/components/flow/node-renderer/codeEditor.tsx index 03586f25c..6be8de655 100644 --- a/web/components/flow/node-renderer/codeEditor.tsx +++ b/web/components/flow/node-renderer/codeEditor.tsx @@ -22,19 +22,17 @@ export const RenderCodeEditor = (params: Props) => { setIsModalOpen(true); }; - const handleOk = () => { + const onOk = () => { setIsModalOpen(false); }; - const handleCancel = () => { + const onCancel = () => { setIsModalOpen(false); }; - /** - * 设置弹窗宽度 - */ + const modalWidth = useMemo(() => { if (data?.ui?.editor?.width) { - return data?.ui?.editor?.width + 100 + return data?.ui?.editor?.width + 100; } return '80%'; }, [data?.ui?.editor?.width]); @@ -44,16 +42,16 @@ export const RenderCodeEditor = (params: Props) => { {t('openCodeEditor')} - + + { const { data, defaultValue, onChange } = params; - const attr = convertKeysToCamelCase(data.ui?.attr || {}); const onChangeDate: DatePickerProps['onChange'] = (date, dateString) => { diff --git a/web/components/flow/node-renderer/password.tsx b/web/components/flow/node-renderer/password.tsx index 93dec6e85..8ab44cd90 100644 --- a/web/components/flow/node-renderer/password.tsx +++ b/web/components/flow/node-renderer/password.tsx @@ -4,13 +4,13 @@ import { convertKeysToCamelCase } from '@/utils/flow'; const { Password } = Input; -type TextAreaProps = { +type Props = { data: IFlowNodeParameter; defaultValue: any; onChange: (value: any) => void; }; -export const RenderPassword = (params: TextAreaProps) => { +export const RenderPassword = (params: Props) => { const { data, defaultValue, onChange } = params; const attr = convertKeysToCamelCase(data.ui?.attr || {}); diff --git a/web/components/flow/node-renderer/slider.tsx b/web/components/flow/node-renderer/slider.tsx index 7ee011f08..adc5c1f25 100644 --- a/web/components/flow/node-renderer/slider.tsx +++ b/web/components/flow/node-renderer/slider.tsx @@ -4,13 +4,13 @@ import { Col, InputNumber, Row, Slider, Space } from 'antd'; import type { InputNumberProps } from 'antd'; import React, { useState } from 'react'; -type TextAreaProps = { +type Props = { data: IFlowNodeParameter; defaultValue: any; onChange: (value: any) => void; }; -export const RenderSlider = (params: TextAreaProps) => { +export const RenderSlider = (params: Props) => { const { data, defaultValue, onChange } = params; const attr = convertKeysToCamelCase(data.ui?.attr || {}); const [inputValue, setInputValue] = useState(defaultValue); @@ -18,7 +18,6 @@ export const RenderSlider = (params: TextAreaProps) => { setInputValue(newValue); onChange(newValue); }; -console.log(data); return ( <> diff --git a/web/components/flow/node-renderer/textarea.tsx b/web/components/flow/node-renderer/textarea.tsx index 8f62df8d8..e59f74ec2 100644 --- a/web/components/flow/node-renderer/textarea.tsx +++ b/web/components/flow/node-renderer/textarea.tsx @@ -5,13 +5,13 @@ import classNames from 'classnames'; const { TextArea } = Input; -type TextAreaProps = { +type Props = { data: IFlowNodeParameter; defaultValue: any; onChange: (value: any) => void; }; -export const RenderTextArea = (params: TextAreaProps) => { +export const RenderTextArea = (params: Props) => { const { data, defaultValue, onChange } = params; const attr = convertKeysToCamelCase(data.ui?.attr || {}); diff --git a/web/components/flow/node-renderer/time-picker.tsx b/web/components/flow/node-renderer/time-picker.tsx index 95e13524b..4c560aa4b 100644 --- a/web/components/flow/node-renderer/time-picker.tsx +++ b/web/components/flow/node-renderer/time-picker.tsx @@ -1,15 +1,16 @@ -import React, { useState } from 'react'; +import React from 'react'; import type { TimePickerProps } from 'antd'; import { TimePicker } from 'antd'; import { IFlowNodeParameter } from '@/types/flow'; import { convertKeysToCamelCase } from '@/utils/flow'; -type TextAreaProps = { +type Props = { data: IFlowNodeParameter; defaultValue: any; onChange: (value: any) => void; }; -export const RenderTimePicker = (params: TextAreaProps) => { + +export const RenderTimePicker = (params: Props) => { const { data, defaultValue, onChange } = params; const attr = convertKeysToCamelCase(data.ui?.attr || {}); diff --git a/web/components/flow/node-renderer/tree-select.tsx b/web/components/flow/node-renderer/tree-select.tsx index 74ee226fd..f619485c7 100644 --- a/web/components/flow/node-renderer/tree-select.tsx +++ b/web/components/flow/node-renderer/tree-select.tsx @@ -3,27 +3,24 @@ import { TreeSelect } from 'antd'; import { IFlowNodeParameter } from '@/types/flow'; import { convertKeysToCamelCase } from '@/utils/flow'; -type TextAreaProps = { +type Props = { data: IFlowNodeParameter; defaultValue: any; onChange: (value: any) => void; }; -export const RenderTreeSelect = (params: TextAreaProps) => { +export const RenderTreeSelect = (params: Props) => { const { data, defaultValue, onChange } = params; const attr = convertKeysToCamelCase(data.ui?.attr || {}); return ( - - - + ); }; diff --git a/web/components/flow/node-renderer/upload.tsx b/web/components/flow/node-renderer/upload.tsx index 35e1ddd75..9f2944af7 100644 --- a/web/components/flow/node-renderer/upload.tsx +++ b/web/components/flow/node-renderer/upload.tsx @@ -1,7 +1,7 @@ import React from 'react'; import { UploadOutlined } from '@ant-design/icons'; import type { UploadProps } from 'antd'; -import { Button, message, Upload } from 'antd'; +import { Button, Upload } from 'antd'; import { convertKeysToCamelCase } from '@/utils/flow'; import { IFlowNodeParameter } from '@/types/flow'; import { useTranslation } from 'react-i18next'; @@ -28,7 +28,7 @@ export const RenderUpload = (params: Props) => { const attr = convertKeysToCamelCase(data.ui?.attr || {}); return ( - + }>{t('UploadData')} diff --git a/web/pages/flow/canvas/index.tsx b/web/pages/flow/canvas/index.tsx index 8c304845a..072884f5c 100644 --- a/web/pages/flow/canvas/index.tsx +++ b/web/pages/flow/canvas/index.tsx @@ -179,6 +179,7 @@ const Canvas: React.FC = () => { const { name, label, description = '', editable = false, state = 'deployed' } = form.getFieldsValue(); console.log(form.getFieldsValue()); const reactFlowObject = mapHumpToUnderline(reactFlow.toObject() as IFlowData); + if (id) { const [, , res] = await apiInterceptors(updateFlowById(id, { name, label, description, editable, uid: id, flow_data: reactFlowObject, state })); setIsModalVisible(false); diff --git a/web/types/flow.ts b/web/types/flow.ts index 175047e31..04a935052 100644 --- a/web/types/flow.ts +++ b/web/types/flow.ts @@ -12,6 +12,21 @@ export type IFlowUpdateParam = { state?: FlowState; }; +export type IFlowRefreshParams = { + id: string; + type_name: string; + type_cls: string; + flow_type: 'resource' | 'operator'; + refresh: { + name: string; + depends?: Array<{ + name: string; + value: string; + has_value: boolean; + }>; + }[]; +}; + export type IFlow = { dag_id: string; gmt_created: string; @@ -60,8 +75,8 @@ export type IFlowNodeParameterUI = { [key: string]: any; }; editor: { - width: Number; - height: Number; + width: number; + height: number; }; show_input: boolean; }; From 37240c9c2fde7f7745ca2570a5d1620f7e776bcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A8=E6=AC=A3?= Date: Tue, 20 Aug 2024 12:04:45 +0800 Subject: [PATCH 51/89] chore: update api --- web/client/api/request.ts | 43 +++++++++++++++++++++++++++++++++++++-- web/types/flow.ts | 27 ++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/web/client/api/request.ts b/web/client/api/request.ts index 91141f586..40b45f672 100644 --- a/web/client/api/request.ts +++ b/web/client/api/request.ts @@ -34,8 +34,18 @@ import { SpaceConfig, } from '@/types/knowledge'; import { UpdatePromptParams, IPrompt, PromptParams } from '@/types/prompt'; -import { IFlow, IFlowNode, IFlowResponse, IFlowUpdateParam, IFlowRefreshParams } from '@/types/flow'; -import { IAgent, IApp, IAppData, ITeamModal } from '@/types/app'; +import { + IFlow, + IFlowNode, + IFlowResponse, + IFlowUpdateParam, + IFlowRefreshParams, + IFlowExportParams, + IFlowImportParams, + IUploadFileRequestParams, + IUploadFileResponse, +} from '@/types/flow'; +import { IAgent, IApp, IAppData } from '@/types/app'; /** App */ export const postScenes = () => { @@ -289,6 +299,35 @@ export const refreshFlowNodeById = (data: IFlowRefreshParams) => { return POST('/api/v2/serve/awel/nodes/refresh', data); }; +// TODO: wait for interface update +export const debugFlow = (data: any) => { + return POST('/api/v2/serve/awel/flow/debug', data); +}; + +export const exportFlow = (data: IFlowExportParams) => { + return GET('/api/v2/serve/awel/flow/export', data); +}; + +export const importFlow = (data: IFlowImportParams) => { + return POST('/api/v2/serve/awel/flow/import', data); +}; + +export const getFlowTemplateList = () => { + return GET>('/api/v2/serve/awel/flow/templates'); +}; + +export const getFlowTemplateById = (id: string) => { + return GET(`/api/v2/serve/awel/flow/templates/${id}`); +}; + +export const uploadFile = (data: IUploadFileRequestParams) => { + return POST>('/api/v2/serve/file/files/dbgpt', data); +}; + +export const downloadFile = (fileId: string) => { + return GET(`/api/v2/serve/file/files/dbgpt/${fileId}`); +}; + /** app */ export const addApp = (data: IApp) => { return POST('/api/v1/app/create', data); diff --git a/web/types/flow.ts b/web/types/flow.ts index 04a935052..07b193d75 100644 --- a/web/types/flow.ts +++ b/web/types/flow.ts @@ -1,3 +1,4 @@ +import { File } from 'buffer'; import { Node } from 'reactflow'; export type FlowState = 'deployed' | 'developing' | 'initializing' | 'testing' | 'disabled' | 'running' | 'load_failed'; @@ -165,3 +166,29 @@ export type IFlowData = { edges: Array; viewport: IFlowDataViewport; }; + +export type IFlowExportParams = { + export_type?: 'json' | 'dbgpts'; + format?: 'json' | 'file'; + file_name?: string; + user_name?: string; + sys_code?: string; +}; + +export type IFlowImportParams = { + file: File; + save_flow?: boolean; +}; + +export type IUploadFileRequestParams = { + files: Array; + user_name?: string; + sys_code?: string; +}; + +export type IUploadFileResponse = { + file_name: string; + file_id: string; + bucket: string; + uri?: string; +}; From 74cbfb6be2f84c49d5bb88f6d9a36dfeb4c9a359 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A8=E6=AC=A3?= Date: Tue, 20 Aug 2024 14:30:15 +0800 Subject: [PATCH 52/89] feat: add export/import functionality to flow canvas --- web/components/flow/node-param-handler.tsx | 26 ---------- web/pages/flow/canvas/index.tsx | 58 ++++++++++++++++++---- 2 files changed, 49 insertions(+), 35 deletions(-) diff --git a/web/components/flow/node-param-handler.tsx b/web/components/flow/node-param-handler.tsx index f7d3abadf..ff2901aab 100644 --- a/web/components/flow/node-param-handler.tsx +++ b/web/components/flow/node-param-handler.tsx @@ -110,32 +110,6 @@ const NodeParamHandler: React.FC = ({ node, data, label, } } - // TODO: refresh flow node - async function refreshFlowNode() { - // setLoading(true); - const params = { - id: '', - type_name: '', - type_cls: '', - flow_type: 'operator' as const, - refresh: [ - { - name: '', - depends: [ - { - name: '', - value: '', - has_value: true, - }, - ], - }, - ], - }; - const [_, data] = await apiInterceptors(refreshFlowNodeById(params)); - // setLoading(false); - // setFlowList(data?.items ?? []); - } - function renderComponentByType(type: string, props?: any) { switch (type) { case 'select': diff --git a/web/pages/flow/canvas/index.tsx b/web/pages/flow/canvas/index.tsx index 072884f5c..e7ae601aa 100644 --- a/web/pages/flow/canvas/index.tsx +++ b/web/pages/flow/canvas/index.tsx @@ -5,8 +5,8 @@ import ButtonEdge from '@/components/flow/button-edge'; import CanvasNode from '@/components/flow/canvas-node'; import { IFlowData, IFlowUpdateParam } from '@/types/flow'; import { checkFlowDataRequied, getUniqueNodeId, mapHumpToUnderline, mapUnderlineToHump } from '@/utils/flow'; -import { FrownOutlined, SaveOutlined } from '@ant-design/icons'; -import { Button, Checkbox, Divider, Form, Input, Modal, Space, message, notification } from 'antd'; +import { ExportOutlined, FrownOutlined, ImportOutlined, SaveOutlined } from '@ant-design/icons'; +import { Button, Checkbox, Divider, Form, Input, Modal, Space, Tooltip, message, notification } from 'antd'; import { useSearchParams } from 'next/navigation'; import React, { DragEvent, useCallback, useEffect, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; @@ -150,7 +150,7 @@ const Canvas: React.FC = () => { form.setFieldsValue({ name: result }); } - function clickSave() { + function onSave() { const flowData = reactFlow.toObject() as IFlowData; const [check, node, message] = checkFlowDataRequied(flowData); if (!check && message) { @@ -175,11 +175,40 @@ const Canvas: React.FC = () => { setIsModalVisible(true); } + // TODO: EXport flow data + function onExport() { + const flowData = reactFlow.toObject() as IFlowData; + const blob = new Blob([JSON.stringify(flowData)], { type: 'text/plain;charset=utf-8' }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = 'flow.json'; + a.click(); + } + + // TODO: Import flow data + function onImport() { + const input = document.createElement('input'); + input.type = 'file'; + input.accept = '.json'; + input.onchange = async (e: any) => { + const file = e.target.files[0]; + const reader = new FileReader(); + reader.onload = async (event) => { + const flowData = JSON.parse(event.target?.result as string) as IFlowData; + setNodes(flowData.nodes); + setEdges(flowData.edges); + }; + reader.readAsText(file); + }; + input.click; + } + async function handleSaveFlow() { const { name, label, description = '', editable = false, state = 'deployed' } = form.getFieldsValue(); console.log(form.getFieldsValue()); const reactFlowObject = mapHumpToUnderline(reactFlow.toObject() as IFlowData); - + if (id) { const [, , res] = await apiInterceptors(updateFlowById(id, { name, label, description, editable, uid: id, flow_data: reactFlowObject, state })); setIsModalVisible(false); @@ -202,11 +231,22 @@ const Canvas: React.FC = () => { return ( <> - - - - - + + {[ + { title: 'import', icon: }, + { title: 'export', icon: }, + { title: 'save', icon: }, + ].map(({ title, icon }) => ( + + {icon} + + ))} + + Date: Tue, 20 Aug 2024 14:39:59 +0800 Subject: [PATCH 53/89] feat: refactor NodeParamHandler component --- web/components/flow/node-param-handler.tsx | 26 ++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/web/components/flow/node-param-handler.tsx b/web/components/flow/node-param-handler.tsx index ff2901aab..3feaf95fe 100644 --- a/web/components/flow/node-param-handler.tsx +++ b/web/components/flow/node-param-handler.tsx @@ -110,6 +110,32 @@ const NodeParamHandler: React.FC = ({ node, data, label, } } + // TODO: refresh flow node + async function refreshFlowNode() { + // setLoading(true); + const params = { + id: '', + type_name: '', + type_cls: '', + flow_type: 'operator' as const, + refresh: [ + { + name: '', + depends: [ + { + name: '', + value: '', + has_value: true, + }, + ], + }, + ], + }; + const [_, data] = await apiInterceptors(refreshFlowNodeById(params)); + // setLoading(false); + // setFlowList(data?.items ?? []); + } + function renderComponentByType(type: string, props?: any) { switch (type) { case 'select': From 093ab9affb7424c0651d091ab4a3b647b17abe84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A8=E6=AC=A3?= Date: Tue, 20 Aug 2024 15:54:01 +0800 Subject: [PATCH 54/89] refactor: refactor canvasNode component to improve rendering performance --- web/components/flow/canvas-node.tsx | 48 +++++++++++------------------ 1 file changed, 18 insertions(+), 30 deletions(-) diff --git a/web/components/flow/canvas-node.tsx b/web/components/flow/canvas-node.tsx index 59129718f..d46212714 100644 --- a/web/components/flow/canvas-node.tsx +++ b/web/components/flow/canvas-node.tsx @@ -15,8 +15,6 @@ type CanvasNodeProps = { data: IFlowNode; }; -const ICON_PATH_PREFIX = '/icons/node/'; - function TypeLabel({ label }: { label: string }) { return {label}; } @@ -71,29 +69,6 @@ const CanvasNode: React.FC = ({ data }) => { reactFlow.setEdges((edges) => edges.filter((edge) => edge.source !== node.id && edge.target !== node.id)); } - function renderOutput(data: IFlowNode) { - if (flowType === 'operator' && outputs?.length > 0) { - return ( - - - - {outputs?.map((output, index) => ( - - ))} - - - ); - } else if (flowType === 'resource') { - // resource nodes show output default - return ( - - - - - ); - } - } - return ( = ({ data }) => { - {inputs?.map((input, index) => ( - + {inputs?.map((item, index) => ( + ))} @@ -156,14 +131,27 @@ const CanvasNode: React.FC = ({ data }) => { - {parameters?.map((parameter, index) => ( - + {parameters?.map((item, index) => ( + ))} )} - {renderOutput(node)} + {outputs?.length > 0 && ( + + + {flowType === 'operator' ? ( + + {outputs.map((item, index) => ( + + ))} + + ) : ( + flowType === 'resource' && + )} + + )} ); From dd4a90ddf108c49038db6f7173ff70e42926d407 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A8=E6=AC=A3?= Date: Tue, 20 Aug 2024 15:54:58 +0800 Subject: [PATCH 55/89] style: remove console.log statements in NodeParamHandler component --- web/components/flow/node-param-handler.tsx | 3 --- 1 file changed, 3 deletions(-) diff --git a/web/components/flow/node-param-handler.tsx b/web/components/flow/node-param-handler.tsx index 3feaf95fe..d34d66642 100644 --- a/web/components/flow/node-param-handler.tsx +++ b/web/components/flow/node-param-handler.tsx @@ -52,7 +52,6 @@ const NodeParamHandler: React.FC = ({ node, data, label, function renderNodeWithoutUiParam(data: IFlowNodeParameter) { let defaultValue = data.value ?? data.default; - console.log('datacc', data); switch (data.type_name) { case 'int': case 'float': @@ -176,8 +175,6 @@ const NodeParamHandler: React.FC = ({ node, data, label, let defaultValue = data.value ?? data.default; const props = { data, defaultValue, onChange }; - console.log('xxx', props); - return ( {renderLabelWithTooltip(data)} From 4736b80dd2b4c05e059f124719ba592f502cd68b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=A5=E5=BF=97=E5=8B=87?= Date: Tue, 20 Aug 2024 21:42:31 +0800 Subject: [PATCH 56/89] feat: Component Upload complete --- web/app/i18n.ts | 4 ++ web/components/flow/node-renderer/upload.tsx | 58 +++++++++++++++----- web/types/flow.ts | 2 + 3 files changed, 50 insertions(+), 14 deletions(-) diff --git a/web/app/i18n.ts b/web/app/i18n.ts index a03216c55..b15332d35 100644 --- a/web/app/i18n.ts +++ b/web/app/i18n.ts @@ -2,6 +2,8 @@ import i18n from 'i18next'; import { initReactI18next } from 'react-i18next'; const en = { + UploadDataSuccessfully: 'file uploaded successfully', + UploadDataFailed: 'file upload failed', UploadData: 'Upload Data', CodeEditor: 'Code Editor:', openCodeEditor:'Open Code Editor', @@ -237,6 +239,8 @@ export interface Resources { } const zh: Resources['translation'] = { + UploadDataSuccessfully: '文件上传成功', + UploadDataFailed: '文件上传失败', UploadData: '上传数据', CodeEditor: '代码编辑:', openCodeEditor: '打开代码编辑器', diff --git a/web/components/flow/node-renderer/upload.tsx b/web/components/flow/node-renderer/upload.tsx index 9f2944af7..0f2a433c2 100644 --- a/web/components/flow/node-renderer/upload.tsx +++ b/web/components/flow/node-renderer/upload.tsx @@ -1,37 +1,67 @@ -import React from 'react'; +import React,{useState,useRef}from 'react'; import { UploadOutlined } from '@ant-design/icons'; import type { UploadProps } from 'antd'; -import { Button, Upload } from 'antd'; +import { Button, Upload,message } from 'antd'; import { convertKeysToCamelCase } from '@/utils/flow'; import { IFlowNodeParameter } from '@/types/flow'; import { useTranslation } from 'react-i18next'; -const props: UploadProps = { - name: 'file', - action: 'https://660d2bd96ddfa2943b33731c.mockapi.io/api/upload', - headers: { - authorization: 'authorization-text', - }, -}; - type Props = { data: IFlowNodeParameter; defaultValue: any; onChange: (value: any) => void; }; - export const RenderUpload = (params: Props) => { const { t } = useTranslation(); + const urlList =useRef([]); const { data, defaultValue, onChange } = params; const attr = convertKeysToCamelCase(data.ui?.attr || {}); + const [uploading, setUploading] = useState(false); + const getUploadSuccessUrl = (url:string) => { + if (urlList.current.length === data.ui.attr.max_count) { + urlList.current.pop(); + } + urlList.current.push(url) + console.log('上传数据'+urlList.current); + + onChange(urlList.current.toString()) + } + const handleFileRemove = (file:any) => { + const index = urlList.current.indexOf(file.response.data[0].uri); + if (index !== -1) { + urlList.current.splice(index, 1); + } + onChange(urlList.current.toString()) + } + const props: UploadProps = { + name: 'files', + action: process.env.API_BASE_URL + data.ui.action, + headers: { + 'authorization': 'authorization-text', + }, + onChange(info) { + setUploading(true) + if (info.file.status !== 'uploading') { + console.log(info.file, info.fileList); + } + if (info.file.status === 'done') { + setUploading(false) + message.success(`${info.file.response.data[0].file_name} ${t('UploadDataSuccessfully')}`); + getUploadSuccessUrl(info.file.response.data[0].uri) + } else if (info.file.status === 'error') { + setUploading(false) + message.error(`${info.file.response.data[0].file_name} ${t('UploadDataFailed')}`); + } + }, + }; return ( - - }>{t('UploadData')} - + + }>{t('UploadData')} + ) diff --git a/web/types/flow.ts b/web/types/flow.ts index 07b193d75..6a307ec1d 100644 --- a/web/types/flow.ts +++ b/web/types/flow.ts @@ -71,6 +71,8 @@ export type IFlowNodeParameter = { export type IFlowNodeParameterUI = { ui_type: string; language: string; + file_types: string; + action: string; attr: { disabled: boolean; [key: string]: any; From 1781d0d1a662e8161dcc7b31c80aa9a8f3a0f3a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=A5=E5=BF=97=E5=8B=87?= Date: Tue, 20 Aug 2024 21:44:02 +0800 Subject: [PATCH 57/89] feat:remove console --- web/components/flow/node-renderer/upload.tsx | 2 -- 1 file changed, 2 deletions(-) diff --git a/web/components/flow/node-renderer/upload.tsx b/web/components/flow/node-renderer/upload.tsx index 0f2a433c2..e16d59cd5 100644 --- a/web/components/flow/node-renderer/upload.tsx +++ b/web/components/flow/node-renderer/upload.tsx @@ -25,7 +25,6 @@ export const RenderUpload = (params: Props) => { urlList.current.pop(); } urlList.current.push(url) - console.log('上传数据'+urlList.current); onChange(urlList.current.toString()) } @@ -45,7 +44,6 @@ export const RenderUpload = (params: Props) => { onChange(info) { setUploading(true) if (info.file.status !== 'uploading') { - console.log(info.file, info.fileList); } if (info.file.status === 'done') { setUploading(false) From 1c536213d8b19f0da26d02bee967127a9c35153d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A8=E6=AC=A3?= Date: Tue, 20 Aug 2024 21:47:43 +0800 Subject: [PATCH 58/89] refactor: remove unused API functions and update flow utils --- web/client/api/request.ts | 18 +++++++++--------- web/components/flow/node-handler.tsx | 9 ++++++--- web/utils/flow.ts | 7 ++++++- 3 files changed, 21 insertions(+), 13 deletions(-) diff --git a/web/client/api/request.ts b/web/client/api/request.ts index 40b45f672..756c1b70b 100644 --- a/web/client/api/request.ts +++ b/web/client/api/request.ts @@ -299,7 +299,6 @@ export const refreshFlowNodeById = (data: IFlowRefreshParams) => { return POST('/api/v2/serve/awel/nodes/refresh', data); }; -// TODO: wait for interface update export const debugFlow = (data: any) => { return POST('/api/v2/serve/awel/flow/debug', data); }; @@ -312,14 +311,6 @@ export const importFlow = (data: IFlowImportParams) => { return POST('/api/v2/serve/awel/flow/import', data); }; -export const getFlowTemplateList = () => { - return GET>('/api/v2/serve/awel/flow/templates'); -}; - -export const getFlowTemplateById = (id: string) => { - return GET(`/api/v2/serve/awel/flow/templates/${id}`); -}; - export const uploadFile = (data: IUploadFileRequestParams) => { return POST>('/api/v2/serve/file/files/dbgpt', data); }; @@ -328,6 +319,15 @@ export const downloadFile = (fileId: string) => { return GET(`/api/v2/serve/file/files/dbgpt/${fileId}`); }; +// TODO:wait for interface update +export const getFlowTemplateList = () => { + return GET>('/api/v2/serve/awel/flow/templates'); +}; + +export const getFlowTemplateById = (id: string) => { + return GET(`/api/v2/serve/awel/flow/templates/${id}`); +}; + /** app */ export const addApp = (data: IApp) => { return POST('/api/v1/app/create', data); diff --git a/web/components/flow/node-handler.tsx b/web/components/flow/node-handler.tsx index dfe144b2c..e28a50261 100644 --- a/web/components/flow/node-handler.tsx +++ b/web/components/flow/node-handler.tsx @@ -101,8 +101,10 @@ const NodeHandler: React.FC = ({ node, data, type, label, inde isValidConnection={(connection) => isValidConnection(connection)} /> = ({ node, data, type, label, inde } > - {['inputs', 'parameters'].includes(label) && } + {['inputs', 'parameters'].includes(label) && } - {data.type_name}:{label !== 'outputs' && } + {label !== 'outputs' && } + {data.type_name} {data.description && ( diff --git a/web/utils/flow.ts b/web/utils/flow.ts index f7ba7d3b7..3a61085cf 100644 --- a/web/utils/flow.ts +++ b/web/utils/flow.ts @@ -11,6 +11,11 @@ export const getUniqueNodeId = (nodeData: IFlowNode, nodes: Node[]) => { return `${nodeData.id}_${count}`; }; +// function getUniqueNodeId will add '_${count}' to id, so we need to remove it when we want to get the original id +export const removeIndexFromNodeId = (id: string) => { + const indexPattern = /_\d+$/; + return id.replace(indexPattern, ''); +}; // 驼峰转下划线,接口协议字段命名规范 export const mapHumpToUnderline = (flowData: IFlowData) => { @@ -111,7 +116,7 @@ export const convertKeysToCamelCase = (obj: Record): Record convert(item)); + return obj.map((item) => convert(item)); } else if (isObject(obj)) { const newObj: Record = {}; for (const key in obj) { From 438adac7123d16deec5f98a47e415ac132439874 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A8=E6=AC=A3?= Date: Tue, 20 Aug 2024 21:47:54 +0800 Subject: [PATCH 59/89] refactor: Update canvasNode component to support parameter refreshing --- web/components/flow/canvas-node.tsx | 78 ++++++++- web/components/flow/node-param-handler.tsx | 190 +++++++++++++++------ web/types/flow.ts | 8 +- 3 files changed, 220 insertions(+), 56 deletions(-) diff --git a/web/components/flow/canvas-node.tsx b/web/components/flow/canvas-node.tsx index d46212714..f91e0f128 100644 --- a/web/components/flow/canvas-node.tsx +++ b/web/components/flow/canvas-node.tsx @@ -1,15 +1,16 @@ -import { IFlowNode } from '@/types/flow'; +import { IFlowNode, IFlowRefreshParams } from '@/types/flow'; import Image from 'next/image'; import NodeParamHandler from './node-param-handler'; import classNames from 'classnames'; import { useState } from 'react'; import NodeHandler from './node-handler'; -import { Popover, Tooltip } from 'antd'; +import { Form, Popover, Tooltip } from 'antd'; import { CopyOutlined, DeleteOutlined, InfoCircleOutlined } from '@ant-design/icons'; import { useReactFlow } from 'reactflow'; import IconWrapper from '../common/icon-wrapper'; -import { getUniqueNodeId } from '@/utils/flow'; +import { getUniqueNodeId, removeIndexFromNodeId } from '@/utils/flow'; import { cloneDeep } from 'lodash'; +import { apiInterceptors, refreshFlowNodeById } from '@/client/api'; type CanvasNodeProps = { data: IFlowNode; @@ -24,6 +25,7 @@ const CanvasNode: React.FC = ({ data }) => { const { inputs, outputs, parameters, flow_type: flowType } = node; const [isHovered, setIsHovered] = useState(false); const reactFlow = useReactFlow(); + const [form] = Form.useForm(); function onHover() { setIsHovered(true); @@ -69,6 +71,68 @@ const CanvasNode: React.FC = ({ data }) => { reactFlow.setEdges((edges) => edges.filter((edge) => edge.source !== node.id && edge.target !== node.id)); } + // function onChange(value: any) { + // data.value = value; + // } + + function onValuesChange(changedValues: any, allValues: any) { + // onChange(changedValues); + console.log('Changed xxx', changedValues); + console.log('All xxx', allValues); + console.log('xxxx', parameters); + + const [changedKey, changedVal] = Object.entries(changedValues)[0]; + console.log('====', changedKey, changedVal); + + // 获取以当前改变项目为 refresh_depends 的参数name + const needChangeNodes = parameters.filter(({ ui }) => ui?.refresh_depends?.includes(changedKey)); + console.log('needChangeNodes====', needChangeNodes); + + if (needChangeNodes?.length === 0) return; + + needChangeNodes.forEach(async (item) => { + const params = { + id: removeIndexFromNodeId(data?.id), + type_name: data.type_name, + type_cls: data.type_cls, + flow_type: 'operator' as const, + refresh: [ + { + name: item.name, // 要刷新的参数的name + depends: [ + { + name: changedKey, // 依赖的参数的name + value: changedVal, // 依赖的参数的值 + has_value: true, + }, + ], + }, + ], + }; + + // const params = { + // id: 'operator_example_refresh_operator___$$___example___$$___v1', + // type_name: 'ExampleFlowRefreshOperator', + // type_cls: 'unusual_prefix_90027f35e50ecfda77e3c7c7b20a0272d562480c_awel_flow_ui_components.ExampleFlowRefreshOperator', + // flow_type: 'operator' as const, + // refresh: [ + // { + // name: 'recent_time', // 要刷新的参数的name + // depends: [ + // { + // name: 'time_interval', // 依赖的参数的name + // value: 3, // 依赖的参数的值 + // has_value: true, + // }, + // ], + // }, + // ], + // }; + + const [_, res] = await apiInterceptors(refreshFlowNodeById(params)); + }); + } + return ( = ({ data }) => { {parameters?.length > 0 && ( - + {/* */} + + {parameters?.map((item, index) => ( ))} - + + + {/* */} )} diff --git a/web/components/flow/node-param-handler.tsx b/web/components/flow/node-param-handler.tsx index d34d66642..1d4821262 100644 --- a/web/components/flow/node-param-handler.tsx +++ b/web/components/flow/node-param-handler.tsx @@ -1,6 +1,5 @@ import { IFlowNode, IFlowNodeParameter } from '@/types/flow'; -import { refreshFlowNodeById, apiInterceptors } from '@/client/api'; -import { Checkbox, Input, InputNumber, Select, Tooltip } from 'antd'; +import { Checkbox, Form, Input, InputNumber, Select, Tooltip } from 'antd'; import React from 'react'; import RequiredIcon from './required-icon'; import NodeHandler from './node-handler'; @@ -38,17 +37,75 @@ const NodeParamHandler: React.FC = ({ node, data, label, function renderLabelWithTooltip(data: IFlowNodeParameter) { return ( - {data.label}: - {data.description && ( - - - - )} + + {data.label} + + + ); } // render node parameters based on AWEL1.0 + // function renderNodeWithoutUiParam(data: IFlowNodeParameter) { + // let defaultValue = data.value ?? data.default; + + // switch (data.type_name) { + // case 'int': + // case 'float': + // return ( + // + // {renderLabelWithTooltip(data)} + // { + // console.log('value', value); + + // onChange(value); + // }} + // /> + // + // ); + // case 'str': + // return ( + // + // {renderLabelWithTooltip(data)} + // {data.options?.length > 0 ? ( + // ({ label: item.label, value: item.value }))} + // onChange={onChange} + // /> + // ) : ( + // { + // onChange(e.target.value); + // }} + // /> + // )} + // + // ); + // case 'bool': + // defaultValue = defaultValue === 'False' ? false : defaultValue; + // defaultValue = defaultValue === 'True' ? true : defaultValue; + // return ( + // + // {renderLabelWithTooltip(data)} + // { + // onChange(e.target.checked); + // }} + // /> + // + // ); + // } + // } function renderNodeWithoutUiParam(data: IFlowNodeParameter) { let defaultValue = data.value ?? data.default; @@ -56,23 +113,39 @@ const NodeParamHandler: React.FC = ({ node, data, label, case 'int': case 'float': return ( - + {data.label}} + tooltip={data.description ? { title: data.description, icon: } : ''} + rules={[{ required: !data.optional }]} + > + + + ); + { + /* {renderLabelWithTooltip(data)} { console.log('value', value); - onChange(value); }} /> - - ); + */ + } + case 'str': return ( - - {renderLabelWithTooltip(data)} + {data.label}} + tooltip={data.description ? { title: data.description, icon: } : ''} + rules={[{ required: !data.optional }]} + > {data.options?.length > 0 ? ( = ({ node, data, label, }} /> )} - + + // + // {renderLabelWithTooltip(data)} + // {data.options?.length > 0 ? ( + // ({ label: item.label, value: item.value }))} + // onChange={onChange} + // /> + // ) : ( + // { + // onChange(e.target.value); + // }} + // /> + // )} + // ); case 'bool': defaultValue = defaultValue === 'False' ? false : defaultValue; defaultValue = defaultValue === 'True' ? true : defaultValue; return ( - - {renderLabelWithTooltip(data)} + // + // {renderLabelWithTooltip(data)} + // { + // onChange(e.target.checked); + // }} + // /> + // + + {data.label}} + tooltip={data.description ? { title: data.description, icon: } : ''} + rules={[{ required: !data.optional }]} + > = ({ node, data, label, onChange(e.target.checked); }} /> - + ); } } - // TODO: refresh flow node - async function refreshFlowNode() { - // setLoading(true); - const params = { - id: '', - type_name: '', - type_cls: '', - flow_type: 'operator' as const, - refresh: [ - { - name: '', - depends: [ - { - name: '', - value: '', - has_value: true, - }, - ], - }, - ], - }; - const [_, data] = await apiInterceptors(refreshFlowNodeById(params)); - // setLoading(false); - // setFlowList(data?.items ?? []); - } - function renderComponentByType(type: string, props?: any) { switch (type) { case 'select': @@ -172,14 +254,26 @@ const NodeParamHandler: React.FC = ({ node, data, label, // render node parameters based on AWEL2.0 function renderNodeWithUiParam(data: IFlowNodeParameter) { + const { refresh_depends, ui_type } = data.ui; let defaultValue = data.value ?? data.default; const props = { data, defaultValue, onChange }; return ( - - {renderLabelWithTooltip(data)} - {renderComponentByType(data?.ui?.ui_type, props)} - + // + // {renderLabelWithTooltip(data)} + // {renderComponentByType(data?.ui?.ui_type, props)} + // + + {data.label}} + tooltip={data.description ? { title: data.description, icon: } : ''} + {...(refresh_depends && { dependencies: refresh_depends })} + rules={[{ required: !data.optional }]} + > + {renderComponentByType(ui_type, props)} + ); } diff --git a/web/types/flow.ts b/web/types/flow.ts index 07b193d75..ed497c018 100644 --- a/web/types/flow.ts +++ b/web/types/flow.ts @@ -22,7 +22,7 @@ export type IFlowRefreshParams = { name: string; depends?: Array<{ name: string; - value: string; + value: any; has_value: boolean; }>; }[]; @@ -75,11 +75,13 @@ export type IFlowNodeParameterUI = { disabled: boolean; [key: string]: any; }; - editor: { + editor?: { width: number; height: number; }; - show_input: boolean; + show_input?: boolean; + refresh?: boolean; + refresh_depends?: string[]; }; export type IFlowNodeInput = { From 3859ef17b52e4259430920503e840515ec485f9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=A5=E5=BF=97=E5=8B=87?= Date: Tue, 20 Aug 2024 21:48:38 +0800 Subject: [PATCH 60/89] feat: flow Component Upload support multiple selection --- web/components/flow/node-renderer/upload.tsx | 21 +++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/web/components/flow/node-renderer/upload.tsx b/web/components/flow/node-renderer/upload.tsx index e16d59cd5..cfb6b1554 100644 --- a/web/components/flow/node-renderer/upload.tsx +++ b/web/components/flow/node-renderer/upload.tsx @@ -1,7 +1,7 @@ -import React,{useState,useRef}from 'react'; +import React, { useState, useRef } from 'react'; import { UploadOutlined } from '@ant-design/icons'; import type { UploadProps } from 'antd'; -import { Button, Upload,message } from 'antd'; +import { Button, Upload, message } from 'antd'; import { convertKeysToCamelCase } from '@/utils/flow'; import { IFlowNodeParameter } from '@/types/flow'; import { useTranslation } from 'react-i18next'; @@ -13,22 +13,22 @@ type Props = { }; export const RenderUpload = (params: Props) => { const { t } = useTranslation(); - const urlList =useRef([]); + const urlList = useRef([]); const { data, defaultValue, onChange } = params; const attr = convertKeysToCamelCase(data.ui?.attr || {}); const [uploading, setUploading] = useState(false); - const getUploadSuccessUrl = (url:string) => { + const getUploadSuccessUrl = (url: string) => { if (urlList.current.length === data.ui.attr.max_count) { urlList.current.pop(); } urlList.current.push(url) - + onChange(urlList.current.toString()) } - const handleFileRemove = (file:any) => { + const handleFileRemove = (file: any) => { const index = urlList.current.indexOf(file.response.data[0].uri); if (index !== -1) { urlList.current.splice(index, 1); @@ -57,9 +57,12 @@ export const RenderUpload = (params: Props) => { }; return ( - - }>{t('UploadData')} - + {data.is_list ? + }>{t('UploadData')} + : + }>{t('UploadData')} + } + ) From 132afece0d67842351271a966b78cdff8f39d3be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=A5=E5=BF=97=E5=8B=87?= Date: Tue, 20 Aug 2024 22:20:14 +0800 Subject: [PATCH 61/89] feat: upload file type --- web/components/flow/node-renderer/upload.tsx | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/web/components/flow/node-renderer/upload.tsx b/web/components/flow/node-renderer/upload.tsx index cfb6b1554..35da07eb4 100644 --- a/web/components/flow/node-renderer/upload.tsx +++ b/web/components/flow/node-renderer/upload.tsx @@ -14,11 +14,11 @@ type Props = { export const RenderUpload = (params: Props) => { const { t } = useTranslation(); const urlList = useRef([]); - const { data, defaultValue, onChange } = params; const attr = convertKeysToCamelCase(data.ui?.attr || {}); const [uploading, setUploading] = useState(false); + const [uploadType, setUploadType] = useState(''); const getUploadSuccessUrl = (url: string) => { if (urlList.current.length === data.ui.attr.max_count) { @@ -55,11 +55,15 @@ export const RenderUpload = (params: Props) => { } }, }; + + if (data.ui?.file_types && Array.isArray(data.ui?.file_types)) { + setUploadType(data.ui?.file_types.toString()) + } return ( - {data.is_list ? + {data.is_list ? }>{t('UploadData')} - : + : }>{t('UploadData')} } From 06d7bd41fb2ccc6f8563ce1278df2bdae042e96d Mon Sep 17 00:00:00 2001 From: Fangyin Cheng Date: Tue, 20 Aug 2024 22:22:55 +0800 Subject: [PATCH 62/89] chore: Fix merge code error --- dbgpt/core/awel/flow/base.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/dbgpt/core/awel/flow/base.py b/dbgpt/core/awel/flow/base.py index e88a47d54..7ab7cbb34 100644 --- a/dbgpt/core/awel/flow/base.py +++ b/dbgpt/core/awel/flow/base.py @@ -1198,15 +1198,6 @@ async def refresh( request, trigger, system_app ) - def refresh( - self, key: str, is_operator: bool, request: List[RefreshOptionRequest] - ) -> Dict: - """Refresh the metadata.""" - if is_operator: - return _get_operator_class(key).metadata.refresh(request) # type: ignore - else: - return _get_resource_class(key).metadata.refresh(request) - _OPERATOR_REGISTRY: FlowRegistry = FlowRegistry() From 0723cda2d81ca844ff790807befe9d33eaa4b523 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A8=E6=AC=A3?= Date: Tue, 20 Aug 2024 23:05:06 +0800 Subject: [PATCH 63/89] feat: add prefix icon support to RenderInput component --- web/components/flow/node-renderer/input.tsx | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/web/components/flow/node-renderer/input.tsx b/web/components/flow/node-renderer/input.tsx index 60c559baa..436899fc1 100644 --- a/web/components/flow/node-renderer/input.tsx +++ b/web/components/flow/node-renderer/input.tsx @@ -1,6 +1,8 @@ import { IFlowNodeParameter } from '@/types/flow'; import { convertKeysToCamelCase } from '@/utils/flow'; import { Input } from 'antd'; +import * as Icons from '@ant-design/icons'; +import { FC } from 'react'; type Props = { data: IFlowNodeParameter; @@ -8,9 +10,27 @@ type Props = { onChange: (value: any) => void; }; +const isValidIconComponent = (component: any): component is FC => { + console.log('222', typeof component); + + return component && typeof component === 'function'; +}; + +const getIconComponent = (iconString: string) => { + const match = iconString.match(/^icon:(\w+)$/); + if (match) { + const iconName = match[1] as keyof typeof Icons; + const IconComponent = Icons[iconName]; + // @ts-ignore + return IconComponent ? : null; + } + return null; +}; + export const RenderInput = (params: Props) => { const { data, defaultValue, onChange } = params; const attr = convertKeysToCamelCase(data.ui?.attr || {}); + attr.prefix = getIconComponent(data.ui?.attr?.prefix || ''); return ( Date: Wed, 21 Aug 2024 21:56:16 +0800 Subject: [PATCH 64/89] feat: fix import typo in node-renderer/index.ts --- web/app/i18n.ts | 24 +- web/components/flow/canvas-node.tsx | 74 +++--- web/components/flow/node-handler.tsx | 4 +- web/components/flow/node-param-handler.tsx | 245 ++++-------------- .../flow/node-renderer/cascader.tsx | 11 +- .../flow/node-renderer/checkbox.tsx | 11 +- .../{codeEditor.tsx => code-editor.tsx} | 49 ++-- .../flow/node-renderer/date-picker.tsx | 26 +- web/components/flow/node-renderer/index.ts | 2 +- web/components/flow/node-renderer/input.tsx | 29 +-- .../flow/node-renderer/password.tsx | 11 +- web/components/flow/node-renderer/radio.tsx | 18 +- web/components/flow/node-renderer/select.tsx | 15 +- web/components/flow/node-renderer/slider.tsx | 36 +-- .../flow/node-renderer/textarea.tsx | 12 +- .../flow/node-renderer/time-picker.tsx | 16 +- .../flow/node-renderer/tree-select.tsx | 13 +- web/components/flow/node-renderer/upload.tsx | 63 +++-- .../flow/node-renderer/variables.tsx | 22 +- 19 files changed, 186 insertions(+), 495 deletions(-) rename web/components/flow/node-renderer/{codeEditor.tsx => code-editor.tsx} (50%) diff --git a/web/app/i18n.ts b/web/app/i18n.ts index b15332d35..f4ae1c6bd 100644 --- a/web/app/i18n.ts +++ b/web/app/i18n.ts @@ -1,12 +1,13 @@ +import { Domain } from '@mui/icons-material'; import i18n from 'i18next'; import { initReactI18next } from 'react-i18next'; const en = { - UploadDataSuccessfully: 'file uploaded successfully', - UploadDataFailed: 'file upload failed', - UploadData: 'Upload Data', - CodeEditor: 'Code Editor:', - openCodeEditor:'Open Code Editor', + Upload_Data_Successfully: 'file uploaded successfully', + Upload_Data_Failed: 'file upload failed', + Upload_Data: 'Upload Data', + Code_Editor: 'Code Editor', + Open_Code_Editor: 'Open Code Editor', Knowledge_Space: 'Knowledge', space: 'space', Vector: 'Vector', @@ -23,6 +24,7 @@ const en = { Please_select_file: 'Please select one file', Description: 'Description', Storage: 'Storage', + Domain: 'Domain', Please_input_the_description: 'Please input the description', Please_select_the_storage: 'Please select the storage', Please_select_the_domain_type: 'Please select the domain type', @@ -229,7 +231,7 @@ const en = { Chinese: 'Chinese', English: 'English', refreshSuccess: 'Refresh Success', - Download: 'Download' + Download: 'Download', } as const; export type I18nKeys = keyof typeof en; @@ -239,11 +241,11 @@ export interface Resources { } const zh: Resources['translation'] = { - UploadDataSuccessfully: '文件上传成功', - UploadDataFailed: '文件上传失败', - UploadData: '上传数据', - CodeEditor: '代码编辑:', - openCodeEditor: '打开代码编辑器', + Upload_Data_Successfully: '文件上传成功', + Upload_Data_Failed: '文件上传失败', + Upload_Data: '上传数据', + Code_Editor: '代码编辑器', + Open_Code_Editor: '打开代码编辑器', Knowledge_Space: '知识库', space: '知识库', Vector: '向量', diff --git a/web/components/flow/canvas-node.tsx b/web/components/flow/canvas-node.tsx index f91e0f128..461d933f7 100644 --- a/web/components/flow/canvas-node.tsx +++ b/web/components/flow/canvas-node.tsx @@ -71,26 +71,22 @@ const CanvasNode: React.FC = ({ data }) => { reactFlow.setEdges((edges) => edges.filter((edge) => edge.source !== node.id && edge.target !== node.id)); } - // function onChange(value: any) { - // data.value = value; - // } - - function onValuesChange(changedValues: any, allValues: any) { - // onChange(changedValues); - console.log('Changed xxx', changedValues); - console.log('All xxx', allValues); - console.log('xxxx', parameters); + function updateCurrentNodeValue(changedKey: string, changedVal: any) { + parameters.forEach((item) => { + if (item.name === changedKey) { + item.value = changedVal; + } + }); + } - const [changedKey, changedVal] = Object.entries(changedValues)[0]; - console.log('====', changedKey, changedVal); + async function updateDependsNodeValue(changedKey: string, changedVal: any) { + if (!changedVal) return; - // 获取以当前改变项目为 refresh_depends 的参数name - const needChangeNodes = parameters.filter(({ ui }) => ui?.refresh_depends?.includes(changedKey)); - console.log('needChangeNodes====', needChangeNodes); + const dependParamNodes = parameters.filter(({ ui }) => ui?.refresh_depends?.includes(changedKey)); - if (needChangeNodes?.length === 0) return; + if (dependParamNodes?.length === 0) return; - needChangeNodes.forEach(async (item) => { + dependParamNodes.forEach(async (item) => { const params = { id: removeIndexFromNodeId(data?.id), type_name: data.type_name, @@ -98,11 +94,11 @@ const CanvasNode: React.FC = ({ data }) => { flow_type: 'operator' as const, refresh: [ { - name: item.name, // 要刷新的参数的name + name: item.name, depends: [ { - name: changedKey, // 依赖的参数的name - value: changedVal, // 依赖的参数的值 + name: changedKey, + value: changedVal, has_value: true, }, ], @@ -110,29 +106,23 @@ const CanvasNode: React.FC = ({ data }) => { ], }; - // const params = { - // id: 'operator_example_refresh_operator___$$___example___$$___v1', - // type_name: 'ExampleFlowRefreshOperator', - // type_cls: 'unusual_prefix_90027f35e50ecfda77e3c7c7b20a0272d562480c_awel_flow_ui_components.ExampleFlowRefreshOperator', - // flow_type: 'operator' as const, - // refresh: [ - // { - // name: 'recent_time', // 要刷新的参数的name - // depends: [ - // { - // name: 'time_interval', // 依赖的参数的name - // value: 3, // 依赖的参数的值 - // has_value: true, - // }, - // ], - // }, - // ], - // }; - const [_, res] = await apiInterceptors(refreshFlowNodeById(params)); + // TODO: update node value + console.log('res', res); }); } + function onParameterValuesChange(changedValues: any, allValues: any) { + // TODO: update node value + console.log('Changed xxx', changedValues); + console.log('All xxx', allValues); + + const [changedKey, changedVal] = Object.entries(changedValues)[0]; + + updateCurrentNodeValue(changedKey, changedVal); + updateDependsNodeValue(changedKey, changedVal); + } + return ( = ({ data }) => { {parameters?.length > 0 && ( - {/* */} - - + {parameters?.map((item, index) => ( - + ))} - - {/* */} )} diff --git a/web/components/flow/node-handler.tsx b/web/components/flow/node-handler.tsx index e28a50261..4feed9ccd 100644 --- a/web/components/flow/node-handler.tsx +++ b/web/components/flow/node-handler.tsx @@ -101,10 +101,8 @@ const NodeHandler: React.FC = ({ node, data, type, label, inde isValidConnection={(connection) => isValidConnection(connection)} /> = ({ node, data, label, index }) => { - function onChange(value: any) { - data.value = value; - } - - function renderLabelWithTooltip(data: IFlowNodeParameter) { - return ( - - - {data.label} - - - - - ); - } - +const NodeParamHandler: React.FC = ({ node, paramData, label, index }) => { // render node parameters based on AWEL1.0 - // function renderNodeWithoutUiParam(data: IFlowNodeParameter) { - // let defaultValue = data.value ?? data.default; - - // switch (data.type_name) { - // case 'int': - // case 'float': - // return ( - // - // {renderLabelWithTooltip(data)} - // { - // console.log('value', value); - - // onChange(value); - // }} - // /> - // - // ); - // case 'str': - // return ( - // - // {renderLabelWithTooltip(data)} - // {data.options?.length > 0 ? ( - // ({ label: item.label, value: item.value }))} - // onChange={onChange} - // /> - // ) : ( - // { - // onChange(e.target.value); - // }} - // /> - // )} - // - // ); - // case 'bool': - // defaultValue = defaultValue === 'False' ? false : defaultValue; - // defaultValue = defaultValue === 'True' ? true : defaultValue; - // return ( - // - // {renderLabelWithTooltip(data)} - // { - // onChange(e.target.checked); - // }} - // /> - // - // ); - // } - // } function renderNodeWithoutUiParam(data: IFlowNodeParameter) { let defaultValue = data.value ?? data.default; @@ -114,139 +38,83 @@ const NodeParamHandler: React.FC = ({ node, data, label, case 'float': return ( {data.label}} tooltip={data.description ? { title: data.description, icon: } : ''} - rules={[{ required: !data.optional }]} > - + ); - { - /* - {renderLabelWithTooltip(data)} - { - console.log('value', value); - onChange(value); - }} - /> - */ - } case 'str': return ( {data.label}} tooltip={data.description ? { title: data.description, icon: } : ''} - rules={[{ required: !data.optional }]} > {data.options?.length > 0 ? ( - ({ label: item.label, value: item.value }))} - onChange={onChange} - /> + ({ label: item.label, value: item.value }))} /> ) : ( - { - onChange(e.target.value); - }} - /> + )} - // - // {renderLabelWithTooltip(data)} - // {data.options?.length > 0 ? ( - // ({ label: item.label, value: item.value }))} - // onChange={onChange} - // /> - // ) : ( - // { - // onChange(e.target.value); - // }} - // /> - // )} - // ); + case 'bool': defaultValue = defaultValue === 'False' ? false : defaultValue; defaultValue = defaultValue === 'True' ? true : defaultValue; return ( - // - // {renderLabelWithTooltip(data)} - // { - // onChange(e.target.checked); - // }} - // /> - // - {data.label}} tooltip={data.description ? { title: data.description, icon: } : ''} - rules={[{ required: !data.optional }]} > - { - onChange(e.target.checked); - }} - /> + ); } } - function renderComponentByType(type: string, props?: any) { + function renderComponentByType(type: string, data: IFlowNodeParameter) { switch (type) { case 'select': - return ; + return renderSelect(data); case 'cascader': - return ; + return renderCascader(data); case 'checkbox': - return ; + return renderCheckbox(data); case 'radio': - return ; + return renderRadio(data); case 'input': - return ; + return renderInput(data); case 'text_area': - return ; + return renderTextArea(data); case 'slider': - return ; + return renderSlider(data); case 'date_picker': - return ; + return renderDatePicker(data); case 'time_picker': - return ; + return renderTimePicker(data); case 'tree_select': - return ; + return renderTreeSelect(data); case 'password': - return ; + return renderPassword(data); case 'upload': - return ; + return renderUpload({ data }); case 'variables': - return ; + return renderVariables(data); case 'code_editor': - return ; + return renderCodeEditor({ data }); default: return null; } @@ -256,31 +124,26 @@ const NodeParamHandler: React.FC = ({ node, data, label, function renderNodeWithUiParam(data: IFlowNodeParameter) { const { refresh_depends, ui_type } = data.ui; let defaultValue = data.value ?? data.default; - const props = { data, defaultValue, onChange }; return ( - // - // {renderLabelWithTooltip(data)} - // {renderComponentByType(data?.ui?.ui_type, props)} - // - {data.label}} - tooltip={data.description ? { title: data.description, icon: } : ''} {...(refresh_depends && { dependencies: refresh_depends })} - rules={[{ required: !data.optional }]} + {...(data.description && { tooltip: { title: data.description, icon: } })} > - {renderComponentByType(ui_type, props)} + {renderComponentByType(ui_type, data)} ); } - if (data.category === 'resource') { - return ; - } else if (data.category === 'common') { - return data?.ui ? renderNodeWithUiParam(data) : renderNodeWithoutUiParam(data); + if (paramData.category === 'resource') { + return ; + } else if (paramData.category === 'common') { + return paramData?.ui ? renderNodeWithUiParam(paramData) : renderNodeWithoutUiParam(paramData); } }; diff --git a/web/components/flow/node-renderer/cascader.tsx b/web/components/flow/node-renderer/cascader.tsx index 118c3ac40..d16d59e58 100644 --- a/web/components/flow/node-renderer/cascader.tsx +++ b/web/components/flow/node-renderer/cascader.tsx @@ -2,24 +2,15 @@ import { IFlowNodeParameter } from '@/types/flow'; import { convertKeysToCamelCase } from '@/utils/flow'; import { Cascader } from 'antd'; -type Props = { - data: IFlowNodeParameter; - defaultValue: any; - onChange: (value: any) => void; -}; - -export const RenderCascader = (params: Props) => { - const { data, defaultValue, onChange } = params; +export const renderCascader = (data: IFlowNodeParameter) => { const attr = convertKeysToCamelCase(data.ui?.attr || {}); return ( ); }; diff --git a/web/components/flow/node-renderer/checkbox.tsx b/web/components/flow/node-renderer/checkbox.tsx index 0500d3498..973b58b2f 100644 --- a/web/components/flow/node-renderer/checkbox.tsx +++ b/web/components/flow/node-renderer/checkbox.tsx @@ -2,20 +2,13 @@ import { IFlowNodeParameter } from '@/types/flow'; import { convertKeysToCamelCase } from '@/utils/flow'; import { Checkbox } from 'antd'; -type Props = { - data: IFlowNodeParameter; - defaultValue: any; - onChange: (value: any) => void; -}; - -export const RenderCheckbox = (params: Props) => { - const { data, defaultValue, onChange } = params; +export const renderCheckbox = (data: IFlowNodeParameter) => { const attr = convertKeysToCamelCase(data.ui?.attr || {}); return ( data.options?.length > 0 && ( - + ) ); diff --git a/web/components/flow/node-renderer/codeEditor.tsx b/web/components/flow/node-renderer/code-editor.tsx similarity index 50% rename from web/components/flow/node-renderer/codeEditor.tsx rename to web/components/flow/node-renderer/code-editor.tsx index 6be8de655..d0b55dfce 100644 --- a/web/components/flow/node-renderer/codeEditor.tsx +++ b/web/components/flow/node-renderer/code-editor.tsx @@ -1,5 +1,5 @@ import React, { useState, useMemo } from 'react'; -import { Button, Modal } from 'antd'; +import { Button, Form, Modal } from 'antd'; import Editor from '@monaco-editor/react'; import { IFlowNodeParameter } from '@/types/flow'; import { convertKeysToCamelCase } from '@/utils/flow'; @@ -7,13 +7,12 @@ import { useTranslation } from 'react-i18next'; type Props = { data: IFlowNodeParameter; - defaultValue: any; - onChange: (value: any) => void; + defaultValue?: any; + onChange?: (value: any) => void; }; -export const RenderCodeEditor = (params: Props) => { +export const renderCodeEditor = (params: Props) => { const { t } = useTranslation(); - const { data, defaultValue, onChange } = params; const attr = convertKeysToCamelCase(data.ui?.attr || {}); @@ -38,27 +37,29 @@ export const RenderCodeEditor = (params: Props) => { }, [data?.ui?.editor?.width]); return ( - - - {t('openCodeEditor')} + + + {t('Open_Code_Editor')} - - + +
{data.label}: {data.description && ( @@ -86,11 +86,11 @@ const NodeParamHandler: React.FC = ({ node, data, label, )}
{data.label}: {data.description && ( diff --git a/web/components/flow/node-renderer/cascader.tsx b/web/components/flow/node-renderer/cascader.tsx index 4c9d69ac8..118c3ac40 100644 --- a/web/components/flow/node-renderer/cascader.tsx +++ b/web/components/flow/node-renderer/cascader.tsx @@ -13,15 +13,13 @@ export const RenderCascader = (params: Props) => { const attr = convertKeysToCamelCase(data.ui?.attr || {}); return ( -
- {data.label}: - {data.description && ( - - - - )} -
- {data.label}: - {data.description && ( - - - - )} - { - onChange(e.target.checked); - }} - /> -