diff --git a/src/crewai/cli/create_crew.py b/src/crewai/cli/create_crew.py index c658b0de15..8626924121 100644 --- a/src/crewai/cli/create_crew.py +++ b/src/crewai/cli/create_crew.py @@ -93,50 +93,66 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None): folder_path, folder_name, class_name = create_folder_structure(name, parent_folder) env_vars = load_env_vars(folder_path) if not skip_provider: - if not provider: - provider_models = get_provider_data() - if not provider_models: - return - - existing_provider = None - for provider, env_keys in ENV_VARS.items(): - if any( - "key_name" in details and details["key_name"] in env_vars - for details in env_keys - ): - existing_provider = provider - break - - if existing_provider: - if not click.confirm( - f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?" - ): - click.secho("Keeping existing provider configuration.", fg="yellow") - return - provider_models = get_provider_data() if not provider_models: + click.secho("Could not retrieve provider data.", fg="red") return - while True: - selected_provider = select_provider(provider_models) - if selected_provider is None: # User typed 'q' - click.secho("Exiting...", fg="yellow") - sys.exit(0) - if selected_provider: # Valid selection - break - click.secho( - "No provider selected. Please try again or press 'q' to exit.", fg="red" - ) + selected_provider = None + + if provider: + provider = provider.lower() + if provider in provider_models: + selected_provider = provider + click.secho(f"Using specified provider: {selected_provider.capitalize()}", fg="green") + else: + click.secho(f"Warning: Specified provider '{provider}' is not recognized. Please select one.", fg="yellow") + + if not selected_provider: + existing_provider = None + for p, env_keys in ENV_VARS.items(): + if any( + "key_name" in details and details["key_name"] in env_vars + for details in env_keys + ): + existing_provider = p + break + + if existing_provider: + if not click.confirm( + f"Found existing environment variable configuration for {existing_provider.capitalize()}. Do you want to override it?" + ): + click.secho("Keeping existing provider configuration. Exiting provider setup.", fg="yellow") + copy_template_files(folder_path, name, class_name, parent_folder) + click.secho(f"Crew '{name}' created successfully!", fg="green") + click.secho(f"To run your crew, cd into '{folder_name}' and run 'crewai run'", fg="cyan") + return + else: + pass + + while True: + selected_provider = select_provider(provider_models) + if selected_provider is None: + click.secho("Exiting...", fg="yellow") + sys.exit(0) + if selected_provider: + break + click.secho( + "No provider selected. Please try again or press 'q' to exit.", fg="red" + ) + + if not selected_provider: + click.secho("Provider selection failed. Exiting.", fg="red") + sys.exit(1) + - # Check if the selected provider has predefined models if selected_provider in MODELS and MODELS[selected_provider]: while True: selected_model = select_model(selected_provider, provider_models) - if selected_model is None: # User typed 'q' + if selected_model is None: click.secho("Exiting...", fg="yellow") sys.exit(0) - if selected_model: # Valid selection + if selected_model: break click.secho( "No model selected. Please try again or press 'q' to exit.", @@ -144,17 +160,14 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None): ) env_vars["MODEL"] = selected_model - # Check if the selected provider requires API keys if selected_provider in ENV_VARS: provider_env_vars = ENV_VARS[selected_provider] for details in provider_env_vars: if details.get("default", False): - # Automatically add default key-value pairs for key, value in details.items(): if key not in ["prompt", "key_name", "default"]: env_vars[key] = value elif "key_name" in details: - # Prompt for non-default key-value pairs prompt = details["prompt"] key_name = details["key_name"] api_key_value = click.prompt(prompt, default="", show_default=False) @@ -167,41 +180,12 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None): click.secho("API keys and model saved to .env file", fg="green") else: click.secho( - "No API keys provided. Skipping .env file creation.", fg="yellow" + "No API keys provided or required by provider. Skipping .env file creation.", fg="yellow" ) click.secho(f"Selected model: {env_vars.get('MODEL', 'N/A')}", fg="green") - package_dir = Path(__file__).parent - templates_dir = package_dir / "templates" / "crew" - - root_template_files = ( - [".gitignore", "pyproject.toml", "README.md", "knowledge/user_preference.txt"] - if not parent_folder - else [] - ) - tools_template_files = ["tools/custom_tool.py", "tools/__init__.py"] - config_template_files = ["config/agents.yaml", "config/tasks.yaml"] - src_template_files = ( - ["__init__.py", "main.py", "crew.py"] if not parent_folder else ["crew.py"] - ) - - for file_name in root_template_files: - src_file = templates_dir / file_name - dst_file = folder_path / file_name - copy_template(src_file, dst_file, name, class_name, folder_name) - - src_folder = folder_path / "src" / folder_name if not parent_folder else folder_path - - for file_name in src_template_files: - src_file = templates_dir / file_name - dst_file = src_folder / file_name - copy_template(src_file, dst_file, name, class_name, folder_name) - - if not parent_folder: - for file_name in tools_template_files + config_template_files: - src_file = templates_dir / file_name - dst_file = src_folder / file_name - copy_template(src_file, dst_file, name, class_name, folder_name) + copy_template_files(folder_path, name, class_name, parent_folder) - click.secho(f"Crew {name} created successfully!", fg="green", bold=True) + click.secho(f"Crew '{name}' created successfully!", fg="green") + click.secho(f"To run your crew, cd into '{folder_name}' and run 'crewai run'", fg="cyan") diff --git a/tests/cli/test_create_crew.py b/tests/cli/test_create_crew.py new file mode 100644 index 0000000000..ce4cb993f2 --- /dev/null +++ b/tests/cli/test_create_crew.py @@ -0,0 +1,142 @@ +import pytest +from click.testing import CliRunner +from unittest.mock import patch, MagicMock +from pathlib import Path +import sys + +# Ensure the src directory is in the Python path for imports +sys.path.insert(0, str(Path(__file__).parent.parent.parent / 'src')) + +from crewai.cli.cli import crewai +from crewai.cli import create_crew +from crewai.cli.constants import MODELS, ENV_VARS + +# Mock provider data for testing +MOCK_PROVIDER_DATA = { + 'openai': {'models': ['gpt-4', 'gpt-3.5-turbo']}, + 'google': {'models': ['gemini-pro']}, + 'anthropic': {'models': ['claude-3-opus']} +} + +MOCK_VALID_PROVIDERS = list(MOCK_PROVIDER_DATA.keys()) + +@pytest.fixture +def runner(): + return CliRunner() + +@pytest.fixture(autouse=True) +def isolate_fs(monkeypatch): + # Prevent tests from interacting with the actual filesystem or real env vars + monkeypatch.setattr(Path, 'mkdir', lambda *args, **kwargs: None) + monkeypatch.setattr(Path, 'exists', lambda *args: False) # Assume folders don't exist initially + monkeypatch.setattr(create_crew, 'load_env_vars', lambda *args: {}) # Start with empty env vars + monkeypatch.setattr(create_crew, 'write_env_file', lambda *args, **kwargs: None) + monkeypatch.setattr(create_crew, 'copy_template_files', lambda *args, **kwargs: None) + +@patch('crewai.cli.create_crew.get_provider_data', return_value=MOCK_PROVIDER_DATA) +@patch('crewai.cli.create_crew.select_provider') +@patch('crewai.cli.create_crew.select_model') +@patch('click.prompt') +@patch('click.confirm', return_value=True) # Default to confirming prompts +def test_create_crew_with_valid_provider(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, runner): + """Test `crewai create crew --provider `""" + result = runner.invoke(crewai, ['create', 'crew', 'testcrew', '--provider', 'openai']) + + print(f"CLI Output:\n{result.output}") # Debug output + assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}" + assert "Using specified provider: Openai" in result.output + mock_select_provider.assert_not_called() # Should not ask interactively + # Depending on whether openai needs models/keys, check select_model/prompt calls + assert "Crew 'testcrew' created successfully!" in result.output + +@patch('crewai.cli.create_crew.get_provider_data', return_value=MOCK_PROVIDER_DATA) +@patch('crewai.cli.create_crew.select_provider', return_value='google') # Simulate user selecting google +@patch('crewai.cli.create_crew.select_model', return_value='gemini-pro') +@patch('click.prompt') +@patch('click.confirm', return_value=True) +def test_create_crew_with_invalid_provider(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, runner): + """Test `crewai create crew --provider `""" + result = runner.invoke(crewai, ['create', 'crew', 'testcrew', '--provider', 'invalidprovider']) + + print(f"CLI Output:\n{result.output}") # Debug output + assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}" + assert "Warning: Specified provider 'invalidprovider' is not recognized." in result.output + mock_select_provider.assert_called_once() # Should ask interactively + # Check if subsequent steps for the selected provider (google) ran + mock_select_model.assert_called_once() + assert "Crew 'testcrew' created successfully!" in result.output + +@patch('crewai.cli.create_crew.get_provider_data', return_value=MOCK_PROVIDER_DATA) +@patch('crewai.cli.create_crew.select_provider', return_value='anthropic') # Simulate user selecting anthropic +@patch('crewai.cli.create_crew.select_model', return_value='claude-3-opus') +@patch('click.prompt', return_value='sk-abc') # Simulate API key entry +@patch('click.confirm', return_value=True) +def test_create_crew_no_provider(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, runner): + """Test `crewai create crew `""" + result = runner.invoke(crewai, ['create', 'crew', 'testcrew']) + + print(f"CLI Output:\n{result.output}") # Debug output + assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}" + assert "Using specified provider:" not in result.output # Should not mention specified provider + mock_select_provider.assert_called_once() # Should ask interactively + mock_select_model.assert_called_once() + # Check if prompt for API key was called (assuming anthropic needs one) + if 'anthropic' in ENV_VARS and any('key_name' in d for d in ENV_VARS['anthropic']): + mock_prompt.assert_called() + assert "Crew 'testcrew' created successfully!" in result.output + +@patch('crewai.cli.create_crew.get_provider_data') +@patch('crewai.cli.create_crew.select_provider') +@patch('crewai.cli.create_crew.select_model') +@patch('click.prompt') +@patch('click.confirm') +def test_create_crew_skip_provider(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, runner): + """Test `crewai create crew --skip_provider`""" + result = runner.invoke(crewai, ['create', 'crew', 'testcrew', '--skip_provider']) + + print(f"CLI Output:\n{result.output}") # Debug output + assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}" + mock_get_data.assert_not_called() + mock_select_provider.assert_not_called() + mock_select_model.assert_not_called() + mock_prompt.assert_not_called() + mock_confirm.assert_not_called() + assert "Crew 'testcrew' created successfully!" in result.output + +@patch('crewai.cli.create_crew.load_env_vars', return_value={'OPENAI_API_KEY': 'existing_key'}) # Simulate existing env +@patch('crewai.cli.create_crew.get_provider_data', return_value=MOCK_PROVIDER_DATA) +@patch('crewai.cli.create_crew.select_provider', return_value='google') # Simulate selecting new provider +@patch('crewai.cli.create_crew.select_model', return_value='gemini-pro') +@patch('click.prompt') +@patch('click.confirm', return_value=True) # User confirms override +def test_create_crew_existing_override(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, mock_load_env, runner): + """Test `crewai create crew ` with existing config and user overrides.""" + result = runner.invoke(crewai, ['create', 'crew', 'testcrew']) + + print(f"CLI Output:\n{result.output}") # Debug output + assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}" + mock_confirm.assert_called_once_with( + 'Found existing environment variable configuration for Openai. Do you want to override it?' + ) + mock_select_provider.assert_called_once() # Should ask for new provider after confirming override + assert "Crew 'testcrew' created successfully!" in result.output + +@patch('crewai.cli.create_crew.load_env_vars', return_value={'OPENAI_API_KEY': 'existing_key'}) # Simulate existing env +@patch('crewai.cli.create_crew.get_provider_data', return_value=MOCK_PROVIDER_DATA) +@patch('crewai.cli.create_crew.select_provider') +@patch('crewai.cli.create_crew.select_model') +@patch('click.prompt') +@patch('click.confirm', return_value=False) # User denies override +def test_create_crew_existing_keep(mock_confirm, mock_prompt, mock_select_model, mock_select_provider, mock_get_data, mock_load_env, runner): + """Test `crewai create crew ` with existing config and user keeps it.""" + result = runner.invoke(crewai, ['create', 'crew', 'testcrew']) + + print(f"CLI Output:\n{result.output}") # Debug output + assert result.exit_code == 0, f"CLI exited with code {result.exit_code}\nOutput: {result.output}" + mock_confirm.assert_called_once_with( + 'Found existing environment variable configuration for Openai. Do you want to override it?' + ) + assert "Keeping existing provider configuration. Exiting provider setup." in result.output + mock_select_provider.assert_not_called() # Should NOT ask for new provider + assert "Crew 'testcrew' created successfully!" in result.output +