Skip to content

Commit

Permalink
done app
Browse files Browse the repository at this point in the history
  • Loading branch information
sasax7 committed Dec 23, 2024
1 parent edb44cc commit 132bb33
Show file tree
Hide file tree
Showing 12 changed files with 381 additions and 63 deletions.
10 changes: 9 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,10 @@

__pycache__
*.pyc
.env
.vscode
.git
venv
keys.txt
__pycache__/
.venv

Binary file modified __pycache__/get_trend_data.cpython-311.pyc
Binary file not shown.
2 changes: 1 addition & 1 deletion api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class CorrelationRequest(BaseModel):

class CorrelateChildrenRequest(BaseModel):
asset_id: int
lag: Optional[Dict[LagUnit, int]] = None
lags: Optional[List[Dict[LagUnit, int]]] = None
start_time: Optional[datetime] = None
end_time: Optional[datetime] = None

Expand Down
89 changes: 53 additions & 36 deletions api/openapi.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,42 @@
from fastapi import FastAPI, HTTPException

from datetime import datetime
from api.models import (
CorrelationRequest,
CorrelateChildrenRequest,
)
import pytz
import yaml
from api.models import CorrelationRequest, CorrelateChildrenRequest, AssetAttribute
from api.correlation import get_data, compute_correlation
from api.plot_correlation import (
create_best_correlation_heatmap,
in_depth_plot_scatter,
plot_lag_correlations,
)
import pytz
from get_trend_data import get_all_asset_children


# Create the FastAPI app instance
app = FastAPI(
title="Correlation App API",
description="API to manage and query correlations between assets.",
version="1.0.0",
openapi_url="/v1/version/openapi.json",
openapi_version="3.1.0",
)

# Load custom OpenAPI schema
with open("openapi.yaml", "r") as f:
openapi_yaml = yaml.safe_load(f)


def custom_openapi():
app.openapi_schema = openapi_yaml
return app.openapi_schema

app = FastAPI()

app.openapi = custom_openapi

@app.post("/correlate")

# Define endpoints
@app.post("/v1/correlate")
def correlate_assets(request: CorrelationRequest):
end_time = request.end_time or datetime.now()
dataframes = get_data(request)
Expand All @@ -30,19 +51,30 @@ def correlate_assets(request: CorrelationRequest):
}


@app.post("/correlate-children")
@app.post("/v1/correlate-children")
def correlate_asset_children(request: CorrelateChildrenRequest):
end_time = request.end_time or datetime.now()
child_asset_ids = get_all_asset_children(request.asset_id)
print(f"Found {len(child_asset_ids)} children for asset {request.asset_id}")
correlation_request = CorrelationRequest(
assets=child_asset_ids,
lags=request.lags,
start_time=request.start_time,
end_time=request.end_time,
)

correlations = correlate_assets(correlation_request)

return {
"asset_id": request.asset_id,
"lag": request.lag,
"assets": child_asset_ids,
"lags": request.lags,
"start_time": request.start_time,
"end_time": end_time,
"correlation": "To be implemented",
"correlation": correlations,
}


