Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bluenaas tool test #18

Merged
merged 3 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Add get morphoelectric (me) model tool
- BlueNaaS simulation tool.
- BlueNaaS tool test.

## [0.1.1] - 26.09.2024

Expand Down
60 changes: 47 additions & 13 deletions src/neuroagent/tools/bluenaas_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,52 @@ async def _arun(
) -> BaseToolOutput:
"""Run the BlueNaaS tool."""
logger.info("Running BlueNaaS tool")
json_api = self.create_json_api(
current_injection__inject_to=current_injection__inject_to,
current_injection__stimulus__stimulus_type=current_injection__stimulus__stimulus_type,
current_injection__stimulus__stimulus_protocol=current_injection__stimulus__stimulus_protocol,
current_injection__stimulus__amplitudes=current_injection__stimulus__amplitudes,
record_from=record_from,
conditions__celsius=conditions__celsius,
conditions__vinit=conditions__vinit,
conditions__hypamp=conditions__hypamp,
conditions__max_time=conditions__max_time,
conditions__seed=conditions__seed,
)

try:
_ = await self.metadata["httpx_client"].post(
url=self.metadata["url"],
params={"model_id": me_model_id},
headers={"Authorization": f'Bearer {self.metadata["token"]}'},
json=json_api,
timeout=5.0,
)

return BlueNaaSOutput(status="success")

except Exception as e:
raise ToolException(str(e), self.name)

@staticmethod
def create_json_api(
current_injection__inject_to: str = "soma[0]",
current_injection__stimulus__stimulus_type: Literal[
"current_clamp", "voltage_clamp", "conductance"
] = "current_clamp",
current_injection__stimulus__stimulus_protocol: Literal[
"ap_waveform", "idrest", "iv", "fire_pattern"
] = "ap_waveform",
current_injection__stimulus__amplitudes: list[float] | None = None,
record_from: list[RecordingLocation] | None = None,
conditions__celsius: int = 34,
conditions__vinit: int = -73,
conditions__hypamp: int = 0,
conditions__max_time: int = 100,
conditions__time_step: float = 0.05,
conditions__seed: int = 100,
) -> dict[str, Any]:
"""Based on the simulation config, create a valid JSON for the API."""
if not current_injection__stimulus__amplitudes:
current_injection__stimulus__amplitudes = [0.1]
if not record_from:
Expand All @@ -132,16 +178,4 @@ async def _arun(
"type": "single-neuron-simulation",
"simulationDuration": conditions__max_time,
}
try:
_ = await self.metadata["httpx_client"].post(
url=self.metadata["url"],
params={"model_id": me_model_id},
headers={"Authorization": f'Bearer {self.metadata["token"]}'},
json=json_api,
timeout=5.0,
)

return BlueNaaSOutput(status="success")

except Exception as e:
raise ToolException(str(e), self.name)
return json_api
98 changes: 98 additions & 0 deletions tests/tools/test_bluenaas_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Tests BlueNaaS tool."""

import httpx
import pytest
from langchain_core.tools import ToolException

from neuroagent.tools import BlueNaaSTool
from neuroagent.tools.bluenaas_tool import BlueNaaSOutput, RecordingLocation


@pytest.mark.asyncio
async def test_arun(httpx_mock):
me_model_id = "great_id"
url = "http://fake_url"

httpx_mock.add_response(
url=url + f"?model_id={me_model_id}",
json={"t": [0.05, 0.1, 0.15, 0.2], "v": [-1.14, -0.67, -1.78]},
)
tool = BlueNaaSTool(
metadata={
"url": url,
"httpx_client": httpx.AsyncClient(),
"token": "fake_token",
}
)
response = await tool._arun(
me_model_id=me_model_id,
conditions__celsius=7,
current_injection__inject_to="axon[1]",
)
assert isinstance(response, BlueNaaSOutput)


@pytest.mark.asyncio
async def test_arun_errors(httpx_mock, brain_region_json_path, tmp_path):
url = "http://fake_url"
httpx_mock.add_exception(httpx.ReadTimeout("Unable to read within timeout"))

tool = BlueNaaSTool(
metadata={
"url": url,
"httpx_client": httpx.AsyncClient(),
"token": "fake_token",
}
)
with pytest.raises(ToolException) as tool_exception:
_ = await tool._arun(
me_model_id="great_id",
)

assert tool_exception.value.args[0] == "Unable to read within timeout"


def test_create_json_api():
url = "http://fake_url"

tool = BlueNaaSTool(
metadata={
"url": url,
"httpx_client": httpx.AsyncClient(),
"token": "fake_token",
}
)

json_api = tool.create_json_api(
conditions__vinit=-3,
current_injection__stimulus__stimulus_type="conductance",
record_from=[
RecordingLocation(),
RecordingLocation(section="axon[78]", offset=0.1),
],
current_injection__stimulus__amplitudes=[0.1, 0.5],
)
assert json_api == {
"currentInjection": {
"injectTo": "soma[0]",
"stimulus": {
"stimulusType": "conductance",
"stimulusProtocol": "ap_waveform",
"amplitudes": [0.1, 0.5],
},
},
"recordFrom": [
{"section": "soma[0]", "offset": 0.5},
{"section": "axon[78]", "offset": 0.1},
],
"conditions": {
"celsius": 34,
"vinit": -3,
"hypamp": 0,
"max_time": 100,
"time_step": 0.05,
"seed": 100,
},
"type": "single-neuron-simulation",
"simulationDuration": 100,
}
Loading