Skip to content

Commit

Permalink
demo changes
Browse files Browse the repository at this point in the history
  • Loading branch information
XianBW committed Aug 2, 2024
1 parent b4d846d commit e2b8211
Showing 1 changed file with 60 additions and 49 deletions.
109 changes: 60 additions & 49 deletions rdagent/log/ui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,21 +195,28 @@ def evolving_feedback_window(wsf: FactorSingleFeedback | ModelCoderFeedback):
st.markdown(wsf.value_feedback)


def display_hypotheses(hypotheses: dict[int, Hypothesis], decisions: dict[int, bool]):
shd = {k: v.__dict__ for k, v in hypotheses.items()}
def display_hypotheses(hypotheses: dict[int, Hypothesis], decisions: dict[int, bool], success_only: bool = False):
if success_only:
shd = {k: v.__dict__ for k, v in hypotheses.items() if decisions[k]}
else:
shd = {k: v.__dict__ for k, v in hypotheses.items()}
df = pd.DataFrame(shd).T
if "reason" in df.columns:
df.drop(["reason"], axis=1, inplace=True)
df.columns = df.columns.map(lambda x: x.replace("_", " ").capitalize())

def highlight_rows(row):
def style_rows(row):
if decisions[row.name]:
return ['color: green; font-weight: bold;'] * len(row)

def background_color_columns(col):
if col.name == 'hypothesis':
return ['background-color: lightgrey'] * len(col)
return ['color: green;'] * len(row)
return [''] * len(row)

st.dataframe(df.style.apply(highlight_rows, axis=1).apply(background_color_columns, axis=0))
def style_columns(col):
if col.name != 'Hypothesis':
return ['font-style: italic;'] * len(col)
return ['font-weight: bold;'] * len(col)

st.markdown(df.style.apply(highlight_rows, axis=1).apply(background_color_columns, axis=0).to_html(), unsafe_allow_html=True)
# st.dataframe(df.style.apply(style_rows, axis=1).apply(style_columns, axis=0))
st.markdown(df.style.apply(style_rows, axis=1).apply(style_columns, axis=0).to_html(), unsafe_allow_html=True)


def metrics_window(df: pd.DataFrame, R: int, C: int, *, height: int = 300, colors: list[str] = None):
Expand All @@ -219,7 +226,9 @@ def hypothesis_hover_text(h: Hypothesis, d: bool = False):
text = h.hypothesis
lines = textwrap.wrap(text, width=60)
return f"<span style='color: {color};'>{'<br>'.join(lines)}</span>"
hover_texts = [hypothesis_hover_text(state.hypotheses[int(i[6:])], state.h_decisions[int(i[6:])]) for i in df.index]
hover_texts = [hypothesis_hover_text(state.hypotheses[int(i[6:])], state.h_decisions[int(i[6:])]) for i in df.index if i != "alpha158"]
if state.alpha158_metrics is not None:
hover_texts = ["Baseline: alpha158"] + hover_texts
for ci, col in enumerate(df.columns):
row = ci // C + 1
col_num = ci % C + 1
Expand Down Expand Up @@ -256,52 +265,41 @@ def summary_window():
with st.container():
# TODO: not fixed height
with st.container():
ac,bc,cc = st.columns([2,1,2], vertical_alignment="center")
with ac:
st.subheader("Hypotheses🏅", anchor="_hypotheses")
bc,cc = st.columns([2,2], vertical_alignment="center")
with bc:
st.subheader("Metrics📈", anchor="_metrics")
with cc:
show_true_only = st.toggle("successful hypotheses", value=False)

hypotheses_c, chart_c = st.columns([2, 3])
with hypotheses_c:
with st.container(height=700):
h_strs = []
for id, h in state.hypotheses.items():
if state.h_decisions[id]:
h_strs.append(f"{id}. :green[**{h.hypothesis}**]\n\t>:green-background[*{h.__dict__.get('concise_reason', '')}*]")
else:
h_strs.append(f"{id}. {h.hypothesis}\n\t>*{h.__dict__.get('concise_reason', '')}*")
# hypotheses_c, chart_c = st.columns([2, 3])
chart_c = st.container()
hypotheses_c = st.container()

if hasattr(h, "concise_observation"):
h_strs[-1] += f"\n\n\t>:blue[**Observation**]: {h.concise_observation}"
h_strs[-1] += f"\n\n\t>:blue[**Justification**]: {h.concise_justification}"
h_strs[-1] += f"\n\n\t>:blue[**Knowledge**]: {h.concise_knowledge}"
st.markdown("\n".join(h_strs))
with hypotheses_c:
st.subheader("Hypotheses🏅", anchor="_hypotheses")
display_hypotheses(state.hypotheses, state.h_decisions, show_true_only)

with chart_c:
with st.container(height=700):
if state.log_type == "qlib_factor":
df = pd.DataFrame([state.alpha158_metrics] + state.metric_series)
if state.log_type == "qlib_factor":
df = pd.DataFrame([state.alpha158_metrics] + state.metric_series)
else:
df = pd.DataFrame(state.metric_series)
if show_true_only and len(state.hypotheses) >= len(state.metric_series):
if state.alpha158_metrics is not None:
selected = ["alpha158"] + [i for i in df.index if state.h_decisions[int(i[6:])]]
else:
df = pd.DataFrame(state.metric_series)
if show_true_only and len(state.hypotheses) >= len(state.metric_series):
if state.alpha158_metrics is not None:
selected = ["alpha158"] + [i for i in df.index if state.h_decisions[int(i[6:])]]
else:
selected = [i for i in df.index if state.h_decisions[int(i[6:])]]
df = df.loc[selected]
if df.shape[0] == 1:
st.table(df.iloc[0])
elif df.shape[0] > 1:
if df.shape[1] == 1:
# suhan's scenario
fig = px.line(df, x=df.index, y=df.columns, markers=True)
fig.update_layout(xaxis_title="Loop Round", yaxis_title=None)
st.plotly_chart(fig)
else:
metrics_window(df, 2, 2, height=650, colors=['red', 'blue', 'orange', 'green'])
selected = [i for i in df.index if state.h_decisions[int(i[6:])]]
df = df.loc[selected]
if df.shape[0] == 1:
st.table(df.iloc[0])
elif df.shape[0] > 1:
if df.shape[1] == 1:
# suhan's scenario
fig = px.line(df, x=df.index, y=df.columns, markers=True)
fig.update_layout(xaxis_title="Loop Round", yaxis_title=None)
st.plotly_chart(fig)
else:
metrics_window(df, 1, 4, height=300, colors=['red', 'blue', 'orange', 'green'])


elif state.log_type == "model_extraction_and_implementation" and len(state.msgs[state.lround]["d.evolving code"]) > 0:
Expand Down Expand Up @@ -443,10 +441,23 @@ def tasks_window(tasks: list[FactorTask | ModelTask]):
if isinstance(state.last_msg.content, list):
st.write(state.last_msg.content[0])
elif not isinstance(state.last_msg.content, str):
st.write(state.last_msg.content)
st.write(state.last_msg.content.__dict__)


# Main Window
header_c1, header_c3 = st.columns([1, 6], vertical_alignment="center")
with st.container():
with header_c1:
st.image("https://img-prod-cms-rt-microsoft-com.akamaized.net/cms/api/am/imageFileData/RE1Mu3b?ver=5c31")
with header_c3:
st.markdown(
"""
<h1>
RD-Agent:<br>LLM-based autonomous evolving agents for industrial data-driven R&D
</h1>
""",
unsafe_allow_html=True
)

# Project Info
with st.container():
Expand Down

0 comments on commit e2b8211

Please sign in to comment.