diff --git a/lilac/formats/openai_json.py b/lilac/formats/openai_json.py index 3b861daec..47a198047 100644 --- a/lilac/formats/openai_json.py +++ b/lilac/formats/openai_json.py @@ -32,7 +32,7 @@ class OpenAIJSON(DatasetFormat): Taken from: https://platform.openai.com/docs/api-reference/chat """ - name: ClassVar[str] = 'openai_json' + name: ClassVar[str] = 'OpenAI JSON' data_schema: Schema = schema( { 'messages': [ @@ -88,7 +88,7 @@ class OpenAIConversationJSON(DatasetFormat): Note that here "messages" is "conversation" for support with common datasets. """ - name: ClassVar[str] = 'openai_conversation_json' + name: ClassVar[str] = 'OpenAI Conversation JSON' data_schema: Schema = schema( { 'conversation': [ diff --git a/lilac/formats/openchat.py b/lilac/formats/openchat.py index 815268e0d..9bee2ee32 100644 --- a/lilac/formats/openchat.py +++ b/lilac/formats/openchat.py @@ -10,7 +10,7 @@ class OpenChat(DatasetFormat): """OpenChat format.""" - name: ClassVar[str] = 'openchat' + name: ClassVar[str] = 'OpenChat' data_schema: Schema = schema( { 'items': [ diff --git a/lilac/formats/sharegpt.py b/lilac/formats/sharegpt.py index 49a6b1038..30134205d 100644 --- a/lilac/formats/sharegpt.py +++ b/lilac/formats/sharegpt.py @@ -37,7 +37,7 @@ def _sharegpt_selector(item: Item, conv_from: str) -> str: class ShareGPT(DatasetFormat): """ShareGPT format.""" - name: ClassVar[str] = 'sharegpt' + name: ClassVar[str] = 'ShareGPT' data_schema: Schema = schema( { 'conversations': [ diff --git a/lilac/router_dataset_signals.py b/lilac/router_dataset_signals.py index b20455a60..700668752 100644 --- a/lilac/router_dataset_signals.py +++ b/lilac/router_dataset_signals.py @@ -7,7 +7,6 @@ from pydantic import Field as PydanticField from .auth import UserInfo, get_session_user, get_user_access -from .config import ClusterInputSelectorConfig from .dataset_format import DatasetFormatInputSelector, get_dataset_format_cls from .db_manager import get_dataset from .router_utils import RouteErrorHandler @@ -85,7 +84,7 @@ class ClusterOptions(BaseModel): """The request for the cluster endpoint.""" input: Optional[Path] = None - input_selector: Optional[ClusterInputSelectorConfig] = None + input_selector: Optional[str] = None output_path: Optional[Path] = None use_garden: bool = PydanticField( @@ -111,9 +110,6 @@ def cluster( if not get_user_access(user).dataset.compute_signals: raise HTTPException(401, 'User does not have access to compute clusters over this dataset.') - if options.input is None and options.input_selector is None: - raise HTTPException(400, 'Either input or input_selector must be provided.') - dataset = get_dataset(namespace, dataset_name) manifest = dataset.manifest() @@ -129,21 +125,15 @@ def cluster( format_cls = get_dataset_format_cls(dataset_format.name) if format_cls is None: - raise ValueError(f'Unknown format: {c.input_selector.format}') + raise ValueError(f'Unknown format: {dataset_format.name}') - format = format_cls() - if format != manifest.dataset_format: - raise ValueError( - f'Cluster input format {c.input_selector.format} does not match ' - f'dataset format {manifest.dataset_format}' - ) - - cluster_input = format_cls.input_selectors[c.input_selector.selector] + cluster_input = format_cls.input_selectors[options.input_selector] task_name = ( - f'[{namespace}/{dataset_name}] Clustering using input selector ' - f'"{options.input_selector.selector}"' + f'[{namespace}/{dataset_name}] Clustering using input selector ' f'"{options.input_selector}"' ) + else: + raise HTTPException(400, 'Either input or input_selector must be provided.') task_id = get_task_manager().task_id(name=task_name) diff --git a/web/blueprint/src/lib/components/ComputeClusterModal.svelte b/web/blueprint/src/lib/components/ComputeClusterModal.svelte index 5856b28ff..1be5de9d9 100644 --- a/web/blueprint/src/lib/components/ComputeClusterModal.svelte +++ b/web/blueprint/src/lib/components/ComputeClusterModal.svelte @@ -18,7 +18,11 @@