diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index fb33a04..8187cc8 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -5,9 +5,9 @@ name: Python application on: push: - branches: [ "main" ] + branches: [ "main", "development" ] pull_request: - branches: [ "main" ] + branches: [ "main", "development" ] permissions: contents: read diff --git a/.gitignore b/.gitignore index 6da870c..ef2725c 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,6 @@ src/hackingBuddyGPT.egg-info/ build/ dist/ .coverage +src/hackingBuddyGPT/usecases/web_api_testing/openapi_spec/ +src/hackingBuddyGPT/usecases/web_api_testing/converted_files/ +/src/hackingBuddyGPT/usecases/web_api_testing/utils/openapi_spec/ diff --git a/README.md b/README.md index cc22c02..b7a64c1 100644 --- a/README.md +++ b/README.md @@ -8,11 +8,18 @@ HackingBuddyGPT helps security researchers use LLMs to discover new attack vecto We aim to become **THE go-to framework for security researchers** and pen-testers interested in using LLMs or LLM-based autonomous agents for security testing. To aid their experiments, we also offer re-usable [linux priv-esc benchmarks](https://github.com/ipa-lab/benchmark-privesc-linux) and publish all our findings as open-access reports. -How can LLMs aid or even emulate hackers? Threat actors are [already using LLMs](https://arxiv.org/abs/2307.00691), to better protect against this new threat we must learn more about LLMs' capabilities and help blue teams preparing for them. +If you want to use hackingBuddyGPT and need help selecting the best LLM for your tasks, [we have a paper comparing multiple LLMs](https://arxiv.org/abs/2310.11409). -**[Join us](https://discord.gg/vr4PhSM8yN) / Help us, more people need to be involved in the future of LLM-assisted pen-testing:** +## hackingBuddyGPT in the News -To ground our research in reality, we performed a comprehensive analysis into [understanding hackers' work](https://arxiv.org/abs/2308.07057). There seems to be a mismatch between some academic research and the daily work of penetration testers, please help us to create more visibility for this issue by citing this paper (if suitable and fitting). +- **upcoming** 2024-11-20: [Manuel Reinsperger](https://www.github.com/neverbolt) will present hackingBuddyGPT at the [European Symposium on Security and Artificial Intelligence (ESSAI)](https://essai-conference.eu/) +- 2024-07-26: The [GitHub Accelerator Showcase](https://github.blog/open-source/maintainers/github-accelerator-showcase-celebrating-our-second-cohort-and-whats-next/) features hackingBuddyGPT +- 2024-07-24: [Juergen](https://github.com/citostyle) speaks at [Open Source + mezcal night @ GitHub HQ](https://lu.ma/bx120myg) +- 2024-05-23: hackingBuddyGPT is part of [GitHub Accelerator 2024](https://github.blog/news-insights/company-news/2024-github-accelerator-meet-the-11-projects-shaping-open-source-ai/) +- 2023-12-05: [Andreas](https://github.com/andreashappe) presented hackingBuddyGPT at FSE'23 in San Francisco ([paper](https://arxiv.org/abs/2308.00121), [video](https://2023.esec-fse.org/details/fse-2023-ideas--visions-and-reflections/9/Towards-Automated-Software-Security-Testing-Augmenting-Penetration-Testing-through-L)) +- 2023-09-20: [Andreas](https://github.com/andreashappe) presented preliminary results at [FIRST AI Security SIG](https://www.first.org/global/sigs/ai-security/) + +## Original Paper hackingBuddyGPT is described in [Getting pwn'd by AI: Penetration Testing with Large Language Models ](https://arxiv.org/abs/2308.00121), help us by citing it through: @@ -29,7 +36,6 @@ hackingBuddyGPT is described in [Getting pwn'd by AI: Penetration Testing with L } ~~~ - ## Getting help If you need help or want to chat about using AI for security or education, please join our [discord server where we talk about all things AI + Offensive Security](https://discord.gg/vr4PhSM8yN)! @@ -74,12 +80,10 @@ The following would create a new (minimal) linux privilege-escalation agent. Thr template_dir = pathlib.Path(__file__).parent template_next_cmd = Template(filename=str(template_dir / "next_cmd.txt")) -@use_case("minimal_linux_privesc", "Showcase Minimal Linux Priv-Escalation") -@dataclass + class MinimalLinuxPrivesc(Agent): conn: SSHConnection = None - _sliding_history: SlidingCliHistory = None def init(self): @@ -89,10 +93,10 @@ class MinimalLinuxPrivesc(Agent): self.add_capability(SSHTestCredential(conn=self.conn)) self._template_size = self.llm.count_tokens(template_next_cmd.source) - def perform_round(self, turn): - got_root : bool = False + def perform_round(self, turn: int) -> bool: + got_root: bool = False - with self.console.status("[bold green]Asking LLM for a new command..."): + with self._log.console.status("[bold green]Asking LLM for a new command..."): # get as much history as fits into the target context size history = self._sliding_history.get_history(self.llm.context_size - llm_util.SAFETY_MARGIN - self._template_size) @@ -100,17 +104,22 @@ class MinimalLinuxPrivesc(Agent): answer = self.llm.get_response(template_next_cmd, capabilities=self.get_capability_block(), history=history, conn=self.conn) cmd = llm_util.cmd_output_fixer(answer.result) - with self.console.status("[bold green]Executing that command..."): - self.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:")) - result, got_root = self.get_capability(cmd.split(" ", 1)[0])(cmd) + with self._log.console.status("[bold green]Executing that command..."): + self._log.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:")) + result, got_root = self.get_capability(cmd.split(" ", 1)[0])(cmd) # log and output the command and its result - self.log_db.add_log_query(self._run_id, turn, cmd, result, answer) + self._log.log_db.add_log_query(self._log.run_id, turn, cmd, result, answer) self._sliding_history.add_command(cmd, result) - self.console.print(Panel(result, title=f"[bold cyan]{cmd}")) + self._log.console.print(Panel(result, title=f"[bold cyan]{cmd}")) # if we got root, we can stop the loop return got_root + + +@use_case("Showcase Minimal Linux Priv-Escalation") +class MinimalLinuxPrivescUseCase(AutonomousAgentUseCase[MinimalLinuxPrivesc]): + pass ~~~ The corresponding `next_cmd.txt` template would be: @@ -170,6 +179,9 @@ wintermute.py: error: the following arguments are required: {linux_privesc,windo # start wintermute, i.e., attack the configured virtual machine $ python wintermute.py minimal_linux_privesc + +# install dependencies for testing if you want to run the tests +$ pip install .[testing] ~~~ ## Publications about hackingBuddyGPT diff --git a/pyproject.toml b/pyproject.toml index 52546ca..8133e90 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,11 +29,14 @@ dependencies = [ 'requests == 2.32.0', 'rich == 13.7.1', 'tiktoken == 0.6.0', - 'instructor == 1.2.2', + 'instructor == 1.3.5', 'PyYAML == 6.0.1', 'python-dotenv == 1.0.1', 'pypsexec == 0.3.0', + 'pydantic == 2.8.2', 'openai == 1.28.0', + 'BeautifulSoup4', + 'nltk' ] [project.urls] @@ -54,6 +57,11 @@ pythonpath = "src" addopts = [ "--import-mode=importlib", ] +[project.optional-dependencies] +testing = [ + 'pytest', + 'pytest-mock' +] [project.scripts] wintermute = "hackingBuddyGPT.cli.wintermute:main" diff --git a/src/hackingBuddyGPT/capabilities/http_request.py b/src/hackingBuddyGPT/capabilities/http_request.py index 62f31f8..3a508d8 100644 --- a/src/hackingBuddyGPT/capabilities/http_request.py +++ b/src/hackingBuddyGPT/capabilities/http_request.py @@ -41,7 +41,16 @@ def __call__(self, ) -> str: if body is not None and body_is_base64: body = base64.b64decode(body).decode() - + if self.host[-1] != "/": + path = "/" + path + resp = self._client.request( + method, + self.host + path, + params=query, + data=body, + headers=headers, + allow_redirects=self.follow_redirects, + ) try: resp = self._client.request( method, diff --git a/src/hackingBuddyGPT/capabilities/submit_http_method.py b/src/hackingBuddyGPT/capabilities/submit_http_method.py new file mode 100644 index 0000000..7a9d40b --- /dev/null +++ b/src/hackingBuddyGPT/capabilities/submit_http_method.py @@ -0,0 +1,78 @@ +import base64 +from dataclasses import dataclass, field +from typing import Set, Dict, Callable, Literal, Optional +import inspect + +import requests +from pydantic import create_model, BaseModel + +from . import Capability + + +@dataclass +class SubmitHTTPMethod(Capability): + http_methods_format: str + valid_http_methods: Set[str] + _client = requests.Session() + host: str + follow_redirects: bool = False + success_function: Callable[[], None] = None + + + submitted_valid_http_methods: Set[str] = field(default_factory=set, init=False) + + def describe(self) -> str: + return f"Submits a HTTP methods, which is the goal of all activities. The HTTP methods format is {self.http_methods_format}. If you find a flag following the that format, that you have not yet submitted, then the most important first step is to submit it, ignoring all other possibilities of further action" + + def to_model(self) -> BaseModel: + """ + Converts the parameters of the `__call__` function of the capability to a pydantic model, that can be used to + interface with an LLM using eg instructor or the openAI function calling API. + The model will have the same name as the capability class and will have the same fields as the `__call__`, + the `__call__` method can then be accessed by calling the `execute` method of the model. + """ + sig = inspect.signature(self.__call__) + fields = {param: (param_info.annotation, ...) for param, param_info in sig.parameters.items()} + model_type = create_model(self.__class__.__name__, __doc__=self.describe(), **fields) + + def execute(model): + m = model.dict() + return self(**m) + + model_type.execute = execute + + return model_type + + def __call__(self, method: Literal["GET", "HEAD", "POST", "PUT", "DELETE", "OPTION", "PATCH"], + path: str, + query: Optional[str] = None, + body: Optional[str] = None, + body_is_base64: Optional[bool] = False, + headers: Optional[Dict[str, str]] = None + ) -> str: + + if body is not None and body_is_base64: + body = base64.b64decode(body).decode() + + resp = self._client.request( + method, + self.host + path, + params=query, + data=body, + headers=headers, + allow_redirects=self.follow_redirects, + ) + try: + resp.raise_for_status() + except requests.exceptions.HTTPError as e: + return str(e) + + headers = "\r\n".join(f"{k}: {v}" for k, v in resp.headers.items()) + if len(self.submitted_valid_http_methods) == len(self.valid_http_methods): + if self.success_function is not None: + self.success_function() + else: + return "All methods submitted, congratulations" + # turn the response into "plain text format" for responding to the prompt + return f"HTTP/1.1 {resp.status_code} {resp.reason}\r\n{headers}\r\n\r\n{resp.text}""" + diff --git a/src/hackingBuddyGPT/capabilities/yamlFile.py b/src/hackingBuddyGPT/capabilities/yamlFile.py new file mode 100644 index 0000000..e46f357 --- /dev/null +++ b/src/hackingBuddyGPT/capabilities/yamlFile.py @@ -0,0 +1,44 @@ +from dataclasses import dataclass, field +from typing import Tuple, List + +import yaml + +from . import Capability + +@dataclass +class YAMLFile(Capability): + + def describe(self) -> str: + return "Takes a Yaml file and updates it with the given information" + + def __call__(self, yaml_str: str) -> str: + """ + Updates a YAML string based on provided inputs and returns the updated YAML string. + + Args: + yaml_str (str): Original YAML content in string form. + updates (dict): A dictionary representing the updates to be applied. + + Returns: + str: Updated YAML content as a string. + """ + try: + # Load the YAML content from string + data = yaml.safe_load(yaml_str) + + print(f'Updates:{yaml_str}') + + # Apply updates from the updates dictionary + #for key, value in updates.items(): + # if key in data: + # data[key] = value + # else: + # print(f"Warning: Key '{key}' not found in the original data. Adding new key.") + # data[key] = value + # + ## Convert the updated dictionary back into a YAML string + #updated_yaml_str = yaml.safe_dump(data, sort_keys=False) + #return updated_yaml_str + except yaml.YAMLError as e: + print(f"Error processing YAML data: {e}") + return "None" \ No newline at end of file diff --git a/src/hackingBuddyGPT/cli/wintermute.py b/src/hackingBuddyGPT/cli/wintermute.py index 85552b3..4f6f0c1 100644 --- a/src/hackingBuddyGPT/cli/wintermute.py +++ b/src/hackingBuddyGPT/cli/wintermute.py @@ -8,12 +8,13 @@ def main(): parser = argparse.ArgumentParser() subparser = parser.add_subparsers(required=True) for name, use_case in use_cases.items(): - use_case.build_parser(subparser.add_parser( + subb = subparser.add_parser( name=use_case.name, help=use_case.description - )) - - parsed = parser.parse_args(sys.argv[1:]) + ) + use_case.build_parser(subb) + x= sys.argv[1:] + parsed = parser.parse_args(x) instance = parsed.use_case(parsed) instance.init() instance.run() diff --git a/src/hackingBuddyGPT/usecases/agents.py b/src/hackingBuddyGPT/usecases/agents.py index 003d455..a018b58 100644 --- a/src/hackingBuddyGPT/usecases/agents.py +++ b/src/hackingBuddyGPT/usecases/agents.py @@ -4,18 +4,33 @@ from rich.panel import Panel from typing import Dict +from hackingBuddyGPT.usecases.base import Logger from hackingBuddyGPT.utils import llm_util - from hackingBuddyGPT.capabilities.capability import Capability, capabilities_to_simple_text_handler -from .common_patterns import RoundBasedUseCase +from hackingBuddyGPT.utils.openai.openai_llm import OpenAIConnection + @dataclass -class Agent(RoundBasedUseCase, ABC): +class Agent(ABC): _capabilities: Dict[str, Capability] = field(default_factory=dict) _default_capability: Capability = None + _log: Logger = None + + llm: OpenAIConnection = None def init(self): - super().init() + pass + + def before_run(self): + pass + + def after_run(self): + pass + + # callback + @abstractmethod + def perform_round(self, turn: int) -> bool: + pass def add_capability(self, cap: Capability, default: bool = False): self._capabilities[cap.get_name()] = cap @@ -29,6 +44,7 @@ def get_capability_block(self) -> str: capability_descriptions, _parser = capabilities_to_simple_text_handler(self._capabilities) return "You can either\n\n" + "\n".join(f"- {description}" for description in capability_descriptions.values()) + @dataclass class AgentWorldview(ABC): @@ -40,6 +56,7 @@ def to_template(self): def update(self, capability, cmd, result): pass + class TemplatedAgent(Agent): _state: AgentWorldview = None @@ -59,7 +76,7 @@ def set_template(self, template:str): def perform_round(self, turn:int) -> bool: got_root : bool = False - with self.console.status("[bold green]Asking LLM for a new command..."): + with self._log.console.status("[bold green]Asking LLM for a new command..."): # TODO output/log state options = self._state.to_template() options.update({ @@ -70,16 +87,16 @@ def perform_round(self, turn:int) -> bool: answer = self.llm.get_response(self._template, **options) cmd = llm_util.cmd_output_fixer(answer.result) - with self.console.status("[bold green]Executing that command..."): - self.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:")) + with self._log.console.status("[bold green]Executing that command..."): + self._log.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:")) capability = self.get_capability(cmd.split(" ", 1)[0]) result, got_root = capability(cmd) # log and output the command and its result - self.log_db.add_log_query(self._run_id, turn, cmd, result, answer) + self._log.log_db.add_log_query(self._log.run_id, turn, cmd, result, answer) self._state.update(capability, cmd, result) # TODO output/log new state - self.console.print(Panel(result, title=f"[bold cyan]{cmd}")) + self._log.console.print(Panel(result, title=f"[bold cyan]{cmd}")) # if we got root, we can stop the loop return got_root diff --git a/src/hackingBuddyGPT/usecases/base.py b/src/hackingBuddyGPT/usecases/base.py index f090e4e..459db92 100644 --- a/src/hackingBuddyGPT/usecases/base.py +++ b/src/hackingBuddyGPT/usecases/base.py @@ -1,10 +1,24 @@ import abc import argparse -from dataclasses import dataclass, field +import typing +from dataclasses import dataclass +from rich.panel import Panel from typing import Dict, Type -from hackingBuddyGPT.utils.configurable import ParameterDefinitions, build_parser, get_arguments, get_class_parameters +from hackingBuddyGPT.utils.configurable import ParameterDefinitions, build_parser, get_arguments, get_class_parameters, transparent +from hackingBuddyGPT.utils.console.console import Console +from hackingBuddyGPT.utils.db_storage.db_storage import DbStorage + +@dataclass +class Logger: + log_db: DbStorage + console: Console + tag: str = "" + run_id: int = 0 + + +@dataclass class UseCase(abc.ABC): """ A UseCase is the combination of tools and capabilities to solve a specific problem. @@ -16,13 +30,21 @@ class UseCase(abc.ABC): so that they can be automatically discovered and run from the command line. """ + log_db: DbStorage + console: Console + tag: str = "" + + _run_id: int = 0 + _log: Logger = None + def init(self): """ The init method is called before the run method. It is used to initialize the UseCase, and can be used to perform any dynamic setup that is needed before the run method is called. One of the most common use cases is setting up the llm capabilities from the tools that were injected. """ - pass + self._run_id = self.log_db.create_new_run(self.get_name(), self.tag) + self._log = Logger(self.log_db, self.console, self.tag, self._run_id) @abc.abstractmethod def run(self): @@ -33,6 +55,57 @@ def run(self): """ pass + @abc.abstractmethod + def get_name(self) -> str: + """ + This method should return the name of the use case. It is used for logging and debugging purposes. + """ + pass + + +# this runs the main loop for a bounded amount of turns or until root was achieved +@dataclass +class AutonomousUseCase(UseCase, abc.ABC): + max_turns: int = 10 + + _got_root: bool = False + + @abc.abstractmethod + def perform_round(self, turn: int): + pass + + def before_run(self): + pass + + def after_run(self): + pass + + def run(self): + + self.before_run() + + turn = 1 + while turn <= self.max_turns and not self._got_root: + self._log.console.log(f"[yellow]Starting turn {turn} of {self.max_turns}") + + self._got_root = self.perform_round(turn) + + # finish turn and commit logs to storage + self._log.log_db.commit() + turn += 1 + + self.after_run() + + # write the final result to the database and console + if self._got_root: + self._log.log_db.run_was_success(self._run_id, turn) + self._log.console.print(Panel("[bold green]Got Root!", title="Run finished")) + else: + self._log.log_db.run_was_failure(self._run_id, turn) + self._log.console.print(Panel("[green]maximum turn number reached", title="Run finished")) + + return self._got_root + @dataclass class _WrappedUseCase: @@ -56,17 +129,63 @@ def __call__(self, args: argparse.Namespace): use_cases: Dict[str, _WrappedUseCase] = dict() -def use_case(name: str, desc: str): - """ - By wrapping a UseCase with this decorator, it will be automatically discoverable and can be run from the command - line. - """ +T = typing.TypeVar("T") + + +class AutonomousAgentUseCase(AutonomousUseCase, typing.Generic[T]): + agent: T = None + + def perform_round(self, turn: int): + raise ValueError("Do not use AutonomousAgentUseCase without supplying an agent type as generic") + + def get_name(self) -> str: + raise ValueError("Do not use AutonomousAgentUseCase without supplying an agent type as generic") + + @classmethod + def __class_getitem__(cls, item): + item = dataclass(item) + item.__parameters__ = get_class_parameters(item) - def inner(cls: Type[UseCase]): + class AutonomousAgentUseCase(AutonomousUseCase): + agent: transparent(item) = None + + def init(self): + super().init() + self.agent._log = self._log + self.agent.init() + + def get_name(self) -> str: + return self.__class__.__name__ + + def before_run(self): + return self.agent.before_run() + + def after_run(self): + return self.agent.after_run() + + def perform_round(self, turn: int): + return self.agent.perform_round(turn) + + constructed_class = dataclass(AutonomousAgentUseCase) + + return constructed_class + + +def use_case(description): + def inner(cls): + cls = dataclass(cls) + name = cls.__name__.removesuffix("UseCase") if name in use_cases: raise IndexError(f"Use case with name {name} already exists") - use_cases[name] = _WrappedUseCase(name, desc, cls, get_class_parameters(cls, name)) - + use_cases[name] = _WrappedUseCase(name, description, cls, get_class_parameters(cls)) return cls - return inner + + +def register_use_case(name: str, description: str, use_case: Type[UseCase]): + """ + This function is used to register a UseCase that was created manually, and not through the use_case decorator. + """ + if name in use_cases: + raise IndexError(f"Use case with name {name} already exists") + use_cases[name] = _WrappedUseCase(name, description, use_case, get_class_parameters(use_case)) diff --git a/src/hackingBuddyGPT/usecases/common_patterns.py b/src/hackingBuddyGPT/usecases/common_patterns.py deleted file mode 100644 index 357a56f..0000000 --- a/src/hackingBuddyGPT/usecases/common_patterns.py +++ /dev/null @@ -1,62 +0,0 @@ -import abc - -from dataclasses import dataclass -from rich.panel import Panel - -from .base import UseCase -from hackingBuddyGPT.utils import Console, DbStorage -from hackingBuddyGPT.utils.openai.openai_llm import OpenAIConnection - -# this set ups all the console and database stuff, and runs the main loop for a bounded amount of turns -@dataclass -class RoundBasedUseCase(UseCase, abc.ABC): - log_db: DbStorage - console: Console - llm: OpenAIConnection = None - tag: str = "" - max_turns: int =10 - - _got_root: bool = False - _run_id: int = 0 - - def init(self): - super().init() - self._run_id = self.log_db.create_new_run(self.llm.model, self.llm.context_size, self.tag) - - # callback - def setup(self): - pass - - # callback - @abc.abstractmethod - def perform_round(self, turn: int): - pass - - # callback - def teardown(self): - pass - - def run(self): - - self.setup() - - turn = 1 - while turn <= self.max_turns and not self._got_root: - self.console.log(f"[yellow]Starting turn {turn} of {self.max_turns}") - - self._got_root = self.perform_round(turn) - - # finish turn and commit logs to storage - self.log_db.commit() - turn += 1 - - # write the final result to the database and console - if self._got_root: - self.log_db.run_was_success(self._run_id, turn) - self.console.print(Panel("[bold green]Got Root!", title="Run finished")) - else: - self.log_db.run_was_failure(self._run_id, turn) - self.console.print(Panel("[green]maximum turn number reached", title="Run finished")) - - self.teardown() - return self._got_root \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/minimal/agent.py b/src/hackingBuddyGPT/usecases/minimal/agent.py index 555a068..e7e6442 100644 --- a/src/hackingBuddyGPT/usecases/minimal/agent.py +++ b/src/hackingBuddyGPT/usecases/minimal/agent.py @@ -1,23 +1,20 @@ import pathlib -from dataclasses import dataclass, field from mako.template import Template from rich.panel import Panel from hackingBuddyGPT.capabilities import SSHRunCommand, SSHTestCredential from hackingBuddyGPT.utils import SSHConnection, llm_util -from hackingBuddyGPT.usecases.base import use_case +from hackingBuddyGPT.usecases.base import use_case, AutonomousAgentUseCase from hackingBuddyGPT.usecases.agents import Agent from hackingBuddyGPT.utils.cli_history import SlidingCliHistory template_dir = pathlib.Path(__file__).parent template_next_cmd = Template(filename=str(template_dir / "next_cmd.txt")) -@use_case("minimal_linux_privesc", "Showcase Minimal Linux Priv-Escalation") -@dataclass + class MinimalLinuxPrivesc(Agent): conn: SSHConnection = None - _sliding_history: SlidingCliHistory = None def init(self): @@ -27,10 +24,10 @@ def init(self): self.add_capability(SSHTestCredential(conn=self.conn)) self._template_size = self.llm.count_tokens(template_next_cmd.source) - def perform_round(self, turn): - got_root : bool = False + def perform_round(self, turn: int) -> bool: + got_root: bool = False - with self.console.status("[bold green]Asking LLM for a new command..."): + with self._log.console.status("[bold green]Asking LLM for a new command..."): # get as much history as fits into the target context size history = self._sliding_history.get_history(self.llm.context_size - llm_util.SAFETY_MARGIN - self._template_size) @@ -38,14 +35,19 @@ def perform_round(self, turn): answer = self.llm.get_response(template_next_cmd, capabilities=self.get_capability_block(), history=history, conn=self.conn) cmd = llm_util.cmd_output_fixer(answer.result) - with self.console.status("[bold green]Executing that command..."): - self.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:")) - result, got_root = self.get_capability(cmd.split(" ", 1)[0])(cmd) + with self._log.console.status("[bold green]Executing that command..."): + self._log.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:")) + result, got_root = self.get_capability(cmd.split(" ", 1)[0])(cmd) # log and output the command and its result - self.log_db.add_log_query(self._run_id, turn, cmd, result, answer) + self._log.log_db.add_log_query(self._log.run_id, turn, cmd, result, answer) self._sliding_history.add_command(cmd, result) - self.console.print(Panel(result, title=f"[bold cyan]{cmd}")) + self._log.console.print(Panel(result, title=f"[bold cyan]{cmd}")) # if we got root, we can stop the loop return got_root + + +@use_case("Showcase Minimal Linux Priv-Escalation") +class MinimalLinuxPrivescUseCase(AutonomousAgentUseCase[MinimalLinuxPrivesc]): + pass diff --git a/src/hackingBuddyGPT/usecases/minimal/agent_with_state.py b/src/hackingBuddyGPT/usecases/minimal/agent_with_state.py index 85955ad..ed21d7f 100644 --- a/src/hackingBuddyGPT/usecases/minimal/agent_with_state.py +++ b/src/hackingBuddyGPT/usecases/minimal/agent_with_state.py @@ -5,15 +5,15 @@ from hackingBuddyGPT.capabilities import SSHRunCommand, SSHTestCredential from hackingBuddyGPT.utils import SSHConnection, llm_util -from hackingBuddyGPT.usecases.base import use_case +from hackingBuddyGPT.usecases.base import use_case, AutonomousAgentUseCase from hackingBuddyGPT.usecases.agents import TemplatedAgent, AgentWorldview from hackingBuddyGPT.utils.cli_history import SlidingCliHistory + @dataclass class MinimalLinuxTemplatedPrivescState(AgentWorldview): - sliding_history: SlidingCliHistory = None + sliding_history: SlidingCliHistory max_history_size: int = 0 - conn: SSHConnection = None def __init__(self, conn, llm, max_history_size): @@ -30,8 +30,7 @@ def to_template(self) -> dict[str, Any]: 'conn': self.conn } -@use_case("minimal_linux_templated_agent", "Showcase Minimal Linux Priv-Escalation") -@dataclass + class MinimalLinuxTemplatedPrivesc(TemplatedAgent): conn: SSHConnection = None @@ -49,3 +48,8 @@ def init(self): # setup state max_history_size = self.llm.context_size - llm_util.SAFETY_MARGIN - self._template_size self.set_initial_state(MinimalLinuxTemplatedPrivescState(self.conn, self.llm, max_history_size)) + + +@use_case("Showcase Minimal Linux Priv-Escalation") +class MinimalLinuxTemplatedPrivescUseCase(AutonomousAgentUseCase[MinimalLinuxTemplatedPrivesc]): + pass diff --git a/src/hackingBuddyGPT/usecases/privesc/common.py b/src/hackingBuddyGPT/usecases/privesc/common.py index 56721ea..0760082 100644 --- a/src/hackingBuddyGPT/usecases/privesc/common.py +++ b/src/hackingBuddyGPT/usecases/privesc/common.py @@ -35,9 +35,9 @@ class Privesc(Agent): def init(self): super().init() - def setup(self): + def before_run(self): if self.hint != "": - self.console.print(f"[bold green]Using the following hint: '{self.hint}'") + self._log.console.print(f"[bold green]Using the following hint: '{self.hint}'") if self.disable_history is False: self._sliding_history = SlidingCliHistory(self.llm) @@ -57,48 +57,48 @@ def setup(self): def perform_round(self, turn: int) -> bool: got_root: bool = False - with self.console.status("[bold green]Asking LLM for a new command..."): + with self._log.console.status("[bold green]Asking LLM for a new command..."): answer = self.get_next_command() cmd = answer.result - with self.console.status("[bold green]Executing that command..."): - self.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:")) + with self._log.console.status("[bold green]Executing that command..."): + self._log.console.print(Panel(answer.result, title="[bold cyan]Got command from LLM:")) _capability_descriptions, parser = capabilities_to_simple_text_handler(self._capabilities, default_capability=self._default_capability) success, *output = parser(cmd) if not success: - self.console.print(Panel(output[0], title="[bold red]Error parsing command:")) + self._log.console.print(Panel(output[0], title="[bold red]Error parsing command:")) return False assert(len(output) == 1) capability, cmd, (result, got_root) = output[0] # log and output the command and its result - self.log_db.add_log_query(self._run_id, turn, cmd, result, answer) + self._log.log_db.add_log_query(self._log.run_id, turn, cmd, result, answer) if self._sliding_history: self._sliding_history.add_command(cmd, result) - self.console.print(Panel(result, title=f"[bold cyan]{cmd}")) + self._log.console.print(Panel(result, title=f"[bold cyan]{cmd}")) # analyze the result.. if self.enable_explanation: - with self.console.status("[bold green]Analyze its result..."): + with self._log.console.status("[bold green]Analyze its result..."): answer = self.analyze_result(cmd, result) - self.log_db.add_log_analyze_response(self._run_id, turn, cmd, answer.result, answer) + self._log.log_db.add_log_analyze_response(self._log.run_id, turn, cmd, answer.result, answer) # .. and let our local model update its state if self.enable_update_state: # this must happen before the table output as we might include the # status processing time in the table.. - with self.console.status("[bold green]Updating fact list.."): + with self._log.console.status("[bold green]Updating fact list.."): state = self.update_state(cmd, result) - self.log_db.add_log_update_state(self._run_id, turn, "", state.result, state) + self._log.log_db.add_log_update_state(self._log.run_id, turn, "", state.result, state) # Output Round Data.. - self.console.print(ui.get_history_table(self.enable_explanation, self.enable_update_state, self._run_id, self.log_db, turn)) + self._log.console.print(ui.get_history_table(self.enable_explanation, self.enable_update_state, self._log.run_id, self._log.log_db, turn)) # .. and output the updated state if self.enable_update_state: - self.console.print(Panel(self._state, title="What does the LLM Know about the system?")) + self._log.console.print(Panel(self._state, title="What does the LLM Know about the system?")) # if we got root, we can stop the loop return got_root diff --git a/src/hackingBuddyGPT/usecases/privesc/linux.py b/src/hackingBuddyGPT/usecases/privesc/linux.py index ccd1065..a701d3b 100644 --- a/src/hackingBuddyGPT/usecases/privesc/linux.py +++ b/src/hackingBuddyGPT/usecases/privesc/linux.py @@ -1,148 +1,148 @@ import json import pathlib -from dataclasses import dataclass from mako.template import Template from hackingBuddyGPT.capabilities import SSHRunCommand, SSHTestCredential +from hackingBuddyGPT.utils.openai.openai_llm import OpenAIConnection from .common import Privesc from hackingBuddyGPT.utils import SSHConnection -from hackingBuddyGPT.usecases.base import use_case, UseCase -from hackingBuddyGPT.utils.console.console import Console -from hackingBuddyGPT.utils.db_storage.db_storage import DbStorage -from hackingBuddyGPT.utils.openai.openai_llm import OpenAIConnection +from hackingBuddyGPT.usecases.base import UseCase, use_case, AutonomousAgentUseCase template_dir = pathlib.Path(__file__).parent / "templates" -template_next_cmd = Template(filename=str(template_dir / "query_next_command.txt")) -template_analyze = Template(filename=str(template_dir / "analyze_cmd.txt")) -template_state = Template(filename=str(template_dir / "update_state.txt")) template_lse = Template(filename=str(template_dir / "get_hint_from_lse.txt")) -@use_case("linux_privesc_hintfile", "Linux Privilege Escalation using a hints file") -@dataclass -class PrivescWithHintFile(UseCase): + +class LinuxPrivesc(Privesc): conn: SSHConnection = None - system: str = '' - enable_explanation: bool = False - enable_update_state: bool = False - disable_history: bool = False - hints: str = "" + system: str = "linux" - # all of these would typically be set by RoundBasedUseCase :-/ - # but we need them here so that we can pass them on to the inner - # use-case - log_db: DbStorage = None - console: Console = None - llm: OpenAIConnection = None - tag: str = "" - max_turns: int = 10 + def init(self): + super().init() + self.add_capability(SSHRunCommand(conn=self.conn), default=True) + self.add_capability(SSHTestCredential(conn=self.conn)) + + +@use_case("Linux Privilege Escalation") +class LinuxPrivescUseCase(AutonomousAgentUseCase[LinuxPrivesc]): + pass + + +@use_case("Linux Privilege Escalation using hints from a hint file initial guidance") +class LinuxPrivescWithHintFileUseCase(AutonomousAgentUseCase[LinuxPrivesc]): + hints: str = None def init(self): super().init() + self.agent.hint = self.read_hint() # simple helper that reads the hints file and returns the hint # for the current machine (test-case) def read_hint(self): - if self.hints != "": - try: - with open(self.hints, "r") as hint_file: - hints = json.load(hint_file) - if self.conn.hostname in hints: - return hints[self.conn.hostname] - except: - self.console.print("[yellow]Was not able to load hint file") - else: - self.console.print("[yellow]calling the hintfile use-case without a hint file?") + try: + with open(self.hints, "r") as hint_file: + hints = json.load(hint_file) + if self.agent.conn.hostname in hints: + return hints[self.agent.conn.hostname] + except FileNotFoundError: + self._log.console.print("[yellow]Hint file not found") + except Exception as e: + self._log.console.print("[yellow]Hint file could not loaded:", str(e)) return "" - def run(self): - # read the hint - hint = self.read_hint() - - # call the inner use-case - priv_esc = LinuxPrivesc( - conn=self.conn, # must be set in sub classes - enable_explanation=self.enable_explanation, - disable_history=self.disable_history, - hint=hint, - log_db = self.log_db, - console = self.console, - llm = self.llm, - tag = self.tag, - max_turns = self.max_turns - ) - - priv_esc.init() - priv_esc.run() -@use_case("linux_privesc_guided", "Linux Privilege Escalation using lse.sh for initial guidance") -@dataclass -class PrivescWithLSE(UseCase): +@use_case("Linux Privilege Escalation using lse.sh for initial guidance") +class LinuxPrivescWithLSEUseCase(UseCase): conn: SSHConnection = None - system: str = '' + max_turns: int = 20 enable_explanation: bool = False enable_update_state: bool = False disable_history: bool = False - - # all of these would typically be set by RoundBasedUseCase :-/ - # but we need them here so that we can pass them on to the inner - # use-case - log_db: DbStorage = None - console: Console = None llm: OpenAIConnection = None - tag: str = "" - max_turns: int = 10 - low_llm: OpenAIConnection = None + + _got_root: bool = False + + # use either an use-case or an agent to perform the privesc + use_use_case: bool = False def init(self): super().init() # simple helper that uses lse.sh to get hints from the system - def read_hint(self): - - self.console.print("[green]performing initial enumeration with lse.sh") + def call_lse_against_host(self): + self._log.console.print("[green]performing initial enumeration with lse.sh") run_cmd = "wget -q 'https://github.com/diego-treitos/linux-smart-enumeration/releases/latest/download/lse.sh' -O lse.sh;chmod 700 lse.sh; ./lse.sh -c -i -l 0 | grep -v 'nope$' | grep -v 'skip$'" - result, got_root = SSHRunCommand(conn=self.conn, timeout=120)(run_cmd) + result, _ = SSHRunCommand(conn=self.conn, timeout=120)(run_cmd) self.console.print("[yellow]got the output: " + result) cmd = self.llm.get_response(template_lse, lse_output=result, number=3) self.console.print("[yellow]got the cmd: " + cmd.result) - return cmd.result + return [x for x in cmd.result.splitlines() if x.strip()] + def get_name(self) -> str: + return self.__class__.__name__ + def run(self): - # read the hint - hint = self.read_hint() - - for i in hint.splitlines(): - self.console.print("[green]Now using Hint: " + i) - - # call the inner use-case - priv_esc = LinuxPrivesc( - conn=self.conn, # must be set in sub classes - enable_explanation=self.enable_explanation, - disable_history=self.disable_history, - hint=i, - log_db = self.log_db, - console = self.console, - llm = self.low_llm, - tag = self.tag + "_hint_" +i, - max_turns = self.max_turns - ) - - priv_esc.init() - if priv_esc.run(): - # we are root! w00t! - return True + # get the hints through running LSE on the target system + hints = self.call_lse_against_host() + turns_per_hint = int(self.max_turns / len(hints)) -@use_case("linux_privesc", "Linux Privilege Escalation") -@dataclass -class LinuxPrivesc(Privesc): - conn: SSHConnection = None - system: str = "linux" + # now try to escalate privileges using the hints + for hint in hints: - def init(self): - super().init() - self.add_capability(SSHRunCommand(conn=self.conn), default=True) - self.add_capability(SSHTestCredential(conn=self.conn)) \ No newline at end of file + if self._use_use_case: + result = self.run_using_usecases(hint, turns_per_hint) + else: + result = self.run_using_agent(hint, turns_per_hint) + + if result is True: + self.console.print("[green]Got root!") + return True + + def run_using_usecases(self, hint, turns_per_hint): + # TODO: init usecase + linux_privesc = LinuxPrivescUseCase( + agent = LinuxPrivesc( + conn = self.conn, + enable_explanation = self.enable_explanation, + enable_update_state = self.enable_update_state, + disable_history = self.disable_history, + llm = self.llm, + hint = hint + ), + max_turns = turns_per_hint, + log_db = self.log_db, + console = self.console + ) + linux_privesc.init() + return linux_privesc.run() + + def run_using_agent(self, hint, turns_per_hint): + # init agent + agent = LinuxPrivesc( + conn = self.conn, + llm = self.llm, + hint = hint, + enable_explanation = self.enable_explanation, + enable_update_state = self.enable_update_state, + disable_history = self.disable_history + ) + agent._log = self._log + agent.init() + + # perform the privilege escalation + agent.before_run() + turn = 1 + got_root = False + while turn <= turns_per_hint and not got_root: + self._log.console.log(f"[yellow]Starting turn {turn} of {turns_per_hint}") + + if agent.perform_round(turn) is True: + got_root = True + turn += 1 + + # cleanup and finish + agent.after_run() + return got_root diff --git a/src/hackingBuddyGPT/usecases/privesc/windows.py b/src/hackingBuddyGPT/usecases/privesc/windows.py index cf56509..7225dc0 100644 --- a/src/hackingBuddyGPT/usecases/privesc/windows.py +++ b/src/hackingBuddyGPT/usecases/privesc/windows.py @@ -1,14 +1,10 @@ -from dataclasses import dataclass - from hackingBuddyGPT.capabilities.psexec_run_command import PSExecRunCommand from hackingBuddyGPT.capabilities.psexec_test_credential import PSExecTestCredential -from hackingBuddyGPT.usecases.base import use_case +from hackingBuddyGPT.usecases.base import AutonomousAgentUseCase, use_case from hackingBuddyGPT.usecases.privesc.common import Privesc from hackingBuddyGPT.utils.psexec.psexec import PSExecConnection -@use_case("windows_privesc", "Windows Privilege Escalation") -@dataclass class WindowsPrivesc(Privesc): conn: PSExecConnection = None system: str = "Windows" @@ -16,4 +12,9 @@ class WindowsPrivesc(Privesc): def init(self): super().init() self.add_capability(PSExecRunCommand(conn=self.conn), default=True) - self.add_capability(PSExecTestCredential(conn=self.conn)) \ No newline at end of file + self.add_capability(PSExecTestCredential(conn=self.conn)) + + +@use_case("Windows Privilege Escalation") +class WindowsPrivescUseCase(AutonomousAgentUseCase[WindowsPrivesc]): + pass diff --git a/src/hackingBuddyGPT/usecases/web/simple.py b/src/hackingBuddyGPT/usecases/web/simple.py index c9177d3..22152b5 100644 --- a/src/hackingBuddyGPT/usecases/web/simple.py +++ b/src/hackingBuddyGPT/usecases/web/simple.py @@ -1,7 +1,7 @@ import pydantic_core import time -from dataclasses import dataclass, field +from dataclasses import field from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage from rich.panel import Panel from typing import List, Any, Union, Dict @@ -11,9 +11,9 @@ from hackingBuddyGPT.capabilities.http_request import HTTPRequest from hackingBuddyGPT.capabilities.record_note import RecordNote from hackingBuddyGPT.capabilities.submit_flag import SubmitFlag +from hackingBuddyGPT.usecases.agents import Agent from hackingBuddyGPT.utils import LLMResult, tool_message -from hackingBuddyGPT.usecases.base import use_case -from hackingBuddyGPT.usecases.common_patterns import RoundBasedUseCase +from hackingBuddyGPT.usecases.base import use_case, AutonomousAgentUseCase from hackingBuddyGPT.utils.configurable import parameter from hackingBuddyGPT.utils.openai.openai_lib import OpenAILib @@ -22,9 +22,7 @@ Context = Any -@use_case("simple_web_test", "Minimal implementation of a web testing use case") -@dataclass -class MinimalWebTesting(RoundBasedUseCase): +class MinimalWebTesting(Agent): llm: OpenAILib host: str = parameter(desc="The host to test", default="http://localhost") flag_format_description: str = parameter(desc="Description of the flag provided to the LLM", default="a string starting with 'FLAG.' and ending with '.GALF'") @@ -54,11 +52,11 @@ def init(self): } def all_flags_found(self): - self.console.print(Panel("All flags found! Congratulations!", title="system")) + self._log.console.print(Panel("All flags found! Congratulations!", title="system")) self._all_flags_found = True def perform_round(self, turn: int): - with self.console.status("[bold green]Asking LLM for a new command..."): + with self._log.console.status("[bold green]Asking LLM for a new command..."): prompt = self._prompt_history # TODO: in the future, this should do some context truncation tic = time.perf_counter() @@ -68,15 +66,20 @@ def perform_round(self, turn: int): message = completion.choices[0].message tool_call_id = message.tool_calls[0].id command = pydantic_core.to_json(response).decode() - self.console.print(Panel(command, title="assistant")) + self._log.console.print(Panel(command, title="assistant")) self._prompt_history.append(message) answer = LLMResult(completion.choices[0].message.content, str(prompt), completion.choices[0].message.content, toc-tic, completion.usage.prompt_tokens, completion.usage.completion_tokens) - with self.console.status("[bold green]Executing that command..."): + with self._log.console.status("[bold green]Executing that command..."): result = response.execute() - self.console.print(Panel(result, title="tool")) + self._log.console.print(Panel(result, title="tool")) self._prompt_history.append(tool_message(result, tool_call_id)) - self.log_db.add_log_query(self._run_id, turn, command, result, answer) + self._log.log_db.add_log_query(self._log.run_id, turn, command, result, answer) return self._all_flags_found + + +@use_case("Minimal implementation of a web testing use case") +class MinimalWebTestingUseCase(AutonomousAgentUseCase[MinimalWebTesting]): + pass diff --git a/src/hackingBuddyGPT/usecases/web/with_explanation.py b/src/hackingBuddyGPT/usecases/web/with_explanation.py index 19fb203..96dd657 100644 --- a/src/hackingBuddyGPT/usecases/web/with_explanation.py +++ b/src/hackingBuddyGPT/usecases/web/with_explanation.py @@ -1,5 +1,5 @@ import time -from dataclasses import dataclass, field +from dataclasses import field from typing import List, Any, Union, Dict from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage @@ -8,9 +8,9 @@ from hackingBuddyGPT.capabilities import Capability from hackingBuddyGPT.capabilities.http_request import HTTPRequest from hackingBuddyGPT.capabilities.submit_flag import SubmitFlag +from hackingBuddyGPT.usecases.agents import Agent +from hackingBuddyGPT.usecases.base import AutonomousAgentUseCase, use_case from hackingBuddyGPT.utils import LLMResult, tool_message -from hackingBuddyGPT.usecases import use_case -from hackingBuddyGPT.usecases.common_patterns import RoundBasedUseCase from hackingBuddyGPT.utils.configurable import parameter from hackingBuddyGPT.utils.openai.openai_lib import OpenAILib @@ -19,9 +19,7 @@ Context = Any -@use_case("web_test_with_explanation", "Minimal implementation of a web testing use case while allowing the llm to 'talk'") -@dataclass -class WebTestingWithExplanation(RoundBasedUseCase): +class WebTestingWithExplanation(Agent): llm: OpenAILib host: str = parameter(desc="The host to test", default="http://localhost") flag_format_description: str = parameter(desc="Description of the flag provided to the LLM", default="a string starting with 'FLAG.' and ending with '.GALF'") @@ -51,19 +49,19 @@ def init(self): } def all_flags_found(self): - self.console.print(Panel("All flags found! Congratulations!", title="system")) + self._log.console.print(Panel("All flags found! Congratulations!", title="system")) self._all_flags_found = True def perform_round(self, turn: int): prompt = self._prompt_history # TODO: in the future, this should do some context truncation result: LLMResult = None - stream = self.llm.stream_response(prompt, self.console, capabilities=self._capabilities) + stream = self.llm.stream_response(prompt, self._log.console, capabilities=self._capabilities) for part in stream: result = part message: ChatCompletionMessage = result.result - message_id = self.log_db.add_log_message(self._run_id, message.role, message.content, result.tokens_query, result.tokens_response, result.duration) + message_id = self._log.log_db.add_log_message(self._log.run_id, message.role, message.content, result.tokens_query, result.tokens_response, result.duration) self._prompt_history.append(result.result) if message.tool_calls is not None: @@ -72,9 +70,14 @@ def perform_round(self, turn: int): tool_call_result = self._capabilities[tool_call.function.name].to_model().model_validate_json(tool_call.function.arguments).execute() toc = time.perf_counter() - self.console.print(f"\n[bold green on gray3]{' '*self.console.width}\nTOOL RESPONSE:[/bold green on gray3]") - self.console.print(tool_call_result) + self._log.console.print(f"\n[bold green on gray3]{' '*self._log.console.width}\nTOOL RESPONSE:[/bold green on gray3]") + self._log.console.print(tool_call_result) self._prompt_history.append(tool_message(tool_call_result, tool_call.id)) - self.log_db.add_log_tool_call(self._run_id, message_id, tool_call.id, tool_call.function.name, tool_call.function.arguments, tool_call_result, toc - tic) + self._log.log_db.add_log_tool_call(self._log.run_id, message_id, tool_call.id, tool_call.function.name, tool_call.function.arguments, tool_call_result, toc - tic) return self._all_flags_found + + +@use_case("Minimal implementation of a web testing use case while allowing the llm to 'talk'") +class WebTestingWithExplanationUseCase(AutonomousAgentUseCase[WebTestingWithExplanation]): + pass diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_engineer.py b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_engineer.py index 5d7fcf8..8615e54 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/prompt_engineer.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/prompt_engineer.py @@ -1,43 +1,49 @@ -from openai.types.chat import ChatCompletionMessage +import nltk +from nltk.tokenize import word_tokenize +from instructor.retry import InstructorRetryException -from hackingBuddyGPT.utils import openai class PromptEngineer(object): '''Prompt engineer that creates prompts of different types''' - def __init__(self, strategy, api_key, history): + def __init__(self, strategy, llm_handler, history, schemas, response_handler): """ - Initializes the PromptEngineer with a specific strategy and API key. + Initializes the PromptEngineer with a specific strategy and handlers for LLM and responses. Args: strategy (PromptStrategy): The prompt engineering strategy to use. - api_key (str): The API key for OpenAI. - + llm_handler (object): The LLM handler. history (dict, optional): The history of chats. Defaults to None. + schemas (object): The schemas to use. + response_handler (object): The handler for managing responses. Attributes: strategy (PromptStrategy): Stores the provided strategy. - api_key (str): Stores the provided API key. - host (str): Stores the provided host for OpenAI API. - flag_format_description (str): Stores the provided flag description format. - prompt_history (list): A list that keeps track of the conversation history. - initial_prompt (str): The initial prompt used for conversation. - prompt (str): The current prompt to be used. + llm_handler (object): Handles the interaction with the LLM. + nlp (spacy.lang.en.English): The spaCy English model used for NLP tasks. + _prompt_history (dict): Keeps track of the conversation history. + prompt (dict): The current state of the prompt history. + previous_prompt (str): The previous prompt content based on the conversation history. + schemas (object): Stores the provided schemas. + response_handler (object): Manages the response handling logic. + round (int): Tracks the current round of conversation. strategies (dict): Maps strategies to their corresponding methods. """ self.strategy = strategy - self.api_key = api_key - # Set the OpenAI API key - openai.api_key = self.api_key + self.response_handler = response_handler + self.llm_handler = llm_handler self.round = 0 - - - - # Initialize prompt history + self.found_endpoints = ["/"] + self.endpoint_methods = {} + self.endpoint_found_methods = {} + # Check if the models are already installed + nltk.download('punkt') + nltk.download('stopwords') self._prompt_history = history - self.prompt = self._prompt_history + self.prompt = {self.round: {"content": "initial_prompt"}} + self.previous_prompt = self._prompt_history[self.round]["content"] + self.schemas = schemas - # Set up strategy map self.strategies = { PromptStrategy.IN_CONTEXT: self.in_context_learning, PromptStrategy.CHAIN_OF_THOUGHT: self.chain_of_thought, @@ -53,42 +59,22 @@ def generate_prompt(self, doc=False): """ # Directly call the method using the strategy mapping prompt_func = self.strategies.get(self.strategy) + is_good = False if prompt_func: - print(f'prompt history:{self._prompt_history[self.round]}') - if not isinstance(self._prompt_history[self.round],ChatCompletionMessage ): + while not is_good: prompt = prompt_func(doc) - self._prompt_history[self.round]["content"] = prompt - self.round = self.round +1 - return self._prompt_history - #self.get_response(prompt) - - def get_response(self, prompt): - """ - Sends a prompt to OpenAI's API and retrieves the response. - - Args: - prompt (str): The prompt to be sent to the API. - - Returns: - str: The response from the API. - """ - response = openai.Completion.create( - engine="text-davinci-002", - prompt=prompt, - max_tokens=150, - n=1, - stop=None, - temperature=0.7, - ) - # Update history - response_text = response.choices[0].text.strip() - self._prompt_history.extend([f"[User]: {prompt}", f"[System]: {response_text}"]) - - return response_text - - - - def in_context_learning(self, doc=False): + try: + response_text = self.response_handler.get_response_for_prompt(prompt) + is_good = self.evaluate_response(prompt, response_text) + except InstructorRetryException : + prompt = prompt_func(doc, hint=f"invalid prompt:{prompt}") + if is_good: + self._prompt_history.append( {"role":"system", "content":prompt}) + self.previous_prompt = prompt + self.round = self.round +1 + return self._prompt_history + + def in_context_learning(self, doc=False, hint=""): """ Generates a prompt for in-context learning. @@ -98,54 +84,137 @@ def in_context_learning(self, doc=False): Returns: str: The generated prompt. """ - return str("\n".join(self._prompt_history[self.round]["content"] + [self.prompt])) + history_content = [entry["content"] for entry in self._prompt_history] + prompt_content = self.prompt.get(self.round, {}).get("content", "") - def chain_of_thought(self, doc=False): + # Add hint if provided + if hint: + prompt_content += f"\n{hint}" + + return "\n".join(history_content + [prompt_content]) + + def get_http_action_template(self, method): + """Helper to construct a consistent HTTP action description.""" + if method == "POST" and method == "PUT": + return ( + f"Create HTTPRequests of type {method} considering the found schemas: {self.schemas} and understand the responses. Ensure that they are correct requests." + ) + + else: + return ( + f"Create HTTPRequests of type {method} considering only the object with id=1 for the endpoint and understand the responses. Ensure that they are correct requests.") + def get_initial_steps(self, common_steps): + return [ + "Identify all available endpoints via GET Requests. Exclude those in this list: {self.found_endpoints}", + "Note down the response structures, status codes, and headers for each endpoint.", + "For each endpoint, document the following details: URL, HTTP method, query parameters and path variables, expected request body structure for requests, response structure for successful and error responses." + ] + common_steps + + def get_phase_steps(self, phase, common_steps): + if phase != "DELETE": + return [ + f"Identify for all endpoints {self.found_endpoints} excluding {self.endpoint_found_methods[phase]} a valid HTTP method {phase} call.", + self.get_http_action_template(phase) + ] + common_steps + else: + return [ + "Check for all endpoints the DELETE method. Delete the first instance for all endpoints.", + self.get_http_action_template(phase) + ] + common_steps + + def get_endpoints_needing_help(self): + endpoints_needing_help = [] + endpoints_and_needed_methods = {} + http_methods_set = {"GET", "POST", "PUT", "DELETE"} + + for endpoint, methods in self.endpoint_methods.items(): + missing_methods = http_methods_set - set(methods) + if len(methods) < 4: + endpoints_needing_help.append(endpoint) + endpoints_and_needed_methods[endpoint] = list(missing_methods) + + if endpoints_needing_help: + first_endpoint = endpoints_needing_help[0] + needed_method = endpoints_and_needed_methods[first_endpoint][0] + return [ + f"For endpoint {first_endpoint} find this missing method: {needed_method}. If all the HTTP methods have already been found for an endpoint, then do not include this endpoint in your search."] + return [] + def chain_of_thought(self, doc=False, hint=""): """ - Generates a prompt using the chain-of-thought strategy. https://www.promptingguide.ai/techniques/cot + Generates a prompt using the chain-of-thought strategy. - This method adds a step-by-step reasoning prompt to the current prompt. + Args: + doc (bool): Determines whether the documentation-oriented chain of thought should be used. + hint (str): Additional hint to be added to the chain of thought. Returns: str: The generated prompt. """ - - previous_prompt = self._prompt_history[self.round]["content"] - - if doc : - chain_of_thought_steps = [ - "Explore the API by reviewing any available documentation to learn about the API endpoints, data models, and behaviors.", - "Identify all available endpoints.", - "Create GET, POST, PUT, DELETE requests to understand the responses.", - "Note down the response structures, status codes, and headers for each endpoint.", - "For each endpoint, document the following details: URL, HTTP method (GET, POST, PUT, DELETE), query parameters and path variables, expected request body structure for POST and PUT requests, response structure for successful and error responses.", - "First execute the GET requests, then POST, then PUT and DELETE." + common_steps = [ "Identify common data structures returned by various endpoints and define them as reusable schemas. Determine the type of each field (e.g., integer, string, array) and define common response structures as components that can be referenced in multiple endpoint definitions.", "Create an OpenAPI document including metadata such as API title, version, and description, define the base URL of the API, list all endpoints, methods, parameters, and responses, and define reusable schemas, response types, and parameters.", "Ensure the correctness and completeness of the OpenAPI specification by validating the syntax and completeness of the document using tools like Swagger Editor, and ensure the specification matches the actual behavior of the API.", "Refine the document based on feedback and additional testing, share the draft with others, gather feedback, and make necessary adjustments. Regularly update the specification as the API evolves.", "Make the OpenAPI specification available to developers by incorporating it into your API documentation site and keep the documentation up to date with API changes." - ] + ] + + http_methods = ["PUT", "DELETE"] + http_phase = {10: http_methods[0], 15: http_methods[1]} + if doc: + if self.round <= 5: + chain_of_thought_steps = self.get_initial_steps(common_steps) + elif self.round <= 10: + phase = http_phase.get(min(filter(lambda x: self.round <= x, http_phase.keys()))) + chain_of_thought_steps = self.get_phase_steps(phase, common_steps) + else: + chain_of_thought_steps = self.get_endpoints_needing_help() else: - if round == 0: - chain_of_thought_steps = [ - "Let's think step by step." # zero shot prompt - ] - elif self.round <= 5: - chain_of_thought_steps = ["Just Focus on the endpoints for now."] - elif self.round >5 and self.round <= 10: - chain_of_thought_steps = ["Just Focus on the HTTP method GET for now."] - elif self.round > 10 and self.round <= 15: - chain_of_thought_steps = ["Just Focus on the HTTP method POST and PUT for now."] - elif self.round > 15 and self.round <= 20: - chain_of_thought_steps = ["Just Focus on the HTTP method DELETE for now."] + if self.round == 0: + chain_of_thought_steps = ["Let's think step by step."] + elif self.round <= 20: + focus_phases = ["endpoints", "HTTP method GET", "HTTP method POST and PUT", "HTTP method DELETE"] + focus_phase = focus_phases[self.round // 5] + chain_of_thought_steps = [f"Just focus on the {focus_phase} for now."] else: chain_of_thought_steps = ["Look for exploits."] + if hint: + chain_of_thought_steps.append(hint) + + prompt = self.check_prompt(self.previous_prompt, chain_of_thought_steps) + return prompt + + def token_count(self, text): + """ + Counts the number of word tokens in the provided text using NLTK's tokenizer. + + Args: + text (str): The input text to tokenize and count. + + Returns: + int: The number of tokens in the input text. + """ + # Tokenize the text using NLTK + tokens = word_tokenize(text) + # Filter out punctuation marks + words = [token for token in tokens if token.isalnum()] + return len(words) + - return "\n".join([previous_prompt] + chain_of_thought_steps) + def check_prompt(self, previous_prompt, chain_of_thought_steps, max_tokens=900): + def validate_prompt(prompt): + if self.token_count(prompt) <= max_tokens: + return prompt + shortened_prompt = self.response_handler.get_response_for_prompt("Shorten this prompt." + prompt ) + if self.token_count(shortened_prompt) <= max_tokens: + return shortened_prompt + return "Prompt is still too long after summarization." + if not all(step in previous_prompt for step in chain_of_thought_steps): + potential_prompt = "\n".join(chain_of_thought_steps) + return validate_prompt(potential_prompt) + return validate_prompt(previous_prompt) def tree_of_thought(self, doc=False): """ @@ -167,6 +236,8 @@ def tree_of_thought(self, doc=False): )] return "\n".join([self._prompt_history[self.round]["content"]] + tree_of_thoughts_steps) + def evaluate_response(self, prompt, response_text): #TODO find a good way of evaluating result of prompt + return True diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/simple_openapi_documentation.py b/src/hackingBuddyGPT/usecases/web_api_testing/simple_openapi_documentation.py index 03b34cb..285bd34 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/simple_openapi_documentation.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/simple_openapi_documentation.py @@ -1,32 +1,26 @@ -import datetime -import os -import pydantic_core -import time -import yaml +from dataclasses import field +from typing import List, Any, Union, Dict -from dataclasses import dataclass, field +import pydantic_core from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage from rich.panel import Panel -from typing import List, Any, Union, Dict from hackingBuddyGPT.capabilities import Capability -from hackingBuddyGPT.capabilities.capability import capabilities_to_action_model from hackingBuddyGPT.capabilities.http_request import HTTPRequest from hackingBuddyGPT.capabilities.record_note import RecordNote -from hackingBuddyGPT.capabilities.submit_flag import SubmitFlag -from hackingBuddyGPT.usecases.common_patterns import RoundBasedUseCase +from hackingBuddyGPT.usecases.agents import Agent +from hackingBuddyGPT.usecases.web_api_testing.utils.openapi_specification_manager import OpenAPISpecificationManager +from hackingBuddyGPT.usecases.web_api_testing.utils.llm_handler import LLMHandler from hackingBuddyGPT.usecases.web_api_testing.prompt_engineer import PromptEngineer, PromptStrategy -from hackingBuddyGPT.utils import LLMResult, tool_message, ui +from hackingBuddyGPT.usecases.web_api_testing.utils.response_handler import ResponseHandler +from hackingBuddyGPT.utils import tool_message from hackingBuddyGPT.utils.configurable import parameter from hackingBuddyGPT.utils.openai.openai_lib import OpenAILib -from hackingBuddyGPT.usecases import use_case - +from hackingBuddyGPT.usecases.base import AutonomousAgentUseCase, use_case Prompt = List[Union[ChatCompletionMessage, ChatCompletionMessageParam]] Context = Any -@use_case("simple_web_api_documentation", "Minimal implementation of a web api documentation use case") -@dataclass -class SimpleWebAPIDocumentation(RoundBasedUseCase): +class SimpleWebAPIDocumentation(Agent): llm: OpenAILib host: str = parameter(desc="The host to test", default="https://jsonplaceholder.typicode.com") _prompt_history: Prompt = field(default_factory=list) @@ -34,187 +28,108 @@ class SimpleWebAPIDocumentation(RoundBasedUseCase): _capabilities: Dict[str, Capability] = field(default_factory=dict) _all_http_methods_found: bool = False - # Parameter specifying the pattern description for expected HTTP methods in the API response - http_method_description: str = parameter( + # Description for expected HTTP methods + _http_method_description: str = parameter( desc="Pattern description for expected HTTP methods in the API response", default="A string that represents an HTTP method (e.g., 'GET', 'POST', etc.)." ) - # Parameter specifying the template used to format HTTP methods in API requests - http_method_template: str = parameter( - desc="Template used to format HTTP methods in API requests. The {method} placeholder will be replaced by actual HTTP method names.", - default="{method} request" + # Template for HTTP methods in API requests + _http_method_template: str = parameter( + desc="Template to format HTTP methods in API requests, with {method} replaced by actual HTTP method names.", + default="{method}" ) - # Parameter specifying the expected HTTP methods as a comma-separated list - http_methods: str = parameter( - desc="Comma-separated list of HTTP methods expected to be used in the API response.", + # List of expected HTTP methods + _http_methods: str = parameter( + desc="Expected HTTP methods in the API, as a comma-separated list.", default="GET,POST,PUT,PATCH,DELETE" ) def init(self): super().init() - self.openapi_spec = self.openapi_spec = { - "openapi": "3.0.0", - "info": { - "title": "Generated API Documentation", - "version": "1.0", - "description": "Automatically generated description of the API." - }, - "servers": [{"url": "https://jsonplaceholder.typicode.com"}], - "endpoints": {} - } - self._prompt_history.append( - { - "role": "system", - "content": f"You're tasked with documenting the REST APIs of a website hosted at {self.host}. " - f"Your main goal is to comprehensively explore the APIs endpoints and responses, and then document your findings in form of a OpenAPI specification." - f"Start with an empty OpenAPI specification.\n" - f"Maintain meticulousness in documenting your observations as you traverse the APIs. This will streamline the documentation process.\n" - f"Avoid resorting to brute-force methods. All essential information should be accessible through the API endpoints.\n" - - }) - self.prompt_engineer = PromptEngineer( - strategy=PromptStrategy.CHAIN_OF_THOUGHT, - api_key=self.llm.api_key, - history=self._prompt_history) - - self._context["host"] = self.host - sett = set(self.http_method_template.format(method=method) for method in self.http_methods.split(",")) + self._setup_capabilities() + self.llm_handler = LLMHandler(self.llm, self._capabilities) + self.response_handler = ResponseHandler(self.llm_handler) + self._setup_initial_prompt() + self.documentation_handler = OpenAPISpecificationManager(self.llm_handler, self.response_handler) + + def _setup_capabilities(self): + notes = self._context["notes"] self._capabilities = { - "submit_http_method": SubmitFlag(self.http_method_description, - sett, - success_function=self.all_http_methods_found), "http_request": HTTPRequest(self.host), - "record_note": RecordNote(self._context["notes"]), + "record_note": RecordNote(notes) } - self.current_time = datetime.datetime.now() - - def all_http_methods_found(self): - self.console.print(Panel("All HTTP methods found! Congratulations!", title="system")) - self._all_http_methods_found = True - - def perform_round(self, turn: int, FINAL_ROUND=20): - - with self.console.status("[bold green]Asking LLM for a new command..."): - # generate prompt - prompt = self.prompt_engineer.generate_prompt(doc=True) - - tic = time.perf_counter() - response, completion = self.llm.instructor.chat.completions.create_with_completion(model=self.llm.model, - messages=prompt, - response_model=capabilities_to_action_model( - self._capabilities)) - toc = time.perf_counter() - - message = completion.choices[0].message - - tool_call_id = message.tool_calls[0].id - command = pydantic_core.to_json(response).decode() - self.console.print(Panel(command, title="assistant")) + def _setup_initial_prompt(self): + initial_prompt = { + "role": "system", + "content": f"You're tasked with documenting the REST APIs of a website hosted at {self.host}. " + f"Start with an empty OpenAPI specification.\n" + f"Maintain meticulousness in documenting your observations as you traverse the APIs." + } + self._prompt_history.append(initial_prompt) + self.prompt_engineer = PromptEngineer(strategy=PromptStrategy.CHAIN_OF_THOUGHT, llm_handler=self.llm_handler, + history=self._prompt_history, schemas={}, + response_handler=self.response_handler) + + + def all_http_methods_found(self,turn): + print(f'found endpoints:{self.documentation_handler.endpoint_methods.items()}') + print(f'found endpoints values:{self.documentation_handler.endpoint_methods.values()}') + + found_endpoints = sum(len(value_list) for value_list in self.documentation_handler.endpoint_methods.values()) + expected_endpoints = len(self.documentation_handler.endpoint_methods.keys())*4 + print(f'found endpoints:{found_endpoints}') + print(f'expected endpoints:{expected_endpoints}') + print(f'correct? {found_endpoints== expected_endpoints}') + if found_endpoints > 0 and (found_endpoints== expected_endpoints) : + return True + else: + if turn == 20: + if found_endpoints > 0 and (found_endpoints == expected_endpoints): + return True + return False + + def perform_round(self, turn: int): + prompt = self.prompt_engineer.generate_prompt(doc=True) + response, completion = self.llm_handler.call_llm(prompt) + return self._handle_response(completion, response, turn) + + def _handle_response(self, completion, response, turn): + message = completion.choices[0].message + tool_call_id = message.tool_calls[0].id + command = pydantic_core.to_json(response).decode() + self._log.console.print(Panel(command, title="assistant")) + self._prompt_history.append(message) + + with self._log.console.status("[bold green]Executing that command..."): + result = response.execute() + self._log.console.print(Panel(result[:30], title="tool")) + result_str = self.response_handler.parse_http_status_line(result) + self._prompt_history.append(tool_message(result_str, tool_call_id)) + invalid_flags = ["recorded","Not a valid HTTP method", "404" ,"Client Error: Not Found"] + if not result_str in invalid_flags or any(item in result_str for item in invalid_flags): + self.prompt_engineer.found_endpoints = self.documentation_handler.update_openapi_spec(response, result) + self.documentation_handler.write_openapi_to_yaml() + self.prompt_engineer.schemas = self.documentation_handler.schemas + from collections import defaultdict + http_methods_dict = defaultdict(list) - self._prompt_history.append(message) - content = completion.choices[0].message.content + # Iterate through the original dictionary + for endpoint, methods in self.documentation_handler.endpoint_methods.items(): + for method in methods: + http_methods_dict[method].append(endpoint) + self.prompt_engineer.endpoint_found_methods = http_methods_dict + self.prompt_engineer.endpoint_methods = self.documentation_handler.endpoint_methods + return self.all_http_methods_found(turn) - answer = LLMResult(content, str(prompt), - content, toc - tic, completion.usage.prompt_tokens, - completion.usage.completion_tokens) - with self.console.status("[bold green]Executing that command..."): - result = response.execute() - self.console.print(Panel(result, title="tool")) - result_str = self.parse_http_status_line(result) - self._prompt_history.append(tool_message(result_str, tool_call_id)) - if result_str == '200 OK': - self.update_openapi_spec(response ) + def has_no_numbers(self, path): + return not any(char.isdigit() for char in path) - self.log_db.add_log_query(self._run_id, turn, command, result, answer) - self.write_openapi_to_yaml() - return self._all_http_methods_found - def parse_http_status_line(self, status_line): - if status_line is None or status_line == "Not a valid flag": - return status_line - else: - # Split the status line into components - parts = status_line.split(' ', 2) - - # Check if the parts are at least three in number - if len(parts) >= 3: - protocol = parts[0] # e.g., "HTTP/1.1" - status_code = parts[1] # e.g., "200" - status_message = parts[2].split("\r\n")[0] # e.g., "OK" - print(f'status code:{status_code}, status msg:{status_message}') - return str(status_code + " " + status_message) - else: - raise ValueError("Invalid HTTP status line") - - def has_no_numbers(self,path): - for char in path: - if char.isdigit(): - return False - return True - def update_openapi_spec(self, response): - # This function should parse the request and update the OpenAPI specification - # For the purpose of this example, let's assume it parses JSON requests and updates paths - request = response.action - path = request.path - method = request.method - if path and method: - if path not in self.openapi_spec['endpoints']:#and self.has_no_numbers(path): - self.openapi_spec['endpoints'][path] = {} - self.openapi_spec['endpoints'][path][method.lower()] = { - "summary": f"{method} operation on {path}", - "responses": { - "200": { - "description": "Successful response", - "content": { - "application/json": { - "schema": {"type": "object"} # Simplified for example - } - } - } - } - } - - def write_openapi_to_yaml(self, filename='openapi_spec.yaml'): - """Write the OpenAPI specification to a YAML file.""" - try: - openapi_data = { - "openapi": self.openapi_spec["openapi"], - "info": self.openapi_spec["info"], - "servers": self.openapi_spec["servers"], - "paths": self.openapi_spec["endpoints"] - } - - # Ensure the directory exists - file_path = filename.split(".yaml")[0] - file_name = filename.split(".yaml")[0] + "_"+ self.current_time.strftime("%Y-%m-%d %H:%M:%S")+".yaml" - os.makedirs(file_path, exist_ok=True) - - with open(os.path.join(file_path, file_name), 'w') as yaml_file: - yaml.dump(openapi_data, yaml_file, allow_unicode=True, default_flow_style=False) - self.console.print(f"[green]OpenAPI specification written to [bold]{filename}[/bold].") - except Exception as e: - raise Exception(e) - - #self.console.print(f"[red]Error writing YAML file: {e}") - def write_openapi_to_yaml2(self, filename='openapi_spec.yaml'): - """Write the OpenAPI specification to a YAML file.""" - try: - # self.setup_yaml() # Configure YAML to handle complex types - with open(filename, 'w') as yaml_file: - yaml.dump(self.openapi_spec, yaml_file, allow_unicode=True, default_flow_style=False) - self.console.print(f"[green]OpenAPI specification written to [bold]{filename}[/bold].") - except TypeError as e: - raise Exception(e) - #self.console.print(f"[red]Error writing YAML file: {e}") - - def represent_dict_order(self, data): - return self.represent_mapping('tag:yaml.org,2002:map', data.items()) - - def setup_yaml(self): - """Configure YAML to output OrderedDicts as regular dicts (helpful for better YAML readability).""" - yaml.add_representer(dict, self.represent_dict_order) +@use_case("Minimal implementation of a web API testing use case") +class SimpleWebAPIDocumentationUseCase(AutonomousAgentUseCase[SimpleWebAPIDocumentation]): + pass \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/simple_web_api_testing.py b/src/hackingBuddyGPT/usecases/web_api_testing/simple_web_api_testing.py index 96d4a78..3f8e1dd 100644 --- a/src/hackingBuddyGPT/usecases/web_api_testing/simple_web_api_testing.py +++ b/src/hackingBuddyGPT/usecases/web_api_testing/simple_web_api_testing.py @@ -1,46 +1,37 @@ -import time - from dataclasses import dataclass, field +from typing import List, Any, Union, Dict + +import pydantic_core from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessage from rich.panel import Panel -from typing import List, Any, Union, Dict from hackingBuddyGPT.capabilities import Capability -from hackingBuddyGPT.capabilities.capability import capabilities_to_action_model from hackingBuddyGPT.capabilities.http_request import HTTPRequest from hackingBuddyGPT.capabilities.record_note import RecordNote -from hackingBuddyGPT.capabilities.submit_flag import SubmitFlag -from hackingBuddyGPT.usecases.common_patterns import RoundBasedUseCase +from hackingBuddyGPT.capabilities.submit_http_method import SubmitHTTPMethod +from hackingBuddyGPT.usecases.agents import Agent +from hackingBuddyGPT.usecases.web_api_testing.utils.llm_handler import LLMHandler from hackingBuddyGPT.usecases.web_api_testing.prompt_engineer import PromptEngineer, PromptStrategy -from hackingBuddyGPT.utils import LLMResult, tool_message, ui +from hackingBuddyGPT.usecases.web_api_testing.utils.response_handler import ResponseHandler +from hackingBuddyGPT.utils import tool_message from hackingBuddyGPT.utils.configurable import parameter from hackingBuddyGPT.utils.openai.openai_lib import OpenAILib -from hackingBuddyGPT.usecases import use_case - -import pydantic_core +from hackingBuddyGPT.usecases.base import AutonomousAgentUseCase, use_case Prompt = List[Union[ChatCompletionMessage, ChatCompletionMessageParam]] Context = Any - -@use_case("simple_web_api_testing", "Minimal implementation of a web api testing use case") -@dataclass -class SimpleWebAPITesting(RoundBasedUseCase): +class SimpleWebAPITesting(Agent): llm: OpenAILib host: str = parameter(desc="The host to test", default="https://jsonplaceholder.typicode.com") - # Parameter specifying the pattern description for expected HTTP methods in the API response http_method_description: str = parameter( desc="Pattern description for expected HTTP methods in the API response", default="A string that represents an HTTP method (e.g., 'GET', 'POST', etc.)." ) - - # Parameter specifying the template used to format HTTP methods in API requests http_method_template: str = parameter( desc="Template used to format HTTP methods in API requests. The {method} placeholder will be replaced by actual HTTP method names.", default="{method}" ) - - # Parameter specifying the expected HTTP methods as a comma-separated list http_methods: str = parameter( desc="Comma-separated list of HTTP methods expected to be used in the API response.", default="GET,POST,PUT,DELETE" @@ -51,83 +42,99 @@ class SimpleWebAPITesting(RoundBasedUseCase): _capabilities: Dict[str, Capability] = field(default_factory=dict) _all_http_methods_found: bool = False + + def init(self): + """ + Initializes the SimpleWebAPITesting use case by setting up the context, response handler, + LLM handler, capabilities, and the initial prompt. + """ super().init() - self._prompt_history.append( - { - "role": "system", - "content": f"You're a seasoned API tester tasked with evaluating the REST APIs of a website hosted at {self.host}. " - f"Your primary objective is to thoroughly explore the APIs, understanding their endpoints, parameters, and responses, and then devise and execute potential exploits.\n" - f"Be diligent in documenting your findings as you navigate through the APIs..\n" - f"Avoid resorting to brute-force tactics. All necessary information can be obtained from the API endpoints themselves.\n" - f"Remember, if you encounter a http methods ({self.http_method_description}), promptly submit it as it is of utmost importance." - }) - self.prompt_engineer = PromptEngineer(strategy=PromptStrategy.CHAIN_OF_THOUGHT, - api_key=self.llm.api_key, - history=self._prompt_history) - self._context["host"] = self.host - sett = set(self.http_method_template.format(method=method) for method in self.http_methods.split(",")) - flag = SubmitFlag(self.http_method_description, - sett, - success_function=self.all_http_methods_found) - print(f'Valid flags:{flag.valid_flags}') - self._capabilities = { - "submit_http_method": flag, - "http_request": HTTPRequest(self.host), - "record_note": RecordNote(self._context["notes"]), + self._setup_capabilities() + self.llm_handler = LLMHandler(self.llm, self._capabilities) + self.response_handler = ResponseHandler(self.llm_handler) + + self._setup_initial_prompt() + + def _setup_initial_prompt(self): + """ + Sets up the initial prompt for the LLM. The prompt provides instructions for the LLM + to evaluate the REST APIs of the specified host and to document findings. + """ + initial_prompt = { + "role": "system", + "content": ( + f"You're a seasoned API tester tasked with evaluating the REST APIs of a website hosted at {self.host}. " + f"Your primary objective is to thoroughly explore the APIs, understanding their endpoints, parameters, and responses, " + f"and then devise and execute potential exploits. Be diligent in documenting your findings as you navigate through the APIs. " + f"Avoid resorting to brute-force tactics. All necessary information can be obtained from the API endpoints themselves. " + f"Remember, if you encounter an HTTP method ({self.http_method_description}), promptly submit it as it is of utmost importance." + ) } + self._prompt_history.append(initial_prompt) + self.prompt_engineer = PromptEngineer( + strategy=PromptStrategy.CHAIN_OF_THOUGHT, llm_handler=self.llm_handler, + history=self._prompt_history, schemas={}, response_handler=self.response_handler + ) def all_http_methods_found(self): - self.console.print(Panel("All HTTP methods found! Congratulations!", title="system")) + """ + Handles the event when all HTTP methods are found. Displays a congratulatory message + and sets the _all_http_methods_found flag to True. + """ + self._log.console.print(Panel("All HTTP methods found! Congratulations!", title="system")) self._all_http_methods_found = True - def perform_round(self, turn: int): - with self.console.status("[bold green]Asking LLM for a new command..."): - # generate prompt - prompt = self.prompt_engineer.generate_prompt() - - - tic = time.perf_counter() - response, completion = self.llm.instructor.chat.completions.create_with_completion(model=self.llm.model, - messages=prompt, - response_model=capabilities_to_action_model( - self._capabilities)) - toc = time.perf_counter() - - message = completion.choices[0].message - tool_call_id = message.tool_calls[0].id - command = pydantic_core.to_json(response).decode() - self.console.print(Panel(command, title="assistant")) - self._prompt_history.append(message) - - answer = LLMResult(completion.choices[0].message.content, str(prompt), - completion.choices[0].message.content, toc - tic, completion.usage.prompt_tokens, - completion.usage.completion_tokens) + def _setup_capabilities(self): + """ + Sets up the capabilities required for the use case. Initializes HTTP request capabilities, + note recording capabilities, and HTTP method submission capabilities based on the provided + configuration. + """ + methods_set = {self.http_method_template.format(method=method) for method in self.http_methods.split(",")} + notes = self._context["notes"] + self._capabilities = { + "submit_http_method": HTTPRequest(self.host), + "http_request": HTTPRequest(self.host), + "record_note": RecordNote(notes) + } - with self.console.status("[bold green]Executing that command..."): + def perform_round(self, turn: int, FINAL_ROUND=30): + """ + Performs a single round of interaction with the LLM. Generates a prompt, sends it to the LLM, + and handles the response. + + Args: + turn (int): The current round number. + FINAL_ROUND (int, optional): The final round number. Defaults to 30. + """ + prompt = self.prompt_engineer.generate_prompt(doc=True) + response, completion = self.llm_handler.call_llm(prompt) + self._handle_response(completion, response) + + def _handle_response(self, completion, response): + """ + Handles the response from the LLM. Parses the response, executes the necessary actions, + and updates the prompt history. + + Args: + completion (Any): The completion object from the LLM. + response (Any): The response object from the LLM. + """ + message = completion.choices[0].message + tool_call_id = message.tool_calls[0].id + command = pydantic_core.to_json(response).decode() + self._log.console.print(Panel(command, title="assistant")) + self._prompt_history.append(message) + + with self._log.console.status("[bold green]Executing that command..."): result = response.execute() - self.console.print(Panel(result, title="tool")) - result_str = self.parse_http_status_line(result) + self._log.console.print(Panel(result[:30], title="tool")) + result_str = self.response_handler.parse_http_status_line(result) self._prompt_history.append(tool_message(result_str, tool_call_id)) - - self.log_db.add_log_query(self._run_id, turn, command, result, answer) - return self._all_http_methods_found - - def parse_http_status_line(self, status_line): - if status_line is None or status_line == "Not a valid flag": - return status_line - else: - # Split the status line into components - parts = status_line.split(' ', 2) - - # Check if the parts are at least three in number - if len(parts) >= 3: - protocol = parts[0] # e.g., "HTTP/1.1" - status_code = parts[1] # e.g., "200" - status_message = parts[2].split("\r\n")[0] # e.g., "OK" - print(f'status code:{status_code}, status msg:{status_message}') - return str(status_code + " " + status_message) - else: - raise ValueError("Invalid HTTP status line") + return self.all_http_methods_found() +@use_case("Minimal implementation of a web API testing use case") +class SimpleWebAPITestingUseCase(AutonomousAgentUseCase[SimpleWebAPITesting]): + pass \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/utils/__init__.py b/src/hackingBuddyGPT/usecases/web_api_testing/utils/__init__.py new file mode 100644 index 0000000..a856540 --- /dev/null +++ b/src/hackingBuddyGPT/usecases/web_api_testing/utils/__init__.py @@ -0,0 +1,5 @@ +from .openapi_specification_manager import OpenAPISpecificationManager +from .llm_handler import LLMHandler +from .response_handler import ResponseHandler +from .openapi_parser import OpenAPISpecificationParser +from .yaml_assistant import YamlFileAssistant \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/utils/llm_handler.py b/src/hackingBuddyGPT/usecases/web_api_testing/utils/llm_handler.py new file mode 100644 index 0000000..1fe0026 --- /dev/null +++ b/src/hackingBuddyGPT/usecases/web_api_testing/utils/llm_handler.py @@ -0,0 +1,64 @@ +from hackingBuddyGPT.capabilities.capability import capabilities_to_action_model + +class LLMHandler(object): + """ + LLMHandler is a class responsible for managing interactions with a large language model (LLM). + It handles the execution of prompts and the management of created objects based on the capabilities. + + Attributes: + llm (object): The large language model to interact with. + _capabilities (dict): A dictionary of capabilities that define the actions the LLM can perform. + created_objects (dict): A dictionary to keep track of created objects by their type. + """ + + def __init__(self, llm, capabilities): + """ + Initializes the LLMHandler with the specified LLM and capabilities. + + Args: + llm (object): The large language model to interact with. + capabilities (dict): A dictionary of capabilities that define the actions the LLM can perform. + """ + self.llm = llm + self._capabilities = capabilities + self.created_objects = {} + + def call_llm(self, prompt): + """ + Calls the LLM with the specified prompt and retrieves the response. + + Args: + prompt (list): The prompt messages to send to the LLM. + + Returns: + response (object): The response from the LLM. + """ + print(f'Capabilities:{self._capabilities}') + return self.llm.instructor.chat.completions.create_with_completion( + model=self.llm.model, + messages=prompt, + response_model=capabilities_to_action_model(self._capabilities) + ) + + def add_created_object(self, created_object, object_type): + """ + Adds a created object to the dictionary of created objects, categorized by object type. + + Args: + created_object (object): The object that was created. + object_type (str): The type/category of the created object. + """ + if object_type not in self.created_objects: + self.created_objects[object_type] = [] + if len(self.created_objects[object_type]) < 7: + self.created_objects[object_type].append(created_object) + + def get_created_objects(self): + """ + Retrieves the dictionary of created objects and prints its contents. + + Returns: + dict: The dictionary of created objects. + """ + print(f'created_objects: {self.created_objects}') + return self.created_objects diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/utils/openapi_converter.py b/src/hackingBuddyGPT/usecases/web_api_testing/utils/openapi_converter.py new file mode 100644 index 0000000..5b9c5ed --- /dev/null +++ b/src/hackingBuddyGPT/usecases/web_api_testing/utils/openapi_converter.py @@ -0,0 +1,96 @@ +import os.path +import yaml +import json + +class OpenAPISpecificationConverter: + """ + OpenAPISpecificationConverter is a class for converting OpenAPI specification files between YAML and JSON formats. + + Attributes: + base_directory (str): The base directory for the output files. + """ + + def __init__(self, base_directory): + """ + Initializes the OpenAPISpecificationConverter with the specified base directory. + + Args: + base_directory (str): The base directory for the output files. + """ + self.base_directory = base_directory + + def convert_file(self, input_filepath, output_directory, input_type, output_type): + """ + Converts files between YAML and JSON formats. + + Args: + input_filepath (str): The path to the input file. + output_directory (str): The subdirectory for the output files. + input_type (str): The type of the input file ('yaml' or 'json'). + output_type (str): The type of the output file ('json' or 'yaml'). + + Returns: + str: The path to the converted output file, or None if an error occurred. + """ + try: + filename = os.path.basename(input_filepath) + output_filename = filename.replace(f".{input_type}", f".{output_type}") + output_path = os.path.join(self.base_directory, output_directory, output_filename) + + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + with open(input_filepath, 'r') as infile: + if input_type == 'yaml': + content = yaml.safe_load(infile) + else: + content = json.load(infile) + + with open(output_path, 'w') as outfile: + if output_type == 'yaml': + yaml.dump(content, outfile, allow_unicode=True, default_flow_style=False) + else: + json.dump(content, outfile, indent=2) + + print(f"Successfully converted {input_filepath} to {output_filename}") + return output_path + + except Exception as e: + print(f"Error converting {input_filepath}: {e}") + return None + + def yaml_to_json(self, yaml_filepath): + """ + Converts a YAML file to a JSON file. + + Args: + yaml_filepath (str): The path to the YAML file to be converted. + + Returns: + str: The path to the converted JSON file, or None if an error occurred. + """ + return self.convert_file(yaml_filepath, "json", 'yaml', 'json') + + def json_to_yaml(self, json_filepath): + """ + Converts a JSON file to a YAML file. + + Args: + json_filepath (str): The path to the JSON file to be converted. + + Returns: + str: The path to the converted YAML file, or None if an error occurred. + """ + return self.convert_file(json_filepath, "yaml", 'json', 'yaml') + + +# Usage example +if __name__ == '__main__': + yaml_input = '/home/diana/Desktop/masterthesis/hackingBuddyGPT/src/hackingBuddyGPT/usecases/web_api_testing/openapi_spec/openapi_spec_2024-06-13_17-16-25.yaml' + + converter = OpenAPISpecificationConverter("converted_files") + # Convert YAML to JSON + json_file = converter.yaml_to_json(yaml_input) + + # Convert JSON to YAML + if json_file: + converter.json_to_yaml(json_file) diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/utils/openapi_parser.py b/src/hackingBuddyGPT/usecases/web_api_testing/utils/openapi_parser.py new file mode 100644 index 0000000..182b0a5 --- /dev/null +++ b/src/hackingBuddyGPT/usecases/web_api_testing/utils/openapi_parser.py @@ -0,0 +1,87 @@ +import yaml + +class OpenAPISpecificationParser: + """ + OpenAPISpecificationParser is a class for parsing and extracting information from an OpenAPI specification file. + + Attributes: + filepath (str): The path to the OpenAPI specification YAML file. + api_data (dict): The parsed data from the YAML file. + """ + + def __init__(self, filepath): + """ + Initializes the OpenAPISpecificationParser with the specified file path. + + Args: + filepath (str): The path to the OpenAPI specification YAML file. + """ + self.filepath = filepath + self.api_data = self.load_yaml() + + def load_yaml(self): + """ + Loads YAML data from the specified file. + + Returns: + dict: The parsed data from the YAML file. + """ + with open(self.filepath, 'r') as file: + return yaml.safe_load(file) + + def get_servers(self): + """ + Retrieves the list of server URLs from the OpenAPI specification. + + Returns: + list: A list of server URLs. + """ + return [server['url'] for server in self.api_data.get('servers', [])] + + def get_paths(self): + """ + Retrieves all API paths and their methods from the OpenAPI specification. + + Returns: + dict: A dictionary with API paths as keys and methods as values. + """ + paths_info = {} + paths = self.api_data.get('paths', {}) + for path, methods in paths.items(): + paths_info[path] = {method: details for method, details in methods.items()} + return paths_info + + def get_operations(self, path): + """ + Retrieves operations for a specific path from the OpenAPI specification. + + Args: + path (str): The API path to retrieve operations for. + + Returns: + dict: A dictionary with methods as keys and operation details as values. + """ + return self.api_data['paths'].get(path, {}) + + def print_api_details(self): + """ + Prints details of the API extracted from the OpenAPI document, including title, version, servers, + paths, and operations. + """ + print("API Title:", self.api_data['info']['title']) + print("API Version:", self.api_data['info']['version']) + print("Servers:", self.get_servers()) + print("\nAvailable Paths and Operations:") + for path, operations in self.get_paths().items(): + print(f"\nPath: {path}") + for operation, details in operations.items(): + print(f" Operation: {operation.upper()}") + print(f" Summary: {details.get('summary')}") + print(f" Description: {details['responses']['200']['description']}") + +# Usage example +if __name__ == '__main__': + openapi_parser = OpenAPISpecificationParser( + '/hackingBuddyGPT/usecases/web_api_testing/openapi_spec/openapi_spec_2024-06-13_17-16-25.yaml' + ) + openapi_parser.print_api_details() diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/utils/openapi_specification_manager.py b/src/hackingBuddyGPT/usecases/web_api_testing/utils/openapi_specification_manager.py new file mode 100644 index 0000000..bdfc2e7 --- /dev/null +++ b/src/hackingBuddyGPT/usecases/web_api_testing/utils/openapi_specification_manager.py @@ -0,0 +1,154 @@ +import os +import yaml +from datetime import datetime +from hackingBuddyGPT.capabilities.yamlFile import YAMLFile + +class OpenAPISpecificationManager: + """ + Handles the generation and updating of an OpenAPI specification document based on dynamic API responses. + + Attributes: + response_handler (object): An instance of the response handler for processing API responses. + schemas (dict): A dictionary to store API schemas. + filename (str): The filename for the OpenAPI specification file. + openapi_spec (dict): The OpenAPI specification document structure. + llm_handler (object): An instance of the LLM handler for interacting with the LLM. + api_key (str): The API key for accessing the LLM. + file_path (str): The path to the directory where the OpenAPI specification file will be stored. + file (str): The complete path to the OpenAPI specification file. + _capabilities (dict): A dictionary to store capabilities related to YAML file handling. + """ + + def __init__(self, llm_handler, response_handler): + """ + Initializes the handler with a template OpenAPI specification. + + Args: + llm_handler (object): An instance of the LLM handler for interacting with the LLM. + response_handler (object): An instance of the response handler for processing API responses. + """ + self.response_handler = response_handler + self.schemas = {} + self.endpoint_methods ={} + self.filename = f"openapi_spec_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.yaml" + self.openapi_spec = { + "openapi": "3.0.0", + "info": { + "title": "Generated API Documentation", + "version": "1.0", + "description": "Automatically generated description of the API." + }, + "servers": [{"url": "https://jsonplaceholder.typicode.com"}], + "endpoints": {}, + "components": {"schemas": {}} + } + self.llm_handler = llm_handler + #self.api_key = llm_handler.llm.api_key + current_path = os.path.dirname(os.path.abspath(__file__)) + self.file_path = os.path.join(current_path, "openapi_spec") + self.file = os.path.join(self.file_path, self.filename) + self._capabilities = { + "yaml": YAMLFile() + } + + def is_partial_match(self, element, string_list): + return any(element in string or string in element for string in string_list) + + def update_openapi_spec(self, resp, result): + """ + Updates the OpenAPI specification based on the API response provided. + + Args: + resp (object): The response object containing details like the path and method which should be documented. + result (str): The result of the API call. + """ + request = resp.action + + if request.__class__.__name__ == 'RecordNote': # TODO: check why isinstance does not work + self.check_openapi_spec(resp) + elif request.__class__.__name__ == 'HTTPRequest': + path = request.path + method = request.method + print(f'method: {method}') + # Ensure that path and method are not None and method has no numeric characters + # Ensure path and method are valid and method has no numeric characters + if path and method: + endpoint_methods = self.endpoint_methods + endpoints = self.openapi_spec['endpoints'] + x = path.split('/')[1] + + # Initialize the path if not already present + if path not in endpoints and x != "": + endpoints[path] = {} + if '1' not in path: + endpoint_methods[path] = [] + + # Update the method description within the path + example, reference, self.openapi_spec = self.response_handler.parse_http_response_to_openapi_example( + self.openapi_spec, result, path, method + ) + self.schemas = self.openapi_spec["components"]["schemas"] + + if example or reference: + endpoints[path][method.lower()] = { + "summary": f"{method} operation on {path}", + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": {"$ref": reference}, + "examples": example + } + } + } + } + } + + if '1' not in path and x != "": + endpoint_methods[path].append(method) + elif self.is_partial_match(x, endpoints.keys()): + path = f"/{x}" + print(f'endpoint methods = {endpoint_methods}') + print(f'new path:{path}') + endpoint_methods[path].append(method) + + endpoint_methods[path] = list(set(endpoint_methods[path])) + + return list(endpoints.keys()) + + def write_openapi_to_yaml(self): + """ + Writes the updated OpenAPI specification to a YAML file with a timestamped filename. + """ + try: + # Prepare data to be written to YAML + openapi_data = { + "openapi": self.openapi_spec["openapi"], + "info": self.openapi_spec["info"], + "servers": self.openapi_spec["servers"], + "components": self.openapi_spec["components"], + "paths": self.openapi_spec["endpoints"] + } + + # Create directory if it doesn't exist and generate the timestamped filename + os.makedirs(self.file_path, exist_ok=True) + + # Write to YAML file + with open(self.file, 'w') as yaml_file: + yaml.dump(openapi_data, yaml_file, allow_unicode=True, default_flow_style=False) + print(f"OpenAPI specification written to {self.filename}.") + except Exception as e: + raise Exception(f"Error writing YAML file: {e}") + + def check_openapi_spec(self, note): + """ + Uses OpenAI's GPT model to generate a complete OpenAPI specification based on a natural language description. + + Args: + note (object): The note object containing the description of the API. + """ + description = self.response_handler.extract_description(note) + from hackingBuddyGPT.usecases.web_api_testing.utils.yaml_assistant import YamlFileAssistant + yaml_file_assistant = YamlFileAssistant(self.file_path, self.llm_handler) + yaml_file_assistant.run(description) diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/utils/response_handler.py b/src/hackingBuddyGPT/usecases/web_api_testing/utils/response_handler.py new file mode 100644 index 0000000..da87481 --- /dev/null +++ b/src/hackingBuddyGPT/usecases/web_api_testing/utils/response_handler.py @@ -0,0 +1,224 @@ +import json +from bs4 import BeautifulSoup +import re + +class ResponseHandler(object): + """ + ResponseHandler is a class responsible for handling various types of responses from an LLM (Large Language Model). + It processes prompts, parses HTTP responses, extracts examples, and handles OpenAPI specifications. + + Attributes: + llm_handler (object): An instance of the LLM handler for interacting with the LLM. + """ + + def __init__(self, llm_handler): + """ + Initializes the ResponseHandler with the specified LLM handler. + + Args: + llm_handler (object): An instance of the LLM handler for interacting with the LLM. + """ + self.llm_handler = llm_handler + + def get_response_for_prompt(self, prompt): + """ + Sends a prompt to the LLM's API and retrieves the response. + + Args: + prompt (str): The prompt to be sent to the API. + + Returns: + str: The response from the API. + """ + messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] + response, completion = self.llm_handler.call_llm(messages) + response_text = response.execute() + return response_text + + def parse_http_status_line(self, status_line): + """ + Parses an HTTP status line and returns the status code and message. + + Args: + status_line (str): The HTTP status line to be parsed. + + Returns: + str: The parsed status code and message. + + Raises: + ValueError: If the status line is invalid. + """ + if status_line == "Not a valid HTTP method": + return status_line + status_line = status_line.split('\r\n')[0] + # Regular expression to match valid HTTP status lines + match = re.match(r'^(HTTP/\d\.\d) (\d{3}) (.*)$', status_line) + if match: + protocol, status_code, status_message = match.groups() + return f'{status_code} {status_message}' + else: + raise ValueError("Invalid HTTP status line") + + def extract_response_example(self, html_content): + """ + Extracts the JavaScript example code and result placeholder from HTML content. + + Args: + html_content (str): The HTML content containing the example code. + + Returns: + dict: The extracted response example as a dictionary, or None if extraction fails. + """ + soup = BeautifulSoup(html_content, 'html.parser') + example_code = soup.find('code', {'id': 'example'}) + result_code = soup.find('code', {'id': 'result'}) + if example_code and result_code: + example_text = example_code.get_text() + result_text = result_code.get_text() + return json.loads(result_text) + return None + + def parse_http_response_to_openapi_example(self, openapi_spec, http_response, path, method): + """ + Parses an HTTP response to generate an OpenAPI example. + + Args: + openapi_spec (dict): The OpenAPI specification to update. + http_response (str): The HTTP response to parse. + path (str): The API path. + method (str): The HTTP method. + + Returns: + tuple: A tuple containing the entry dictionary, reference, and updated OpenAPI specification. + """ + + headers, body = http_response.split('\r\n\r\n', 1) + try: + body_dict = json.loads(body) + except json.decoder.JSONDecodeError: + return None, None, openapi_spec + + reference, object_name, openapi_spec = self.parse_http_response_to_schema(openapi_spec, body_dict, path) + entry_dict = {} + + if len(body_dict) == 1: + entry_dict["id"] = {"value": body_dict} + self.llm_handler.add_created_object(entry_dict, object_name) + else: + if isinstance(body_dict, list): + for entry in body_dict: + key = entry.get("title") or entry.get("name") or entry.get("id") + entry_dict[key] = {"value": entry} + self.llm_handler.add_created_object(entry_dict[key], object_name) + else: + print(f'entry: {body_dict}') + + key = body_dict.get("title") or body_dict.get("name") or body_dict.get("id") + entry_dict[key] = {"value": body_dict} + self.llm_handler.add_created_object(entry_dict[key], object_name) + + + return entry_dict, reference, openapi_spec + + def extract_description(self, note): + """ + Extracts the description from a note. + + Args: + note (object): The note containing the description. + + Returns: + str: The extracted description. + """ + return note.action.content + + def parse_http_response_to_schema(self, openapi_spec, body_dict, path): + """ + Parses an HTTP response body to generate an OpenAPI schema. + + Args: + openapi_spec (dict): The OpenAPI specification to update. + body_dict (dict): The HTTP response body as a dictionary. + path (str): The API path. + + Returns: + tuple: A tuple containing the reference, object name, and updated OpenAPI specification. + """ + object_name = path.split("/")[1].capitalize().rstrip('s') + properties_dict = {} + + if len(body_dict) == 1: + properties_dict["id"] = {"type": "int", "format": "uuid", "example": str(body_dict["id"])} + else: + + for param in body_dict: + if isinstance(body_dict, list): + for key, value in param.items(): + properties_dict =self.extract_keys(key, value, properties_dict) + break + else: + for key, value in body_dict.items(): + properties_dict = self.extract_keys(key, value, properties_dict) + print(f'properzies: {properties_dict}') + + + object_dict = {"type": "object", "properties": properties_dict} + + if object_name not in openapi_spec["components"]["schemas"]: + openapi_spec["components"]["schemas"][object_name] = object_dict + + reference = f"#/components/schemas/{object_name}" + return reference, object_name, openapi_spec + + def read_yaml_to_string(self, filepath): + """ + Reads a YAML file and returns its contents as a string. + + Args: + filepath (str): The path to the YAML file. + + Returns: + str: The contents of the YAML file, or None if an error occurred. + """ + try: + with open(filepath, 'r') as file: + return file.read() + except FileNotFoundError: + print(f"Error: The file {filepath} does not exist.") + return None + except IOError as e: + print(f"Error reading file {filepath}: {e}") + return None + + def extract_endpoints(self, note): + """ + Extracts API endpoints from a note using regular expressions. + + Args: + note (str): The note containing endpoint definitions. + + Returns: + dict: A dictionary with endpoints as keys and HTTP methods as values. + """ + required_endpoints = {} + pattern = r"(\d+\.\s+GET)\s(/[\w{}]+)" + matches = re.findall(pattern, note) + + for match in matches: + method, endpoint = match + method = method.split()[1] + if endpoint in required_endpoints: + if method not in required_endpoints[endpoint]: + required_endpoints[endpoint].append(method) + else: + required_endpoints[endpoint] = [method] + + return required_endpoints + + def extract_keys(self, key, value, properties_dict): + if key == "id": + properties_dict[key] = {"type": str(type(value).__name__), "format": "uuid", "example": str(value)} + else: + properties_dict[key] = {"type": str(type(value).__name__), "example": str(value)} + + return properties_dict diff --git a/src/hackingBuddyGPT/usecases/web_api_testing/utils/yaml_assistant.py b/src/hackingBuddyGPT/usecases/web_api_testing/utils/yaml_assistant.py new file mode 100644 index 0000000..d0e62b4 --- /dev/null +++ b/src/hackingBuddyGPT/usecases/web_api_testing/utils/yaml_assistant.py @@ -0,0 +1,58 @@ +from openai import OpenAI + + +class YamlFileAssistant(object): + def __init__(self, yaml_file, client): + self.yaml_file = yaml_file + self.client = client + + def run(self, recorded_note): + ''' assistant = self.client.beta.assistants.create( + name="Yaml File Analysis Assistant", + instructions="You are an OpenAPI specification analyst. Use you knowledge to check " + f"if the following information is contained in the provided yaml file. Information:{recorded_note}", + model="gpt-4o", + tools=[{"type": "file_search"}], + ) + + # Create a vector store caled "Financial Statements" + vector_store = self.client.beta.vector_stores.create(name="Financial Statements") + + # Ready the files for upload to OpenAI + file_streams = [open(self.yaml_file, "rb") ] + + # Use the upload and poll SDK helper to upload the files, add them to the vector store, + # and poll the status of the file batch for completion. + file_batch = self.client.beta.vector_stores.file_batches.upload_and_poll( + vector_store_id=vector_store.id, files=file_streams + ) + + # You can print the status and the file counts of the batch to see the result of this operation. + print(file_batch.status) + print(file_batch.file_counts) + + assistant = self.client.beta.assistants.update( + assistant_id=assistant.id, + tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}}, + ) + # Upload the user provided file to OpenAI + message_file = self.client.files.create( + file=open("edgar/aapl-10k.pdf", "rb"), purpose="assistants" + ) + + # Create a thread and attach the file to the message + thread = self.client.beta.threads.create( + messages=[ + { + "role": "user", + "content": "How many shares of AAPL were outstanding at the end of of October 2023?", + # Attach the new file to the message. + "attachments": [ + {"file_id": message_file.id, "tools": [{"type": "file_search"}]} + ], + } + ] + ) + + # The thread now has a vector store with that file in its tool resources. + print(thread.tool_resources.file_search)''' diff --git a/src/hackingBuddyGPT/utils/configurable.py b/src/hackingBuddyGPT/utils/configurable.py index 33a451c..6a41e79 100644 --- a/src/hackingBuddyGPT/utils/configurable.py +++ b/src/hackingBuddyGPT/utils/configurable.py @@ -3,7 +3,7 @@ import inspect import os from dataclasses import dataclass -from typing import Any, Dict +from typing import Any, Dict, TypeVar from dotenv import load_dotenv @@ -14,7 +14,7 @@ def parameter(*, desc: str, default=dataclasses.MISSING, init: bool = True, repr: bool = True, hash=None, - compare: bool = True, metadata: Dict = None, kw_only: bool = dataclasses.MISSING) -> dataclasses.Field: + compare: bool = True, metadata: Dict = None, kw_only: bool = dataclasses.MISSING): if metadata is None: metadata = dict() metadata["desc"] = desc @@ -37,15 +37,14 @@ class ParameterDefinition: default: Any description: str - def parser(self, basename: str, parser: argparse.ArgumentParser): - name = f"{basename}{self.name}" + def parser(self, name: str, parser: argparse.ArgumentParser): default = get_default(name, self.default) parser.add_argument(f"--{name}", type=self.type, default=default, required=default is None, help=self.description) - def get(self, basename: str, args: argparse.Namespace): - return getattr(args, f"{basename}{self.name}") + def get(self, name: str, args: argparse.Namespace): + return getattr(args, name) ParameterDefinitions = Dict[str, ParameterDefinition] @@ -60,19 +59,25 @@ class ComplexParameterDefinition(ParameterDefinition): it. So if you have recursive type definitions that you try to make configurable, this will not work. """ parameters: ParameterDefinitions + transparent: bool = False def parser(self, basename: str, parser: argparse.ArgumentParser): for name, parameter in self.parameters.items(): if isinstance(parameter, dict): - build_parser(parameter, parser, f"{basename}{self.name}.") + build_parser(parameter, parser, next_name(basename, name, parameter)) else: - parameter.parser(f"{basename}{self.name}.", parser) + parameter.parser(next_name(basename, name, parameter), parser) + + def get(self, name: str, args: argparse.Namespace): + args = get_arguments(self.parameters, args, name) - def get(self, basename: str, args: argparse.Namespace): - parameter = self.type(**get_arguments(self.parameters, args, f"{basename}{self.name}.")) - if hasattr(parameter, "init"): - parameter.init() - return parameter + def create(): + instance = self.type(**args) + if hasattr(instance, "init") and not getattr(self.type, "__transparent__", False): + instance.init() + setattr(instance, "configurable_recreate", create) + return instance + return create() def get_class_parameters(cls, name: str = None, fields: Dict[str, dataclasses.Field] = None) -> ParameterDefinitions: @@ -94,7 +99,7 @@ def get_parameters(fun, basename: str, fields: Dict[str, dataclasses.Field] = No continue if not param.annotation: - raise ValueError(f"Parameter {name} of {basename}.{fun.__name__} must have a type annotation") + raise ValueError(f"Parameter {name} of {basename} must have a type annotation") default = param.default if param.default != inspect.Parameter.empty else None description = None @@ -113,22 +118,22 @@ def get_parameters(fun, basename: str, fields: Dict[str, dataclasses.Field] = No type = field.type if hasattr(type, "__parameters__"): - params[name] = ComplexParameterDefinition(name, type, default, description, get_class_parameters(type, f"{basename}.{fun.__name__}")) + params[name] = ComplexParameterDefinition(name, type, default, description, get_class_parameters(type, basename), transparent=getattr(type, "__transparent__", False)) elif type in (str, int, float, bool): params[name] = ParameterDefinition(name, type, default, description) else: - raise ValueError(f"Parameter {name} of {basename}.{fun.__name__} must have str, int, bool, or a __parameters__ class as type, not {type}") + raise ValueError(f"Parameter {name} of {basename} must have str, int, bool, or a __parameters__ class as type, not {type}") return params def build_parser(parameters: ParameterDefinitions, parser: argparse.ArgumentParser, basename: str = ""): for name, parameter in parameters.items(): - parameter.parser(basename, parser) + parameter.parser(next_name(basename, name, parameter), parser) def get_arguments(parameters: ParameterDefinitions, args: argparse.Namespace, basename: str = "") -> Dict[str, Any]: - return {name: parameter.get(basename, args) for name, parameter in parameters.items()} + return {name: parameter.get(next_name(basename, name, parameter), args) for name, parameter in parameters.items()} Configurable = Type # TODO: Define type @@ -149,3 +154,45 @@ def inner(cls) -> Configurable: return cls return inner + + +T = TypeVar("T") + + +def transparent(subclass: T) -> T: + """ + setting a type to be transparent means, that it will not increase a level in the configuration tree, so if you have the following classes: + + class Inner: + a: int + b: str + + def init(self): + print("inner init") + + class Outer: + inner: transparent(Inner) + + def init(self): + inner.init() + + the configuration will be `--a` and `--b` instead of `--inner.a` and `--inner.b`. + + A transparent attribute will also not have its init function called automatically, so you will need to do that on your own, as seen in the Outer init. + """ + class Cloned(subclass): + __transparent__ = True + Cloned.__name__ = subclass.__name__ + Cloned.__qualname__ = subclass.__qualname__ + Cloned.__module__ = subclass.__module__ + return Cloned + + +def next_name(basename: str, name: str, param: Any) -> str: + if isinstance(param, ComplexParameterDefinition) and param.transparent: + return basename + elif basename == "": + return name + else: + return f"{basename}.{name}" + diff --git a/src/hackingBuddyGPT/utils/db_storage/db_storage.py b/src/hackingBuddyGPT/utils/db_storage/db_storage.py index 06787e4..497c023 100644 --- a/src/hackingBuddyGPT/utils/db_storage/db_storage.py +++ b/src/hackingBuddyGPT/utils/db_storage/db_storage.py @@ -33,7 +33,6 @@ def setup_db(self): self.cursor.execute("""CREATE TABLE IF NOT EXISTS runs ( id INTEGER PRIMARY KEY, model text, - context_size INTEGER, state TEXT, tag TEXT, started_at text, @@ -81,10 +80,10 @@ def setup_db(self): self.analyze_response_id = self.insert_or_select_cmd('analyze_response') self.state_update_id = self.insert_or_select_cmd('update_state') - def create_new_run(self, model, context_size, tag): + def create_new_run(self, model, tag): self.cursor.execute( - "INSERT INTO runs (model, context_size, state, tag, started_at) VALUES (?, ?, ?, ?, datetime('now'))", - (model, context_size, "in progress", tag)) + "INSERT INTO runs (model, state, tag, started_at) VALUES (?, ?, ?, datetime('now'))", + (model, "in progress", tag)) return self.cursor.lastrowid def add_log_query(self, run_id, round, cmd, result, answer): diff --git a/tests/integration_minimal_test.py b/tests/integration_minimal_test.py index 0f4abe5..06319bc 100644 --- a/tests/integration_minimal_test.py +++ b/tests/integration_minimal_test.py @@ -1,8 +1,8 @@ from typing import Tuple -from hackingBuddyGPT.usecases.minimal.agent import MinimalLinuxPrivesc -from hackingBuddyGPT.usecases.minimal.agent_with_state import MinimalLinuxTemplatedPrivesc -from hackingBuddyGPT.usecases.privesc.linux import LinuxPrivesc +from hackingBuddyGPT.usecases.minimal.agent import MinimalLinuxPrivesc, MinimalLinuxPrivescUseCase +from hackingBuddyGPT.usecases.minimal.agent_with_state import MinimalLinuxTemplatedPrivesc, MinimalLinuxTemplatedPrivescUseCase +from hackingBuddyGPT.usecases.privesc.linux import LinuxPrivesc, LinuxPrivescUseCase from hackingBuddyGPT.utils.console.console import Console from hackingBuddyGPT.utils.db_storage.db_storage import DbStorage from hackingBuddyGPT.utils.llm_util import LLM, LLMResult @@ -74,14 +74,16 @@ def test_linuxprivesc(): log_db.init() - priv_esc = LinuxPrivesc( - conn=conn, - enable_explanation=False, - disable_history=False, - hint='', + priv_esc = LinuxPrivescUseCase( + agent = LinuxPrivesc( + conn=conn, + enable_explanation=False, + disable_history=False, + hint='', + llm = llm, + ), log_db = log_db, console = console, - llm = llm, tag = 'integration_test_linuxprivesc', max_turns = len(llm.responses) ) @@ -99,12 +101,14 @@ def test_minimal_agent(): log_db.init() - priv_esc = MinimalLinuxPrivesc( - conn=conn, + priv_esc = MinimalLinuxPrivescUseCase( + agent = MinimalLinuxPrivesc( + conn=conn, + llm=llm + ), log_db = log_db, console = console, - llm = llm, - tag = 'integration_test_linuxprivesc', + tag = 'integration_test_minimallinuxprivesc', max_turns = len(llm.responses) ) @@ -121,11 +125,13 @@ def test_minimal_agent_state(): log_db.init() - priv_esc = MinimalLinuxTemplatedPrivesc( - conn=conn, + priv_esc = MinimalLinuxTemplatedPrivescUseCase( + agent = MinimalLinuxTemplatedPrivesc( + conn=conn, + llm = llm, + ), log_db = log_db, console = console, - llm = llm, tag = 'integration_test_linuxprivesc', max_turns = len(llm.responses) ) diff --git a/tests/test_llm_handler.py b/tests/test_llm_handler.py new file mode 100644 index 0000000..9b209d2 --- /dev/null +++ b/tests/test_llm_handler.py @@ -0,0 +1,61 @@ +import unittest +from unittest.mock import MagicMock, patch +from hackingBuddyGPT.capabilities.capability import capabilities_to_action_model +from hackingBuddyGPT.usecases.web_api_testing.utils import LLMHandler + + +class TestLLMHandler(unittest.TestCase): + def setUp(self): + self.llm_mock = MagicMock() + self.capabilities = {'cap1': MagicMock(), 'cap2': MagicMock()} + self.llm_handler = LLMHandler(self.llm_mock, self.capabilities) + + '''@patch('hackingBuddyGPT.usecases.web_api_testing.utils.capabilities_to_action_model') + def test_call_llm(self, mock_capabilities_to_action_model): + prompt = [{'role': 'user', 'content': 'Hello, LLM!'}] + response_mock = MagicMock() + self.llm_mock.instructor.chat.completions.create_with_completion.return_value = response_mock + + # Mock the capabilities_to_action_model to return a dummy Pydantic model + mock_model = MagicMock() + mock_capabilities_to_action_model.return_value = mock_model + + response = self.llm_handler.call_llm(prompt) + + self.llm_mock.instructor.chat.completions.create_with_completion.assert_called_once_with( + model=self.llm_mock.model, + messages=prompt, + response_model=mock_model + ) + self.assertEqual(response, response_mock)''' + def test_add_created_object(self): + created_object = MagicMock() + object_type = 'test_type' + + self.llm_handler.add_created_object(created_object, object_type) + + self.assertIn(object_type, self.llm_handler.created_objects) + self.assertIn(created_object, self.llm_handler.created_objects[object_type]) + + def test_add_created_object_limit(self): + created_object = MagicMock() + object_type = 'test_type' + + for _ in range(8): # Exceed the limit of 7 objects + self.llm_handler.add_created_object(created_object, object_type) + + self.assertEqual(len(self.llm_handler.created_objects[object_type]), 7) + + def test_get_created_objects(self): + created_object = MagicMock() + object_type = 'test_type' + self.llm_handler.add_created_object(created_object, object_type) + + created_objects = self.llm_handler.get_created_objects() + + self.assertIn(object_type, created_objects) + self.assertIn(created_object, created_objects[object_type]) + self.assertEqual(created_objects, self.llm_handler.created_objects) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_openAPI_specification_manager.py b/tests/test_openAPI_specification_manager.py new file mode 100644 index 0000000..35b5dc2 --- /dev/null +++ b/tests/test_openAPI_specification_manager.py @@ -0,0 +1,53 @@ +import unittest +from unittest.mock import MagicMock, patch + +from hackingBuddyGPT.capabilities.http_request import HTTPRequest +from hackingBuddyGPT.usecases.web_api_testing.utils import OpenAPISpecificationManager + + +class TestSpecificationHandler(unittest.TestCase): + def setUp(self): + self.llm_handler = MagicMock() + self.response_handler = MagicMock() + self.doc_handler = OpenAPISpecificationManager(self.llm_handler, self.response_handler) + + @patch('os.makedirs') + @patch('builtins.open') + def test_write_openapi_to_yaml(self, mock_open, mock_makedirs): + self.doc_handler.write_openapi_to_yaml() + mock_makedirs.assert_called_once_with(self.doc_handler.file_path, exist_ok=True) + mock_open.assert_called_once_with(self.doc_handler.file, 'w') + + # Create a mock HTTPRequest object + response_mock = MagicMock() + response_mock.action = HTTPRequest( + host="https://jsonplaceholder.typicode.com", + follow_redirects=False, + use_cookie_jar=True + ) + response_mock.action.method = "GET" + response_mock.action.path = "/test" + + result = '{"key": "value"}' + + self.response_handler.parse_http_response_to_openapi_example = MagicMock( + return_value=({}, "#/components/schemas/TestSchema", self.doc_handler.openapi_spec) + ) + + endpoints = self.doc_handler.update_openapi_spec(response_mock, result) + + self.assertIn("/test", self.doc_handler.openapi_spec["endpoints"]) + self.assertIn("get", self.doc_handler.openapi_spec["endpoints"]["/test"]) + self.assertEqual(self.doc_handler.openapi_spec["endpoints"]["/test"]["get"]["summary"], + "GET operation on /test") + self.assertEqual(endpoints, ["/test"]) + + + def test_partial_match(self): + string_list = ["test_endpoint", "another_endpoint"] + self.assertTrue(self.doc_handler.is_partial_match("test", string_list)) + self.assertFalse(self.doc_handler.is_partial_match("not_in_list", string_list)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_openapi_converter.py b/tests/test_openapi_converter.py new file mode 100644 index 0000000..43354aa --- /dev/null +++ b/tests/test_openapi_converter.py @@ -0,0 +1,87 @@ +import unittest +from unittest.mock import patch, mock_open, MagicMock +import os +import yaml +import json + +from hackingBuddyGPT.usecases.web_api_testing.utils.openapi_converter import OpenAPISpecificationConverter + + +class TestOpenAPISpecificationConverter(unittest.TestCase): + def setUp(self): + self.converter = OpenAPISpecificationConverter("base_directory") + + @patch("os.makedirs") + @patch("builtins.open", new_callable=mock_open, read_data="yaml_content") + @patch("yaml.safe_load", return_value={"key": "value"}) + @patch("json.dump") + def test_convert_file_yaml_to_json(self, mock_json_dump, mock_yaml_safe_load, mock_open_file, mock_makedirs): + input_filepath = "input.yaml" + output_directory = "json" + input_type = "yaml" + output_type = "json" + expected_output_path = os.path.join("base_directory", output_directory, "input.json") + + result = self.converter.convert_file(input_filepath, output_directory, input_type, output_type) + + mock_open_file.assert_any_call(input_filepath, 'r') + mock_yaml_safe_load.assert_called_once() + mock_open_file.assert_any_call(expected_output_path, 'w') + mock_json_dump.assert_called_once_with({"key": "value"}, mock_open_file(), indent=2) + mock_makedirs.assert_called_once_with(os.path.join("base_directory", output_directory), exist_ok=True) + self.assertEqual(result, expected_output_path) + + @patch("os.makedirs") + @patch("builtins.open", new_callable=mock_open, read_data='{"key": "value"}') + @patch("json.load", return_value={"key": "value"}) + @patch("yaml.dump") + def test_convert_file_json_to_yaml(self, mock_yaml_dump, mock_json_load, mock_open_file, mock_makedirs): + input_filepath = "input.json" + output_directory = "yaml" + input_type = "json" + output_type = "yaml" + expected_output_path = os.path.join("base_directory", output_directory, "input.yaml") + + result = self.converter.convert_file(input_filepath, output_directory, input_type, output_type) + + mock_open_file.assert_any_call(input_filepath, 'r') + mock_json_load.assert_called_once() + mock_open_file.assert_any_call(expected_output_path, 'w') + mock_yaml_dump.assert_called_once_with({"key": "value"}, mock_open_file(), allow_unicode=True, default_flow_style=False) + mock_makedirs.assert_called_once_with(os.path.join("base_directory", output_directory), exist_ok=True) + self.assertEqual(result, expected_output_path) + + @patch("os.makedirs") + @patch("builtins.open", new_callable=mock_open, read_data="yaml_content") + @patch("yaml.safe_load", side_effect=Exception("YAML error")) + def test_convert_file_yaml_to_json_error(self, mock_yaml_safe_load, mock_open_file, mock_makedirs): + input_filepath = "input.yaml" + output_directory = "json" + input_type = "yaml" + output_type = "json" + + result = self.converter.convert_file(input_filepath, output_directory, input_type, output_type) + + mock_open_file.assert_any_call(input_filepath, 'r') + mock_yaml_safe_load.assert_called_once() + mock_makedirs.assert_called_once_with(os.path.join("base_directory", output_directory), exist_ok=True) + self.assertIsNone(result) + + @patch("os.makedirs") + @patch("builtins.open", new_callable=mock_open, read_data='{"key": "value"}') + @patch("json.load", side_effect=Exception("JSON error")) + def test_convert_file_json_to_yaml_error(self, mock_json_load, mock_open_file, mock_makedirs): + input_filepath = "input.json" + output_directory = "yaml" + input_type = "json" + output_type = "yaml" + + result = self.converter.convert_file(input_filepath, output_directory, input_type, output_type) + + mock_open_file.assert_any_call(input_filepath, 'r') + mock_json_load.assert_called_once() + mock_makedirs.assert_called_once_with(os.path.join("base_directory", output_directory), exist_ok=True) + self.assertIsNone(result) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_openapi_parser.py b/tests/test_openapi_parser.py new file mode 100644 index 0000000..0d52251 --- /dev/null +++ b/tests/test_openapi_parser.py @@ -0,0 +1,256 @@ +import unittest +from unittest.mock import patch, mock_open +import yaml +from hackingBuddyGPT.usecases.web_api_testing.utils import OpenAPISpecificationParser + +class TestOpenAPISpecificationParser(unittest.TestCase): + def setUp(self): + self.filepath = "dummy_path.yaml" + self.yaml_content = """ + openapi: 3.0.0 + info: + title: Sample API + version: 1.0.0 + servers: + - url: https://api.example.com + - url: https://staging.api.example.com + paths: + /pets: + get: + summary: List all pets + responses: + '200': + description: A paged array of pets + post: + summary: Create a pet + responses: + '200': + description: Pet created + /pets/{petId}: + get: + summary: Info for a specific pet + responses: + '200': + description: Expected response to a valid request + """ + + @patch("builtins.open", new_callable=mock_open, read_data="") + @patch("yaml.safe_load", return_value=yaml.safe_load(""" + openapi: 3.0.0 + info: + title: Sample API + version: 1.0.0 + servers: + - url: https://api.example.com + - url: https://staging.api.example.com + paths: + /pets: + get: + summary: List all pets + responses: + '200': + description: A paged array of pets + post: + summary: Create a pet + responses: + '200': + description: Pet created + /pets/{petId}: + get: + summary: Info for a specific pet + responses: + '200': + description: Expected response to a valid request + """)) + def test_load_yaml(self, mock_yaml_load, mock_open_file): + parser = OpenAPISpecificationParser(self.filepath) + self.assertEqual(parser.api_data['info']['title'], "Sample API") + self.assertEqual(parser.api_data['info']['version'], "1.0.0") + self.assertEqual(len(parser.api_data['servers']), 2) + + @patch("builtins.open", new_callable=mock_open, read_data="") + @patch("yaml.safe_load", return_value=yaml.safe_load(""" + openapi: 3.0.0 + info: + title: Sample API + version: 1.0.0 + servers: + - url: https://api.example.com + - url: https://staging.api.example.com + paths: + /pets: + get: + summary: List all pets + responses: + '200': + description: A paged array of pets + post: + summary: Create a pet + responses: + '200': + description: Pet created + /pets/{petId}: + get: + summary: Info for a specific pet + responses: + '200': + description: Expected response to a valid request + """)) + def test_get_servers(self, mock_yaml_load, mock_open_file): + parser = OpenAPISpecificationParser(self.filepath) + servers = parser.get_servers() + self.assertEqual(servers, ["https://api.example.com", "https://staging.api.example.com"]) + + @patch("builtins.open", new_callable=mock_open, read_data="") + @patch("yaml.safe_load", return_value=yaml.safe_load(""" + openapi: 3.0.0 + info: + title: Sample API + version: 1.0.0 + servers: + - url: https://api.example.com + - url: https://staging.api.example.com + paths: + /pets: + get: + summary: List all pets + responses: + '200': + description: A paged array of pets + post: + summary: Create a pet + responses: + '200': + description: Pet created + /pets/{petId}: + get: + summary: Info for a specific pet + responses: + '200': + description: Expected response to a valid request + """)) + def test_get_paths(self, mock_yaml_load, mock_open_file): + parser = OpenAPISpecificationParser(self.filepath) + paths = parser.get_paths() + expected_paths = { + "/pets": { + "get": { + "summary": "List all pets", + "responses": { + "200": { + "description": "A paged array of pets" + } + } + }, + "post": { + "summary": "Create a pet", + "responses": { + "200": { + "description": "Pet created" + } + } + } + }, + "/pets/{petId}": { + "get": { + "summary": "Info for a specific pet", + "responses": { + "200": { + "description": "Expected response to a valid request" + } + } + } + } + } + self.assertEqual(paths, expected_paths) + + @patch("builtins.open", new_callable=mock_open, read_data="") + @patch("yaml.safe_load", return_value=yaml.safe_load(""" + openapi: 3.0.0 + info: + title: Sample API + version: 1.0.0 + servers: + - url: https://api.example.com + - url: https://staging.api.example.com + paths: + /pets: + get: + summary: List all pets + responses: + '200': + description: A paged array of pets + post: + summary: Create a pet + responses: + '200': + description: Pet created + /pets/{petId}: + get: + summary: Info for a specific pet + responses: + '200': + description: Expected response to a valid request + """)) + def test_get_operations(self, mock_yaml_load, mock_open_file): + parser = OpenAPISpecificationParser(self.filepath) + operations = parser.get_operations("/pets") + expected_operations = { + "get": { + "summary": "List all pets", + "responses": { + "200": { + "description": "A paged array of pets" + } + } + }, + "post": { + "summary": "Create a pet", + "responses": { + "200": { + "description": "Pet created" + } + } + } + } + self.assertEqual(operations, expected_operations) + + @patch("builtins.open", new_callable=mock_open, read_data="") + @patch("yaml.safe_load", return_value=yaml.safe_load(""" + openapi: 3.0.0 + info: + title: Sample API + version: 1.0.0 + servers: + - url: https://api.example.com + - url: https://staging.api.example.com + paths: + /pets: + get: + summary: List all pets + responses: + '200': + description: A paged array of pets + post: + summary: Create a pet + responses: + '200': + description: Pet created + /pets/{petId}: + get: + summary: Info for a specific pet + responses: + '200': + description: Expected response to a valid request + """)) + def test_print_api_details(self, mock_yaml_load, mock_open_file): + parser = OpenAPISpecificationParser(self.filepath) + with patch('builtins.print') as mocked_print: + parser.print_api_details() + mocked_print.assert_any_call("API Title:", "Sample API") + mocked_print.assert_any_call("API Version:", "1.0.0") + mocked_print.assert_any_call("Servers:", ["https://api.example.com", "https://staging.api.example.com"]) + mocked_print.assert_any_call("\nAvailable Paths and Operations:") + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_prompt_engineer.py b/tests/test_prompt_engineer.py new file mode 100644 index 0000000..f8b9e44 --- /dev/null +++ b/tests/test_prompt_engineer.py @@ -0,0 +1,64 @@ +import unittest +from unittest.mock import MagicMock +from hackingBuddyGPT.usecases.web_api_testing.prompt_engineer import PromptStrategy, PromptEngineer + + +class TestPromptEngineer(unittest.TestCase): + def setUp(self): + self.strategy = PromptStrategy.IN_CONTEXT + self.llm_handler = MagicMock() + self.history = [{"content": "initial_prompt", "role": "system"}] + self.schemas = MagicMock() + self.response_handler = MagicMock() + self.prompt_engineer = PromptEngineer( + self.strategy, self.llm_handler, self.history, self.schemas, self.response_handler + ) + def test_token_count(self): + text = "This is a sample text with several words." + count = self.prompt_engineer.token_count(text) + self.assertEqual(8, count) + def test_check_prompt(self): + self.response_handler.get_response_for_prompt = MagicMock(return_value="shortened_prompt") + prompt = self.prompt_engineer.check_prompt("previous_prompt", + ["step1", "step2", "step3", "step4", "step5", "step6"], max_tokens=5) + self.assertEqual(prompt, "shortened_prompt") + + def test_in_context_learning_no_hint(self): + expected_prompt = "initial_prompt\ninitial_prompt" + actual_prompt = self.prompt_engineer.in_context_learning() + self.assertEqual(expected_prompt, actual_prompt) + + def test_in_context_learning_with_hint(self): + hint = "This is a hint." + expected_prompt = "initial_prompt\ninitial_prompt\nThis is a hint." + actual_prompt = self.prompt_engineer.in_context_learning(hint=hint) + self.assertEqual(expected_prompt, actual_prompt) + + def test_in_context_learning_with_doc_and_hint(self): + hint = "This is another hint." + expected_prompt = "initial_prompt\ninitial_prompt\nThis is another hint." + actual_prompt = self.prompt_engineer.in_context_learning(doc=True, hint=hint) + self.assertEqual(expected_prompt, actual_prompt) + def test_generate_prompt_chain_of_thought(self): + self.prompt_engineer.strategy = PromptStrategy.CHAIN_OF_THOUGHT + self.response_handler.get_response_for_prompt = MagicMock(return_value="response_text") + self.prompt_engineer.evaluate_response = MagicMock(return_value=True) + + prompt_history = self.prompt_engineer.generate_prompt() + + self.assertEqual( 2, len(prompt_history)) + + def test_generate_prompt_tree_of_thought(self): + self.prompt_engineer.strategy = PromptStrategy.TREE_OF_THOUGHT + self.response_handler.get_response_for_prompt = MagicMock(return_value="response_text") + self.prompt_engineer.evaluate_response = MagicMock(return_value=True) + + prompt_history = self.prompt_engineer.generate_prompt() + + self.assertEqual(len(prompt_history), 2) + + + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/test_response_handler.py b/tests/test_response_handler.py new file mode 100644 index 0000000..d80ce55 --- /dev/null +++ b/tests/test_response_handler.py @@ -0,0 +1,117 @@ +import unittest +from unittest.mock import MagicMock, patch +from bs4 import BeautifulSoup +import json +from hackingBuddyGPT.usecases.web_api_testing.utils import ResponseHandler + +class TestResponseHandler(unittest.TestCase): + def setUp(self): + self.llm_handler_mock = MagicMock() + self.response_handler = ResponseHandler(self.llm_handler_mock) + + def test_get_response_for_prompt(self): + prompt = "Test prompt" + response_mock = MagicMock() + response_mock.execute.return_value = "Response text" + self.llm_handler_mock.call_llm.return_value = (response_mock, MagicMock()) + + response_text = self.response_handler.get_response_for_prompt(prompt) + + self.llm_handler_mock.call_llm.assert_called_once_with([{"role": "user", "content": [{"type": "text", "text": prompt}]}]) + self.assertEqual(response_text, "Response text") + + def test_parse_http_status_line_valid(self): + status_line = "HTTP/1.1 200 OK" + result = self.response_handler.parse_http_status_line(status_line) + self.assertEqual(result, "200 OK") + + def test_parse_http_status_line_invalid(self): + status_line = "Invalid status line" + with self.assertRaises(ValueError): + self.response_handler.parse_http_status_line(status_line) + + def test_extract_response_example(self): + html_content = """ + +
+{"example": "test"}
+ {"key": "value"}
+
+
+ """
+ result = self.response_handler.extract_response_example(html_content)
+ self.assertEqual(result, {"key": "value"})
+
+ def test_extract_response_example_invalid(self):
+ html_content = "No code tags"
+ result = self.response_handler.extract_response_example(html_content)
+ self.assertIsNone(result)
+
+ @patch('hackingBuddyGPT.usecases.web_api_testing.utils.ResponseHandler.parse_http_response_to_schema')
+ def test_parse_http_response_to_openapi_example(self, mock_parse_http_response_to_schema):
+ openapi_spec = {
+ "components": {"schemas": {}}
+ }
+ http_response = "HTTP/1.1 200 OK\r\n\r\n{\"id\": 1, \"name\": \"test\"}"
+ path = "/test"
+ method = "GET"
+
+ mock_parse_http_response_to_schema.return_value = ("#/components/schemas/Test", "Test", openapi_spec)
+
+ entry_dict, reference, updated_spec = self.response_handler.parse_http_response_to_openapi_example(openapi_spec, http_response, path, method)
+
+ self.assertEqual(reference, "#/components/schemas/Test")
+ self.assertEqual(updated_spec, openapi_spec)
+ self.assertIn("test", entry_dict)
+
+ def test_extract_description(self):
+ note = MagicMock()
+ note.action.content = "Test description"
+ description = self.response_handler.extract_description(note)
+ self.assertEqual(description, "Test description")
+
+ @patch('hackingBuddyGPT.usecases.web_api_testing.utils.ResponseHandler.extract_keys')
+ def test_parse_http_response_to_schema(self, mock_extract_keys):
+ openapi_spec = {
+ "components": {"schemas": {}}
+ }
+ body_dict = {"id": 1, "name": "test"}
+ path = "/tests"
+
+ mock_extract_keys.side_effect = lambda key, value, properties: {**properties, key: {"type": type(value).__name__, "example": value}}
+
+ reference, object_name, updated_spec = self.response_handler.parse_http_response_to_schema(openapi_spec, body_dict, path)
+
+ self.assertEqual(reference, "#/components/schemas/Test")
+ self.assertEqual(object_name, "Test")
+ self.assertIn("Test", updated_spec["components"]["schemas"])
+ self.assertIn("id", updated_spec["components"]["schemas"]["Test"]["properties"])
+ self.assertIn("name", updated_spec["components"]["schemas"]["Test"]["properties"])
+
+ @patch('builtins.open', new_callable=unittest.mock.mock_open, read_data='yaml_content')
+ def test_read_yaml_to_string(self, mock_open):
+ filepath = "test.yaml"
+ result = self.response_handler.read_yaml_to_string(filepath)
+ mock_open.assert_called_once_with(filepath, 'r')
+ self.assertEqual(result, 'yaml_content')
+
+ def test_read_yaml_to_string_file_not_found(self):
+ filepath = "nonexistent.yaml"
+ result = self.response_handler.read_yaml_to_string(filepath)
+ self.assertIsNone(result)
+
+ def test_extract_endpoints(self):
+ note = "1. GET /test\n"
+ result = self.response_handler.extract_endpoints(note)
+ self.assertEqual( {'/test': ['GET']}, result)
+
+ def test_extract_keys(self):
+ key = "name"
+ value = "test"
+ properties_dict = {}
+ result = self.response_handler.extract_keys(key, value, properties_dict)
+ self.assertIn(key, result)
+ self.assertEqual(result[key], {"type": "str", "example": "test"})
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_web_api_documentation.py b/tests/test_web_api_documentation.py
new file mode 100644
index 0000000..ce70be6
--- /dev/null
+++ b/tests/test_web_api_documentation.py
@@ -0,0 +1,76 @@
+import unittest
+from unittest.mock import MagicMock, patch
+from hackingBuddyGPT.usecases.web_api_testing.simple_openapi_documentation import SimpleWebAPIDocumentationUseCase, \
+ SimpleWebAPIDocumentation
+from hackingBuddyGPT.utils import DbStorage, Console
+
+
+class TestSimpleWebAPIDocumentationTest(unittest.TestCase):
+
+ @patch('hackingBuddyGPT.utils.openai.openai_lib.OpenAILib')
+ def setUp(self, MockOpenAILib):
+ # Mock the OpenAILib instance
+ self.mock_llm = MockOpenAILib.return_value
+ log_db = DbStorage(':memory:')
+ console = Console()
+
+ log_db.init()
+ self.agent = SimpleWebAPIDocumentation(llm=self.mock_llm)
+ self.agent.init()
+ self.simple_api_testing = SimpleWebAPIDocumentationUseCase(
+ agent=self.agent,
+ log_db=log_db,
+ console=console,
+ tag='webApiDocumentation',
+ max_turns=len(self.mock_llm.responses)
+ )
+ self.simple_api_testing.init()
+
+ def test_initial_prompt(self):
+ # Test if the initial prompt is set correctly
+ expected_prompt = "You're tasked with documenting the REST APIs of a website hosted at https://jsonplaceholder.typicode.com. Start with an empty OpenAPI specification.\nMaintain meticulousness in documenting your observations as you traverse the APIs."
+
+ self.assertIn(expected_prompt, self.agent._prompt_history[0]['content'])
+
+ def test_all_flags_found(self):
+ # Mock console.print to suppress output during testing
+ with patch('rich.console.Console.print'):
+ self.agent.all_http_methods_found(1)
+ self.assertFalse(self.agent.all_http_methods_found(1))
+
+ @patch('time.perf_counter', side_effect=[1, 2]) # Mocking perf_counter for consistent timing
+ def test_perform_round(self, mock_perf_counter):
+ # Prepare mock responses
+ mock_response = MagicMock()
+ mock_completion = MagicMock()
+
+ # Setup completion response with mocked data
+ mock_completion.choices[0].message.content = "Mocked LLM response"
+ mock_completion.choices[0].message.tool_calls = [MagicMock(id="tool_call_1")]
+ mock_completion.usage.prompt_tokens = 10
+ mock_completion.usage.completion_tokens = 20
+
+ # Mock the OpenAI LLM response
+ self.agent.llm.instructor.chat.completions.create_with_completion.return_value = (
+ mock_response, mock_completion)
+
+ # Mock the tool execution result
+ mock_response.execute.return_value = "HTTP/1.1 200 OK"
+
+ # Perform the round
+ result = self.agent.perform_round(1)
+
+ # Assertions
+ self.assertFalse(result)
+
+ # Check if the LLM was called with the correct parameters
+ mock_create_with_completion = self.agent.llm.instructor.chat.completions.create_with_completion
+
+ # if it can be called multiple times, use assert_called
+ self.assertEqual( 2, mock_create_with_completion.call_count)
+
+ # Check if the prompt history was updated correctly
+ self.assertEqual(5, len(self.agent._prompt_history)) # Initial message + LLM response + tool message
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tests/test_web_api_testing.py b/tests/test_web_api_testing.py
new file mode 100644
index 0000000..aa4d5da
--- /dev/null
+++ b/tests/test_web_api_testing.py
@@ -0,0 +1,74 @@
+import unittest
+from unittest.mock import MagicMock, patch
+from hackingBuddyGPT.usecases import SimpleWebAPITesting
+from hackingBuddyGPT.usecases.web_api_testing.simple_web_api_testing import SimpleWebAPITestingUseCase
+from hackingBuddyGPT.utils import DbStorage, Console
+
+
+class TestSimpleWebAPITestingTest(unittest.TestCase):
+
+ @patch('hackingBuddyGPT.utils.openai.openai_lib.OpenAILib')
+ def setUp(self, MockOpenAILib):
+ # Mock the OpenAILib instance
+ self.mock_llm = MockOpenAILib.return_value
+ log_db = DbStorage(':memory:')
+ console = Console()
+
+ log_db.init()
+ self.agent = SimpleWebAPITesting(llm=self.mock_llm)
+ self.agent.init()
+ self.simple_api_testing = SimpleWebAPITestingUseCase(
+ agent=self.agent,
+ log_db=log_db,
+ console=console,
+ tag='integration_test_linuxprivesc',
+ max_turns=len(self.mock_llm.responses)
+ )
+ self.simple_api_testing.init()
+
+ def test_initial_prompt(self):
+ # Test if the initial prompt is set correctly
+ self.assertIn("You're a seasoned API tester tasked with evaluating the REST APIs of a website hosted at https://jsonplaceholder.typicode.com. Your primary objective is to thoroughly explore the APIs, understanding their endpoints, parameters, and responses, and then devise and execute potential exploits. Be diligent in documenting your findings as you navigate through the APIs. Avoid resorting to brute-force tactics. All necessary information can be obtained from the API endpoints themselves. Remember, if you encounter an HTTP method (A string that represents an HTTP method (e.g., 'GET', 'POST', etc.).), promptly submit it as it is of utmost importance.", self.agent._prompt_history[0]['content'])
+
+ def test_all_flags_found(self):
+ # Mock console.print to suppress output during testing
+ with patch('rich.console.Console.print'):
+ self.agent.all_http_methods_found()
+ self.assertFalse(self.agent.all_http_methods_found())
+
+ @patch('time.perf_counter', side_effect=[1, 2]) # Mocking perf_counter for consistent timing
+ def test_perform_round(self, mock_perf_counter):
+ # Prepare mock responses
+ mock_response = MagicMock()
+ mock_completion = MagicMock()
+
+ # Setup completion response with mocked data
+ mock_completion.choices[0].message.content = "Mocked LLM response"
+ mock_completion.choices[0].message.tool_calls = [MagicMock(id="tool_call_1")]
+ mock_completion.usage.prompt_tokens = 10
+ mock_completion.usage.completion_tokens = 20
+
+ # Mock the OpenAI LLM response
+ self.agent.llm.instructor.chat.completions.create_with_completion.return_value = (
+ mock_response, mock_completion)
+
+ # Mock the tool execution result
+ mock_response.execute.return_value = "HTTP/1.1 200 OK"
+
+ # Perform the round
+ result = self.agent.perform_round(1)
+
+ # Assertions
+ self.assertFalse(result) # No flags found in this round
+
+ # Check if the LLM was called with the correct parameters
+ mock_create_with_completion = self.agent.llm.instructor.chat.completions.create_with_completion
+
+ # if it can be called multiple times, use assert_called
+ self.assertEqual( 2, mock_create_with_completion.call_count)
+
+ # Check if the prompt history was updated correctly
+ self.assertEqual(5, len(self.agent._prompt_history)) # Initial message + LLM response + tool message
+
+if __name__ == '__main__':
+ unittest.main()