Skip to content

Commit

Permalink
Bug fix: Change make_service_call_with_data to accept a list of entit…
Browse files Browse the repository at this point in the history
…y IDs, modify tools that call this function to always pass in a list
  • Loading branch information
hemanthpai committed Sep 13, 2024
1 parent 4a8c580 commit f13c35d
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions custom_components/ai_assistant/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,30 +278,30 @@ async def make_service_call_with_action(entity_id: str, action: str, tool_name:
return str(tool_call_result)


async def make_service_call_with_data(entity_id: str, service: str, data: dict, tool_name: str):
async def make_service_call_with_data(entity_ids: list[str], service: str, data: dict, tool_name: str):
"""Make a service call to Home Assistant.
Args:
entity_id: The entity ID to call the service on.
entity_ids: The entity IDs to call the service on.
service: The service to call.
data: The data to pass to the service.
tool_name: The name of the tool making the service call.
"""

domain_entity_map = {}
service_call_results: list[HomeAssistantServiceResult] = []
tool_call_result = ToolCallResult()

if "." not in entity_id:
tool_call_result.add_missing_domain_entity_id([entity_id])
else:
domain = entity_id.split(".")[0]
validate_entity_ids(entity_ids, domain_entity_map,
tool_call_result, tool_name)

if domain not in SUPPORTED_DOMAINS[tool_name]:
tool_call_result.add_domain_not_supported_entity_id([entity_id])
else:
result = await HomeAssistantService.async_call_service(
[entity_id], domain, service, data)
process_service_call_results([result], tool_call_result)
for domain, ids in domain_entity_map.items():
result = await HomeAssistantService.async_call_service(
ids, domain, service, data)
service_call_results.append(result)

process_service_call_results(service_call_results, tool_call_result)

return str(tool_call_result)

Expand Down Expand Up @@ -720,7 +720,7 @@ async def hass_set_temperature(entity_id: str, temperature: float):
LOGGER.debug(
f"{entity_id} is in auto mode. Setting the target high temperature and retaining the target low temperature.")

result = await make_service_call_with_data(entity_id, "set_temperature", {
result = await make_service_call_with_data([entity_id], "set_temperature", {
"target_temp_low": thermostat_attributes.target_temperature_low,
"target_temp_high": temperature
}, "hass_set_temperature")
Expand All @@ -731,7 +731,7 @@ async def hass_set_temperature(entity_id: str, temperature: float):
LOGGER.debug(
f"{entity_id} is in heat or cool mode. Setting the temperature.")

result = await make_service_call_with_data(entity_id, "set_temperature", {
result = await make_service_call_with_data([entity_id], "set_temperature", {
"temperature": temperature
}, "hass_set_temperature")

Expand Down Expand Up @@ -759,7 +759,7 @@ async def hass_set_humidity(entity_id: str, humidity: float):

LOGGER.debug(f"Setting humidity of entity {entity_id} to {humidity}")

result = await make_service_call_with_data(entity_id, "set_humidity", {
result = await make_service_call_with_data([entity_id], "set_humidity", {
"humidity": humidity
}, "hass_set_humidity")

Expand All @@ -783,7 +783,7 @@ async def hass_set_fan_mode(entity_id: str, fan_mode: str):

LOGGER.debug(f"Setting fan mode of entity {entity_id} to {fan_mode.value}")

result = await make_service_call_with_data(entity_id, "set_fan_mode", {
result = await make_service_call_with_data([entity_id], "set_fan_mode", {
"fan_mode": fan_mode
}, "hass_set_fan_mode")

Expand All @@ -808,7 +808,7 @@ async def hass_set_hvac_mode(entity_id: str, hvac_mode: str):
LOGGER.debug(f"Setting HVAC mode of entity {
entity_id} to {hvac_mode.value}")

result = await make_service_call_with_data(entity_id, "set_hvac_mode", {
result = await make_service_call_with_data([entity_id], "set_hvac_mode", {
"hvac_mode": hvac_mode
}, "hass_set_hvac_mode")

Expand All @@ -833,7 +833,7 @@ async def hass_set_preset_mode(entity_id: str, preset_mode: str):
LOGGER.debug(f"Setting preset mode of entity {
entity_id} to {preset_mode.value}")

result = await make_service_call_with_data(entity_id, "set_preset_mode", {
result = await make_service_call_with_data([entity_id], "set_preset_mode", {
"preset_mode": preset_mode
}, "hass_set_preset_mode")

Expand Down

0 comments on commit f13c35d

Please sign in to comment.