@app.post("/in-depth-correlation")
@app.post("/v1/in-depth-correlation")
def in_depth_correlation(request: CorrelationRequest):
"""
1) Fetch data for exactly two assets/attributes.
Expand All @@ -61,29 +93,13 @@ def in_depth_correlation(request: CorrelationRequest):
detail="Could not retrieve data for both assets/attributes. Check logs.",
)

# 2) Compute correlation (including any lags) for these two DataFrameInfo
correlations = compute_correlation(df_infos, request)
# 'correlations' is a dict that includes "lag_details" for the single pair, e.g.:
# {
# "867_energy_costs and 867_Wirkleistung": {
# "best_correlation": 0.16,
# "best_lag": 3,
# "best_lag_unit": "hours",
# "lag_details": [
# {"lag_unit": "hours", "lag_step": -10, "correlation": 0.05}, ...
# ]
# }
# }

# 3) Plot the lag correlation lines for that single pair
# This will create line plots in "lag_plots/" (by default) for each lag unit
plot_lag_correlations(correlations, output_dir="lag_plots")

# 4) Also create a scatter plot of the raw data to see direct x-y relationship
# (We re-use 'df_infos' -> 2 DataFrameInfo objects)

lag_plots = plot_lag_correlations(correlations, output_dir="/tmp/lag_plots")

try:
scatter_result = in_depth_plot_scatter(
df_infos, output_file="in_depth_scatter.png"
df_infos, output_file="/tmp/in_depth_scatter.png"
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
Expand All @@ -92,10 +108,11 @@ def in_depth_correlation(request: CorrelationRequest):

return {
"assets": request.assets,
"lags": request.lags,
"start_time": request.start_time,
"end_time": end_time,
"best_correlation": scatter_result["correlation"], # correlation from scatter
"plot_base64_png": scatter_result["plot_base64_png"], # scatter in Base64
"lag_correlation_plots": "Saved in lag_plots/ directory",
"detailed_correlations": correlations, # full correlation details with lags
"correlation": correlations,
"scatter_plot": scatter_result["plot_base64_png"],
"columns": scatter_result["columns"],
"lag_plots": lag_plots,
}
47 changes: 23 additions & 24 deletions api/plot_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import base64
import io
import os


def create_best_correlation_heatmap(correlations_dict, output_file="heatmap.png"):
def create_best_correlation_heatmap(correlations_dict, output_file="/tmp/heatmap.png"):
"""
Creates a heatmap from the 'best_correlation' values in correlations_dict.
In addition to correlation, each cell is annotated with the best_lag and lag_unit.
Saves the resulting figure to 'output_file' instead of showing it.
The input 'correlations_dict' is expected to look like:
Expand All @@ -28,7 +26,6 @@ def create_best_correlation_heatmap(correlations_dict, output_file="heatmap.png"
Only the best_correlation field is used for coloring the heatmap
(pairs with None or NaN are left blank).
The annotation in each cell will display correlation + lag info if available.
"""
# 1) Gather all columns (col1, col2) by splitting each key on " and "
all_cols = set()
Expand All @@ -48,51 +45,39 @@ def create_best_correlation_heatmap(correlations_dict, output_file="heatmap.png"
# Convert the set of columns to a sorted list (for consistent ordering)
all_cols = sorted(all_cols)

# 2) Initialize an NxN DataFrame for numeric correlation and a second for text annotations
# 2) Initialize an NxN DataFrame for numeric correlation
df_matrix = pd.DataFrame(np.nan, index=all_cols, columns=all_cols)
df_annot = pd.DataFrame("", index=all_cols, columns=all_cols)

# 3) Fill in the best correlations and annotations
# 3) Fill in the best correlations
for col1, col2, best_corr, best_lag, best_lag_unit in data_for_matrix:
if best_corr is not None and not np.isnan(best_corr):
df_matrix.loc[col1, col2] = best_corr
df_matrix.loc[col2, col1] = best_corr

corr_str = f"{best_corr:.2f}"
if best_lag_unit and best_lag != 0:
annotation = f"{corr_str}\n({best_lag} {best_lag_unit})"
else:
annotation = corr_str

df_annot.loc[col1, col2] = annotation
df_annot.loc[col2, col1] = annotation

# (Optional) Set diagonal to 1.0 correlation, with a simple annotation like "1.00"
# (Optional) Set diagonal to 1.0 correlation
for col in all_cols:
df_matrix.loc[col, col] = 1.0
df_annot.loc[col, col] = "1.00"

# 4) Plot the heatmap (no plt.show())
plt.figure(figsize=(10, 8))
sns.heatmap(
df_matrix,
annot=df_annot,
fmt="",
annot=False, # Disable text annotations
cmap="coolwarm",
square=True,
center=0.0,
vmin=-1,
vmax=1,
)
plt.title("Best Correlation Heatmap (with Lag Info)")
plt.title("Best Correlation Heatmap")
plt.tight_layout()

# Save the figure to disk
plt.savefig(output_file)
plt.close() # Close the figure to free resources


def in_depth_plot_scatter(df_info_list, output_file="in_depth_scatter.png"):
def in_depth_plot_scatter(df_info_list, output_file="/tmp/in_depth_scatter.png"):
"""
Accepts a list of TWO DataFrameInfo objects (each with one column),
merges them, computes correlation, and returns:
Expand Down Expand Up @@ -161,7 +146,7 @@ def in_depth_plot_scatter(df_info_list, output_file="in_depth_scatter.png"):
}


