diff --git a/CHANGELOG.md b/CHANGELOG.md index 0fa45e7..8a27457 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Add get morphoelectric (me) model tool - BlueNaaS simulation tool. +- BlueNaaS tool test. ## [0.1.1] - 26.09.2024 diff --git a/src/neuroagent/tools/bluenaas_tool.py b/src/neuroagent/tools/bluenaas_tool.py index 1c7cc6f..4fb7bf3 100644 --- a/src/neuroagent/tools/bluenaas_tool.py +++ b/src/neuroagent/tools/bluenaas_tool.py @@ -107,6 +107,52 @@ async def _arun( ) -> BaseToolOutput: """Run the BlueNaaS tool.""" logger.info("Running BlueNaaS tool") + json_api = self.create_json_api( + current_injection__inject_to=current_injection__inject_to, + current_injection__stimulus__stimulus_type=current_injection__stimulus__stimulus_type, + current_injection__stimulus__stimulus_protocol=current_injection__stimulus__stimulus_protocol, + current_injection__stimulus__amplitudes=current_injection__stimulus__amplitudes, + record_from=record_from, + conditions__celsius=conditions__celsius, + conditions__vinit=conditions__vinit, + conditions__hypamp=conditions__hypamp, + conditions__max_time=conditions__max_time, + conditions__seed=conditions__seed, + ) + + try: + _ = await self.metadata["httpx_client"].post( + url=self.metadata["url"], + params={"model_id": me_model_id}, + headers={"Authorization": f'Bearer {self.metadata["token"]}'}, + json=json_api, + timeout=5.0, + ) + + return BlueNaaSOutput(status="success") + + except Exception as e: + raise ToolException(str(e), self.name) + + @staticmethod + def create_json_api( + current_injection__inject_to: str = "soma[0]", + current_injection__stimulus__stimulus_type: Literal[ + "current_clamp", "voltage_clamp", "conductance" + ] = "current_clamp", + current_injection__stimulus__stimulus_protocol: Literal[ + "ap_waveform", "idrest", "iv", "fire_pattern" + ] = "ap_waveform", + current_injection__stimulus__amplitudes: list[float] | None = None, + record_from: list[RecordingLocation] | None = None, + conditions__celsius: int = 34, + conditions__vinit: int = -73, + conditions__hypamp: int = 0, + conditions__max_time: int = 100, + conditions__time_step: float = 0.05, + conditions__seed: int = 100, + ) -> dict[str, Any]: + """Based on the simulation config, create a valid JSON for the API.""" if not current_injection__stimulus__amplitudes: current_injection__stimulus__amplitudes = [0.1] if not record_from: @@ -132,16 +178,4 @@ async def _arun( "type": "single-neuron-simulation", "simulationDuration": conditions__max_time, } - try: - _ = await self.metadata["httpx_client"].post( - url=self.metadata["url"], - params={"model_id": me_model_id}, - headers={"Authorization": f'Bearer {self.metadata["token"]}'}, - json=json_api, - timeout=5.0, - ) - - return BlueNaaSOutput(status="success") - - except Exception as e: - raise ToolException(str(e), self.name) + return json_api diff --git a/tests/tools/test_bluenaas_tool.py b/tests/tools/test_bluenaas_tool.py new file mode 100644 index 0000000..0d610f9 --- /dev/null +++ b/tests/tools/test_bluenaas_tool.py @@ -0,0 +1,98 @@ +"""Tests BlueNaaS tool.""" + +import httpx +import pytest +from langchain_core.tools import ToolException + +from neuroagent.tools import BlueNaaSTool +from neuroagent.tools.bluenaas_tool import BlueNaaSOutput, RecordingLocation + + +@pytest.mark.asyncio +async def test_arun(httpx_mock): + me_model_id = "great_id" + url = "http://fake_url" + + httpx_mock.add_response( + url=url + f"?model_id={me_model_id}", + json={"t": [0.05, 0.1, 0.15, 0.2], "v": [-1.14, -0.67, -1.78]}, + ) + tool = BlueNaaSTool( + metadata={ + "url": url, + "httpx_client": httpx.AsyncClient(), + "token": "fake_token", + } + ) + response = await tool._arun( + me_model_id=me_model_id, + conditions__celsius=7, + current_injection__inject_to="axon[1]", + ) + assert isinstance(response, BlueNaaSOutput) + + +@pytest.mark.asyncio +async def test_arun_errors(httpx_mock, brain_region_json_path, tmp_path): + url = "http://fake_url" + httpx_mock.add_exception(httpx.ReadTimeout("Unable to read within timeout")) + + tool = BlueNaaSTool( + metadata={ + "url": url, + "httpx_client": httpx.AsyncClient(), + "token": "fake_token", + } + ) + with pytest.raises(ToolException) as tool_exception: + _ = await tool._arun( + me_model_id="great_id", + ) + + assert tool_exception.value.args[0] == "Unable to read within timeout" + + +def test_create_json_api(): + url = "http://fake_url" + + tool = BlueNaaSTool( + metadata={ + "url": url, + "httpx_client": httpx.AsyncClient(), + "token": "fake_token", + } + ) + + json_api = tool.create_json_api( + conditions__vinit=-3, + current_injection__stimulus__stimulus_type="conductance", + record_from=[ + RecordingLocation(), + RecordingLocation(section="axon[78]", offset=0.1), + ], + current_injection__stimulus__amplitudes=[0.1, 0.5], + ) + assert json_api == { + "currentInjection": { + "injectTo": "soma[0]", + "stimulus": { + "stimulusType": "conductance", + "stimulusProtocol": "ap_waveform", + "amplitudes": [0.1, 0.5], + }, + }, + "recordFrom": [ + {"section": "soma[0]", "offset": 0.5}, + {"section": "axon[78]", "offset": 0.1}, + ], + "conditions": { + "celsius": 34, + "vinit": -3, + "hypamp": 0, + "max_time": 100, + "time_step": 0.05, + "seed": 100, + }, + "type": "single-neuron-simulation", + "simulationDuration": 100, + }