def plot_lag_correlations(correlations_dict, output_dir="lag_plots"):
def plot_lag_correlations(correlations_dict, output_dir="/tmp/lag_plots"):
"""
For each pair of columns in 'correlations_dict', we look at 'lag_details'
and group them by lag_unit (e.g., hours, days). Then we make a separate plot
Expand All @@ -174,10 +159,13 @@ def plot_lag_correlations(correlations_dict, output_dir="lag_plots"):
"""
import os
import matplotlib.pyplot as plt
import io
import base64

os.makedirs(output_dir, exist_ok=True)

seen_pairs = set() # Track pairs we've already plotted
plot_images = {} # Dictionary to store base64-encoded images

for pair_name, info in correlations_dict.items():
# Split on " and " to get the two column names.
Expand Down Expand Up @@ -231,4 +219,15 @@ def plot_lag_correlations(correlations_dict, output_dir="lag_plots"):

plt.tight_layout()
plt.savefig(filepath)
plt.close(fig)

# Save the figure to a buffer and encode it in base64
buf = io.BytesIO()
fig.savefig(buf, format="png")
buf.seek(0)
img_base64 = base64.b64encode(buf.read()).decode("utf-8")
plt.close(fig) # Free resources

# Store the base64 image in the dictionary
plot_images[f"{pair_label}_{safe_unit}"] = img_base64

return plot_images
20 changes: 20 additions & 0 deletions dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
FROM python:3.11.9

WORKDIR /app

RUN apt-get update && apt-get install -y \
git \
libpq-dev \
gcc \
&& apt-get clean

COPY requirements.txt .

RUN pip install --upgrade pip
RUN pip install --no-cache-dir -r requirements.txt

COPY . .

EXPOSE 3000

CMD ["python", "main.py"]
20 changes: 20 additions & 0 deletions get_trend_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import eliona.api_client2
from eliona.api_client2.rest import ApiException
from eliona.api_client2.api.data_api import DataApi
from eliona.api_client2.api.assets_api import AssetsApi
import os
import logging
import pytz
from datetime import datetime
from api.models import AssetAttribute

# Initialize the logger
logger = logging.getLogger(__name__)
Expand All @@ -18,6 +20,24 @@
# Create an instance of the API client
api_client = eliona.api_client2.ApiClient(configuration)
data_api = DataApi(api_client)
assets_api = AssetsApi(api_client)


def get_all_asset_children(asset_id):
try:
logger.info(f"Fetching all assets to find children for asset {asset_id}")
assets = assets_api.get_assets()
child_ids = [asset_id] # Start with the parent asset_id

for asset in assets:
if asset_id in asset.locational_asset_id_path:
child_ids.append(asset.id)

logger.info(f"Found {len(child_ids) - 1} children for asset {asset_id}")
return [AssetAttribute(asset_id=child_id) for child_id in child_ids]
except ApiException as e:
logger.error(f"Exception when calling AssetsApi->get_assets: {e}")
return [AssetAttribute(asset_id=asset_id)]


def get_trend_data(asset_id, start_date, end_date):
Expand Down
1 change: 1 addition & 0 deletions icon
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

3 changes: 2 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ def start_api():
port = int(os.getenv("API_SERVER_PORT", 3000))
uvicorn.run("api.openapi:app", host="0.0.0.0", port=port)

#Initialize()

# Initialize()
start_api()
19 changes: 19 additions & 0 deletions metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{
"name": "Correlation",
"elionaMinVersion": "v11.0.0",
"displayName": {
"en": "correlation app",
"de": "Korrelations-App"
},
"description": {
"en": "This app provides a correlation analysis.",
"de": "Diese App bietet eine Korrelationsanalyse."
},
"dashboardTemplateNames": [
"correlation"
],
"apiUrl": "v1",
"apiSpecificationPath": "/version/openapi.json",
"documentationUrl": "https://doc.eliona.io/eliona/referenzen/app-entwicklung",
"useEnvironment": [ "CONNECTION_STRING", "API_ENDPOINT", "API_TOKEN" ]
}
Loading

0 comments on commit 132bb33

Please sign in to comment.