diff --git a/evaluation/agent_bench/scripts/summarise_results.py b/evaluation/agent_bench/scripts/summarise_results.py index 67a8964b1da2..3523f4d51f1a 100644 --- a/evaluation/agent_bench/scripts/summarise_results.py +++ b/evaluation/agent_bench/scripts/summarise_results.py @@ -5,13 +5,13 @@ def extract_test_results(res_file_path: str) -> tuple[list[str], list[str]]: passed = [] failed = [] - with open(res_file_path, 'r') as file: + with open(res_file_path, "r") as file: for line in file: data = json.loads(line.strip()) - instance_id = data['instance_id'] + instance_id = data["instance_id"] resolved = False - if 'test_result' in data and 'result' in data['test_result']: - resolved = data['test_result']['result'] + if "test_result" in data and "result" in data["test_result"]: + resolved = data["test_result"]["result"] if resolved: passed.append(instance_id) else: @@ -19,19 +19,19 @@ def extract_test_results(res_file_path: str) -> tuple[list[str], list[str]]: return passed, failed -if __name__ == '__main__': +if __name__ == "__main__": if len(sys.argv) != 2: print( - 'Usage: poetry run python summarise_results.py ' + "Usage: poetry run python summarise_results.py " ) sys.exit(1) json_file_path = sys.argv[1] passed_tests, failed_tests = extract_test_results(json_file_path) succ_rate = len(passed_tests) / (len(passed_tests) + len(failed_tests)) print( - f'\nPassed {len(passed_tests)} tests, failed {len(failed_tests)} tests, resolve rate = {succ_rate}' + f"\nPassed {len(passed_tests)} tests, failed {len(failed_tests)} tests, resolve rate = {succ_rate}" ) - print('PASSED TESTS:') + print("PASSED TESTS:") print(passed_tests) - print('FAILED TESTS:') + print("FAILED TESTS:") print(failed_tests) diff --git a/evaluation/aider_bench/scripts/summarize_results.py b/evaluation/aider_bench/scripts/summarize_results.py index 3dc66f5509d0..47f437d11355 100644 --- a/evaluation/aider_bench/scripts/summarize_results.py +++ b/evaluation/aider_bench/scripts/summarize_results.py @@ -8,10 +8,10 @@ def extract_test_results(df: pd.DataFrame) -> tuple[list[str], list[str]]: passed = [] failed = [] for _, row in df.iterrows(): - instance_id = row['instance_id'] + instance_id = row["instance_id"] resolved = False - if 'test_result' in row and 'exit_code' in row['test_result']: - resolved = row['test_result']['exit_code'] == 0 + if "test_result" in row and "exit_code" in row["test_result"]: + resolved = row["test_result"]["exit_code"] == 0 if resolved: passed.append(instance_id) else: @@ -21,38 +21,38 @@ def extract_test_results(df: pd.DataFrame) -> tuple[list[str], list[str]]: def visualize_results(df: pd.DataFrame): df1 = pd.DataFrame() - df1['cost'] = df['metrics'].apply(pd.Series)['accumulated_cost'] - df1['result'] = ( - df['test_result'].apply(pd.Series)['exit_code'].map({0: 'Pass', 1: 'Fail'}) + df1["cost"] = df["metrics"].apply(pd.Series)["accumulated_cost"] + df1["result"] = ( + df["test_result"].apply(pd.Series)["exit_code"].map({0: "Pass", 1: "Fail"}) ) - df1['actions'] = pd.Series([len(a) - 1 for a in df['history']]) + df1["actions"] = pd.Series([len(a) - 1 for a in df["history"]]) - passed = np.sum(df1['result'] == 'Pass') + passed = np.sum(df1["result"] == "Pass") total = df.shape[0] resolve_rate = round((passed / total) * 100, 2) - print('Number of passed tests:', f'{passed}/{total} {resolve_rate:.2f}%') - print('\nDescriptive statistics for number of actions:') - print(df1['actions'].describe()) - print('\nDescriptive statistics for costs:') - print(df1['cost'].describe()) + print("Number of passed tests:", f"{passed}/{total} {resolve_rate:.2f}%") + print("\nDescriptive statistics for number of actions:") + print(df1["actions"].describe()) + print("\nDescriptive statistics for costs:") + print(df1["cost"].describe()) # Bin counts for actions - action_bins = pd.cut(df1['actions'], bins=range(0, 32, 2)) - print('\nAction bin counts:') + action_bins = pd.cut(df1["actions"], bins=range(0, 32, 2)) + print("\nAction bin counts:") print(action_bins.value_counts().sort_index()) # Bin counts for costs - cost_bins = pd.cut(df1['cost'], bins=10) - print('\nCost bin counts:') + cost_bins = pd.cut(df1["cost"], bins=10) + print("\nCost bin counts:") print(cost_bins.value_counts().sort_index()) return resolve_rate -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Summarize AiderBench results') - parser.add_argument('input_filepath', type=str, help='Path to the JSONL file') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Summarize AiderBench results") + parser.add_argument("input_filepath", type=str, help="Path to the JSONL file") args = parser.parse_args() # Create DataFrame from JSONL file @@ -62,9 +62,9 @@ def visualize_results(df: pd.DataFrame): resolve_rate = visualize_results(df) print( - f'\nPassed {len(passed_tests)} tests, failed {len(failed_tests)} tests, resolve rate = {resolve_rate:.2f}%' + f"\nPassed {len(passed_tests)} tests, failed {len(failed_tests)} tests, resolve rate = {resolve_rate:.2f}%" ) - print('PASSED TESTS:') + print("PASSED TESTS:") print(passed_tests) - print('FAILED TESTS:') + print("FAILED TESTS:") print(failed_tests) diff --git a/evaluation/biocoder/scripts/setup/copy_changed_code.py b/evaluation/biocoder/scripts/setup/copy_changed_code.py index 2cee1e97b66f..0115d29a480d 100644 --- a/evaluation/biocoder/scripts/setup/copy_changed_code.py +++ b/evaluation/biocoder/scripts/setup/copy_changed_code.py @@ -7,18 +7,18 @@ def get_changed_code(target_filepath, line_start, include_signature=False): selected_lines = [] offset = 1 if include_signature else 0 - with open('/testing_files/first_line_after_removed.txt', 'r') as f: + with open("/testing_files/first_line_after_removed.txt", "r") as f: first_line_after_removed = f.read() if first_line_after_removed is None: - print('First line after removed is None') + print("First line after removed is None") - with open(target_filepath, 'r') as f: - lines = f.read().split('\n') + with open(target_filepath, "r") as f: + lines = f.read().split("\n") for i in range(line_start - offset, len(lines)): if lines[i].strip() == first_line_after_removed.strip(): break selected_lines.append(lines[i]) - text = '\n'.join(selected_lines) + text = "\n".join(selected_lines) return text @@ -26,16 +26,16 @@ def copy_changed_code( target_filepath, generated_code_filepath, line_start, include_signature=False ): changed_code = get_changed_code(target_filepath, line_start, include_signature) - with open(generated_code_filepath, 'w') as f: + with open(generated_code_filepath, "w") as f: f.write(changed_code) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--target_filepath', type=str, required=True) - parser.add_argument('--generated_code_filepath', type=str, required=True) - parser.add_argument('--line_start', type=int, required=True) - parser.add_argument('--include_signature', action='store_true') + parser.add_argument("--target_filepath", type=str, required=True) + parser.add_argument("--generated_code_filepath", type=str, required=True) + parser.add_argument("--line_start", type=int, required=True) + parser.add_argument("--include_signature", action="store_true") args = parser.parse_args() copy_changed_code( args.target_filepath, diff --git a/evaluation/biocoder/scripts/setup/remove_code.py b/evaluation/biocoder/scripts/setup/remove_code.py index 3c76a41738d5..51e35a2c27b4 100644 --- a/evaluation/biocoder/scripts/setup/remove_code.py +++ b/evaluation/biocoder/scripts/setup/remove_code.py @@ -19,24 +19,24 @@ def get_likely_indent_size(array_of_tabs) -> int: def get_target_filepath(self): target_filepath = os.path.join( self.workspace_mount_path, - self.biocoder_instance.repository.split('/')[1], + self.biocoder_instance.repository.split("/")[1], self.biocoder_instance.filePath, ) return target_filepath def remove_code(target_filepath: str, line_start: int, line_end: int, language: str): - comment_prefix = {'python': '#', 'java': '//'} + comment_prefix = {"python": "#", "java": "//"} - with open(target_filepath, 'r') as f: - lines = f.read().split('\n') + with open(target_filepath, "r") as f: + lines = f.read().split("\n") # print("="*10+"ORIGINAL"+"="*10) # print("\n".join(lines)) signature_line = lines[line_start - 1] # get the number of tabs def get_indent_size(s: str): - return len(re.match(r'\s*', s).group()) + return len(re.match(r"\s*", s).group()) indent_sizes = list(map(get_indent_size, lines)) indent_size = get_likely_indent_size(indent_sizes) @@ -46,7 +46,7 @@ def get_indent_size(s: str): + [ f"{' '*comment_indent_size+comment_prefix[language.lower()]}TODO: replace with your code here" ] - + ([''] * 2) + + ([""] * 2) + lines[line_end:] ) first_line_after_removed_index = line_start @@ -56,19 +56,19 @@ def get_indent_size(s: str): first_line_after_removed_index += 1 first_line_after_removed = lines[first_line_after_removed_index] - print('FIRST LINE AFTER REMOVED: ', first_line_after_removed) - with open('/testing_files/first_line_after_removed.txt', 'w') as f: + print("FIRST LINE AFTER REMOVED: ", first_line_after_removed) + with open("/testing_files/first_line_after_removed.txt", "w") as f: f.write(first_line_after_removed) - with open(target_filepath, 'w') as f: - f.write('\n'.join(lines)) + with open(target_filepath, "w") as f: + f.write("\n".join(lines)) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--target_filepath', type=str, required=True) - parser.add_argument('--line_start', type=int, required=True) - parser.add_argument('--line_end', type=int, required=True) - parser.add_argument('--language', type=str, required=True) + parser.add_argument("--target_filepath", type=str, required=True) + parser.add_argument("--line_start", type=int, required=True) + parser.add_argument("--line_end", type=int, required=True) + parser.add_argument("--language", type=str, required=True) args = parser.parse_args() remove_code(args.target_filepath, args.line_start, args.line_end, args.language) diff --git a/evaluation/discoverybench/eval_utils/eval_w_subhypo_gen.py b/evaluation/discoverybench/eval_utils/eval_w_subhypo_gen.py index a80df8279cfb..1ebf33e602a6 100644 --- a/evaluation/discoverybench/eval_utils/eval_w_subhypo_gen.py +++ b/evaluation/discoverybench/eval_utils/eval_w_subhypo_gen.py @@ -7,61 +7,61 @@ from .openai_helpers import get_response logging.basicConfig( - format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', - datefmt='%m/%d/%Y %H:%M:%S', + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger = logging.getLogger(__name__) def get_score_from_answer(type, answer): - if type == 'context': - answer = answer.replace('Answer:', '').strip() - if answer.startswith('A)'): + if type == "context": + answer = answer.replace("Answer:", "").strip() + if answer.startswith("A)"): return 1.0 - elif answer.startswith('B)'): + elif answer.startswith("B)"): return 0.0 return -1.0 - elif type == 'var': + elif type == "var": try: var_json = json.loads(answer) # print(f"var_json:{var_json}") p = 0.0 r = 0.0 f1 = 0.0 - if var_json['sizeB']: - p = var_json['intersection'] / var_json['sizeB'] - if var_json['sizeA']: - r = var_json['intersection'] / var_json['sizeA'] + if var_json["sizeB"]: + p = var_json["intersection"] / var_json["sizeB"] + if var_json["sizeA"]: + r = var_json["intersection"] / var_json["sizeA"] if p > 0.0 and r > 0.0: f1 = (2 * p * r) / (p + r) else: f1 = 0.0 eval_rec = { - 'p': p, - 'r': r, - 'f1': f1, - 'sizeA': var_json['sizeA'], - 'sizeB': var_json['sizeB'], - 'intersection': var_json['intersection'], - 'explanation': var_json['explanation'], + "p": p, + "r": r, + "f1": f1, + "sizeA": var_json["sizeA"], + "sizeB": var_json["sizeB"], + "intersection": var_json["intersection"], + "explanation": var_json["explanation"], } - print(f'var_eval: {eval_rec}') + print(f"var_eval: {eval_rec}") return eval_rec except Exception: # COMMENT: added Exception - return {'p': -1.0, 'r': -1.0, 'f1': -1.0} - elif type == 'rel': + return {"p": -1.0, "r": -1.0, "f1": -1.0} + elif type == "rel": print(answer) rel_json = json.loads(answer) - answer_str = rel_json['answer'].strip() - if answer_str.startswith('A') or 'very similar' in answer_str: + answer_str = rel_json["answer"].strip() + if answer_str.startswith("A") or "very similar" in answer_str: return 1.0 elif ( - answer_str.startswith('B') or 'similar but general than HypoA' in answer_str + answer_str.startswith("B") or "similar but general than HypoA" in answer_str ): return 0.5 - elif answer_str.startswith('C') or 'different' in answer_str: + elif answer_str.startswith("C") or "different" in answer_str: return 0.0 return -1.0 return -1.0 @@ -79,28 +79,28 @@ def ask_dimension_question( dataset_type, use_column_metadata=True, ): - dimension_question = '' - answer = '' + dimension_question = "" + answer = "" score = 0.0 - if dimension == 'var': - score = {'p': -1.0, 'r': -1.0, 'f1': -1.0} + if dimension == "var": + score = {"p": -1.0, "r": -1.0, "f1": -1.0} num_tokens = 256 num_retries = 1 json_response = False messages = [ { - 'role': 'system', - 'content': 'You are an AI assistant that helps evaluate a data-driven hypothesis. You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.', + "role": "system", + "content": "You are an AI assistant that helps evaluate a data-driven hypothesis. You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.", }, ] - if dimension == 'context': + if dimension == "context": dimension_question = """\ Question: Is HypoB defined in the same context as HypoA? (Context refers to assumptions/stratification under which the hypotheses are defined.) Options: A) same B) different What is your answer?""" - elif dimension == 'var': + elif dimension == "var": dimension_question = """\ Question: For both HypoA and HypoB, what are the different variables found in the hypotheses? \ Return your answer as a JSON object in the following format: @@ -115,7 +115,7 @@ def ask_dimension_question( num_tokens = 512 num_retries = 1 json_response = True - elif dimension == 'rel': + elif dimension == "rel": dimension_question = """\ Question: Does HypoB exhibit the same relation as HypoA? Compare using following example hierarchy of relationships (based on specificity): \ @@ -161,7 +161,7 @@ def ask_dimension_question( {dimension_question}""" - messages.append({'role': 'user', 'content': dimension_question_str}) + messages.append({"role": "user", "content": dimension_question_str}) for retry in range(num_retries): response = run_chatgpt_query_multi_turn( messages=messages, @@ -184,32 +184,32 @@ def prepare_dataset_metadata_json(dataset_meta, dataset_type, use_column_metadat if dataset_meta is None: # COMMENT: changed from == to is None return [ { - 'dataset_description': '', - 'columns': [], + "dataset_description": "", + "columns": [], } ] datasets_json = [] - if dataset_type == 'real': - for d in dataset_meta['datasets']: + if dataset_type == "real": + for d in dataset_meta["datasets"]: datasets_json.append( { - 'dataset_description': d['description'], - 'columns': [ - {'name': col['name'], 'description': col['description']} - for col in d['columns']['raw'] + "dataset_description": d["description"], + "columns": [ + {"name": col["name"], "description": col["description"]} + for col in d["columns"]["raw"] ] if use_column_metadata else [], } ) else: - for d in dataset_meta['datasets']: + for d in dataset_meta["datasets"]: datasets_json.append( { - 'dataset_description': d['description'], - 'columns': [ - {'name': col['name'], 'description': col['description']} - for col in d['columns'] + "dataset_description": d["description"], + "columns": [ + {"name": col["name"], "description": col["description"]} + for col in d["columns"] ] if use_column_metadata else [], @@ -272,19 +272,19 @@ def get_sub_hypotheses( if sub_hypo_json is not None: # COMMENT: changed from != to is not # print(f"full hypothesis: {hypo}") - print(f'sub_hypo_json: {sub_hypo_json}') + print(f"sub_hypo_json: {sub_hypo_json}") else: sub_hypo_json = { - 'sub_hypo': [], + "sub_hypo": [], } - sub_hypo_json['full_hypo'] = hypo + sub_hypo_json["full_hypo"] = hypo return sub_hypo_json def match_context_with_gpt( - gold_hyp, gold_context, pred_hyp, pred_context, model='gpt-3.5-turbo' + gold_hyp, gold_context, pred_hyp, pred_context, model="gpt-3.5-turbo" ): prompt = f"""\ Given a gold hypothesis, a gold context, a predicted hypothesis, and a predicted context, your task is \ @@ -314,13 +314,13 @@ def match_context_with_gpt( client = OpenAI() output = get_response(client, prompt, model=model) - return output.get('match', False) + return output.get("match", False) def is_matching_context(gold_hyp, gold_context, pred_hyp, pred_context, llm_used): if gold_context == pred_context: return True - if 'None' in [gold_context, pred_context]: + if "None" in [gold_context, pred_context]: return False return match_context_with_gpt( gold_hyp, gold_context, pred_hyp, pred_context, model=llm_used @@ -342,14 +342,14 @@ def run_eval_gold_vs_gen_NL_subhypo( # GPT-4 based evaluation to evaluate generated hypothesis in terms of context, variables, relation eval_rec = { - 'query': query, - 'HypoA': gold_hypo, - 'WorkflowA': gold_workflow, - 'HypoB': gen_hypo, - 'WorkflowB': gen_workflow, + "query": query, + "HypoA": gold_hypo, + "WorkflowA": gold_workflow, + "HypoB": gen_hypo, + "WorkflowB": gen_workflow, } - for dimension in ['var', 'rel']: + for dimension in ["var", "rel"]: question, answer, score = ask_dimension_question( query, gold_hypo, @@ -363,14 +363,14 @@ def run_eval_gold_vs_gen_NL_subhypo( use_column_metadata=use_column_metadata, ) - eval_rec[dimension] = {'question': question, 'answer': answer, 'score': score} + eval_rec[dimension] = {"question": question, "answer": answer, "score": score} - eval_rec['context'] = context_score - eval_rec['accuracy_score'] = ( + eval_rec["context"] = context_score + eval_rec["accuracy_score"] = ( 1.0 - * eval_rec['context']['score'] - * eval_rec['var']['score']['f1'] - * eval_rec['rel']['score'] + * eval_rec["context"]["score"] + * eval_rec["var"]["score"]["f1"] + * eval_rec["rel"]["score"] ) return eval_rec @@ -409,11 +409,11 @@ def run_eval_gold_vs_gen_NL_hypo_workflow( # recall_context = 1.0 # COMMENT: never used eval_rec = { - 'query': query, - 'HypoA': gold_hypo, - 'WorkflowA': gold_workflow, - 'HypoB': gen_hypo, - 'WorkflowB': gen_workflow, + "query": query, + "HypoA": gold_hypo, + "WorkflowA": gold_workflow, + "HypoB": gen_hypo, + "WorkflowB": gen_workflow, } gold_sub_hypo_json = get_sub_hypotheses( @@ -425,17 +425,17 @@ def run_eval_gold_vs_gen_NL_hypo_workflow( dataset_type=dataset_type, use_column_metadata=use_column_metadata, ) - if len(gold_sub_hypo_json['sub_hypo']) == 0: - gold_sub_hypo_json['sub_hypo'] = [ + if len(gold_sub_hypo_json["sub_hypo"]) == 0: + gold_sub_hypo_json["sub_hypo"] = [ { - 'text': gold_hypo, - 'context': 'None', - 'variables': [], - 'relations': '', - 'explanation': 'unable to segment', + "text": gold_hypo, + "context": "None", + "variables": [], + "relations": "", + "explanation": "unable to segment", } ] - print(f'gold_sub_hypo_json: {gold_sub_hypo_json}') + print(f"gold_sub_hypo_json: {gold_sub_hypo_json}") gen_sub_hypo_json = get_sub_hypotheses( query=query, @@ -446,38 +446,38 @@ def run_eval_gold_vs_gen_NL_hypo_workflow( dataset_type=dataset_type, use_column_metadata=use_column_metadata, ) - if len(gen_sub_hypo_json['sub_hypo']) == 0: - gen_sub_hypo_json['sub_hypo'] = [ + if len(gen_sub_hypo_json["sub_hypo"]) == 0: + gen_sub_hypo_json["sub_hypo"] = [ { - 'text': gen_hypo, - 'context': 'None', - 'variables': [], - 'relations': '', - 'explanation': 'unable to segment', + "text": gen_hypo, + "context": "None", + "variables": [], + "relations": "", + "explanation": "unable to segment", } ] - print(f'gen_sub_hypo_json: {gen_sub_hypo_json}') + print(f"gen_sub_hypo_json: {gen_sub_hypo_json}") - eval_rec['gold_sub_hypo'] = gold_sub_hypo_json - eval_rec['gen_sub_hypo'] = gen_sub_hypo_json + eval_rec["gold_sub_hypo"] = gold_sub_hypo_json + eval_rec["gen_sub_hypo"] = gen_sub_hypo_json gold_subh_covered = [] gen_subh_to_gold_subh = dict() gen_gold_subh_to_context = dict() - for p_id, gen_subh in enumerate(gen_sub_hypo_json['sub_hypo']): + for p_id, gen_subh in enumerate(gen_sub_hypo_json["sub_hypo"]): gen_subh_to_gold_subh[p_id] = -1 - for g_id, gold_subh in enumerate(gold_sub_hypo_json['sub_hypo']): + for g_id, gold_subh in enumerate(gold_sub_hypo_json["sub_hypo"]): if g_id in gold_subh_covered: continue # match context context_bool = is_matching_context( - gold_subh['text'], - gold_subh.get('context', ''), - gen_subh['text'], - gen_subh.get('context', ''), + gold_subh["text"], + gold_subh.get("context", ""), + gen_subh["text"], + gen_subh.get("context", ""), llm_used, ) if context_bool: @@ -488,21 +488,21 @@ def run_eval_gold_vs_gen_NL_hypo_workflow( if context_score == 1.0: # match only when context_score = 1.0 gen_subh_to_gold_subh[p_id] = g_id gold_subh_covered.append(g_id) - gen_gold_subh_to_context[f'P{p_id}||G{g_id}'] = { - 'question': f"""Comapring: GoldH: {gold_subh["text"]}, GoldC: {gold_subh['context']}\nGenH: {gen_subh['text']}, GenC: {gen_subh['context']}""", - 'answer': context_bool, - 'score': context_score, + gen_gold_subh_to_context[f"P{p_id}||G{g_id}"] = { + "question": f"""Comapring: GoldH: {gold_subh["text"]}, GoldC: {gold_subh['context']}\nGenH: {gen_subh['text']}, GenC: {gen_subh['context']}""", + "answer": context_bool, + "score": context_score, } break - print(f'gen_subh_to_gold_subh: {gen_subh_to_gold_subh}') - eval_rec['gen_subh_to_gold_subh'] = gen_subh_to_gold_subh - eval_rec['gold_subh_covered'] = gold_subh_covered + print(f"gen_subh_to_gold_subh: {gen_subh_to_gold_subh}") + eval_rec["gen_subh_to_gold_subh"] = gen_subh_to_gold_subh + eval_rec["gold_subh_covered"] = gold_subh_covered matched_gold_gen_subh_evals = dict() sum_accuracy_score = 0.0 for p_id, g_id in gen_subh_to_gold_subh.items(): if g_id >= 0: - key = f'P{p_id}||G{g_id}' + key = f"P{p_id}||G{g_id}" context_score = gen_gold_subh_to_context[key] subh_eval_rec = run_eval_gold_vs_gen_NL_subhypo( query, @@ -516,13 +516,13 @@ def run_eval_gold_vs_gen_NL_hypo_workflow( dataset_type=dataset_type, use_column_metadata=use_column_metadata, ) - sum_accuracy_score += subh_eval_rec['accuracy_score'] + sum_accuracy_score += subh_eval_rec["accuracy_score"] matched_gold_gen_subh_evals[key] = subh_eval_rec - eval_rec['matched_gold_gen_subh_evals'] = matched_gold_gen_subh_evals - eval_rec['recall_context'] = ( - len(gold_subh_covered) / len(gold_sub_hypo_json['sub_hypo']) - if len(gold_sub_hypo_json['sub_hypo']) + eval_rec["matched_gold_gen_subh_evals"] = matched_gold_gen_subh_evals + eval_rec["recall_context"] = ( + len(gold_subh_covered) / len(gold_sub_hypo_json["sub_hypo"]) + if len(gold_sub_hypo_json["sub_hypo"]) else 0.0 ) mean_accuracy_score = ( @@ -530,9 +530,9 @@ def run_eval_gold_vs_gen_NL_hypo_workflow( if len(gen_subh_to_gold_subh) else 0.0 ) - eval_rec['mean_accuracy_score'] = mean_accuracy_score - final_score = eval_rec['recall_context'] * mean_accuracy_score - eval_rec['final_score'] = final_score - print(f'eval_rec: {json.dumps(eval_rec, indent=2)}') + eval_rec["mean_accuracy_score"] = mean_accuracy_score + final_score = eval_rec["recall_context"] * mean_accuracy_score + eval_rec["final_score"] = final_score + print(f"eval_rec: {json.dumps(eval_rec, indent=2)}") return eval_rec diff --git a/evaluation/discoverybench/eval_utils/lm_utils.py b/evaluation/discoverybench/eval_utils/lm_utils.py index 10486ee82294..63db09e6a36f 100644 --- a/evaluation/discoverybench/eval_utils/lm_utils.py +++ b/evaluation/discoverybench/eval_utils/lm_utils.py @@ -15,22 +15,22 @@ from typing_extensions import Literal -Model = Literal['gpt-4', 'gpt-3.5-turbo', 'text-davinci-003'] +Model = Literal["gpt-4", "gpt-3.5-turbo", "text-davinci-003"] -OpenAI.api_key = os.getenv('OPENAI_API_KEY') +OpenAI.api_key = os.getenv("OPENAI_API_KEY") OPENAI_GEN_HYP = { - 'temperature': 0, - 'max_tokens': 250, - 'top_p': 1.0, - 'frequency_penalty': 0, - 'presence_penalty': 0, + "temperature": 0, + "max_tokens": 250, + "top_p": 1.0, + "frequency_penalty": 0, + "presence_penalty": 0, } @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) def run_chatgpt_query_multi_turn( messages, - model_name='gpt-4-turbo', # pass "gpt4" for more recent model output + model_name="gpt-4-turbo", # pass "gpt4" for more recent model output max_tokens=256, temperature=0.0, json_response=False, @@ -46,7 +46,7 @@ def run_chatgpt_query_multi_turn( if json_response: response = client.chat.completions.create( model=model_name, - response_format={'type': 'json_object'}, + response_format={"type": "json_object"}, messages=messages, **OPENAI_GEN_HYP, ) @@ -58,7 +58,7 @@ def run_chatgpt_query_multi_turn( except Exception as e: print(e) - print('GPT error. Retrying in 2 seconds...') + print("GPT error. Retrying in 2 seconds...") time.sleep(2) return response diff --git a/evaluation/discoverybench/eval_utils/openai_helpers.py b/evaluation/discoverybench/eval_utils/openai_helpers.py index 95ab23cf9c2e..5b945e9458df 100644 --- a/evaluation/discoverybench/eval_utils/openai_helpers.py +++ b/evaluation/discoverybench/eval_utils/openai_helpers.py @@ -4,34 +4,34 @@ def OPENAI_TOPIC_GEN_MESSAGES(n=10): return [ { - 'role': 'system', - 'content': 'You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.', + "role": "system", + "content": "You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.", }, { - 'role': 'user', - 'content': f'Given `n`, come up with a list of `n` distinct topics and their descriptions. The topics can be absolutely anything. Be as creative as possible. Return your answer as a JSON object. \n\nFor example, for `n`=3, a valid answer might be:\n```json\n{{"topics": [\n {{"id": 1, "topic": "cooking", "description": "Related to recipes, ingredients, chefs, etc."}},\n {{"id": 2, "topic": "sports", "description": "Related to players, stadiums, trophies, etc."}},\n {{"id": 3, "topic": "antiquing", "description": "Related to unique items, history, etc."}}\n]}}```\n\nNow, give me a list for `n`={n}. Remember, pick diverse topics from everything possible. No consecutive topics should be broadly similar. Directly respond with the answer JSON object.', + "role": "user", + "content": f'Given `n`, come up with a list of `n` distinct topics and their descriptions. The topics can be absolutely anything. Be as creative as possible. Return your answer as a JSON object. \n\nFor example, for `n`=3, a valid answer might be:\n```json\n{{"topics": [\n {{"id": 1, "topic": "cooking", "description": "Related to recipes, ingredients, chefs, etc."}},\n {{"id": 2, "topic": "sports", "description": "Related to players, stadiums, trophies, etc."}},\n {{"id": 3, "topic": "antiquing", "description": "Related to unique items, history, etc."}}\n]}}```\n\nNow, give me a list for `n`={n}. Remember, pick diverse topics from everything possible. No consecutive topics should be broadly similar. Directly respond with the answer JSON object.', }, ] OPENAI_GEN_HYP = { - 'temperature': 1.0, - 'max_tokens': 4096, - 'top_p': 1.0, - 'frequency_penalty': 0, - 'presence_penalty': 0, + "temperature": 1.0, + "max_tokens": 4096, + "top_p": 1.0, + "frequency_penalty": 0, + "presence_penalty": 0, } def OPENAI_SEMANTICS_GEN_MESSAGES(dependent, relationship, domain, domain_desc): return [ { - 'role': 'system', - 'content': 'You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.', + "role": "system", + "content": "You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.", }, { - 'role': 'user', - 'content': f'Given the true relationship in a dataset and a given domain, your task is to come up with an interpretation of some real-world concepts that the relationship could be modeling from the provided domain. It\'s okay to be wrong, but suggest something reasonable. Try as much as possible to make sure that the TARGET is actually derivable from the other variables. Give your answer as a JSON object. Here\'s an example:\n\nRelationship for x2 = "(96.4 * x1 ** 3) + (88.72 * x5 ** 2) + (81.96 * x6 ** -2) + (28.13 * x3) + (97.0) + (0 * x4)"\nDomain="Sales"\nDomain description="Related to product distribution, revenues, marketing, etc."\n\nBased on this, the following real-world concepts might be applicable:\n```json\n{{\n "dependent": "x2",\n "relationship": "(96.4 * x1 ** 3) + (88.72 * x5 ** 2) + (81.96 * x6 ** -2) + (28.13 * x3) + (97.0) + (0 * x4)",\n "domain": "Sales",\n "trends": {{\n "x1": "Positive, cubic factor",\n "x2": "TARGET",\n "x3": "Positive, linear factor",\n "x4": "No relation",\n "x5": "Positive quadratic factor",\n "x6": "Positive, inverse quadratic factor"\n }},\n "interpretation": {{\n "x2": {{"description": "Volume of product sales by area", "name": "sales_area", "is_target": true}},\n "x1": {{"description": "Population by area", "name": "pop_area"}},\n "x3": {{"description": "Advertising spending", "name": "ad_spend"}},\n "x4": {{"description": "Gender ratio of marketing team", "name": "gdr_ratio_mkt_team"}},\n "x5": {{"description": "Intensity of marketing campaign", "name": "mkt_intensity"}}\n }},\n "x6": {{"description": "Distance to distribution center", "name": "dist_to_distr_ctr"}}\n}}```\n\nHere\'s a new test question:\nRelationship for {dependent} = "{relationship}"\nDomain = "{domain}"\nDomain description="{domain_desc}"\n\nRespond only with the answer JSON. Make sure that you do not forget to include the TARGET variable in the interpretation object.', + "role": "user", + "content": f'Given the true relationship in a dataset and a given domain, your task is to come up with an interpretation of some real-world concepts that the relationship could be modeling from the provided domain. It\'s okay to be wrong, but suggest something reasonable. Try as much as possible to make sure that the TARGET is actually derivable from the other variables. Give your answer as a JSON object. Here\'s an example:\n\nRelationship for x2 = "(96.4 * x1 ** 3) + (88.72 * x5 ** 2) + (81.96 * x6 ** -2) + (28.13 * x3) + (97.0) + (0 * x4)"\nDomain="Sales"\nDomain description="Related to product distribution, revenues, marketing, etc."\n\nBased on this, the following real-world concepts might be applicable:\n```json\n{{\n "dependent": "x2",\n "relationship": "(96.4 * x1 ** 3) + (88.72 * x5 ** 2) + (81.96 * x6 ** -2) + (28.13 * x3) + (97.0) + (0 * x4)",\n "domain": "Sales",\n "trends": {{\n "x1": "Positive, cubic factor",\n "x2": "TARGET",\n "x3": "Positive, linear factor",\n "x4": "No relation",\n "x5": "Positive quadratic factor",\n "x6": "Positive, inverse quadratic factor"\n }},\n "interpretation": {{\n "x2": {{"description": "Volume of product sales by area", "name": "sales_area", "is_target": true}},\n "x1": {{"description": "Population by area", "name": "pop_area"}},\n "x3": {{"description": "Advertising spending", "name": "ad_spend"}},\n "x4": {{"description": "Gender ratio of marketing team", "name": "gdr_ratio_mkt_team"}},\n "x5": {{"description": "Intensity of marketing campaign", "name": "mkt_intensity"}}\n }},\n "x6": {{"description": "Distance to distribution center", "name": "dist_to_distr_ctr"}}\n}}```\n\nHere\'s a new test question:\nRelationship for {dependent} = "{relationship}"\nDomain = "{domain}"\nDomain description="{domain_desc}"\n\nRespond only with the answer JSON. Make sure that you do not forget to include the TARGET variable in the interpretation object.', }, ] @@ -41,12 +41,12 @@ def OPENAI_SEMANTICS_GEN_W_MAP_MESSAGES( ): return [ { - 'role': 'system', - 'content': 'You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.', + "role": "system", + "content": "You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.", }, { - 'role': 'user', - 'content': f'Given a partial mapping from variables to real-world concepts and a true relationship in a dataset, your task is to come up with an interpretation of real-world concepts for the variables without any assigned mapping (those starting with x). Suggest something reasonable. The dependent variable must be derivable only from the other variables in the dependent relationship. Give your answer as a JSON object. Here\'s an example:\n\nExample partial mapping and relationship:\n```json\n{{\n "domain": "Sales",\n "domain_description": "Related to product distribution, revenues, marketing, etc.",\n "variable_mapping": {{\n "x1": {{"description": "Population by area", "name": "pop_area"}},\n "x2": {{"description": "Volume of product sales by area", "name": "sales_area"}},\n "x4": {{"description": "Gender ratio of marketing team", "name": "gdr_ratio_mkt_team"}},\n "x6": {{"description": "Distance to distribution center", "name": "dist_to_distr_ctr"}}\n }},\n "dependent_variable": "sales_area",\n "dependent_relationship": "(96.4 * pop_area ** 3) + (88.72 * x5 ** 2) + (81.96 * dist_to_distr_ctr ** -2) + (28.13 * x3) + (97.0)"\n}}```\nBased on this, an example answer would be:\n```json\n{{\n "dependent_variable": "sales_area",\n "missing_mapping": ["x3", "x5"],\n "trends": {{\n "x3": "Positive, linear factor",\n "x5": "Positive quadratic factor"\n }},\n "interpretation": {{\n "x3": {{"description": "Advertising spending", "name": "ad_spend"}},\n "x5": {{"description": "Intensity of marketing campaign", "name": "mkt_intensity"}}\n }}\n}}```\n\nHere\'s a new test question:\n```json\n{{\n "domain": "{domain}",\n "domain_description": "{domain_desc}",\n "variable_mapping": {json.dumps(mapping, indent=2)},\n "dependent_variable": "{dependent}",\n "dependent_relationship": "{relationship}"\n}}```\nRespond only with the answer JSON.', + "role": "user", + "content": f'Given a partial mapping from variables to real-world concepts and a true relationship in a dataset, your task is to come up with an interpretation of real-world concepts for the variables without any assigned mapping (those starting with x). Suggest something reasonable. The dependent variable must be derivable only from the other variables in the dependent relationship. Give your answer as a JSON object. Here\'s an example:\n\nExample partial mapping and relationship:\n```json\n{{\n "domain": "Sales",\n "domain_description": "Related to product distribution, revenues, marketing, etc.",\n "variable_mapping": {{\n "x1": {{"description": "Population by area", "name": "pop_area"}},\n "x2": {{"description": "Volume of product sales by area", "name": "sales_area"}},\n "x4": {{"description": "Gender ratio of marketing team", "name": "gdr_ratio_mkt_team"}},\n "x6": {{"description": "Distance to distribution center", "name": "dist_to_distr_ctr"}}\n }},\n "dependent_variable": "sales_area",\n "dependent_relationship": "(96.4 * pop_area ** 3) + (88.72 * x5 ** 2) + (81.96 * dist_to_distr_ctr ** -2) + (28.13 * x3) + (97.0)"\n}}```\nBased on this, an example answer would be:\n```json\n{{\n "dependent_variable": "sales_area",\n "missing_mapping": ["x3", "x5"],\n "trends": {{\n "x3": "Positive, linear factor",\n "x5": "Positive quadratic factor"\n }},\n "interpretation": {{\n "x3": {{"description": "Advertising spending", "name": "ad_spend"}},\n "x5": {{"description": "Intensity of marketing campaign", "name": "mkt_intensity"}}\n }}\n}}```\n\nHere\'s a new test question:\n```json\n{{\n "domain": "{domain}",\n "domain_description": "{domain_desc}",\n "variable_mapping": {json.dumps(mapping, indent=2)},\n "dependent_variable": "{dependent}",\n "dependent_relationship": "{relationship}"\n}}```\nRespond only with the answer JSON.', }, ] @@ -54,12 +54,12 @@ def OPENAI_SEMANTICS_GEN_W_MAP_MESSAGES( def OPENAI_SEMANTICS_GEN_SUMMARY_MESSAGES(dataset): return [ { - 'role': 'system', - 'content': 'You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.', + "role": "system", + "content": "You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.", }, { - 'role': 'user', - 'content': f'Given the following descriptions of the columns of a dataset, your task is to come up with a natural language overview of the dataset, which should include (1) what the dataset is about, (2) how the data was collected, (3) when the data was collected, and (3) for what purpose the data was collected. Be specific and creative.\n\nExample dataset:\n```json\n{{ \n "dataset": {{ \n "x6": {{"description": "Ancient artifact significance score", "name": "artifact_significance_score", "is_target": true}},\n "x1": {{"description": "Distance to ancient city center", "name": "dist_to_ancient_city_ctr"}},\n "x2": {{"description": "Quantity of discovered relics", "name": "relic_discovery_qty"}},\n "x3": {{"description": "Years since last archaeological expedition", "name": "years_since_exp"}},\n "x4": {{"description": "Number of artifacts in excavation site", "name": "artifact_qty"}},\n "x5": {{"description": "Soil fertility coefficient", "name": "soil_fertility_coef"}},\n "x7": {{"description": "Distance to ancient burial grounds", "name": "dist_to_burial_grounds"}},\n "x8": {{"description": "Population estimate of ancient civilization", "name": "ancient_civilization_pop_estimate"}},\n "x9": {{"description": "Temperature variation in excavation region", "name": "temp_variation"}}\n }}\n}}```\nExample description:\nThis dataset is about archaeological explorations and findings linked to ancient civilizations. The data was collected in the form of field metrics during various archaeological expeditions during the late mid-20th century. The purpose of the data collection is to evaluate the significance of ancient artifacts discovered during excavations.\n\nHere is a new test dataset.\n{json.dumps(dataset, indent=2)}\nProvide only the description.', + "role": "user", + "content": f'Given the following descriptions of the columns of a dataset, your task is to come up with a natural language overview of the dataset, which should include (1) what the dataset is about, (2) how the data was collected, (3) when the data was collected, and (3) for what purpose the data was collected. Be specific and creative.\n\nExample dataset:\n```json\n{{ \n "dataset": {{ \n "x6": {{"description": "Ancient artifact significance score", "name": "artifact_significance_score", "is_target": true}},\n "x1": {{"description": "Distance to ancient city center", "name": "dist_to_ancient_city_ctr"}},\n "x2": {{"description": "Quantity of discovered relics", "name": "relic_discovery_qty"}},\n "x3": {{"description": "Years since last archaeological expedition", "name": "years_since_exp"}},\n "x4": {{"description": "Number of artifacts in excavation site", "name": "artifact_qty"}},\n "x5": {{"description": "Soil fertility coefficient", "name": "soil_fertility_coef"}},\n "x7": {{"description": "Distance to ancient burial grounds", "name": "dist_to_burial_grounds"}},\n "x8": {{"description": "Population estimate of ancient civilization", "name": "ancient_civilization_pop_estimate"}},\n "x9": {{"description": "Temperature variation in excavation region", "name": "temp_variation"}}\n }}\n}}```\nExample description:\nThis dataset is about archaeological explorations and findings linked to ancient civilizations. The data was collected in the form of field metrics during various archaeological expeditions during the late mid-20th century. The purpose of the data collection is to evaluate the significance of ancient artifacts discovered during excavations.\n\nHere is a new test dataset.\n{json.dumps(dataset, indent=2)}\nProvide only the description.', }, ] @@ -67,12 +67,12 @@ def OPENAI_SEMANTICS_GEN_SUMMARY_MESSAGES(dataset): def OPENAI_GEN_HYPO_MESSAGES(dataset): return [ { - 'role': 'system', - 'content': 'You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.', + "role": "system", + "content": "You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.", }, { - 'role': 'user', - 'content': f'Given a dataset with its descriptions and the true functional relationship between its variables, your task is to generate 3 levels of hypotheses for the stated relationship in plain English. The three levels are "broad", "medium" and "narrow". Make sure that the hypotheses sound natural. *Only include concepts for variables that are present in the provided functional relationship.* Give your answer as a JSON.\n\nFor example, an example dataset might be the following:\n```json\n{{\n "domain": "cybersecurity",\n "summary": "This dataset is about measuring cybersecurity threats in a system. The data was collected by monitoring various cybersecurity metrics in a network environment. The purpose of the data collection is to assess and predict potential cybersecurity risks and vulnerabilities.",\n "variables": [\n {{\n "description": "Level of cybersecurity threat",\n "name": "cybersecurity_threat",\n "is_target": true\n }},\n {{\n "description": "Number of failed login attempts",\n "name": "failed_login_attempts"\n }},\n {{\n "description": "Amount of encrypted data",\n "name": "encrypted_data"\n }},\n {{\n "description": "Frequency of software updates",\n "name": "software_updates"\n }},\n {{\n "description": "Number of antivirus software installed",\n "name": "antivirus_software"\n }},\n {{\n "description": "Quality of firewall protection",\n "name": "firewall_quality"\n }}\n ],\n "relationship": {{\n "dependent": "cybersecurity_threat",\n "relation": "-53.5*encrypted_data**2 - 53.85*failed_login_attempts**2 + 67.75*firewall_quality - 92.16 - 36.68/software_updates**3"\n }}\n}}```\nGiven this dataset, the following is a valid answer:\n```json\n{{\n "broad": {{\n "instruction": "Be vague. Only indicate which concepts might be related but not how they are related",\n "hypothesis": "Threat to cybersecurity is influenced by several factors including the amount of encrypted data, the number of failed login attempts, the quality of the firewall, as well as how often the software is updated."\n }},\n "medium": {{\n "instruction": "Be slightly more specific. For each factor, indicate carefully whether it positively or negatively affects the relationship, but do not indicate what the exponent is.",\n "hypothesis": "Cybersecurity threat tends to decrease with the amount of data encryption, the number of failed login attempts, as well as the frequency of software updates to some extent, while improvement in the firewall quality has a positive effect."\n }},\n "narrow": {{\n "instruction": "Be specific. Communicate the concepts, whether there is a positive or negative effect (be careful), and the meaning of the exponent",\n "hypothesis": "The threat to cybersecurity interacts in a complex manner with various factors. As the amount of encrypted data increases, there is a quadratic decrease in threat. Similarly for the number of failed login attempts, there is a negative quadratic relationship. The quality of the firewall protection on the other hand demonstrates a positive and linear relationship. Finally, the frequency of software updates has an inverse cubic relationship to the threat."\n }},\n}}\n```\n\nBased on this, provide an answer for the following test dataset:\n```json\n{dataset}```\nRespond only with a JSON.', + "role": "user", + "content": f'Given a dataset with its descriptions and the true functional relationship between its variables, your task is to generate 3 levels of hypotheses for the stated relationship in plain English. The three levels are "broad", "medium" and "narrow". Make sure that the hypotheses sound natural. *Only include concepts for variables that are present in the provided functional relationship.* Give your answer as a JSON.\n\nFor example, an example dataset might be the following:\n```json\n{{\n "domain": "cybersecurity",\n "summary": "This dataset is about measuring cybersecurity threats in a system. The data was collected by monitoring various cybersecurity metrics in a network environment. The purpose of the data collection is to assess and predict potential cybersecurity risks and vulnerabilities.",\n "variables": [\n {{\n "description": "Level of cybersecurity threat",\n "name": "cybersecurity_threat",\n "is_target": true\n }},\n {{\n "description": "Number of failed login attempts",\n "name": "failed_login_attempts"\n }},\n {{\n "description": "Amount of encrypted data",\n "name": "encrypted_data"\n }},\n {{\n "description": "Frequency of software updates",\n "name": "software_updates"\n }},\n {{\n "description": "Number of antivirus software installed",\n "name": "antivirus_software"\n }},\n {{\n "description": "Quality of firewall protection",\n "name": "firewall_quality"\n }}\n ],\n "relationship": {{\n "dependent": "cybersecurity_threat",\n "relation": "-53.5*encrypted_data**2 - 53.85*failed_login_attempts**2 + 67.75*firewall_quality - 92.16 - 36.68/software_updates**3"\n }}\n}}```\nGiven this dataset, the following is a valid answer:\n```json\n{{\n "broad": {{\n "instruction": "Be vague. Only indicate which concepts might be related but not how they are related",\n "hypothesis": "Threat to cybersecurity is influenced by several factors including the amount of encrypted data, the number of failed login attempts, the quality of the firewall, as well as how often the software is updated."\n }},\n "medium": {{\n "instruction": "Be slightly more specific. For each factor, indicate carefully whether it positively or negatively affects the relationship, but do not indicate what the exponent is.",\n "hypothesis": "Cybersecurity threat tends to decrease with the amount of data encryption, the number of failed login attempts, as well as the frequency of software updates to some extent, while improvement in the firewall quality has a positive effect."\n }},\n "narrow": {{\n "instruction": "Be specific. Communicate the concepts, whether there is a positive or negative effect (be careful), and the meaning of the exponent",\n "hypothesis": "The threat to cybersecurity interacts in a complex manner with various factors. As the amount of encrypted data increases, there is a quadratic decrease in threat. Similarly for the number of failed login attempts, there is a negative quadratic relationship. The quality of the firewall protection on the other hand demonstrates a positive and linear relationship. Finally, the frequency of software updates has an inverse cubic relationship to the threat."\n }},\n}}\n```\n\nBased on this, provide an answer for the following test dataset:\n```json\n{dataset}```\nRespond only with a JSON.', }, ] @@ -80,14 +80,14 @@ def OPENAI_GEN_HYPO_MESSAGES(dataset): def create_prompt(usr_msg): return [ { - 'role': 'system', - 'content': 'You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.', + "role": "system", + "content": "You are a helpful assistant who is not talkative. You only respond with the exact answer to a query without additional conversation.", }, - {'role': 'user', 'content': usr_msg}, + {"role": "user", "content": usr_msg}, ] -def get_response(client, prompt, max_retry=5, model='gpt-3.5-turbo', verbose=False): +def get_response(client, prompt, max_retry=5, model="gpt-3.5-turbo", verbose=False): n_try = 0 while n_try < max_retry: response = client.chat.completions.create( @@ -97,26 +97,26 @@ def get_response(client, prompt, max_retry=5, model='gpt-3.5-turbo', verbose=Fal # COMMENT: changed from # response.choices[0].message.content.strip().strip('```json').strip('```') content = response.choices[0].message.content - cleaned_content = content.split('```json')[1].split('```')[0].strip() + cleaned_content = content.split("```json")[1].split("```")[0].strip() output = cleaned_content try: response_json = json.loads(output) return response_json except ValueError: if verbose: - print(f'Bad JSON output:\n\n{output}') + print(f"Bad JSON output:\n\n{output}") n_try += 1 if n_try < max_retry: if verbose: - print('Retrying...') + print("Retrying...") else: if verbose: - print('Retry limit reached') + print("Retry limit reached") return None def get_code_fix( - client, code, error, max_retry=5, model='gpt-3.5-turbo', verbose=False + client, code, error, max_retry=5, model="gpt-3.5-turbo", verbose=False ): prompt = f"""\ Given the following code snippet and error message, provide a single-line fix for the error. \ @@ -141,7 +141,7 @@ def get_code_fix( def get_new_hypothesis( - client, target, old, expr, cols, model='gpt-3.5-turbo', verbose=False + client, target, old, expr, cols, model="gpt-3.5-turbo", verbose=False ): prompt = f"""\ Given a target column from a dataset, a pandas expression to derive the column from existing columns, a list of \ @@ -168,7 +168,7 @@ def get_new_hypothesis( return response -def replace_variable(client, expr, old, new, model='gpt-3.5-turbo', verbose=False): +def replace_variable(client, expr, old, new, model="gpt-3.5-turbo", verbose=False): prompt = f"""\ Given a pandas "expression", replace mentions of the "old" column with its "new" value such that the resultant \ expression is equivalent to the original expression. diff --git a/evaluation/discoverybench/eval_utils/openai_semantic_gen_prompts.py b/evaluation/discoverybench/eval_utils/openai_semantic_gen_prompts.py index a0b5438e4c8a..4af6963e312a 100644 --- a/evaluation/discoverybench/eval_utils/openai_semantic_gen_prompts.py +++ b/evaluation/discoverybench/eval_utils/openai_semantic_gen_prompts.py @@ -1,51 +1,51 @@ common_hypothesis_features = [ - '1-2 sentences', - 'surprising finding', - 'includes numeric concepts', - 'includes categorical concepts', - 'includes binary concepts', + "1-2 sentences", + "surprising finding", + "includes numeric concepts", + "includes categorical concepts", + "includes binary concepts", ] hypothesis_features = [ - ['requires within-cluster analysis'], - ['requires across-cluster analysis'], - ['corresponds to a polynomial relationship of some columns'], - ['corresponds to a ratio between some columns'], - ['requires temporal analysis'], - ['relationship is based on descriptive statistics of some columns'], - ['requires concepts based on percentage or percentiles'], - ['relationship is only applicable to one cluster in the data and not the others'], + ["requires within-cluster analysis"], + ["requires across-cluster analysis"], + ["corresponds to a polynomial relationship of some columns"], + ["corresponds to a ratio between some columns"], + ["requires temporal analysis"], + ["relationship is based on descriptive statistics of some columns"], + ["requires concepts based on percentage or percentiles"], + ["relationship is only applicable to one cluster in the data and not the others"], ] column_features = [ [ - 'must have one target column', - 'must have quantifiable columns', - 'must have a few categorical columns', - 'make sure the categorical column values do not contain special characters', - 'include a few distractor columns', + "must have one target column", + "must have quantifiable columns", + "must have a few categorical columns", + "make sure the categorical column values do not contain special characters", + "include a few distractor columns", ] ] common_pandas_features = [ - 'must be executable using python `eval` to create the target column in variable `df` (pandas dataframe)', + "must be executable using python `eval` to create the target column in variable `df` (pandas dataframe)", "for e.g., df['A']**2 + 3*df['B'] + 9, np.where(df['A'] > 3, 'Yes', 'No'), etc.", - 'variables in pandas_expression must be from the existing columns listed above', - 'variables in pandas_expression must NOT contain the target column itself', + "variables in pandas_expression must be from the existing columns listed above", + "variables in pandas_expression must NOT contain the target column itself", ] pandas_features = [ - ['expression is a quadratic polynomial'], - ['expression is a cubic polynomial'], - ['expression is a ratio of existing columns'], - ['expression is derived through logical combination of existing columns'], + ["expression is a quadratic polynomial"], + ["expression is a cubic polynomial"], + ["expression is a ratio of existing columns"], + ["expression is derived through logical combination of existing columns"], # workflow ] pandas_features = [common_pandas_features + p for p in pandas_features] common_derived_features = [ - '1-2 sentences', - 'includes numeric concepts', - 'includes categorical concepts', - 'includes binary concepts', + "1-2 sentences", + "includes numeric concepts", + "includes categorical concepts", + "includes binary concepts", ] derived_features = [common_derived_features + h for h in hypothesis_features] hypothesis_features = [common_hypothesis_features + h for h in hypothesis_features] diff --git a/evaluation/discoverybench/eval_utils/response_parser.py b/evaluation/discoverybench/eval_utils/response_parser.py index b5de82b5df9e..592b4c054212 100644 --- a/evaluation/discoverybench/eval_utils/response_parser.py +++ b/evaluation/discoverybench/eval_utils/response_parser.py @@ -1,24 +1,24 @@ workflow_summary_markers = [ - 'WORKFLOW SUMMARY', - 'WORKFLOW_SUMMARY', - 'WORKFLOW-SUMMARY', - 'Workflow Summary', + "WORKFLOW SUMMARY", + "WORKFLOW_SUMMARY", + "WORKFLOW-SUMMARY", + "Workflow Summary", ] final_answer_markers = [ - 'FINAL ANSWER', - 'FINAL_ANSWER', - 'FINAL-ANSWER', - 'Final Answer', - 'Scientific Hypothesis', - 'Hypothesis', + "FINAL ANSWER", + "FINAL_ANSWER", + "FINAL-ANSWER", + "Final Answer", + "Scientific Hypothesis", + "Hypothesis", ] next_agent_markers = [ - 'NEXT AGENT', - 'NEXT-AGENT', - 'NEXT_AGENT', - 'FEEDBACK', + "NEXT AGENT", + "NEXT-AGENT", + "NEXT_AGENT", + "FEEDBACK", ] @@ -31,22 +31,22 @@ def extract_between(content, start_markers, end_markers=None): if end_marker in result: result = result.split(end_marker, 1)[0] return result - return '' + return "" def extract_gen_hypo_from_logs(content: str): - error = '' + error = "" gen_workflow = extract_between( content, workflow_summary_markers, final_answer_markers ) if not gen_workflow: - error += 'No Workflow Summary found in the line. | ' + error += "No Workflow Summary found in the line. | " gen_hypothesis = extract_between(content, final_answer_markers, next_agent_markers) if not gen_hypothesis: - error += 'No Final Answer in the line.' + error += "No Final Answer in the line." return gen_hypothesis, gen_workflow, error diff --git a/evaluation/discoverybench/run_infer.py b/evaluation/discoverybench/run_infer.py index 72148a64e759..b0a559721dd6 100644 --- a/evaluation/discoverybench/run_infer.py +++ b/evaluation/discoverybench/run_infer.py @@ -89,8 +89,7 @@ def get_config( def get_dv_query_for_real( datasets, question, domain_knowledge=None, workflow_tags=None ): - """ - Prepare a structured query for the agent to execute on the specified datasets. + """Prepare a structured query for the agent to execute on the specified datasets. This function constructs a query by compiling metadata from the provided datasets, along with any relevant domain knowledge and workflow tags. @@ -104,7 +103,6 @@ def get_dv_query_for_real( query_to_dv: Query to be run on the dataset dataset_meta: Metadata of the dataset """ - dataset_meta = '' for dataset_metadata in datasets: dataset_meta += 'Dataset name: ' + dataset_metadata['name'] @@ -140,8 +138,7 @@ def get_dv_query_for_real( def initialize_runtime(runtime: Runtime, data_files: list[str]): - """ - Initialize the runtime for the agent. + """Initialize the runtime for the agent. This function is called before the runtime is used to run the agent. """ @@ -231,8 +228,7 @@ def process_instance( metadata: EvalMetadata, reset_logger: bool = True, ): - """ - Process and evaluate a single instance of the dataset. + """Process and evaluate a single instance of the dataset. This function executes the OpenHands agent for a specific instance of the dataset. It retrieves @@ -247,7 +243,6 @@ def process_instance( Returns: output: EvalOutput object """ - config = get_config(metadata) # use a session id for concurrent evaluation @@ -359,8 +354,7 @@ def list_csv_files(list_of_datasets): def create_dataset(repo_location: str, split: str = 'test'): - """ - Create a dataset from the discoverybench repository + """Create a dataset from the discoverybench repository by walking through the repository and extracting metadata from the metadata_{}.json files @@ -371,7 +365,6 @@ def create_dataset(repo_location: str, split: str = 'test'): Returns: df: DataFrame containing the dataset instances """ - data_dict = {} data_location = os.path.join(repo_location, 'discoverybench', 'real', split) diff --git a/evaluation/integration_tests/tests/t01_fix_simple_typo.py b/evaluation/integration_tests/tests/t01_fix_simple_typo.py index 4cfa331df1b5..01e9f5ecfc64 100644 --- a/evaluation/integration_tests/tests/t01_fix_simple_typo.py +++ b/evaluation/integration_tests/tests/t01_fix_simple_typo.py @@ -8,32 +8,32 @@ class Test(BaseIntegrationTest): - INSTRUCTION = 'Fix typos in bad.txt.' + INSTRUCTION = "Fix typos in bad.txt." @classmethod def initialize_runtime(cls, runtime: Runtime) -> None: # create a file with a typo in /workspace/bad.txt with tempfile.TemporaryDirectory() as temp_dir: - temp_file_path = os.path.join(temp_dir, 'bad.txt') - with open(temp_file_path, 'w') as f: - f.write('This is a stupid typoo.\nReally?\nNo mor typos!\nEnjoy!') + temp_file_path = os.path.join(temp_dir, "bad.txt") + with open(temp_file_path, "w") as f: + f.write("This is a stupid typoo.\nReally?\nNo mor typos!\nEnjoy!") # Copy the file to the desired location - runtime.copy_to(temp_file_path, '/workspace') + runtime.copy_to(temp_file_path, "/workspace") @classmethod def verify_result(cls, runtime: Runtime, histories: list[Event]) -> TestResult: # check if the file /workspace/bad.txt has been fixed - action = CmdRunAction(command='cat /workspace/bad.txt', keep_prompt=False) + action = CmdRunAction(command="cat /workspace/bad.txt", keep_prompt=False) obs = runtime.run_action(action) if obs.exit_code != 0: return TestResult( - success=False, reason=f'Failed to run command: {obs.content}' + success=False, reason=f"Failed to run command: {obs.content}" ) # check if the file /workspace/bad.txt has been fixed if ( - obs.content.strip().replace('\r\n', '\n') - == 'This is a stupid typo.\nReally?\nNo more typos!\nEnjoy!' + obs.content.strip().replace("\r\n", "\n") + == "This is a stupid typo.\nReally?\nNo more typos!\nEnjoy!" ): return TestResult(success=True) - return TestResult(success=False, reason=f'File not fixed: {obs.content}') + return TestResult(success=False, reason=f"File not fixed: {obs.content}") diff --git a/evaluation/integration_tests/tests/t02_add_bash_hello.py b/evaluation/integration_tests/tests/t02_add_bash_hello.py index ac82e89bac05..1e39a101fb2d 100644 --- a/evaluation/integration_tests/tests/t02_add_bash_hello.py +++ b/evaluation/integration_tests/tests/t02_add_bash_hello.py @@ -10,30 +10,30 @@ class Test(BaseIntegrationTest): @classmethod def initialize_runtime(cls, runtime: Runtime) -> None: - action = CmdRunAction(command='mkdir -p /workspace', keep_prompt=False) + action = CmdRunAction(command="mkdir -p /workspace", keep_prompt=False) obs = runtime.run_action(action) - assert_and_raise(obs.exit_code == 0, f'Failed to run command: {obs.content}') + assert_and_raise(obs.exit_code == 0, f"Failed to run command: {obs.content}") @classmethod def verify_result(cls, runtime: Runtime, histories: list[Event]) -> TestResult: # check if the file /workspace/hello.sh exists - action = CmdRunAction(command='cat /workspace/hello.sh', keep_prompt=False) + action = CmdRunAction(command="cat /workspace/hello.sh", keep_prompt=False) obs = runtime.run_action(action) if obs.exit_code != 0: return TestResult( success=False, - reason=f'Failed to cat /workspace/hello.sh: {obs.content}.', + reason=f"Failed to cat /workspace/hello.sh: {obs.content}.", ) # execute the script - action = CmdRunAction(command='bash /workspace/hello.sh', keep_prompt=False) + action = CmdRunAction(command="bash /workspace/hello.sh", keep_prompt=False) obs = runtime.run_action(action) if obs.exit_code != 0: return TestResult( success=False, - reason=f'Failed to execute /workspace/hello.sh: {obs.content}.', + reason=f"Failed to execute /workspace/hello.sh: {obs.content}.", ) - if obs.content.strip() != 'hello': + if obs.content.strip() != "hello": return TestResult( success=False, reason=f'Script did not print "hello": {obs.content}.' ) diff --git a/evaluation/integration_tests/tests/t03_jupyter_write_file.py b/evaluation/integration_tests/tests/t03_jupyter_write_file.py index e1ed6c27c4a6..c0244cdad816 100644 --- a/evaluation/integration_tests/tests/t03_jupyter_write_file.py +++ b/evaluation/integration_tests/tests/t03_jupyter_write_file.py @@ -10,32 +10,32 @@ class Test(BaseIntegrationTest): @classmethod def initialize_runtime(cls, runtime: Runtime) -> None: - action = CmdRunAction(command='mkdir -p /workspace', keep_prompt=False) + action = CmdRunAction(command="mkdir -p /workspace", keep_prompt=False) obs = runtime.run_action(action) - assert_and_raise(obs.exit_code == 0, f'Failed to run command: {obs.content}') + assert_and_raise(obs.exit_code == 0, f"Failed to run command: {obs.content}") @classmethod def verify_result(cls, runtime: Runtime, histories: list[Event]) -> TestResult: # check if the file /workspace/hello.sh exists - action = CmdRunAction(command='cat /workspace/test.txt', keep_prompt=False) + action = CmdRunAction(command="cat /workspace/test.txt", keep_prompt=False) obs = runtime.run_action(action) if obs.exit_code != 0: return TestResult( success=False, - reason=f'Failed to cat /workspace/test.txt: {obs.content}.', + reason=f"Failed to cat /workspace/test.txt: {obs.content}.", ) # execute the script - action = CmdRunAction(command='cat /workspace/test.txt', keep_prompt=False) + action = CmdRunAction(command="cat /workspace/test.txt", keep_prompt=False) obs = runtime.run_action(action) if obs.exit_code != 0: return TestResult( success=False, - reason=f'Failed to cat /workspace/test.txt: {obs.content}.', + reason=f"Failed to cat /workspace/test.txt: {obs.content}.", ) - if 'hello world' not in obs.content.strip(): + if "hello world" not in obs.content.strip(): return TestResult( success=False, reason=f'File did not contain "hello world": {obs.content}.', diff --git a/evaluation/integration_tests/tests/t04_git_staging.py b/evaluation/integration_tests/tests/t04_git_staging.py index aadb861203e7..fcbdc1e55520 100644 --- a/evaluation/integration_tests/tests/t04_git_staging.py +++ b/evaluation/integration_tests/tests/t04_git_staging.py @@ -6,50 +6,50 @@ class Test(BaseIntegrationTest): - INSTRUCTION = 'Write a git commit message for the current staging area and commit the changes.' + INSTRUCTION = "Write a git commit message for the current staging area and commit the changes." @classmethod def initialize_runtime(cls, runtime: Runtime) -> None: - action = CmdRunAction(command='mkdir -p /workspace', keep_prompt=False) + action = CmdRunAction(command="mkdir -p /workspace", keep_prompt=False) obs = runtime.run_action(action) - assert_and_raise(obs.exit_code == 0, f'Failed to run command: {obs.content}') + assert_and_raise(obs.exit_code == 0, f"Failed to run command: {obs.content}") # git init - action = CmdRunAction(command='git init', keep_prompt=False) + action = CmdRunAction(command="git init", keep_prompt=False) obs = runtime.run_action(action) - assert_and_raise(obs.exit_code == 0, f'Failed to run command: {obs.content}') + assert_and_raise(obs.exit_code == 0, f"Failed to run command: {obs.content}") # create README.md action = CmdRunAction( - command='echo \'print("hello world")\' > hello.py', keep_prompt=False + command="echo 'print(\"hello world\")' > hello.py", keep_prompt=False ) obs = runtime.run_action(action) - assert_and_raise(obs.exit_code == 0, f'Failed to run command: {obs.content}') + assert_and_raise(obs.exit_code == 0, f"Failed to run command: {obs.content}") # git add README.md - action = CmdRunAction(command='git add hello.py', keep_prompt=False) + action = CmdRunAction(command="git add hello.py", keep_prompt=False) obs = runtime.run_action(action) - assert_and_raise(obs.exit_code == 0, f'Failed to run command: {obs.content}') + assert_and_raise(obs.exit_code == 0, f"Failed to run command: {obs.content}") @classmethod def verify_result(cls, runtime: Runtime, histories: list[Event]) -> TestResult: # check if the file /workspace/hello.py exists - action = CmdRunAction(command='cat /workspace/hello.py', keep_prompt=False) + action = CmdRunAction(command="cat /workspace/hello.py", keep_prompt=False) obs = runtime.run_action(action) if obs.exit_code != 0: return TestResult( success=False, - reason=f'Failed to cat /workspace/hello.py: {obs.content}.', + reason=f"Failed to cat /workspace/hello.py: {obs.content}.", ) # check if the staging area is empty - action = CmdRunAction(command='git status', keep_prompt=False) + action = CmdRunAction(command="git status", keep_prompt=False) obs = runtime.run_action(action) if obs.exit_code != 0: return TestResult( - success=False, reason=f'Failed to git status: {obs.content}.' + success=False, reason=f"Failed to git status: {obs.content}." ) - if 'nothing to commit, working tree clean' in obs.content.strip(): + if "nothing to commit, working tree clean" in obs.content.strip(): return TestResult(success=True) return TestResult( diff --git a/evaluation/integration_tests/tests/t05_simple_browsing.py b/evaluation/integration_tests/tests/t05_simple_browsing.py index 8f08cb4e7250..54f863e8bc1a 100644 --- a/evaluation/integration_tests/tests/t05_simple_browsing.py +++ b/evaluation/integration_tests/tests/t05_simple_browsing.py @@ -79,29 +79,29 @@ class Test(BaseIntegrationTest): - INSTRUCTION = 'Browse localhost:8000, and tell me the ultimate answer to life.' + INSTRUCTION = "Browse localhost:8000, and tell me the ultimate answer to life." @classmethod def initialize_runtime(cls, runtime: Runtime) -> None: - action = CmdRunAction(command='mkdir -p /workspace', keep_prompt=False) + action = CmdRunAction(command="mkdir -p /workspace", keep_prompt=False) obs = runtime.run_action(action) - assert_and_raise(obs.exit_code == 0, f'Failed to run command: {obs.content}') + assert_and_raise(obs.exit_code == 0, f"Failed to run command: {obs.content}") - action = CmdRunAction(command='mkdir -p /tmp/server', keep_prompt=False) + action = CmdRunAction(command="mkdir -p /tmp/server", keep_prompt=False) obs = runtime.run_action(action) - assert_and_raise(obs.exit_code == 0, f'Failed to run command: {obs.content}') + assert_and_raise(obs.exit_code == 0, f"Failed to run command: {obs.content}") # create a file with a typo in /workspace/bad.txt with tempfile.TemporaryDirectory() as temp_dir: - temp_file_path = os.path.join(temp_dir, 'index.html') - with open(temp_file_path, 'w') as f: + temp_file_path = os.path.join(temp_dir, "index.html") + with open(temp_file_path, "w") as f: f.write(HTML_FILE) # Copy the file to the desired location - runtime.copy_to(temp_file_path, '/tmp/server') + runtime.copy_to(temp_file_path, "/tmp/server") # create README.md action = CmdRunAction( - command='cd /tmp/server && nohup python3 -m http.server 8000 &', + command="cd /tmp/server && nohup python3 -m http.server 8000 &", keep_prompt=False, ) obs = runtime.run_action(action) @@ -120,15 +120,15 @@ def verify_result(cls, runtime: Runtime, histories: list[Event]) -> TestResult: if isinstance(event, AgentDelegateObservation): content = event.content elif isinstance(event, AgentFinishAction): - content = event.outputs.get('content', '') + content = event.outputs.get("content", "") elif isinstance(event, MessageAction): content = event.content else: - raise ValueError(f'Unknown event type: {type(event)}') + raise ValueError(f"Unknown event type: {type(event)}") - if 'OpenHands is all you need!' in content: + if "OpenHands is all you need!" in content: return TestResult(success=True) return TestResult( success=False, - reason=f'The answer is not found in any message. Total messages: {len(message_actions)}. Messages: {message_actions}', + reason=f"The answer is not found in any message. Total messages: {len(message_actions)}. Messages: {message_actions}", ) diff --git a/evaluation/integration_tests/tests/t06_github_pr_browsing.py b/evaluation/integration_tests/tests/t06_github_pr_browsing.py index 52ec927cd334..4707afdd68c7 100644 --- a/evaluation/integration_tests/tests/t06_github_pr_browsing.py +++ b/evaluation/integration_tests/tests/t06_github_pr_browsing.py @@ -6,7 +6,7 @@ class Test(BaseIntegrationTest): - INSTRUCTION = 'Look at https://github.com/All-Hands-AI/OpenHands/pull/8, and tell me what is happening there and what did @asadm suggest.' + INSTRUCTION = "Look at https://github.com/All-Hands-AI/OpenHands/pull/8, and tell me what is happening there and what did @asadm suggest." @classmethod def initialize_runtime(cls, runtime: Runtime) -> None: @@ -26,19 +26,19 @@ def verify_result(cls, runtime: Runtime, histories: list[Event]) -> TestResult: if isinstance(event, AgentDelegateObservation): content = event.content elif isinstance(event, AgentFinishAction): - content = event.outputs.get('content', '') + content = event.outputs.get("content", "") elif isinstance(event, MessageAction): content = event.content else: - raise ValueError(f'Unknown event type: {type(event)}') + raise ValueError(f"Unknown event type: {type(event)}") if ( - 'non-commercial' in content - or 'MIT' in content - or 'Apache 2.0' in content + "non-commercial" in content + or "MIT" in content + or "Apache 2.0" in content ): return TestResult(success=True) return TestResult( success=False, - reason=f'The answer is not found in any message. Total messages: {len(message_actions)}. Messages: {message_actions}', + reason=f"The answer is not found in any message. Total messages: {len(message_actions)}. Messages: {message_actions}", ) diff --git a/evaluation/mint/prompts/__init__.py b/evaluation/mint/prompts/__init__.py index e07c54e748a6..46a4c795256c 100644 --- a/evaluation/mint/prompts/__init__.py +++ b/evaluation/mint/prompts/__init__.py @@ -3,7 +3,7 @@ from utils import load_file PROMPT_DIR = os.path.dirname(__file__) -TEMPLATE_WITH_TOOL = load_file(os.path.join(PROMPT_DIR, 'template_with_tool.txt')) +TEMPLATE_WITH_TOOL = load_file(os.path.join(PROMPT_DIR, "template_with_tool.txt")) class PromptTemplate: @@ -21,5 +21,5 @@ def __init__(self, use_tool: bool): if use_tool: template = TEMPLATE_WITH_TOOL else: - raise NotImplementedError('Evaluation without tool is not supported yet.') + raise NotImplementedError("Evaluation without tool is not supported yet.") super().__init__(template) diff --git a/evaluation/mint/tasks/__init__.py b/evaluation/mint/tasks/__init__.py index 4f6ac721aca9..373a03fb7606 100644 --- a/evaluation/mint/tasks/__init__.py +++ b/evaluation/mint/tasks/__init__.py @@ -7,10 +7,10 @@ ) __all__ = [ - 'Task', - 'MultipleChoiceTask', - 'ReasoningTask', - 'TheoremqaTask', - 'MBPPTask', - 'HumanEvalTask', + "Task", + "MultipleChoiceTask", + "ReasoningTask", + "TheoremqaTask", + "MBPPTask", + "HumanEvalTask", ] diff --git a/evaluation/mint/tasks/base.py b/evaluation/mint/tasks/base.py index d00f4d17111d..cd46c214be44 100644 --- a/evaluation/mint/tasks/base.py +++ b/evaluation/mint/tasks/base.py @@ -5,34 +5,34 @@ from utils import load_file -LOGGER = logging.getLogger('MINT') +LOGGER = logging.getLogger("MINT") class Task(ABC): """Base class for a task instance.""" - task_name: str = 'base' + task_name: str = "base" in_context_example_dir = os.path.join( os.path.dirname(os.path.abspath(__file__)), - 'in_context_examples', + "in_context_examples", ) def __init__(self, **kwargs) -> None: - if 'loaded_history' in kwargs: - self.loaded_history = kwargs['loaded_history'] + if "loaded_history" in kwargs: + self.loaded_history = kwargs["loaded_history"] else: self.loaded_history = None # pre-load the in-context example task_dir = os.path.join(self.in_context_example_dir, self.task_name) self._in_context_example = { - 'with_tool': load_file(os.path.join(task_dir, 'with_tool.txt')), + "with_tool": load_file(os.path.join(task_dir, "with_tool.txt")), } self.metadata = {} @property def task_id(self) -> str: """Return the task id.""" - assert hasattr(self, '_id'), 'Task does not have an id.' + assert hasattr(self, "_id"), "Task does not have an id." return self._id def in_context_example( @@ -40,20 +40,20 @@ def in_context_example( ) -> str: """Return the in-context example for the task.""" if use_tool and not with_feedback: - return self._in_context_example['with_tool'] + return self._in_context_example["with_tool"] else: raise NotImplementedError @property def prompt(self) -> str: """Return the task prompt.""" - assert hasattr(self, '_prompt'), 'Task does not have a prompt.' + assert hasattr(self, "_prompt"), "Task does not have a prompt." return self._prompt @property def reference(self) -> str: """Return the reference solution for the task.""" - assert hasattr(self, '_reference'), 'Task does not have a reference solution.' + assert hasattr(self, "_reference"), "Task does not have a reference solution." return self._reference @abstractmethod @@ -71,20 +71,20 @@ def success(self, solution: str) -> bool: return answer == self.reference @classmethod - def load_tasks(cls, path: str) -> tuple[list['Task'], int]: + def load_tasks(cls, path: str) -> tuple[list["Task"], int]: """Load all the tasks from a given jsonl file.""" - assert path.endswith('.jsonl') or path.endswith('.json') - with open(path, 'r') as f: + assert path.endswith(".jsonl") or path.endswith(".json") + with open(path, "r") as f: tasks = [cls(**json.loads(line)) for line in f.readlines()] - LOGGER.info(f'Loaded {len(tasks)} tasks from {path}') + LOGGER.info(f"Loaded {len(tasks)} tasks from {path}") return tasks, len(tasks) def to_dict(self) -> dict: """Convert the task to a dictionary.""" return { - 'task_name': self.task_name, - 'task_id': self.task_id, - 'prompt': self.prompt, - 'reference': self.reference, - 'metadata': self.metadata, + "task_name": self.task_name, + "task_id": self.task_id, + "prompt": self.prompt, + "reference": self.reference, + "metadata": self.metadata, } diff --git a/evaluation/mint/tasks/codegen.py b/evaluation/mint/tasks/codegen.py index 8a80594ce4b7..2a9a46fd9ac4 100644 --- a/evaluation/mint/tasks/codegen.py +++ b/evaluation/mint/tasks/codegen.py @@ -4,7 +4,7 @@ from evaluation.mint.tasks.base import Task -LOGGER = logging.getLogger('MINT') +LOGGER = logging.getLogger("MINT") class CodeGenTask(Task): @@ -22,16 +22,16 @@ def success(self, solution: str) -> bool: Can be used to provides binary feedback. """ code_to_exec = self.extract_answer(solution) - LOGGER.debug(f'CODE_TO_EXEC:\n{code_to_exec}') - LOGGER.debug(f'TEST_CODE:\n{self._reference}') + LOGGER.debug(f"CODE_TO_EXEC:\n{code_to_exec}") + LOGGER.debug(f"TEST_CODE:\n{self._reference}") res = check_correctness( solution_code=code_to_exec, test_code=self._reference, timeout=10 ) - return res['success'] + return res["success"] class MBPPTask(CodeGenTask): - task_name = 'mbpp' + task_name = "mbpp" @property def prompt(self) -> str: @@ -39,7 +39,7 @@ def prompt(self) -> str: MBPP prompt contains \"\"\" enclosed at both ends. Need to remove it. """ - return self._prompt.replace('"""', '').strip() + return self._prompt.replace('"""', "").strip() def extract_answer(self, solution: str) -> str | None: """Extract the answer from the given solution. @@ -55,7 +55,7 @@ def extract_answer(self, solution: str) -> str | None: class HumanEvalTask(CodeGenTask): - task_name = 'humaneval' + task_name = "humaneval" @property def prompt(self) -> str: @@ -63,7 +63,7 @@ def prompt(self) -> str: MBPP prompt contains \"\"\" enclosed at both ends. Need to remove it. """ - return 'Complete the following code:\n\n' + self._prompt + return "Complete the following code:\n\n" + self._prompt def extract_answer(self, solution: str) -> str | None: """Extract the answer from the given solution. diff --git a/evaluation/mint/tasks/reasoning.py b/evaluation/mint/tasks/reasoning.py index 08cf320c359f..fd178cda1e0d 100644 --- a/evaluation/mint/tasks/reasoning.py +++ b/evaluation/mint/tasks/reasoning.py @@ -9,11 +9,11 @@ from tasks.base import Task -LOGGER = logging.getLogger('MINT') +LOGGER = logging.getLogger("MINT") class ReasoningTask(Task): - task_name = 'reasoning' + task_name = "reasoning" def __init__(self, id: str, prompt: str, reference: str, **kwargs): super().__init__(**kwargs) @@ -35,7 +35,7 @@ def compare_w_digits(self, reference: str, answer: str) -> bool: except ValueError: return reference in answer except Exception: - raise ValueError(f'Cannot compare {reference} and {answer}') + raise ValueError(f"Cannot compare {reference} and {answer}") def success(self, solution: str) -> bool: answer = self.extract_answer(solution) @@ -45,14 +45,14 @@ def success(self, solution: str) -> bool: class MultipleChoiceTask(Task): """Subclass of Task for multiple choice tasks.""" - task_name = 'reasoning' + task_name = "reasoning" def __init__(self, id, prompt: str, reference: str, **kwargs): super().__init__(**kwargs) self._id = id - self.hide_options = kwargs.get('hide_options', False) + self.hide_options = kwargs.get("hide_options", False) if self.hide_options: - self._prompt = prompt.split('Options:')[0].strip() + self._prompt = prompt.split("Options:")[0].strip() else: self._prompt = prompt self._reference = reference.strip().lower() @@ -64,17 +64,17 @@ def __init__(self, id, prompt: str, reference: str, **kwargs): self.hide_options = True except ValueError: pass - self.metadata.update({'options': self._options}) + self.metadata.update({"options": self._options}) def extract_answer(self, solution: str) -> str | None: # Extract the selected option from the solution solution = solution.lower().strip() - for letter in 'abcdefghijklmnopqrstuvwxyz': - if f'{letter})' in solution or f'{letter} )' in solution: - print('SOLUTION', letter) + for letter in "abcdefghijklmnopqrstuvwxyz": + if f"{letter})" in solution or f"{letter} )" in solution: + print("SOLUTION", letter) return letter else: - print('SOLUTION', solution) + print("SOLUTION", solution) return solution def compare_w_digits(self, reference: str, answer: str) -> bool: @@ -90,8 +90,8 @@ def success(self, solution: str) -> bool: else: correct_option = self._options[self._reference] wrong_option_list = list(self._options.values()) - print('OPTIONS', correct_option, wrong_option_list) - print('ANSWER', answer) + print("OPTIONS", correct_option, wrong_option_list) + print("ANSWER", answer) for i in wrong_option_list: if i in correct_option: wrong_option_list.remove(i) @@ -107,20 +107,20 @@ def success(self, solution: str) -> bool: def extract_options(self, prompt: str) -> dict: # Find the possible option separators (comma, semicolon, or parentheses) - prompt = prompt.split('Options: ')[-1] + prompt = prompt.split("Options: ")[-1] # Extract the options using the delimiter - options_match = prompt.split(' , ') + options_match = prompt.split(" , ") options = {} for i in range(len(options_match)): option = options_match[i].strip("[]' ") - option = option.split(')') + option = option.split(")") letter = option[0].lower().strip() content = ( option[1] .lower() - .strip('.') - .replace('. Which option is correct?', '') - .replace('. Which one is correct?', '') + .strip(".") + .replace(". Which option is correct?", "") + .replace(". Which one is correct?", "") .strip() ) options.update({letter: content}) @@ -172,35 +172,35 @@ def parse_number_list(s: str): def is_number(string): - pattern = r'^[-+]?(\d{1,3}(,\d{3})*|(\d+))(\.\d+)?$' + pattern = r"^[-+]?(\d{1,3}(,\d{3})*|(\d+))(\.\d+)?$" match = re.match(pattern, string) return bool(match) def is_scientific_number(string): - pattern = r'^[-+]?\d+(\.\d+)?e[-]?\d+$' + pattern = r"^[-+]?\d+(\.\d+)?e[-]?\d+$" match = re.match(pattern, string) return bool(match) def contain_num_and_str(string): - pattern_str = r'[a-zA-Z]' - pattern_num = r'[0-9]' + pattern_str = r"[a-zA-Z]" + pattern_num = r"[0-9]" return bool(re.search(pattern_str, string) and re.search(pattern_num, string)) class TheoremqaTask(Task): - task_name = 'reasoning' + task_name = "reasoning" def __init__(self, id: str, prompt: str, reference: str, **kwargs): super().__init__(**kwargs) self._id = id self._prompt = ( - 'Answer the following question with a number, a list of numbers or True or False. ' + "Answer the following question with a number, a list of numbers or True or False. " + prompt.strip() ) self._reference = reference - self._answer_type = kwargs.get('answer_type') + self._answer_type = kwargs.get("answer_type") def extract_answer(self, solution: str) -> Any: """Extract the answer from the given solution.""" @@ -210,107 +210,107 @@ def extract_answer(self, solution: str) -> Any: # Preprocessing the string [Stage 1] if not isinstance(prediction, str): - prediction = str(prediction) if prediction is not None else '0' + prediction = str(prediction) if prediction is not None else "0" # Replace special tokens - if '=' in prediction: - prediction = prediction.split('=')[-1].strip() - if '≈' in prediction: - prediction = prediction.split('≈')[-1].strip() - if '`' in prediction: - prediction = prediction.replace('`', '') - if '$' in prediction: - prediction = prediction.replace('$', '') - if '°' in prediction: - prediction = prediction.replace('°', '') + if "=" in prediction: + prediction = prediction.split("=")[-1].strip() + if "≈" in prediction: + prediction = prediction.split("≈")[-1].strip() + if "`" in prediction: + prediction = prediction.replace("`", "") + if "$" in prediction: + prediction = prediction.replace("$", "") + if "°" in prediction: + prediction = prediction.replace("°", "") # Detect the boolean keyword in the generation - if prediction in ('true', 'yes', 'false', 'no'): - if prediction in ('true', 'yes'): - prediction = 'True' + if prediction in ("true", "yes", "false", "no"): + if prediction in ("true", "yes"): + prediction = "True" else: - prediction = 'False' - if 'True' in prediction or 'False' in prediction: - prediction = 'True' if 'True' in prediction else 'False' + prediction = "False" + if "True" in prediction or "False" in prediction: + prediction = "True" if "True" in prediction else "False" # Detect the approximation keyword - if 'approximately' in prediction: - prediction = prediction.replace('approximately', '').strip() - if ' or ' in prediction: - prediction = prediction.split(' or ')[0] + if "approximately" in prediction: + prediction = prediction.replace("approximately", "").strip() + if " or " in prediction: + prediction = prediction.split(" or ")[0] # Drop the units before and after the number - if re.match(r'[-+]?(?:[\d,]*\.*\d+) [^0-9 ]+$', prediction): + if re.match(r"[-+]?(?:[\d,]*\.*\d+) [^0-9 ]+$", prediction): prediction = re.search( - r'([-+]?(?:[\d,]*\.*\d+)) [^0-9 ]+$', prediction + r"([-+]?(?:[\d,]*\.*\d+)) [^0-9 ]+$", prediction ).group(1) - if re.match(r'[^0-9 ]+ [-+]?(?:[\d,]*\.*\d+)$', prediction): + if re.match(r"[^0-9 ]+ [-+]?(?:[\d,]*\.*\d+)$", prediction): prediction = re.search( - r'[^0-9 ]+ ([-+]?(?:[\d,]*\.*\d+))$', prediction + r"[^0-9 ]+ ([-+]?(?:[\d,]*\.*\d+))$", prediction ).group(1) - if re.match(r'[-+]?(?:[\d,]*\.*\d+)[^\d]{1,2}$', prediction): + if re.match(r"[-+]?(?:[\d,]*\.*\d+)[^\d]{1,2}$", prediction): prediction = re.search( - r'([-+]?(?:[\d,]*\.*\d+))[^\d]{1,2}$', prediction + r"([-+]?(?:[\d,]*\.*\d+))[^\d]{1,2}$", prediction ).group(1) - if re.match(r'[^-+\d]{1,2}(?:[\d,]*\.*\d+)$', prediction): + if re.match(r"[^-+\d]{1,2}(?:[\d,]*\.*\d+)$", prediction): prediction = re.search( - r'[^-+\d]{1,2}((?:[\d,]*\.*\d+))$', prediction + r"[^-+\d]{1,2}((?:[\d,]*\.*\d+))$", prediction ).group(1) # Preprocessing the number [Stage 1] - if '10^' in prediction: - prediction = re.sub(r'10\^(-?\d+)', r'math.pow(10, \1)', prediction) - if ' x ' in prediction: - prediction = prediction.replace(' x ', '*') - if ' × ' in prediction: - prediction = prediction.replace(' × ', '*') + if "10^" in prediction: + prediction = re.sub(r"10\^(-?\d+)", r"math.pow(10, \1)", prediction) + if " x " in prediction: + prediction = prediction.replace(" x ", "*") + if " × " in prediction: + prediction = prediction.replace(" × ", "*") if is_number(prediction): - prediction = prediction.replace(',', '') + prediction = prediction.replace(",", "") # Preprocessing the option [Stage 3] if ( - 'a)' in prediction - or 'a )' in prediction - or prediction.lower().strip() == 'a' + "a)" in prediction + or "a )" in prediction + or prediction.lower().strip() == "a" ): - prediction = '(a)' + prediction = "(a)" if ( - 'b)' in prediction - or 'b )' in prediction - or prediction.lower().strip() == 'b' + "b)" in prediction + or "b )" in prediction + or prediction.lower().strip() == "b" ): - prediction = '(b)' + prediction = "(b)" if ( - 'c)' in prediction - or 'c )' in prediction - or prediction.lower().strip() == 'c' + "c)" in prediction + or "c )" in prediction + or prediction.lower().strip() == "c" ): - prediction = '(c)' + prediction = "(c)" if ( - 'd)' in prediction - or 'd )' in prediction - or prediction.lower().strip() == 'd' + "d)" in prediction + or "d )" in prediction + or prediction.lower().strip() == "d" ): - prediction = '(d)' + prediction = "(d)" if ( - '(a)' in prediction - or '(b)' in prediction - or '(c)' in prediction - or '(d)' in prediction + "(a)" in prediction + or "(b)" in prediction + or "(c)" in prediction + or "(d)" in prediction ): - prediction = '"' + re.search(r'\([a-d]\)', prediction).group(0) + '"' + prediction = '"' + re.search(r"\([a-d]\)", prediction).group(0) + '"' # If the prediction is empty, use dummy '0' if not prediction: - prediction = '0' + prediction = "0" # Converting the string answer to a number/list/bool/option try: prediction = eval(prediction) except Exception: LOGGER.warning( - f'[TASK] Failed to convert the answer: {prediction}\n{traceback.format_exc()}' + f"[TASK] Failed to convert the answer: {prediction}\n{traceback.format_exc()}" ) return None # failed to convert the answer @@ -336,19 +336,19 @@ def success(self, solution: str) -> bool: # Follow the implementation from TheoremQA # https://github.com/wenhuchen/TheoremQA/blob/123e36beaaa97c01f28a582f13c4f77a6822c199/predict_accuracy.py#L301C9-L317C1 prediction = self.extract_answer(solution) - LOGGER.info(f'TheoremQA Parsed Prediction: {prediction}') + LOGGER.info(f"TheoremQA Parsed Prediction: {prediction}") answer_type = self._answer_type gt = self.extract_answer(self.reference) if isinstance(prediction, (str, int, float, list)): # Comparing prediction against the reference - if answer_type in ['bool', 'option', 'Option']: - cur_correct = int(prediction == f'({gt})') or int(prediction == gt) - elif answer_type == 'integer': + if answer_type in ["bool", "option", "Option"]: + cur_correct = int(prediction == f"({gt})") or int(prediction == gt) + elif answer_type == "integer": cur_correct = int(compare_two_numbers(prediction, gt)) - elif answer_type == 'float': + elif answer_type == "float": cur_correct = int(compare_two_numbers(prediction, gt)) - elif answer_type in ['list of integer', 'list of float']: + elif answer_type in ["list of integer", "list of float"]: cur_correct = int(compare_two_list(prediction, gt)) else: cur_correct = 0 diff --git a/evaluation/ml_bench/scripts/summarise_results.py b/evaluation/ml_bench/scripts/summarise_results.py index fbc82293e45f..61f2889f265e 100644 --- a/evaluation/ml_bench/scripts/summarise_results.py +++ b/evaluation/ml_bench/scripts/summarise_results.py @@ -9,62 +9,62 @@ def extract_test_results(res_file_path: str) -> tuple[list[str], list[str]]: costs = [] instance_ids = set() instances = [] - with open(res_file_path, 'r') as file: + with open(res_file_path, "r") as file: for line in file: data = json.loads(line.strip()) - success = data['metrics']['success'] - if data['instance_id'] in instance_ids: + success = data["metrics"]["success"] + if data["instance_id"] in instance_ids: print(f'WARNING: Duplicate instance_id found: {data["instance_id"]}') continue - instance_ids.add(data['instance_id']) + instance_ids.add(data["instance_id"]) instances.append(data) if success: passed.append( { - 'instance_id': data['instance_id'], - 'repo': data['repo'], - 'instruction': data['instruction'], - 'eval_script': data['eval_script'], - 'eval_exit_code': data['eval_exit_code'], - 'eval_output': data['eval_output'], - 'accumulated_cost': data['metrics']['accumulated_cost'], + "instance_id": data["instance_id"], + "repo": data["repo"], + "instruction": data["instruction"], + "eval_script": data["eval_script"], + "eval_exit_code": data["eval_exit_code"], + "eval_output": data["eval_output"], + "accumulated_cost": data["metrics"]["accumulated_cost"], } ) else: failed.append( { - 'instance_id': data['instance_id'], - 'repo': data['repo'], - 'instruction': data['instruction'], - 'eval_script': data['eval_script'], - 'eval_exit_code': data['eval_exit_code'], - 'eval_output': data['eval_output'], - 'accumulated_cost': data['metrics']['accumulated_cost'], + "instance_id": data["instance_id"], + "repo": data["repo"], + "instruction": data["instruction"], + "eval_script": data["eval_script"], + "eval_exit_code": data["eval_exit_code"], + "eval_output": data["eval_output"], + "accumulated_cost": data["metrics"]["accumulated_cost"], } ) - costs.append(data['metrics']['accumulated_cost']) + costs.append(data["metrics"]["accumulated_cost"]) # sort by instance_id - instances.sort(key=lambda x: x['instance_id']) - with open(res_file_path, 'w') as file: + instances.sort(key=lambda x: x["instance_id"]) + with open(res_file_path, "w") as file: for instance in instances: - file.write(json.dumps(instance) + '\n') + file.write(json.dumps(instance) + "\n") return passed, failed, costs -if __name__ == '__main__': +if __name__ == "__main__": if len(sys.argv) != 2: print( - 'Usage: poetry run python summarise_results.py ' + "Usage: poetry run python summarise_results.py " ) sys.exit(1) json_file_path = sys.argv[1] passed_tests, failed_tests, costs = extract_test_results(json_file_path) success_rate = len(passed_tests) / (len(passed_tests) + len(failed_tests)) - print('PASSED TESTS:') + print("PASSED TESTS:") pprint.pprint(passed_tests) - print('FAILED TESTS:') + print("FAILED TESTS:") pprint.pprint(failed_tests) print( - f'\nPassed {len(passed_tests)} tests, failed {len(failed_tests)} tests, success rate = {success_rate}, average cost = {sum(costs) / len(costs)}' + f"\nPassed {len(passed_tests)} tests, failed {len(failed_tests)} tests, success rate = {success_rate}, average cost = {sum(costs) / len(costs)}" ) diff --git a/evaluation/regression/cases/hello-world/test_hello_world.py b/evaluation/regression/cases/hello-world/test_hello_world.py index 6b4b808c4eda..2bed2ad612c3 100644 --- a/evaluation/regression/cases/hello-world/test_hello_world.py +++ b/evaluation/regression/cases/hello-world/test_hello_world.py @@ -4,17 +4,17 @@ from conftest import agents -@pytest.mark.parametrize('agent', agents()) +@pytest.mark.parametrize("agent", agents()) def test_hello_world(task_file, run_test_case, agent): """Test case for the "Hello, World!" Bash script using different agents.""" # Run the test case for the specified agent - workspace_dir = run_test_case(agent, 'hello-world') + workspace_dir = run_test_case(agent, "hello-world") # Validate the generated workspace assert os.path.exists(workspace_dir) - assert os.path.isfile(os.path.join(workspace_dir, 'hello_world.sh')) + assert os.path.isfile(os.path.join(workspace_dir, "hello_world.sh")) # Execute the hello_world.sh script os.chdir(workspace_dir) - output = os.popen('bash hello_world.sh').read() - assert output == 'Hello, World!\n' + output = os.popen("bash hello_world.sh").read() + assert output == "Hello, World!\n" diff --git a/evaluation/regression/cases/node-cli-rewrite/start/commands/scramble.py b/evaluation/regression/cases/node-cli-rewrite/start/commands/scramble.py index 7470813dac82..29d34f8ef7f3 100644 --- a/evaluation/regression/cases/node-cli-rewrite/start/commands/scramble.py +++ b/evaluation/regression/cases/node-cli-rewrite/start/commands/scramble.py @@ -4,4 +4,4 @@ def scramble_string(s): s_list = list(s) random.shuffle(s_list) - return ''.join(s_list) + return "".join(s_list) diff --git a/evaluation/regression/cases/node-cli-rewrite/start/commands/spongebob.py b/evaluation/regression/cases/node-cli-rewrite/start/commands/spongebob.py index 782af450e16f..bd8d2ce644e5 100644 --- a/evaluation/regression/cases/node-cli-rewrite/start/commands/spongebob.py +++ b/evaluation/regression/cases/node-cli-rewrite/start/commands/spongebob.py @@ -1,5 +1,5 @@ def spongebob_case(s): - result = '' + result = "" for i, char in enumerate(s): if i % 2 == 0: result += char.lower() diff --git a/evaluation/regression/cases/node-cli-rewrite/start/string_cli.py b/evaluation/regression/cases/node-cli-rewrite/start/string_cli.py index 678455130571..4f0a6e41f0a6 100644 --- a/evaluation/regression/cases/node-cli-rewrite/start/string_cli.py +++ b/evaluation/regression/cases/node-cli-rewrite/start/string_cli.py @@ -16,40 +16,40 @@ def print_help(): print(help_text) -if __name__ == '__main__': - if len(sys.argv) == 2 and sys.argv[1] == '--help': +if __name__ == "__main__": + if len(sys.argv) == 2 and sys.argv[1] == "--help": print_help() sys.exit(0) elif len(sys.argv) < 3: - print('Usage: python string_cli.py ') + print("Usage: python string_cli.py ") sys.exit(1) command = sys.argv[1] input_string = sys.argv[2] - if command == 'reverse': + if command == "reverse": from commands.reverse import reverse_string print(reverse_string(input_string)) - elif command == 'uppercase': + elif command == "uppercase": from commands.uppercase import to_uppercase print(to_uppercase(input_string)) - elif command == 'lowercase': + elif command == "lowercase": from commands.lowercase import to_lowercase print(to_lowercase(input_string)) - elif command == 'spongebob': + elif command == "spongebob": from commands.spongebob import spongebob_case print(spongebob_case(input_string)) - elif command == 'length': + elif command == "length": from commands.length import string_length print(string_length(input_string)) - elif command == 'scramble': + elif command == "scramble": from commands.scramble import scramble_string print(scramble_string(input_string)) else: - print('Invalid command!') + print("Invalid command!") diff --git a/evaluation/regression/cases/python-cli-help/start/commands/scramble.py b/evaluation/regression/cases/python-cli-help/start/commands/scramble.py index 7470813dac82..29d34f8ef7f3 100644 --- a/evaluation/regression/cases/python-cli-help/start/commands/scramble.py +++ b/evaluation/regression/cases/python-cli-help/start/commands/scramble.py @@ -4,4 +4,4 @@ def scramble_string(s): s_list = list(s) random.shuffle(s_list) - return ''.join(s_list) + return "".join(s_list) diff --git a/evaluation/regression/cases/python-cli-help/start/commands/spongebob.py b/evaluation/regression/cases/python-cli-help/start/commands/spongebob.py index 782af450e16f..bd8d2ce644e5 100644 --- a/evaluation/regression/cases/python-cli-help/start/commands/spongebob.py +++ b/evaluation/regression/cases/python-cli-help/start/commands/spongebob.py @@ -1,5 +1,5 @@ def spongebob_case(s): - result = '' + result = "" for i, char in enumerate(s): if i % 2 == 0: result += char.lower() diff --git a/evaluation/regression/cases/python-cli-help/start/string_cli.py b/evaluation/regression/cases/python-cli-help/start/string_cli.py index 2deb02b0a670..060a1125e06c 100644 --- a/evaluation/regression/cases/python-cli-help/start/string_cli.py +++ b/evaluation/regression/cases/python-cli-help/start/string_cli.py @@ -1,36 +1,36 @@ import sys -if __name__ == '__main__': +if __name__ == "__main__": if len(sys.argv) < 3: - print('Usage: python string_cli.py ') + print("Usage: python string_cli.py ") sys.exit(1) command = sys.argv[1] input_string = sys.argv[2] - if command == 'reverse': + if command == "reverse": from commands.reverse import reverse_string print(reverse_string(input_string)) - elif command == 'uppercase': + elif command == "uppercase": from commands.uppercase import to_uppercase print(to_uppercase(input_string)) - elif command == 'lowercase': + elif command == "lowercase": from commands.lowercase import to_lowercase print(to_lowercase(input_string)) - elif command == 'spongebob': + elif command == "spongebob": from commands.spongebob import spongebob_case print(spongebob_case(input_string)) - elif command == 'length': + elif command == "length": from commands.length import string_length print(string_length(input_string)) - elif command == 'scramble': + elif command == "scramble": from commands.scramble import scramble_string print(scramble_string(input_string)) else: - print('Invalid command!') + print("Invalid command!") diff --git a/evaluation/regression/cases/server-test/start/server.py b/evaluation/regression/cases/server-test/start/server.py index 71a8d84c946b..a8cf7618330c 100644 --- a/evaluation/regression/cases/server-test/start/server.py +++ b/evaluation/regression/cases/server-test/start/server.py @@ -4,18 +4,18 @@ class HelloWorldHandler(BaseHTTPRequestHandler): def do_GET(self): self.send_response(200) - self.send_header('Content-type', 'text/plain') + self.send_header("Content-type", "text/plain") self.end_headers() - self.wfile.write(b'Hello World\n') + self.wfile.write(b"Hello World\n") def run(server_class=HTTPServer, handler_class=HelloWorldHandler, port=8000): - server_address = ('', port) + server_address = ("", port) httpd = server_class(server_address, handler_class) - print(f'Starting httpd on port {port}...') + print(f"Starting httpd on port {port}...") httpd.serve_forever() -if __name__ == '__main__': - print('starting server...') +if __name__ == "__main__": + print("starting server...") run() diff --git a/evaluation/swe_bench/eval_infer.py b/evaluation/swe_bench/eval_infer.py index 81eadeb33f10..4a13c70eda7c 100644 --- a/evaluation/swe_bench/eval_infer.py +++ b/evaluation/swe_bench/eval_infer.py @@ -99,8 +99,7 @@ def process_instance( reset_logger: bool = True, log_dir: str | None = None, ) -> EvalOutput: - """ - Evaluate agent performance on a SWE-bench problem instance. + """Evaluate agent performance on a SWE-bench problem instance. Note that this signature differs from the expected input to `run_evaluation`. Use `functools.partial` to provide optional arguments before passing to the evaluation harness. diff --git a/evaluation/swe_bench/scripts/docker/push_docker_instance_images.py b/evaluation/swe_bench/scripts/docker/push_docker_instance_images.py index 20fb1b94c0b6..931cadfe251a 100644 --- a/evaluation/swe_bench/scripts/docker/push_docker_instance_images.py +++ b/evaluation/swe_bench/scripts/docker/push_docker_instance_images.py @@ -31,49 +31,49 @@ from openhands.core.logger import openhands_logger as logger -logger.setLevel('ERROR') +logger.setLevel("ERROR") from evaluation.swe_bench.run_infer import get_instance_docker_image # noqa parser = argparse.ArgumentParser() -parser.add_argument('--dataset', type=str, default='princeton-nlp/SWE-bench_Lite') -parser.add_argument('--split', type=str, default='test') +parser.add_argument("--dataset", type=str, default="princeton-nlp/SWE-bench_Lite") +parser.add_argument("--split", type=str, default="test") args = parser.parse_args() dataset = load_dataset(args.dataset, split=args.split) client = docker.from_env() pbar = tqdm(total=len(dataset)) -counter = {'success': 0, 'failed': 0} +counter = {"success": 0, "failed": 0} failed_instances = [] for instance in dataset: - instance_id = instance['instance_id'] - image_name = f'sweb.eval.x86_64.{instance_id}' + instance_id = instance["instance_id"] + image_name = f"sweb.eval.x86_64.{instance_id}" target_image_name = get_instance_docker_image(instance_id) - print('-' * 100) + print("-" * 100) # check if image exists try: image: docker.models.images.Image = client.images.get(image_name) image.tag(target_image_name) - print(f'Image {image_name} -- tagging to --> {target_image_name}') + print(f"Image {image_name} -- tagging to --> {target_image_name}") ret_push = client.images.push(target_image_name) if isinstance(ret_push, str): print(ret_push) else: for line in ret_push: print(line) - print(f'Image {image_name} -- pushed to --> {target_image_name}') - counter['success'] += 1 + print(f"Image {image_name} -- pushed to --> {target_image_name}") + counter["success"] += 1 except docker.errors.ImageNotFound: - print(f'ERROR: Image {image_name} does not exist') - counter['failed'] += 1 + print(f"ERROR: Image {image_name} does not exist") + counter["failed"] += 1 failed_instances.append(instance_id) finally: pbar.update(1) pbar.set_postfix(counter) print(f'Success: {counter["success"]}, Failed: {counter["failed"]}') -print('Failed instances IDs:') +print("Failed instances IDs:") for failed_instance in failed_instances: print(failed_instance) diff --git a/evaluation/swe_bench/scripts/eval/compare_outputs.py b/evaluation/swe_bench/scripts/eval/compare_outputs.py index 2b4b8a40a850..f2af60ab17b6 100755 --- a/evaluation/swe_bench/scripts/eval/compare_outputs.py +++ b/evaluation/swe_bench/scripts/eval/compare_outputs.py @@ -4,18 +4,18 @@ import pandas as pd parser = argparse.ArgumentParser( - description='Compare two swe_bench output JSONL files and print the resolved diff' + description="Compare two swe_bench output JSONL files and print the resolved diff" ) -parser.add_argument('input_file_1', type=str) -parser.add_argument('input_file_2', type=str) +parser.add_argument("input_file_1", type=str) +parser.add_argument("input_file_2", type=str) args = parser.parse_args() -df1 = pd.read_json(args.input_file_1, orient='records', lines=True) -df2 = pd.read_json(args.input_file_2, orient='records', lines=True) +df1 = pd.read_json(args.input_file_1, orient="records", lines=True) +df2 = pd.read_json(args.input_file_2, orient="records", lines=True) # Get the intersection of the instance_ids -df = pd.merge(df1, df2, on='instance_id', how='inner') +df = pd.merge(df1, df2, on="instance_id", how="inner") def _get_resolved(report): @@ -24,44 +24,44 @@ def _get_resolved(report): if isinstance(report, float): return False else: - return report.get('resolved', False) + return report.get("resolved", False) -df['resolved_x'] = df['report_x'].apply(_get_resolved) -df['resolved_y'] = df['report_y'].apply(_get_resolved) -df['diff'] = df.apply(lambda x: x['resolved_x'] != x['resolved_y'], axis=1) +df["resolved_x"] = df["report_x"].apply(_get_resolved) +df["resolved_y"] = df["report_y"].apply(_get_resolved) +df["diff"] = df.apply(lambda x: x["resolved_x"] != x["resolved_y"], axis=1) -df_diff = df[df['diff']].sort_values( - by=['resolved_x', 'resolved_y'], ascending=[False, False] +df_diff = df[df["diff"]].sort_values( + by=["resolved_x", "resolved_y"], ascending=[False, False] ) # skip if any of the resolved is nan, which means one of the eval is not finished yet -df_diff = df_diff[df_diff['resolved_x'].notna() & df_diff['resolved_y'].notna()] +df_diff = df_diff[df_diff["resolved_x"].notna() & df_diff["resolved_y"].notna()] -print(f'X={args.input_file_1}') -print(f'Y={args.input_file_2}') -print(f'# diff={df_diff.shape[0]}') -df_diff = df_diff[['instance_id', 'resolved_x', 'resolved_y', 'report_x', 'report_y']] +print(f"X={args.input_file_1}") +print(f"Y={args.input_file_2}") +print(f"# diff={df_diff.shape[0]}") +df_diff = df_diff[["instance_id", "resolved_x", "resolved_y", "report_x", "report_y"]] # x resolved but y not -print('-' * 100) -df_diff_x_only = df_diff[df_diff['resolved_x'] & ~df_diff['resolved_y']].sort_values( - by='instance_id' +print("-" * 100) +df_diff_x_only = df_diff[df_diff["resolved_x"] & ~df_diff["resolved_y"]].sort_values( + by="instance_id" ) -print(f'# x resolved but y not={df_diff_x_only.shape[0]}') -print(df_diff_x_only[['instance_id', 'report_x', 'report_y']]) +print(f"# x resolved but y not={df_diff_x_only.shape[0]}") +print(df_diff_x_only[["instance_id", "report_x", "report_y"]]) # y resolved but x not -print('-' * 100) -df_diff_y_only = df_diff[~df_diff['resolved_x'] & df_diff['resolved_y']].sort_values( - by='instance_id' +print("-" * 100) +df_diff_y_only = df_diff[~df_diff["resolved_x"] & df_diff["resolved_y"]].sort_values( + by="instance_id" ) -print(f'# y resolved but x not={df_diff_y_only.shape[0]}') -print(df_diff_y_only[['instance_id', 'report_x', 'report_y']]) +print(f"# y resolved but x not={df_diff_y_only.shape[0]}") +print(df_diff_y_only[["instance_id", "report_x", "report_y"]]) # get instance_id from df_diff_y_only -print('-' * 100) -print('Instances that x resolved but y not:') -print(df_diff_x_only['instance_id'].tolist()) +print("-" * 100) +print("Instances that x resolved but y not:") +print(df_diff_x_only["instance_id"].tolist()) -print('-' * 100) -print('Instances that y resolved but x not:') -print(df_diff_y_only['instance_id'].tolist()) +print("-" * 100) +print("Instances that y resolved but x not:") +print(df_diff_y_only["instance_id"].tolist()) diff --git a/evaluation/swe_bench/scripts/eval/convert_oh_output_to_md.py b/evaluation/swe_bench/scripts/eval/convert_oh_output_to_md.py index 17a375ee3b79..5a52c2fa50d4 100755 --- a/evaluation/swe_bench/scripts/eval/convert_oh_output_to_md.py +++ b/evaluation/swe_bench/scripts/eval/convert_oh_output_to_md.py @@ -14,19 +14,19 @@ tqdm.pandas() parser = argparse.ArgumentParser() -parser.add_argument('oh_output_file', type=str) +parser.add_argument("oh_output_file", type=str) args = parser.parse_args() -output_md_folder = args.oh_output_file.replace('.jsonl', '.viz') -print(f'Converting {args.oh_output_file} to markdown files in {output_md_folder}') +output_md_folder = args.oh_output_file.replace(".jsonl", ".viz") +print(f"Converting {args.oh_output_file} to markdown files in {output_md_folder}") -oh_format = pd.read_json(args.oh_output_file, orient='records', lines=True) +oh_format = pd.read_json(args.oh_output_file, orient="records", lines=True) # model name is the folder name of oh_output_file model_name = os.path.basename(os.path.dirname(args.oh_output_file)) def convert_history_to_str(history): - ret = '' - separator = '\n\n' + '-' * 100 + '\n' + ret = "" + separator = "\n\n" + "-" * 100 + "\n" for i, event in enumerate(history): if i != 0: @@ -35,54 +35,54 @@ def convert_history_to_str(history): if isinstance(event, list): # "event" is a legacy pair of (action, observation) event_obj = event_from_dict(event[0]) - ret += f'## {i+1}| {event_obj.__class__.__name__}\n\n' + ret += f"## {i+1}| {event_obj.__class__.__name__}\n\n" ret += str(event_obj) ret += separator event_obj = event_from_dict(event[1]) - ret += f'## {i+1}| {event_obj.__class__.__name__}\n\n' + ret += f"## {i+1}| {event_obj.__class__.__name__}\n\n" ret += str(event_obj) else: # "event" is a single event event_obj = event_from_dict(event) - ret += f'## {i+1}| {event_obj.__class__.__name__}\n\n' + ret += f"## {i+1}| {event_obj.__class__.__name__}\n\n" ret += str(event_obj) return ret def write_row_to_md_file(row): - if 'git_patch' in row: - model_patch = row['git_patch'] - elif 'test_result' in row and 'git_patch' in row['test_result']: - model_patch = row['test_result']['git_patch'] + if "git_patch" in row: + model_patch = row["git_patch"] + elif "test_result" in row and "git_patch" in row["test_result"]: + model_patch = row["test_result"]["git_patch"] else: - raise ValueError(f'Row {row} does not have a git_patch') + raise ValueError(f"Row {row} does not have a git_patch") - if 'report' in row: - resolved = row['report'].get('resolved', False) + if "report" in row: + resolved = row["report"].get("resolved", False) else: resolved = None - instance_id = row['instance_id'] - filename = f'{str(resolved).lower()}.{instance_id}.md' + instance_id = row["instance_id"] + filename = f"{str(resolved).lower()}.{instance_id}.md" os.makedirs(output_md_folder, exist_ok=True) filepath = os.path.join(output_md_folder, filename) - with open(filepath, 'w') as f: - f.write(f'# {instance_id} (resolved: {resolved})\n') + with open(filepath, "w") as f: + f.write(f"# {instance_id} (resolved: {resolved})\n") # MetaData - f.write('## MetaData\n') - f.write('```json\n') - f.write(json.dumps(row['metadata'], indent=2)) - f.write('\n```\n') + f.write("## MetaData\n") + f.write("```json\n") + f.write(json.dumps(row["metadata"], indent=2)) + f.write("\n```\n") # Trajectory - f.write('## History\n') - f.write(convert_history_to_str(row['history'])) + f.write("## History\n") + f.write(convert_history_to_str(row["history"])) - f.write('## Model Patch\n') - f.write(f'{process_git_patch(model_patch)}\n') + f.write("## Model Patch\n") + f.write(f"{process_git_patch(model_patch)}\n") oh_format.progress_apply(write_row_to_md_file, axis=1) diff --git a/evaluation/swe_bench/scripts/eval/convert_oh_output_to_swe_json.py b/evaluation/swe_bench/scripts/eval/convert_oh_output_to_swe_json.py index 5006d3dde357..a0368d60036d 100644 --- a/evaluation/swe_bench/scripts/eval/convert_oh_output_to_swe_json.py +++ b/evaluation/swe_bench/scripts/eval/convert_oh_output_to_swe_json.py @@ -6,30 +6,30 @@ from evaluation.swe_bench.eval_infer import process_git_patch parser = argparse.ArgumentParser() -parser.add_argument('oh_output_file', type=str) +parser.add_argument("oh_output_file", type=str) args = parser.parse_args() -output_filepath = args.oh_output_file.replace('.jsonl', '.swebench.jsonl') -print(f'Converting {args.oh_output_file} to {output_filepath}') +output_filepath = args.oh_output_file.replace(".jsonl", ".swebench.jsonl") +print(f"Converting {args.oh_output_file} to {output_filepath}") -oh_format = pd.read_json(args.oh_output_file, orient='records', lines=True) +oh_format = pd.read_json(args.oh_output_file, orient="records", lines=True) # model name is the folder name of oh_output_file model_name = os.path.basename(os.path.dirname(args.oh_output_file)) def convert_row_to_swebench_format(row): - if 'git_patch' in row: - model_patch = row['git_patch'] - elif 'test_result' in row and 'git_patch' in row['test_result']: - model_patch = row['test_result']['git_patch'] + if "git_patch" in row: + model_patch = row["git_patch"] + elif "test_result" in row and "git_patch" in row["test_result"]: + model_patch = row["test_result"]["git_patch"] else: - raise ValueError(f'Row {row} does not have a git_patch') + raise ValueError(f"Row {row} does not have a git_patch") return { - 'instance_id': row['instance_id'], - 'model_patch': process_git_patch(model_patch), - 'model_name_or_path': model_name, + "instance_id": row["instance_id"], + "model_patch": process_git_patch(model_patch), + "model_name_or_path": model_name, } swebench_format = oh_format.apply(convert_row_to_swebench_format, axis=1) -swebench_format.to_json(output_filepath, lines=True, orient='records') +swebench_format.to_json(output_filepath, lines=True, orient="records") diff --git a/evaluation/swe_bench/scripts/eval/download_gold_patch.py b/evaluation/swe_bench/scripts/eval/download_gold_patch.py index 480df4cf9772..790f6c6a61ff 100644 --- a/evaluation/swe_bench/scripts/eval/download_gold_patch.py +++ b/evaluation/swe_bench/scripts/eval/download_gold_patch.py @@ -4,24 +4,24 @@ from datasets import load_dataset parser = argparse.ArgumentParser() -parser.add_argument('output_filepath', type=str, help='Path to save the output file') +parser.add_argument("output_filepath", type=str, help="Path to save the output file") parser.add_argument( - '--dataset_name', + "--dataset_name", type=str, - help='Name of the dataset to download', - default='princeton-nlp/SWE-bench_Lite', + help="Name of the dataset to download", + default="princeton-nlp/SWE-bench_Lite", ) -parser.add_argument('--split', type=str, help='Split to download', default='test') +parser.add_argument("--split", type=str, help="Split to download", default="test") args = parser.parse_args() dataset = load_dataset(args.dataset_name, split=args.split) output_filepath = args.output_filepath print( - f'Downloading gold patches from {args.dataset_name} (split: {args.split}) to {output_filepath}' + f"Downloading gold patches from {args.dataset_name} (split: {args.split}) to {output_filepath}" ) patches = [ - {'instance_id': row['instance_id'], 'model_patch': row['patch']} for row in dataset + {"instance_id": row["instance_id"], "model_patch": row["patch"]} for row in dataset ] -print(f'{len(patches)} gold patches loaded') -pd.DataFrame(patches).to_json(output_filepath, lines=True, orient='records') -print(f'Patches saved to {output_filepath}') +print(f"{len(patches)} gold patches loaded") +pd.DataFrame(patches).to_json(output_filepath, lines=True, orient="records") +print(f"Patches saved to {output_filepath}") diff --git a/evaluation/swe_bench/scripts/eval/summarize_outputs.py b/evaluation/swe_bench/scripts/eval/summarize_outputs.py index 5d5dbbf2a3bd..c15eb9eb7390 100755 --- a/evaluation/swe_bench/scripts/eval/summarize_outputs.py +++ b/evaluation/swe_bench/scripts/eval/summarize_outputs.py @@ -7,17 +7,17 @@ from openhands.events.utils import get_pairs_from_events ERROR_KEYWORDS = [ - 'Agent encountered an error while processing the last action', - 'APIError', - 'Action execution failed', + "Agent encountered an error while processing the last action", + "APIError", + "Action execution failed", ] -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('output_file', type=str, help='The file to summarize') + parser.add_argument("output_file", type=str, help="The file to summarize") args = parser.parse_args() - with open(args.output_file, 'r') as file: + with open(args.output_file, "r") as file: lines = file.readlines() num_lines = len(lines) @@ -37,7 +37,7 @@ _d = json.loads(line) # Cost - costs = _d['metrics'].get('costs', []) + costs = _d["metrics"].get("costs", []) _cur_main_agent_cost = 0 _cur_editor_cost = 0 for cost in costs: @@ -45,39 +45,39 @@ # backward compatible _cur_main_agent_cost += cost else: - if 'draft_editor' in cost['model']: - _cur_editor_cost += cost['cost'] + if "draft_editor" in cost["model"]: + _cur_editor_cost += cost["cost"] else: - _cur_main_agent_cost += cost['cost'] + _cur_main_agent_cost += cost["cost"] main_agent_cost.append(_cur_main_agent_cost) editor_cost.append(_cur_editor_cost) # Turn status - history = _d.get('history', []) + history = _d.get("history", []) events = [event_from_dict(event) for event in history] pairs = get_pairs_from_events(events) num_turns.append(len(pairs)) # Patch & resolve status - patch = _d.get('test_result', {}).get('git_patch', '') - if patch == '': + patch = _d.get("test_result", {}).get("git_patch", "") + if patch == "": num_empty_patch += 1 continue - report = _d.get('report', {}) or {} - resolved = report.get('resolved', False) + report = _d.get("report", {}) or {} + resolved = report.get("resolved", False) if resolved: num_resolved += 1 # Error - error = _d.get('error', None) + error = _d.get("error", None) if error is not None and isinstance(error, str): - agent_stuck_in_loop = 'Agent got stuck in a loop' in error + agent_stuck_in_loop = "Agent got stuck in a loop" in error contains_error = bool(error) and not agent_stuck_in_loop if agent_stuck_in_loop: - error_counter['Agent got stuck in a loop'] += 1 + error_counter["Agent got stuck in a loop"] += 1 num_agent_stuck_in_loop += 1 elif contains_error: error_counter[error] += 1 @@ -91,28 +91,28 @@ # print the error counter (with percentage) print( - f'Number of resolved: {num_resolved} / {num_lines} ({num_resolved / num_lines * 100:.2f}%)' + f"Number of resolved: {num_resolved} / {num_lines} ({num_resolved / num_lines * 100:.2f}%)" ) print( - f'Number of empty patch: {num_empty_patch} / {num_lines} ({num_empty_patch / num_lines * 100:.2f}%)' + f"Number of empty patch: {num_empty_patch} / {num_lines} ({num_empty_patch / num_lines * 100:.2f}%)" ) print( - f'Number of error lines: {num_error_lines} / {num_lines} ({num_error_lines / num_lines * 100:.2f}%)' + f"Number of error lines: {num_error_lines} / {num_lines} ({num_error_lines / num_lines * 100:.2f}%)" ) print( - f'Number of agent stuck in loop: {num_agent_stuck_in_loop} / {num_lines} ({num_agent_stuck_in_loop / num_lines * 100:.2f}%)' + f"Number of agent stuck in loop: {num_agent_stuck_in_loop} / {num_lines} ({num_agent_stuck_in_loop / num_lines * 100:.2f}%)" ) assert len(num_turns) == num_lines assert len(main_agent_cost) == num_lines assert len(editor_cost) == num_lines - print('## Statistics') - print(f'Avg. num of turns per instance: {sum(num_turns) / num_lines:.2f}') - print(f'Avg. agent cost per instance: {sum(main_agent_cost) / num_lines:.2f} USD') - print(f'Avg. editor cost per instance: {sum(editor_cost) / num_lines:.2f} USD') + print("## Statistics") + print(f"Avg. num of turns per instance: {sum(num_turns) / num_lines:.2f}") + print(f"Avg. agent cost per instance: {sum(main_agent_cost) / num_lines:.2f} USD") + print(f"Avg. editor cost per instance: {sum(editor_cost) / num_lines:.2f} USD") print( - f'Avg. total cost per instance: {(sum(main_agent_cost) + sum(editor_cost)) / num_lines:.2f} USD' + f"Avg. total cost per instance: {(sum(main_agent_cost) + sum(editor_cost)) / num_lines:.2f} USD" ) - print('## Detailed error breakdown:') + print("## Detailed error breakdown:") for error, count in error_counter.items(): - print(f'{error}: {count} ({count / num_lines * 100:.2f}%)') + print(f"{error}: {count} ({count / num_lines * 100:.2f}%)") diff --git a/evaluation/swe_bench/scripts/eval/update_output_with_eval.py b/evaluation/swe_bench/scripts/eval/update_output_with_eval.py index 662e640ca752..50cb8fbe1a4f 100644 --- a/evaluation/swe_bench/scripts/eval/update_output_with_eval.py +++ b/evaluation/swe_bench/scripts/eval/update_output_with_eval.py @@ -6,7 +6,7 @@ import pandas as pd parser = argparse.ArgumentParser() -parser.add_argument('input_file', type=str) +parser.add_argument("input_file", type=str) args = parser.parse_args() dirname = os.path.dirname(args.input_file) @@ -15,31 +15,31 @@ instance_id_to_status = defaultdict( lambda: { - 'empty_generation': False, - 'resolved': False, - 'failed_apply_patch': False, - 'error_eval': False, - 'test_timeout': False, + "empty_generation": False, + "resolved": False, + "failed_apply_patch": False, + "error_eval": False, + "test_timeout": False, } ) # Apply the status to the dataframe def apply_report(row): - instance_id = row['instance_id'] + instance_id = row["instance_id"] if instance_id in instance_id_to_status: return dict(instance_id_to_status[instance_id]) - return row.get('report', {}) + return row.get("report", {}) -swebench_official_report_json = os.path.join(dirname, 'report.json') +swebench_official_report_json = os.path.join(dirname, "report.json") openhands_remote_report_jsonl = args.input_file.replace( - '.jsonl', '.swebench_eval.jsonl' + ".jsonl", ".swebench_eval.jsonl" ) if os.path.exists(swebench_official_report_json): - output_md_filepath = os.path.join(dirname, 'README.md') - with open(swebench_official_report_json, 'r') as f: + output_md_filepath = os.path.join(dirname, "README.md") + with open(swebench_official_report_json, "r") as f: report = json.load(f) output_md = ( @@ -56,77 +56,77 @@ def apply_report(row): f"- unstopped instances: {report['unstopped_instances']}\n" ) - output_md += '\n## Resolved Instances\n' + output_md += "\n## Resolved Instances\n" # instance_id to status - for instance_id in report['resolved_ids']: - instance_id_to_status[instance_id]['resolved'] = True + for instance_id in report["resolved_ids"]: + instance_id_to_status[instance_id]["resolved"] = True output_md += ( - f'- [{instance_id}](./eval_outputs/{instance_id}/run_instance.log)\n' + f"- [{instance_id}](./eval_outputs/{instance_id}/run_instance.log)\n" ) - output_md += '\n## Unresolved Instances\n' - for instance_id in report['unresolved_ids']: + output_md += "\n## Unresolved Instances\n" + for instance_id in report["unresolved_ids"]: output_md += ( - f'- [{instance_id}](./eval_outputs/{instance_id}/run_instance.log)\n' + f"- [{instance_id}](./eval_outputs/{instance_id}/run_instance.log)\n" ) - output_md += '\n## Error Instances\n' - for instance_id in report['error_ids']: - instance_id_to_status[instance_id]['error_eval'] = True + output_md += "\n## Error Instances\n" + for instance_id in report["error_ids"]: + instance_id_to_status[instance_id]["error_eval"] = True output_md += ( - f'- [{instance_id}](./eval_outputs/{instance_id}/run_instance.log)\n' + f"- [{instance_id}](./eval_outputs/{instance_id}/run_instance.log)\n" ) - output_md += '\n## Empty Patch Instances\n' - for instance_id in report['empty_patch_ids']: - instance_id_to_status[instance_id]['empty_generation'] = True + output_md += "\n## Empty Patch Instances\n" + for instance_id in report["empty_patch_ids"]: + instance_id_to_status[instance_id]["empty_generation"] = True output_md += ( - f'- [{instance_id}](./eval_outputs/{instance_id}/run_instance.log)\n' + f"- [{instance_id}](./eval_outputs/{instance_id}/run_instance.log)\n" ) - output_md += '\n## Incomplete Instances\n' - for instance_id in report['incomplete_ids']: + output_md += "\n## Incomplete Instances\n" + for instance_id in report["incomplete_ids"]: output_md += ( - f'- [{instance_id}](./eval_outputs/{instance_id}/run_instance.log)\n' + f"- [{instance_id}](./eval_outputs/{instance_id}/run_instance.log)\n" ) - df['report'] = df.apply(apply_report, axis=1) + df["report"] = df.apply(apply_report, axis=1) - with open(output_md_filepath, 'w') as f: + with open(output_md_filepath, "w") as f: f.write(output_md) elif os.path.exists(openhands_remote_report_jsonl): - output_md_filepath = args.input_file.replace('.jsonl', '.swebench_eval.md') + output_md_filepath = args.input_file.replace(".jsonl", ".swebench_eval.md") - df_eval = pd.read_json(openhands_remote_report_jsonl, lines=True, orient='records') + df_eval = pd.read_json(openhands_remote_report_jsonl, lines=True, orient="records") - assert len(df['instance_id'].unique()) == len( + assert len(df["instance_id"].unique()) == len( df - ), 'There are duplicate instance ids in the original output which is not allowed' - assert len(df_eval['instance_id'].unique()) == len( + ), "There are duplicate instance ids in the original output which is not allowed" + assert len(df_eval["instance_id"].unique()) == len( df_eval - ), 'There are duplicate instance ids in the eval report which is not allowed' + ), "There are duplicate instance ids in the eval report which is not allowed" for _, row in df_eval.iterrows(): - instance_id_to_status[row['instance_id']] = row['test_result']['report'] - df['report'] = df.apply(apply_report, axis=1) + instance_id_to_status[row["instance_id"]] = row["test_result"]["report"] + df["report"] = df.apply(apply_report, axis=1) _n_instances = len(df) - _n_resolved = len(df[df['report'].apply(lambda x: x.get('resolved', False))]) + _n_resolved = len(df[df["report"].apply(lambda x: x.get("resolved", False))]) _n_unresolved = _n_instances - _n_resolved _n_empty_patch = len( - df[df['report'].apply(lambda x: x.get('empty_generation', False))] + df[df["report"].apply(lambda x: x.get("empty_generation", False))] ) - _n_error = len(df[df['report'].apply(lambda x: x.get('error_eval', False))]) + _n_error = len(df[df["report"].apply(lambda x: x.get("error_eval", False))]) output_md = ( - '# SWE-bench Report\n' - 'This folder contains the evaluation results of the SWE-bench using the [official evaluation docker containerization](https://github.com/princeton-nlp/SWE-bench/blob/main/docs/20240627_docker/README.md#choosing-the-right-cache_level).\n\n' - '## Summary\n' - f'- submitted instances: {_n_instances}\n' - f'- empty patch instances: {_n_empty_patch}\n' - f'- resolved instances: {_n_resolved}\n' - f'- unresolved instances: {_n_unresolved}\n' - f'- error instances: {_n_error}\n' + "# SWE-bench Report\n" + "This folder contains the evaluation results of the SWE-bench using the [official evaluation docker containerization](https://github.com/princeton-nlp/SWE-bench/blob/main/docs/20240627_docker/README.md#choosing-the-right-cache_level).\n\n" + "## Summary\n" + f"- submitted instances: {_n_instances}\n" + f"- empty patch instances: {_n_empty_patch}\n" + f"- resolved instances: {_n_resolved}\n" + f"- unresolved instances: {_n_unresolved}\n" + f"- error instances: {_n_error}\n" ) def _instance_id_to_log_path(instance_id): @@ -135,63 +135,63 @@ def _instance_id_to_log_path(instance_id): path = os.path.relpath(path, start=dirname) return path - output_md += '\n## Resolved Instances\n' + output_md += "\n## Resolved Instances\n" # instance_id to status for instance_id in sorted( - df[df['report'].apply(lambda x: x.get('resolved', False))][ - 'instance_id' + df[df["report"].apply(lambda x: x.get("resolved", False))][ + "instance_id" ].unique() ): - instance_id_to_status[instance_id]['resolved'] = True - output_md += f'- [{instance_id}]({_instance_id_to_log_path(instance_id)})\n' + instance_id_to_status[instance_id]["resolved"] = True + output_md += f"- [{instance_id}]({_instance_id_to_log_path(instance_id)})\n" - output_md += '\n## Unresolved Instances\n' + output_md += "\n## Unresolved Instances\n" for instance_id in sorted( - df[~df['report'].apply(lambda x: x.get('resolved', False))][ - 'instance_id' + df[~df["report"].apply(lambda x: x.get("resolved", False))][ + "instance_id" ].unique() ): - output_md += f'- [{instance_id}]({_instance_id_to_log_path(instance_id)})\n' + output_md += f"- [{instance_id}]({_instance_id_to_log_path(instance_id)})\n" - output_md += '\n## Error Instances\n' + output_md += "\n## Error Instances\n" for instance_id in sorted( - df[df['report'].apply(lambda x: x.get('error_eval', False))][ - 'instance_id' + df[df["report"].apply(lambda x: x.get("error_eval", False))][ + "instance_id" ].unique() ): - instance_id_to_status[instance_id]['error_eval'] = True - output_md += f'- [{instance_id}]({_instance_id_to_log_path(instance_id)})\n' + instance_id_to_status[instance_id]["error_eval"] = True + output_md += f"- [{instance_id}]({_instance_id_to_log_path(instance_id)})\n" - output_md += '\n## Empty Patch Instances\n' + output_md += "\n## Empty Patch Instances\n" for instance_id in sorted( - df[df['report'].apply(lambda x: x.get('empty_generation', False))][ - 'instance_id' + df[df["report"].apply(lambda x: x.get("empty_generation", False))][ + "instance_id" ].unique() ): - instance_id_to_status[instance_id]['empty_generation'] = True - output_md += f'- [{instance_id}]({_instance_id_to_log_path(instance_id)})\n' + instance_id_to_status[instance_id]["empty_generation"] = True + output_md += f"- [{instance_id}]({_instance_id_to_log_path(instance_id)})\n" - output_md += '\n## Incomplete Instances\n' + output_md += "\n## Incomplete Instances\n" for instance_id in sorted( - df[df['report'].apply(lambda x: x.get('test_timeout', False))][ - 'instance_id' + df[df["report"].apply(lambda x: x.get("test_timeout", False))][ + "instance_id" ].unique() ): - output_md += f'- [{instance_id}]({_instance_id_to_log_path(instance_id)})\n' - with open(output_md_filepath, 'w') as f: + output_md += f"- [{instance_id}]({_instance_id_to_log_path(instance_id)})\n" + with open(output_md_filepath, "w") as f: f.write(output_md) else: print( - f'No report file found: Both {swebench_official_report_json} and {openhands_remote_report_jsonl} do not exist.' + f"No report file found: Both {swebench_official_report_json} and {openhands_remote_report_jsonl} do not exist." ) exit() -if os.path.exists(args.input_file + '.bak'): - conf = input('Existing backup file found. Do you want to overwrite it? (y/n)') - if conf != 'y': +if os.path.exists(args.input_file + ".bak"): + conf = input("Existing backup file found. Do you want to overwrite it? (y/n)") + if conf != "y": exit() - os.remove(args.input_file + '.bak') + os.remove(args.input_file + ".bak") # backup the original file -os.rename(args.input_file, args.input_file + '.bak') -df.to_json(args.input_file, orient='records', lines=True) +os.rename(args.input_file, args.input_file + ".bak") +df.to_json(args.input_file, orient="records", lines=True) diff --git a/evaluation/swe_bench/scripts/setup/compare_patch_filename.py b/evaluation/swe_bench/scripts/setup/compare_patch_filename.py index 3f77119f55d4..e333450612b2 100755 --- a/evaluation/swe_bench/scripts/setup/compare_patch_filename.py +++ b/evaluation/swe_bench/scripts/setup/compare_patch_filename.py @@ -9,9 +9,9 @@ def extract_modified_files(patch): modified_files = set() - file_pattern = re.compile(r'^diff --git a/(.*?) b/') + file_pattern = re.compile(r"^diff --git a/(.*?) b/") - for line in patch.split('\n'): + for line in patch.split("\n"): match = file_pattern.match(line) if match: modified_files.add(match.group(1)) @@ -24,9 +24,9 @@ def process_report(oh_output_file): fail = 0 for line in open(oh_output_file): line = json.loads(line) - instance_id = line['instance_id'] - gold_patch = line['swe_instance']['patch'] - generated_patch = line['git_patch'] + instance_id = line["instance_id"] + gold_patch = line["swe_instance"]["patch"] + generated_patch = line["git_patch"] gold_modified_files = extract_modified_files(gold_patch) # swe-bench lite only: a gold patch always contains exactly one file assert len(gold_modified_files) == 1 @@ -39,16 +39,16 @@ def process_report(oh_output_file): else: fail += 1 print( - f'{instance_id}: file mismatch, gold = {gold_modified_files}, generated = {generated_modified_files}' + f"{instance_id}: file mismatch, gold = {gold_modified_files}, generated = {generated_modified_files}" ) print( - f'\nSUMMARY: {succ} out of {succ + fail} instances found correct files to edit, success rate = {succ / float(succ + fail)}' + f"\nSUMMARY: {succ} out of {succ + fail} instances found correct files to edit, success rate = {succ / float(succ + fail)}" ) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--oh_output_file', help='Path to the OH output file') + parser.add_argument("--oh_output_file", help="Path to the OH output file") args = parser.parse_args() process_report(args.oh_output_file) diff --git a/openhands/agenthub/browsing_agent/__init__.py b/openhands/agenthub/browsing_agent/__init__.py index 436d69d135f9..fe90dc828226 100644 --- a/openhands/agenthub/browsing_agent/__init__.py +++ b/openhands/agenthub/browsing_agent/__init__.py @@ -1,4 +1,4 @@ from openhands.agenthub.browsing_agent.browsing_agent import BrowsingAgent from openhands.controller.agent import Agent -Agent.register('BrowsingAgent', BrowsingAgent) +Agent.register("BrowsingAgent", BrowsingAgent) diff --git a/openhands/agenthub/browsing_agent/browsing_agent.py b/openhands/agenthub/browsing_agent/browsing_agent.py index 822677bab526..dec7753d1519 100644 --- a/openhands/agenthub/browsing_agent/browsing_agent.py +++ b/openhands/agenthub/browsing_agent/browsing_agent.py @@ -24,10 +24,10 @@ ) USE_NAV = ( - os.environ.get('USE_NAV', 'true') == 'true' + os.environ.get("USE_NAV", "true") == "true" ) # only disable NAV actions when running webarena and miniwob benchmarks USE_CONCISE_ANSWER = ( - os.environ.get('USE_CONCISE_ANSWER', 'false') == 'true' + os.environ.get("USE_CONCISE_ANSWER", "false") == "true" ) # only return concise answer when running webarena and miniwob benchmarks if not USE_NAV and USE_CONCISE_ANSWER: @@ -37,7 +37,7 @@ def get_error_prefix(last_browser_action: str) -> str: - return f'IMPORTANT! Last action is incorrect:\n{last_browser_action}\nThink again with the current observation of the page.\n' + return f"IMPORTANT! Last action is incorrect:\n{last_browser_action}\nThink again with the current observation of the page.\n" def get_system_message(goal: str, action_space: str) -> str: @@ -92,7 +92,7 @@ def get_prompt( class BrowsingAgent(Agent): - VERSION = '1.0' + VERSION = "1.0" """ An agent that interacts with the browser. """ @@ -113,9 +113,9 @@ def __init__( super().__init__(llm, config) # define a configurable action space, with chat functionality, web navigation, and webpage grounding using accessibility tree and HTML. # see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/action/highlevel.py for more details - action_subsets = ['chat', 'bid'] + action_subsets = ["chat", "bid"] if USE_NAV: - action_subsets.append('nav') + action_subsets.append("nav") self.action_space = HighLevelActionSet( subsets=action_subsets, strict=False, # less strict on the parsing of the actions @@ -144,9 +144,9 @@ def step(self, state: State) -> Action: """ messages: list[Message] = [] prev_actions = [] - cur_url = '' - cur_axtree_txt = '' - error_prefix = '' + cur_url = "" + cur_axtree_txt = "" + error_prefix = "" last_obs = None last_action = None @@ -154,7 +154,7 @@ def step(self, state: State) -> Action: # for webarena and miniwob++ eval, we need to retrieve the initial observation already in browser env # initialize and retrieve the first observation by issuing an noop OP # For non-benchmark browsing, the browser env starts with a blank page, and the agent is expected to first navigate to desired websites - return BrowseInteractiveAction(browser_actions='noop()') + return BrowseInteractiveAction(browser_actions="noop()") for event in state.history: if isinstance(event, BrowseInteractiveAction): @@ -162,14 +162,14 @@ def step(self, state: State) -> Action: last_action = event elif isinstance(event, MessageAction) and event.source == EventSource.AGENT: # agent has responded, task finished. - return AgentFinishAction(outputs={'content': event.content}) + return AgentFinishAction(outputs={"content": event.content}) elif isinstance(event, Observation): last_obs = event if EVAL_MODE: prev_actions = prev_actions[1:] # remove the first noop action - prev_action_str = '\n'.join(prev_actions) + prev_action_str = "\n".join(prev_actions) # if the final BrowserInteractiveAction exec BrowserGym's send_msg_to_user, # we should also send a message back to the user in OpenHands and call it a day if ( @@ -184,7 +184,7 @@ def step(self, state: State) -> Action: error_prefix = get_error_prefix(last_obs.last_browser_action) self.error_accumulator += 1 if self.error_accumulator > 5: - return MessageAction('Too many errors encountered. Task failed.') + return MessageAction("Too many errors encountered. Task failed.") cur_url = last_obs.url @@ -197,27 +197,27 @@ def step(self, state: State) -> Action: ) except Exception as e: logger.error( - 'Error when trying to process the accessibility tree: %s', e + "Error when trying to process the accessibility tree: %s", e ) - return MessageAction('Error encountered when browsing.') + return MessageAction("Error encountered when browsing.") goal, _ = state.get_current_user_intent() if goal is None: - goal = state.inputs['task'] + goal = state.inputs["task"] system_msg = get_system_message( goal, self.action_space.describe(with_long_description=False, with_examples=True), ) - messages.append(Message(role='system', content=[TextContent(text=system_msg)])) + messages.append(Message(role="system", content=[TextContent(text=system_msg)])) prompt = get_prompt(error_prefix, cur_url, cur_axtree_txt, prev_action_str) - messages.append(Message(role='user', content=[TextContent(text=prompt)])) + messages.append(Message(role="user", content=[TextContent(text=prompt)])) response = self.llm.completion( messages=self.llm.format_messages_for_llm(messages), - stop=[')```', ')\n```'], + stop=[")```", ")\n```"], ) return self.response_parser.parse(response) diff --git a/openhands/agenthub/browsing_agent/prompt.py b/openhands/agenthub/browsing_agent/prompt.py index 354156841912..0bd6883bf873 100644 --- a/openhands/agenthub/browsing_agent/prompt.py +++ b/openhands/agenthub/browsing_agent/prompt.py @@ -31,22 +31,22 @@ class Flags: use_action_history: bool = False use_memory: bool = False use_diff: bool = False - html_type: str = 'pruned_html' + html_type: str = "pruned_html" use_concrete_example: bool = True use_abstract_example: bool = False multi_actions: bool = False action_space: Literal[ - 'python', 'bid', 'coord', 'bid+coord', 'bid+nav', 'coord+nav', 'bid+coord+nav' - ] = 'bid' + "python", "bid", "coord", "bid+coord", "bid+nav", "coord+nav", "bid+coord+nav" + ] = "bid" is_strict: bool = False # This flag will be automatically disabled `if not chat_model_args.has_vision()` use_screenshot: bool = True enable_chat: bool = False max_prompt_tokens: int = 100_000 extract_visible_tag: bool = False - extract_coords: Literal['False', 'center', 'box'] = 'False' + extract_coords: Literal["False", "center", "box"] = "False" extract_visible_elements_only: bool = False - demo_mode: Literal['off', 'default', 'only_visible_elements'] = 'off' + demo_mode: Literal["off", "default", "only_visible_elements"] = "off" def copy(self): return deepcopy(self) @@ -63,7 +63,7 @@ def from_dict(self, flags_dict): if not isinstance(flags_dict, dict): raise ValueError( - f'Unregcognized type for flags_dict of type {type(flags_dict)}.' + f"Unregcognized type for flags_dict of type {type(flags_dict)}." ) return Flags(**flags_dict) @@ -77,9 +77,9 @@ class PromptElement: attributes or @property decorator. """ - _prompt = '' - _abstract_ex = '' - _concrete_ex = '' + _prompt = "" + _abstract_ex = "" + _concrete_ex = "" def __init__(self, visible: bool = True) -> None: """Prompt element that can be hidden. @@ -131,7 +131,7 @@ def _hide(self, value): if self.is_visible: return value else: - return '' + return "" def _parse_answer(self, text_answer) -> dict: if self.is_visible: @@ -174,9 +174,9 @@ def shrink(self) -> None: lines = self._prompt.splitlines() new_line_count = int(len(lines) * (1 - self.shrink_speed)) self.deleted_lines += len(lines) - new_line_count - self._prompt = '\n'.join(lines[:new_line_count]) + self._prompt = "\n".join(lines[:new_line_count]) self._prompt += ( - f'\n... Deleted {self.deleted_lines} lines to reduce prompt size.' + f"\n... Deleted {self.deleted_lines} lines to reduce prompt size." ) self.shrink_calls += 1 @@ -212,9 +212,9 @@ def fit_tokens( if isinstance(prompt, str): prompt_str = prompt elif isinstance(prompt, list): - prompt_str = '\n'.join([p['text'] for p in prompt if p['type'] == 'text']) + prompt_str = "\n".join([p["text"] for p in prompt if p["type"] == "text"]) else: - raise ValueError(f'Unrecognized type for prompt: {type(prompt)}') + raise ValueError(f"Unrecognized type for prompt: {type(prompt)}") n_chars = len(prompt_str) if n_chars <= max_prompt_chars: return prompt @@ -231,33 +231,33 @@ def fit_tokens( class HTML(Truncater): - def __init__(self, html, visible: bool = True, prefix='') -> None: + def __init__(self, html, visible: bool = True, prefix="") -> None: super().__init__(visible=visible, start_truncate_iteration=5) - self._prompt = f'\n{prefix}HTML:\n{html}\n' + self._prompt = f"\n{prefix}HTML:\n{html}\n" class AXTree(Truncater): def __init__( - self, ax_tree, visible: bool = True, coord_type=None, prefix='' + self, ax_tree, visible: bool = True, coord_type=None, prefix="" ) -> None: super().__init__(visible=visible, start_truncate_iteration=10) - if coord_type == 'center': + if coord_type == "center": coord_note = """\ Note: center coordinates are provided in parenthesis and are relative to the top left corner of the page.\n\n""" - elif coord_type == 'box': + elif coord_type == "box": coord_note = """\ Note: bounding box of each object are provided in parenthesis and are relative to the top left corner of the page.\n\n""" else: - coord_note = '' - self._prompt = f'\n{prefix}AXTree:\n{coord_note}{ax_tree}\n' + coord_note = "" + self._prompt = f"\n{prefix}AXTree:\n{coord_note}{ax_tree}\n" class Error(PromptElement): - def __init__(self, error, visible: bool = True, prefix='') -> None: + def __init__(self, error, visible: bool = True, prefix="") -> None: super().__init__(visible=visible) - self._prompt = f'\n{prefix}Error from previous action:\n{error}\n' + self._prompt = f"\n{prefix}Error from previous action:\n{error}\n" class Observation(Shrinkable): @@ -270,17 +270,17 @@ def __init__(self, obs, flags: Flags) -> None: super().__init__() self.flags = flags self.obs = obs - self.html = HTML(obs[flags.html_type], visible=flags.use_html, prefix='## ') + self.html = HTML(obs[flags.html_type], visible=flags.use_html, prefix="## ") self.ax_tree = AXTree( - obs['axtree_txt'], + obs["axtree_txt"], visible=flags.use_ax_tree, coord_type=flags.extract_coords, - prefix='## ', + prefix="## ", ) self.error = Error( - obs['last_action_error'], - visible=flags.use_error_logs and obs['last_action_error'], - prefix='## ', + obs["last_action_error"], + visible=flags.use_error_logs and obs["last_action_error"], + prefix="## ", ) def shrink(self): @@ -289,24 +289,24 @@ def shrink(self): @property def _prompt(self) -> str: # type: ignore - return f'\n# Observation of current step:\n{self.html.prompt}{self.ax_tree.prompt}{self.error.prompt}\n\n' + return f"\n# Observation of current step:\n{self.html.prompt}{self.ax_tree.prompt}{self.error.prompt}\n\n" def add_screenshot(self, prompt): if self.flags.use_screenshot: if isinstance(prompt, str): - prompt = [{'type': 'text', 'text': prompt}] + prompt = [{"type": "text", "text": prompt}] img_url = BrowserEnv.image_to_jpg_base64_url( - self.obs['screenshot'], add_data_prefix=True + self.obs["screenshot"], add_data_prefix=True ) - prompt.append({'type': 'image_url', 'image_url': img_url}) + prompt.append({"type": "image_url", "image_url": img_url}) return prompt class MacNote(PromptElement): def __init__(self) -> None: - super().__init__(visible=platform.system() == 'Darwin') - self._prompt = '\nNote: you are on mac so you should use Meta instead of Control for Control+C etc.\n' + super().__init__(visible=platform.system() == "Darwin") + self._prompt = "\nNote: you are on mac so you should use Meta instead of Control for Control+C etc.\n" class BeCautious(PromptElement): @@ -351,7 +351,7 @@ def __init__(self, chat_messages, visible: bool = True) -> None: ## Chat messages: """ - self._prompt += '\n'.join( + self._prompt += "\n".join( [ f"""\ - [{msg['role']}], {msg['message']}""" @@ -381,20 +381,20 @@ def __init__( self.history = History(obs_history, actions, memories, thoughts, flags) if self.flags.enable_chat: self.instructions: Union[ChatInstructions, GoalInstructions] = ( - ChatInstructions(obs_history[-1]['chat_messages']) + ChatInstructions(obs_history[-1]["chat_messages"]) ) else: if ( - 'chat_messages' in obs_history[-1] + "chat_messages" in obs_history[-1] and sum( - [msg['role'] == 'user' for msg in obs_history[-1]['chat_messages']] + [msg["role"] == "user" for msg in obs_history[-1]["chat_messages"]] ) > 1 ): logging.warning( - 'Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`.' + "Agent is in goal mode, but multiple user messages are present in the chat. Consider switching to `enable_chat=True`." ) - self.instructions = GoalInstructions(obs_history[-1]['goal']) + self.instructions = GoalInstructions(obs_history[-1]["goal"]) self.obs = Observation(obs_history[-1], self.flags) self.action_space = ActionSpace(self.flags) @@ -456,7 +456,7 @@ def __init__(self, flags: Flags) -> None: self.action_space = _get_action_space(flags) self._prompt = ( - f'# Action space:\n{self.action_space.describe()}{MacNote().prompt}\n' + f"# Action space:\n{self.action_space.describe()}{MacNote().prompt}\n" ) self._abstract_ex = f""" @@ -471,17 +471,17 @@ def __init__(self, flags: Flags) -> None: def _parse_answer(self, text_answer): ans_dict = parse_html_tags_raise( - text_answer, keys=['action'], merge_multiple=True + text_answer, keys=["action"], merge_multiple=True ) try: # just check if action can be mapped to python code but keep action as is # the environment will be responsible for mapping it to python - self.action_space.to_python_code(ans_dict['action']) + self.action_space.to_python_code(ans_dict["action"]) except Exception as e: raise ParseError( - f'Error while parsing action\n: {e}\n' - 'Make sure your answer is restricted to the allowed actions.' + f"Error while parsing action\n: {e}\n" + "Make sure your answer is restricted to the allowed actions." ) return ans_dict @@ -489,34 +489,34 @@ def _parse_answer(self, text_answer): def _get_action_space(flags: Flags) -> AbstractActionSet: match flags.action_space: - case 'python': + case "python": action_space = PythonActionSet(strict=flags.is_strict) if flags.multi_actions: warn( - f'Flag action_space={repr(flags.action_space)} incompatible with multi_actions={repr(flags.multi_actions)}.', + f"Flag action_space={repr(flags.action_space)} incompatible with multi_actions={repr(flags.multi_actions)}.", stacklevel=2, ) - if flags.demo_mode != 'off': + if flags.demo_mode != "off": warn( - f'Flag action_space={repr(flags.action_space)} incompatible with demo_mode={repr(flags.demo_mode)}.', + f"Flag action_space={repr(flags.action_space)} incompatible with demo_mode={repr(flags.demo_mode)}.", stacklevel=2, ) return action_space - case 'bid': - action_subsets = ['chat', 'bid'] - case 'coord': - action_subsets = ['chat', 'coord'] - case 'bid+coord': - action_subsets = ['chat', 'bid', 'coord'] - case 'bid+nav': - action_subsets = ['chat', 'bid', 'nav'] - case 'coord+nav': - action_subsets = ['chat', 'coord', 'nav'] - case 'bid+coord+nav': - action_subsets = ['chat', 'bid', 'coord', 'nav'] + case "bid": + action_subsets = ["chat", "bid"] + case "coord": + action_subsets = ["chat", "coord"] + case "bid+coord": + action_subsets = ["chat", "bid", "coord"] + case "bid+nav": + action_subsets = ["chat", "bid", "nav"] + case "coord+nav": + action_subsets = ["chat", "coord", "nav"] + case "bid+coord+nav": + action_subsets = ["chat", "bid", "coord", "nav"] case _: raise NotImplementedError( - f'Unknown action_space {repr(flags.action_space)}' + f"Unknown action_space {repr(flags.action_space)}" ) action_space = HighLevelActionSet( @@ -530,7 +530,7 @@ def _get_action_space(flags: Flags) -> AbstractActionSet: class Memory(PromptElement): - _prompt = '' # provided in the abstract and concrete examples + _prompt = "" # provided in the abstract and concrete examples _abstract_ex = """ @@ -548,12 +548,12 @@ class Memory(PromptElement): def _parse_answer(self, text_answer): return parse_html_tags_raise( - text_answer, optional_keys=['memory'], merge_multiple=True + text_answer, optional_keys=["memory"], merge_multiple=True ) class Think(PromptElement): - _prompt = '' + _prompt = "" _abstract_ex = """ @@ -571,7 +571,7 @@ class Think(PromptElement): def _parse_answer(self, text_answer): return parse_html_tags_raise( - text_answer, optional_keys=['think'], merge_multiple=True + text_answer, optional_keys=["think"], merge_multiple=True ) @@ -581,10 +581,10 @@ def diff(previous, new): If the difference is above diff_threshold, return the diff string. """ if previous == new: - return 'Identical', [] + return "Identical", [] if len(previous) == 0 or previous is None: - return 'previous is empty', [] + return "previous is empty", [] diff_gen = difflib.ndiff(previous.splitlines(), new.splitlines()) @@ -592,23 +592,23 @@ def diff(previous, new): plus_count = 0 minus_count = 0 for line in diff_gen: - if line.strip().startswith('+'): + if line.strip().startswith("+"): diff_lines.append(line) plus_count += 1 - elif line.strip().startswith('-'): + elif line.strip().startswith("-"): diff_lines.append(line) minus_count += 1 else: continue - header = f'{plus_count} lines added and {minus_count} lines removed:' + header = f"{plus_count} lines added and {minus_count} lines removed:" return header, diff_lines class Diff(Shrinkable): def __init__( - self, previous, new, prefix='', max_line_diff=20, shrink_speed=2, visible=True + self, previous, new, prefix="", max_line_diff=20, shrink_speed=2, visible=True ) -> None: super().__init__(visible=visible) self.max_line_diff = max_line_diff @@ -622,11 +622,11 @@ def shrink(self): @property def _prompt(self) -> str: # type: ignore - diff_str = '\n'.join(self.diff_lines[: self.max_line_diff]) + diff_str = "\n".join(self.diff_lines[: self.max_line_diff]) if len(self.diff_lines) > self.max_line_diff: original_count = len(self.diff_lines) - diff_str = f'{diff_str}\nDiff truncated, {original_count - self.max_line_diff} changes now shown.' - return f'{self.prefix}{self.header}\n{diff_str}\n' + diff_str = f"{diff_str}\nDiff truncated, {original_count - self.max_line_diff} changes now shown." + return f"{self.prefix}{self.header}\n{diff_str}\n" class HistoryStep(Shrinkable): @@ -637,25 +637,25 @@ def __init__( self.html_diff = Diff( previous_obs[flags.html_type], current_obs[flags.html_type], - prefix='\n### HTML diff:\n', + prefix="\n### HTML diff:\n", shrink_speed=shrink_speed, visible=lambda: flags.use_html and flags.use_diff, ) self.ax_tree_diff = Diff( - previous_obs['axtree_txt'], - current_obs['axtree_txt'], - prefix='\n### Accessibility tree diff:\n', + previous_obs["axtree_txt"], + current_obs["axtree_txt"], + prefix="\n### Accessibility tree diff:\n", shrink_speed=shrink_speed, visible=lambda: flags.use_ax_tree and flags.use_diff, ) self.error = Error( - current_obs['last_action_error'], + current_obs["last_action_error"], visible=( flags.use_error_logs - and current_obs['last_action_error'] + and current_obs["last_action_error"] and flags.use_past_error_logs ), - prefix='### ', + prefix="### ", ) self.shrink_speed = shrink_speed self.action = action @@ -669,17 +669,17 @@ def shrink(self): @property def _prompt(self) -> str: # type: ignore - prompt = '' + prompt = "" if self.flags.use_action_history: - prompt += f'\n### Action:\n{self.action}\n' + prompt += f"\n### Action:\n{self.action}\n" prompt += ( - f'{self.error.prompt}{self.html_diff.prompt}{self.ax_tree_diff.prompt}' + f"{self.error.prompt}{self.html_diff.prompt}{self.ax_tree_diff.prompt}" ) if self.flags.use_memory and self.memory is not None: - prompt += f'\n### Memory:\n{self.memory}\n' + prompt += f"\n### Memory:\n{self.memory}\n" return prompt @@ -715,14 +715,14 @@ def shrink(self): @property def _prompt(self): - prompts = ['# History of interaction with the task:\n'] + prompts = ["# History of interaction with the task:\n"] for i, step in enumerate(self.history_steps): - prompts.append(f'## step {i}') + prompts.append(f"## step {i}") prompts.append(step.prompt) - return '\n'.join(prompts) + '\n' + return "\n".join(prompts) + "\n" -if __name__ == '__main__': +if __name__ == "__main__": html_template = """ @@ -736,27 +736,27 @@ def _prompt(self): OBS_HISTORY = [ { - 'goal': 'do this and that', - 'pruned_html': html_template.format(1), - 'axtree_txt': '[1] Click me', - 'last_action_error': '', + "goal": "do this and that", + "pruned_html": html_template.format(1), + "axtree_txt": "[1] Click me", + "last_action_error": "", }, { - 'goal': 'do this and that', - 'pruned_html': html_template.format(2), - 'axtree_txt': '[1] Click me', - 'last_action_error': '', + "goal": "do this and that", + "pruned_html": html_template.format(2), + "axtree_txt": "[1] Click me", + "last_action_error": "", }, { - 'goal': 'do this and that', - 'pruned_html': html_template.format(3), - 'axtree_txt': '[1] Click me', - 'last_action_error': 'Hey, there is an error now', + "goal": "do this and that", + "pruned_html": html_template.format(3), + "axtree_txt": "[1] Click me", + "last_action_error": "Hey, there is an error now", }, ] ACTIONS = ["click('41')", "click('42')"] - MEMORIES = ['memory A', 'memory B'] - THOUGHTS = ['thought A', 'thought B'] + MEMORIES = ["memory A", "memory B"] + THOUGHTS = ["thought A", "thought B"] flags = Flags( use_html=True, @@ -768,7 +768,7 @@ def _prompt(self): use_action_history=True, use_memory=True, use_diff=True, - html_type='pruned_html', + html_type="pruned_html", use_concrete_example=True, use_abstract_example=True, use_screenshot=False, diff --git a/openhands/agenthub/browsing_agent/response_parser.py b/openhands/agenthub/browsing_agent/response_parser.py index 8687016c6ad7..a04378fa31b0 100644 --- a/openhands/agenthub/browsing_agent/response_parser.py +++ b/openhands/agenthub/browsing_agent/response_parser.py @@ -21,17 +21,17 @@ def parse(self, response: str) -> Action: return self.parse_action(action_str) def parse_response(self, response) -> str: - action_str = response['choices'][0]['message']['content'] + action_str = response["choices"][0]["message"]["content"] if action_str is None: - return '' + return "" action_str = action_str.strip() # Ensure action_str ends with ')```' if action_str: - if not action_str.endswith('```'): - if action_str.endswith(')'): - action_str += '```' # prevent duplicate ending paranthesis, e.g. send_msg_to_user('Done')) + if not action_str.endswith("```"): + if action_str.endswith(")"): + action_str += "```" # prevent duplicate ending paranthesis, e.g. send_msg_to_user('Done')) else: - action_str += ')```' # expected format + action_str += ")```" # expected format logger.debug(action_str) return action_str @@ -53,7 +53,7 @@ def __init__( pass def check_condition(self, action_str: str) -> bool: - return '```' not in action_str + return "```" not in action_str def parse(self, action_str: str) -> Action: msg = f'send_msg_to_user("""{action_str}""")' @@ -92,29 +92,29 @@ def parse(self, action_str: str) -> Action: # when the LLM returns only one string, it looks like this: ### goto('https://www.whitehouse.gov/about-the-white-house/presidents/') # and parse_response added )``` to the end of the string - parts = action_str.split('```') + parts = action_str.split("```") browser_actions = ( - parts[1].strip() if parts[1].strip() != '' else parts[0].strip() + parts[1].strip() if parts[1].strip() != "" else parts[0].strip() ) - thought = parts[0].strip() if parts[1].strip() != '' else '' + thought = parts[0].strip() if parts[1].strip() != "" else "" # if the LLM wants to talk to the user, we extract the message - msg_content = '' - for sub_action in browser_actions.split('\n'): - if 'send_msg_to_user(' in sub_action: + msg_content = "" + for sub_action in browser_actions.split("\n"): + if "send_msg_to_user(" in sub_action: try: tree = ast.parse(sub_action) args = tree.body[0].value.args # type: ignore msg_content = args[0].value except SyntaxError: - logger.error(f'Error parsing action: {sub_action}') + logger.error(f"Error parsing action: {sub_action}") # the syntax was not correct, but we can still try to get the message # e.g. send_msg_to_user("Hello, world!") or send_msg_to_user('Hello, world!' match = re.search(r'send_msg_to_user\((["\'])(.*?)\1\)', sub_action) if match: msg_content = match.group(2) else: - msg_content = '' + msg_content = "" return BrowseInteractiveAction( browser_actions=browser_actions, diff --git a/openhands/agenthub/browsing_agent/utils.py b/openhands/agenthub/browsing_agent/utils.py index 8e67679966ae..d57deb37255f 100644 --- a/openhands/agenthub/browsing_agent/utils.py +++ b/openhands/agenthub/browsing_agent/utils.py @@ -8,12 +8,12 @@ def yaml_parser(message): """Parse a yaml message for the retry function.""" # saves gpt-3.5 from some yaml parsing errors - message = re.sub(r':\s*\n(?=\S|\n)', ': ', message) + message = re.sub(r":\s*\n(?=\S|\n)", ": ", message) try: value = yaml.safe_load(message) valid = True - retry_message = '' + retry_message = "" except yaml.YAMLError as e: warn(str(e), stacklevel=2) value = {} @@ -22,7 +22,7 @@ def yaml_parser(message): return value, valid, retry_message -def _compress_chunks(text, identifier, skip_list, split_regex='\n\n+'): +def _compress_chunks(text, identifier, skip_list, split_regex="\n\n+"): """Compress a string by replacing redundant chunks by identifiers. Chunks are defined by the split_regex.""" text_list = re.split(split_regex, text) text_list = [chunk.strip() for chunk in text_list] @@ -33,11 +33,11 @@ def _compress_chunks(text, identifier, skip_list, split_regex='\n\n+'): # Store items that occur more than once in a dictionary for item, count in counter.items(): if count > 1 and item not in skip_list and len(item) > 10: - def_dict[f'{identifier}-{id}'] = item + def_dict[f"{identifier}-{id}"] = item id += 1 # Replace redundant items with their identifiers in the text - compressed_text = '\n'.join(text_list) + compressed_text = "\n".join(text_list) for key, value in def_dict.items(): compressed_text = compressed_text.replace(value, key) @@ -48,23 +48,23 @@ def compress_string(text): """Compress a string by replacing redundant paragraphs and lines with identifiers.""" # Perform paragraph-level compression def_dict, compressed_text = _compress_chunks( - text, identifier='§', skip_list=[], split_regex='\n\n+' + text, identifier="§", skip_list=[], split_regex="\n\n+" ) # Perform line-level compression, skipping any paragraph identifiers line_dict, compressed_text = _compress_chunks( - compressed_text, '¶', list(def_dict.keys()), split_regex='\n+' + compressed_text, "¶", list(def_dict.keys()), split_regex="\n+" ) def_dict.update(line_dict) # Create a definitions section - def_lines = [''] + def_lines = [""] for key, value in def_dict.items(): - def_lines.append(f'{key}:\n{value}') - def_lines.append('') - definitions = '\n'.join(def_lines) + def_lines.append(f"{key}:\n{value}") + def_lines.append("") + definitions = "\n".join(def_lines) - return definitions + '\n' + compressed_text + return definitions + "\n" + compressed_text def extract_html_tags(text, keys): @@ -91,7 +91,7 @@ def extract_html_tags(text, keys): # text = text.lower() # keys = set([k.lower() for k in keys]) for key in keys: - pattern = f'<{key}>(.*?)' + pattern = f"<{key}>(.*?)" matches = re.findall(pattern, text, re.DOTALL) if matches: content_dict[key] = [match.strip() for match in matches] @@ -140,19 +140,19 @@ def parse_html_tags(text, keys=(), optional_keys=(), merge_multiple=False): for key in all_keys: if key not in content_dict: if key not in optional_keys: - retry_messages.append(f'Missing the key <{key}> in the answer.') + retry_messages.append(f"Missing the key <{key}> in the answer.") else: val = content_dict[key] content_dict[key] = val[0] if len(val) > 1: if not merge_multiple: retry_messages.append( - f'Found multiple instances of the key {key}. You should have only one of them.' + f"Found multiple instances of the key {key}. You should have only one of them." ) else: # merge the multiple instances - content_dict[key] = '\n'.join(val) + content_dict[key] = "\n".join(val) valid = len(retry_messages) == 0 - retry_message = '\n'.join(retry_messages) + retry_message = "\n".join(retry_messages) return content_dict, valid, retry_message diff --git a/openhands/agenthub/codeact_agent/__init__.py b/openhands/agenthub/codeact_agent/__init__.py index 63f1fdb820b4..9edd9eb08ef6 100644 --- a/openhands/agenthub/codeact_agent/__init__.py +++ b/openhands/agenthub/codeact_agent/__init__.py @@ -1,4 +1,4 @@ from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent from openhands.controller.agent import Agent -Agent.register('CodeActAgent', CodeActAgent) +Agent.register("CodeActAgent", CodeActAgent) diff --git a/openhands/agenthub/codeact_agent/action_parser.py b/openhands/agenthub/codeact_agent/action_parser.py index 75fab1156f8c..3e76fed5e898 100644 --- a/openhands/agenthub/codeact_agent/action_parser.py +++ b/openhands/agenthub/codeact_agent/action_parser.py @@ -46,21 +46,21 @@ def parse(self, response) -> Action: def parse_response(self, response) -> str: action = response.choices[0].message.content if action is None: - return '' - for lang in ['bash', 'ipython', 'browse']: + return "" + for lang in ["bash", "ipython", "browse"]: # special handling for DeepSeek: it has stop-word bug and returns - if f'' not in action: - action = action.replace(f'') + if f"" not in action: + action = action.replace(f"") - if f'' in action and f'' not in action: - action += f'' + if f"" in action and f"" not in action: + action += f"" # special handling for DeepSeek: it has stop-word bug and returns - if '' not in action: - action = action.replace('') + if "" not in action: + action = action.replace("") - if '' not in action: - action += '' + if "" not in action: + action += "" return action def parse_action(self, action_str: str) -> Action: @@ -72,19 +72,19 @@ def parse_action(self, action_str: str) -> Action: def action_to_str(self, action: Action) -> str: if isinstance(action, CmdRunAction): return ( - f'{action.thought}\n\n{action.command}\n' + f"{action.thought}\n\n{action.command}\n" ) elif isinstance(action, IPythonRunCellAction): - return f'{action.thought}\n\n{action.code}\n' + return f"{action.thought}\n\n{action.code}\n" elif isinstance(action, AgentDelegateAction): return f'{action.thought}\n\n{action.inputs["task"]}\n' elif isinstance(action, FileEditAction): - return f'{action.thought}\n\n{action.content}\n' + return f"{action.thought}\n\n{action.content}\n" elif isinstance(action, MessageAction): return action.content - elif isinstance(action, AgentFinishAction) and action.source == 'agent': + elif isinstance(action, AgentFinishAction) and action.source == "agent": return action.thought - return '' + return "" class CodeActActionParserFinish(ActionParser): @@ -98,14 +98,14 @@ def __init__( self.finish_command = None def check_condition(self, action_str: str) -> bool: - self.finish_command = re.search(r'.*', action_str, re.DOTALL) + self.finish_command = re.search(r".*", action_str, re.DOTALL) return self.finish_command is not None def parse(self, action_str: str) -> Action: assert ( self.finish_command is not None - ), 'self.finish_command should not be None when parse is called' - thought = action_str.replace(self.finish_command.group(0), '').strip() + ), "self.finish_command should not be None when parse is called" + thought = action_str.replace(self.finish_command.group(0), "").strip() return AgentFinishAction(thought=thought) @@ -122,18 +122,18 @@ def __init__( def check_condition(self, action_str: str) -> bool: self.bash_command = re.search( - r'(.*?)', action_str, re.DOTALL + r"(.*?)", action_str, re.DOTALL ) return self.bash_command is not None def parse(self, action_str: str) -> Action: assert ( self.bash_command is not None - ), 'self.bash_command should not be None when parse is called' - thought = action_str.replace(self.bash_command.group(0), '').strip() + ), "self.bash_command should not be None when parse is called" + thought = action_str.replace(self.bash_command.group(0), "").strip() # a command was found command_group = self.bash_command.group(1).strip() - if command_group.strip() == 'exit': + if command_group.strip() == "exit": return AgentFinishAction(thought=thought) return CmdRunAction(command=command_group, thought=thought) @@ -147,20 +147,20 @@ def __init__( self, ): self.python_code = None - self.jupyter_kernel_init_code: str = 'from agentskills import *' + self.jupyter_kernel_init_code: str = "from agentskills import *" def check_condition(self, action_str: str) -> bool: self.python_code = re.search( - r'(.*?)', action_str, re.DOTALL + r"(.*?)", action_str, re.DOTALL ) return self.python_code is not None def parse(self, action_str: str) -> Action: assert ( self.python_code is not None - ), 'self.python_code should not be None when parse is called' + ), "self.python_code should not be None when parse is called" code_group = self.python_code.group(1).strip() - thought = action_str.replace(self.python_code.group(0), '').strip() + thought = action_str.replace(self.python_code.group(0), "").strip() return IPythonRunCellAction( code=code_group, thought=thought, @@ -180,24 +180,24 @@ def __init__( def check_condition(self, action_str: str) -> bool: self.agent_delegate = re.search( - r'(.*)', action_str, re.DOTALL + r"(.*)", action_str, re.DOTALL ) return self.agent_delegate is not None def parse(self, action_str: str) -> Action: assert ( self.agent_delegate is not None - ), 'self.agent_delegate should not be None when parse is called' - thought = action_str.replace(self.agent_delegate.group(0), '').strip() + ), "self.agent_delegate should not be None when parse is called" + thought = action_str.replace(self.agent_delegate.group(0), "").strip() browse_actions = self.agent_delegate.group(1).strip() thought = ( - f'{thought}\nI should start with: {browse_actions}' + f"{thought}\nI should start with: {browse_actions}" if thought - else f'I should start with: {browse_actions}' + else f"I should start with: {browse_actions}" ) return AgentDelegateAction( - agent='BrowsingAgent', thought=thought, inputs={'task': browse_actions} + agent="BrowsingAgent", thought=thought, inputs={"task": browse_actions} ) @@ -229,7 +229,7 @@ def __init__(self): self.file_edit_match: re.Match | None = None def check_condition(self, action_str: str) -> bool: - if ' bool: f'FileEditAction detected but the format is incorrect. Unable to match for in:\n{"-" * 80}\n{action_str}\n{"-" * 80}' ) raise LLMMalformedActionError( - 'FileEditAction detected but the format is incorrect. Usage:\n' + "FileEditAction detected but the format is incorrect. Usage:\n" '\n' - '[content_to_edit]\n' - '\n' + "[content_to_edit]\n" + "\n" ) path = self.file_edit_match.group(2) @@ -256,7 +256,7 @@ def check_condition(self, action_str: str) -> bool: if not path: raise LLMMalformedActionError( - 'FileEditAction detected but no `path` specified. You should specify the path of the file to edit.' + "FileEditAction detected but no `path` specified. You should specify the path of the file to edit." ) if start: @@ -264,7 +264,7 @@ def check_condition(self, action_str: str) -> bool: int(start) except ValueError: raise LLMMalformedActionError( - f'FileEditAction detected but `start` is not a valid integer: {start}' + f"FileEditAction detected but `start` is not a valid integer: {start}" ) if end: @@ -272,7 +272,7 @@ def check_condition(self, action_str: str) -> bool: int(end) except ValueError: raise LLMMalformedActionError( - f'FileEditAction detected but `end` is not a valid integer: {end}' + f"FileEditAction detected but `end` is not a valid integer: {end}" ) return True @@ -280,7 +280,7 @@ def check_condition(self, action_str: str) -> bool: def parse(self, action_str: str) -> Action: assert ( self.file_edit_match is not None - ), 'self.file_edit_match should not be None when parse is called' + ), "self.file_edit_match should not be None when parse is called" file_path = self.file_edit_match.group(2).strip() start_line = ( @@ -294,7 +294,7 @@ def parse(self, action_str: str) -> Action: else None ) content = self.file_edit_match.group(7) - thought = action_str.replace(self.file_edit_match.group(0), '').strip() + thought = action_str.replace(self.file_edit_match.group(0), "").strip() action = FileEditAction(path=file_path, content=content, thought=thought) if start_line is not None: diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index 629c6edfb18b..c7f6b8525539 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -43,7 +43,7 @@ class CodeActAgent(Agent): - VERSION = '2.2' + VERSION = "2.2" """ The Code Act Agent is a minimalist agent. The agent works by passing the model a list of action-observation pairs and prompting the model to take the next step. @@ -70,7 +70,7 @@ class CodeActAgent(Agent): AgentSkillsRequirement(), JupyterRequirement(), ] - obs_prefix = 'OBSERVATION:\n' + obs_prefix = "OBSERVATION:\n" def __init__( self, @@ -88,8 +88,8 @@ def __init__( self.function_calling_active = self.config.function_calling if self.function_calling_active and not self.llm.is_function_calling_active(): logger.warning( - f'Function calling not supported for model {self.llm.config.model}. ' - 'Disabling function calling.' + f"Function calling not supported for model {self.llm.config.model}. " + "Disabling function calling." ) self.function_calling_active = False @@ -100,18 +100,24 @@ def __init__( codeact_enable_llm_editor=self.config.codeact_enable_llm_editor, ) logger.debug( - f'TOOLS loaded for CodeActAgent: {json.dumps(self.tools, indent=2)}' + f"TOOLS loaded for CodeActAgent: {json.dumps(self.tools, indent=2)}" ) self.prompt_manager = PromptManager( - microagent_dir=os.path.join(os.path.dirname(__file__), 'micro') if self.config.use_microagents else None, - prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts', 'tools'), + microagent_dir=os.path.join(os.path.dirname(__file__), "micro") + if self.config.use_microagents + else None, + prompt_dir=os.path.join(os.path.dirname(__file__), "prompts", "tools"), disabled_microagents=self.config.disabled_microagents, ) else: self.action_parser = CodeActResponseParser() self.prompt_manager = PromptManager( - microagent_dir=os.path.join(os.path.dirname(__file__), 'micro') if self.config.use_microagents else None, - prompt_dir=os.path.join(os.path.dirname(__file__), 'prompts', 'default'), + microagent_dir=os.path.join(os.path.dirname(__file__), "micro") + if self.config.use_microagents + else None, + prompt_dir=os.path.join( + os.path.dirname(__file__), "prompts", "default" + ), agent_skills_docs=AgentSkillsRequirement.documentation, disabled_microagents=self.config.disabled_microagents, ) @@ -162,11 +168,11 @@ def get_action_message( FileEditAction, BrowseInteractiveAction, ), - ) or (isinstance(action, AgentFinishAction) and action.source == 'agent'): + ) or (isinstance(action, AgentFinishAction) and action.source == "agent"): if self.function_calling_active: tool_metadata = action.tool_call_metadata assert tool_metadata is not None, ( - 'Tool call metadata should NOT be None when function calling is enabled. Action: ' + "Tool call metadata should NOT be None when function calling is enabled. Action: " + str(action) ) @@ -177,7 +183,7 @@ def get_action_message( pending_tool_call_action_messages[llm_response.id] = Message( role=assistant_msg.role, # tool call content SHOULD BE a string - content=[TextContent(text=assistant_msg.content or '')] + content=[TextContent(text=assistant_msg.content or "")] if assistant_msg.content is not None else [], tool_calls=assistant_msg.tool_calls, @@ -185,19 +191,19 @@ def get_action_message( return [] else: assert not isinstance(action, BrowseInteractiveAction), ( - 'BrowseInteractiveAction is not supported in non-function calling mode. Action: ' + "BrowseInteractiveAction is not supported in non-function calling mode. Action: " + str(action) ) content = [TextContent(text=self.action_parser.action_to_str(action))] return [ Message( - role='user' if action.source == 'user' else 'assistant', + role="user" if action.source == "user" else "assistant", content=content, ) ] elif isinstance(action, MessageAction): - role = 'user' if action.source == 'user' else 'assistant' - content = [TextContent(text=action.content or '')] + role = "user" if action.source == "user" else "assistant" + content = [TextContent(text=action.content or "")] if self.llm.vision_is_active() and action.image_urls: content.append(ImageContent(image_urls=action.image_urls)) return [ @@ -240,58 +246,58 @@ def get_observation_message( """ message: Message max_message_chars = self.llm.config.max_message_chars - obs_prefix = 'OBSERVATION:\n' + obs_prefix = "OBSERVATION:\n" if isinstance(obs, CmdOutputObservation): text = obs_prefix + truncate_content( obs.content + obs.interpreter_details, max_message_chars ) - text += f'\n[Command finished with exit code {obs.exit_code}]' - message = Message(role='user', content=[TextContent(text=text)]) + text += f"\n[Command finished with exit code {obs.exit_code}]" + message = Message(role="user", content=[TextContent(text=text)]) elif isinstance(obs, IPythonRunCellObservation): text = obs_prefix + obs.content # replace base64 images with a placeholder - splitted = text.split('\n') + splitted = text.split("\n") for i, line in enumerate(splitted): - if '![image](data:image/png;base64,' in line: + if "![image](data:image/png;base64," in line: splitted[i] = ( - '![image](data:image/png;base64, ...) already displayed to user' + "![image](data:image/png;base64, ...) already displayed to user" ) - text = '\n'.join(splitted) + text = "\n".join(splitted) text = truncate_content(text, max_message_chars) - message = Message(role='user', content=[TextContent(text=text)]) + message = Message(role="user", content=[TextContent(text=text)]) elif isinstance(obs, FileEditObservation): text = obs_prefix + truncate_content(str(obs), max_message_chars) - message = Message(role='user', content=[TextContent(text=text)]) + message = Message(role="user", content=[TextContent(text=text)]) elif isinstance(obs, BrowserOutputObservation): text = obs.get_agent_obs_text() message = Message( - role='user', + role="user", content=[TextContent(text=obs_prefix + text)], ) elif isinstance(obs, AgentDelegateObservation): text = obs_prefix + truncate_content( - obs.outputs['content'] if 'content' in obs.outputs else '', + obs.outputs["content"] if "content" in obs.outputs else "", max_message_chars, ) - message = Message(role='user', content=[TextContent(text=text)]) + message = Message(role="user", content=[TextContent(text=text)]) elif isinstance(obs, ErrorObservation): text = obs_prefix + truncate_content(obs.content, max_message_chars) - text += '\n[Error occurred in processing last action]' - message = Message(role='user', content=[TextContent(text=text)]) + text += "\n[Error occurred in processing last action]" + message = Message(role="user", content=[TextContent(text=text)]) elif isinstance(obs, UserRejectObservation): - text = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars) - text += '\n[Last action has been rejected by the user]' - message = Message(role='user', content=[TextContent(text=text)]) + text = "OBSERVATION:\n" + truncate_content(obs.content, max_message_chars) + text += "\n[Last action has been rejected by the user]" + message = Message(role="user", content=[TextContent(text=text)]) else: # If an observation message is not returned, it will cause an error # when the LLM tries to return the next message - raise ValueError(f'Unknown observation type: {type(obs)}') + raise ValueError(f"Unknown observation type: {type(obs)}") if self.function_calling_active: # Update the message as tool response properly if (tool_call_metadata := obs.tool_call_metadata) is not None: tool_call_id_to_message[tool_call_metadata.tool_call_id] = Message( - role='tool', + role="tool", content=message.content, tool_call_id=tool_call_metadata.tool_call_id, name=tool_call_metadata.function_name, @@ -327,23 +333,23 @@ def step(self, state: State) -> Action: # if we're done, go back latest_user_message = state.get_last_user_message() - if latest_user_message and latest_user_message.content.strip() == '/exit': + if latest_user_message and latest_user_message.content.strip() == "/exit": return AgentFinishAction() # prepare what we want to send to the LLM messages = self._get_messages(state) params: dict = { - 'messages': self.llm.format_messages_for_llm(messages), + "messages": self.llm.format_messages_for_llm(messages), } if self.function_calling_active: - params['tools'] = self.tools - params['parallel_tool_calls'] = False + params["tools"] = self.tools + params["parallel_tool_calls"] = False else: - params['stop'] = [ - '', - '', - '', - '', + params["stop"] = [ + "", + "", + "", + "", ] response = self.llm.completion(**params) @@ -389,7 +395,7 @@ def _get_messages(self, state: State) -> list[Message]: """ messages: list[Message] = [ Message( - role='system', + role="system", content=[ TextContent( text=self.prompt_manager.get_system_message(), @@ -402,7 +408,7 @@ def _get_messages(self, state: State) -> list[Message]: if example_message: messages.append( Message( - role='user', + role="user", content=[TextContent(text=example_message)], cache_prompt=self.llm.is_caching_prompt_active(), ) @@ -424,7 +430,7 @@ def _get_messages(self, state: State) -> list[Message]: tool_call_id_to_message=tool_call_id_to_message, ) else: - raise ValueError(f'Unknown event type: {type(event)}') + raise ValueError(f"Unknown event type: {type(event)}") # Check pending tool call action messages and see if they are complete _response_ids_to_remove = [] @@ -433,8 +439,8 @@ def _get_messages(self, state: State) -> list[Message]: pending_message, ) in pending_tool_call_action_messages.items(): assert pending_message.tool_calls is not None, ( - 'Tool calls should NOT be None when function calling is enabled & the message is considered pending tool call. ' - f'Pending message: {pending_message}' + "Tool calls should NOT be None when function calling is enabled & the message is considered pending tool call. " + f"Pending message: {pending_message}" ) if all( tool_call.id in tool_call_id_to_message @@ -454,7 +460,7 @@ def _get_messages(self, state: State) -> list[Message]: for message in messages_to_add: if message: - if message.role == 'user': + if message.role == "user": self.prompt_manager.enhance_message(message) # handle error if the message is the SAME role as the previous message # litellm.exceptions.BadRequestError: litellm.BadRequestError: OpenAIException - Error code: 400 - {'detail': 'Only supports u/a/u/a/u...'} @@ -463,7 +469,7 @@ def _get_messages(self, state: State) -> list[Message]: if ( messages and messages[-1].role == message.role - and message.role != 'tool' + and message.role != "tool" ): messages[-1].content.extend(message.content) else: @@ -475,7 +481,7 @@ def _get_messages(self, state: State) -> list[Message]: # https://github.com/anthropics/anthropic-quickstarts/blob/8f734fd08c425c6ec91ddd613af04ff87d70c5a0/computer-use-demo/computer_use_demo/loop.py#L241-L262 breakpoints_remaining = 3 # remaining 1 for system/tool for message in reversed(messages): - if message.role == 'user' or message.role == 'tool': + if message.role == "user" or message.role == "tool": if breakpoints_remaining > 0: message.content[ -1 diff --git a/openhands/agenthub/codeact_agent/function_calling.py b/openhands/agenthub/codeact_agent/function_calling.py index 177e7b7ff171..e230dbb0150f 100644 --- a/openhands/agenthub/codeact_agent/function_calling.py +++ b/openhands/agenthub/codeact_agent/function_calling.py @@ -32,19 +32,19 @@ """ CmdRunTool = ChatCompletionToolParam( - type='function', + type="function", function=ChatCompletionToolParamFunctionChunk( - name='execute_bash', + name="execute_bash", description=_BASH_DESCRIPTION, parameters={ - 'type': 'object', - 'properties': { - 'command': { - 'type': 'string', - 'description': 'The bash command to execute. Can be empty to view additional logs when previous exit code is `-1`. Can be `ctrl+c` to interrupt the currently running process.', + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The bash command to execute. Can be empty to view additional logs when previous exit code is `-1`. Can be `ctrl+c` to interrupt the currently running process.", }, }, - 'required': ['command'], + "required": ["command"], }, ), ) @@ -58,19 +58,19 @@ # {AgentSkillsRequirement.documentation}""" IPythonTool = ChatCompletionToolParam( - type='function', + type="function", function=ChatCompletionToolParamFunctionChunk( - name='execute_ipython_cell', + name="execute_ipython_cell", description=_IPYTHON_DESCRIPTION, parameters={ - 'type': 'object', - 'properties': { - 'code': { - 'type': 'string', - 'description': 'The Python code to execute. Supports magic commands like %pip.', + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "The Python code to execute. Supports magic commands like %pip.", }, }, - 'required': ['code'], + "required": ["code"], }, ), ) @@ -182,31 +182,31 @@ def __init__(self): """ LLMBasedFileEditTool = ChatCompletionToolParam( - type='function', + type="function", function=ChatCompletionToolParamFunctionChunk( - name='edit_file', + name="edit_file", description=_FILE_EDIT_DESCRIPTION, parameters={ - 'type': 'object', - 'properties': { - 'path': { - 'type': 'string', - 'description': 'The absolute path to the file to be edited.', + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "The absolute path to the file to be edited.", }, - 'new_content_draft': { - 'type': 'string', - 'description': 'A draft of the new content for the file being edited. Note that the assistant may skip unchanged lines.', + "new_content_draft": { + "type": "string", + "description": "A draft of the new content for the file being edited. Note that the assistant may skip unchanged lines.", }, - 'start': { - 'type': 'integer', - 'description': 'The starting line number for the edit (1-indexed, inclusive). Default is 1.', + "start": { + "type": "integer", + "description": "The starting line number for the edit (1-indexed, inclusive). Default is 1.", }, - 'end': { - 'type': 'integer', - 'description': 'The ending line number for the edit (1-indexed, inclusive). Default is -1 (end of file).', + "end": { + "type": "integer", + "description": "The ending line number for the edit (1-indexed, inclusive). Default is -1 (end of file).", }, }, - 'required': ['path', 'content'], + "required": ["path", "content"], }, ), ) @@ -225,52 +225,52 @@ def __init__(self): """ StrReplaceEditorTool = ChatCompletionToolParam( - type='function', + type="function", function=ChatCompletionToolParamFunctionChunk( - name='str_replace_editor', + name="str_replace_editor", description=_STR_REPLACE_EDITOR_DESCRIPTION, parameters={ - 'type': 'object', - 'properties': { - 'command': { - 'description': 'The commands to run. Allowed options are: `view`, `create`, `str_replace`, `insert`, `undo_edit`.', - 'enum': ['view', 'create', 'str_replace', 'insert', 'undo_edit'], - 'type': 'string', + "type": "object", + "properties": { + "command": { + "description": "The commands to run. Allowed options are: `view`, `create`, `str_replace`, `insert`, `undo_edit`.", + "enum": ["view", "create", "str_replace", "insert", "undo_edit"], + "type": "string", }, - 'path': { - 'description': 'Absolute path to file or directory, e.g. `/workspace/file.py` or `/workspace`.', - 'type': 'string', + "path": { + "description": "Absolute path to file or directory, e.g. `/workspace/file.py` or `/workspace`.", + "type": "string", }, - 'file_text': { - 'description': 'Required parameter of `create` command, with the content of the file to be created.', - 'type': 'string', + "file_text": { + "description": "Required parameter of `create` command, with the content of the file to be created.", + "type": "string", }, - 'old_str': { - 'description': 'Required parameter of `str_replace` command containing the string in `path` to replace.', - 'type': 'string', + "old_str": { + "description": "Required parameter of `str_replace` command containing the string in `path` to replace.", + "type": "string", }, - 'new_str': { - 'description': 'Optional parameter of `str_replace` command containing the new string (if not given, no string will be added). Required parameter of `insert` command containing the string to insert.', - 'type': 'string', + "new_str": { + "description": "Optional parameter of `str_replace` command containing the new string (if not given, no string will be added). Required parameter of `insert` command containing the string to insert.", + "type": "string", }, - 'insert_line': { - 'description': 'Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`.', - 'type': 'integer', + "insert_line": { + "description": "Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`.", + "type": "integer", }, - 'view_range': { - 'description': 'Optional parameter of `view` command when `path` points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting `[start_line, -1]` shows all lines from `start_line` to the end of the file.', - 'items': {'type': 'integer'}, - 'type': 'array', + "view_range": { + "description": "Optional parameter of `view` command when `path` points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting `[start_line, -1]` shows all lines from `start_line` to the end of the file.", + "items": {"type": "integer"}, + "type": "array", }, }, - 'required': ['command', 'path'], + "required": ["command", "path"], }, ), ) # from browsergym/core/action/highlevel.py _browser_action_space = HighLevelActionSet( - subsets=['bid', 'nav'], + subsets=["bid", "nav"], strict=False, # less strict on the parsing of the actions multiaction=True, # enable to agent to take multiple actions at once ) @@ -395,28 +395,28 @@ def __init__(self): for _, action in _browser_action_space.action_set.items(): assert ( action.signature in _BROWSER_TOOL_DESCRIPTION - ), f'Browser description mismatch. Please double check if the BrowserGym updated their action space.\n\nAction: {action.signature}' + ), f"Browser description mismatch. Please double check if the BrowserGym updated their action space.\n\nAction: {action.signature}" assert ( action.description in _BROWSER_TOOL_DESCRIPTION - ), f'Browser description mismatch. Please double check if the BrowserGym updated their action space.\n\nAction: {action.description}' + ), f"Browser description mismatch. Please double check if the BrowserGym updated their action space.\n\nAction: {action.description}" BrowserTool = ChatCompletionToolParam( - type='function', + type="function", function=ChatCompletionToolParamFunctionChunk( - name='browser', + name="browser", description=_BROWSER_DESCRIPTION, parameters={ - 'type': 'object', - 'properties': { - 'code': { - 'type': 'string', - 'description': ( - 'The Python code that interacts with the browser.\n' + "type": "object", + "properties": { + "code": { + "type": "string", + "description": ( + "The Python code that interacts with the browser.\n" + _BROWSER_TOOL_DESCRIPTION ), } }, - 'required': ['code'], + "required": ["code"], }, ), ) @@ -424,16 +424,16 @@ def __init__(self): _FINISH_DESCRIPTION = """Finish the interaction when the task is complete OR if the assistant cannot proceed further with the task.""" FinishTool = ChatCompletionToolParam( - type='function', + type="function", function=ChatCompletionToolParamFunctionChunk( - name='finish', + name="finish", description=_FINISH_DESCRIPTION, ), ) def combine_thought(action: Action, thought: str) -> Action: - if not hasattr(action, 'thought'): + if not hasattr(action, "thought"): return action if thought: action.thought = thought @@ -442,17 +442,17 @@ def combine_thought(action: Action, thought: str) -> Action: def response_to_actions(response: ModelResponse) -> list[Action]: actions: list[Action] = [] - assert len(response.choices) == 1, 'Only one choice is supported for now' + assert len(response.choices) == 1, "Only one choice is supported for now" assistant_msg = response.choices[0].message if assistant_msg.tool_calls: # Check if there's assistant_msg.content. If so, add it to the thought - thought = '' + thought = "" if isinstance(assistant_msg.content, str): thought = assistant_msg.content elif isinstance(assistant_msg.content, list): for msg in assistant_msg.content: - if msg['type'] == 'text': - thought += msg['text'] + if msg["type"] == "text": + thought += msg["text"] # Process each tool call to OpenHands action for i, tool_call in enumerate(assistant_msg.tool_calls): @@ -461,33 +461,33 @@ def response_to_actions(response: ModelResponse) -> list[Action]: arguments = json.loads(tool_call.function.arguments) except json.decoder.JSONDecodeError as e: raise RuntimeError( - f'Failed to parse tool call arguments: {tool_call.function.arguments}' + f"Failed to parse tool call arguments: {tool_call.function.arguments}" ) from e - if tool_call.function.name == 'execute_bash': + if tool_call.function.name == "execute_bash": action = CmdRunAction(**arguments) - elif tool_call.function.name == 'execute_ipython_cell': + elif tool_call.function.name == "execute_ipython_cell": action = IPythonRunCellAction(**arguments) - elif tool_call.function.name == 'delegate_to_browsing_agent': + elif tool_call.function.name == "delegate_to_browsing_agent": action = AgentDelegateAction( - agent='BrowsingAgent', + agent="BrowsingAgent", inputs=arguments, ) - elif tool_call.function.name == 'finish': + elif tool_call.function.name == "finish": action = AgentFinishAction() - elif tool_call.function.name == 'edit_file': + elif tool_call.function.name == "edit_file": action = FileEditAction(**arguments) - elif tool_call.function.name == 'str_replace_editor': + elif tool_call.function.name == "str_replace_editor": # We implement this in agent_skills, which can be used via Jupyter # convert tool_call.function.arguments to kwargs that can be passed to file_editor - code = f'print(file_editor(**{arguments}))' + code = f"print(file_editor(**{arguments}))" logger.debug( - f'TOOL CALL: str_replace_editor -> file_editor with code: {code}' + f"TOOL CALL: str_replace_editor -> file_editor with code: {code}" ) action = IPythonRunCellAction(code=code, include_extra=False) - elif tool_call.function.name == 'browser': - action = BrowseInteractiveAction(browser_actions=arguments['code']) + elif tool_call.function.name == "browser": + action = BrowseInteractiveAction(browser_actions=arguments["code"]) else: - raise RuntimeError(f'Unknown tool call: {tool_call.function.name}') + raise RuntimeError(f"Unknown tool call: {tool_call.function.name}") # We only add thought to the first action if i == 0: diff --git a/openhands/agenthub/codeact_swe_agent/__init__.py b/openhands/agenthub/codeact_swe_agent/__init__.py index ef5233786194..d7e826febae6 100644 --- a/openhands/agenthub/codeact_swe_agent/__init__.py +++ b/openhands/agenthub/codeact_swe_agent/__init__.py @@ -1,4 +1,4 @@ from openhands.agenthub.codeact_swe_agent.codeact_swe_agent import CodeActSWEAgent from openhands.controller.agent import Agent -Agent.register('CodeActSWEAgent', CodeActSWEAgent) +Agent.register("CodeActSWEAgent", CodeActSWEAgent) diff --git a/openhands/agenthub/codeact_swe_agent/action_parser.py b/openhands/agenthub/codeact_swe_agent/action_parser.py index c77c1404a6e6..ffa16463b7b8 100644 --- a/openhands/agenthub/codeact_swe_agent/action_parser.py +++ b/openhands/agenthub/codeact_swe_agent/action_parser.py @@ -21,14 +21,14 @@ def __init__( self.finish_command = None def check_condition(self, action_str: str) -> bool: - self.finish_command = re.search(r'.*', action_str, re.DOTALL) + self.finish_command = re.search(r".*", action_str, re.DOTALL) return self.finish_command is not None def parse(self, action_str: str) -> Action: assert ( self.finish_command is not None - ), 'self.finish_command should not be None when parse is called' - thought = action_str.replace(self.finish_command.group(0), '').strip() + ), "self.finish_command should not be None when parse is called" + thought = action_str.replace(self.finish_command.group(0), "").strip() return AgentFinishAction(thought=thought) @@ -45,18 +45,18 @@ def __init__( def check_condition(self, action_str: str) -> bool: self.bash_command = re.search( - r'(.*?)', action_str, re.DOTALL + r"(.*?)", action_str, re.DOTALL ) return self.bash_command is not None def parse(self, action_str: str) -> Action: assert ( self.bash_command is not None - ), 'self.bash_command should not be None when parse is called' - thought = action_str.replace(self.bash_command.group(0), '').strip() + ), "self.bash_command should not be None when parse is called" + thought = action_str.replace(self.bash_command.group(0), "").strip() # a command was found command_group = self.bash_command.group(1).strip() - if command_group.strip() == 'exit': + if command_group.strip() == "exit": return AgentFinishAction() return CmdRunAction(command=command_group, thought=thought) @@ -70,20 +70,20 @@ def __init__( self, ): self.python_code = None - self.jupyter_kernel_init_code: str = 'from agentskills import *' + self.jupyter_kernel_init_code: str = "from agentskills import *" def check_condition(self, action_str: str) -> bool: self.python_code = re.search( - r'(.*?)', action_str, re.DOTALL + r"(.*?)", action_str, re.DOTALL ) return self.python_code is not None def parse(self, action_str: str) -> Action: assert ( self.python_code is not None - ), 'self.python_code should not be None when parse is called' + ), "self.python_code should not be None when parse is called" code_group = self.python_code.group(1).strip() - thought = action_str.replace(self.python_code.group(0), '').strip() + thought = action_str.replace(self.python_code.group(0), "").strip() return IPythonRunCellAction( code=code_group, thought=thought, diff --git a/openhands/agenthub/codeact_swe_agent/codeact_swe_agent.py b/openhands/agenthub/codeact_swe_agent/codeact_swe_agent.py index 8d403d357e03..ba7834b1ca98 100644 --- a/openhands/agenthub/codeact_swe_agent/codeact_swe_agent.py +++ b/openhands/agenthub/codeact_swe_agent/codeact_swe_agent.py @@ -34,7 +34,7 @@ def get_system_message() -> str: - return f'{SYSTEM_PREFIX}\n\n{COMMAND_DOCS}\n\n{SYSTEM_SUFFIX}' + return f"{SYSTEM_PREFIX}\n\n{COMMAND_DOCS}\n\n{SYSTEM_SUFFIX}" def get_in_context_example() -> str: @@ -42,7 +42,7 @@ def get_in_context_example() -> str: class CodeActSWEAgent(Agent): - VERSION = '1.6' + VERSION = "1.6" """ This agent is an adaptation of the original [SWE Agent](https://swe-agent.com/) based on CodeAct 1.5 using the `agentskills` library of OpenHands. @@ -80,13 +80,13 @@ def __init__( def action_to_str(self, action: Action) -> str: if isinstance(action, CmdRunAction): return ( - f'{action.thought}\n\n{action.command}\n' + f"{action.thought}\n\n{action.command}\n" ) elif isinstance(action, IPythonRunCellAction): - return f'{action.thought}\n\n{action.code}\n' + return f"{action.thought}\n\n{action.code}\n" elif isinstance(action, MessageAction): return action.content - return '' + return "" def get_action_message(self, action: Action) -> Message | None: if isinstance(action, (CmdRunAction, IPythonRunCellAction, MessageAction)): @@ -100,7 +100,7 @@ def get_action_message(self, action: Action) -> Message | None: content.append(ImageContent(image_urls=action.image_urls)) return Message( - role='user' if action.source == 'user' else 'assistant', content=content + role="user" if action.source == "user" else "assistant", content=content ) return None @@ -108,33 +108,33 @@ def get_action_message(self, action: Action) -> Message | None: def get_observation_message(self, obs: Observation) -> Message | None: max_message_chars = self.llm.config.max_message_chars if isinstance(obs, CmdOutputObservation): - text = 'OBSERVATION:\n' + truncate_content( + text = "OBSERVATION:\n" + truncate_content( obs.content + obs.interpreter_details, max_message_chars ) text += ( - f'\n[Command {obs.command_id} finished with exit code {obs.exit_code}]' + f"\n[Command {obs.command_id} finished with exit code {obs.exit_code}]" ) - return Message(role='user', content=[TextContent(text=text)]) + return Message(role="user", content=[TextContent(text=text)]) elif isinstance(obs, IPythonRunCellObservation): - text = 'OBSERVATION:\n' + obs.content + text = "OBSERVATION:\n" + obs.content # replace base64 images with a placeholder - splitted = text.split('\n') + splitted = text.split("\n") for i, line in enumerate(splitted): - if '![image](data:image/png;base64,' in line: + if "![image](data:image/png;base64," in line: splitted[i] = ( - '![image](data:image/png;base64, ...) already displayed to user' + "![image](data:image/png;base64, ...) already displayed to user" ) - text = '\n'.join(splitted) + text = "\n".join(splitted) text = truncate_content(text, max_message_chars) - return Message(role='user', content=[TextContent(text=text)]) + return Message(role="user", content=[TextContent(text=text)]) elif isinstance(obs, ErrorObservation): - text = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars) - text += '\n[Error occurred in processing last action]' - return Message(role='user', content=[TextContent(text=text)]) + text = "OBSERVATION:\n" + truncate_content(obs.content, max_message_chars) + text += "\n[Error occurred in processing last action]" + return Message(role="user", content=[TextContent(text=text)]) else: # If an observation message is not returned, it will cause an error # when the LLM tries to return the next message - raise ValueError(f'Unknown observation type: {type(obs)}') + raise ValueError(f"Unknown observation type: {type(obs)}") def reset(self) -> None: """Resets the CodeAct Agent.""" @@ -155,7 +155,7 @@ def step(self, state: State) -> Action: """ # if we're done, go back last_user_message = state.get_last_user_message() - if last_user_message and last_user_message.content.strip() == '/exit': + if last_user_message and last_user_message.content.strip() == "/exit": return AgentFinishAction() # prepare what we want to send to the LLM @@ -163,8 +163,8 @@ def step(self, state: State) -> Action: response = self.llm.completion( messages=self.llm.format_messages_for_llm(messages), stop=[ - '', - '', + "", + "", ], ) @@ -172,8 +172,8 @@ def step(self, state: State) -> Action: def _get_messages(self, state: State) -> list[Message]: messages: list[Message] = [ - Message(role='system', content=[TextContent(text=self.system_message)]), - Message(role='user', content=[TextContent(text=self.in_context_example)]), + Message(role="system", content=[TextContent(text=self.system_message)]), + Message(role="user", content=[TextContent(text=self.in_context_example)]), ] for event in state.history: @@ -183,7 +183,7 @@ def _get_messages(self, state: State) -> list[Message]: elif isinstance(event, Observation): message = self.get_observation_message(event) else: - raise ValueError(f'Unknown event type: {type(event)}') + raise ValueError(f"Unknown event type: {type(event)}") # add regular message if message: @@ -198,7 +198,7 @@ def _get_messages(self, state: State) -> list[Message]: # the latest user message is important: # we want to remind the agent of the environment constraints latest_user_message = next( - (m for m in reversed(messages) if m.role == 'user'), None + (m for m in reversed(messages) if m.role == "user"), None ) # Get the last user text inside content @@ -211,7 +211,7 @@ def _get_messages(self, state: State) -> list[Message]: ) ) # add a reminder to the prompt - reminder_text = f'\n\nENVIRONMENT REMINDER: You have {state.max_iterations - state.iteration} turns left to complete the task. When finished reply with .' + reminder_text = f"\n\nENVIRONMENT REMINDER: You have {state.max_iterations - state.iteration} turns left to complete the task. When finished reply with ." if latest_user_message_text: latest_user_message_text.text = ( diff --git a/openhands/agenthub/codeact_swe_agent/prompt.py b/openhands/agenthub/codeact_swe_agent/prompt.py index 1a2ffabad95f..ddb39ebdaa71 100644 --- a/openhands/agenthub/codeact_swe_agent/prompt.py +++ b/openhands/agenthub/codeact_swe_agent/prompt.py @@ -3,8 +3,8 @@ _AGENT_SKILLS_DOCS = AgentSkillsRequirement.documentation COMMAND_DOCS = ( - '\nApart from the standard Python library, the assistant can also use the following functions (already imported) in environment:\n' - f'{_AGENT_SKILLS_DOCS}' + "\nApart from the standard Python library, the assistant can also use the following functions (already imported) in environment:\n" + f"{_AGENT_SKILLS_DOCS}" "Please note that THE `edit_file` FUNCTION REQUIRES PROPER INDENTATION. If the assistant would like to add the line ' print(x)', it must fully write that out, with all those spaces before the code! Indentation is important and code that is not indented correctly will fail and require fixing before it can be run." ) diff --git a/openhands/agenthub/codeact_swe_agent/response_parser.py b/openhands/agenthub/codeact_swe_agent/response_parser.py index 147b8655f8f8..3375b2c2c291 100644 --- a/openhands/agenthub/codeact_swe_agent/response_parser.py +++ b/openhands/agenthub/codeact_swe_agent/response_parser.py @@ -33,10 +33,10 @@ def parse(self, response: str) -> Action: def parse_response(self, response) -> str: action = response.choices[0].message.content if action is None: - return '' - for lang in ['bash', 'ipython']: - if f'' in action and f'' not in action: - action += f'' + return "" + for lang in ["bash", "ipython"]: + if f"" in action and f"" not in action: + action += f"" return action def parse_action(self, action_str: str) -> Action: diff --git a/openhands/agenthub/delegator_agent/__init__.py b/openhands/agenthub/delegator_agent/__init__.py index 68e20efa3092..21a3fa14cca8 100644 --- a/openhands/agenthub/delegator_agent/__init__.py +++ b/openhands/agenthub/delegator_agent/__init__.py @@ -1,4 +1,4 @@ from openhands.agenthub.delegator_agent.agent import DelegatorAgent from openhands.controller.agent import Agent -Agent.register('DelegatorAgent', DelegatorAgent) +Agent.register("DelegatorAgent", DelegatorAgent) diff --git a/openhands/agenthub/delegator_agent/agent.py b/openhands/agenthub/delegator_agent/agent.py index 7cb987c8c3f7..f86335923ae9 100644 --- a/openhands/agenthub/delegator_agent/agent.py +++ b/openhands/agenthub/delegator_agent/agent.py @@ -7,12 +7,12 @@ class DelegatorAgent(Agent): - VERSION = '1.0' + VERSION = "1.0" """ The Delegator Agent is responsible for delegating tasks to other agents based on the current task. """ - current_delegate: str = '' + current_delegate: str = "" def __init__(self, llm: LLM, config: AgentConfig): """Initialize the Delegator Agent with an LLM @@ -33,11 +33,11 @@ def step(self, state: State) -> Action: - AgentFinishAction: If the last state was 'completed', 'verified', or 'abandoned' - AgentDelegateAction: The next agent to delegate the task to """ - if self.current_delegate == '': - self.current_delegate = 'study' + if self.current_delegate == "": + self.current_delegate = "study" task, _ = state.get_current_user_intent() return AgentDelegateAction( - agent='StudyRepoForTaskAgent', inputs={'task': task} + agent="StudyRepoForTaskAgent", inputs={"task": task} ) # last observation in history should be from the delegate @@ -48,40 +48,40 @@ def step(self, state: State) -> Action: break if not isinstance(last_observation, AgentDelegateObservation): - raise Exception('Last observation is not an AgentDelegateObservation') + raise Exception("Last observation is not an AgentDelegateObservation") goal, _ = state.get_current_user_intent() - if self.current_delegate == 'study': - self.current_delegate = 'coder' + if self.current_delegate == "study": + self.current_delegate = "coder" return AgentDelegateAction( - agent='CoderAgent', + agent="CoderAgent", inputs={ - 'task': goal, - 'summary': last_observation.outputs['summary'], + "task": goal, + "summary": last_observation.outputs["summary"], }, ) - elif self.current_delegate == 'coder': - self.current_delegate = 'verifier' + elif self.current_delegate == "coder": + self.current_delegate = "verifier" return AgentDelegateAction( - agent='VerifierAgent', + agent="VerifierAgent", inputs={ - 'task': goal, + "task": goal, }, ) - elif self.current_delegate == 'verifier': + elif self.current_delegate == "verifier": if ( - 'completed' in last_observation.outputs - and last_observation.outputs['completed'] + "completed" in last_observation.outputs + and last_observation.outputs["completed"] ): return AgentFinishAction() else: - self.current_delegate = 'coder' + self.current_delegate = "coder" return AgentDelegateAction( - agent='CoderAgent', + agent="CoderAgent", inputs={ - 'task': goal, - 'summary': last_observation.outputs['summary'], + "task": goal, + "summary": last_observation.outputs["summary"], }, ) else: - raise Exception('Invalid delegate state') + raise Exception("Invalid delegate state") diff --git a/openhands/agenthub/dummy_agent/__init__.py b/openhands/agenthub/dummy_agent/__init__.py index d0db8e26c9cd..0873b1d525d3 100644 --- a/openhands/agenthub/dummy_agent/__init__.py +++ b/openhands/agenthub/dummy_agent/__init__.py @@ -1,4 +1,4 @@ from openhands.agenthub.dummy_agent.agent import DummyAgent from openhands.controller.agent import Agent -Agent.register('DummyAgent', DummyAgent) +Agent.register("DummyAgent", DummyAgent) diff --git a/openhands/agenthub/dummy_agent/agent.py b/openhands/agenthub/dummy_agent/agent.py index 272e6c935f2e..5d7bc9c545e4 100644 --- a/openhands/agenthub/dummy_agent/agent.py +++ b/openhands/agenthub/dummy_agent/agent.py @@ -35,7 +35,7 @@ """ ActionObs = TypedDict( - 'ActionObs', {'action': Action, 'observations': list[Observation]} + "ActionObs", {"action": Action, "observations": list[Observation]} ) @@ -50,81 +50,81 @@ def __init__(self, llm: LLM, config: AgentConfig): super().__init__(llm, config) self.steps: list[ActionObs] = [ { - 'action': AddTaskAction( - parent='None', goal='check the current directory' + "action": AddTaskAction( + parent="None", goal="check the current directory" ), - 'observations': [], + "observations": [], }, { - 'action': AddTaskAction(parent='0', goal='run ls'), - 'observations': [], + "action": AddTaskAction(parent="0", goal="run ls"), + "observations": [], }, { - 'action': ModifyTaskAction(task_id='0', state='in_progress'), - 'observations': [], + "action": ModifyTaskAction(task_id="0", state="in_progress"), + "observations": [], }, { - 'action': MessageAction('Time to get started!'), - 'observations': [], + "action": MessageAction("Time to get started!"), + "observations": [], }, { - 'action': CmdRunAction(command='echo "foo"'), - 'observations': [ + "action": CmdRunAction(command='echo "foo"'), + "observations": [ CmdOutputObservation( - 'foo', command_id=-1, command='echo "foo"', exit_code=0 + "foo", command_id=-1, command='echo "foo"', exit_code=0 ) ], }, { - 'action': FileWriteAction( - content='echo "Hello, World!"', path='hello.sh' + "action": FileWriteAction( + content='echo "Hello, World!"', path="hello.sh" ), - 'observations': [ + "observations": [ FileWriteObservation( - content='echo "Hello, World!"', path='hello.sh' + content='echo "Hello, World!"', path="hello.sh" ) ], }, { - 'action': FileReadAction(path='hello.sh'), - 'observations': [ - FileReadObservation('echo "Hello, World!"\n', path='hello.sh') + "action": FileReadAction(path="hello.sh"), + "observations": [ + FileReadObservation('echo "Hello, World!"\n', path="hello.sh") ], }, { - 'action': CmdRunAction(command='bash hello.sh'), - 'observations': [ + "action": CmdRunAction(command="bash hello.sh"), + "observations": [ CmdOutputObservation( - 'bash: hello.sh: No such file or directory', + "bash: hello.sh: No such file or directory", command_id=-1, - command='bash workspace/hello.sh', + command="bash workspace/hello.sh", exit_code=127, ) ], }, { - 'action': BrowseURLAction(url='https://google.com'), - 'observations': [ + "action": BrowseURLAction(url="https://google.com"), + "observations": [ # BrowserOutputObservation('Simulated Google page',url='https://google.com',screenshot=''), ], }, { - 'action': BrowseInteractiveAction( + "action": BrowseInteractiveAction( browser_actions='goto("https://google.com")' ), - 'observations': [ + "observations": [ # BrowserOutputObservation('Simulated Google page after interaction',url='https://google.com',screenshot=''), ], }, { - 'action': AgentRejectAction(), - 'observations': [NullObservation('')], + "action": AgentRejectAction(), + "observations": [NullObservation("")], }, { - 'action': AgentFinishAction( - outputs={}, thought='Task completed', action='finish' + "action": AgentFinishAction( + outputs={}, thought="Task completed", action="finish" ), - 'observations': [AgentStateChangedObservation('', AgentState.FINISHED)], + "observations": [AgentStateChangedObservation("", AgentState.FINISHED)], }, ] @@ -133,23 +133,23 @@ def step(self, state: State) -> Action: return AgentFinishAction() current_step = self.steps[state.iteration] - action = current_step['action'] + action = current_step["action"] # If the action is AddTaskAction or ModifyTaskAction, update the parent ID or task_id if isinstance(action, AddTaskAction): - if action.parent == 'None': - action.parent = '' # Root task has no parent - elif action.parent == '0': + if action.parent == "None": + action.parent = "" # Root task has no parent + elif action.parent == "0": action.parent = state.root_task.id - elif action.parent.startswith('0.'): - action.parent = f'{state.root_task.id}{action.parent[1:]}' + elif action.parent.startswith("0."): + action.parent = f"{state.root_task.id}{action.parent[1:]}" elif isinstance(action, ModifyTaskAction): - if action.task_id == '0': + if action.task_id == "0": action.task_id = state.root_task.id - elif action.task_id.startswith('0.'): - action.task_id = f'{state.root_task.id}{action.task_id[1:]}' + elif action.task_id.startswith("0."): + action.task_id = f"{state.root_task.id}{action.task_id[1:]}" # Ensure the task_id doesn't start with a dot - if action.task_id.startswith('.'): + if action.task_id.startswith("."): action.task_id = action.task_id[1:] elif isinstance(action, (BrowseURLAction, BrowseInteractiveAction)): try: @@ -162,13 +162,13 @@ def step(self, state: State) -> Action: if state.iteration > 0: prev_step = self.steps[state.iteration - 1] - if 'observations' in prev_step and prev_step['observations']: - expected_observations = prev_step['observations'] + if "observations" in prev_step and prev_step["observations"]: + expected_observations = prev_step["observations"] hist_events = state.history[-len(expected_observations) :] if len(hist_events) < len(expected_observations): print( - f'Warning: Expected {len(expected_observations)} observations, but got {len(hist_events)}' + f"Warning: Expected {len(expected_observations)} observations, but got {len(hist_events)}" ) for i in range(min(len(expected_observations), len(hist_events))): @@ -177,16 +177,16 @@ def step(self, state: State) -> Action: # Remove dynamic fields for comparison for obs in [hist_obs, expected_obs]: - obs.pop('id', None) - obs.pop('timestamp', None) - obs.pop('cause', None) - obs.pop('source', None) - if 'extras' in obs: - obs['extras'].pop('command_id', None) + obs.pop("id", None) + obs.pop("timestamp", None) + obs.pop("cause", None) + obs.pop("source", None) + if "extras" in obs: + obs["extras"].pop("command_id", None) if hist_obs != expected_obs: print( - f'Warning: Observation mismatch. Expected {expected_obs}, got {hist_obs}' + f"Warning: Observation mismatch. Expected {expected_obs}, got {hist_obs}" ) return action @@ -201,11 +201,11 @@ def handle_browser_unavailable( self, action: Union[BrowseURLAction, BrowseInteractiveAction] ) -> Action: # Create a message action to inform that browsing is not available - message = 'Browser actions are not available in the DummyAgent environment.' + message = "Browser actions are not available in the DummyAgent environment." if isinstance(action, BrowseURLAction): - message += f' Unable to browse URL: {action.url}' + message += f" Unable to browse URL: {action.url}" elif isinstance(action, BrowseInteractiveAction): message += ( - f' Unable to perform interactive browsing: {action.browser_actions}' + f" Unable to perform interactive browsing: {action.browser_actions}" ) return MessageAction(content=message) diff --git a/openhands/agenthub/micro/agent.py b/openhands/agenthub/micro/agent.py index a9b0825afd9d..f5fbc05b181c 100644 --- a/openhands/agenthub/micro/agent.py +++ b/openhands/agenthub/micro/agent.py @@ -28,14 +28,12 @@ def to_json(obj, **kwargs): class MicroAgent(Agent): - VERSION = '1.0' - prompt = '' + VERSION = "1.0" + prompt = "" agent_definition: dict = {} def history_to_json(self, history: list[Event], max_events: int = 20, **kwargs): - """ - Serialize and simplify history to str format - """ + """Serialize and simplify history to str format""" processed_history = [] event_count = 0 @@ -54,11 +52,11 @@ def history_to_json(self, history: list[Event], max_events: int = 20, **kwargs): def __init__(self, llm: LLM, config: AgentConfig): super().__init__(llm, config) - if 'name' not in self.agent_definition: - raise ValueError('Agent definition must contain a name') + if "name" not in self.agent_definition: + raise ValueError("Agent definition must contain a name") self.prompt_template = Environment(loader=BaseLoader).from_string(self.prompt) self.delegates = all_microagents.copy() - del self.delegates[self.agent_definition['name']] + del self.delegates[self.agent_definition["name"]] def step(self, state: State) -> Action: last_user_message, last_image_urls = state.get_current_user_intent() @@ -73,10 +71,10 @@ def step(self, state: State) -> Action: content = [TextContent(text=prompt)] if self.llm.vision_is_active() and last_image_urls: content.append(ImageContent(image_urls=last_image_urls)) - message = Message(role='user', content=content) + message = Message(role="user", content=content) resp = self.llm.completion( messages=self.llm.format_messages_for_llm(message), ) - action_resp = resp['choices'][0]['message']['content'] + action_resp = resp["choices"][0]["message"]["content"] action = parse_response(action_resp) return action diff --git a/openhands/agenthub/micro/instructions.py b/openhands/agenthub/micro/instructions.py index 73e72eb2b3c9..e932961257df 100644 --- a/openhands/agenthub/micro/instructions.py +++ b/openhands/agenthub/micro/instructions.py @@ -2,7 +2,7 @@ instructions: dict = {} -base_dir = os.path.dirname(os.path.abspath(__file__)) + '/_instructions' +base_dir = os.path.dirname(os.path.abspath(__file__)) + "/_instructions" for root, dirs, files in os.walk(base_dir): if len(files) == 0: continue @@ -10,7 +10,7 @@ obj = instructions else: rel_base = os.path.relpath(root, base_dir) - keys = rel_base.split('/') + keys = rel_base.split("/") obj = instructions for key in keys: if key not in obj: @@ -18,5 +18,5 @@ obj = obj[key] for file in files: without_ext = os.path.splitext(file)[0] - with open(os.path.join(root, file), 'r') as f: + with open(os.path.join(root, file), "r") as f: obj[without_ext] = f.read() diff --git a/openhands/agenthub/micro/registry.py b/openhands/agenthub/micro/registry.py index cc16e4d26d9b..c10698415489 100644 --- a/openhands/agenthub/micro/registry.py +++ b/openhands/agenthub/micro/registry.py @@ -8,20 +8,20 @@ dirs = sorted(os.listdir(os.path.dirname(__file__))) for dir in dirs: - base = os.path.dirname(__file__) + '/' + dir + base = os.path.dirname(__file__) + "/" + dir if os.path.isfile(base): continue - if dir.startswith('_'): + if dir.startswith("_"): continue - promptFile = base + '/prompt.md' - agentFile = base + '/agent.yaml' + promptFile = base + "/prompt.md" + agentFile = base + "/agent.yaml" if not os.path.isfile(promptFile) or not os.path.isfile(agentFile): - raise Exception(f'Missing prompt or agent file in {base}. Please create them.') - with open(promptFile, 'r') as f: + raise Exception(f"Missing prompt or agent file in {base}. Please create them.") + with open(promptFile, "r") as f: prompt = f.read() - with open(agentFile, 'r') as f: + with open(agentFile, "r") as f: agent = yaml.safe_load(f) - if 'name' not in agent: - raise Exception(f'Missing name in {agentFile}') - agent['prompt'] = prompt - all_microagents[agent['name']] = agent + if "name" not in agent: + raise Exception(f"Missing name in {agentFile}") + agent["prompt"] = prompt + all_microagents[agent["name"]] = agent diff --git a/openhands/agenthub/planner_agent/__init__.py b/openhands/agenthub/planner_agent/__init__.py index e8c030e84c09..b09a79dcdc76 100644 --- a/openhands/agenthub/planner_agent/__init__.py +++ b/openhands/agenthub/planner_agent/__init__.py @@ -1,4 +1,4 @@ from openhands.agenthub.planner_agent.agent import PlannerAgent from openhands.controller.agent import Agent -Agent.register('PlannerAgent', PlannerAgent) +Agent.register("PlannerAgent", PlannerAgent) diff --git a/openhands/agenthub/planner_agent/agent.py b/openhands/agenthub/planner_agent/agent.py index f5aef523d9b9..659d0efe4007 100644 --- a/openhands/agenthub/planner_agent/agent.py +++ b/openhands/agenthub/planner_agent/agent.py @@ -9,7 +9,7 @@ class PlannerAgent(Agent): - VERSION = '1.0' + VERSION = "1.0" """ The planner agent utilizes a special prompting strategy to create long term plans for solving problems. The agent is given its previous action-observation pairs, current task, and hint based on last action taken at every step. @@ -36,9 +36,9 @@ def step(self, state: State) -> Action: - Action: The next action to take based on llm response """ if state.root_task.state in [ - 'completed', - 'verified', - 'abandoned', + "completed", + "verified", + "abandoned", ]: return AgentFinishAction() @@ -48,6 +48,6 @@ def step(self, state: State) -> Action: content = [TextContent(text=prompt)] if self.llm.vision_is_active() and image_urls: content.append(ImageContent(image_urls=image_urls)) - message = Message(role='user', content=content) + message = Message(role="user", content=content) resp = self.llm.completion(messages=self.llm.format_messages_for_llm(message)) return self.response_parser.parse(resp) diff --git a/openhands/agenthub/planner_agent/prompt.py b/openhands/agenthub/planner_agent/prompt.py index 7b73f4353131..8ab0e6e58d17 100644 --- a/openhands/agenthub/planner_agent/prompt.py +++ b/openhands/agenthub/planner_agent/prompt.py @@ -101,18 +101,18 @@ def get_hint(latest_action_id: str) -> str: """Returns action type hint based on given action_id""" hints = { - '': "You haven't taken any actions yet. Start by using `ls` to check out what files you're working with.", - ActionType.RUN: 'You should think about the command you just ran, what output it gave, and how that affects your plan.', - ActionType.READ: 'You should think about the file you just read, what you learned from it, and how that affects your plan.', - ActionType.WRITE: 'You just changed a file. You should think about how it affects your plan.', - ActionType.BROWSE: 'You should think about the page you just visited, and what you learned from it.', + "": "You haven't taken any actions yet. Start by using `ls` to check out what files you're working with.", + ActionType.RUN: "You should think about the command you just ran, what output it gave, and how that affects your plan.", + ActionType.READ: "You should think about the file you just read, what you learned from it, and how that affects your plan.", + ActionType.WRITE: "You just changed a file. You should think about how it affects your plan.", + ActionType.BROWSE: "You should think about the page you just visited, and what you learned from it.", ActionType.MESSAGE: "Look at your last thought in the history above. What does it suggest? Don't think anymore--take action.", - ActionType.ADD_TASK: 'You should think about the next action to take.', - ActionType.MODIFY_TASK: 'You should think about the next action to take.', - ActionType.SUMMARIZE: '', - ActionType.FINISH: '', + ActionType.ADD_TASK: "You should think about the next action to take.", + ActionType.MODIFY_TASK: "You should think about the next action to take.", + ActionType.SUMMARIZE: "", + ActionType.FINISH: "", } - return hints.get(latest_action_id, '') + return hints.get(latest_action_id, "") def get_prompt_and_images( @@ -159,19 +159,19 @@ def get_prompt_and_images( plan_status = "You're not currently working on any tasks. Your next action MUST be to mark a task as in_progress." # the hint, based on the last action - hint = get_hint(event_to_memory(latest_action, max_message_chars).get('action', '')) - logger.debug('HINT:\n' + hint, extra={'msg_type': 'DETAIL'}) + hint = get_hint(event_to_memory(latest_action, max_message_chars).get("action", "")) + logger.debug("HINT:\n" + hint, extra={"msg_type": "DETAIL"}) # the last relevant user message (the task) message, image_urls = state.get_current_user_intent() # finally, fill in the prompt return prompt % { - 'task': message, - 'plan': plan_str, - 'history': history_str, - 'hint': hint, - 'plan_status': plan_status, + "task": message, + "plan": plan_str, + "history": history_str, + "hint": hint, + "plan_status": plan_status, }, image_urls @@ -184,8 +184,8 @@ def parse_response(response: str) -> Action: - Action: A valid next action to perform from model output """ action_dict = json.loads(response) - if 'contents' in action_dict: + if "contents" in action_dict: # The LLM gets confused here. Might as well be robust - action_dict['content'] = action_dict.pop('contents') + action_dict["content"] = action_dict.pop("contents") action = action_from_dict(action_dict) return action diff --git a/openhands/agenthub/planner_agent/response_parser.py b/openhands/agenthub/planner_agent/response_parser.py index 12068cd5b769..d0f93e8a7879 100644 --- a/openhands/agenthub/planner_agent/response_parser.py +++ b/openhands/agenthub/planner_agent/response_parser.py @@ -16,7 +16,7 @@ def parse(self, response: str) -> Action: def parse_response(self, response) -> str: # get the next action from the response - return response['choices'][0]['message']['content'] + return response["choices"][0]["message"]["content"] def parse_action(self, action_str: str) -> Action: """Parses a string to find an action within it @@ -30,8 +30,8 @@ def parse_action(self, action_str: str) -> Action: # attempt to load the JSON dict from the response action_dict = json.loads(action_str) - if 'content' in action_dict: + if "content" in action_dict: # The LLM gets confused here. Might as well be robust - action_dict['contents'] = action_dict.pop('content') + action_dict["contents"] = action_dict.pop("content") return action_from_dict(action_dict) diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index d0b806a3f19b..09da4d9f300d 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -131,7 +131,8 @@ def __init__( async def close(self): """Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream. - Note that it's fairly important that this closes properly, otherwise the state is incomplete.""" + Note that it's fairly important that this closes properly, otherwise the state is incomplete. + """ await self.set_agent_state_to(AgentState.STOPPED) # we made history, now is the time to rewrite it! @@ -189,7 +190,6 @@ async def _react_to_exception( async def start_step_loop(self): """The main loop for the agent's step-by-step execution.""" - self.log('info', 'Starting step loop...') while should_continue(): try: @@ -303,7 +303,6 @@ async def _handle_message_action(self, action: MessageAction): def reset_task(self): """Resets the agent's task.""" - self.almost_stuck = 0 self.agent.reset() @@ -660,7 +659,6 @@ def _init_history(self): - Excludes all events between the action and observation - Includes the delegate action and observation themselves """ - # define range of events to fetch # delegates start with a start_id and initially won't find any events # otherwise we're restoring a previous session diff --git a/openhands/controller/state/state.py b/openhands/controller/state/state.py index d52844d418b4..f9438c0f8c2b 100644 --- a/openhands/controller/state/state.py +++ b/openhands/controller/state/state.py @@ -18,13 +18,13 @@ class TrafficControlState(str, Enum): # default state, no rate limiting - NORMAL = 'normal' + NORMAL = "normal" # task paused due to traffic control - THROTTLING = 'throttling' + THROTTLING = "throttling" # traffic control is temporarily paused - PAUSED = 'paused' + PAUSED = "paused" RESUMABLE_STATES = [ @@ -37,8 +37,7 @@ class TrafficControlState(str, Enum): @dataclass class State: - """ - Represents the running state of an agent in the OpenHands system, saving data of its operation and memory. + """Represents the running state of an agent in the OpenHands system, saving data of its operation and memory. - Multi-agent/delegate state: - store the task (conversation between the agent and the user) @@ -97,26 +96,26 @@ class State: # NOTE: This will never be used by the controller, but it can be used by different # evaluation tasks to store extra data needed to track the progress/state of the task. extra_data: dict[str, Any] = field(default_factory=dict) - last_error: str = '' + last_error: str = "" def save_to_session(self, sid: str, file_store: FileStore): pickled = pickle.dumps(self) - logger.debug(f'Saving state to session {sid}:{self.agent_state}') - encoded = base64.b64encode(pickled).decode('utf-8') + logger.debug(f"Saving state to session {sid}:{self.agent_state}") + encoded = base64.b64encode(pickled).decode("utf-8") try: - file_store.write(f'sessions/{sid}/agent_state.pkl', encoded) + file_store.write(f"sessions/{sid}/agent_state.pkl", encoded) except Exception as e: - logger.error(f'Failed to save state to session: {e}') + logger.error(f"Failed to save state to session: {e}") raise e @staticmethod - def restore_from_session(sid: str, file_store: FileStore) -> 'State': + def restore_from_session(sid: str, file_store: FileStore) -> "State": try: - encoded = file_store.read(f'sessions/{sid}/agent_state.pkl') + encoded = file_store.read(f"sessions/{sid}/agent_state.pkl") pickled = base64.b64decode(encoded) state = pickle.loads(pickled) except Exception as e: - logger.warning(f'Could not restore state from session: {e}') + logger.warning(f"Could not restore state from session: {e}") raise e # update state @@ -132,14 +131,14 @@ def restore_from_session(sid: str, file_store: FileStore) -> 'State': def __getstate__(self): # don't pickle history, it will be restored from the event stream state = self.__dict__.copy() - state['history'] = [] + state["history"] = [] return state def __setstate__(self, state): self.__dict__.update(state) # make sure we always have the attribute history - if not hasattr(self, 'history'): + if not hasattr(self, "history"): self.history = [] def get_current_user_intent(self) -> tuple[str | None, list[str] | None]: @@ -147,7 +146,7 @@ def get_current_user_intent(self) -> tuple[str | None, list[str] | None]: last_user_message = None last_user_message_image_urls: list[str] | None = [] for event in reversed(self.history): - if isinstance(event, MessageAction) and event.source == 'user': + if isinstance(event, MessageAction) and event.source == "user": last_user_message = event.content last_user_message_image_urls = event.image_urls elif isinstance(event, AgentFinishAction): diff --git a/openhands/controller/state/task.py b/openhands/controller/state/task.py index 456ae0f0a27d..ffd3adaa14db 100644 --- a/openhands/controller/state/task.py +++ b/openhands/controller/state/task.py @@ -4,11 +4,11 @@ ) from openhands.core.logger import openhands_logger as logger -OPEN_STATE = 'open' -COMPLETED_STATE = 'completed' -ABANDONED_STATE = 'abandoned' -IN_PROGRESS_STATE = 'in_progress' -VERIFIED_STATE = 'verified' +OPEN_STATE = "open" +COMPLETED_STATE = "completed" +ABANDONED_STATE = "abandoned" +IN_PROGRESS_STATE = "in_progress" +VERIFIED_STATE = "verified" STATES = [ OPEN_STATE, COMPLETED_STATE, @@ -21,12 +21,12 @@ class Task: id: str goal: str - parent: 'Task | None' - subtasks: list['Task'] + parent: "Task | None" + subtasks: list["Task"] def __init__( self, - parent: 'Task', + parent: "Task", goal: str, state: str = OPEN_STATE, subtasks=None, # noqa: B006 @@ -42,26 +42,26 @@ def __init__( if subtasks is None: subtasks = [] if parent.id: - self.id = parent.id + '.' + str(len(parent.subtasks)) + self.id = parent.id + "." + str(len(parent.subtasks)) else: self.id = str(len(parent.subtasks)) self.parent = parent self.goal = goal - logger.debug(f'Creating task {self.id} with parent={parent.id}, goal={goal}') + logger.debug(f"Creating task {self.id} with parent={parent.id}, goal={goal}") self.subtasks = [] for subtask in subtasks or []: if isinstance(subtask, Task): self.subtasks.append(subtask) else: - goal = subtask.get('goal') - state = subtask.get('state') - subtasks = subtask.get('subtasks') - logger.debug(f'Reading: {goal}, {state}, {subtasks}') + goal = subtask.get("goal") + state = subtask.get("state") + subtasks = subtask.get("subtasks") + logger.debug(f"Reading: {goal}, {state}, {subtasks}") self.subtasks.append(Task(self, goal, state, subtasks)) self.state = OPEN_STATE - def to_string(self, indent=''): + def to_string(self, indent=""): """Returns a string representation of the task and its subtasks. Args: @@ -70,20 +70,20 @@ def to_string(self, indent=''): Returns: A string representation of the task and its subtasks. """ - emoji = '' + emoji = "" if self.state == VERIFIED_STATE: - emoji = '✅' + emoji = "✅" elif self.state == COMPLETED_STATE: - emoji = '🟢' + emoji = "🟢" elif self.state == ABANDONED_STATE: - emoji = '❌' + emoji = "❌" elif self.state == IN_PROGRESS_STATE: - emoji = '💪' + emoji = "💪" elif self.state == OPEN_STATE: - emoji = '🔵' - result = indent + emoji + ' ' + self.id + ' ' + self.goal + '\n' + emoji = "🔵" + result = indent + emoji + " " + self.id + " " + self.goal + "\n" for subtask in self.subtasks: - result += subtask.to_string(indent + ' ') + result += subtask.to_string(indent + " ") return result def to_dict(self): @@ -93,10 +93,10 @@ def to_dict(self): A dictionary containing the task's attributes. """ return { - 'id': self.id, - 'goal': self.goal, - 'state': self.state, - 'subtasks': [t.to_dict() for t in self.subtasks], + "id": self.id, + "goal": self.goal, + "state": self.state, + "subtasks": [t.to_dict() for t in self.subtasks], } def set_state(self, state): @@ -108,7 +108,7 @@ def set_state(self, state): TaskInvalidStateError: If the provided state is invalid. """ if state not in STATES: - logger.error('Invalid state: %s', state) + logger.error("Invalid state: %s", state) raise TaskInvalidStateError(state) self.state = state if ( @@ -123,7 +123,7 @@ def set_state(self, state): if self.parent is not None: self.parent.set_state(state) - def get_current_task(self) -> 'Task | None': + def get_current_task(self) -> "Task | None": """Retrieves the current task in progress. Returns: @@ -151,8 +151,8 @@ class RootTask(Task): state: The state of the root_task. """ - id: str = '' - goal: str = '' + id: str = "" + goal: str = "" parent: None = None def __init__(self): @@ -179,18 +179,18 @@ def get_task_by_id(self, id: str) -> Task: Raises: AgentMalformedActionError: If the provided task ID is invalid or does not exist. """ - if id == '': + if id == "": return self if len(self.subtasks) == 0: - raise LLMMalformedActionError('Task does not exist:' + id) + raise LLMMalformedActionError("Task does not exist:" + id) try: - parts = [int(p) for p in id.split('.')] + parts = [int(p) for p in id.split(".")] except ValueError: - raise LLMMalformedActionError('Invalid task id:' + id) + raise LLMMalformedActionError("Invalid task id:" + id) task: Task = self for part in parts: if part >= len(task.subtasks): - raise LLMMalformedActionError('Task does not exist:' + id) + raise LLMMalformedActionError("Task does not exist:" + id) task = task.subtasks[part] return task @@ -215,7 +215,7 @@ def set_subtask_state(self, id: str, state: str): state: The new state of the subtask. """ task = self.get_task_by_id(id) - logger.debug('Setting task {task.id} from state {task.state} to {state}') + logger.debug("Setting task {task.id} from state {task.state} to {state}") task.set_state(state) unfinished_tasks = [ t diff --git a/openhands/core/cli.py b/openhands/core/cli.py index 5a4f30da7fdc..b870e1e4fe14 100644 --- a/openhands/core/cli.py +++ b/openhands/core/cli.py @@ -78,7 +78,6 @@ def display_event(event: Event): async def main(): """Runs the agent in CLI mode""" - parser = get_parser() # Add the version argument parser.add_argument( diff --git a/openhands/core/config/__init__.py b/openhands/core/config/__init__.py index b8fefb715cf3..8cc2cd082e76 100644 --- a/openhands/core/config/__init__.py +++ b/openhands/core/config/__init__.py @@ -19,19 +19,19 @@ ) __all__ = [ - 'OH_DEFAULT_AGENT', - 'OH_MAX_ITERATIONS', - 'AgentConfig', - 'AppConfig', - 'LLMConfig', - 'SandboxConfig', - 'SecurityConfig', - 'load_app_config', - 'load_from_env', - 'load_from_toml', - 'finalize_config', - 'get_llm_config_arg', - 'get_field_info', - 'get_parser', - 'parse_arguments', + "OH_DEFAULT_AGENT", + "OH_MAX_ITERATIONS", + "AgentConfig", + "AppConfig", + "LLMConfig", + "SandboxConfig", + "SecurityConfig", + "load_app_config", + "load_from_env", + "load_from_toml", + "finalize_config", + "get_llm_config_arg", + "get_field_info", + "get_parser", + "parse_arguments", ] diff --git a/openhands/core/config/app_config.py b/openhands/core/config/app_config.py index d11072a9c92d..3b1951532072 100644 --- a/openhands/core/config/app_config.py +++ b/openhands/core/config/app_config.py @@ -48,62 +48,62 @@ class AppConfig: default_agent: str = OH_DEFAULT_AGENT sandbox: SandboxConfig = field(default_factory=SandboxConfig) security: SecurityConfig = field(default_factory=SecurityConfig) - runtime: str = 'eventstream' - file_store: str = 'memory' - file_store_path: str = '/tmp/file_store' + runtime: str = "eventstream" + file_store: str = "memory" + file_store_path: str = "/tmp/file_store" trajectories_path: str | None = None workspace_base: str | None = None workspace_mount_path: str | None = None - workspace_mount_path_in_sandbox: str = '/workspace' + workspace_mount_path_in_sandbox: str = "/workspace" workspace_mount_rewrite: str | None = None - cache_dir: str = '/tmp/cache' + cache_dir: str = "/tmp/cache" run_as_openhands: bool = True max_iterations: int = OH_MAX_ITERATIONS max_budget_per_task: float | None = None - e2b_api_key: str = '' - modal_api_token_id: str = '' - modal_api_token_secret: str = '' + e2b_api_key: str = "" + modal_api_token_id: str = "" + modal_api_token_secret: str = "" disable_color: bool = False jwt_secret: str = uuid.uuid4().hex debug: bool = False file_uploads_max_file_size_mb: int = 0 file_uploads_restrict_file_types: bool = False - file_uploads_allowed_extensions: list[str] = field(default_factory=lambda: ['.*']) + file_uploads_allowed_extensions: list[str] = field(default_factory=lambda: [".*"]) runloop_api_key: str | None = None defaults_dict: ClassVar[dict] = {} - def get_llm_config(self, name='llm') -> LLMConfig: + def get_llm_config(self, name="llm") -> LLMConfig: """Llm is the name for default config (for backward compatibility prior to 0.8)""" if name in self.llms: return self.llms[name] - if name is not None and name != 'llm': + if name is not None and name != "llm": logger.openhands_logger.warning( - f'llm config group {name} not found, using default config' + f"llm config group {name} not found, using default config" ) - if 'llm' not in self.llms: - self.llms['llm'] = LLMConfig() - return self.llms['llm'] + if "llm" not in self.llms: + self.llms["llm"] = LLMConfig() + return self.llms["llm"] - def set_llm_config(self, value: LLMConfig, name='llm'): + def set_llm_config(self, value: LLMConfig, name="llm"): self.llms[name] = value - def get_agent_config(self, name='agent') -> AgentConfig: + def get_agent_config(self, name="agent") -> AgentConfig: """Agent is the name for default config (for backward compability prior to 0.8)""" if name in self.agents: return self.agents[name] - if 'agent' not in self.agents: - self.agents['agent'] = AgentConfig() - return self.agents['agent'] + if "agent" not in self.agents: + self.agents["agent"] = AgentConfig() + return self.agents["agent"] - def set_agent_config(self, value: AgentConfig, name='agent'): + def set_agent_config(self, value: AgentConfig, name="agent"): self.agents[name] = value def get_agent_to_llm_config_map(self) -> dict[str, LLMConfig]: """Get a map of agent names to llm configs.""" return {name: self.get_llm_config_from_agent(name) for name in self.agents} - def get_llm_config_from_agent(self, name='agent') -> LLMConfig: + def get_llm_config_from_agent(self, name="agent") -> LLMConfig: agent_config: AgentConfig = self.get_agent_config(name) llm_config_name = agent_config.llm_config return self.get_llm_config(llm_config_name) @@ -135,16 +135,16 @@ def __str__(self): attr_value = getattr(self, f.name) if attr_name in [ - 'e2b_api_key', - 'github_token', - 'jwt_secret', - 'modal_api_token_id', - 'modal_api_token_secret', - 'runloop_api_key', + "e2b_api_key", + "github_token", + "jwt_secret", + "modal_api_token_id", + "modal_api_token_secret", + "runloop_api_key", ]: - attr_value = '******' if attr_value else None + attr_value = "******" if attr_value else None - attr_str.append(f'{attr_name}={repr(attr_value)}') + attr_str.append(f"{attr_name}={repr(attr_value)}") return f"AppConfig({', '.join(attr_str)}" diff --git a/openhands/core/config/config_utils.py b/openhands/core/config/config_utils.py index 6e7ddebac611..1b324ef8172c 100644 --- a/openhands/core/config/config_utils.py +++ b/openhands/core/config/config_utils.py @@ -1,7 +1,7 @@ from types import UnionType from typing import get_args, get_origin -OH_DEFAULT_AGENT = 'CodeActAgent' +OH_DEFAULT_AGENT = "CodeActAgent" OH_MAX_ITERATIONS = 100 @@ -29,11 +29,11 @@ def get_field_info(f): # type name in a pretty format type_name = ( - field_type.__name__ if hasattr(field_type, '__name__') else str(field_type) + field_type.__name__ if hasattr(field_type, "__name__") else str(field_type) ) # default is always present default = f.default # return a schema with the useful info for frontend - return {'type': type_name.lower(), 'optional': optional, 'default': default} + return {"type": type_name.lower(), "optional": optional, "default": default} diff --git a/openhands/core/config/llm_config.py b/openhands/core/config/llm_config.py index 477b47ccdbe1..6eac2372fc98 100644 --- a/openhands/core/config/llm_config.py +++ b/openhands/core/config/llm_config.py @@ -5,7 +5,7 @@ from openhands.core.config.config_utils import get_field_info from openhands.core.logger import LOG_DIR -LLM_SENSITIVE_FIELDS = ['api_key', 'aws_access_key_id', 'aws_secret_access_key'] +LLM_SENSITIVE_FIELDS = ["api_key", "aws_access_key_id", "aws_secret_access_key"] @dataclass @@ -45,18 +45,18 @@ class LLMConfig: draft_editor: A more efficient LLM to use for file editing. Introduced in [PR 3985](https://github.com/All-Hands-AI/OpenHands/pull/3985). """ - model: str = 'claude-3-5-sonnet-20241022' + model: str = "claude-3-5-sonnet-20241022" api_key: str | None = None base_url: str | None = None api_version: str | None = None - embedding_model: str = 'local' + embedding_model: str = "local" embedding_base_url: str | None = None embedding_deployment_name: str | None = None aws_access_key_id: str | None = None aws_secret_access_key: str | None = None aws_region_name: str | None = None - openrouter_site_url: str = 'https://docs.all-hands.dev/' - openrouter_app_name: str = 'OpenHands' + openrouter_site_url: str = "https://docs.all-hands.dev/" + openrouter_app_name: str = "OpenHands" num_retries: int = 8 retry_multiplier: float = 2 retry_min_wait: int = 15 @@ -75,8 +75,8 @@ class LLMConfig: disable_vision: bool | None = None caching_prompt: bool = True log_completions: bool = False - log_completions_folder: str = os.path.join(LOG_DIR, 'completions') - draft_editor: Optional['LLMConfig'] = None + log_completions_folder: str = os.path.join(LOG_DIR, "completions") + draft_editor: Optional["LLMConfig"] = None def defaults_to_dict(self) -> dict: """Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional.""" @@ -86,16 +86,14 @@ def defaults_to_dict(self) -> dict: return result def __post_init__(self): - """ - Post-initialization hook to assign OpenRouter-related variables to environment variables. + """Post-initialization hook to assign OpenRouter-related variables to environment variables. This ensures that these values are accessible to litellm at runtime. """ - # Assign OpenRouter-specific variables to environment variables if self.openrouter_site_url: - os.environ['OR_SITE_URL'] = self.openrouter_site_url + os.environ["OR_SITE_URL"] = self.openrouter_site_url if self.openrouter_app_name: - os.environ['OR_APP_NAME'] = self.openrouter_app_name + os.environ["OR_APP_NAME"] = self.openrouter_app_name def __str__(self): attr_str = [] @@ -104,9 +102,9 @@ def __str__(self): attr_value = getattr(self, f.name) if attr_name in LLM_SENSITIVE_FIELDS: - attr_value = '******' if attr_value else None + attr_value = "******" if attr_value else None - attr_str.append(f'{attr_name}={repr(attr_value)}') + attr_str.append(f"{attr_name}={repr(attr_value)}") return f"LLMConfig({', '.join(attr_str)})" @@ -118,20 +116,20 @@ def to_safe_dict(self): ret = self.__dict__.copy() for k, v in ret.items(): if k in LLM_SENSITIVE_FIELDS: - ret[k] = '******' if v else None + ret[k] = "******" if v else None elif isinstance(v, LLMConfig): ret[k] = v.to_safe_dict() return ret @classmethod - def from_dict(cls, llm_config_dict: dict) -> 'LLMConfig': + def from_dict(cls, llm_config_dict: dict) -> "LLMConfig": """Create an LLMConfig object from a dictionary. This function is used to create an LLMConfig object from a dictionary, with the exception of the 'draft_editor' key, which is a nested LLMConfig object. """ args = {k: v for k, v in llm_config_dict.items() if not isinstance(v, dict)} - if 'draft_editor' in llm_config_dict: - draft_editor_config = LLMConfig(**llm_config_dict['draft_editor']) - args['draft_editor'] = draft_editor_config + if "draft_editor" in llm_config_dict: + draft_editor_config = LLMConfig(**llm_config_dict["draft_editor"]) + args["draft_editor"] = draft_editor_config return cls(**args) diff --git a/openhands/core/config/sandbox_config.py b/openhands/core/config/sandbox_config.py index 57f4b189b182..21ba2b01b896 100644 --- a/openhands/core/config/sandbox_config.py +++ b/openhands/core/config/sandbox_config.py @@ -34,13 +34,13 @@ class SandboxConfig: platform: The platform on which the image should be built. Default is None. """ - remote_runtime_api_url: str = 'http://localhost:8000' - local_runtime_url: str = 'http://localhost' + remote_runtime_api_url: str = "http://localhost:8000" + local_runtime_url: str = "http://localhost" keep_runtime_alive: bool = True api_key: str | None = None - base_container_image: str = 'nikolaik/python-nodejs:python3.12-nodejs22' # default to nikolaik/python-nodejs:python3.12-nodejs22 for eventstream runtime + base_container_image: str = "nikolaik/python-nodejs:python3.12-nodejs22" # default to nikolaik/python-nodejs:python3.12-nodejs22 for eventstream runtime runtime_container_image: str | None = None - user_id: int = os.getuid() if hasattr(os, 'getuid') else 1000 + user_id: int = os.getuid() if hasattr(os, "getuid") else 1000 timeout: int = 120 remote_runtime_init_timeout: int = 180 enable_auto_lint: bool = ( @@ -67,7 +67,7 @@ def __str__(self): attr_name = f.name attr_value = getattr(self, f.name) - attr_str.append(f'{attr_name}={repr(attr_value)}') + attr_str.append(f"{attr_name}={repr(attr_value)}") return f"SandboxConfig({', '.join(attr_str)})" diff --git a/openhands/core/config/security_config.py b/openhands/core/config/security_config.py index a4c49c2b0cda..3b143f927524 100644 --- a/openhands/core/config/security_config.py +++ b/openhands/core/config/security_config.py @@ -28,7 +28,7 @@ def __str__(self): attr_name = f.name attr_value = getattr(self, f.name) - attr_str.append(f'{attr_name}={repr(attr_value)}') + attr_str.append(f"{attr_name}={repr(attr_value)}") return f"SecurityConfig({', '.join(attr_str)})" diff --git a/openhands/core/config/utils.py b/openhands/core/config/utils.py index 86794e8aac2f..0e52b9ee5651 100644 --- a/openhands/core/config/utils.py +++ b/openhands/core/config/utils.py @@ -37,7 +37,7 @@ def get_optional_type(union_type: UnionType) -> Any: return next((t for t in types if t is not type(None)), None) # helper function to set attributes based on env vars - def set_attr_from_env(sub_config: Any, prefix=''): + def set_attr_from_env(sub_config: Any, prefix=""): """Set attributes of a config dataclass based on environment variables.""" for field_name, field_type in sub_config.__annotations__.items(): # compute the expected env var name from the prefix and field name @@ -47,7 +47,7 @@ def set_attr_from_env(sub_config: Any, prefix=''): if is_dataclass(field_type): # nested dataclass nested_sub_config = getattr(sub_config, field_name) - set_attr_from_env(nested_sub_config, prefix=field_name + '_') + set_attr_from_env(nested_sub_config, prefix=field_name + "_") elif env_var_name in env_or_toml_dict: # convert the env var to the correct type and set it value = env_or_toml_dict[env_var_name] @@ -63,13 +63,13 @@ def set_attr_from_env(sub_config: Any, prefix=''): # Attempt to cast the env var to type hinted in the dataclass if field_type is bool: - cast_value = str(value).lower() in ['true', '1'] + cast_value = str(value).lower() in ["true", "1"] else: cast_value = field_type(value) setattr(sub_config, field_name, cast_value) except (ValueError, TypeError): logger.openhands_logger.error( - f'Error setting env var {env_var_name}={value}: check that the value is of the right type' + f"Error setting env var {env_var_name}={value}: check that the value is of the right type" ) # Start processing from the root of the config object @@ -77,13 +77,13 @@ def set_attr_from_env(sub_config: Any, prefix=''): # load default LLM config from env default_llm_config = cfg.get_llm_config() - set_attr_from_env(default_llm_config, 'LLM_') + set_attr_from_env(default_llm_config, "LLM_") # load default agent config from env default_agent_config = cfg.get_agent_config() - set_attr_from_env(default_agent_config, 'AGENT_') + set_attr_from_env(default_agent_config, "AGENT_") -def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'): +def load_from_toml(cfg: AppConfig, toml_file: str = "config.toml"): """Load the config from the toml file. Supports both styles of config vars. Args: @@ -92,65 +92,65 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'): """ # try to read the config.toml file into the config object try: - with open(toml_file, 'r', encoding='utf-8') as toml_contents: + with open(toml_file, "r", encoding="utf-8") as toml_contents: toml_config = toml.load(toml_contents) except FileNotFoundError: return except toml.TomlDecodeError as e: logger.openhands_logger.warning( - f'Cannot parse config from toml, toml values have not been applied.\nError: {e}', + f"Cannot parse config from toml, toml values have not been applied.\nError: {e}", exc_info=False, ) return # if there was an exception or core is not in the toml, try to use the old-style toml - if 'core' not in toml_config: + if "core" not in toml_config: # re-use the env loader to set the config from env-style vars load_from_env(cfg, toml_config) return - core_config = toml_config['core'] + core_config = toml_config["core"] # load llm configs and agent configs for key, value in toml_config.items(): if isinstance(value, dict): try: - if key is not None and key.lower() == 'agent': + if key is not None and key.lower() == "agent": logger.openhands_logger.debug( - 'Attempt to load default agent config from config toml' + "Attempt to load default agent config from config toml" ) non_dict_fields = { k: v for k, v in value.items() if not isinstance(v, dict) } agent_config = AgentConfig(**non_dict_fields) - cfg.set_agent_config(agent_config, 'agent') + cfg.set_agent_config(agent_config, "agent") for nested_key, nested_value in value.items(): if isinstance(nested_value, dict): logger.openhands_logger.debug( - f'Attempt to load group {nested_key} from config toml as agent config' + f"Attempt to load group {nested_key} from config toml as agent config" ) agent_config = AgentConfig(**nested_value) cfg.set_agent_config(agent_config, nested_key) - elif key is not None and key.lower() == 'llm': + elif key is not None and key.lower() == "llm": logger.openhands_logger.debug( - 'Attempt to load default LLM config from config toml' + "Attempt to load default LLM config from config toml" ) llm_config = LLMConfig.from_dict(value) - cfg.set_llm_config(llm_config, 'llm') + cfg.set_llm_config(llm_config, "llm") for nested_key, nested_value in value.items(): if isinstance(nested_value, dict): logger.openhands_logger.debug( - f'Attempt to load group {nested_key} from config toml as llm config' + f"Attempt to load group {nested_key} from config toml as llm config" ) llm_config = LLMConfig.from_dict(nested_value) cfg.set_llm_config(llm_config, nested_key) - elif not key.startswith('sandbox') and key.lower() != 'core': + elif not key.startswith("sandbox") and key.lower() != "core": logger.openhands_logger.warning( f'Unknown key in {toml_file}: "{key}"' ) except (TypeError, KeyError) as e: logger.openhands_logger.warning( - f'Cannot parse config from toml, toml values have not been applied.\n Error: {e}', + f"Cannot parse config from toml, toml values have not been applied.\n Error: {e}", exc_info=False, ) else: @@ -161,18 +161,18 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'): sandbox_config = cfg.sandbox # migrate old sandbox configs from [core] section to sandbox config - keys_to_migrate = [key for key in core_config if key.startswith('sandbox_')] + keys_to_migrate = [key for key in core_config if key.startswith("sandbox_")] for key in keys_to_migrate: - new_key = key.replace('sandbox_', '') + new_key = key.replace("sandbox_", "") if new_key in sandbox_config.__annotations__: # read the key in sandbox and remove it from core setattr(sandbox_config, new_key, core_config.pop(key)) else: - logger.openhands_logger.warning(f'Unknown sandbox config: {key}') + logger.openhands_logger.warning(f"Unknown sandbox config: {key}") # the new style values override the old style values - if 'sandbox' in toml_config: - sandbox_config = SandboxConfig(**toml_config['sandbox']) + if "sandbox" in toml_config: + sandbox_config = SandboxConfig(**toml_config["sandbox"]) # update the config object with the new values cfg.sandbox = sandbox_config @@ -180,10 +180,10 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'): if hasattr(cfg, key): setattr(cfg, key, value) else: - logger.openhands_logger.warning(f'Unknown core config key: {key}') + logger.openhands_logger.warning(f"Unknown core config key: {key}") except (TypeError, KeyError) as e: logger.openhands_logger.warning( - f'Cannot parse config from toml, toml values have not been applied.\nError: {e}', + f"Cannot parse config from toml, toml values have not been applied.\nError: {e}", exc_info=False, ) @@ -197,7 +197,7 @@ def finalize_config(cfg: AppConfig): if cfg.workspace_mount_rewrite: base = cfg.workspace_base or os.getcwd() - parts = cfg.workspace_mount_rewrite.split(':') + parts = cfg.workspace_mount_rewrite.split(":") cfg.workspace_mount_path = base.replace(parts[0], parts[1]) # make sure log_completions_folder is an absolute path @@ -206,10 +206,10 @@ def finalize_config(cfg: AppConfig): if llm.embedding_base_url is None: llm.embedding_base_url = llm.base_url - if cfg.sandbox.use_host_network and platform.system() == 'Darwin': + if cfg.sandbox.use_host_network and platform.system() == "Darwin": logger.openhands_logger.warning( - 'Please upgrade to Docker Desktop 4.29.0 or later to use host network mode on macOS. ' - 'See https://github.com/docker/roadmap/issues/238#issuecomment-2044688144 for more information.' + "Please upgrade to Docker Desktop 4.29.0 or later to use host network mode on macOS. " + "See https://github.com/docker/roadmap/issues/238#issuecomment-2044688144 for more information." ) # make sure cache dir exists @@ -219,7 +219,7 @@ def finalize_config(cfg: AppConfig): # Utility function for command line --group argument def get_llm_config_arg( - llm_config_arg: str, toml_file: str = 'config.toml' + llm_config_arg: str, toml_file: str = "config.toml" ) -> LLMConfig | None: """Get a group of llm settings from the config file. @@ -246,127 +246,127 @@ def get_llm_config_arg( LLMConfig: The LLMConfig object with the settings from the config file. """ # keep only the name, just in case - llm_config_arg = llm_config_arg.strip('[]') + llm_config_arg = llm_config_arg.strip("[]") # truncate the prefix, just in case - if llm_config_arg.startswith('llm.'): + if llm_config_arg.startswith("llm."): llm_config_arg = llm_config_arg[4:] - logger.openhands_logger.debug(f'Loading llm config from {llm_config_arg}') + logger.openhands_logger.debug(f"Loading llm config from {llm_config_arg}") # load the toml file try: - with open(toml_file, 'r', encoding='utf-8') as toml_contents: + with open(toml_file, "r", encoding="utf-8") as toml_contents: toml_config = toml.load(toml_contents) except FileNotFoundError as e: - logger.openhands_logger.error(f'Config file not found: {e}') + logger.openhands_logger.error(f"Config file not found: {e}") return None except toml.TomlDecodeError as e: logger.openhands_logger.error( - f'Cannot parse llm group from {llm_config_arg}. Exception: {e}' + f"Cannot parse llm group from {llm_config_arg}. Exception: {e}" ) return None # update the llm config with the specified section - if 'llm' in toml_config and llm_config_arg in toml_config['llm']: - return LLMConfig.from_dict(toml_config['llm'][llm_config_arg]) - logger.openhands_logger.debug(f'Loading from toml failed for {llm_config_arg}') + if "llm" in toml_config and llm_config_arg in toml_config["llm"]: + return LLMConfig.from_dict(toml_config["llm"][llm_config_arg]) + logger.openhands_logger.debug(f"Loading from toml failed for {llm_config_arg}") return None # Command line arguments def get_parser() -> argparse.ArgumentParser: """Get the parser for the command line arguments.""" - parser = argparse.ArgumentParser(description='Run an agent with a specific task') + parser = argparse.ArgumentParser(description="Run an agent with a specific task") parser.add_argument( - '--config-file', + "--config-file", type=str, - default='config.toml', - help='Path to the config file (default: config.toml in the current directory)', + default="config.toml", + help="Path to the config file (default: config.toml in the current directory)", ) parser.add_argument( - '-d', - '--directory', + "-d", + "--directory", type=str, - help='The working directory for the agent', + help="The working directory for the agent", ) parser.add_argument( - '-t', - '--task', + "-t", + "--task", type=str, - default='', - help='The task for the agent to perform', + default="", + help="The task for the agent to perform", ) parser.add_argument( - '-f', - '--file', + "-f", + "--file", type=str, - help='Path to a file containing the task. Overrides -t if both are provided.', + help="Path to a file containing the task. Overrides -t if both are provided.", ) parser.add_argument( - '-c', - '--agent-cls', + "-c", + "--agent-cls", default=OH_DEFAULT_AGENT, type=str, - help='Name of the default agent to use', + help="Name of the default agent to use", ) parser.add_argument( - '-i', - '--max-iterations', + "-i", + "--max-iterations", default=OH_MAX_ITERATIONS, type=int, - help='The maximum number of iterations to run the agent', + help="The maximum number of iterations to run the agent", ) parser.add_argument( - '-b', - '--max-budget-per-task', + "-b", + "--max-budget-per-task", type=float, - help='The maximum budget allowed per task, beyond which the agent will stop.', + help="The maximum budget allowed per task, beyond which the agent will stop.", ) # --eval configs are for evaluations only parser.add_argument( - '--eval-output-dir', - default='evaluation/evaluation_outputs/outputs', + "--eval-output-dir", + default="evaluation/evaluation_outputs/outputs", type=str, - help='The directory to save evaluation output', + help="The directory to save evaluation output", ) parser.add_argument( - '--eval-n-limit', + "--eval-n-limit", default=None, type=int, - help='The number of instances to evaluate', + help="The number of instances to evaluate", ) parser.add_argument( - '--eval-num-workers', + "--eval-num-workers", default=4, type=int, - help='The number of workers to use for evaluation', + help="The number of workers to use for evaluation", ) parser.add_argument( - '--eval-note', + "--eval-note", default=None, type=str, - help='The note to add to the evaluation directory', + help="The note to add to the evaluation directory", ) parser.add_argument( - '-l', - '--llm-config', + "-l", + "--llm-config", default=None, type=str, help='Replace default LLM ([llm] section in config.toml) config with the specified LLM config, e.g. "llama3" for [llm.llama3] section in config.toml', ) parser.add_argument( - '-n', - '--name', - default='default', + "-n", + "--name", + default="default", type=str, - help='Name for the session', + help="Name for the session", ) parser.add_argument( - '--eval-ids', + "--eval-ids", default=None, type=str, - help='The comma-separated list (in quotes) of IDs of the instances to evaluate', + help="The comma-separated list (in quotes) of IDs of the instances to evaluate", ) return parser @@ -379,7 +379,7 @@ def parse_arguments() -> argparse.Namespace: def load_app_config( - set_logging_levels: bool = True, config_file: str = 'config.toml' + set_logging_levels: bool = True, config_file: str = "config.toml" ) -> AppConfig: """Load the configuration from the specified config file and environment variables. diff --git a/openhands/core/const/guide_url.py b/openhands/core/const/guide_url.py index c401de8bb6a1..031fa92dc48e 100644 --- a/openhands/core/const/guide_url.py +++ b/openhands/core/const/guide_url.py @@ -1 +1 @@ -TROUBLESHOOTING_URL = 'https://docs.all-hands.dev/modules/usage/troubleshooting' +TROUBLESHOOTING_URL = "https://docs.all-hands.dev/modules/usage/troubleshooting" diff --git a/openhands/core/logger.py b/openhands/core/logger.py index 20a4a4d6581a..b6968688a459 100644 --- a/openhands/core/logger.py +++ b/openhands/core/logger.py @@ -114,18 +114,14 @@ def print_lines(self): self.replace_current_line(line) def move_back(self, amount=-1): - """ - '\033[F' moves the cursor up one line. - """ + """'\033[F' moves the cursor up one line.""" if amount == -1: amount = self.max_lines self._write('\033[F' * (self.max_lines)) self._flush() def replace_current_line(self, line=''): - """ - '\033[2K\r' clears the line and moves the cursor to the beginning of the line. - """ + """'\033[2K\r' clears the line and moves the cursor to the beginning of the line.""" self._write('\033[2K' + line + '\n') self._flush() diff --git a/openhands/core/loop.py b/openhands/core/loop.py index 2a2808dd0980..26fda9d0e3f8 100644 --- a/openhands/core/loop.py +++ b/openhands/core/loop.py @@ -11,8 +11,7 @@ async def run_agent_until_done( runtime: Runtime, end_states: list[AgentState], ): - """ - run_agent_until_done takes a controller and a runtime, and will run + """run_agent_until_done takes a controller and a runtime, and will run the agent until it reaches a terminal state. Note that runtime must be connected before being passed in here. """ diff --git a/openhands/core/schema/__init__.py b/openhands/core/schema/__init__.py index 370bf022f8ff..47c1eb93e008 100644 --- a/openhands/core/schema/__init__.py +++ b/openhands/core/schema/__init__.py @@ -4,8 +4,8 @@ from openhands.core.schema.observation import ObservationType __all__ = [ - 'ActionType', - 'ObservationType', - 'ConfigType', - 'AgentState', + "ActionType", + "ObservationType", + "ConfigType", + "AgentState", ] diff --git a/openhands/core/schema/action.py b/openhands/core/schema/action.py index dc4cfe542e0a..646d29cac4e4 100644 --- a/openhands/core/schema/action.py +++ b/openhands/core/schema/action.py @@ -1,89 +1,89 @@ from pydantic import BaseModel, Field -__all__ = ['ActionType'] +__all__ = ["ActionType"] class ActionTypeSchema(BaseModel): - INIT: str = Field(default='initialize') + INIT: str = Field(default="initialize") """Initializes the agent. Only sent by client. """ - MESSAGE: str = Field(default='message') + MESSAGE: str = Field(default="message") """Represents a message. """ - START: str = Field(default='start') + START: str = Field(default="start") """Starts a new development task OR send chat from the user. Only sent by the client. """ - READ: str = Field(default='read') + READ: str = Field(default="read") """Reads the content of a file. """ - WRITE: str = Field(default='write') + WRITE: str = Field(default="write") """Writes the content to a file. """ - EDIT: str = Field(default='edit') + EDIT: str = Field(default="edit") """Edits a file by providing a draft. """ - RUN: str = Field(default='run') + RUN: str = Field(default="run") """Runs a command. """ - RUN_IPYTHON: str = Field(default='run_ipython') + RUN_IPYTHON: str = Field(default="run_ipython") """Runs a IPython cell. """ - BROWSE: str = Field(default='browse') + BROWSE: str = Field(default="browse") """Opens a web page. """ - BROWSE_INTERACTIVE: str = Field(default='browse_interactive') + BROWSE_INTERACTIVE: str = Field(default="browse_interactive") """Interact with the browser instance. """ - DELEGATE: str = Field(default='delegate') + DELEGATE: str = Field(default="delegate") """Delegates a task to another agent. """ - FINISH: str = Field(default='finish') + FINISH: str = Field(default="finish") """If you're absolutely certain that you've completed your task and have tested your work, use the finish action to stop working. """ - REJECT: str = Field(default='reject') + REJECT: str = Field(default="reject") """If you're absolutely certain that you cannot complete the task with given requirements, use the reject action to stop working. """ - NULL: str = Field(default='null') + NULL: str = Field(default="null") - SUMMARIZE: str = Field(default='summarize') + SUMMARIZE: str = Field(default="summarize") - ADD_TASK: str = Field(default='add_task') + ADD_TASK: str = Field(default="add_task") - MODIFY_TASK: str = Field(default='modify_task') + MODIFY_TASK: str = Field(default="modify_task") - PAUSE: str = Field(default='pause') + PAUSE: str = Field(default="pause") """Pauses the task. """ - RESUME: str = Field(default='resume') + RESUME: str = Field(default="resume") """Resumes the task. """ - STOP: str = Field(default='stop') + STOP: str = Field(default="stop") """Stops the task. Must send a start action to restart a new task. """ - CHANGE_AGENT_STATE: str = Field(default='change_agent_state') + CHANGE_AGENT_STATE: str = Field(default="change_agent_state") - PUSH: str = Field(default='push') + PUSH: str = Field(default="push") """Push a branch to github.""" - SEND_PR: str = Field(default='send_pr') + SEND_PR: str = Field(default="send_pr") """Send a PR to github.""" diff --git a/openhands/core/schema/agent.py b/openhands/core/schema/agent.py index 4ea09d7afc2a..19eb6975403a 100644 --- a/openhands/core/schema/agent.py +++ b/openhands/core/schema/agent.py @@ -2,50 +2,50 @@ class AgentState(str, Enum): - LOADING = 'loading' + LOADING = "loading" """The agent is loading. """ - INIT = 'init' + INIT = "init" """The agent is initialized. """ - RUNNING = 'running' + RUNNING = "running" """The agent is running. """ - AWAITING_USER_INPUT = 'awaiting_user_input' + AWAITING_USER_INPUT = "awaiting_user_input" """The agent is awaiting user input. """ - PAUSED = 'paused' + PAUSED = "paused" """The agent is paused. """ - STOPPED = 'stopped' + STOPPED = "stopped" """The agent is stopped. """ - FINISHED = 'finished' + FINISHED = "finished" """The agent is finished with the current task. """ - REJECTED = 'rejected' + REJECTED = "rejected" """The agent rejects the task. """ - ERROR = 'error' + ERROR = "error" """An error occurred during the task. """ - AWAITING_USER_CONFIRMATION = 'awaiting_user_confirmation' + AWAITING_USER_CONFIRMATION = "awaiting_user_confirmation" """The agent is awaiting user confirmation. """ - USER_CONFIRMED = 'user_confirmed' + USER_CONFIRMED = "user_confirmed" """The user confirmed the agent's action. """ - USER_REJECTED = 'user_rejected' + USER_REJECTED = "user_rejected" """The user rejected the agent's action. """ diff --git a/openhands/core/schema/config.py b/openhands/core/schema/config.py index 1272ebe655a5..c58f315b35f6 100644 --- a/openhands/core/schema/config.py +++ b/openhands/core/schema/config.py @@ -3,47 +3,47 @@ class ConfigType(str, Enum): # For frontend - AGENT = 'AGENT' - AGENT_MEMORY_ENABLED = 'AGENT_MEMORY_ENABLED' - AGENT_MEMORY_MAX_THREADS = 'AGENT_MEMORY_MAX_THREADS' - AWS_ACCESS_KEY_ID = 'AWS_ACCESS_KEY_ID' - AWS_REGION_NAME = 'AWS_REGION_NAME' - AWS_SECRET_ACCESS_KEY = 'AWS_SECRET_ACCESS_KEY' - BASE_CONTAINER_IMAGE = 'BASE_CONTAINER_IMAGE' - CACHE_DIR = 'CACHE_DIR' - CONFIRMATION_MODE = 'CONFIRMATION_MODE' - DEBUG = 'DEBUG' - DISABLE_COLOR = 'DISABLE_COLOR' - E2B_API_KEY = 'E2B_API_KEY' - FILE_UPLOADS_ALLOWED_EXTENSIONS = 'FILE_UPLOADS_ALLOWED_EXTENSIONS' - FILE_UPLOADS_MAX_FILE_SIZE_MB = 'FILE_UPLOADS_MAX_FILE_SIZE_MB' - FILE_UPLOADS_RESTRICT_FILE_TYPES = 'FILE_UPLOADS_RESTRICT_FILE_TYPES' - LLM_API_KEY = 'LLM_API_KEY' - LLM_API_VERSION = 'LLM_API_VERSION' - LLM_BASE_URL = 'LLM_BASE_URL' - LLM_CACHING_PROMPT = 'LLM_CACHING_PROMPT' - LLM_CUSTOM_LLM_PROVIDER = 'LLM_CUSTOM_LLM_PROVIDER' - LLM_DROP_PARAMS = 'LLM_DROP_PARAMS' - LLM_EMBEDDING_BASE_URL = 'LLM_EMBEDDING_BASE_URL' - LLM_EMBEDDING_DEPLOYMENT_NAME = 'LLM_EMBEDDING_DEPLOYMENT_NAME' - LLM_EMBEDDING_MODEL = 'LLM_EMBEDDING_MODEL' - LLM_MAX_INPUT_TOKENS = 'LLM_MAX_INPUT_TOKENS' - LLM_MAX_OUTPUT_TOKENS = 'LLM_MAX_OUTPUT_TOKENS' - LLM_MODEL = 'LLM_MODEL' - LLM_NUM_RETRIES = 'LLM_NUM_RETRIES' - LLM_RETRY_MAX_WAIT = 'LLM_RETRY_MAX_WAIT' - LLM_RETRY_MIN_WAIT = 'LLM_RETRY_MIN_WAIT' - LLM_TEMPERATURE = 'LLM_TEMPERATURE' - LLM_TIMEOUT = 'LLM_TIMEOUT' - LLM_TOP_P = 'LLM_TOP_P' - LLM_DISABLE_VISION = 'LLM_DISABLE_VISION' - MAX_ITERATIONS = 'MAX_ITERATIONS' - RUN_AS_OPENHANDS = 'RUN_AS_OPENHANDS' - SANDBOX_TIMEOUT = 'SANDBOX_TIMEOUT' - SANDBOX_USER_ID = 'SANDBOX_USER_ID' - SECURITY_ANALYZER = 'SECURITY_ANALYZER' - USE_HOST_NETWORK = 'USE_HOST_NETWORK' - WORKSPACE_BASE = 'WORKSPACE_BASE' - WORKSPACE_MOUNT_PATH = 'WORKSPACE_MOUNT_PATH' - WORKSPACE_MOUNT_PATH_IN_SANDBOX = 'WORKSPACE_MOUNT_PATH_IN_SANDBOX' - WORKSPACE_MOUNT_REWRITE = 'WORKSPACE_MOUNT_REWRITE' + AGENT = "AGENT" + AGENT_MEMORY_ENABLED = "AGENT_MEMORY_ENABLED" + AGENT_MEMORY_MAX_THREADS = "AGENT_MEMORY_MAX_THREADS" + AWS_ACCESS_KEY_ID = "AWS_ACCESS_KEY_ID" + AWS_REGION_NAME = "AWS_REGION_NAME" + AWS_SECRET_ACCESS_KEY = "AWS_SECRET_ACCESS_KEY" + BASE_CONTAINER_IMAGE = "BASE_CONTAINER_IMAGE" + CACHE_DIR = "CACHE_DIR" + CONFIRMATION_MODE = "CONFIRMATION_MODE" + DEBUG = "DEBUG" + DISABLE_COLOR = "DISABLE_COLOR" + E2B_API_KEY = "E2B_API_KEY" + FILE_UPLOADS_ALLOWED_EXTENSIONS = "FILE_UPLOADS_ALLOWED_EXTENSIONS" + FILE_UPLOADS_MAX_FILE_SIZE_MB = "FILE_UPLOADS_MAX_FILE_SIZE_MB" + FILE_UPLOADS_RESTRICT_FILE_TYPES = "FILE_UPLOADS_RESTRICT_FILE_TYPES" + LLM_API_KEY = "LLM_API_KEY" + LLM_API_VERSION = "LLM_API_VERSION" + LLM_BASE_URL = "LLM_BASE_URL" + LLM_CACHING_PROMPT = "LLM_CACHING_PROMPT" + LLM_CUSTOM_LLM_PROVIDER = "LLM_CUSTOM_LLM_PROVIDER" + LLM_DROP_PARAMS = "LLM_DROP_PARAMS" + LLM_EMBEDDING_BASE_URL = "LLM_EMBEDDING_BASE_URL" + LLM_EMBEDDING_DEPLOYMENT_NAME = "LLM_EMBEDDING_DEPLOYMENT_NAME" + LLM_EMBEDDING_MODEL = "LLM_EMBEDDING_MODEL" + LLM_MAX_INPUT_TOKENS = "LLM_MAX_INPUT_TOKENS" + LLM_MAX_OUTPUT_TOKENS = "LLM_MAX_OUTPUT_TOKENS" + LLM_MODEL = "LLM_MODEL" + LLM_NUM_RETRIES = "LLM_NUM_RETRIES" + LLM_RETRY_MAX_WAIT = "LLM_RETRY_MAX_WAIT" + LLM_RETRY_MIN_WAIT = "LLM_RETRY_MIN_WAIT" + LLM_TEMPERATURE = "LLM_TEMPERATURE" + LLM_TIMEOUT = "LLM_TIMEOUT" + LLM_TOP_P = "LLM_TOP_P" + LLM_DISABLE_VISION = "LLM_DISABLE_VISION" + MAX_ITERATIONS = "MAX_ITERATIONS" + RUN_AS_OPENHANDS = "RUN_AS_OPENHANDS" + SANDBOX_TIMEOUT = "SANDBOX_TIMEOUT" + SANDBOX_USER_ID = "SANDBOX_USER_ID" + SECURITY_ANALYZER = "SECURITY_ANALYZER" + USE_HOST_NETWORK = "USE_HOST_NETWORK" + WORKSPACE_BASE = "WORKSPACE_BASE" + WORKSPACE_MOUNT_PATH = "WORKSPACE_MOUNT_PATH" + WORKSPACE_MOUNT_PATH_IN_SANDBOX = "WORKSPACE_MOUNT_PATH_IN_SANDBOX" + WORKSPACE_MOUNT_REWRITE = "WORKSPACE_MOUNT_REWRITE" diff --git a/openhands/core/schema/observation.py b/openhands/core/schema/observation.py index 622f2680f785..232d8b7ee090 100644 --- a/openhands/core/schema/observation.py +++ b/openhands/core/schema/observation.py @@ -1,48 +1,48 @@ from pydantic import BaseModel, Field -__all__ = ['ObservationType'] +__all__ = ["ObservationType"] class ObservationTypeSchema(BaseModel): - READ: str = Field(default='read') + READ: str = Field(default="read") """The content of a file """ - WRITE: str = Field(default='write') + WRITE: str = Field(default="write") - EDIT: str = Field(default='edit') + EDIT: str = Field(default="edit") - BROWSE: str = Field(default='browse') + BROWSE: str = Field(default="browse") """The HTML content of a URL """ - RUN: str = Field(default='run') + RUN: str = Field(default="run") """The output of a command """ - RUN_IPYTHON: str = Field(default='run_ipython') + RUN_IPYTHON: str = Field(default="run_ipython") """Runs a IPython cell. """ - CHAT: str = Field(default='chat') + CHAT: str = Field(default="chat") """A message from the user """ - DELEGATE: str = Field(default='delegate') + DELEGATE: str = Field(default="delegate") """The result of a task delegated to another agent """ - MESSAGE: str = Field(default='message') + MESSAGE: str = Field(default="message") - ERROR: str = Field(default='error') + ERROR: str = Field(default="error") - SUCCESS: str = Field(default='success') + SUCCESS: str = Field(default="success") - NULL: str = Field(default='null') + NULL: str = Field(default="null") - AGENT_STATE_CHANGED: str = Field(default='agent_state_changed') + AGENT_STATE_CHANGED: str = Field(default="agent_state_changed") - USER_REJECTED: str = Field(default='user_rejected') + USER_REJECTED: str = Field(default="user_rejected") ObservationType = ObservationTypeSchema() diff --git a/openhands/core/utils/json.py b/openhands/core/utils/json.py index c0b22740bec4..ad78bf977f8c 100644 --- a/openhands/core/utils/json.py +++ b/openhands/core/utils/json.py @@ -37,11 +37,11 @@ def loads(json_str, **kwargs): depth = 0 start = -1 for i, char in enumerate(json_str): - if char == '{': + if char == "{": if depth == 0: start = i depth += 1 - elif char == '}': + elif char == "}": depth -= 1 if depth == 0 and start != -1: response = json_str[start : i + 1] @@ -50,6 +50,6 @@ def loads(json_str, **kwargs): return json.loads(json_str, **kwargs) except (json.JSONDecodeError, ValueError, TypeError) as e: raise LLMResponseError( - 'Invalid JSON in response. Please make sure the response is a valid JSON object.' + "Invalid JSON in response. Please make sure the response is a valid JSON object." ) from e - raise LLMResponseError('No valid JSON object found in response.') + raise LLMResponseError("No valid JSON object found in response.") diff --git a/openhands/events/action/__init__.py b/openhands/events/action/__init__.py index 129cb3073982..e78b15600c31 100644 --- a/openhands/events/action/__init__.py +++ b/openhands/events/action/__init__.py @@ -18,22 +18,22 @@ from openhands.events.action.tasks import AddTaskAction, ModifyTaskAction __all__ = [ - 'Action', - 'NullAction', - 'CmdRunAction', - 'BrowseURLAction', - 'BrowseInteractiveAction', - 'FileReadAction', - 'FileWriteAction', - 'FileEditAction', - 'AgentFinishAction', - 'AgentRejectAction', - 'AgentDelegateAction', - 'AgentSummarizeAction', - 'AddTaskAction', - 'ModifyTaskAction', - 'ChangeAgentStateAction', - 'IPythonRunCellAction', - 'MessageAction', - 'ActionConfirmationStatus', + "Action", + "NullAction", + "CmdRunAction", + "BrowseURLAction", + "BrowseInteractiveAction", + "FileReadAction", + "FileWriteAction", + "FileEditAction", + "AgentFinishAction", + "AgentRejectAction", + "AgentDelegateAction", + "AgentSummarizeAction", + "AddTaskAction", + "ModifyTaskAction", + "ChangeAgentStateAction", + "IPythonRunCellAction", + "MessageAction", + "ActionConfirmationStatus", ] diff --git a/openhands/events/action/action.py b/openhands/events/action/action.py index 0605af7ed53e..e70cf037ce3f 100644 --- a/openhands/events/action/action.py +++ b/openhands/events/action/action.py @@ -6,9 +6,9 @@ class ActionConfirmationStatus(str, Enum): - CONFIRMED = 'confirmed' - REJECTED = 'rejected' - AWAITING_CONFIRMATION = 'awaiting_confirmation' + CONFIRMED = "confirmed" + REJECTED = "rejected" + AWAITING_CONFIRMATION = "awaiting_confirmation" class ActionSecurityRisk(int, Enum): diff --git a/openhands/events/action/agent.py b/openhands/events/action/agent.py index f49f573ed698..1e11ec5c67a4 100644 --- a/openhands/events/action/agent.py +++ b/openhands/events/action/agent.py @@ -10,12 +10,12 @@ class ChangeAgentStateAction(Action): """Fake action, just to notify the client that a task state has changed.""" agent_state: str - thought: str = '' + thought: str = "" action: str = ActionType.CHANGE_AGENT_STATE @property def message(self) -> str: - return f'Agent state changed to {self.agent_state}' + return f"Agent state changed to {self.agent_state}" @dataclass @@ -28,8 +28,8 @@ def message(self) -> str: return self.summary def __str__(self) -> str: - ret = '**AgentSummarizeAction**\n' - ret += f'SUMMARY: {self.summary}' + ret = "**AgentSummarizeAction**\n" + ret += f"SUMMARY: {self.summary}" return ret @@ -44,12 +44,12 @@ class AgentFinishAction(Action): """ outputs: dict[str, Any] = field(default_factory=dict) - thought: str = '' + thought: str = "" action: str = ActionType.FINISH @property def message(self) -> str: - if self.thought != '': + if self.thought != "": return self.thought return "All done! What's next on the agenda?" @@ -57,14 +57,14 @@ def message(self) -> str: @dataclass class AgentRejectAction(Action): outputs: dict = field(default_factory=dict) - thought: str = '' + thought: str = "" action: str = ActionType.REJECT @property def message(self) -> str: - msg: str = 'Task is rejected by the agent.' - if 'reason' in self.outputs: - msg += ' Reason: ' + self.outputs['reason'] + msg: str = "Task is rejected by the agent." + if "reason" in self.outputs: + msg += " Reason: " + self.outputs["reason"] return msg @@ -72,7 +72,7 @@ def message(self) -> str: class AgentDelegateAction(Action): agent: str inputs: dict - thought: str = '' + thought: str = "" action: str = ActionType.DELEGATE @property diff --git a/openhands/events/action/browse.py b/openhands/events/action/browse.py index 41816216d6d5..d2a9dab7f6ab 100644 --- a/openhands/events/action/browse.py +++ b/openhands/events/action/browse.py @@ -8,28 +8,28 @@ @dataclass class BrowseURLAction(Action): url: str - thought: str = '' + thought: str = "" action: str = ActionType.BROWSE runnable: ClassVar[bool] = True security_risk: ActionSecurityRisk | None = None @property def message(self) -> str: - return f'Browsing URL: {self.url}' + return f"Browsing URL: {self.url}" def __str__(self) -> str: - ret = '**BrowseURLAction**\n' + ret = "**BrowseURLAction**\n" if self.thought: - ret += f'THOUGHT: {self.thought}\n' - ret += f'URL: {self.url}' + ret += f"THOUGHT: {self.thought}\n" + ret += f"URL: {self.url}" return ret @dataclass class BrowseInteractiveAction(Action): browser_actions: str - thought: str = '' - browsergym_send_msg_to_user: str = '' + thought: str = "" + browsergym_send_msg_to_user: str = "" action: str = ActionType.BROWSE_INTERACTIVE runnable: ClassVar[bool] = True security_risk: ActionSecurityRisk | None = None @@ -37,12 +37,12 @@ class BrowseInteractiveAction(Action): @property def message(self) -> str: return ( - f'I am interacting with the browser:\n' f'```\n{self.browser_actions}\n```' + f"I am interacting with the browser:\n" f"```\n{self.browser_actions}\n```" ) def __str__(self) -> str: - ret = '**BrowseInteractiveAction**\n' + ret = "**BrowseInteractiveAction**\n" if self.thought: - ret += f'THOUGHT: {self.thought}\n' - ret += f'BROWSER_ACTIONS: {self.browser_actions}' + ret += f"THOUGHT: {self.thought}\n" + ret += f"BROWSER_ACTIONS: {self.browser_actions}" return ret diff --git a/openhands/events/action/commands.py b/openhands/events/action/commands.py index 83dd19f9d161..1989081ebe61 100644 --- a/openhands/events/action/commands.py +++ b/openhands/events/action/commands.py @@ -12,7 +12,7 @@ @dataclass class CmdRunAction(Action): command: str - thought: str = '' + thought: str = "" blocking: bool = False # If False, the command will be run in a non-blocking / interactive way # The partial command outputs will be returned as output observation. @@ -33,20 +33,20 @@ class CmdRunAction(Action): @property def message(self) -> str: - return f'Running command: {self.command}' + return f"Running command: {self.command}" def __str__(self) -> str: - ret = f'**CmdRunAction (source={self.source})**\n' + ret = f"**CmdRunAction (source={self.source})**\n" if self.thought: - ret += f'THOUGHT: {self.thought}\n' - ret += f'COMMAND:\n{self.command}' + ret += f"THOUGHT: {self.thought}\n" + ret += f"COMMAND:\n{self.command}" return ret @dataclass class IPythonRunCellAction(Action): code: str - thought: str = '' + thought: str = "" include_extra: bool = ( True # whether to include CWD & Python interpreter in the output ) @@ -54,15 +54,15 @@ class IPythonRunCellAction(Action): runnable: ClassVar[bool] = True confirmation_state: ActionConfirmationStatus = ActionConfirmationStatus.CONFIRMED security_risk: ActionSecurityRisk | None = None - kernel_init_code: str = '' # code to run in the kernel (if the kernel is restarted) + kernel_init_code: str = "" # code to run in the kernel (if the kernel is restarted) def __str__(self) -> str: - ret = '**IPythonRunCellAction**\n' + ret = "**IPythonRunCellAction**\n" if self.thought: - ret += f'THOUGHT: {self.thought}\n' - ret += f'CODE:\n{self.code}' + ret += f"THOUGHT: {self.thought}\n" + ret += f"CODE:\n{self.code}" return ret @property def message(self) -> str: - return f'Running Python code interactively: {self.code}' + return f"Running Python code interactively: {self.code}" diff --git a/openhands/events/action/empty.py b/openhands/events/action/empty.py index 32e034600102..a1b496cd725c 100644 --- a/openhands/events/action/empty.py +++ b/openhands/events/action/empty.py @@ -12,4 +12,4 @@ class NullAction(Action): @property def message(self) -> str: - return 'No action' + return "No action" diff --git a/openhands/events/action/files.py b/openhands/events/action/files.py index 3e2131228b6b..d4b3c848e17b 100644 --- a/openhands/events/action/files.py +++ b/openhands/events/action/files.py @@ -15,14 +15,14 @@ class FileReadAction(Action): path: str start: int = 0 end: int = -1 - thought: str = '' + thought: str = "" action: str = ActionType.READ runnable: ClassVar[bool] = True security_risk: ActionSecurityRisk | None = None @property def message(self) -> str: - return f'Reading file: {self.path}' + return f"Reading file: {self.path}" @dataclass @@ -36,14 +36,14 @@ class FileWriteAction(Action): content: str start: int = 0 end: int = -1 - thought: str = '' + thought: str = "" action: str = ActionType.WRITE runnable: ClassVar[bool] = True security_risk: ActionSecurityRisk | None = None @property def message(self) -> str: - return f'Writing file: {self.path}' + return f"Writing file: {self.path}" @dataclass @@ -60,15 +60,15 @@ class FileEditAction(Action): content: str start: int = 1 end: int = -1 - thought: str = '' + thought: str = "" action: str = ActionType.EDIT runnable: ClassVar[bool] = True security_risk: ActionSecurityRisk | None = None def __repr__(self) -> str: - ret = '**FileEditAction**\n' - ret += f'Thought: {self.thought}\n' - ret += f'Range: [L{self.start}:L{self.end}]\n' - ret += f'Path: [{self.path}]\n' - ret += f'Content:\n```\n{self.content}\n```\n' + ret = "**FileEditAction**\n" + ret += f"Thought: {self.thought}\n" + ret += f"Range: [L{self.start}:L{self.end}]\n" + ret += f"Path: [{self.path}]\n" + ret += f"Content:\n```\n{self.content}\n```\n" return ret diff --git a/openhands/events/action/message.py b/openhands/events/action/message.py index 86d7c439e936..c9500004971f 100644 --- a/openhands/events/action/message.py +++ b/openhands/events/action/message.py @@ -24,10 +24,11 @@ def images_urls(self): @images_urls.setter def images_urls(self, value): self.image_urls = value + def __str__(self) -> str: - ret = f'**MessageAction** (source={self.source})\n' - ret += f'CONTENT: {self.content}' + ret = f"**MessageAction** (source={self.source})\n" + ret += f"CONTENT: {self.content}" if self.image_urls: for url in self.image_urls: - ret += f'\nIMAGE_URL: {url}' + ret += f"\nIMAGE_URL: {url}" return ret diff --git a/openhands/events/action/tasks.py b/openhands/events/action/tasks.py index b1f1c215f74d..429c24a959d6 100644 --- a/openhands/events/action/tasks.py +++ b/openhands/events/action/tasks.py @@ -9,21 +9,21 @@ class AddTaskAction(Action): parent: str goal: str subtasks: list = field(default_factory=list) - thought: str = '' + thought: str = "" action: str = ActionType.ADD_TASK @property def message(self) -> str: - return f'Added task: {self.goal}' + return f"Added task: {self.goal}" @dataclass class ModifyTaskAction(Action): task_id: str state: str - thought: str = '' + thought: str = "" action: str = ActionType.MODIFY_TASK @property def message(self) -> str: - return f'Set task {self.task_id} to {self.state}' + return f"Set task {self.task_id} to {self.state}" diff --git a/openhands/events/observation/__init__.py b/openhands/events/observation/__init__.py index 28525b09aabb..7cd08f5449ee 100644 --- a/openhands/events/observation/__init__.py +++ b/openhands/events/observation/__init__.py @@ -17,17 +17,17 @@ from openhands.events.observation.success import SuccessObservation __all__ = [ - 'Observation', - 'NullObservation', - 'CmdOutputObservation', - 'IPythonRunCellObservation', - 'BrowserOutputObservation', - 'FileReadObservation', - 'FileWriteObservation', - 'FileEditObservation', - 'ErrorObservation', - 'AgentStateChangedObservation', - 'AgentDelegateObservation', - 'SuccessObservation', - 'UserRejectObservation', + "Observation", + "NullObservation", + "CmdOutputObservation", + "IPythonRunCellObservation", + "BrowserOutputObservation", + "FileReadObservation", + "FileWriteObservation", + "FileEditObservation", + "ErrorObservation", + "AgentStateChangedObservation", + "AgentDelegateObservation", + "SuccessObservation", + "UserRejectObservation", ] diff --git a/openhands/events/observation/agent.py b/openhands/events/observation/agent.py index 802c23c3786d..2ad71a4773ca 100644 --- a/openhands/events/observation/agent.py +++ b/openhands/events/observation/agent.py @@ -13,4 +13,4 @@ class AgentStateChangedObservation(Observation): @property def message(self) -> str: - return '' + return "" diff --git a/openhands/events/observation/browse.py b/openhands/events/observation/browse.py index 9632fac57d54..b7aa3d1936a2 100644 --- a/openhands/events/observation/browse.py +++ b/openhands/events/observation/browse.py @@ -22,43 +22,43 @@ class BrowserOutputObservation(Observation): extra_element_properties: dict = field( default_factory=dict, repr=False ) # don't show in repr - last_browser_action: str = '' - last_browser_action_error: str = '' - focused_element_bid: str = '' + last_browser_action: str = "" + last_browser_action_error: str = "" + focused_element_bid: str = "" @property def message(self) -> str: - return 'Visited ' + self.url + return "Visited " + self.url def __str__(self) -> str: ret = ( - '**BrowserOutputObservation**\n' - f'URL: {self.url}\n' - f'Error: {self.error}\n' - f'Open pages: {self.open_pages_urls}\n' - f'Active page index: {self.active_page_index}\n' - f'Last browser action: {self.last_browser_action}\n' - f'Last browser action error: {self.last_browser_action_error}\n' - f'Focused element bid: {self.focused_element_bid}\n' - f'Content: {self.content}\n' + "**BrowserOutputObservation**\n" + f"URL: {self.url}\n" + f"Error: {self.error}\n" + f"Open pages: {self.open_pages_urls}\n" + f"Active page index: {self.active_page_index}\n" + f"Last browser action: {self.last_browser_action}\n" + f"Last browser action error: {self.last_browser_action_error}\n" + f"Focused element bid: {self.focused_element_bid}\n" + f"Content: {self.content}\n" ) - ret += '--- Agent Observation ---\n' + ret += "--- Agent Observation ---\n" ret += self.get_agent_obs_text() return ret def get_agent_obs_text(self) -> str: """Get a concise text that will be shown to the agent.""" - text = f'[Current URL: {self.url}]\n' - text += f'[Focused element bid: {self.focused_element_bid}]\n\n' + text = f"[Current URL: {self.url}]\n" + text += f"[Focused element bid: {self.focused_element_bid}]\n\n" if self.error: text += ( - '================ BEGIN error message ===============\n' - 'The following error occurred when executing the last action:\n' - f'{self.last_browser_action_error}\n' - '================ END error message ===============\n' + "================ BEGIN error message ===============\n" + "The following error occurred when executing the last action:\n" + f"{self.last_browser_action_error}\n" + "================ END error message ===============\n" ) else: - text += '[Action executed successfully.]\n' + text += "[Action executed successfully.]\n" try: # We do not filter visible only here because we want to show the full content @@ -66,12 +66,12 @@ def get_agent_obs_text(self) -> str: # FIXME: handle the case when the web page is too large cur_axtree_txt = self.get_axtree_str(filter_visible_only=False) text += ( - f'============== BEGIN accessibility tree ==============\n' - f'{cur_axtree_txt}\n' - f'============== END accessibility tree ==============\n' + f"============== BEGIN accessibility tree ==============\n" + f"{cur_axtree_txt}\n" + f"============== END accessibility tree ==============\n" ) except Exception as e: - text += f'\n[Error encountered when processing the accessibility tree: {e}]' + text += f"\n[Error encountered when processing the accessibility tree: {e}]" return text def get_axtree_str(self, filter_visible_only: bool = False) -> str: diff --git a/openhands/events/observation/commands.py b/openhands/events/observation/commands.py index a182168e694a..c1b76b247cfa 100644 --- a/openhands/events/observation/commands.py +++ b/openhands/events/observation/commands.py @@ -13,7 +13,7 @@ class CmdOutputObservation(Observation): exit_code: int = 0 hidden: bool = False observation: str = ObservationType.RUN - interpreter_details: str = '' + interpreter_details: str = "" @property def error(self) -> bool: @@ -21,10 +21,10 @@ def error(self) -> bool: @property def message(self) -> str: - return f'Command `{self.command}` executed with exit code {self.exit_code}.' + return f"Command `{self.command}` executed with exit code {self.exit_code}." def __str__(self) -> str: - return f'**CmdOutputObservation (source={self.source}, exit code={self.exit_code})**\n{self.content}' + return f"**CmdOutputObservation (source={self.source}, exit code={self.exit_code})**\n{self.content}" @dataclass @@ -40,7 +40,7 @@ def error(self) -> bool: @property def message(self) -> str: - return 'Code executed in IPython cell.' + return "Code executed in IPython cell." def __str__(self) -> str: - return f'**IPythonRunCellObservation**\n{self.content}' + return f"**IPythonRunCellObservation**\n{self.content}" diff --git a/openhands/events/observation/delegate.py b/openhands/events/observation/delegate.py index 9e98c6b5982a..95e05b5ff3ff 100644 --- a/openhands/events/observation/delegate.py +++ b/openhands/events/observation/delegate.py @@ -19,4 +19,4 @@ class AgentDelegateObservation(Observation): @property def message(self) -> str: - return '' + return "" diff --git a/openhands/events/observation/empty.py b/openhands/events/observation/empty.py index 9d7d0f18a792..5cc802804d22 100644 --- a/openhands/events/observation/empty.py +++ b/openhands/events/observation/empty.py @@ -14,4 +14,4 @@ class NullObservation(Observation): @property def message(self) -> str: - return 'No observation' + return "No observation" diff --git a/openhands/events/observation/error.py b/openhands/events/observation/error.py index 4ed05b89ac78..d91ac1bf2a1e 100644 --- a/openhands/events/observation/error.py +++ b/openhands/events/observation/error.py @@ -13,11 +13,11 @@ class ErrorObservation(Observation): """ observation: str = ObservationType.ERROR - error_id: str = '' + error_id: str = "" @property def message(self) -> str: return self.content def __str__(self) -> str: - return f'**ErrorObservation**\n{self.content}' + return f"**ErrorObservation**\n{self.content}" diff --git a/openhands/events/observation/files.py b/openhands/events/observation/files.py index bfc45264ccae..563c601579cd 100644 --- a/openhands/events/observation/files.py +++ b/openhands/events/observation/files.py @@ -14,7 +14,7 @@ class FileReadObservation(Observation): @property def message(self) -> str: - return f'I read the file {self.path}.' + return f"I read the file {self.path}." @dataclass @@ -26,7 +26,7 @@ class FileWriteObservation(Observation): @property def message(self) -> str: - return f'I wrote to the file {self.path}.' + return f"I wrote to the file {self.path}." @dataclass @@ -42,12 +42,12 @@ class FileEditObservation(Observation): @property def message(self) -> str: - return f'I edited the file {self.path}.' + return f"I edited the file {self.path}." def get_edit_groups(self, n_context_lines: int = 2) -> list[dict[str, list[str]]]: """Get the edit groups of the file edit.""" - old_lines = self.old_content.split('\n') - new_lines = self.new_content.split('\n') + old_lines = self.old_content.split("\n") + new_lines = self.new_content.split("\n") # Borrowed from difflib.unified_diff to directly parse into structured format. edit_groups: list[dict] = [] for group in SequenceMatcher(None, old_lines, new_lines).get_grouped_opcodes( @@ -56,29 +56,29 @@ def get_edit_groups(self, n_context_lines: int = 2) -> list[dict[str, list[str]] # take the max line number in the group _indent_pad_size = len(str(group[-1][3])) + 1 # +1 for the "*" prefix cur_group: dict[str, list[str]] = { - 'before_edits': [], - 'after_edits': [], + "before_edits": [], + "after_edits": [], } for tag, i1, i2, j1, j2 in group: - if tag == 'equal': + if tag == "equal": for idx, line in enumerate(old_lines[i1:i2]): - cur_group['before_edits'].append( - f'{i1+idx+1:>{_indent_pad_size}}|{line}' + cur_group["before_edits"].append( + f"{i1+idx+1:>{_indent_pad_size}}|{line}" ) for idx, line in enumerate(new_lines[j1:j2]): - cur_group['after_edits'].append( - f'{j1+idx+1:>{_indent_pad_size}}|{line}' + cur_group["after_edits"].append( + f"{j1+idx+1:>{_indent_pad_size}}|{line}" ) continue - if tag in {'replace', 'delete'}: + if tag in {"replace", "delete"}: for idx, line in enumerate(old_lines[i1:i2]): - cur_group['before_edits'].append( - f'-{i1+idx+1:>{_indent_pad_size-1}}|{line}' + cur_group["before_edits"].append( + f"-{i1+idx+1:>{_indent_pad_size-1}}|{line}" ) - if tag in {'replace', 'insert'}: + if tag in {"replace", "insert"}: for idx, line in enumerate(new_lines[j1:j2]): - cur_group['after_edits'].append( - f'+{j1+idx+1:>{_indent_pad_size-1}}|{line}' + cur_group["after_edits"].append( + f"+{j1+idx+1:>{_indent_pad_size-1}}|{line}" ) edit_groups.append(cur_group) return edit_groups @@ -97,37 +97,37 @@ def visualize_diff( n_context_lines: The number of lines of context to show before and after the changes. change_applied: Whether the changes are applied to the file. If true, the file have been modified. If not, the file is not modified (due to linting errors). """ - if change_applied and self.content.strip() == '': + if change_applied and self.content.strip() == "": # diff patch is empty - return '(no changes detected. Please make sure your edits changes the content of the existing file.)\n' + return "(no changes detected. Please make sure your edits changes the content of the existing file.)\n" edit_groups = self.get_edit_groups(n_context_lines=n_context_lines) result = [ - f'[Existing file {self.path} is edited with {len(edit_groups)} changes.]' + f"[Existing file {self.path} is edited with {len(edit_groups)} changes.]" if change_applied else f"[Changes are NOT applied to {self.path} - Here's how the file looks like if changes are applied.]" ] - op_type = 'edit' if change_applied else 'ATTEMPTED edit' + op_type = "edit" if change_applied else "ATTEMPTED edit" for i, cur_edit_group in enumerate(edit_groups): if i != 0: - result.append('-------------------------') - result.append(f'[begin of {op_type} {i+1} / {len(edit_groups)}]') - result.append(f'(content before {op_type})') - result.extend(cur_edit_group['before_edits']) - result.append(f'(content after {op_type})') - result.extend(cur_edit_group['after_edits']) - result.append(f'[end of {op_type} {i+1} / {len(edit_groups)}]') - return '\n'.join(result) + result.append("-------------------------") + result.append(f"[begin of {op_type} {i+1} / {len(edit_groups)}]") + result.append(f"(content before {op_type})") + result.extend(cur_edit_group["before_edits"]) + result.append(f"(content after {op_type})") + result.extend(cur_edit_group["after_edits"]) + result.append(f"[end of {op_type} {i+1} / {len(edit_groups)}]") + return "\n".join(result) def __str__(self) -> str: - ret = '' + ret = "" if not self.prev_exist: assert ( - self.old_content == '' - ), 'old_content should be empty if the file is new (prev_exist=False).' - ret += f'[New file {self.path} is created with the provided content.]\n' - return ret.rstrip() + '\n' + self.old_content == "" + ), "old_content should be empty if the file is new (prev_exist=False)." + ret += f"[New file {self.path} is created with the provided content.]\n" + return ret.rstrip() + "\n" ret += self.visualize_diff() - return ret.rstrip() + '\n' + return ret.rstrip() + "\n" diff --git a/openhands/events/serialization/__init__.py b/openhands/events/serialization/__init__.py index f36d08d86cf0..1e70dbe8f6c1 100644 --- a/openhands/events/serialization/__init__.py +++ b/openhands/events/serialization/__init__.py @@ -12,10 +12,10 @@ ) __all__ = [ - 'action_from_dict', - 'event_from_dict', - 'event_to_dict', - 'event_to_memory', - 'event_to_trajectory', - 'observation_from_dict', + "action_from_dict", + "event_from_dict", + "event_to_dict", + "event_to_memory", + "event_to_trajectory", + "observation_from_dict", ] diff --git a/openhands/events/serialization/action.py b/openhands/events/serialization/action.py index defac3b5dda6..4fc92f9d7d4e 100644 --- a/openhands/events/serialization/action.py +++ b/openhands/events/serialization/action.py @@ -43,37 +43,37 @@ def action_from_dict(action: dict) -> Action: if not isinstance(action, dict): - raise LLMMalformedActionError('action must be a dictionary') + raise LLMMalformedActionError("action must be a dictionary") action = action.copy() - if 'action' not in action: + if "action" not in action: raise LLMMalformedActionError(f"'action' key is not found in {action=}") - if not isinstance(action['action'], str): + if not isinstance(action["action"], str): raise LLMMalformedActionError( f"'{action['action']=}' is not defined. Available actions: {ACTION_TYPE_TO_CLASS.keys()}" ) - action_class = ACTION_TYPE_TO_CLASS.get(action['action']) + action_class = ACTION_TYPE_TO_CLASS.get(action["action"]) if action_class is None: raise LLMMalformedActionError( f"'{action['action']=}' is not defined. Available actions: {ACTION_TYPE_TO_CLASS.keys()}" ) - args = action.get('args', {}) + args = action.get("args", {}) # Remove timestamp from args if present - timestamp = args.pop('timestamp', None) + timestamp = args.pop("timestamp", None) # compatibility for older event streams # is_confirmed has been renamed to confirmation_state - is_confirmed = args.pop('is_confirmed', None) + is_confirmed = args.pop("is_confirmed", None) if is_confirmed is not None: - args['confirmation_state'] = is_confirmed + args["confirmation_state"] = is_confirmed # images_urls has been renamed to image_urls - if 'images_urls' in args: - args['image_urls'] = args.pop('images_urls') - + if "images_urls" in args: + args["image_urls"] = args.pop("images_urls") + try: decoded_action = action_class(**args) - if 'timeout' in action: - decoded_action.timeout = action['timeout'] + if "timeout" in action: + decoded_action.timeout = action["timeout"] # Set timestamp if it was provided if timestamp: @@ -81,6 +81,6 @@ def action_from_dict(action: dict) -> Action: except TypeError as e: raise LLMMalformedActionError( - f'action={action} has the wrong arguments: {str(e)}' + f"action={action} has the wrong arguments: {str(e)}" ) return decoded_action diff --git a/openhands/events/serialization/event.py b/openhands/events/serialization/event.py index 78f7940626d4..f381eda8aef2 100644 --- a/openhands/events/serialization/event.py +++ b/openhands/events/serialization/event.py @@ -10,109 +10,109 @@ # TODO: move `content` into `extras` TOP_KEYS = [ - 'id', - 'timestamp', - 'source', - 'message', - 'cause', - 'action', - 'observation', - 'tool_call_metadata', + "id", + "timestamp", + "source", + "message", + "cause", + "action", + "observation", + "tool_call_metadata", ] -UNDERSCORE_KEYS = ['id', 'timestamp', 'source', 'cause', 'tool_call_metadata'] +UNDERSCORE_KEYS = ["id", "timestamp", "source", "cause", "tool_call_metadata"] DELETE_FROM_TRAJECTORY_EXTRAS = { - 'screenshot', - 'dom_object', - 'axtree_object', - 'active_page_index', - 'last_browser_action', - 'last_browser_action_error', - 'focused_element_bid', - 'extra_element_properties', + "screenshot", + "dom_object", + "axtree_object", + "active_page_index", + "last_browser_action", + "last_browser_action_error", + "focused_element_bid", + "extra_element_properties", } -DELETE_FROM_MEMORY_EXTRAS = DELETE_FROM_TRAJECTORY_EXTRAS | {'open_pages_urls'} +DELETE_FROM_MEMORY_EXTRAS = DELETE_FROM_TRAJECTORY_EXTRAS | {"open_pages_urls"} -def event_from_dict(data) -> 'Event': +def event_from_dict(data) -> "Event": evt: Event - if 'action' in data: + if "action" in data: evt = action_from_dict(data) - elif 'observation' in data: + elif "observation" in data: evt = observation_from_dict(data) else: - raise ValueError('Unknown event type: ' + data) + raise ValueError("Unknown event type: " + data) for key in UNDERSCORE_KEYS: if key in data: value = data[key] - if key == 'timestamp' and isinstance(value, datetime): + if key == "timestamp" and isinstance(value, datetime): value = value.isoformat() - if key == 'source': + if key == "source": value = EventSource(value) - if key == 'tool_call_metadata': + if key == "tool_call_metadata": value = ToolCallMetadata(**value) - setattr(evt, '_' + key, value) + setattr(evt, "_" + key, value) return evt -def event_to_dict(event: 'Event') -> dict: +def event_to_dict(event: "Event") -> dict: props = asdict(event) d = {} for key in TOP_KEYS: if hasattr(event, key) and getattr(event, key) is not None: d[key] = getattr(event, key) - elif hasattr(event, f'_{key}') and getattr(event, f'_{key}') is not None: - d[key] = getattr(event, f'_{key}') - if key == 'id' and d.get('id') == -1: - d.pop('id', None) - if key == 'timestamp' and 'timestamp' in d: - if isinstance(d['timestamp'], datetime): - d['timestamp'] = d['timestamp'].isoformat() - if key == 'source' and 'source' in d: - d['source'] = d['source'].value - if key == 'tool_call_metadata' and 'tool_call_metadata' in d: - d['tool_call_metadata'] = d['tool_call_metadata'].model_dump() + elif hasattr(event, f"_{key}") and getattr(event, f"_{key}") is not None: + d[key] = getattr(event, f"_{key}") + if key == "id" and d.get("id") == -1: + d.pop("id", None) + if key == "timestamp" and "timestamp" in d: + if isinstance(d["timestamp"], datetime): + d["timestamp"] = d["timestamp"].isoformat() + if key == "source" and "source" in d: + d["source"] = d["source"].value + if key == "tool_call_metadata" and "tool_call_metadata" in d: + d["tool_call_metadata"] = d["tool_call_metadata"].model_dump() props.pop(key, None) - if 'security_risk' in props and props['security_risk'] is None: - props.pop('security_risk') - if 'action' in d: - d['args'] = props + if "security_risk" in props and props["security_risk"] is None: + props.pop("security_risk") + if "action" in d: + d["args"] = props if event.timeout is not None: - d['timeout'] = event.timeout - elif 'observation' in d: - d['content'] = props.pop('content', '') - d['extras'] = props + d["timeout"] = event.timeout + elif "observation" in d: + d["content"] = props.pop("content", "") + d["extras"] = props else: - raise ValueError('Event must be either action or observation') + raise ValueError("Event must be either action or observation") return d -def event_to_trajectory(event: 'Event') -> dict: +def event_to_trajectory(event: "Event") -> dict: d = event_to_dict(event) - if 'extras' in d: - remove_fields(d['extras'], DELETE_FROM_TRAJECTORY_EXTRAS) + if "extras" in d: + remove_fields(d["extras"], DELETE_FROM_TRAJECTORY_EXTRAS) return d -def event_to_memory(event: 'Event', max_message_chars: int) -> dict: +def event_to_memory(event: "Event", max_message_chars: int) -> dict: d = event_to_dict(event) - d.pop('id', None) - d.pop('cause', None) - d.pop('timestamp', None) - d.pop('message', None) - d.pop('image_urls', None) + d.pop("id", None) + d.pop("cause", None) + d.pop("timestamp", None) + d.pop("message", None) + d.pop("image_urls", None) # runnable actions have some extra fields used in the BE/FE, which should not be sent to the LLM - if 'args' in d: - d['args'].pop('blocking', None) - d['args'].pop('keep_prompt', None) - d['args'].pop('confirmation_state', None) - - if 'extras' in d: - remove_fields(d['extras'], DELETE_FROM_MEMORY_EXTRAS) - if isinstance(event, Observation) and 'content' in d: - d['content'] = truncate_content(d['content'], max_message_chars) + if "args" in d: + d["args"].pop("blocking", None) + d["args"].pop("keep_prompt", None) + d["args"].pop("confirmation_state", None) + + if "extras" in d: + remove_fields(d["extras"], DELETE_FROM_MEMORY_EXTRAS) + if isinstance(event, Observation) and "content" in d: + d["content"] = truncate_content(d["content"], max_message_chars) return d @@ -125,6 +125,6 @@ def truncate_content(content: str, max_chars: int) -> str: half = max_chars // 2 return ( content[:half] - + '\n[... Observation truncated due to length ...]\n' + + "\n[... Observation truncated due to length ...]\n" + content[-half:] ) diff --git a/openhands/events/serialization/observation.py b/openhands/events/serialization/observation.py index 9030ccb1e1dd..18e9ecdba81b 100644 --- a/openhands/events/serialization/observation.py +++ b/openhands/events/serialization/observation.py @@ -39,15 +39,15 @@ def observation_from_dict(observation: dict) -> Observation: observation = observation.copy() - if 'observation' not in observation: + if "observation" not in observation: raise KeyError(f"'observation' key is not found in {observation=}") - observation_class = OBSERVATION_TYPE_TO_CLASS.get(observation['observation']) + observation_class = OBSERVATION_TYPE_TO_CLASS.get(observation["observation"]) if observation_class is None: raise KeyError( f"'{observation['observation']=}' is not defined. Available observations: {OBSERVATION_TYPE_TO_CLASS.keys()}" ) - observation.pop('observation') - observation.pop('message', None) - content = observation.pop('content', '') - extras = observation.pop('extras', {}) + observation.pop("observation") + observation.pop("message", None) + content = observation.pop("content", "") + extras = observation.pop("extras", {}) return observation_class(content=content, **extras) diff --git a/openhands/events/serialization/utils.py b/openhands/events/serialization/utils.py index de448e01429b..fb0f1303ddf1 100644 --- a/openhands/events/serialization/utils.py +++ b/openhands/events/serialization/utils.py @@ -14,7 +14,7 @@ def remove_fields(obj, fields: set[str]): elif isinstance(obj, (list, tuple)): for item in obj: remove_fields(item, fields) - elif hasattr(obj, '__dataclass_fields__'): + elif hasattr(obj, "__dataclass_fields__"): raise ValueError( - 'Object must not contain dataclass, consider converting to dict first' + "Object must not contain dataclass, consider converting to dict first" ) diff --git a/openhands/events/stream.py b/openhands/events/stream.py index 625abd7a5221..465f27405a58 100644 --- a/openhands/events/stream.py +++ b/openhands/events/stream.py @@ -90,8 +90,7 @@ def get_events( filter_out_type: tuple[type[Event], ...] | None = None, filter_hidden=False, ) -> Iterable[Event]: - """ - Retrieve events from the event stream, optionally filtering out events of a given type + """Retrieve events from the event stream, optionally filtering out events of a given type and events marked as hidden. Args: diff --git a/openhands/linter/languages/python.py b/openhands/linter/languages/python.py index 9b7e944a2868..fc302146aaa4 100644 --- a/openhands/linter/languages/python.py +++ b/openhands/linter/languages/python.py @@ -7,13 +7,13 @@ def python_compile_lint(fname: str) -> list[LintResult]: try: - with open(fname, 'r') as f: + with open(fname, "r") as f: code = f.read() - compile(code, fname, 'exec') # USE TRACEBACK BELOW HERE + compile(code, fname, "exec") # USE TRACEBACK BELOW HERE return [] except SyntaxError as err: - err_lineno = getattr(err, 'end_lineno', err.lineno) - err_offset = getattr(err, 'end_offset', err.offset) + err_lineno = getattr(err, "end_lineno", err.lineno) + err_offset = getattr(err, "end_offset", err.offset) if err_offset and err_offset < 0: err_offset = err.offset return [ @@ -24,8 +24,8 @@ def python_compile_lint(fname: str) -> list[LintResult]: def flake_lint(filepath: str) -> list[LintResult]: - fatal = 'F821,F822,F831,E112,E113,E999,E902' - flake8_cmd = f'flake8 --select={fatal} --isolated {filepath}' + fatal = "F821,F822,F831,E112,E113,E999,E902" + flake8_cmd = f"flake8 --select={fatal} --isolated {filepath}" try: cmd_outputs = run_cmd(flake8_cmd) @@ -35,17 +35,17 @@ def flake_lint(filepath: str) -> list[LintResult]: if not cmd_outputs: return results for line in cmd_outputs.splitlines(): - parts = line.split(':') + parts = line.split(":") if len(parts) >= 4: _msg = parts[3].strip() if len(parts) > 4: - _msg += ': ' + parts[4].strip() + _msg += ": " + parts[4].strip() try: line_num = int(parts[1]) except ValueError as e: logger.warning( - f'Error parsing flake8 output for line: {e}. Parsed parts: {parts}. Skipping...' + f"Error parsing flake8 output for line: {e}. Parsed parts: {parts}. Skipping..." ) continue @@ -54,10 +54,10 @@ def flake_lint(filepath: str) -> list[LintResult]: except ValueError as e: column_num = 1 _msg = ( - parts[2].strip() + ' ' + _msg + parts[2].strip() + " " + _msg ) # add the unparsed message to the original message logger.warning( - f'Error parsing flake8 output for column: {e}. Parsed parts: {parts}. Using default column 1.' + f"Error parsing flake8 output for column: {e}. Parsed parts: {parts}. Using default column 1." ) results.append( @@ -74,7 +74,7 @@ def flake_lint(filepath: str) -> list[LintResult]: class PythonLinter(BaseLinter): @property def supported_extensions(self) -> List[str]: - return ['.py'] + return [".py"] def lint(self, file_path: str) -> list[LintResult]: error = flake_lint(file_path) @@ -84,7 +84,7 @@ def lint(self, file_path: str) -> list[LintResult]: def compile_lint(self, file_path: str, code: str) -> List[LintResult]: try: - compile(code, file_path, 'exec') + compile(code, file_path, "exec") return [] except SyntaxError as e: return [ @@ -93,6 +93,6 @@ def compile_lint(self, file_path: str, code: str) -> List[LintResult]: line=e.lineno, column=e.offset, message=str(e), - rule='SyntaxError', + rule="SyntaxError", ) ] diff --git a/openhands/linter/languages/treesitter.py b/openhands/linter/languages/treesitter.py index 83b5d466aecc..e85c381e886d 100644 --- a/openhands/linter/languages/treesitter.py +++ b/openhands/linter/languages/treesitter.py @@ -7,7 +7,7 @@ from openhands.linter.base import BaseLinter, LintResult # tree_sitter is throwing a FutureWarning -warnings.simplefilter('ignore', category=FutureWarning) +warnings.simplefilter("ignore", category=FutureWarning) def tree_context(fname, code, line_nums): @@ -34,10 +34,10 @@ def tree_context(fname, code, line_nums): def traverse_tree(node): """Traverses the tree to find errors.""" errors = [] - if node.type == 'ERROR' or node.is_missing: + if node.type == "ERROR" or node.is_missing: line_no = node.start_point[0] + 1 col_no = node.start_point[1] + 1 - error_type = 'Missing node' if node.is_missing else 'Syntax error' + error_type = "Missing node" if node.is_missing else "Syntax error" errors.append((line_no, col_no, error_type)) for child in node.children: @@ -57,9 +57,9 @@ def lint(self, file_path: str) -> list[LintResult]: if not lang: return [] parser = get_parser(lang) - with open(file_path, 'r') as f: + with open(file_path, "r") as f: code = f.read() - tree = parser.parse(bytes(code, 'utf-8')) + tree = parser.parse(bytes(code, "utf-8")) errors = traverse_tree(tree.root_node) if not errors: return [] diff --git a/openhands/linter/utils/__init__.py b/openhands/linter/utils/__init__.py index e48f26f076b5..c346a8f561a4 100644 --- a/openhands/linter/utils/__init__.py +++ b/openhands/linter/utils/__init__.py @@ -1,3 +1,3 @@ from .cmd import check_tool_installed, run_cmd -__all__ = ['run_cmd', 'check_tool_installed'] +__all__ = ["run_cmd", "check_tool_installed"] diff --git a/openhands/linter/utils/cmd.py b/openhands/linter/utils/cmd.py index f5c2803c3d77..13989d645bab 100644 --- a/openhands/linter/utils/cmd.py +++ b/openhands/linter/utils/cmd.py @@ -7,14 +7,13 @@ def run_cmd(cmd: str, cwd: str | None = None) -> str | None: If the command succeeds, return None. If the command fails, return the stdout. """ - process = subprocess.Popen( cmd.split(), cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - encoding='utf-8', - errors='replace', + encoding="utf-8", + errors="replace", ) stdout, _ = process.communicate() if process.returncode == 0: @@ -26,7 +25,7 @@ def check_tool_installed(tool_name: str) -> bool: """Check if a tool is installed.""" try: subprocess.run( - [tool_name, '--version'], + [tool_name, "--version"], check=True, cwd=os.getcwd(), stdout=subprocess.PIPE, diff --git a/openhands/llm/retry_mixin.py b/openhands/llm/retry_mixin.py index 1005677320e1..a1705f0d30ca 100644 --- a/openhands/llm/retry_mixin.py +++ b/openhands/llm/retry_mixin.py @@ -13,8 +13,7 @@ class RetryMixin: """Mixin class for retry logic.""" def retry_decorator(self, **kwargs): - """ - Create a LLM retry decorator with customizable parameters. This is used for 429 errors, and a few other exceptions in LLM classes. + """Create a LLM retry decorator with customizable parameters. This is used for 429 errors, and a few other exceptions in LLM classes. Args: **kwargs: Keyword arguments to override default retry behavior. diff --git a/openhands/memory/memory.py b/openhands/memory/memory.py index 9d83cc9cdc8c..509bb8b2ffbe 100644 --- a/openhands/memory/memory.py +++ b/openhands/memory/memory.py @@ -35,7 +35,6 @@ def __init__( event_stream: EventStream, ): """Initialize the chromadb and set up ChromaVectorStore for later use.""" - check_llama_index() # initialize the chromadb client diff --git a/openhands/runtime/action_execution_server.py b/openhands/runtime/action_execution_server.py index 2e060337a58e..3afba6b6f5bb 100644 --- a/openhands/runtime/action_execution_server.py +++ b/openhands/runtime/action_execution_server.py @@ -1,5 +1,4 @@ -""" -This is the main file for the runtime client. +"""This is the main file for the runtime client. It is responsible for executing actions received from OpenHands backend and producing observations. NOTE: this will be executed inside the docker sandbox. diff --git a/openhands/runtime/browser/__init__.py b/openhands/runtime/browser/__init__.py index 2687e03c684f..db13465101ce 100644 --- a/openhands/runtime/browser/__init__.py +++ b/openhands/runtime/browser/__init__.py @@ -1,3 +1,3 @@ from openhands.runtime.browser.utils import browse -__all__ = ['browse'] +__all__ = ["browse"] diff --git a/openhands/runtime/browser/browser_env.py b/openhands/runtime/browser/browser_env.py index 9bad97b9bb2b..8f9b06e14321 100644 --- a/openhands/runtime/browser/browser_env.py +++ b/openhands/runtime/browser/browser_env.py @@ -19,22 +19,22 @@ from openhands.runtime.utils.shutdown_listener import should_continue, should_exit from openhands.utils.tenacity_stop import stop_if_should_exit -BROWSER_EVAL_GET_GOAL_ACTION = 'GET_EVAL_GOAL' -BROWSER_EVAL_GET_REWARDS_ACTION = 'GET_EVAL_REWARDS' +BROWSER_EVAL_GET_GOAL_ACTION = "GET_EVAL_GOAL" +BROWSER_EVAL_GET_REWARDS_ACTION = "GET_EVAL_REWARDS" class BrowserEnv: def __init__(self, browsergym_eval_env: str | None = None): self.html_text_converter = self.get_html_text_converter() self.eval_mode = False - self.eval_dir = '' + self.eval_dir = "" # EVAL only: browsergym_eval_env must be provided for evaluation self.browsergym_eval_env = browsergym_eval_env self.eval_mode = bool(browsergym_eval_env) # Initialize browser environment process - multiprocessing.set_start_method('spawn', force=True) + multiprocessing.set_start_method("spawn", force=True) self.browser_side, self.agent_side = multiprocessing.Pipe() self.init_browser() @@ -57,42 +57,42 @@ def get_html_text_converter(self): retry=tenacity.retry_if_exception_type(BrowserInitException), ) def init_browser(self): - logger.debug('Starting browser env...') + logger.debug("Starting browser env...") try: self.process = multiprocessing.Process(target=self.browser_process) self.process.start() except Exception as e: - logger.error(f'Failed to start browser process: {e}') + logger.error(f"Failed to start browser process: {e}") raise if not self.check_alive(): self.close() - raise BrowserInitException('Failed to start browser environment.') + raise BrowserInitException("Failed to start browser environment.") def browser_process(self): if self.eval_mode: assert self.browsergym_eval_env is not None - logger.debug('Initializing browser env for web browsing evaluation.') - if 'webarena' in self.browsergym_eval_env: + logger.debug("Initializing browser env for web browsing evaluation.") + if "webarena" in self.browsergym_eval_env: import browsergym.webarena # noqa F401 register webarena tasks as gym environments - elif 'miniwob' in self.browsergym_eval_env: + elif "miniwob" in self.browsergym_eval_env: import browsergym.miniwob # noqa F401 register miniwob tasks as gym environments else: raise ValueError( - f'Unsupported browsergym eval env: {self.browsergym_eval_env}' + f"Unsupported browsergym eval env: {self.browsergym_eval_env}" ) env = gym.make( self.browsergym_eval_env, - tags_to_mark='all', + tags_to_mark="all", ) else: env = gym.make( - 'browsergym/openended', - task_kwargs={'start_url': 'about:blank', 'goal': 'PLACEHOLDER_GOAL'}, + "browsergym/openended", + task_kwargs={"start_url": "about:blank", "goal": "PLACEHOLDER_GOAL"}, wait_for_user_message=False, headless=True, disable_env_checker=True, - tags_to_mark='all', + tags_to_mark="all", ) obs, info = env.reset() @@ -102,39 +102,39 @@ def browser_process(self): self.eval_rewards: list[float] = [] if self.eval_mode: logger.debug(f"Browsing goal: {obs['goal']}") - self.eval_goal = obs['goal'] + self.eval_goal = obs["goal"] - logger.debug('Browser env started.') + logger.debug("Browser env started.") while should_continue(): try: if self.browser_side.poll(timeout=0.01): unique_request_id, action_data = self.browser_side.recv() # shutdown the browser environment - if unique_request_id == 'SHUTDOWN': - logger.debug('SHUTDOWN recv, shutting down browser env...') + if unique_request_id == "SHUTDOWN": + logger.debug("SHUTDOWN recv, shutting down browser env...") env.close() return - elif unique_request_id == 'IS_ALIVE': - self.browser_side.send(('ALIVE', None)) + elif unique_request_id == "IS_ALIVE": + self.browser_side.send(("ALIVE", None)) continue # EVAL ONLY: Get evaluation info - if action_data['action'] == BROWSER_EVAL_GET_GOAL_ACTION: + if action_data["action"] == BROWSER_EVAL_GET_GOAL_ACTION: self.browser_side.send( - (unique_request_id, {'text_content': self.eval_goal}) + (unique_request_id, {"text_content": self.eval_goal}) ) continue - elif action_data['action'] == BROWSER_EVAL_GET_REWARDS_ACTION: + elif action_data["action"] == BROWSER_EVAL_GET_REWARDS_ACTION: self.browser_side.send( ( unique_request_id, - {'text_content': json.dumps(self.eval_rewards)}, + {"text_content": json.dumps(self.eval_rewards)}, ) ) continue - action = action_data['action'] + action = action_data["action"] obs, reward, terminated, truncated, info = env.step(action) # EVAL ONLY: Save the rewards into file for evaluation @@ -142,15 +142,15 @@ def browser_process(self): self.eval_rewards.append(reward) # add text content of the page - html_str = flatten_dom_to_str(obs['dom_object']) - obs['text_content'] = self.html_text_converter.handle(html_str) + html_str = flatten_dom_to_str(obs["dom_object"]) + obs["text_content"] = self.html_text_converter.handle(html_str) # make observation serializable - obs['screenshot'] = self.image_to_png_base64_url(obs['screenshot']) - obs['active_page_index'] = obs['active_page_index'].item() - obs['elapsed_time'] = obs['elapsed_time'].item() + obs["screenshot"] = self.image_to_png_base64_url(obs["screenshot"]) + obs["active_page_index"] = obs["active_page_index"].item() + obs["elapsed_time"] = obs["elapsed_time"].item() self.browser_side.send((unique_request_id, obs)) except KeyboardInterrupt: - logger.debug('Browser env process interrupted by user.') + logger.debug("Browser env process interrupted by user.") try: env.close() except Exception: @@ -160,33 +160,33 @@ def browser_process(self): def step(self, action_str: str, timeout: float = 30) -> dict: """Execute an action in the browser environment and return the observation.""" unique_request_id = str(uuid.uuid4()) - self.agent_side.send((unique_request_id, {'action': action_str})) + self.agent_side.send((unique_request_id, {"action": action_str})) start_time = time.time() while True: if should_exit() or time.time() - start_time > timeout: - raise TimeoutError('Browser environment took too long to respond.') + raise TimeoutError("Browser environment took too long to respond.") if self.agent_side.poll(timeout=0.01): response_id, obs = self.agent_side.recv() if response_id == unique_request_id: return obs def check_alive(self, timeout: float = 60): - self.agent_side.send(('IS_ALIVE', None)) + self.agent_side.send(("IS_ALIVE", None)) if self.agent_side.poll(timeout=timeout): response_id, _ = self.agent_side.recv() - if response_id == 'ALIVE': + if response_id == "ALIVE": return True - logger.debug(f'Browser env is not alive. Response ID: {response_id}') + logger.debug(f"Browser env is not alive. Response ID: {response_id}") def close(self): if not self.process.is_alive(): return try: - self.agent_side.send(('SHUTDOWN', None)) + self.agent_side.send(("SHUTDOWN", None)) self.process.join(5) # Wait for the process to terminate if self.process.is_alive(): logger.error( - 'Browser process did not terminate, forcefully terminating...' + "Browser process did not terminate, forcefully terminating..." ) self.process.terminate() self.process.join(5) # Wait for the process to terminate @@ -196,7 +196,7 @@ def close(self): self.agent_side.close() self.browser_side.close() except Exception: - logger.error('Encountered an error when closing browser env', exc_info=True) + logger.error("Encountered an error when closing browser env", exc_info=True) @staticmethod def image_to_png_base64_url( @@ -205,16 +205,16 @@ def image_to_png_base64_url( """Convert a numpy array to a base64 encoded png image url.""" if isinstance(image, np.ndarray): image = Image.fromarray(image) - if image.mode in ('RGBA', 'LA'): - image = image.convert('RGB') + if image.mode in ("RGBA", "LA"): + image = image.convert("RGB") buffered = io.BytesIO() - image.save(buffered, format='PNG') + image.save(buffered, format="PNG") image_base64 = base64.b64encode(buffered.getvalue()).decode() return ( - f'data:image/png;base64,{image_base64}' + f"data:image/png;base64,{image_base64}" if add_data_prefix - else f'{image_base64}' + else f"{image_base64}" ) @staticmethod @@ -224,14 +224,14 @@ def image_to_jpg_base64_url( """Convert a numpy array to a base64 encoded jpeg image url.""" if isinstance(image, np.ndarray): image = Image.fromarray(image) - if image.mode in ('RGBA', 'LA'): - image = image.convert('RGB') + if image.mode in ("RGBA", "LA"): + image = image.convert("RGB") buffered = io.BytesIO() - image.save(buffered, format='JPEG') + image.save(buffered, format="JPEG") image_base64 = base64.b64encode(buffered.getvalue()).decode() return ( - f'data:image/jpeg;base64,{image_base64}' + f"data:image/jpeg;base64,{image_base64}" if add_data_prefix - else f'{image_base64}' + else f"{image_base64}" ) diff --git a/openhands/runtime/browser/utils.py b/openhands/runtime/browser/utils.py index 336b3801e3e2..517954f78a1b 100644 --- a/openhands/runtime/browser/utils.py +++ b/openhands/runtime/browser/utils.py @@ -16,7 +16,7 @@ async def browse( if isinstance(action, BrowseURLAction): # legacy BrowseURLAction asked_url = action.url - if not asked_url.startswith('http'): + if not asked_url.startswith("http"): asked_url = os.path.abspath(os.curdir) + action.url action_str = f'goto("{asked_url}")' @@ -25,36 +25,36 @@ async def browse( # action in BrowserGym: see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/action/functions.py action_str = action.browser_actions else: - raise ValueError(f'Invalid action type: {action.action}') + raise ValueError(f"Invalid action type: {action.action}") try: # obs provided by BrowserGym: see https://github.com/ServiceNow/BrowserGym/blob/main/core/src/browsergym/core/env.py#L396 obs = browser.step(action_str) return BrowserOutputObservation( - content=obs['text_content'], # text content of the page - url=obs.get('url', ''), # URL of the page - screenshot=obs.get('screenshot', None), # base64-encoded screenshot, png - open_pages_urls=obs.get('open_pages_urls', []), # list of open pages + content=obs["text_content"], # text content of the page + url=obs.get("url", ""), # URL of the page + screenshot=obs.get("screenshot", None), # base64-encoded screenshot, png + open_pages_urls=obs.get("open_pages_urls", []), # list of open pages active_page_index=obs.get( - 'active_page_index', -1 + "active_page_index", -1 ), # index of the active page - dom_object=obs.get('dom_object', {}), # DOM object - axtree_object=obs.get('axtree_object', {}), # accessibility tree object - extra_element_properties=obs.get('extra_element_properties', {}), + dom_object=obs.get("dom_object", {}), # DOM object + axtree_object=obs.get("axtree_object", {}), # accessibility tree object + extra_element_properties=obs.get("extra_element_properties", {}), focused_element_bid=obs.get( - 'focused_element_bid', None + "focused_element_bid", None ), # focused element bid last_browser_action=obs.get( - 'last_action', '' + "last_action", "" ), # last browser env action performed - last_browser_action_error=obs.get('last_action_error', ''), - error=True if obs.get('last_action_error', '') else False, # error flag + last_browser_action_error=obs.get("last_action_error", ""), + error=True if obs.get("last_action_error", "") else False, # error flag ) except Exception as e: return BrowserOutputObservation( content=str(e), - screenshot='', + screenshot="", error=True, last_browser_action_error=str(e), - url=asked_url if action.action == ActionType.BROWSE else '', + url=asked_url if action.action == ActionType.BROWSE else "", ) diff --git a/openhands/runtime/builder/__init__.py b/openhands/runtime/builder/__init__.py index fcebb8a24056..2d0767f6ef58 100644 --- a/openhands/runtime/builder/__init__.py +++ b/openhands/runtime/builder/__init__.py @@ -1,4 +1,4 @@ from openhands.runtime.builder.base import RuntimeBuilder from openhands.runtime.builder.docker import DockerRuntimeBuilder -__all__ = ['RuntimeBuilder', 'DockerRuntimeBuilder'] +__all__ = ["RuntimeBuilder", "DockerRuntimeBuilder"] diff --git a/openhands/runtime/builder/base.py b/openhands/runtime/builder/base.py index 4930b13d7ffd..df2ee99035c9 100644 --- a/openhands/runtime/builder/base.py +++ b/openhands/runtime/builder/base.py @@ -9,13 +9,13 @@ def build( tags: list[str], platform: str | None = None, ) -> str: - """ - Build the runtime image. + """Build the runtime image. Args: path (str): The path to the runtime image's build directory. tags (list[str]): The tags to apply to the runtime image (e.g., ["repo:my-repo", "sha:my-sha"]). platform (str, optional): The target platform for the build. Defaults to None. + Returns: str: The name:tag of the runtime image after build (e.g., "repo:sha"). This can be different from the tags input if the builder chooses to mutate the tags (e.g., adding a @@ -28,8 +28,7 @@ def build( @abc.abstractmethod def image_exists(self, image_name: str, pull_from_repo: bool = True) -> bool: - """ - Check if the runtime image exists. + """Check if the runtime image exists. Args: image_name (str): The name of the runtime image (e.g., "repo:sha"). diff --git a/openhands/runtime/builder/docker.py b/openhands/runtime/builder/docker.py index a3cb5af39f3d..c304b1ebaa83 100644 --- a/openhands/runtime/builder/docker.py +++ b/openhands/runtime/builder/docker.py @@ -16,9 +16,9 @@ def __init__(self, docker_client: docker.DockerClient): self.docker_client = docker_client version_info = self.docker_client.version() - server_version = version_info.get('Version', '').replace('-', '.') - if tuple(map(int, server_version.split('.')[:2])) < (18, 9): - raise RuntimeError('Docker server version must be >= 18.09 to use BuildKit') + server_version = version_info.get("Version", "").replace("-", ".") + if tuple(map(int, server_version.split(".")[:2])) < (18, 9): + raise RuntimeError("Docker server version must be >= 18.09 to use BuildKit") self.rolling_logger = RollingLogger(max_lines=10) @@ -52,35 +52,35 @@ def build( """ self.docker_client = docker.from_env() version_info = self.docker_client.version() - server_version = version_info.get('Version', '').replace('-', '.') - if tuple(map(int, server_version.split('.'))) < (18, 9): - raise RuntimeError('Docker server version must be >= 18.09 to use BuildKit') + server_version = version_info.get("Version", "").replace("-", ".") + if tuple(map(int, server_version.split("."))) < (18, 9): + raise RuntimeError("Docker server version must be >= 18.09 to use BuildKit") target_image_hash_name = tags[0] - target_image_repo, target_image_source_tag = target_image_hash_name.split(':') - target_image_tag = tags[1].split(':')[1] if len(tags) > 1 else None + target_image_repo, target_image_source_tag = target_image_hash_name.split(":") + target_image_tag = tags[1].split(":")[1] if len(tags) > 1 else None buildx_cmd = [ - 'docker', - 'buildx', - 'build', - '--progress=plain', - f'--build-arg=OPENHANDS_RUNTIME_VERSION={oh_version}', - f'--build-arg=OPENHANDS_RUNTIME_BUILD_TIME={datetime.datetime.now().isoformat()}', - f'--tag={target_image_hash_name}', - '--load', + "docker", + "buildx", + "build", + "--progress=plain", + f"--build-arg=OPENHANDS_RUNTIME_VERSION={oh_version}", + f"--build-arg=OPENHANDS_RUNTIME_BUILD_TIME={datetime.datetime.now().isoformat()}", + f"--tag={target_image_hash_name}", + "--load", ] # Include the platform argument only if platform is specified if platform: - buildx_cmd.append(f'--platform={platform}') + buildx_cmd.append(f"--platform={platform}") - cache_dir = '/tmp/.buildx-cache' + cache_dir = "/tmp/.buildx-cache" if use_local_cache and self._is_cache_usable(cache_dir): buildx_cmd.extend( [ - f'--cache-from=type=local,src={cache_dir}', - f'--cache-to=type=local,dest={cache_dir},mode=max', + f"--cache-from=type=local,src={cache_dir}", + f"--cache-to=type=local,dest={cache_dir},mode=max", ] ) @@ -90,7 +90,7 @@ def build( buildx_cmd.append(path) # must be last! self.rolling_logger.start( - '================ DOCKER BUILD STARTED ================' + "================ DOCKER BUILD STARTED ================" ) try: @@ -103,7 +103,7 @@ def build( ) if process.stdout: - for line in iter(process.stdout.readline, ''): + for line in iter(process.stdout.readline, ""): line = line.strip() if line: self._output_logs(line) @@ -119,51 +119,51 @@ def build( ) except subprocess.CalledProcessError as e: - logger.error(f'Image build failed:\n{e}') - logger.error(f'Command output:\n{e.output}') + logger.error(f"Image build failed:\n{e}") + logger.error(f"Command output:\n{e.output}") raise except subprocess.TimeoutExpired: - logger.error('Image build timed out') + logger.error("Image build timed out") raise except FileNotFoundError as e: - logger.error(f'Python executable not found: {e}') + logger.error(f"Python executable not found: {e}") raise except PermissionError as e: logger.error( - f'Permission denied when trying to execute the build command:\n{e}' + f"Permission denied when trying to execute the build command:\n{e}" ) raise except Exception as e: - logger.error(f'An unexpected error occurred during the build process: {e}') + logger.error(f"An unexpected error occurred during the build process: {e}") raise - logger.info(f'Image [{target_image_hash_name}] build finished.') + logger.info(f"Image [{target_image_hash_name}] build finished.") if target_image_tag: image = self.docker_client.images.get(target_image_hash_name) image.tag(target_image_repo, target_image_tag) logger.info( - f'Re-tagged image [{target_image_hash_name}] with more generic tag [{target_image_tag}]' + f"Re-tagged image [{target_image_hash_name}] with more generic tag [{target_image_tag}]" ) # Check if the image is built successfully image = self.docker_client.images.get(target_image_hash_name) if image is None: raise RuntimeError( - f'Build failed: Image {target_image_hash_name} not found' + f"Build failed: Image {target_image_hash_name} not found" ) tags_str = ( - f'{target_image_source_tag}, {target_image_tag}' + f"{target_image_source_tag}, {target_image_tag}" if target_image_tag else target_image_source_tag ) logger.info( - f'Image {target_image_repo} with tags [{tags_str}] built successfully' + f"Image {target_image_repo} with tags [{tags_str}] built successfully" ) return target_image_hash_name @@ -177,28 +177,28 @@ def image_exists(self, image_name: str, pull_from_repo: bool = True) -> bool: bool: Whether the Docker image exists in the registry or in the local store """ if not image_name: - logger.error(f'Invalid image name: `{image_name}`') + logger.error(f"Invalid image name: `{image_name}`") return False try: - logger.debug(f'Checking, if image exists locally:\n{image_name}') + logger.debug(f"Checking, if image exists locally:\n{image_name}") self.docker_client.images.get(image_name) - logger.debug('Image found locally.') + logger.debug("Image found locally.") return True except docker.errors.ImageNotFound: if not pull_from_repo: - logger.debug(f'Image {image_name} not found locally') + logger.debug(f"Image {image_name} not found locally") return False try: logger.debug( - 'Image not found locally. Trying to pull it, please wait...' + "Image not found locally. Trying to pull it, please wait..." ) layers: dict[str, dict[str, str]] = {} previous_layer_count = 0 - if ':' in image_name: - image_repo, image_tag = image_name.split(':', 1) + if ":" in image_name: + image_repo, image_tag = image_name.split(":", 1) else: image_repo = image_name image_tag = None @@ -208,18 +208,18 @@ def image_exists(self, image_name: str, pull_from_repo: bool = True) -> bool: ): self._output_build_progress(line, layers, previous_layer_count) previous_layer_count = len(layers) - logger.debug('Image pulled') + logger.debug("Image pulled") return True except docker.errors.ImageNotFound: - logger.debug('Could not find image locally or in registry.') + logger.debug("Could not find image locally or in registry.") return False except Exception as e: - msg = 'Image could not be pulled: ' + msg = "Image could not be pulled: " ex_msg = str(e) - if 'Not Found' in ex_msg: - msg += 'image not found in registry.' + if "Not Found" in ex_msg: + msg += "image not found in registry." else: - msg += f'{ex_msg}' + msg += f"{ex_msg}" logger.debug(msg) return False @@ -232,62 +232,61 @@ def _output_logs(self, new_line: str) -> None: def _output_build_progress( self, current_line: dict, layers: dict, previous_layer_count: int ) -> None: - if 'id' in current_line and 'progressDetail' in current_line: - layer_id = current_line['id'] + if "id" in current_line and "progressDetail" in current_line: + layer_id = current_line["id"] if layer_id not in layers: - layers[layer_id] = {'status': '', 'progress': '', 'last_logged': 0} + layers[layer_id] = {"status": "", "progress": "", "last_logged": 0} - if 'status' in current_line: - layers[layer_id]['status'] = current_line['status'] + if "status" in current_line: + layers[layer_id]["status"] = current_line["status"] - if 'progress' in current_line: - layers[layer_id]['progress'] = current_line['progress'] + if "progress" in current_line: + layers[layer_id]["progress"] = current_line["progress"] - if 'progressDetail' in current_line: - progress_detail = current_line['progressDetail'] - if 'total' in progress_detail and 'current' in progress_detail: - total = progress_detail['total'] - current = progress_detail['current'] + if "progressDetail" in current_line: + progress_detail = current_line["progressDetail"] + if "total" in progress_detail and "current" in progress_detail: + total = progress_detail["total"] + current = progress_detail["current"] percentage = min( (current / total) * 100, 100 ) # Ensure it doesn't exceed 100% else: percentage = ( - 100 if layers[layer_id]['status'] == 'Download complete' else 0 + 100 if layers[layer_id]["status"] == "Download complete" else 0 ) if self.rolling_logger.is_enabled(): self.rolling_logger.move_back(previous_layer_count) for lid, layer_data in sorted(layers.items()): self.rolling_logger.replace_current_line() - status = layer_data['status'] - progress = layer_data['progress'] - if status == 'Download complete': + status = layer_data["status"] + progress = layer_data["progress"] + if status == "Download complete": self.rolling_logger.write_immediately( - f'Layer {lid}: Download complete' + f"Layer {lid}: Download complete" ) - elif status == 'Already exists': + elif status == "Already exists": self.rolling_logger.write_immediately( - f'Layer {lid}: Already exists' + f"Layer {lid}: Already exists" ) else: self.rolling_logger.write_immediately( - f'Layer {lid}: {progress} {status}' + f"Layer {lid}: {progress} {status}" ) elif percentage != 0 and ( - percentage - layers[layer_id]['last_logged'] >= 10 or percentage == 100 + percentage - layers[layer_id]["last_logged"] >= 10 or percentage == 100 ): logger.debug( f'Layer {layer_id}: {layers[layer_id]["progress"]} {layers[layer_id]["status"]}' ) - layers[layer_id]['last_logged'] = percentage - elif 'status' in current_line: - logger.debug(current_line['status']) + layers[layer_id]["last_logged"] = percentage + elif "status" in current_line: + logger.debug(current_line["status"]) def _prune_old_cache_files(self, cache_dir: str, max_age_days: int = 7) -> None: - """ - Prune cache files older than the specified number of days. + """Prune cache files older than the specified number of days. Args: cache_dir (str): The path to the cache directory. @@ -304,15 +303,14 @@ def _prune_old_cache_files(self, cache_dir: str, max_age_days: int = 7) -> None: file_age = current_time - os.path.getmtime(file_path) if file_age > max_age_seconds: os.remove(file_path) - logger.debug(f'Removed old cache file: {file_path}') + logger.debug(f"Removed old cache file: {file_path}") except Exception as e: - logger.warning(f'Error processing cache file {file_path}: {e}') + logger.warning(f"Error processing cache file {file_path}: {e}") except Exception as e: - logger.warning(f'Error during build cache pruning: {e}') + logger.warning(f"Error during build cache pruning: {e}") def _is_cache_usable(self, cache_dir: str) -> bool: - """ - Check if the cache directory is usable (exists and is writable). + """Check if the cache directory is usable (exists and is writable). Args: cache_dir (str): The path to the cache directory. @@ -323,18 +321,18 @@ def _is_cache_usable(self, cache_dir: str) -> bool: if not os.path.exists(cache_dir): try: os.makedirs(cache_dir, exist_ok=True) - logger.debug(f'Created cache directory: {cache_dir}') + logger.debug(f"Created cache directory: {cache_dir}") except OSError as e: - logger.debug(f'Failed to create cache directory {cache_dir}: {e}') + logger.debug(f"Failed to create cache directory {cache_dir}: {e}") return False if not os.access(cache_dir, os.W_OK): logger.warning( - f'Cache directory {cache_dir} is not writable. Caches will not be used for Docker builds.' + f"Cache directory {cache_dir} is not writable. Caches will not be used for Docker builds." ) return False self._prune_old_cache_files(cache_dir) - logger.debug(f'Cache directory {cache_dir} is usable') + logger.debug(f"Cache directory {cache_dir} is usable") return True diff --git a/openhands/runtime/builder/remote.py b/openhands/runtime/builder/remote.py index b1b14752cb89..8969d07c5792 100644 --- a/openhands/runtime/builder/remote.py +++ b/openhands/runtime/builder/remote.py @@ -21,87 +21,87 @@ def __init__(self, api_url: str, api_key: str): self.api_url = api_url self.api_key = api_key self.session = requests.Session() - self.session.headers.update({'X-API-Key': self.api_key}) + self.session.headers.update({"X-API-Key": self.api_key}) def build(self, path: str, tags: list[str], platform: str | None = None) -> str: """Builds a Docker image using the Runtime API's /build endpoint.""" # Create a tar archive of the build context tar_buffer = io.BytesIO() - with tarfile.open(fileobj=tar_buffer, mode='w:gz') as tar: - tar.add(path, arcname='.') + with tarfile.open(fileobj=tar_buffer, mode="w:gz") as tar: + tar.add(path, arcname=".") tar_buffer.seek(0) # Encode the tar file as base64 - base64_encoded_tar = base64.b64encode(tar_buffer.getvalue()).decode('utf-8') + base64_encoded_tar = base64.b64encode(tar_buffer.getvalue()).decode("utf-8") # Prepare the multipart form data files = [ - ('context', ('context.tar.gz', base64_encoded_tar)), - ('target_image', (None, tags[0])), + ("context", ("context.tar.gz", base64_encoded_tar)), + ("target_image", (None, tags[0])), ] # Add additional tags if present for tag in tags[1:]: - files.append(('tags', (None, tag))) + files.append(("tags", (None, tag))) # Send the POST request to /build (Begins the build process) try: response = send_request( self.session, - 'POST', - f'{self.api_url}/build', + "POST", + f"{self.api_url}/build", files=files, timeout=30, ) except requests.exceptions.HTTPError as e: if e.response.status_code == 429: - logger.warning('Build was rate limited. Retrying in 30 seconds.') + logger.warning("Build was rate limited. Retrying in 30 seconds.") time.sleep(30) return self.build(path, tags, platform) else: raise e build_data = response.json() - build_id = build_data['build_id'] - logger.info(f'Build initiated with ID: {build_id}') + build_id = build_data["build_id"] + logger.info(f"Build initiated with ID: {build_id}") # Poll /build_status until the build is complete start_time = time.time() timeout = 30 * 60 # 20 minutes in seconds while should_continue(): if time.time() - start_time > timeout: - logger.error('Build timed out after 30 minutes') - raise RuntimeError('Build timed out after 30 minutes') + logger.error("Build timed out after 30 minutes") + raise RuntimeError("Build timed out after 30 minutes") status_response = send_request( self.session, - 'GET', - f'{self.api_url}/build_status', - params={'build_id': build_id}, + "GET", + f"{self.api_url}/build_status", + params={"build_id": build_id}, ) if status_response.status_code != 200: - logger.error(f'Failed to get build status: {status_response.text}') + logger.error(f"Failed to get build status: {status_response.text}") raise RuntimeError( - f'Failed to get build status: {status_response.text}' + f"Failed to get build status: {status_response.text}" ) status_data = status_response.json() - status = status_data['status'] - logger.info(f'Build status: {status}') + status = status_data["status"] + logger.info(f"Build status: {status}") - if status == 'SUCCESS': + if status == "SUCCESS": logger.debug(f"Successfully built {status_data['image']}") - return status_data['image'] + return status_data["image"] elif status in [ - 'FAILURE', - 'INTERNAL_ERROR', - 'TIMEOUT', - 'CANCELLED', - 'EXPIRED', + "FAILURE", + "INTERNAL_ERROR", + "TIMEOUT", + "CANCELLED", + "EXPIRED", ]: error_message = status_data.get( - 'error', f'Build failed with status: {status}. Build ID: {build_id}' + "error", f"Build failed with status: {status}. Build ID: {build_id}" ) logger.error(error_message) raise RuntimeError(error_message) @@ -109,31 +109,31 @@ def build(self, path: str, tags: list[str], platform: str | None = None) -> str: # Wait before polling again sleep_if_should_continue(30) - raise RuntimeError('Build interrupted (likely received SIGTERM or SIGINT).') + raise RuntimeError("Build interrupted (likely received SIGTERM or SIGINT).") def image_exists(self, image_name: str, pull_from_repo: bool = True) -> bool: """Checks if an image exists in the remote registry using the /image_exists endpoint.""" - params = {'image': image_name} + params = {"image": image_name} response = send_request( self.session, - 'GET', - f'{self.api_url}/image_exists', + "GET", + f"{self.api_url}/image_exists", params=params, ) if response.status_code != 200: - logger.error(f'Failed to check image existence: {response.text}') - raise RuntimeError(f'Failed to check image existence: {response.text}') + logger.error(f"Failed to check image existence: {response.text}") + raise RuntimeError(f"Failed to check image existence: {response.text}") result = response.json() - if result['exists']: + if result["exists"]: logger.debug( f"Image {image_name} exists. " f"Uploaded at: {result['image']['upload_time']}, " f"Size: {result['image']['image_size_bytes'] / 1024 / 1024:.2f} MB" ) else: - logger.debug(f'Image {image_name} does not exist.') + logger.debug(f"Image {image_name} does not exist.") - return result['exists'] + return result["exists"] diff --git a/openhands/runtime/impl/e2b/e2b_runtime.py b/openhands/runtime/impl/e2b/e2b_runtime.py index 7c9c297f424c..8c084daab3a8 100644 --- a/openhands/runtime/impl/e2b/e2b_runtime.py +++ b/openhands/runtime/impl/e2b/e2b_runtime.py @@ -24,7 +24,7 @@ def __init__( self, config: AppConfig, event_stream: EventStream, - sid: str = 'default', + sid: str = "default", plugins: list[PluginRequirement] | None = None, sandbox: E2BSandbox | None = None, status_callback: Optional[Callable] = None, @@ -39,27 +39,27 @@ def __init__( if sandbox is None: self.sandbox = E2BSandbox() if not isinstance(self.sandbox, E2BSandbox): - raise ValueError('E2BRuntime requires an E2BSandbox') + raise ValueError("E2BRuntime requires an E2BSandbox") self.file_store = E2BFileStore(self.sandbox.filesystem) def read(self, action: FileReadAction) -> Observation: content = self.file_store.read(action.path) - lines = read_lines(content.split('\n'), action.start, action.end) - code_view = ''.join(lines) + lines = read_lines(content.split("\n"), action.start, action.end) + code_view = "".join(lines) return FileReadObservation(code_view, path=action.path) def write(self, action: FileWriteAction) -> Observation: if action.start == 0 and action.end == -1: self.file_store.write(action.path, action.content) - return FileWriteObservation(content='', path=action.path) + return FileWriteObservation(content="", path=action.path) files = self.file_store.list(action.path) if action.path in files: - all_lines = self.file_store.read(action.path).split('\n') + all_lines = self.file_store.read(action.path).split("\n") new_file = insert_lines( - action.content.split('\n'), all_lines, action.start, action.end + action.content.split("\n"), all_lines, action.start, action.end ) - self.file_store.write(action.path, ''.join(new_file)) - return FileWriteObservation('', path=action.path) + self.file_store.write(action.path, "".join(new_file)) + return FileWriteObservation("", path=action.path) else: # FIXME: we should create a new file here - return ErrorObservation(f'File not found: {action.path}') + return ErrorObservation(f"File not found: {action.path}") diff --git a/openhands/runtime/impl/e2b/sandbox.py b/openhands/runtime/impl/e2b/sandbox.py index d145dac35115..fd4858d278bf 100644 --- a/openhands/runtime/impl/e2b/sandbox.py +++ b/openhands/runtime/impl/e2b/sandbox.py @@ -12,7 +12,7 @@ class E2BBox: closed = False - _cwd: str = '/home/user' + _cwd: str = "/home/user" _env: dict[str, str] = {} is_initial_session: bool = True @@ -20,7 +20,7 @@ def __init__( self, config: SandboxConfig, e2b_api_key: str, - template: str = 'openhands', + template: str = "openhands", ): self.config = copy.deepcopy(config) self.initialize_plugins: bool = config.initialize_plugins @@ -28,8 +28,8 @@ def __init__( api_key=e2b_api_key, template=template, # It's possible to stream stdout and stderr from sandbox and from each process - on_stderr=lambda x: logger.debug(f'E2B sandbox stderr: {x}'), - on_stdout=lambda x: logger.debug(f'E2B sandbox stdout: {x}'), + on_stderr=lambda x: logger.debug(f"E2B sandbox stderr: {x}"), + on_stdout=lambda x: logger.debug(f"E2B sandbox stdout: {x}"), cwd=self._cwd, # Default workdir inside sandbox ) logger.debug(f'Started E2B sandbox with ID "{self.sandbox.id}"') @@ -42,11 +42,11 @@ def _archive(self, host_src: str, recursive: bool = False): if recursive: assert os.path.isdir( host_src - ), 'Source must be a directory when recursive is True' - files = glob(host_src + '/**/*', recursive=True) + ), "Source must be a directory when recursive is True" + files = glob(host_src + "/**/*", recursive=True) srcname = os.path.basename(host_src) - tar_filename = os.path.join(os.path.dirname(host_src), srcname + '.tar') - with tarfile.open(tar_filename, mode='w') as tar: + tar_filename = os.path.join(os.path.dirname(host_src), srcname + ".tar") + with tarfile.open(tar_filename, mode="w") as tar: for file in files: tar.add( file, arcname=os.path.relpath(file, os.path.dirname(host_src)) @@ -54,10 +54,10 @@ def _archive(self, host_src: str, recursive: bool = False): else: assert os.path.isfile( host_src - ), 'Source must be a file when recursive is False' + ), "Source must be a file when recursive is False" srcname = os.path.basename(host_src) - tar_filename = os.path.join(os.path.dirname(host_src), srcname + '.tar') - with tarfile.open(tar_filename, mode='w') as tar: + tar_filename = os.path.join(os.path.dirname(host_src), srcname + ".tar") + with tarfile.open(tar_filename, mode="w") as tar: tar.add(host_src, arcname=srcname) return tar_filename @@ -67,12 +67,12 @@ def execute(self, cmd: str, timeout: int | None = None) -> tuple[int, str]: try: process_output = process.wait(timeout=timeout) except TimeoutException: - logger.debug('Command timed out, killing process...') + logger.debug("Command timed out, killing process...") process.kill() return -1, f'Command: "{cmd}" timed out' logs = [m.line for m in process_output.messages] - logs_str = '\n'.join(logs) + logs_str = "\n".join(logs) if process.exit_code is None: return -1, logs_str @@ -84,24 +84,24 @@ def copy_to(self, host_src: str, sandbox_dest: str, recursive: bool = False): tar_filename = self._archive(host_src, recursive) # Prepend the sandbox destination with our sandbox cwd - sandbox_dest = os.path.join(self._cwd, sandbox_dest.removeprefix('/')) + sandbox_dest = os.path.join(self._cwd, sandbox_dest.removeprefix("/")) - with open(tar_filename, 'rb') as tar_file: + with open(tar_filename, "rb") as tar_file: # Upload the archive to /home/user (default destination that always exists) uploaded_path = self.sandbox.upload_file(tar_file) # Check if sandbox_dest exists. If not, create it. - process = self.sandbox.process.start_and_wait(f'test -d {sandbox_dest}') + process = self.sandbox.process.start_and_wait(f"test -d {sandbox_dest}") if process.exit_code != 0: self.sandbox.filesystem.make_dir(sandbox_dest) # Extract the archive into the destination and delete the archive process = self.sandbox.process.start_and_wait( - f'sudo tar -xf {uploaded_path} -C {sandbox_dest} && sudo rm {uploaded_path}' + f"sudo tar -xf {uploaded_path} -C {sandbox_dest} && sudo rm {uploaded_path}" ) if process.exit_code != 0: raise Exception( - f'Failed to extract {uploaded_path} to {sandbox_dest}: {process.stderr}' + f"Failed to extract {uploaded_path} to {sandbox_dest}: {process.stderr}" ) # Delete the local archive diff --git a/openhands/runtime/impl/eventstream/eventstream_runtime.py b/openhands/runtime/impl/eventstream/eventstream_runtime.py index dbf6599ea66a..5ecf665f0c6b 100644 --- a/openhands/runtime/impl/eventstream/eventstream_runtime.py +++ b/openhands/runtime/impl/eventstream/eventstream_runtime.py @@ -44,7 +44,7 @@ from openhands.utils.async_utils import call_sync_from_async from openhands.utils.tenacity_stop import stop_if_should_exit -CONTAINER_NAME_PREFIX = 'openhands-runtime-' +CONTAINER_NAME_PREFIX = "openhands-runtime-" def remove_all_runtime_containers(): @@ -63,7 +63,7 @@ class LogBuffer: """ def __init__(self, container: docker.models.containers.Container, logFn: Callable): - self.init_msg = 'Runtime client initialized.' + self.init_msg = "Runtime client initialized." self.buffer: list[str] = [] self.lock = threading.Lock() @@ -95,15 +95,15 @@ def stream_logs(self): if self._stop_event.is_set(): break if log_line: - decoded_line = log_line.decode('utf-8').rstrip() + decoded_line = log_line.decode("utf-8").rstrip() self.append(decoded_line) except Exception as e: - self.log('error', f'Error streaming docker logs: {e}') + self.log("error", f"Error streaming docker logs: {e}") def __del__(self): if self.log_stream_thread.is_alive(): self.log( - 'warn', + "warn", "LogBuffer was not properly closed. Use 'log_buffer.close()' for clean shutdown.", ) self.close(timeout=5) @@ -131,7 +131,7 @@ def init_base_runtime( self, config: AppConfig, event_stream: EventStream, - sid: str = 'default', + sid: str = "default", plugins: list[PluginRequirement] | None = None, env_vars: dict[str, str] | None = None, status_callback: Callable | None = None, @@ -151,7 +151,7 @@ def __init__( self, config: AppConfig, event_stream: EventStream, - sid: str = 'default', + sid: str = "default", plugins: list[PluginRequirement] | None = None, env_vars: dict[str, str] | None = None, status_callback: Callable | None = None, @@ -160,7 +160,7 @@ def __init__( self.config = config self._host_port = 30000 # initial dummy value self._container_port = 30001 # initial dummy value - self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}' + self.api_url = f"{self.config.sandbox.local_runtime_url}:{self._container_port}" self.session = requests.Session() self.status_callback = status_callback @@ -178,8 +178,8 @@ def __init__( if self.config.sandbox.runtime_extra_deps: self.log( - 'debug', - f'Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.runtime_extra_deps}', + "debug", + f"Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.runtime_extra_deps}", ) self.init_base_runtime( @@ -193,22 +193,22 @@ def __init__( ) async def connect(self): - self.send_status_message('STATUS$STARTING_RUNTIME') + self.send_status_message("STATUS$STARTING_RUNTIME") try: await call_sync_from_async(self._attach_to_container) except docker.errors.NotFound as e: if self.attach_to_existing: self.log( - 'error', - f'Container {self.container_name} not found.', + "error", + f"Container {self.container_name} not found.", ) raise e if self.runtime_container_image is None: if self.base_container_image is None: raise ValueError( - 'Neither runtime container image nor base container image is set' + "Neither runtime container image nor base container image is set" ) - self.send_status_message('STATUS$STARTING_CONTAINER') + self.send_status_message("STATUS$STARTING_CONTAINER") self.runtime_container_image = build_runtime_image( self.base_container_image, self.runtime_builder, @@ -218,29 +218,29 @@ async def connect(self): ) self.log( - 'info', f'Starting runtime with image: {self.runtime_container_image}' + "info", f"Starting runtime with image: {self.runtime_container_image}" ) await call_sync_from_async(self._init_container) - self.log('info', f'Container started: {self.container_name}') + self.log("info", f"Container started: {self.container_name}") if not self.attach_to_existing: - self.log('info', f'Waiting for client to become ready at {self.api_url}...') - self.send_status_message('STATUS$WAITING_FOR_CLIENT') + self.log("info", f"Waiting for client to become ready at {self.api_url}...") + self.send_status_message("STATUS$WAITING_FOR_CLIENT") await call_sync_from_async(self._wait_until_alive) if not self.attach_to_existing: - self.log('info', 'Runtime is ready.') + self.log("info", "Runtime is ready.") if not self.attach_to_existing: await call_sync_from_async(self.setup_initial_env) self.log( - 'debug', - f'Container initialized with plugins: {[plugin.name for plugin in self.plugins]}', + "debug", + f"Container initialized with plugins: {[plugin.name for plugin in self.plugins]}", ) if not self.attach_to_existing: - self.send_status_message(' ') + self.send_status_message(" ") @staticmethod @lru_cache(maxsize=1) @@ -249,14 +249,14 @@ def _init_docker_client() -> docker.DockerClient: return docker.from_env() except Exception as ex: logger.error( - 'Launch docker client failed. Please make sure you have installed docker and started docker desktop/daemon.', + "Launch docker client failed. Please make sure you have installed docker and started docker desktop/daemon.", ) raise ex def _init_container(self): - self.log('debug', 'Preparing to start container...') - self.send_status_message('STATUS$PREPARING_CONTAINER') - plugin_arg = '' + self.log("debug", "Preparing to start container...") + self.send_status_message("STATUS$PREPARING_CONTAINER") + plugin_arg = "" if self.plugins is not None and len(self.plugins) > 0: plugin_arg = ( f'--plugins {" ".join([plugin.name for plugin in self.plugins])} ' @@ -266,31 +266,31 @@ def _init_container(self): self._container_port = ( self._host_port ) # in future this might differ from host port - self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}' + self.api_url = f"{self.config.sandbox.local_runtime_url}:{self._container_port}" use_host_network = self.config.sandbox.use_host_network - network_mode: str | None = 'host' if use_host_network else None + network_mode: str | None = "host" if use_host_network else None port_mapping: dict[str, list[dict[str, str]]] | None = ( None if use_host_network - else {f'{self._container_port}/tcp': [{'HostPort': str(self._host_port)}]} + else {f"{self._container_port}/tcp": [{"HostPort": str(self._host_port)}]} ) if use_host_network: self.log( - 'warn', - 'Using host network mode. If you are using MacOS, please make sure you have the latest version of Docker Desktop and enabled host network feature: https://docs.docker.com/network/drivers/host/#docker-desktop', + "warn", + "Using host network mode. If you are using MacOS, please make sure you have the latest version of Docker Desktop and enabled host network feature: https://docs.docker.com/network/drivers/host/#docker-desktop", ) # Combine environment variables environment = { - 'port': str(self._container_port), - 'PYTHONUNBUFFERED': 1, + "port": str(self._container_port), + "PYTHONUNBUFFERED": 1, } if self.config.debug or DEBUG: - environment['DEBUG'] = 'true' + environment["DEBUG"] = "true" - self.log('debug', f'Workspace Base: {self.config.workspace_base}') + self.log("debug", f"Workspace Base: {self.config.workspace_base}") if ( self.config.workspace_mount_path is not None and self.config.workspace_mount_path_in_sandbox is not None @@ -298,27 +298,27 @@ def _init_container(self): # e.g. result would be: {"/home/user/openhands/workspace": {'bind': "/workspace", 'mode': 'rw'}} volumes = { self.config.workspace_mount_path: { - 'bind': self.config.workspace_mount_path_in_sandbox, - 'mode': 'rw', + "bind": self.config.workspace_mount_path_in_sandbox, + "mode": "rw", } } - logger.debug(f'Mount dir: {self.config.workspace_mount_path}') + logger.debug(f"Mount dir: {self.config.workspace_mount_path}") else: logger.debug( - 'Mount dir is not set, will not mount the workspace directory to the container' + "Mount dir is not set, will not mount the workspace directory to the container" ) volumes = None self.log( - 'debug', - f'Sandbox workspace: {self.config.workspace_mount_path_in_sandbox}', + "debug", + f"Sandbox workspace: {self.config.workspace_mount_path_in_sandbox}", ) if self.config.sandbox.browsergym_eval_env is not None: browsergym_arg = ( - f'--browsergym-eval-env {self.config.sandbox.browsergym_eval_env}' + f"--browsergym-eval-env {self.config.sandbox.browsergym_eval_env}" ) else: - browsergym_arg = '' + browsergym_arg = "" try: self.container = self.docker_client.containers.run( @@ -335,35 +335,35 @@ def _init_container(self): ), network_mode=network_mode, ports=port_mapping, - working_dir='/openhands/code/', # do not change this! + working_dir="/openhands/code/", # do not change this! name=self.container_name, detach=True, environment=environment, volumes=volumes, ) self.log_buffer = LogBuffer(self.container, self.log) - self.log('debug', f'Container started. Server url: {self.api_url}') - self.send_status_message('STATUS$CONTAINER_STARTED') + self.log("debug", f"Container started. Server url: {self.api_url}") + self.send_status_message("STATUS$CONTAINER_STARTED") except docker.errors.APIError as e: - if '409' in str(e): + if "409" in str(e): self.log( - 'warning', - f'Container {self.container_name} already exists. Removing...', + "warning", + f"Container {self.container_name} already exists. Removing...", ) remove_all_containers(self.container_name) return self._init_container() else: self.log( - 'error', - f'Error: Instance {self.container_name} FAILED to start container!\n', + "error", + f"Error: Instance {self.container_name} FAILED to start container!\n", ) except Exception as e: self.log( - 'error', - f'Error: Instance {self.container_name} FAILED to start container!\n', + "error", + f"Error: Instance {self.container_name} FAILED to start container!\n", ) - self.log('error', str(e)) + self.log("error", str(e)) self.close() raise e @@ -372,35 +372,35 @@ def _attach_to_container(self): self.log_buffer = LogBuffer(container, self.log) self.container = container self._container_port = 0 - for port in container.attrs['NetworkSettings']['Ports']: - self._container_port = int(port.split('/')[0]) + for port in container.attrs["NetworkSettings"]["Ports"]: + self._container_port = int(port.split("/")[0]) break self._host_port = self._container_port - self.api_url = f'{self.config.sandbox.local_runtime_url}:{self._container_port}' + self.api_url = f"{self.config.sandbox.local_runtime_url}:{self._container_port}" self.log( - 'debug', - f'attached to container: {self.container_name} {self._container_port} {self.api_url}', + "debug", + f"attached to container: {self.container_name} {self._container_port} {self.api_url}", ) def _refresh_logs(self): - self.log('debug', 'Getting container logs...') + self.log("debug", "Getting container logs...") assert ( self.log_buffer is not None - ), 'Log buffer is expected to be initialized when container is started' + ), "Log buffer is expected to be initialized when container is started" logs = self.log_buffer.get_and_clear() if logs: - formatted_logs = '\n'.join([f' |{log}' for log in logs]) + formatted_logs = "\n".join([f" |{log}" for log in logs]) self.log( - 'debug', - '\n' - + '-' * 35 - + 'Container logs:' - + '-' * 35 - + f'\n{formatted_logs}' - + '\n' - + '-' * 80, + "debug", + "\n" + + "-" * 35 + + "Container logs:" + + "-" * 35 + + f"\n{formatted_logs}" + + "\n" + + "-" * 80, ) @tenacity.retry( @@ -411,12 +411,12 @@ def _refresh_logs(self): def _wait_until_alive(self): self._refresh_logs() if not self.log_buffer: - raise RuntimeError('Runtime client is not ready.') + raise RuntimeError("Runtime client is not ready.") send_request( self.session, - 'GET', - f'{self.api_url}/alive', + "GET", + f"{self.api_url}/alive", timeout=5, ) @@ -449,27 +449,27 @@ def run_action(self, action: Action) -> Observation: with self.action_semaphore: if not action.runnable: - return NullObservation('') + return NullObservation("") if ( - hasattr(action, 'confirmation_state') + hasattr(action, "confirmation_state") and action.confirmation_state == ActionConfirmationStatus.AWAITING_CONFIRMATION ): - return NullObservation('') + return NullObservation("") action_type = action.action # type: ignore[attr-defined] if action_type not in ACTION_TYPE_TO_CLASS: - raise ValueError(f'Action {action_type} does not exist.') + raise ValueError(f"Action {action_type} does not exist.") if not hasattr(self, action_type): return ErrorObservation( - f'Action {action_type} is not supported in the current runtime.', - error_id='AGENT_ERROR$BAD_ACTION', + f"Action {action_type} is not supported in the current runtime.", + error_id="AGENT_ERROR$BAD_ACTION", ) if ( - getattr(action, 'confirmation_state', None) + getattr(action, "confirmation_state", None) == ActionConfirmationStatus.REJECTED ): return UserRejectObservation( - 'Action has been rejected by the user! Waiting for further user input.' + "Action has been rejected by the user! Waiting for further user input." ) self._refresh_logs() @@ -479,9 +479,9 @@ def run_action(self, action: Action) -> Observation: try: response = send_request( self.session, - 'POST', - f'{self.api_url}/execute_action', - json={'action': event_to_dict(action)}, + "POST", + f"{self.api_url}/execute_action", + json={"action": event_to_dict(action)}, # wait a few more seconds to get the timeout error from client side timeout=action.timeout + 5, ) @@ -490,7 +490,7 @@ def run_action(self, action: Action) -> Observation: obs._cause = action.id # type: ignore[attr-defined] except requests.Timeout: raise RuntimeError( - f'Runtime failed to return execute_action before the requested timeout of {action.timeout}s' + f"Runtime failed to return execute_action before the requested timeout of {action.timeout}s" ) self._refresh_logs() return obs @@ -521,18 +521,18 @@ def copy_to( self, host_src: str, sandbox_dest: str, recursive: bool = False ) -> None: if not os.path.exists(host_src): - raise FileNotFoundError(f'Source file {host_src} does not exist') + raise FileNotFoundError(f"Source file {host_src} does not exist") self._refresh_logs() try: if recursive: # For recursive copy, create a zip file with tempfile.NamedTemporaryFile( - suffix='.zip', delete=False + suffix=".zip", delete=False ) as temp_zip: temp_zip_path = temp_zip.name - with ZipFile(temp_zip_path, 'w') as zipf: + with ZipFile(temp_zip_path, "w") as zipf: for root, _, files in os.walk(host_src): for file in files: file_path = os.path.join(root, file) @@ -541,31 +541,31 @@ def copy_to( ) zipf.write(file_path, arcname) - upload_data = {'file': open(temp_zip_path, 'rb')} + upload_data = {"file": open(temp_zip_path, "rb")} else: # For single file copy - upload_data = {'file': open(host_src, 'rb')} + upload_data = {"file": open(host_src, "rb")} - params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()} + params = {"destination": sandbox_dest, "recursive": str(recursive).lower()} send_request( self.session, - 'POST', - f'{self.api_url}/upload_file', + "POST", + f"{self.api_url}/upload_file", files=upload_data, params=params, timeout=300, ) except requests.Timeout: - raise TimeoutError('Copy operation timed out') + raise TimeoutError("Copy operation timed out") except Exception as e: - raise RuntimeError(f'Copy operation failed: {str(e)}') + raise RuntimeError(f"Copy operation failed: {str(e)}") finally: if recursive: os.unlink(temp_zip_path) self.log( - 'debug', f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}' + "debug", f"Copy completed: host:{host_src} -> runtime:{sandbox_dest}" ) self._refresh_logs() @@ -578,12 +578,12 @@ def list_files(self, path: str | None = None) -> list[str]: try: data = {} if path is not None: - data['path'] = path + data["path"] = path response = send_request( self.session, - 'POST', - f'{self.api_url}/list_files', + "POST", + f"{self.api_url}/list_files", json=data, timeout=10, ) @@ -591,17 +591,17 @@ def list_files(self, path: str | None = None) -> list[str]: assert isinstance(response_json, list) return response_json except requests.Timeout: - raise TimeoutError('List files operation timed out') + raise TimeoutError("List files operation timed out") def copy_from(self, path: str) -> Path: """Zip all files in the sandbox and return as a stream of bytes.""" self._refresh_logs() try: - params = {'path': path} + params = {"path": path} response = send_request( self.session, - 'GET', - f'{self.api_url}/download_files', + "GET", + f"{self.api_url}/download_files", params=params, stream=True, timeout=30, @@ -612,7 +612,7 @@ def copy_from(self, path: str) -> Path: temp_file.write(chunk) return Path(temp_file.name) except requests.Timeout: - raise TimeoutError('Copy operation timed out') + raise TimeoutError("Copy operation timed out") def _is_port_in_use_docker(self, port): containers = self.docker_client.containers.list() diff --git a/openhands/runtime/impl/modal/modal_runtime.py b/openhands/runtime/impl/modal/modal_runtime.py index 0e598a437f41..a71021de8e9d 100644 --- a/openhands/runtime/impl/modal/modal_runtime.py +++ b/openhands/runtime/impl/modal/modal_runtime.py @@ -29,7 +29,7 @@ # Modal's log generator returns strings, but the upstream LogBuffer expects bytes. def bytes_shim(string_generator) -> Generator[bytes, None, None]: for line in string_generator: - yield line.encode('utf-8') + yield line.encode("utf-8") class ModalLogBuffer(LogBuffer): @@ -41,7 +41,7 @@ class ModalLogBuffer(LogBuffer): """ def __init__(self, sandbox: modal.Sandbox): - self.init_msg = 'Runtime client initialized.' + self.init_msg = "Runtime client initialized." self.buffer: list[str] = [] self.lock = threading.Lock() @@ -65,21 +65,21 @@ class ModalRuntime(EventStreamRuntime): env_vars (dict[str, str] | None, optional): Environment variables to set. Defaults to None. """ - container_name_prefix = 'openhands-sandbox-' + container_name_prefix = "openhands-sandbox-" sandbox: modal.Sandbox | None def __init__( self, config: AppConfig, event_stream: EventStream, - sid: str = 'default', + sid: str = "default", plugins: list[PluginRequirement] | None = None, env_vars: dict[str, str] | None = None, status_callback: Callable | None = None, attach_to_existing: bool = False, ): - assert config.modal_api_token_id, 'Modal API token id is required' - assert config.modal_api_token_secret, 'Modal API token secret is required' + assert config.modal_api_token_id, "Modal API token id is required" + assert config.modal_api_token_secret, "Modal API token secret is required" self.config = config self.sandbox = None @@ -88,14 +88,14 @@ def __init__( config.modal_api_token_id, config.modal_api_token_secret ) self.app = modal.App.lookup( - 'openhands', create_if_missing=True, client=self.modal_client + "openhands", create_if_missing=True, client=self.modal_client ) # workspace_base cannot be used because we can't bind mount into a sandbox. if self.config.workspace_base is not None: self.log( - 'warning', - 'Setting workspace_base is not supported in the modal runtime.', + "warning", + "Setting workspace_base is not supported in the modal runtime.", ) # This value is arbitrary as it's private to the container @@ -112,8 +112,8 @@ def __init__( if self.config.sandbox.runtime_extra_deps: self.log( - 'debug', - f'Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.runtime_extra_deps}', + "debug", + f"Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.runtime_extra_deps}", ) self.init_base_runtime( @@ -127,9 +127,9 @@ def __init__( ) async def connect(self): - self.send_status_message('STATUS$STARTING_RUNTIME') + self.send_status_message("STATUS$STARTING_RUNTIME") - self.log('debug', f'ModalRuntime `{self.sid}`') + self.log("debug", f"ModalRuntime `{self.sid}`") self.image = self._get_image_definition( self.base_container_image_id, @@ -140,36 +140,36 @@ async def connect(self): if self.attach_to_existing: if self.sid in MODAL_RUNTIME_IDS: sandbox_id = MODAL_RUNTIME_IDS[self.sid] - self.log('debug', f'Attaching to existing Modal sandbox: {sandbox_id}') + self.log("debug", f"Attaching to existing Modal sandbox: {sandbox_id}") self.sandbox = modal.Sandbox.from_id( sandbox_id, client=self.modal_client ) else: - self.send_status_message('STATUS$PREPARING_CONTAINER') + self.send_status_message("STATUS$PREPARING_CONTAINER") await call_sync_from_async( self._init_sandbox, sandbox_workspace_dir=self.config.workspace_mount_path_in_sandbox, plugins=self.plugins, ) - self.send_status_message('STATUS$CONTAINER_STARTED') + self.send_status_message("STATUS$CONTAINER_STARTED") self.log_buffer = ModalLogBuffer(self.sandbox) if self.sandbox is None: - raise Exception('Sandbox not initialized') + raise Exception("Sandbox not initialized") tunnel = self.sandbox.tunnels()[self.container_port] self.api_url = tunnel.url - self.log('debug', f'Container started. Server url: {self.api_url}') + self.log("debug", f"Container started. Server url: {self.api_url}") if not self.attach_to_existing: - self.log('debug', 'Waiting for client to become ready...') - self.send_status_message('STATUS$WAITING_FOR_CLIENT') + self.log("debug", "Waiting for client to become ready...") + self.send_status_message("STATUS$WAITING_FOR_CLIENT") self._wait_until_alive() self.setup_initial_env() if not self.attach_to_existing: - self.send_status_message(' ') + self.send_status_message(" ") def _get_image_definition( self, @@ -189,15 +189,15 @@ def _get_image_definition( ) base_runtime_image = modal.Image.from_dockerfile( - path=os.path.join(build_folder, 'Dockerfile'), + path=os.path.join(build_folder, "Dockerfile"), context_mount=modal.Mount.from_local_dir( local_path=build_folder, - remote_path='.', # to current WORKDIR + remote_path=".", # to current WORKDIR ), ) else: raise ValueError( - 'Neither runtime container image nor base container image is set' + "Neither runtime container image nor base container image is set" ) return base_runtime_image.run_commands( @@ -219,43 +219,43 @@ def _init_sandbox( plugins: list[PluginRequirement] | None = None, ): try: - self.log('debug', 'Preparing to start container...') + self.log("debug", "Preparing to start container...") plugin_args = [] if plugins is not None and len(plugins) > 0: - plugin_args.append('--plugins') + plugin_args.append("--plugins") plugin_args.extend([plugin.name for plugin in plugins]) # Combine environment variables environment: dict[str, str | None] = { - 'port': str(self.container_port), - 'PYTHONUNBUFFERED': '1', + "port": str(self.container_port), + "PYTHONUNBUFFERED": "1", } if self.config.debug: - environment['DEBUG'] = 'true' + environment["DEBUG"] = "true" browsergym_args = [] if self.config.sandbox.browsergym_eval_env is not None: browsergym_args = [ - '-browsergym-eval-env', + "-browsergym-eval-env", self.config.sandbox.browsergym_eval_env, ] env_secret = modal.Secret.from_dict(environment) - self.log('debug', f'Sandbox workspace: {sandbox_workspace_dir}') + self.log("debug", f"Sandbox workspace: {sandbox_workspace_dir}") sandbox_start_cmd = get_remote_startup_command( self.container_port, sandbox_workspace_dir, - 'openhands' if self.config.run_as_openhands else 'root', + "openhands" if self.config.run_as_openhands else "root", self.config.sandbox.user_id, plugin_args, browsergym_args, ) - self.log('debug', f'Starting container with command: {sandbox_start_cmd}') + self.log("debug", f"Starting container with command: {sandbox_start_cmd}") self.sandbox = modal.Sandbox.create( *sandbox_start_cmd, secrets=[env_secret], - workdir='/openhands/code', + workdir="/openhands/code", encrypted_ports=[self.container_port], image=self.image, app=self.app, @@ -263,13 +263,13 @@ def _init_sandbox( timeout=60 * 60, ) MODAL_RUNTIME_IDS[self.sid] = self.sandbox.object_id - self.log('debug', 'Container started') + self.log("debug", "Container started") except Exception as e: self.log( - 'error', f'Error: Instance {self.sid} FAILED to start container!\n' + "error", f"Error: Instance {self.sid} FAILED to start container!\n" ) - self.log('error', str(e)) + self.log("error", str(e)) self.close() raise e diff --git a/openhands/runtime/impl/remote/remote_runtime.py b/openhands/runtime/impl/remote/remote_runtime.py index 97b16c1c83fa..3a26fc65f7df 100644 --- a/openhands/runtime/impl/remote/remote_runtime.py +++ b/openhands/runtime/impl/remote/remote_runtime.py @@ -52,7 +52,7 @@ def __init__( self, config: AppConfig, event_stream: EventStream, - sid: str = 'default', + sid: str = "default", plugins: list[PluginRequirement] | None = None, env_vars: dict[str, str] | None = None, status_callback: Optional[Callable] = None, @@ -73,15 +73,15 @@ def __init__( ) if self.config.sandbox.api_key is None: raise ValueError( - 'API key is required to use the remote runtime. ' - 'Please set the API key in the config (config.toml) or as an environment variable (SANDBOX_API_KEY).' + "API key is required to use the remote runtime. " + "Please set the API key in the config (config.toml) or as an environment variable (SANDBOX_API_KEY)." ) - self.session.headers.update({'X-API-Key': self.config.sandbox.api_key}) + self.session.headers.update({"X-API-Key": self.config.sandbox.api_key}) if self.config.workspace_base is not None: self.log( - 'debug', - 'Setting workspace_base is not supported in the remote runtime.', + "debug", + "Setting workspace_base is not supported in the remote runtime.", ) self.runtime_builder = RemoteRuntimeBuilder( @@ -94,98 +94,98 @@ async def connect(self): try: await call_sync_from_async(self._start_or_attach_to_runtime) except RuntimeNotReadyError: - self.log('error', 'Runtime failed to start, timed out before ready') + self.log("error", "Runtime failed to start, timed out before ready") raise await call_sync_from_async(self.setup_initial_env) def _start_or_attach_to_runtime(self): existing_runtime = self._check_existing_runtime() if existing_runtime: - self.log('debug', f'Using existing runtime with ID: {self.runtime_id}') + self.log("debug", f"Using existing runtime with ID: {self.runtime_id}") elif self.attach_to_existing: - raise RuntimeError('Could not find existing runtime to attach to.') + raise RuntimeError("Could not find existing runtime to attach to.") else: - self.send_status_message('STATUS$STARTING_CONTAINER') + self.send_status_message("STATUS$STARTING_CONTAINER") if self.config.sandbox.runtime_container_image is None: self.log( - 'info', - f'Building remote runtime with base image: {self.config.sandbox.base_container_image}', + "info", + f"Building remote runtime with base image: {self.config.sandbox.base_container_image}", ) self._build_runtime() else: self.log( - 'info', - f'Starting remote runtime with image: {self.config.sandbox.runtime_container_image}', + "info", + f"Starting remote runtime with image: {self.config.sandbox.runtime_container_image}", ) self.container_image = self.config.sandbox.runtime_container_image self._start_runtime() assert ( self.runtime_id is not None - ), 'Runtime ID is not set. This should never happen.' + ), "Runtime ID is not set. This should never happen." assert ( self.runtime_url is not None - ), 'Runtime URL is not set. This should never happen.' - self.send_status_message('STATUS$WAITING_FOR_CLIENT') + ), "Runtime URL is not set. This should never happen." + self.send_status_message("STATUS$WAITING_FOR_CLIENT") if not self.attach_to_existing: - self.log('info', 'Waiting for runtime to be alive...') + self.log("info", "Waiting for runtime to be alive...") self._wait_until_alive() if not self.attach_to_existing: - self.log('info', 'Runtime is ready.') - self.send_status_message(' ') + self.log("info", "Runtime is ready.") + self.send_status_message(" ") def _check_existing_runtime(self) -> bool: try: response = self._send_request( - 'GET', - f'{self.config.sandbox.remote_runtime_api_url}/sessions/{self.sid}', + "GET", + f"{self.config.sandbox.remote_runtime_api_url}/sessions/{self.sid}", is_retry=False, timeout=5, ) except requests.HTTPError as e: if e.response.status_code == 404: return False - self.log('debug', f'Error while looking for remote runtime: {e}') + self.log("debug", f"Error while looking for remote runtime: {e}") raise data = response.json() - status = data.get('status') - if status == 'running': + status = data.get("status") + if status == "running": self._parse_runtime_response(response) return True - elif status == 'stopped': - self.log('debug', 'Found existing remote runtime, but it is stopped') + elif status == "stopped": + self.log("debug", "Found existing remote runtime, but it is stopped") return False - elif status == 'paused': - self.log('debug', 'Found existing remote runtime, but it is paused') + elif status == "paused": + self.log("debug", "Found existing remote runtime, but it is paused") self._parse_runtime_response(response) self._resume_runtime() return True else: - self.log('error', f'Invalid response from runtime API: {data}') + self.log("error", f"Invalid response from runtime API: {data}") return False def _build_runtime(self): - self.log('debug', f'Building RemoteRuntime config:\n{self.config}') + self.log("debug", f"Building RemoteRuntime config:\n{self.config}") response = self._send_request( - 'GET', - f'{self.config.sandbox.remote_runtime_api_url}/registry_prefix', + "GET", + f"{self.config.sandbox.remote_runtime_api_url}/registry_prefix", is_retry=False, timeout=10, ) response_json = response.json() - registry_prefix = response_json['registry_prefix'] - os.environ['OH_RUNTIME_RUNTIME_IMAGE_REPO'] = ( - registry_prefix.rstrip('/') + '/runtime' + registry_prefix = response_json["registry_prefix"] + os.environ["OH_RUNTIME_RUNTIME_IMAGE_REPO"] = ( + registry_prefix.rstrip("/") + "/runtime" ) self.log( - 'debug', + "debug", f'Runtime image repo: {os.environ["OH_RUNTIME_RUNTIME_IMAGE_REPO"]}', ) if self.config.sandbox.runtime_extra_deps: self.log( - 'debug', - f'Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.runtime_extra_deps}', + "debug", + f"Installing extra user-provided dependencies in the runtime image: {self.config.sandbox.runtime_extra_deps}", ) # Build the container image @@ -198,71 +198,71 @@ def _build_runtime(self): ) response = self._send_request( - 'GET', - f'{self.config.sandbox.remote_runtime_api_url}/image_exists', + "GET", + f"{self.config.sandbox.remote_runtime_api_url}/image_exists", is_retry=False, - params={'image': self.container_image}, + params={"image": self.container_image}, timeout=10, ) - if not response.json()['exists']: - raise RuntimeError(f'Container image {self.container_image} does not exist') + if not response.json()["exists"]: + raise RuntimeError(f"Container image {self.container_image} does not exist") def _start_runtime(self): # Prepare the request body for the /start endpoint plugin_args = [] if self.plugins is not None and len(self.plugins) > 0: - plugin_args = ['--plugins'] + [plugin.name for plugin in self.plugins] + plugin_args = ["--plugins"] + [plugin.name for plugin in self.plugins] browsergym_args = [] if self.config.sandbox.browsergym_eval_env is not None: browsergym_args = [ - '--browsergym-eval-env' - ] + self.config.sandbox.browsergym_eval_env.split(' ') + "--browsergym-eval-env" + ] + self.config.sandbox.browsergym_eval_env.split(" ") command = get_remote_startup_command( self.port, self.config.workspace_mount_path_in_sandbox, - 'openhands' if self.config.run_as_openhands else 'root', + "openhands" if self.config.run_as_openhands else "root", self.config.sandbox.user_id, plugin_args, browsergym_args, ) start_request = { - 'image': self.container_image, - 'command': command, - 'working_dir': '/openhands/code/', - 'environment': {'DEBUG': 'true'} if self.config.debug else {}, - 'session_id': self.sid, + "image": self.container_image, + "command": command, + "working_dir": "/openhands/code/", + "environment": {"DEBUG": "true"} if self.config.debug else {}, + "session_id": self.sid, } # Start the sandbox using the /start endpoint response = self._send_request( - 'POST', - f'{self.config.sandbox.remote_runtime_api_url}/start', + "POST", + f"{self.config.sandbox.remote_runtime_api_url}/start", is_retry=False, json=start_request, ) self._parse_runtime_response(response) self.log( - 'debug', - f'Runtime started. URL: {self.runtime_url}', + "debug", + f"Runtime started. URL: {self.runtime_url}", ) def _resume_runtime(self): self._send_request( - 'POST', - f'{self.config.sandbox.remote_runtime_api_url}/resume', + "POST", + f"{self.config.sandbox.remote_runtime_api_url}/resume", is_retry=False, - json={'runtime_id': self.runtime_id}, + json={"runtime_id": self.runtime_id}, timeout=30, ) - self.log('debug', 'Runtime resumed.') + self.log("debug", "Runtime resumed.") def _parse_runtime_response(self, response: requests.Response): start_response = response.json() - self.runtime_id = start_response['runtime_id'] - self.runtime_url = start_response['url'] - if 'session_api_key' in start_response: + self.runtime_id = start_response["runtime_id"] + self.runtime_url = start_response["url"] + if "session_api_key" in start_response: self.session.headers.update( - {'X-Session-API-Key': start_response['session_api_key']} + {"X-Session-API-Key": start_response["session_api_key"]} ) def _wait_until_alive(self): @@ -278,56 +278,56 @@ def _wait_until_alive(self): return retry_decorator(self._wait_until_alive_impl)() def _wait_until_alive_impl(self): - self.log('debug', f'Waiting for runtime to be alive at url: {self.runtime_url}') + self.log("debug", f"Waiting for runtime to be alive at url: {self.runtime_url}") runtime_info_response = self._send_request( - 'GET', - f'{self.config.sandbox.remote_runtime_api_url}/sessions/{self.sid}', + "GET", + f"{self.config.sandbox.remote_runtime_api_url}/sessions/{self.sid}", ) runtime_data = runtime_info_response.json() - assert 'runtime_id' in runtime_data - assert runtime_data['runtime_id'] == self.runtime_id - assert 'pod_status' in runtime_data - pod_status = runtime_data['pod_status'] - self.log('debug', f'Pod status: {pod_status}') + assert "runtime_id" in runtime_data + assert runtime_data["runtime_id"] == self.runtime_id + assert "pod_status" in runtime_data + pod_status = runtime_data["pod_status"] + self.log("debug", f"Pod status: {pod_status}") # FIXME: We should fix it at the backend of /start endpoint, make sure # the pod is created before returning the response. # Retry a period of time to give the cluster time to start the pod - if pod_status == 'Ready': + if pod_status == "Ready": try: self._send_request( - 'GET', - f'{self.runtime_url}/alive', + "GET", + f"{self.runtime_url}/alive", ) # will raise exception if we don't get 200 back. except requests.HTTPError as e: self.log( - 'warning', f"Runtime /alive failed, but pod says it's ready: {e}" + "warning", f"Runtime /alive failed, but pod says it's ready: {e}" ) raise RuntimeNotReadyError( - f'Runtime /alive failed to respond with 200: {e}' + f"Runtime /alive failed to respond with 200: {e}" ) return elif ( - pod_status == 'Not Found' - or pod_status == 'Pending' - or pod_status == 'Running' + pod_status == "Not Found" + or pod_status == "Pending" + or pod_status == "Running" ): # nb: Running is not yet Ready raise RuntimeNotReadyError( - f'Runtime (ID={self.runtime_id}) is not yet ready. Status: {pod_status}' + f"Runtime (ID={self.runtime_id}) is not yet ready. Status: {pod_status}" ) - elif pod_status in ('Failed', 'Unknown'): + elif pod_status in ("Failed", "Unknown"): # clean up the runtime self.close() raise RuntimeError( - f'Runtime (ID={self.runtime_id}) failed to start. Current status: {pod_status}' + f"Runtime (ID={self.runtime_id}) failed to start. Current status: {pod_status}" ) else: # Maybe this should be a hard failure, but passing through in case the API changes - self.log('warning', f'Unknown pod status: {pod_status}') + self.log("warning", f"Unknown pod status: {pod_status}") self.log( - 'debug', - f'Waiting for runtime pod to be active. Current status: {pod_status}', + "debug", + f"Waiting for runtime pod to be active. Current status: {pod_status}", ) raise RuntimeNotReadyError() @@ -338,19 +338,19 @@ def close(self, timeout: int = 10): if self.runtime_id and self.session: try: response = self._send_request( - 'POST', - f'{self.config.sandbox.remote_runtime_api_url}/stop', + "POST", + f"{self.config.sandbox.remote_runtime_api_url}/stop", is_retry=False, - json={'runtime_id': self.runtime_id}, + json={"runtime_id": self.runtime_id}, timeout=timeout, ) if response.status_code != 200: self.log( - 'error', - f'Failed to stop runtime: {response.text}', + "error", + f"Failed to stop runtime: {response.text}", ) else: - self.log('debug', 'Runtime stopped.') + self.log("debug", "Runtime stopped.") except Exception as e: raise e finally: @@ -363,24 +363,24 @@ def run_action(self, action: Action, is_retry: bool = False) -> Observation: return self.edit(action) with self.action_semaphore: if not action.runnable: - return NullObservation('') + return NullObservation("") action_type = action.action # type: ignore[attr-defined] if action_type not in ACTION_TYPE_TO_CLASS: - raise ValueError(f'Action {action_type} does not exist.') + raise ValueError(f"Action {action_type} does not exist.") if not hasattr(self, action_type): return ErrorObservation( - f'[Runtime (ID={self.runtime_id})] Action {action_type} is not supported in the current runtime.', - error_id='AGENT_ERROR$BAD_ACTION', + f"[Runtime (ID={self.runtime_id})] Action {action_type} is not supported in the current runtime.", + error_id="AGENT_ERROR$BAD_ACTION", ) assert action.timeout is not None try: - request_body = {'action': event_to_dict(action)} - self.log('debug', f'Request body: {request_body}') + request_body = {"action": event_to_dict(action)} + self.log("debug", f"Request body: {request_body}") response = self._send_request( - 'POST', - f'{self.runtime_url}/execute_action', + "POST", + f"{self.runtime_url}/execute_action", is_retry=False, json=request_body, # wait a few more seconds to get the timeout error from client side @@ -391,7 +391,7 @@ def run_action(self, action: Action, is_retry: bool = False) -> Observation: obs._cause = action.id # type: ignore[attr-defined] except requests.Timeout: raise RuntimeError( - f'Runtime failed to return execute_action before the requested timeout of {action.timeout}s' + f"Runtime failed to return execute_action before the requested timeout of {action.timeout}s" ) return obs @@ -400,16 +400,16 @@ def _send_request(self, method, url, is_retry=False, **kwargs): try: return send_request(self.session, method, url, **kwargs) except requests.Timeout: - self.log('error', 'No response received within the timeout period.') + self.log("error", "No response received within the timeout period.") raise except requests.HTTPError as e: if is_runtime_request and e.response.status_code == 404: raise RuntimeDisconnectedError( - f'404 error while connecting to {self.runtime_url}' + f"404 error while connecting to {self.runtime_url}" ) elif is_runtime_request and e.response.status_code == 503: if not is_retry: - self.log('warning', 'Runtime appears to be paused. Resuming...') + self.log("warning", "Runtime appears to be paused. Resuming...") self._resume_runtime() self._wait_until_alive() return self._send_request(method, url, True, **kwargs) @@ -441,16 +441,16 @@ def copy_to( self, host_src: str, sandbox_dest: str, recursive: bool = False ) -> None: if not os.path.exists(host_src): - raise FileNotFoundError(f'Source file {host_src} does not exist') + raise FileNotFoundError(f"Source file {host_src} does not exist") try: if recursive: with tempfile.NamedTemporaryFile( - suffix='.zip', delete=False + suffix=".zip", delete=False ) as temp_zip: temp_zip_path = temp_zip.name - with ZipFile(temp_zip_path, 'w') as zipf: + with ZipFile(temp_zip_path, "w") as zipf: for root, _, files in os.walk(host_src): for file in files: file_path = os.path.join(root, file) @@ -459,39 +459,39 @@ def copy_to( ) zipf.write(file_path, arcname) - upload_data = {'file': open(temp_zip_path, 'rb')} + upload_data = {"file": open(temp_zip_path, "rb")} else: - upload_data = {'file': open(host_src, 'rb')} + upload_data = {"file": open(host_src, "rb")} - params = {'destination': sandbox_dest, 'recursive': str(recursive).lower()} + params = {"destination": sandbox_dest, "recursive": str(recursive).lower()} response = self._send_request( - 'POST', - f'{self.runtime_url}/upload_file', + "POST", + f"{self.runtime_url}/upload_file", is_retry=False, files=upload_data, params=params, timeout=300, ) self.log( - 'debug', - f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}. Response: {response.text}', + "debug", + f"Copy completed: host:{host_src} -> runtime:{sandbox_dest}. Response: {response.text}", ) finally: if recursive: os.unlink(temp_zip_path) self.log( - 'debug', f'Copy completed: host:{host_src} -> runtime:{sandbox_dest}' + "debug", f"Copy completed: host:{host_src} -> runtime:{sandbox_dest}" ) def list_files(self, path: str | None = None) -> list[str]: data = {} if path is not None: - data['path'] = path + data["path"] = path response = self._send_request( - 'POST', - f'{self.runtime_url}/list_files', + "POST", + f"{self.runtime_url}/list_files", is_retry=False, json=data, timeout=30, @@ -502,10 +502,10 @@ def list_files(self, path: str | None = None) -> list[str]: def copy_from(self, path: str) -> Path: """Zip all files in the sandbox and return as a stream of bytes.""" - params = {'path': path} + params = {"path": path} response = self._send_request( - 'GET', - f'{self.runtime_url}/download_files', + "GET", + f"{self.runtime_url}/download_files", is_retry=False, params=params, stream=True, diff --git a/openhands/runtime/impl/runloop/runloop_runtime.py b/openhands/runtime/impl/runloop/runloop_runtime.py index 36ad4590b7a5..7e052f2e73b6 100644 --- a/openhands/runtime/impl/runloop/runloop_runtime.py +++ b/openhands/runtime/impl/runloop/runloop_runtime.py @@ -21,7 +21,7 @@ from openhands.runtime.utils.request import send_request from openhands.utils.tenacity_stop import stop_if_should_exit -CONTAINER_NAME_PREFIX = 'openhands-runtime-' +CONTAINER_NAME_PREFIX = "openhands-runtime-" class RunloopLogBuffer(LogBuffer): @@ -34,7 +34,7 @@ class RunloopLogBuffer(LogBuffer): def __init__(self, runloop_api_client: Runloop, devbox_id: str): self.client_ready = False - self.init_msg = 'Runtime client initialized.' + self.init_msg = "Runtime client initialized." self.buffer: list[str] = [] self.lock = threading.Lock() @@ -52,7 +52,6 @@ def stream_logs(self): This method runs in its own thread to handle the blocking operation of reading log lines from the Docker SDK's synchronous generator. """ - try: # TODO(Runloop) Replace with stream while True: @@ -76,7 +75,7 @@ def stream_logs(self): time.sleep(1) except Exception as e: - logger.error(f'Error streaming runloop logs: {e}') + logger.error(f"Error streaming runloop logs: {e}") # NB: Match LogBuffer behavior on below methods @@ -104,13 +103,13 @@ def __init__( self, config: AppConfig, event_stream: EventStream, - sid: str = 'default', + sid: str = "default", plugins: list[PluginRequirement] | None = None, env_vars: dict[str, str] | None = None, status_callback: Callable | None = None, attach_to_existing: bool = False, ): - assert config.runloop_api_key is not None, 'Runloop API key is required' + assert config.runloop_api_key is not None, "Runloop API key is required" self.devbox: DevboxView | None = None self.config = config self.runloop_api_client = Runloop( @@ -137,15 +136,15 @@ def __init__( ) def _wait_for_devbox(self, devbox: DevboxView) -> DevboxView: """Pull devbox status until it is running""" - if devbox == 'running': + if devbox == "running": return devbox devbox = self.runloop_api_client.devboxes.retrieve(id=devbox.id) - if devbox.status != 'running': - raise ConnectionRefusedError('Devbox is not running') + if devbox.status != "running": + raise ConnectionRefusedError("Devbox is not running") # Devbox is connected and running - logging.debug(f'devbox.id={devbox.id} is running') + logging.debug(f"devbox.id={devbox.id} is running") return devbox def _create_new_devbox(self) -> DevboxView: @@ -153,13 +152,13 @@ def _create_new_devbox(self) -> DevboxView: sandbox_workspace_dir = self.config.workspace_mount_path_in_sandbox plugin_args = [] if self.plugins is not None and len(self.plugins) > 0: - plugin_args.append('--plugins') + plugin_args.append("--plugins") plugin_args.extend([plugin.name for plugin in self.plugins]) browsergym_args = [] if self.config.sandbox.browsergym_eval_env is not None: browsergym_args = [ - '-browsergym-eval-env', + "-browsergym-eval-env", self.config.sandbox.browsergym_eval_env, ] @@ -167,7 +166,7 @@ def _create_new_devbox(self) -> DevboxView: start_command = get_remote_startup_command( self._sandbox_port, sandbox_workspace_dir, - 'openhands' if self.config.run_as_openhands else 'root', + "openhands" if self.config.run_as_openhands else "root", self.config.sandbox.user_id, plugin_args, browsergym_args, @@ -177,33 +176,33 @@ def _create_new_devbox(self) -> DevboxView: # NB: start off as root, action_execution_server will ultimately choose user but expects all context # (ie browser) to be installed as root start_command = ( - 'export MAMBA_ROOT_PREFIX=/openhands/micromamba && ' - 'cd /openhands/code && ' - + '/openhands/micromamba/bin/micromamba run -n openhands poetry config virtualenvs.path /openhands/poetry && ' - + ' '.join(start_command) + "export MAMBA_ROOT_PREFIX=/openhands/micromamba && " + "cd /openhands/code && " + + "/openhands/micromamba/bin/micromamba run -n openhands poetry config virtualenvs.path /openhands/poetry && " + + " ".join(start_command) ) entrypoint = f"sudo bash -c '{start_command}'" devbox = self.runloop_api_client.devboxes.create( entrypoint=entrypoint, - setup_commands=[f'mkdir -p {self.config.workspace_mount_path_in_sandbox}'], + setup_commands=[f"mkdir -p {self.config.workspace_mount_path_in_sandbox}"], name=self.sid, - environment_variables={'DEBUG': 'true'} if self.config.debug else {}, - prebuilt='openhands', + environment_variables={"DEBUG": "true"} if self.config.debug else {}, + prebuilt="openhands", launch_parameters=LaunchParameters( available_ports=[self._sandbox_port], - resource_size_request='LARGE', + resource_size_request="LARGE", ), - metadata={'container-name': self.container_name}, + metadata={"container-name": self.container_name}, ) return self._wait_for_devbox(devbox) async def connect(self): - self.send_status_message('STATUS$STARTING_RUNTIME') + self.send_status_message("STATUS$STARTING_RUNTIME") if self.attach_to_existing: active_devboxes = self.runloop_api_client.devboxes.list( - status='running' + status="running" ).devboxes self.devbox = next( (devbox for devbox in active_devboxes if devbox.name == self.sid), None @@ -220,22 +219,22 @@ async def connect(self): # Hook up logs self.log_buffer = RunloopLogBuffer(self.runloop_api_client, self.devbox.id) - self.api_url = f'https://{tunnel.url}' - logger.info(f'Container started. Server url: {self.api_url}') + self.api_url = f"https://{tunnel.url}" + logger.info(f"Container started. Server url: {self.api_url}") # End Runloop connect # NOTE: Copied from EventStreamRuntime - logger.info('Waiting for client to become ready...') - self.send_status_message('STATUS$WAITING_FOR_CLIENT') + logger.info("Waiting for client to become ready...") + self.send_status_message("STATUS$WAITING_FOR_CLIENT") self._wait_until_alive() if not self.attach_to_existing: self.setup_initial_env() logger.info( - f'Container initialized with plugins: {[plugin.name for plugin in self.plugins]}' + f"Container initialized with plugins: {[plugin.name for plugin in self.plugins]}" ) - self.send_status_message(' ') + self.send_status_message(" ") @tenacity.retry( stop=tenacity.stop_after_delay(120) | stop_if_should_exit(), @@ -246,17 +245,17 @@ def _wait_until_alive(self): # NB(Runloop): Remote logs are not guaranteed realtime, removing client_ready check from logs self._refresh_logs() if not self.log_buffer: - raise RuntimeError('Runtime client is not ready.') + raise RuntimeError("Runtime client is not ready.") response = send_request( self.session, - 'GET', - f'{self.api_url}/alive', + "GET", + f"{self.api_url}/alive", timeout=5, ) if response.status_code == 200: return else: - msg = f'Action execution API is not alive. Response: {response}' + msg = f"Action execution API is not alive. Response: {response}" logger.error(msg) raise RuntimeError(msg) diff --git a/openhands/runtime/plugins/__init__.py b/openhands/runtime/plugins/__init__.py index 66bc499a112b..25bc9cf4cc97 100644 --- a/openhands/runtime/plugins/__init__.py +++ b/openhands/runtime/plugins/__init__.py @@ -7,15 +7,15 @@ from openhands.runtime.plugins.requirement import Plugin, PluginRequirement __all__ = [ - 'Plugin', - 'PluginRequirement', - 'AgentSkillsRequirement', - 'AgentSkillsPlugin', - 'JupyterRequirement', - 'JupyterPlugin', + "Plugin", + "PluginRequirement", + "AgentSkillsRequirement", + "AgentSkillsPlugin", + "JupyterRequirement", + "JupyterPlugin", ] ALL_PLUGINS = { - 'jupyter': JupyterPlugin, - 'agent_skills': AgentSkillsPlugin, + "jupyter": JupyterPlugin, + "agent_skills": AgentSkillsPlugin, } diff --git a/openhands/runtime/plugins/agent_skills/__init__.py b/openhands/runtime/plugins/agent_skills/__init__.py index 01f9d7e028ee..afaa0e3b74f1 100644 --- a/openhands/runtime/plugins/agent_skills/__init__.py +++ b/openhands/runtime/plugins/agent_skills/__init__.py @@ -6,9 +6,9 @@ @dataclass class AgentSkillsRequirement(PluginRequirement): - name: str = 'agent_skills' + name: str = "agent_skills" documentation: str = agentskills.DOCUMENTATION class AgentSkillsPlugin(Plugin): - name: str = 'agent_skills' + name: str = "agent_skills" diff --git a/openhands/runtime/plugins/agent_skills/agentskills.py b/openhands/runtime/plugins/agent_skills/agentskills.py index 046f8af20c61..b1d88cb470ee 100644 --- a/openhands/runtime/plugins/agent_skills/agentskills.py +++ b/openhands/runtime/plugins/agent_skills/agentskills.py @@ -11,21 +11,21 @@ ) __all__ = file_ops.__all__ + file_reader.__all__ -DOCUMENTATION = '' +DOCUMENTATION = "" for func_name in __all__: func = globals()[func_name] cur_doc = func.__doc__ # remove indentation from docstring and extra empty lines - cur_doc = '\n'.join(filter(None, map(lambda x: x.strip(), cur_doc.split('\n')))) + cur_doc = "\n".join(filter(None, map(lambda x: x.strip(), cur_doc.split("\n")))) # now add a consistent 4 indentation - cur_doc = '\n'.join(map(lambda x: ' ' * 4 + x, cur_doc.split('\n'))) + cur_doc = "\n".join(map(lambda x: " " * 4 + x, cur_doc.split("\n"))) - fn_signature = f'{func.__name__}' + str(signature(func)) - DOCUMENTATION += f'{fn_signature}:\n{cur_doc}\n\n' + fn_signature = f"{func.__name__}" + str(signature(func)) + DOCUMENTATION += f"{fn_signature}:\n{cur_doc}\n\n" # Add file_editor (a function) from openhands.runtime.plugins.agent_skills.file_editor import file_editor # noqa: E402 -__all__ += ['file_editor'] +__all__ += ["file_editor"] diff --git a/openhands/runtime/plugins/agent_skills/file_editor/__init__.py b/openhands/runtime/plugins/agent_skills/file_editor/__init__.py index 06d5bcca6325..8fdfd6761be6 100644 --- a/openhands/runtime/plugins/agent_skills/file_editor/__init__.py +++ b/openhands/runtime/plugins/agent_skills/file_editor/__init__.py @@ -5,4 +5,4 @@ from openhands_aci.editor import file_editor -__all__ = ['file_editor'] +__all__ = ["file_editor"] diff --git a/openhands/runtime/plugins/agent_skills/file_ops/file_ops.py b/openhands/runtime/plugins/agent_skills/file_ops/file_ops.py index b2e1b4c8aa4c..ada615f4ec13 100644 --- a/openhands/runtime/plugins/agent_skills/file_ops/file_ops.py +++ b/openhands/runtime/plugins/agent_skills/file_ops/file_ops.py @@ -21,15 +21,15 @@ WINDOW = 100 # This is also used in unit tests! -MSG_FILE_UPDATED = '[File updated (edited at line {line_number}). Please review the changes and make sure they are correct (correct indentation, no duplicate lines, etc). Edit the file again if necessary.]' -LINTER_ERROR_MSG = '[Your proposed edit has introduced new syntax error(s). Please understand the errors and retry your edit command.]\n' +MSG_FILE_UPDATED = "[File updated (edited at line {line_number}). Please review the changes and make sure they are correct (correct indentation, no duplicate lines, etc). Edit the file again if necessary.]" +LINTER_ERROR_MSG = "[Your proposed edit has introduced new syntax error(s). Please understand the errors and retry your edit command.]\n" # ================================================================================================== def _output_error(error_msg: str) -> bool: - print(f'ERROR: {error_msg}') + print(f"ERROR: {error_msg}") return False @@ -37,10 +37,10 @@ def _is_valid_filename(file_name) -> bool: if not file_name or not isinstance(file_name, str) or not file_name.strip(): return False invalid_chars = '<>:"/\\|?*' - if os.name == 'nt': # Windows + if os.name == "nt": # Windows invalid_chars = '<>:"/\\|?*' - elif os.name == 'posix': # Unix-like systems - invalid_chars = '\0' + elif os.name == "posix": # Unix-like systems + invalid_chars = "\0" for char in invalid_chars: if char in file_name: @@ -72,7 +72,7 @@ def _check_current_file(file_path: str | None = None) -> bool: if not file_path: file_path = CURRENT_FILE if not file_path or not os.path.isfile(file_path): - return _output_error('No file open. Use the open_file function first.') + return _output_error("No file open. Use the open_file function first.") return True @@ -93,8 +93,8 @@ def _lint_file(file_path: str) -> tuple[str | None, int | None]: # Linting successful. No issues found. return None, None first_error_line = lint_error[0].line if len(lint_error) > 0 else None - error_text = 'ERRORS:\n' + '\n'.join( - [f'{file_path}:{err.line}:{err.column}: {err.message}' for err in lint_error] + error_text = "ERRORS:\n" + "\n".join( + [f"{file_path}:{err.line}:{err.column}: {err.message}" for err in lint_error] ) return error_text, first_error_line @@ -108,8 +108,8 @@ def _print_window( content = file.read() # Ensure the content ends with a newline character - if not content.endswith('\n'): - content += '\n' + if not content.endswith("\n"): + content += "\n" lines = content.splitlines(True) # Keep all line ending characters total_lines = len(lines) @@ -132,22 +132,22 @@ def _print_window( if end == total_lines: start = max(1, end - window + 1) - output = '' + output = "" # only display this when there's at least one line above if start > 1: - output += f'({start - 1} more lines above)\n' + output += f"({start - 1} more lines above)\n" else: - output += '(this is the beginning of the file)\n' + output += "(this is the beginning of the file)\n" for i in range(start, end + 1): - _new_line = f'{i}|{lines[i-1]}' - if not _new_line.endswith('\n'): - _new_line += '\n' + _new_line = f"{i}|{lines[i-1]}" + if not _new_line.endswith("\n"): + _new_line += "\n" output += _new_line if end < total_lines: - output += f'({total_lines - end} more lines below)\n' + output += f"({total_lines - end} more lines below)\n" else: - output += '(this is the end of the file)\n' + output += "(this is the end of the file)\n" output = output.rstrip() if return_str: @@ -158,8 +158,8 @@ def _print_window( def _cur_file_header(current_file, total_lines) -> str: if not current_file: - return '' - return f'[File: {os.path.abspath(current_file)} ({total_lines} lines total)]\n' + return "" + return f"[File: {os.path.abspath(current_file)} ({total_lines} lines total)]\n" def open_file( @@ -177,7 +177,7 @@ def open_file( global CURRENT_FILE, CURRENT_LINE, WINDOW if not os.path.isfile(path): - _output_error(f'File {path} not found.') + _output_error(f"File {path} not found.") return CURRENT_FILE = os.path.abspath(path) @@ -185,7 +185,7 @@ def open_file( total_lines = max(1, sum(1 for _ in file)) if not isinstance(line_number, int) or line_number < 1 or line_number > total_lines: - _output_error(f'Line number must be between 1 and {total_lines}') + _output_error(f"Line number must be between 1 and {total_lines}") return CURRENT_LINE = line_number @@ -201,8 +201,8 @@ def open_file( return_str=True, ignore_window=False, ) - if output.strip().endswith('more lines below)'): - output += '\n[Use `scroll_down` to view the next 100 lines of the file!]' + if output.strip().endswith("more lines below)"): + output += "\n[Use `scroll_down` to view the next 100 lines of the file!]" print(output) @@ -218,7 +218,7 @@ def goto_line(line_number: int) -> None: with open(str(CURRENT_FILE)) as file: total_lines = max(1, sum(1 for _ in file)) if not isinstance(line_number, int) or line_number < 1 or line_number > total_lines: - _output_error(f'Line number must be between 1 and {total_lines}.') + _output_error(f"Line number must be between 1 and {total_lines}.") return CURRENT_LINE = _clamp(line_number, 1, total_lines) @@ -272,7 +272,7 @@ class LineNumberError(Exception): pass -def search_dir(search_term: str, dir_path: str = './') -> None: +def search_dir(search_term: str, dir_path: str = "./") -> None: """Searches for search_term in all files in dir. If dir is not provided, searches in the current directory. Args: @@ -280,15 +280,15 @@ def search_dir(search_term: str, dir_path: str = './') -> None: dir_path: str: The path to the directory to search. """ if not os.path.isdir(dir_path): - _output_error(f'Directory {dir_path} not found') + _output_error(f"Directory {dir_path} not found") return matches = [] for root, _, files in os.walk(dir_path): for file in files: - if file.startswith('.'): + if file.startswith("."): continue file_path = os.path.join(root, file) - with open(file_path, 'r', errors='ignore') as f: + with open(file_path, "r", errors="ignore") as f: for line_num, line in enumerate(f, 1): if search_term in line: matches.append((file_path, line_num, line.strip())) @@ -308,7 +308,7 @@ def search_dir(search_term: str, dir_path: str = './') -> None: print(f'[Found {num_matches} matches for "{search_term}" in {dir_path}]') for file_path, line_num, line in matches: - print(f'{file_path} (Line {line_num}): {line}') + print(f"{file_path} (Line {line_num}): {line}") print(f'[End of matches for "{search_term}" in {dir_path}]') @@ -323,10 +323,10 @@ def search_file(search_term: str, file_path: str | None = None) -> None: if file_path is None: file_path = CURRENT_FILE if file_path is None: - _output_error('No file specified or open. Use the open_file function first.') + _output_error("No file specified or open. Use the open_file function first.") return if not os.path.isfile(file_path): - _output_error(f'File {file_path} not found.') + _output_error(f"File {file_path} not found.") return matches = [] @@ -338,13 +338,13 @@ def search_file(search_term: str, file_path: str | None = None) -> None: if matches: print(f'[Found {len(matches)} matches for "{search_term}" in {file_path}]') for match in matches: - print(f'Line {match[0]}: {match[1]}') + print(f"Line {match[0]}: {match[1]}") print(f'[End of matches for "{search_term}" in {file_path}]') else: print(f'[No matches found for "{search_term}" in {file_path}]') -def find_file(file_name: str, dir_path: str = './') -> None: +def find_file(file_name: str, dir_path: str = "./") -> None: """Finds all files with the given name in the specified directory. Args: @@ -352,7 +352,7 @@ def find_file(file_name: str, dir_path: str = './') -> None: dir_path: str: The path to the directory to search. """ if not os.path.isdir(dir_path): - _output_error(f'Directory {dir_path} not found') + _output_error(f"Directory {dir_path} not found") return matches = [] @@ -364,18 +364,18 @@ def find_file(file_name: str, dir_path: str = './') -> None: if matches: print(f'[Found {len(matches)} matches for "{file_name}" in {dir_path}]') for match in matches: - print(f'{match}') + print(f"{match}") print(f'[End of matches for "{file_name}" in {dir_path}]') else: print(f'[No matches found for "{file_name}" in {dir_path}]') __all__ = [ - 'open_file', - 'goto_line', - 'scroll_down', - 'scroll_up', - 'search_dir', - 'search_file', - 'find_file', + "open_file", + "goto_line", + "scroll_down", + "scroll_up", + "search_dir", + "search_file", + "find_file", ] diff --git a/openhands/runtime/plugins/agent_skills/file_reader/file_readers.py b/openhands/runtime/plugins/agent_skills/file_reader/file_readers.py index ee41eab0e4bb..7f61ac910e9c 100644 --- a/openhands/runtime/plugins/agent_skills/file_reader/file_readers.py +++ b/openhands/runtime/plugins/agent_skills/file_reader/file_readers.py @@ -40,14 +40,14 @@ def parse_pdf(file_path: str) -> None: Args: file_path: str: The path to the file to open. """ - print(f'[Reading PDF file from {file_path}]') + print(f"[Reading PDF file from {file_path}]") content = PyPDF2.PdfReader(file_path) - text = '' + text = "" for page_idx in range(len(content.pages)): text += ( - f'@@ Page {page_idx + 1} @@\n' + f"@@ Page {page_idx + 1} @@\n" + content.pages[page_idx].extract_text() - + '\n\n' + + "\n\n" ) print(text.strip()) @@ -58,11 +58,11 @@ def parse_docx(file_path: str) -> None: Args: file_path: str: The path to the file to open. """ - print(f'[Reading DOCX file from {file_path}]') + print(f"[Reading DOCX file from {file_path}]") content = docx.Document(file_path) - text = '' + text = "" for i, para in enumerate(content.paragraphs): - text += f'@@ Page {i + 1} @@\n' + para.text + '\n\n' + text += f"@@ Page {i + 1} @@\n" + para.text + "\n\n" print(text) @@ -72,7 +72,7 @@ def parse_latex(file_path: str) -> None: Args: file_path: str: The path to the file to open. """ - print(f'[Reading LaTex file from {file_path}]') + print(f"[Reading LaTex file from {file_path}]") with open(file_path) as f: data = f.read() text = LatexNodes2Text().latex_to_text(data) @@ -80,8 +80,8 @@ def parse_latex(file_path: str) -> None: def _base64_img(file_path: str) -> str: - with open(file_path, 'rb') as image_file: - encoded_image = base64.b64encode(image_file.read()).decode('utf-8') + with open(file_path, "rb") as image_file: + encoded_image = base64.b64encode(image_file.read()).decode("utf-8") return encoded_image @@ -96,8 +96,8 @@ def _base64_video(file_path: str, frame_interval: int = 10) -> list[str]: if not success: break if frame_count % frame_interval == 0: - _, buffer = cv2.imencode('.jpg', frame) - base64_frames.append(base64.b64encode(buffer).decode('utf-8')) + _, buffer = cv2.imencode(".jpg", frame) + base64_frames.append(base64.b64encode(buffer).decode("utf-8")) frame_count += 1 video.release() return base64_frames @@ -106,40 +106,40 @@ def _base64_video(file_path: str, frame_interval: int = 10) -> list[str]: def _prepare_image_messages(task: str, base64_image: str): return [ { - 'role': 'user', - 'content': [ - {'type': 'text', 'text': task}, + "role": "user", + "content": [ + {"type": "text", "text": task}, { - 'type': 'image_url', - 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}, + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}, }, ], } ] -def parse_audio(file_path: str, model: str = 'whisper-1') -> None: +def parse_audio(file_path: str, model: str = "whisper-1") -> None: """Parses the content of an audio file and prints it. Args: file_path: str: The path to the audio file to transcribe. model: str: The audio model to use for transcription. Defaults to 'whisper-1'. """ - print(f'[Transcribing audio file from {file_path}]') + print(f"[Transcribing audio file from {file_path}]") try: # TODO: record the COST of the API call - with open(file_path, 'rb') as audio_file: + with open(file_path, "rb") as audio_file: transcript = _get_openai_client().audio.translations.create( model=model, file=audio_file ) print(transcript.text) except Exception as e: - print(f'Error transcribing audio file: {e}') + print(f"Error transcribing audio file: {e}") def parse_image( - file_path: str, task: str = 'Describe this image as detail as possible.' + file_path: str, task: str = "Describe this image as detail as possible." ) -> None: """Parses the content of an image file and prints the description. @@ -147,7 +147,7 @@ def parse_image( file_path: str: The path to the file to open. task: str: The task description for the API call. Defaults to 'Describe this image as detail as possible.'. """ - print(f'[Reading image file from {file_path}]') + print(f"[Reading image file from {file_path}]") # TODO: record the COST of the API call try: base64_image = _base64_img(file_path) @@ -160,12 +160,12 @@ def parse_image( print(content) except Exception as error: - print(f'Error with the request: {error}') + print(f"Error with the request: {error}") def parse_video( file_path: str, - task: str = 'Describe this image as detail as possible.', + task: str = "Describe this image as detail as possible.", frame_interval: int = 30, ) -> None: """Parses the content of an image file and prints the description. @@ -177,10 +177,10 @@ def parse_video( """ print( - f'[Processing video file from {file_path} with frame interval {frame_interval}]' + f"[Processing video file from {file_path} with frame interval {frame_interval}]" ) - task = task or 'This is one frame from a video, please summarize this frame.' + task = task or "This is one frame from a video, please summarize this frame." base64_frames = _base64_video(file_path) selected_frames = base64_frames[::frame_interval] @@ -188,12 +188,12 @@ def parse_video( new_interval = len(base64_frames) // 30 selected_frames = base64_frames[::new_interval] - print(f'Totally {len(selected_frames)} would be analyze...\n') + print(f"Totally {len(selected_frames)} would be analyze...\n") idx = 0 for base64_frame in selected_frames: idx += 1 - print(f'Process the {file_path}, current No. {idx * frame_interval} frame...') + print(f"Process the {file_path}, current No. {idx * frame_interval} frame...") # TODO: record the COST of the API call try: response = _get_openai_client().chat.completions.create( @@ -207,7 +207,7 @@ def parse_video( print(current_frame_content) except Exception as error: - print(f'Error with the request: {error}') + print(f"Error with the request: {error}") def parse_pptx(file_path: str) -> None: @@ -216,29 +216,29 @@ def parse_pptx(file_path: str) -> None: Args: file_path: str: The path to the file to open. """ - print(f'[Reading PowerPoint file from {file_path}]') + print(f"[Reading PowerPoint file from {file_path}]") try: pres = Presentation(str(file_path)) text = [] for slide_idx, slide in enumerate(pres.slides): - text.append(f'@@ Slide {slide_idx + 1} @@') + text.append(f"@@ Slide {slide_idx + 1} @@") for shape in slide.shapes: - if hasattr(shape, 'text'): + if hasattr(shape, "text"): text.append(shape.text) - print('\n'.join(text)) + print("\n".join(text)) except Exception as e: - print(f'Error reading PowerPoint file: {e}') + print(f"Error reading PowerPoint file: {e}") __all__ = [ - 'parse_pdf', - 'parse_docx', - 'parse_latex', - 'parse_pptx', + "parse_pdf", + "parse_docx", + "parse_latex", + "parse_pptx", ] # This is called from OpenHands's side # If SANDBOX_ENV_OPENAI_API_KEY is set, we will be able to use these tools in the sandbox environment if _get_openai_api_key() and _get_openai_base_url(): - __all__ += ['parse_audio', 'parse_video', 'parse_image'] + __all__ += ["parse_audio", "parse_video", "parse_image"] diff --git a/openhands/runtime/plugins/agent_skills/utils/config.py b/openhands/runtime/plugins/agent_skills/utils/config.py index f0084c540393..a0a9bc45895e 100644 --- a/openhands/runtime/plugins/agent_skills/utils/config.py +++ b/openhands/runtime/plugins/agent_skills/utils/config.py @@ -10,19 +10,19 @@ # AFTER the agentskills is imported (the case for EventStreamRuntime) # ================================================================================================== def _get_openai_api_key(): - return os.getenv('OPENAI_API_KEY', os.getenv('SANDBOX_ENV_OPENAI_API_KEY', '')) + return os.getenv("OPENAI_API_KEY", os.getenv("SANDBOX_ENV_OPENAI_API_KEY", "")) def _get_openai_base_url(): - return os.getenv('OPENAI_BASE_URL', 'https://api.openai.com/v1') + return os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1") def _get_openai_model(): - return os.getenv('OPENAI_MODEL', 'gpt-4o') + return os.getenv("OPENAI_MODEL", "gpt-4o") def _get_max_token(): - return os.getenv('MAX_TOKEN', 500) + return os.getenv("MAX_TOKEN", 500) def _get_openai_client(): diff --git a/openhands/runtime/plugins/agent_skills/utils/dependency.py b/openhands/runtime/plugins/agent_skills/utils/dependency.py index 1ff1636fb545..e0c2f52cee51 100644 --- a/openhands/runtime/plugins/agent_skills/utils/dependency.py +++ b/openhands/runtime/plugins/agent_skills/utils/dependency.py @@ -8,4 +8,4 @@ def import_functions( if hasattr(module, name): target_globals[name] = getattr(module, name) else: - raise ValueError(f'Function {name} not found in {module.__name__}') + raise ValueError(f"Function {name} not found in {module.__name__}") diff --git a/openhands/runtime/plugins/jupyter/__init__.py b/openhands/runtime/plugins/jupyter/__init__.py index dd9842830283..422115e947b2 100644 --- a/openhands/runtime/plugins/jupyter/__init__.py +++ b/openhands/runtime/plugins/jupyter/__init__.py @@ -13,46 +13,46 @@ @dataclass class JupyterRequirement(PluginRequirement): - name: str = 'jupyter' + name: str = "jupyter" class JupyterPlugin(Plugin): - name: str = 'jupyter' + name: str = "jupyter" - async def initialize(self, username: str, kernel_id: str = 'openhands-default'): + async def initialize(self, username: str, kernel_id: str = "openhands-default"): self.kernel_gateway_port = find_available_tcp_port(40000, 49999) self.kernel_id = kernel_id self.gateway_process = subprocess.Popen( ( f"su - {username} -s /bin/bash << 'EOF'\n" - 'cd /openhands/code\n' - 'export POETRY_VIRTUALENVS_PATH=/openhands/poetry;\n' - 'export PYTHONPATH=/openhands/code:$PYTHONPATH;\n' - 'export MAMBA_ROOT_PREFIX=/openhands/micromamba;\n' - '/openhands/micromamba/bin/micromamba run -n openhands ' - 'poetry run jupyter kernelgateway ' - '--KernelGatewayApp.ip=0.0.0.0 ' - f'--KernelGatewayApp.port={self.kernel_gateway_port}\n' - 'EOF' + "cd /openhands/code\n" + "export POETRY_VIRTUALENVS_PATH=/openhands/poetry;\n" + "export PYTHONPATH=/openhands/code:$PYTHONPATH;\n" + "export MAMBA_ROOT_PREFIX=/openhands/micromamba;\n" + "/openhands/micromamba/bin/micromamba run -n openhands " + "poetry run jupyter kernelgateway " + "--KernelGatewayApp.ip=0.0.0.0 " + f"--KernelGatewayApp.port={self.kernel_gateway_port}\n" + "EOF" ), stderr=subprocess.STDOUT, shell=True, ) # read stdout until the kernel gateway is ready - output = '' + output = "" while should_continue() and self.gateway_process.stdout is not None: - line = self.gateway_process.stdout.readline().decode('utf-8') + line = self.gateway_process.stdout.readline().decode("utf-8") output += line - if 'at' in line: + if "at" in line: break time.sleep(1) - logger.debug('Waiting for jupyter kernel gateway to start...') + logger.debug("Waiting for jupyter kernel gateway to start...") logger.debug( - f'Jupyter kernel gateway started at port {self.kernel_gateway_port}. Output: {output}' + f"Jupyter kernel gateway started at port {self.kernel_gateway_port}. Output: {output}" ) _obs = await self.run( - IPythonRunCellAction(code='import sys; print(sys.executable)') + IPythonRunCellAction(code="import sys; print(sys.executable)") ) self.python_interpreter_path = _obs.content.strip() @@ -60,12 +60,12 @@ async def _run(self, action: Action) -> IPythonRunCellObservation: """Internal method to run a code cell in the jupyter kernel.""" if not isinstance(action, IPythonRunCellAction): raise ValueError( - f'Jupyter plugin only supports IPythonRunCellAction, but got {action}' + f"Jupyter plugin only supports IPythonRunCellAction, but got {action}" ) - if not hasattr(self, 'kernel'): + if not hasattr(self, "kernel"): self.kernel = JupyterKernel( - f'localhost:{self.kernel_gateway_port}', self.kernel_id + f"localhost:{self.kernel_gateway_port}", self.kernel_id ) if not self.kernel.initialized: diff --git a/openhands/runtime/plugins/jupyter/execute_server.py b/openhands/runtime/plugins/jupyter/execute_server.py index da038d526686..54bef75fc59a 100644 --- a/openhands/runtime/plugins/jupyter/execute_server.py +++ b/openhands/runtime/plugins/jupyter/execute_server.py @@ -44,21 +44,21 @@ def strip_ansi(o: str) -> str: 'Lorem dolor sit ipsum' """ # pattern = re.compile(r'/(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]/') - pattern = re.compile(r'\x1B\[\d+(;\d+){0,2}m') - stripped = pattern.sub('', o) + pattern = re.compile(r"\x1B\[\d+(;\d+){0,2}m") + stripped = pattern.sub("", o) return stripped class JupyterKernel: - def __init__(self, url_suffix, convid, lang='python'): - self.base_url = f'http://{url_suffix}' - self.base_ws_url = f'ws://{url_suffix}' + def __init__(self, url_suffix, convid, lang="python"): + self.base_url = f"http://{url_suffix}" + self.base_ws_url = f"ws://{url_suffix}" self.lang = lang self.kernel_id = None self.ws = None self.convid = convid logging.info( - f'Jupyter kernel created for conversation {convid} at {url_suffix}' + f"Jupyter kernel created for conversation {convid} at {url_suffix}" ) self.heartbeat_interval = 10000 # 10 seconds @@ -66,14 +66,14 @@ def __init__(self, url_suffix, convid, lang='python'): self.initialized = False async def initialize(self): - await self.execute(r'%colors nocolor') + await self.execute(r"%colors nocolor") # pre-defined tools self.tools_to_run: list[str] = [ # TODO: You can add code for your pre-defined tools here ] for tool in self.tools_to_run: res = await self.execute(tool) - logging.info(f'Tool [{tool}] initialized:\n{res}') + logging.info(f"Tool [{tool}] initialized:\n{res}") self.initialized = True async def _send_heartbeat(self): @@ -88,7 +88,7 @@ async def _send_heartbeat(self): await self._connect() except ConnectionRefusedError: logging.info( - 'ConnectionRefusedError: Failed to reconnect to kernel websocket - Is the kernel still running?' + "ConnectionRefusedError: Failed to reconnect to kernel websocket - Is the kernel still running?" ) async def _connect(self): @@ -102,12 +102,12 @@ async def _connect(self): while n_tries > 0: try: response = await client.fetch( - '{}/api/kernels'.format(self.base_url), - method='POST', - body=json_encode({'name': self.lang}), + "{}/api/kernels".format(self.base_url), + method="POST", + body=json_encode({"name": self.lang}), ) kernel = json_decode(response.body) - self.kernel_id = kernel['id'] + self.kernel_id = kernel["id"] break except Exception: # kernels are not ready yet @@ -115,15 +115,15 @@ async def _connect(self): await asyncio.sleep(1) if n_tries == 0: - raise ConnectionRefusedError('Failed to connect to kernel') + raise ConnectionRefusedError("Failed to connect to kernel") ws_req = HTTPRequest( - url='{}/api/kernels/{}/channels'.format( + url="{}/api/kernels/{}/channels".format( self.base_ws_url, url_escape(self.kernel_id) ) ) self.ws = await websocket_connect(ws_req) - logging.info('Connected to kernel websocket') + logging.info("Connected to kernel websocket") # Setup heartbeat if self.heartbeat_callback: @@ -147,28 +147,28 @@ async def execute(self, code, timeout=120): res = await self.ws.write_message( json_encode( { - 'header': { - 'username': '', - 'version': '5.0', - 'session': '', - 'msg_id': msg_id, - 'msg_type': 'execute_request', + "header": { + "username": "", + "version": "5.0", + "session": "", + "msg_id": msg_id, + "msg_type": "execute_request", }, - 'parent_header': {}, - 'channel': 'shell', - 'content': { - 'code': code, - 'silent': False, - 'store_history': False, - 'user_expressions': {}, - 'allow_stdin': False, + "parent_header": {}, + "channel": "shell", + "content": { + "code": code, + "silent": False, + "store_history": False, + "user_expressions": {}, + "allow_stdin": False, }, - 'metadata': {}, - 'buffers': {}, + "metadata": {}, + "buffers": {}, } ) ) - logging.info(f'Executed code in jupyter kernel:\n{res}') + logging.info(f"Executed code in jupyter kernel:\n{res}") outputs = [] @@ -178,68 +178,68 @@ async def wait_for_messages(): assert self.ws is not None msg = await self.ws.read_message() msg = json_decode(msg) - msg_type = msg['msg_type'] - parent_msg_id = msg['parent_header'].get('msg_id', None) + msg_type = msg["msg_type"] + parent_msg_id = msg["parent_header"].get("msg_id", None) if parent_msg_id != msg_id: continue - if os.environ.get('DEBUG'): + if os.environ.get("DEBUG"): logging.info( f"MSG TYPE: {msg_type.upper()} DONE:{execution_done}\nCONTENT: {msg['content']}" ) - if msg_type == 'error': - traceback = '\n'.join(msg['content']['traceback']) + if msg_type == "error": + traceback = "\n".join(msg["content"]["traceback"]) outputs.append(traceback) execution_done = True - elif msg_type == 'stream': - outputs.append(msg['content']['text']) - elif msg_type in ['execute_result', 'display_data']: - outputs.append(msg['content']['data']['text/plain']) - if 'image/png' in msg['content']['data']: + elif msg_type == "stream": + outputs.append(msg["content"]["text"]) + elif msg_type in ["execute_result", "display_data"]: + outputs.append(msg["content"]["data"]["text/plain"]) + if "image/png" in msg["content"]["data"]: # use markdone to display image (in case of large image) outputs.append( f"\n![image](data:image/png;base64,{msg['content']['data']['image/png']})\n" ) - elif msg_type == 'execute_reply': + elif msg_type == "execute_reply": execution_done = True return execution_done async def interrupt_kernel(): client = AsyncHTTPClient() interrupt_response = await client.fetch( - f'{self.base_url}/api/kernels/{self.kernel_id}/interrupt', - method='POST', - body=json_encode({'kernel_id': self.kernel_id}), + f"{self.base_url}/api/kernels/{self.kernel_id}/interrupt", + method="POST", + body=json_encode({"kernel_id": self.kernel_id}), ) - logging.info(f'Kernel interrupted: {interrupt_response}') + logging.info(f"Kernel interrupted: {interrupt_response}") try: execution_done = await asyncio.wait_for(wait_for_messages(), timeout) except asyncio.TimeoutError: await interrupt_kernel() - return f'[Execution timed out ({timeout} seconds).]' + return f"[Execution timed out ({timeout} seconds).]" if not outputs and execution_done: - ret = '[Code executed successfully with no output]' + ret = "[Code executed successfully with no output]" else: - ret = ''.join(outputs) + ret = "".join(outputs) # Remove ANSI ret = strip_ansi(ret) - if os.environ.get('DEBUG'): - logging.info(f'OUTPUT:\n{ret}') + if os.environ.get("DEBUG"): + logging.info(f"OUTPUT:\n{ret}") return ret async def shutdown_async(self): if self.kernel_id: client = AsyncHTTPClient() await client.fetch( - '{}/api/kernels/{}'.format(self.base_url, self.kernel_id), - method='DELETE', + "{}/api/kernels/{}".format(self.base_url, self.kernel_id), + method="DELETE", ) self.kernel_id = None if self.ws: @@ -253,11 +253,11 @@ def initialize(self, jupyter_kernel): async def post(self): data = json_decode(self.request.body) - code = data.get('code') + code = data.get("code") if not code: self.set_status(400) - self.write('Missing code') + self.write("Missing code") return output = await self.jupyter_kernel.execute(code) @@ -268,18 +268,18 @@ async def post(self): def make_app(): jupyter_kernel = JupyterKernel( f"localhost:{os.environ.get('JUPYTER_GATEWAY_PORT')}", - os.environ.get('JUPYTER_GATEWAY_KERNEL_ID'), + os.environ.get("JUPYTER_GATEWAY_KERNEL_ID"), ) asyncio.get_event_loop().run_until_complete(jupyter_kernel.initialize()) return tornado.web.Application( [ - (r'/execute', ExecuteHandler, {'jupyter_kernel': jupyter_kernel}), + (r"/execute", ExecuteHandler, {"jupyter_kernel": jupyter_kernel}), ] ) -if __name__ == '__main__': +if __name__ == "__main__": app = make_app() - app.listen(os.environ.get('JUPYTER_EXEC_SERVER_PORT')) + app.listen(os.environ.get("JUPYTER_EXEC_SERVER_PORT")) tornado.ioloop.IOLoop.current().start() diff --git a/openhands/runtime/utils/__init__.py b/openhands/runtime/utils/__init__.py index 622a0b609573..fdb4411a530c 100644 --- a/openhands/runtime/utils/__init__.py +++ b/openhands/runtime/utils/__init__.py @@ -3,4 +3,4 @@ find_available_tcp_port, ) -__all__ = ['display_number_matrix', 'find_available_tcp_port'] +__all__ = ["display_number_matrix", "find_available_tcp_port"] diff --git a/openhands/runtime/utils/bash.py b/openhands/runtime/utils/bash.py index a5019315a038..c08aa4fe25db 100644 --- a/openhands/runtime/utils/bash.py +++ b/openhands/runtime/utils/bash.py @@ -17,15 +17,15 @@ def split_bash_commands(commands): if not commands.strip(): - return [''] + return [""] try: parsed = bashlex.parse(commands) except bashlex.errors.ParsingError as e: logger.debug( - f'Failed to parse bash commands\n' - f'[input]: {commands}\n' - f'[warning]: {e}\n' - f'The original command will be returned as is.' + f"Failed to parse bash commands\n" + f"[input]: {commands}\n" + f"[warning]: {e}\n" + f"The original command will be returned as is." ) # If parsing fails, return the original commands return [commands] @@ -39,7 +39,7 @@ def split_bash_commands(commands): # Include any text between the last command and this one if start > last_end: between = commands[last_end:start] - logger.debug(f'BASH PARSING between: {between}') + logger.debug(f"BASH PARSING between: {between}") if result: result[-1] += between.rstrip() elif between.strip(): @@ -48,21 +48,21 @@ def split_bash_commands(commands): # Extract the command, preserving original formatting command = commands[start:end].rstrip() - logger.debug(f'BASH PARSING command: {command}') + logger.debug(f"BASH PARSING command: {command}") result.append(command) last_end = end # Add any remaining text after the last command to the last command remaining = commands[last_end:].rstrip() - logger.debug(f'BASH PARSING remaining: {remaining}') + logger.debug(f"BASH PARSING remaining: {remaining}") if last_end < len(commands) and result: result[-1] += remaining - logger.debug(f'BASH PARSING result[-1] += remaining: {result[-1]}') + logger.debug(f"BASH PARSING result[-1] += remaining: {result[-1]}") elif last_end < len(commands): if remaining: result.append(remaining) - logger.debug(f'BASH PARSING result.append(remaining): {result[-1]}') + logger.debug(f"BASH PARSING result.append(remaining): {result[-1]}") return result @@ -73,9 +73,9 @@ def __init__(self, work_dir: str, username: str): self._pwd = work_dir self.shell = pexpect.spawn( - f'su {username}', - encoding='utf-8', - codec_errors='replace', + f"su {username}", + encoding="utf-8", + codec_errors="replace", echo=False, ) self._init_bash_shell(work_dir) @@ -93,23 +93,23 @@ def workdir(self): def _get_working_directory(self): # NOTE: this is part of initialization, so we hard code the timeout - result, exit_code = self._execute_bash('pwd', timeout=60, keep_prompt=False) + result, exit_code = self._execute_bash("pwd", timeout=60, keep_prompt=False) if exit_code != 0: raise RuntimeError( - f'Failed to get working directory (exit code: {exit_code}): {result}' + f"Failed to get working directory (exit code: {exit_code}): {result}" ) return result.strip() def _init_bash_shell(self, work_dir: str): self.__bash_PS1 = ( - r'[PEXPECT_BEGIN]\n' + r"[PEXPECT_BEGIN]\n" r'$(which python >/dev/null 2>&1 && echo "[Python Interpreter: $(which python)]\n")' - r'\u@\h:\w\n' - r'[PEXPECT_END]' + r"\u@\h:\w\n" + r"[PEXPECT_END]" ) # This should NOT match "PS1=\u@\h:\w [PEXPECT]$" when `env` is executed - self.__bash_expect_regex = r'\[PEXPECT_BEGIN\]\s*(.*?)\s*([a-z0-9_-]*)@([a-zA-Z0-9.-]*):(.+)\s*\[PEXPECT_END\]' + self.__bash_expect_regex = r"\[PEXPECT_BEGIN\]\s*(.*?)\s*([a-z0-9_-]*)@([a-zA-Z0-9.-]*):(.+)\s*\[PEXPECT_END\]" # Set umask to allow group write permissions self.shell.sendline(f'umask 002; export PS1="{self.__bash_PS1}"; export PS2=""') self.shell.expect(self.__bash_expect_regex) @@ -119,7 +119,7 @@ def _init_bash_shell(self, work_dir: str): ) self.shell.expect(self.__bash_expect_regex) logger.debug( - f'Bash initialized. Working directory: {work_dir}. Output: [{self.shell.before}]' + f"Bash initialized. Working directory: {work_dir}. Output: [{self.shell.before}]" ) # Ensure the group has write permissions on the working directory self.shell.sendline(f'chmod g+rw "{work_dir}"') @@ -128,17 +128,17 @@ def _init_bash_shell(self, work_dir: str): def _get_bash_prompt_and_update_pwd(self): ps1 = self.shell.after if ps1 == pexpect.EOF: - logger.error(f'Bash shell EOF! {self.shell.after=}, {self.shell.before=}') - raise RuntimeError('Bash shell EOF') + logger.error(f"Bash shell EOF! {self.shell.after=}, {self.shell.before=}") + raise RuntimeError("Bash shell EOF") if ps1 == pexpect.TIMEOUT: - logger.warning('Bash shell timeout') - return '' + logger.warning("Bash shell timeout") + return "" # begin at the last occurrence of '[PEXPECT_BEGIN]'. # In multi-line bash commands, the prompt will be repeated # and the matched regex captures all of them # - we only want the last one (newest prompt) - _begin_pos = ps1.rfind('[PEXPECT_BEGIN]') + _begin_pos = ps1.rfind("[PEXPECT_BEGIN]") if _begin_pos != -1: ps1 = ps1[_begin_pos:] @@ -146,19 +146,19 @@ def _get_bash_prompt_and_update_pwd(self): matched = re.match(self.__bash_expect_regex, ps1) assert ( matched is not None - ), f'Failed to parse bash prompt: {ps1}. This should not happen.' + ), f"Failed to parse bash prompt: {ps1}. This should not happen." other_info, username, hostname, working_dir = matched.groups() working_dir = working_dir.rstrip() self._pwd = os.path.expanduser(working_dir) # re-assemble the prompt # ignore the hostname AND use 'openhands-workspace' - prompt = f'{other_info.strip()}\n{username}@openhands-workspace:{working_dir} ' - if username == 'root': - prompt += '#' + prompt = f"{other_info.strip()}\n{username}@openhands-workspace:{working_dir} " + if username == "root": + prompt += "#" else: - prompt += '$' - return prompt + ' ' + prompt += "$" + return prompt + " " def _execute_bash( self, @@ -167,7 +167,7 @@ def _execute_bash( keep_prompt: bool = True, kill_on_timeout: bool = True, ) -> tuple[str, int]: - logger.debug(f'Executing command: {command}') + logger.debug(f"Executing command: {command}") self.shell.sendline(command) return self._continue_bash( timeout=timeout, keep_prompt=keep_prompt, kill_on_timeout=kill_on_timeout @@ -183,51 +183,51 @@ def _interrupt_bash( # try to interrupt the bash shell use SIGINT while max_retries > 0: self.shell.sendintr() # send SIGINT to the shell - logger.debug('Sent SIGINT to bash. Waiting for output...') + logger.debug("Sent SIGINT to bash. Waiting for output...") try: self.shell.expect(self.__bash_expect_regex, timeout=interrupt_timeout) output = self.shell.before - logger.debug(f'Received output after SIGINT: {output}') + logger.debug(f"Received output after SIGINT: {output}") exit_code = 130 # SIGINT - _additional_msg = '' + _additional_msg = "" if action_timeout is not None: _additional_msg = ( - f'Command timed out after {action_timeout} seconds. ' + f"Command timed out after {action_timeout} seconds. " ) output += ( - '\r\n\r\n' - + f'[{_additional_msg}SIGINT was sent to interrupt the command.]' + "\r\n\r\n" + + f"[{_additional_msg}SIGINT was sent to interrupt the command.]" ) return output, exit_code except pexpect.TIMEOUT as e: - logger.warning(f'Bash pexpect.TIMEOUT while waiting for SIGINT: {e}') + logger.warning(f"Bash pexpect.TIMEOUT while waiting for SIGINT: {e}") max_retries -= 1 # fall back to send control-z logger.error( - 'Failed to get output after SIGINT. Max retries reached. Sending control-z...' + "Failed to get output after SIGINT. Max retries reached. Sending control-z..." ) - self.shell.sendcontrol('z') + self.shell.sendcontrol("z") self.shell.expect(self.__bash_expect_regex) output = self.shell.before - logger.debug(f'Received output after control-z: {output}') + logger.debug(f"Received output after control-z: {output}") # Try to kill the job - self.shell.sendline('kill -9 %1') + self.shell.sendline("kill -9 %1") self.shell.expect(self.__bash_expect_regex) - logger.debug(f'Received output after killing job %1: {self.shell.before}') + logger.debug(f"Received output after killing job %1: {self.shell.before}") output += self.shell.before - _additional_msg = '' + _additional_msg = "" if action_timeout is not None: - _additional_msg = f'Command timed out after {action_timeout} seconds. ' + _additional_msg = f"Command timed out after {action_timeout} seconds. " output += ( - '\r\n\r\n' - + f'[{_additional_msg}SIGINT was sent to interrupt the command, but failed. The command was killed.]' + "\r\n\r\n" + + f"[{_additional_msg}SIGINT was sent to interrupt the command, but failed. The command was killed.]" ) # Try to get the exit code again - self.shell.sendline('echo $?') + self.shell.sendline("echo $?") self.shell.expect(self.__bash_expect_regex) _exit_code_output = self.shell.before exit_code = self._parse_exit_code(_exit_code_output) @@ -238,7 +238,7 @@ def _parse_exit_code(self, output: str) -> int: try: exit_code = int(output.strip().split()[0]) except Exception: - logger.error('Error getting exit code from bash script') + logger.error("Error getting exit code from bash script") # If we try to run an invalid shell script the output sometimes includes error text # rather than the error code - we assume this is an error exit_code = 2 @@ -250,47 +250,47 @@ def _continue_bash( keep_prompt: bool = True, kill_on_timeout: bool = True, ) -> tuple[str, int]: - logger.debug(f'Continuing bash with timeout={timeout}') + logger.debug(f"Continuing bash with timeout={timeout}") try: self.shell.expect(self.__bash_expect_regex, timeout=timeout) output = self.shell.before # Get exit code - self.shell.sendline('echo $?') - logger.debug('Requesting exit code...') + self.shell.sendline("echo $?") + logger.debug("Requesting exit code...") self.shell.expect(self.__bash_expect_regex, timeout=timeout) _exit_code_output = self.shell.before exit_code = self._parse_exit_code(_exit_code_output) except pexpect.TIMEOUT as e: - logger.warning(f'Bash pexpect.TIMEOUT while executing bash command: {e}') + logger.warning(f"Bash pexpect.TIMEOUT while executing bash command: {e}") if kill_on_timeout: output, exit_code = self._interrupt_bash(action_timeout=timeout) else: - output = self.shell.before or '' + output = self.shell.before or "" exit_code = -1 finally: bash_prompt = self._get_bash_prompt_and_update_pwd() if keep_prompt: - output += '\r\n' + bash_prompt + output += "\r\n" + bash_prompt return output, exit_code def run(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservation: try: assert ( action.timeout is not None - ), f'Timeout argument is required for CmdRunAction: {action}' + ), f"Timeout argument is required for CmdRunAction: {action}" commands = split_bash_commands(action.command) - all_output = '' - python_interpreter = '' + all_output = "" + python_interpreter = "" for command in commands: - if command == '': + if command == "": output, exit_code = self._continue_bash( timeout=SOFT_TIMEOUT_SECONDS, keep_prompt=action.keep_prompt, kill_on_timeout=False, ) - elif command.lower() == 'ctrl+c': + elif command.lower() == "ctrl+c": output, exit_code = self._interrupt_bash( action_timeout=None, # intentionally None ) @@ -305,24 +305,24 @@ def run(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservation: ) # Get rid of the python interpreter string from each line of the output. # We need it only once at the end. - parts = output.rsplit('[Python Interpreter: ', 1) + parts = output.rsplit("[Python Interpreter: ", 1) output = parts[0] if len(parts) == 2: - python_interpreter = '[Python Interpreter: ' + parts[1] + python_interpreter = "[Python Interpreter: " + parts[1] if all_output: # previous output already exists so we add a newline - all_output += '\r\n' + all_output += "\r\n" # If the command originated with the agent, append the command that was run... if action.source == EventSource.AGENT: - all_output += command + '\r\n' + all_output += command + "\r\n" all_output += str(output) if exit_code != 0: break return CmdOutputObservation( command_id=-1, - content=all_output.rstrip('\r\n'), + content=all_output.rstrip("\r\n"), command=action.command, hidden=action.hidden, exit_code=exit_code, @@ -330,5 +330,5 @@ def run(self, action: CmdRunAction) -> CmdOutputObservation | ErrorObservation: ) except UnicodeDecodeError as e: return ErrorObservation( - f'Runtime bash execution failed: Command output could not be decoded as utf-8. {str(e)}', + f"Runtime bash execution failed: Command output could not be decoded as utf-8. {str(e)}", ) diff --git a/openhands/runtime/utils/command.py b/openhands/runtime/utils/command.py index 1617ec20f36f..b46486de5c60 100644 --- a/openhands/runtime/utils/command.py +++ b/openhands/runtime/utils/command.py @@ -7,23 +7,23 @@ def get_remote_startup_command( browsergym_args: list[str], ): return [ - '/openhands/micromamba/bin/micromamba', - 'run', - '-n', - 'openhands', - 'poetry', - 'run', - 'python', - '-u', - '-m', - 'openhands.runtime.action_execution_server', + "/openhands/micromamba/bin/micromamba", + "run", + "-n", + "openhands", + "poetry", + "run", + "python", + "-u", + "-m", + "openhands.runtime.action_execution_server", str(port), - '--working-dir', + "--working-dir", sandbox_workspace_dir, *plugin_args, - '--username', + "--username", username, - '--user-id', + "--user-id", str(user_id), *browsergym_args, ] diff --git a/openhands/runtime/utils/edit.py b/openhands/runtime/utils/edit.py index cd3ffd0b71ce..1e46c571dfe0 100644 --- a/openhands/runtime/utils/edit.py +++ b/openhands/runtime/utils/edit.py @@ -52,7 +52,7 @@ def _extract_code(string): - pattern = r'```(?:\w*\n)?(.*?)```' + pattern = r"```(?:\w*\n)?(.*?)```" matches = re.findall(pattern, string, re.DOTALL) if not matches: return None @@ -64,16 +64,16 @@ def get_new_file_contents( ) -> str | None: while num_retries > 0: messages = [ - {'role': 'system', 'content': SYS_MSG}, + {"role": "system", "content": SYS_MSG}, { - 'role': 'user', - 'content': USER_MSG.format( + "role": "user", + "content": USER_MSG.format( old_contents=old_contents, draft_changes=draft_changes ), }, ] resp = llm.completion(messages=messages) - new_contents = _extract_code(resp['choices'][0]['message']['content']) + new_contents = _extract_code(resp["choices"][0]["message"]["content"]) if new_contents is not None: return new_contents num_retries -= 1 @@ -107,18 +107,18 @@ def __init__(self, *args, **kwargs): # manually set the model name for the draft editor LLM to distinguish token costs llm_metrics = Metrics( - model_name='draft_editor:' + llm_config.draft_editor.model + model_name="draft_editor:" + llm_config.draft_editor.model ) if llm_config.draft_editor.caching_prompt: logger.debug( - 'It is not recommended to cache draft editor LLM prompts as it may incur high costs for the same prompt. ' - 'Automatically setting caching_prompt=false.' + "It is not recommended to cache draft editor LLM prompts as it may incur high costs for the same prompt. " + "Automatically setting caching_prompt=false." ) llm_config.draft_editor.caching_prompt = False self.draft_editor_llm = LLM(llm_config.draft_editor, metrics=llm_metrics) logger.debug( - f'[Draft edit functionality] enabled with LLM: {self.draft_editor_llm}' + f"[Draft edit functionality] enabled with LLM: {self.draft_editor_llm}" ) def _validate_range( @@ -131,7 +131,7 @@ def _validate_range( or (start > end and end != -1 and start != -1) ): return ErrorObservation( - f'Invalid range for editing: start={start}, end={end}, total lines={total_lines}. start must be >= 1 and <={total_lines} (total lines of the edited file), start <= end, or start == -1 (append to the end of the file).' + f"Invalid range for editing: start={start}, end={end}, total lines={total_lines}. start must be >= 1 and <={total_lines} (total lines of the edited file), start <= end, or start == -1 (append to the end of the file)." ) if ( (end < 1 and end != -1) @@ -139,7 +139,7 @@ def _validate_range( or (end < start and start != -1 and end != -1) ): return ErrorObservation( - f'Invalid range for editing: start={start}, end={end}, total lines={total_lines}. end must be >= 1 and <= {total_lines} (total lines of the edited file), end >= start, or end == -1 (to edit till the end of the file).' + f"Invalid range for editing: start={start}, end={end}, total lines={total_lines}. end must be >= 1 and <= {total_lines} (total lines of the edited file), end >= start, or end == -1 (to edit till the end of the file)." ) return None @@ -154,9 +154,9 @@ def _get_lint_error( linter = DefaultLinter() # Copy the original file to a temporary file (with the same ext) and lint it with tempfile.NamedTemporaryFile( - suffix=suffix, mode='w+', encoding='utf-8' + suffix=suffix, mode="w+", encoding="utf-8" ) as original_file_copy, tempfile.NamedTemporaryFile( - suffix=suffix, mode='w+', encoding='utf-8' + suffix=suffix, mode="w+", encoding="utf-8" ) as updated_file_copy: # Lint the original file original_file_copy.write(old_content) @@ -180,20 +180,20 @@ def _get_lint_error( ) error_message = ( ( - f'\n[Linting failed for edited file {filepath}. {len(updated_lint_error)} lint errors found.]\n' - '[begin attempted changes]\n' - f'{_obs.visualize_diff(change_applied=False)}\n' - '[end attempted changes]\n' + f"\n[Linting failed for edited file {filepath}. {len(updated_lint_error)} lint errors found.]\n" + "[begin attempted changes]\n" + f"{_obs.visualize_diff(change_applied=False)}\n" + "[end attempted changes]\n" ) - + '-' * 40 - + '\n' + + "-" * 40 + + "\n" ) - error_message += '-' * 20 + 'First 5 lint errors' + '-' * 20 + '\n' + error_message += "-" * 20 + "First 5 lint errors" + "-" * 20 + "\n" for i, lint_error in enumerate(updated_lint_error[:5]): - error_message += f'[begin lint error {i}]\n' - error_message += lint_error.visualize().strip() + '\n' - error_message += f'[end lint error {i}]\n' - error_message += '-' * 40 + '\n' + error_message += f"[begin lint error {i}]\n" + error_message += lint_error.visualize().strip() + "\n" + error_message += f"[end lint error {i}]\n" + error_message += "-" * 40 + "\n" return ErrorObservation(error_message) return None @@ -201,10 +201,10 @@ def edit(self, action: FileEditAction) -> Observation: obs = self.read(FileReadAction(path=action.path)) if ( isinstance(obs, ErrorObservation) - and 'File not found'.lower() in obs.content.lower() + and "File not found".lower() in obs.content.lower() ): logger.debug( - f'Agent attempted to edit a file that does not exist. Creating the file. Error msg: {obs.content}' + f"Agent attempted to edit a file that does not exist. Creating the file. Error msg: {obs.content}" ) # directly write the new content obs = self.write( @@ -214,22 +214,22 @@ def edit(self, action: FileEditAction) -> Observation: return obs if not isinstance(obs, FileWriteObservation): raise ValueError( - f'Expected FileWriteObservation, got {type(obs)}: {str(obs)}' + f"Expected FileWriteObservation, got {type(obs)}: {str(obs)}" ) return FileEditObservation( - content=get_diff('', action.content, action.path), + content=get_diff("", action.content, action.path), path=action.path, prev_exist=False, - old_content='', + old_content="", new_content=action.content, ) if not isinstance(obs, FileReadObservation): raise ValueError( - f'Expected FileReadObservation, got {type(obs)}: {str(obs)}' + f"Expected FileReadObservation, got {type(obs)}: {str(obs)}" ) original_file_content = obs.content - old_file_lines = original_file_content.split('\n') + old_file_lines = original_file_content.split("\n") # NOTE: start and end are 1-indexed start = action.start end = action.end @@ -240,7 +240,7 @@ def edit(self, action: FileEditAction) -> Observation: # append to the end of the file if start == -1: - updated_content = '\n'.join(old_file_lines + action.content.split('\n')) + updated_content = "\n".join(old_file_lines + action.content.split("\n")) diff = get_diff(original_file_content, updated_content, action.path) # Lint the updated content if self.config.sandbox.enable_auto_lint: @@ -279,9 +279,9 @@ def edit(self, action: FileEditAction) -> Observation: length_of_range = end_idx - start_idx if length_of_range > self.MAX_LINES_TO_EDIT + 1: error_msg = ( - f'[Edit error: The range of lines to edit is too long.]\n' - f'[The maximum number of lines allowed to edit at once is {self.MAX_LINES_TO_EDIT}. ' - f'Got (L{start_idx + 1}-L{end_idx}) {length_of_range} lines.]\n' # [start_idx, end_idx), so no need to + 1 + f"[Edit error: The range of lines to edit is too long.]\n" + f"[The maximum number of lines allowed to edit at once is {self.MAX_LINES_TO_EDIT}. " + f"Got (L{start_idx + 1}-L{end_idx}) {length_of_range} lines.]\n" # [start_idx, end_idx), so no need to + 1 ) # search for relevant ranges to hint the agent topk_chunks: list[Chunk] = get_top_k_chunk_matches( @@ -291,29 +291,29 @@ def edit(self, action: FileEditAction) -> Observation: max_chunk_size=20, # lines ) error_msg += ( - 'Here are some snippets that maybe relevant to the provided edit.\n' + "Here are some snippets that maybe relevant to the provided edit.\n" ) for i, chunk in enumerate(topk_chunks): - error_msg += f'[begin relevant snippet {i+1}. Line range: L{chunk.line_range[0]}-L{chunk.line_range[1]}. Similarity: {chunk.normalized_lcs}]\n' + error_msg += f"[begin relevant snippet {i+1}. Line range: L{chunk.line_range[0]}-L{chunk.line_range[1]}. Similarity: {chunk.normalized_lcs}]\n" error_msg += f'[Browse around it via `open_file("{action.path}", {(chunk.line_range[0] + chunk.line_range[1]) // 2})`]\n' - error_msg += chunk.visualize() + '\n' - error_msg += f'[end relevant snippet {i+1}]\n' - error_msg += '-' * 40 + '\n' + error_msg += chunk.visualize() + "\n" + error_msg += f"[end relevant snippet {i+1}]\n" + error_msg += "-" * 40 + "\n" - error_msg += 'Consider using `open_file` to explore around the relevant snippets if needed.\n' + error_msg += "Consider using `open_file` to explore around the relevant snippets if needed.\n" error_msg += f'**IMPORTANT**: Please REDUCE the range of edits to less than {self.MAX_LINES_TO_EDIT} lines by setting `start` and `end` in the edit action (e.g. ``). ' return ErrorObservation(error_msg) - content_to_edit = '\n'.join(old_file_lines[start_idx:end_idx]) + content_to_edit = "\n".join(old_file_lines[start_idx:end_idx]) self.draft_editor_llm.reset() _edited_content = get_new_file_contents( self.draft_editor_llm, content_to_edit, action.content ) if _edited_content is None: ret_err = ErrorObservation( - 'Failed to get new file contents. ' - 'Please try to reduce the number of edits and try again.' + "Failed to get new file contents. " + "Please try to reduce the number of edits and try again." ) ret_err.llm_metrics = self.draft_editor_llm.metrics return ret_err @@ -321,10 +321,10 @@ def edit(self, action: FileEditAction) -> Observation: # piece the updated content with the unchanged content updated_lines = ( old_file_lines[:start_idx] - + _edited_content.split('\n') + + _edited_content.split("\n") + old_file_lines[end_idx:] ) - updated_content = '\n'.join(updated_lines) + updated_content = "\n".join(updated_lines) diff = get_diff(original_file_content, updated_content, action.path) # Lint the updated content diff --git a/openhands/runtime/utils/files.py b/openhands/runtime/utils/files.py index b9664cafc45f..54145fc990ed 100644 --- a/openhands/runtime/utils/files.py +++ b/openhands/runtime/utils/files.py @@ -38,7 +38,7 @@ def resolve_path( # If the path is outside the workspace, deny it if not abs_path_in_sandbox.is_relative_to(workspace_mount_path_in_sandbox): - raise PermissionError(f'File access not permitted: {file_path}') + raise PermissionError(f"File access not permitted: {file_path}") # Get path relative to the root of the workspace inside the sandbox path_in_workspace = abs_path_in_sandbox.relative_to( @@ -81,15 +81,15 @@ async def read_file( ) try: - with open(whole_path, 'r', encoding='utf-8') as file: + with open(whole_path, "r", encoding="utf-8") as file: lines = read_lines(file.readlines(), start, end) except FileNotFoundError: - return ErrorObservation(f'File not found: {path}') + return ErrorObservation(f"File not found: {path}") except UnicodeDecodeError: - return ErrorObservation(f'File could not be decoded as utf-8: {path}') + return ErrorObservation(f"File could not be decoded as utf-8: {path}") except IsADirectoryError: - return ErrorObservation(f'Path is a directory: {path}. You can only read files') - code_view = ''.join(lines) + return ErrorObservation(f"Path is a directory: {path}. You can only read files") + code_view = "".join(lines) return FileReadObservation(path=path, content=code_view) @@ -97,9 +97,9 @@ def insert_lines( to_insert: list[str], original: list[str], start: int = 0, end: int = -1 ): """Insert the new content to the original content based on start and end""" - new_lines = [''] if start == 0 else original[:start] - new_lines += [i + '\n' for i in to_insert] - new_lines += [''] if end == -1 else original[end:] + new_lines = [""] if start == 0 else original[:start] + new_lines += [i + "\n" for i in to_insert] + new_lines += [""] if end == -1 else original[end:] return new_lines @@ -112,7 +112,7 @@ async def write_file( start=0, end=-1, ) -> Observation: - insert = content.split('\n') + insert = content.split("\n") try: whole_path = resolve_path( @@ -120,26 +120,26 @@ async def write_file( ) if not os.path.exists(os.path.dirname(whole_path)): os.makedirs(os.path.dirname(whole_path)) - mode = 'w' if not os.path.exists(whole_path) else 'r+' + mode = "w" if not os.path.exists(whole_path) else "r+" try: - with open(whole_path, mode, encoding='utf-8') as file: - if mode != 'w': + with open(whole_path, mode, encoding="utf-8") as file: + if mode != "w": all_lines = file.readlines() new_file = insert_lines(insert, all_lines, start, end) else: - new_file = [i + '\n' for i in insert] + new_file = [i + "\n" for i in insert] file.seek(0) file.writelines(new_file) file.truncate() except FileNotFoundError: - return ErrorObservation(f'File not found: {path}') + return ErrorObservation(f"File not found: {path}") except IsADirectoryError: return ErrorObservation( - f'Path is a directory: {path}. You can only write to files' + f"Path is a directory: {path}. You can only write to files" ) except UnicodeDecodeError: - return ErrorObservation(f'File could not be decoded as utf-8: {path}') + return ErrorObservation(f"File could not be decoded as utf-8: {path}") except PermissionError: - return ErrorObservation(f'Malformed paths not permitted: {path}') - return FileWriteObservation(content='', path=path) + return ErrorObservation(f"Malformed paths not permitted: {path}") + return FileWriteObservation(content="", path=path) diff --git a/openhands/runtime/utils/runtime_build.py b/openhands/runtime/utils/runtime_build.py index eab98befe538..2dc55e45a81d 100644 --- a/openhands/runtime/utils/runtime_build.py +++ b/openhands/runtime/utils/runtime_build.py @@ -19,13 +19,13 @@ class BuildFromImageType(Enum): - SCRATCH = 'scratch' # Slowest: Build from base image (no dependencies are reused) - VERSIONED = 'versioned' # Medium speed: Reuse the most recent image with the same base image & OH version (a lot of dependencies are already installed) - LOCK = 'lock' # Fastest: Reuse the most recent image with the exact SAME dependencies (lock files) + SCRATCH = "scratch" # Slowest: Build from base image (no dependencies are reused) + VERSIONED = "versioned" # Medium speed: Reuse the most recent image with the same base image & OH version (a lot of dependencies are already installed) + LOCK = "lock" # Fastest: Reuse the most recent image with the exact SAME dependencies (lock files) def get_runtime_image_repo(): - return os.getenv('OH_RUNTIME_RUNTIME_IMAGE_REPO', 'ghcr.io/all-hands-ai/runtime') + return os.getenv("OH_RUNTIME_RUNTIME_IMAGE_REPO", "ghcr.io/all-hands-ai/runtime") def _generate_dockerfile( @@ -45,16 +45,16 @@ def _generate_dockerfile( """ env = Environment( loader=FileSystemLoader( - searchpath=os.path.join(os.path.dirname(__file__), 'runtime_templates') + searchpath=os.path.join(os.path.dirname(__file__), "runtime_templates") ) ) - template = env.get_template('Dockerfile.j2') + template = env.get_template("Dockerfile.j2") dockerfile_content = template.render( base_image=base_image, build_from_scratch=build_from == BuildFromImageType.SCRATCH, build_from_versioned=build_from == BuildFromImageType.VERSIONED, - extra_deps=extra_deps if extra_deps is not None else '', + extra_deps=extra_deps if extra_deps is not None else "", ) return dockerfile_content @@ -68,36 +68,35 @@ def get_runtime_image_repo_and_tag(base_image: str) -> tuple[str, str]: Returns: - tuple[str, str]: The Docker repo and tag of the Docker image """ - if get_runtime_image_repo() in base_image: logger.debug( - f'The provided image [{base_image}] is already a valid runtime image.\n' - f'Will try to reuse it as is.' + f"The provided image [{base_image}] is already a valid runtime image.\n" + f"Will try to reuse it as is." ) - if ':' not in base_image: - base_image = base_image + ':latest' - repo, tag = base_image.split(':') + if ":" not in base_image: + base_image = base_image + ":latest" + repo, tag = base_image.split(":") return repo, tag else: - if ':' not in base_image: - base_image = base_image + ':latest' - [repo, tag] = base_image.split(':') + if ":" not in base_image: + base_image = base_image + ":latest" + [repo, tag] = base_image.split(":") # Hash the repo if it's too long if len(repo) > 32: repo_hash = hashlib.md5(repo[:-24].encode()).hexdigest()[:8] - repo = f'{repo_hash}_{repo[-24:]}' # Use 8 char hash + last 24 chars + repo = f"{repo_hash}_{repo[-24:]}" # Use 8 char hash + last 24 chars else: - repo = repo.replace('/', '_s_') + repo = repo.replace("/", "_s_") - new_tag = f'oh_v{oh_version}_image_{repo}_tag_{tag}' + new_tag = f"oh_v{oh_version}_image_{repo}_tag_{tag}" # if it's still too long, hash the entire image name if len(new_tag) > 128: - new_tag = f'oh_v{oh_version}_image_{hashlib.md5(new_tag.encode()).hexdigest()[:64]}' + new_tag = f"oh_v{oh_version}_image_{hashlib.md5(new_tag.encode()).hexdigest()[:64]}" logger.warning( - f'The new tag [{new_tag}] is still too long, so we use an hash of the entire image name: {new_tag}' + f"The new tag [{new_tag}] is still too long, so we use an hash of the entire image name: {new_tag}" ) return get_runtime_image_repo(), new_tag @@ -164,19 +163,19 @@ def build_runtime_image_in_folder( platform: str | None = None, ) -> str: runtime_image_repo, _ = get_runtime_image_repo_and_tag(base_image) - lock_tag = f'oh_v{oh_version}_{get_hash_for_lock_files(base_image)}' + lock_tag = f"oh_v{oh_version}_{get_hash_for_lock_files(base_image)}" versioned_tag = ( # truncate the base image to 96 characters to fit in the tag max length (128 characters) - f'oh_v{oh_version}_{get_tag_for_versioned_image(base_image)}' + f"oh_v{oh_version}_{get_tag_for_versioned_image(base_image)}" ) - versioned_image_name = f'{runtime_image_repo}:{versioned_tag}' - source_tag = f'{lock_tag}_{get_hash_for_source_files()}' - hash_image_name = f'{runtime_image_repo}:{source_tag}' + versioned_image_name = f"{runtime_image_repo}:{versioned_tag}" + source_tag = f"{lock_tag}_{get_hash_for_source_files()}" + hash_image_name = f"{runtime_image_repo}:{source_tag}" - logger.info(f'Building image: {hash_image_name}') + logger.info(f"Building image: {hash_image_name}") if force_rebuild: logger.debug( - f'Force rebuild: [{runtime_image_repo}:{source_tag}] from scratch.' + f"Force rebuild: [{runtime_image_repo}:{source_tag}] from scratch." ) prep_build_folder( build_folder, @@ -196,29 +195,29 @@ def build_runtime_image_in_folder( ) return hash_image_name - lock_image_name = f'{runtime_image_repo}:{lock_tag}' + lock_image_name = f"{runtime_image_repo}:{lock_tag}" build_from = BuildFromImageType.SCRATCH # If the exact image already exists, we do not need to build it if runtime_builder.image_exists(hash_image_name, False): - logger.debug(f'Reusing Image [{hash_image_name}]') + logger.debug(f"Reusing Image [{hash_image_name}]") return hash_image_name # We look for an existing image that shares the same lock_tag. If such an image exists, we # can use it as the base image for the build and just copy source files. This makes the build # much faster. if runtime_builder.image_exists(lock_image_name): - logger.debug(f'Build [{hash_image_name}] from lock image [{lock_image_name}]') + logger.debug(f"Build [{hash_image_name}] from lock image [{lock_image_name}]") build_from = BuildFromImageType.LOCK base_image = lock_image_name elif runtime_builder.image_exists(versioned_image_name): logger.info( - f'Build [{hash_image_name}] from versioned image [{versioned_image_name}]' + f"Build [{hash_image_name}] from versioned image [{versioned_image_name}]" ) build_from = BuildFromImageType.VERSIONED base_image = versioned_image_name else: - logger.debug(f'Build [{hash_image_name}] from scratch') + logger.debug(f"Build [{hash_image_name}] from scratch") prep_build_folder(build_folder, base_image, build_from, extra_deps) if not dry_run: @@ -249,26 +248,26 @@ def prep_build_folder( # If package is not found, build from source code openhands_source_dir = Path(openhands.__file__).parent project_root = openhands_source_dir.parent - logger.debug(f'Building source distribution using project root: {project_root}') + logger.debug(f"Building source distribution using project root: {project_root}") # Copy the 'openhands' directory (Source code) shutil.copytree( openhands_source_dir, - Path(build_folder, 'code', 'openhands'), + Path(build_folder, "code", "openhands"), ignore=shutil.ignore_patterns( - '.*/', - '__pycache__/', - '*.pyc', - '*.md', + ".*/", + "__pycache__/", + "*.pyc", + "*.md", ), ) # Copy pyproject.toml and poetry.lock files - for file in ['pyproject.toml', 'poetry.lock']: + for file in ["pyproject.toml", "poetry.lock"]: src = Path(openhands_source_dir, file) if not src.exists(): src = Path(project_root, file) - shutil.copy2(src, Path(build_folder, 'code', file)) + shutil.copy2(src, Path(build_folder, "code", file)) # Create a Dockerfile and write it to build_folder dockerfile_content = _generate_dockerfile( @@ -276,7 +275,7 @@ def prep_build_folder( build_from=build_from, extra_deps=extra_deps, ) - with open(Path(build_folder, 'Dockerfile'), 'w') as file: # type: ignore + with open(Path(build_folder, "Dockerfile"), "w") as file: # type: ignore file.write(dockerfile_content) # type: ignore @@ -290,19 +289,19 @@ def truncate_hash(hash: str) -> str: while value > 0 and len(result) < 16: value, remainder = divmod(value, len(_ALPHABET)) result.append(_ALPHABET[remainder]) - return ''.join(result) + return "".join(result) def get_hash_for_lock_files(base_image: str): openhands_source_dir = Path(openhands.__file__).parent md5 = hashlib.md5() md5.update(base_image.encode()) - for file in ['pyproject.toml', 'poetry.lock']: + for file in ["pyproject.toml", "poetry.lock"]: src = Path(openhands_source_dir, file) if not src.exists(): src = Path(openhands_source_dir.parent, file) - with open(src, 'rb') as f: - for chunk in iter(lambda: f.read(4096), b''): + with open(src, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): md5.update(chunk) # We get away with truncation because we want something that is unique # rather than something that is cryptographically secure @@ -311,18 +310,18 @@ def get_hash_for_lock_files(base_image: str): def get_tag_for_versioned_image(base_image: str): - return base_image.replace('/', '_s_').replace(':', '_t_').lower()[-96:] + return base_image.replace("/", "_s_").replace(":", "_t_").lower()[-96:] def get_hash_for_source_files(): openhands_source_dir = Path(openhands.__file__).parent dir_hash = dirhash( openhands_source_dir, - 'md5', + "md5", ignore=[ - '.*/', # hidden directories - '__pycache__/', - '*.pyc', + ".*/", # hidden directories + "__pycache__/", + "*.pyc", ], ) # We get away with truncation because we want something that is unique @@ -342,30 +341,30 @@ def _build_sandbox_image( ): """Build and tag the sandbox image. The image will be tagged with all tags that do not yet exist""" names = [ - f'{runtime_image_repo}:{source_tag}', - f'{runtime_image_repo}:{lock_tag}', + f"{runtime_image_repo}:{source_tag}", + f"{runtime_image_repo}:{lock_tag}", ] if versioned_tag is not None: - names.append(f'{runtime_image_repo}:{versioned_tag}') + names.append(f"{runtime_image_repo}:{versioned_tag}") names = [name for name in names if not runtime_builder.image_exists(name, False)] image_name = runtime_builder.build( path=str(build_folder), tags=names, platform=platform ) if not image_name: - raise RuntimeError(f'Build failed for image {names}') + raise RuntimeError(f"Build failed for image {names}") return image_name -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - '--base_image', type=str, default='nikolaik/python-nodejs:python3.12-nodejs22' + "--base_image", type=str, default="nikolaik/python-nodejs:python3.12-nodejs22" ) - parser.add_argument('--build_folder', type=str, default=None) - parser.add_argument('--force_rebuild', action='store_true', default=False) - parser.add_argument('--platform', type=str, default=None) + parser.add_argument("--build_folder", type=str, default=None) + parser.add_argument("--force_rebuild", action="store_true", default=False) + parser.add_argument("--platform", type=str, default=None) args = parser.parse_args() if args.build_folder is not None: @@ -375,16 +374,16 @@ def _build_sandbox_image( build_folder = args.build_folder assert os.path.exists( build_folder - ), f'Build folder {build_folder} does not exist' + ), f"Build folder {build_folder} does not exist" logger.debug( - f'Copying the source code and generating the Dockerfile in the build folder: {build_folder}' + f"Copying the source code and generating the Dockerfile in the build folder: {build_folder}" ) runtime_image_repo, runtime_image_tag = get_runtime_image_repo_and_tag( args.base_image ) logger.debug( - f'Runtime image repo: {runtime_image_repo} and runtime image tag: {runtime_image_tag}' + f"Runtime image repo: {runtime_image_repo} and runtime image tag: {runtime_image_tag}" ) with tempfile.TemporaryDirectory() as temp_dir: @@ -400,38 +399,38 @@ def _build_sandbox_image( ) _runtime_image_repo, runtime_image_source_tag = ( - runtime_image_hash_name.split(':') + runtime_image_hash_name.split(":") ) # Move contents of temp_dir to build_folder shutil.copytree(temp_dir, build_folder, dirs_exist_ok=True) logger.debug( - f'Build folder [{build_folder}] is ready: {os.listdir(build_folder)}' + f"Build folder [{build_folder}] is ready: {os.listdir(build_folder)}" ) # We now update the config.sh in the build_folder to contain the required values. This is used in the # containers/build.sh script which is called to actually build the Docker image - with open(os.path.join(build_folder, 'config.sh'), 'a') as file: + with open(os.path.join(build_folder, "config.sh"), "a") as file: file.write( ( - f'\n' - f'DOCKER_IMAGE_TAG={runtime_image_tag}\n' - f'DOCKER_IMAGE_SOURCE_TAG={runtime_image_source_tag}\n' + f"\n" + f"DOCKER_IMAGE_TAG={runtime_image_tag}\n" + f"DOCKER_IMAGE_SOURCE_TAG={runtime_image_source_tag}\n" ) ) logger.debug( - f'`config.sh` is updated with the image repo[{runtime_image_repo}] and tags [{runtime_image_tag}, {runtime_image_source_tag}]' + f"`config.sh` is updated with the image repo[{runtime_image_repo}] and tags [{runtime_image_tag}, {runtime_image_source_tag}]" ) logger.debug( - f'Dockerfile, source code and config.sh are ready in {build_folder}' + f"Dockerfile, source code and config.sh are ready in {build_folder}" ) else: # If a build_folder is not provided, after copying the required source code and dynamically creating the # Dockerfile, we actually build the Docker image - logger.debug('Building image in a temporary folder') + logger.debug("Building image in a temporary folder") docker_builder = DockerRuntimeBuilder(docker.from_env()) image_name = build_runtime_image( args.base_image, docker_builder, platform=args.platform ) - logger.debug(f'\nBuilt image: {image_name}\n') + logger.debug(f"\nBuilt image: {image_name}\n") diff --git a/openhands/runtime/utils/runtime_init.py b/openhands/runtime/utils/runtime_init.py index 9ebba67fcd31..6615db359a19 100644 --- a/openhands/runtime/utils/runtime_init.py +++ b/openhands/runtime/utils/runtime_init.py @@ -31,42 +31,41 @@ def init_user_and_working_directory( Returns: int | None: The user ID if it was updated, None otherwise. """ - # First create the working directory, independent of the user - logger.debug(f'Client working directory: {initial_pwd}') - command = f'umask 002; mkdir -p {initial_pwd}' + logger.debug(f"Client working directory: {initial_pwd}") + command = f"umask 002; mkdir -p {initial_pwd}" output = subprocess.run(command, shell=True, capture_output=True) out_str = output.stdout.decode() - command = f'chown -R {username}:root {initial_pwd}' + command = f"chown -R {username}:root {initial_pwd}" output = subprocess.run(command, shell=True, capture_output=True) out_str += output.stdout.decode() - command = f'chmod g+rw {initial_pwd}' + command = f"chmod g+rw {initial_pwd}" output = subprocess.run(command, shell=True, capture_output=True) out_str += output.stdout.decode() - logger.debug(f'Created working directory. Output: [{out_str}]') + logger.debug(f"Created working directory. Output: [{out_str}]") # Skip root since it is already created - if username == 'root': + if username == "root": return None # Check if the username already exists existing_user_id = -1 try: result = subprocess.run( - f'id -u {username}', shell=True, check=True, capture_output=True + f"id -u {username}", shell=True, check=True, capture_output=True ) existing_user_id = int(result.stdout.decode().strip()) # The user ID already exists, skip setup if existing_user_id == user_id: logger.debug( - f'User `{username}` already has the provided UID {user_id}. Skipping user setup.' + f"User `{username}` already has the provided UID {user_id}. Skipping user setup." ) else: logger.warning( - f'User `{username}` already exists with UID {existing_user_id}. Skipping user setup.' + f"User `{username}` already exists with UID {existing_user_id}. Skipping user setup." ) return existing_user_id return None @@ -74,30 +73,30 @@ def init_user_and_working_directory( # Returncode 1 indicates, that the user does not exist yet if e.returncode == 1: logger.debug( - f'User `{username}` does not exist. Proceeding with user creation.' + f"User `{username}` does not exist. Proceeding with user creation." ) else: - logger.error(f'Error checking user `{username}`, skipping setup:\n{e}\n') + logger.error(f"Error checking user `{username}`, skipping setup:\n{e}\n") raise # Add sudoer sudoer_line = r"echo '%sudo ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers" output = subprocess.run(sudoer_line, shell=True, capture_output=True) if output.returncode != 0: - raise RuntimeError(f'Failed to add sudoer: {output.stderr.decode()}') - logger.debug(f'Added sudoer successfully. Output: [{output.stdout.decode()}]') + raise RuntimeError(f"Failed to add sudoer: {output.stderr.decode()}") + logger.debug(f"Added sudoer successfully. Output: [{output.stdout.decode()}]") command = ( - f'useradd -rm -d /home/{username} -s /bin/bash ' - f'-g root -G sudo -u {user_id} {username}' + f"useradd -rm -d /home/{username} -s /bin/bash " + f"-g root -G sudo -u {user_id} {username}" ) output = subprocess.run(command, shell=True, capture_output=True) if output.returncode == 0: logger.debug( - f'Added user `{username}` successfully with UID {user_id}. Output: [{output.stdout.decode()}]' + f"Added user `{username}` successfully with UID {user_id}. Output: [{output.stdout.decode()}]" ) else: raise RuntimeError( - f'Failed to create user `{username}` with UID {user_id}. Output: [{output.stderr.decode()}]' + f"Failed to create user `{username}` with UID {user_id}. Output: [{output.stderr.decode()}]" ) return None diff --git a/openhands/runtime/utils/shutdown_listener.py b/openhands/runtime/utils/shutdown_listener.py index 3aedd2672270..9d1adb5338f9 100644 --- a/openhands/runtime/utils/shutdown_listener.py +++ b/openhands/runtime/utils/shutdown_listener.py @@ -1,6 +1,4 @@ -""" -This module monitors the app for shutdown signals -""" +"""This module monitors the app for shutdown signals""" import asyncio import signal diff --git a/openhands/runtime/utils/system.py b/openhands/runtime/utils/system.py index 921a8bf94b06..9d5bfb3173ae 100644 --- a/openhands/runtime/utils/system.py +++ b/openhands/runtime/utils/system.py @@ -21,7 +21,7 @@ def find_available_tcp_port(min_port=30000, max_port=39999, max_attempts=10) -> for port in ports[:max_attempts]: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: - sock.bind(('localhost', port)) + sock.bind(("localhost", port)) return port except OSError: time.sleep(0.1) # Short delay to further reduce chance of collisions @@ -37,16 +37,16 @@ def display_number_matrix(number: int) -> str | None: # Define the matrix representation for each digit digits = { - '0': ['###', '# #', '# #', '# #', '###'], - '1': [' #', ' #', ' #', ' #', ' #'], - '2': ['###', ' #', '###', '# ', '###'], - '3': ['###', ' #', '###', ' #', '###'], - '4': ['# #', '# #', '###', ' #', ' #'], - '5': ['###', '# ', '###', ' #', '###'], - '6': ['###', '# ', '###', '# #', '###'], - '7': ['###', ' #', ' #', ' #', ' #'], - '8': ['###', '# #', '###', '# #', '###'], - '9': ['###', '# #', '###', ' #', '###'], + "0": ["###", "# #", "# #", "# #", "###"], + "1": [" #", " #", " #", " #", " #"], + "2": ["###", " #", "###", "# ", "###"], + "3": ["###", " #", "###", " #", "###"], + "4": ["# #", "# #", "###", " #", " #"], + "5": ["###", "# ", "###", " #", "###"], + "6": ["###", "# ", "###", "# #", "###"], + "7": ["###", " #", " #", " #", " #"], + "8": ["###", "# #", "###", "# #", "###"], + "9": ["###", "# #", "###", " #", "###"], } # alternatively, with leading zeros: num_str = f"{number:03d}" @@ -54,8 +54,8 @@ def display_number_matrix(number: int) -> str | None: result = [] for row in range(5): - line = ' '.join(digits[digit][row] for digit in num_str) + line = " ".join(digits[digit][row] for digit in num_str) result.append(line) - matrix_display = '\n'.join(result) - return f'\n{matrix_display}\n' + matrix_display = "\n".join(result) + return f"\n{matrix_display}\n" diff --git a/openhands/runtime/utils/tenacity_stop.py b/openhands/runtime/utils/tenacity_stop.py index 48fdead86647..e6dc01db4606 100644 --- a/openhands/runtime/utils/tenacity_stop.py +++ b/openhands/runtime/utils/tenacity_stop.py @@ -7,5 +7,5 @@ class stop_if_should_exit(stop_base): """Stop if the should_exit flag is set.""" - def __call__(self, retry_state: 'RetryCallState') -> bool: + def __call__(self, retry_state: "RetryCallState") -> bool: return should_exit() diff --git a/openhands/security/invariant/__init__.py b/openhands/security/invariant/__init__.py index 9445ef804a01..24c7709e3118 100644 --- a/openhands/security/invariant/__init__.py +++ b/openhands/security/invariant/__init__.py @@ -1,5 +1,5 @@ from openhands.security.invariant.analyzer import InvariantAnalyzer __all__ = [ - 'InvariantAnalyzer', + "InvariantAnalyzer", ] diff --git a/openhands/security/invariant/analyzer.py b/openhands/security/invariant/analyzer.py index 0ba13b4ecddf..2275ddd0f1f3 100644 --- a/openhands/security/invariant/analyzer.py +++ b/openhands/security/invariant/analyzer.py @@ -28,9 +28,9 @@ class InvariantAnalyzer(SecurityAnalyzer): trace: list[TraceElement] input: list[dict] - container_name: str = 'openhands-invariant-server' - image_name: str = 'ghcr.io/invariantlabs-ai/server:openhands' - api_host: str = 'http://localhost' + container_name: str = "openhands-invariant-server" + image_name: str = "ghcr.io/invariantlabs-ai/server:openhands" + api_host: str = "http://localhost" timeout: int = 180 settings: dict = {} @@ -52,16 +52,16 @@ def __init__( self.docker_client = docker.from_env() except Exception as ex: logger.exception( - 'Error creating Invariant Security Analyzer container. Please check that Docker is running or disable the Security Analyzer in settings.', + "Error creating Invariant Security Analyzer container. Please check that Docker is running or disable the Security Analyzer in settings.", exc_info=False, ) raise ex running_containers = self.docker_client.containers.list( - filters={'name': self.container_name} + filters={"name": self.container_name} ) if not running_containers: all_containers = self.docker_client.containers.list( - all=True, filters={'name': self.container_name} + all=True, filters={"name": self.container_name} ) if all_containers: self.container = all_containers[0] @@ -71,33 +71,33 @@ def __init__( self.container = self.docker_client.containers.run( self.image_name, name=self.container_name, - platform='linux/amd64', - ports={'8000/tcp': self.api_port}, + platform="linux/amd64", + ports={"8000/tcp": self.api_port}, detach=True, ) else: self.container = running_containers[0] elapsed = 0 - while self.container.status != 'running': + while self.container.status != "running": self.container = self.docker_client.containers.get(self.container_name) elapsed += 1 logger.debug( - f'waiting for container to start: {elapsed}, container status: {self.container.status}' + f"waiting for container to start: {elapsed}, container status: {self.container.status}" ) if elapsed > self.timeout: break self.api_port = int( - self.container.attrs['NetworkSettings']['Ports']['8000/tcp'][0]['HostPort'] + self.container.attrs["NetworkSettings"]["Ports"]["8000/tcp"][0]["HostPort"] ) - self.api_server = f'{self.api_host}:{self.api_port}' + self.api_server = f"{self.api_host}:{self.api_port}" self.client = InvariantClient(self.api_server, self.sid) if policy is None: policy, _ = self.client.Policy.get_template() if policy is None: - policy = '' + policy = "" self.monitor = self.client.Monitor.from_string(policy) async def close(self): @@ -109,15 +109,15 @@ async def log_event(self, event: Event) -> None: self.trace.extend(element) self.input.extend([e.model_dump(exclude_none=True) for e in element]) # type: ignore [call-overload] else: - logger.debug('Invariant skipping element: event') + logger.debug("Invariant skipping element: event") def get_risk(self, results: list[str]) -> ActionSecurityRisk: mapping = { - 'high': ActionSecurityRisk.HIGH, - 'medium': ActionSecurityRisk.MEDIUM, - 'low': ActionSecurityRisk.LOW, + "high": ActionSecurityRisk.HIGH, + "medium": ActionSecurityRisk.MEDIUM, + "low": ActionSecurityRisk.LOW, } - regex = r'(?<=risk=)\w+' + regex = r"(?<=risk=)\w+" risks = [] for result in results: m = re.search(regex, result) @@ -137,22 +137,22 @@ async def should_confirm(self, event: Event) -> bool: risk = event.security_risk # type: ignore [attr-defined] return ( risk is not None - and risk < self.settings.get('RISK_SEVERITY', ActionSecurityRisk.MEDIUM) - and hasattr(event, 'confirmation_state') + and risk < self.settings.get("RISK_SEVERITY", ActionSecurityRisk.MEDIUM) + and hasattr(event, "confirmation_state") and event.confirmation_state == ActionConfirmationStatus.AWAITING_CONFIRMATION ) async def confirm(self, event: Event) -> None: new_event = action_from_dict( - {'action': 'change_agent_state', 'args': {'agent_state': 'user_confirmed'}} + {"action": "change_agent_state", "args": {"agent_state": "user_confirmed"}} ) # we should confirm only on agent actions event_source = event.source if event.source else EventSource.AGENT await call_sync_from_async(self.event_stream.add_event, new_event, event_source) async def security_risk(self, event: Action) -> ActionSecurityRisk: - logger.debug('Calling security_risk on InvariantAnalyzer') + logger.debug("Calling security_risk on InvariantAnalyzer") new_elements = parse_element(self.trace, event) input = [e.model_dump(exclude_none=True) for e in new_elements] # type: ignore [call-overload] self.trace.extend(new_elements) @@ -160,7 +160,7 @@ async def security_risk(self, event: Action) -> ActionSecurityRisk: self.input.extend(input) risk = ActionSecurityRisk.UNKNOWN if err: - logger.warning(f'Error checking policy: {err}') + logger.warning(f"Error checking policy: {err}") return risk risk = self.get_risk(result) @@ -169,35 +169,35 @@ async def security_risk(self, event: Action) -> ActionSecurityRisk: ### Handle API requests async def handle_api_request(self, request: Request) -> Any: - path_parts = request.url.path.strip('/').split('/') + path_parts = request.url.path.strip("/").split("/") endpoint = path_parts[-1] # Get the last part of the path - if request.method == 'GET': - if endpoint == 'export-trace': + if request.method == "GET": + if endpoint == "export-trace": return await self.export_trace(request) - elif endpoint == 'policy': + elif endpoint == "policy": return await self.get_policy(request) - elif endpoint == 'settings': + elif endpoint == "settings": return await self.get_settings(request) - elif request.method == 'POST': - if endpoint == 'policy': + elif request.method == "POST": + if endpoint == "policy": return await self.update_policy(request) - elif endpoint == 'settings': + elif endpoint == "settings": return await self.update_settings(request) - raise HTTPException(status_code=405, detail='Method Not Allowed') + raise HTTPException(status_code=405, detail="Method Not Allowed") async def export_trace(self, request: Request) -> Any: return JSONResponse(content=self.input) async def get_policy(self, request: Request) -> Any: - return JSONResponse(content={'policy': self.monitor.policy}) + return JSONResponse(content={"policy": self.monitor.policy}) async def update_policy(self, request: Request) -> Any: data = await request.json() - policy = data.get('policy') + policy = data.get("policy") new_monitor = self.client.Monitor.from_string(policy) self.monitor = new_monitor - return JSONResponse(content={'policy': policy}) + return JSONResponse(content={"policy": policy}) async def get_settings(self, request: Request) -> Any: return JSONResponse(content=self.settings) diff --git a/openhands/security/invariant/client.py b/openhands/security/invariant/client.py index c41828745658..a9228f268f83 100644 --- a/openhands/security/invariant/client.py +++ b/openhands/security/invariant/client.py @@ -12,7 +12,7 @@ def __init__(self, server_url: str, session_id: str | None = None): self.server = server_url self.session_id, err = self._create_session(session_id) if err: - raise RuntimeError(f'Failed to create session: {err}') + raise RuntimeError(f"Failed to create session: {err}") self.Policy = self._Policy(self) self.Monitor = self._Monitor(self) @@ -24,12 +24,12 @@ def _create_session( try: if session_id: response = requests.get( - f'{self.server}/session/new?session_id={session_id}', timeout=60 + f"{self.server}/session/new?session_id={session_id}", timeout=60 ) else: - response = requests.get(f'{self.server}/session/new', timeout=60) + response = requests.get(f"{self.server}/session/new", timeout=60) response.raise_for_status() - return response.json().get('id'), None + return response.json().get("id"), None except (ConnectionError, Timeout): elapsed += 1 time.sleep(1) @@ -37,12 +37,12 @@ def _create_session( return None, http_err except Exception as err: return None, err - return None, ConnectionError('Connection timed out') + return None, ConnectionError("Connection timed out") def close_session(self) -> Union[None, Exception]: try: response = requests.delete( - f'{self.server}/session/?session_id={self.session_id}', timeout=60 + f"{self.server}/session/?session_id={self.session_id}", timeout=60 ) response.raise_for_status() except (ConnectionError, Timeout, HTTPError) as err: @@ -57,19 +57,19 @@ def __init__(self, invariant): def _create_policy(self, rule: str) -> tuple[str | None, Exception | None]: try: response = requests.post( - f'{self.server}/policy/new?session_id={self.session_id}', - json={'rule': rule}, + f"{self.server}/policy/new?session_id={self.session_id}", + json={"rule": rule}, timeout=60, ) response.raise_for_status() - return response.json().get('policy_id'), None + return response.json().get("policy_id"), None except (ConnectionError, Timeout, HTTPError) as err: return None, err def get_template(self) -> tuple[str | None, Exception | None]: try: response = requests.get( - f'{self.server}/policy/template', + f"{self.server}/policy/template", timeout=60, ) response.raise_for_status() @@ -87,8 +87,8 @@ def from_string(self, rule: str): def analyze(self, trace: list[dict]) -> Union[Any, Exception]: try: response = requests.post( - f'{self.server}/policy/{self.policy_id}/analyze?session_id={self.session_id}', - json={'trace': trace}, + f"{self.server}/policy/{self.policy_id}/analyze?session_id={self.session_id}", + json={"trace": trace}, timeout=60, ) response.raise_for_status() @@ -100,17 +100,17 @@ class _Monitor: def __init__(self, invariant): self.server = invariant.server self.session_id = invariant.session_id - self.policy = '' + self.policy = "" def _create_monitor(self, rule: str) -> tuple[str | None, Exception | None]: try: response = requests.post( - f'{self.server}/monitor/new?session_id={self.session_id}', - json={'rule': rule}, + f"{self.server}/monitor/new?session_id={self.session_id}", + json={"rule": rule}, timeout=60, ) response.raise_for_status() - return response.json().get('monitor_id'), None + return response.json().get("monitor_id"), None except (ConnectionError, Timeout, HTTPError) as err: return None, err @@ -127,8 +127,8 @@ def check( ) -> Union[Any, Exception]: try: response = requests.post( - f'{self.server}/monitor/{self.monitor_id}/check?session_id={self.session_id}', - json={'past_events': past_events, 'pending_events': pending_events}, + f"{self.server}/monitor/{self.monitor_id}/check?session_id={self.session_id}", + json={"past_events": past_events, "pending_events": pending_events}, timeout=60, ) response.raise_for_status() diff --git a/openhands/security/invariant/nodes.py b/openhands/security/invariant/nodes.py index 47410264743b..42d7b6a6011f 100644 --- a/openhands/security/invariant/nodes.py +++ b/openhands/security/invariant/nodes.py @@ -10,7 +10,7 @@ class LLM: class Event(BaseModel): metadata: dict | None = Field( - default_factory=dict, description='Metadata associated with the event' + default_factory=dict, description="Metadata associated with the event" ) @@ -32,9 +32,9 @@ class Message(Event): def __rich_repr__(self): # Print on separate line - yield 'role', self.role - yield 'content', self.content - yield 'tool_calls', self.tool_calls + yield "role", self.role + yield "content", self.content + yield "tool_calls", self.tool_calls class ToolOutput(Event): diff --git a/openhands/security/invariant/parser.py b/openhands/security/invariant/parser.py index dea128692442..b3d6e06ed167 100644 --- a/openhands/security/invariant/parser.py +++ b/openhands/security/invariant/parser.py @@ -26,7 +26,7 @@ def get_next_id(trace: list[TraceElement]) -> str: for i in range(1, len(used_ids) + 2): if str(i) not in used_ids: return str(i) - return '1' + return "1" def get_last_id( @@ -43,21 +43,21 @@ def parse_action(trace: list[TraceElement], action: Action) -> list[TraceElement inv_trace = [] # type: list[TraceElement] if type(action) == MessageAction: if action.source == EventSource.USER: - inv_trace.append(Message(role='user', content=action.content)) + inv_trace.append(Message(role="user", content=action.content)) else: - inv_trace.append(Message(role='assistant', content=action.content)) + inv_trace.append(Message(role="assistant", content=action.content)) elif type(action) in [NullAction, ChangeAgentStateAction]: pass - elif hasattr(action, 'action') and action.action is not None: + elif hasattr(action, "action") and action.action is not None: event_dict = event_to_dict(action) - args = event_dict.get('args', {}) - thought = args.pop('thought', None) + args = event_dict.get("args", {}) + thought = args.pop("thought", None) function = Function(name=action.action, arguments=args) if thought is not None: - inv_trace.append(Message(role='assistant', content=thought)) - inv_trace.append(ToolCall(id=next_id, type='function', function=function)) + inv_trace.append(Message(role="assistant", content=thought)) + inv_trace.append(ToolCall(id=next_id, type="function", function=function)) else: - logger.error(f'Unknown action type: {type(action)}') + logger.error(f"Unknown action type: {type(action)}") return inv_trace @@ -67,10 +67,10 @@ def parse_observation( last_id = get_last_id(trace) if type(obs) in [NullObservation, AgentStateChangedObservation]: return [] - elif hasattr(obs, 'content') and obs.content is not None: - return [ToolOutput(role='tool', content=obs.content, tool_call_id=last_id)] + elif hasattr(obs, "content") and obs.content is not None: + return [ToolOutput(role="tool", content=obs.content, tool_call_id=last_id)] else: - logger.error(f'Unknown observation type: {type(obs)}') + logger.error(f"Unknown observation type: {type(obs)}") return [] @@ -99,5 +99,5 @@ def add_action(self, action: Action): def add_observation(self, obs: Observation): self.trace.extend(parse_observation(self.trace, obs)) - def concatenate(self, other: 'InvariantState'): + def concatenate(self, other: "InvariantState"): self.trace.extend(other.trace) diff --git a/openhands/server/auth/__init__.py b/openhands/server/auth/__init__.py index 0fe3ddd8cc0c..ed33d8d618c7 100644 --- a/openhands/server/auth/__init__.py +++ b/openhands/server/auth/__init__.py @@ -1,3 +1,3 @@ from openhands.server.auth.auth import get_sid_from_token, sign_token -__all__ = ['get_sid_from_token', 'sign_token'] +__all__ = ["get_sid_from_token", "sign_token"] diff --git a/openhands/server/auth/auth.py b/openhands/server/auth/auth.py index d668650f5834..3547361ea369 100644 --- a/openhands/server/auth/auth.py +++ b/openhands/server/auth/auth.py @@ -15,19 +15,19 @@ def get_sid_from_token(token: str, jwt_secret: str) -> str: """ try: # Decode the JWT using the specified secret and algorithm - payload = jwt.decode(token, jwt_secret, algorithms=['HS256']) + payload = jwt.decode(token, jwt_secret, algorithms=["HS256"]) # Ensure the payload contains 'sid' - if 'sid' in payload: - return payload['sid'] + if "sid" in payload: + return payload["sid"] else: - logger.error('SID not found in token') - return '' + logger.error("SID not found in token") + return "" except InvalidTokenError: - logger.error('Invalid token') + logger.error("Invalid token") except Exception as e: - logger.exception('Unexpected error decoding token: %s', e) - return '' + logger.exception("Unexpected error decoding token: %s", e) + return "" def sign_token(payload: dict[str, object], jwt_secret: str) -> str: @@ -36,4 +36,4 @@ def sign_token(payload: dict[str, object], jwt_secret: str) -> str: # "sid": sid, # # "exp": datetime.now(timezone.utc) + timedelta(minutes=15), # } - return jwt.encode(payload, jwt_secret, algorithm='HS256') + return jwt.encode(payload, jwt_secret, algorithm="HS256") diff --git a/openhands/server/data_models/feedback.py b/openhands/server/data_models/feedback.py index 59f32008b520..54463e24169f 100644 --- a/openhands/server/data_models/feedback.py +++ b/openhands/server/data_models/feedback.py @@ -10,36 +10,36 @@ class FeedbackDataModel(BaseModel): version: str email: str - polarity: Literal['positive', 'negative'] + polarity: Literal["positive", "negative"] feedback: Literal[ - 'positive', 'negative' + "positive", "negative" ] # TODO: remove this, its here for backward compatibility - permissions: Literal['public', 'private'] + permissions: Literal["public", "private"] trajectory: Optional[list[dict[str, Any]]] -FEEDBACK_URL = 'https://share-od-trajectory-3u9bw9tx.uc.gateway.dev/share_od_trajectory' +FEEDBACK_URL = "https://share-od-trajectory-3u9bw9tx.uc.gateway.dev/share_od_trajectory" def store_feedback(feedback: FeedbackDataModel) -> dict[str, str]: # Start logging feedback.feedback = feedback.polarity display_feedback = feedback.model_dump() - if 'trajectory' in display_feedback: - display_feedback['trajectory'] = ( + if "trajectory" in display_feedback: + display_feedback["trajectory"] = ( f"elided [length: {len(display_feedback['trajectory'])}" ) - if 'token' in display_feedback: - display_feedback['token'] = 'elided' - logger.debug(f'Got feedback: {display_feedback}') + if "token" in display_feedback: + display_feedback["token"] = "elided" + logger.debug(f"Got feedback: {display_feedback}") # Start actual request response = requests.post( FEEDBACK_URL, - headers={'Content-Type': 'application/json'}, + headers={"Content-Type": "application/json"}, json=feedback.model_dump(), ) if response.status_code != 200: - raise ValueError(f'Failed to store feedback: {response.text}') + raise ValueError(f"Failed to store feedback: {response.text}") response_data = json.loads(response.text) - logger.debug(f'Stored feedback: {response.text}') + logger.debug(f"Stored feedback: {response.text}") return response_data diff --git a/openhands/server/listen.py b/openhands/server/listen.py index 3b4db2daddad..94c956a1459b 100644 --- a/openhands/server/listen.py +++ b/openhands/server/listen.py @@ -368,8 +368,7 @@ async def websocket_endpoint(websocket: WebSocket): @app.get('/api/options/models') async def get_litellm_models() -> list[str]: - """ - Get all models supported by LiteLLM. + """Get all models supported by LiteLLM. This function combines models from litellm and Bedrock, removing any error-prone Bedrock models. diff --git a/openhands/server/middleware.py b/openhands/server/middleware.py index 218a949fca58..f8fbeebec7c1 100644 --- a/openhands/server/middleware.py +++ b/openhands/server/middleware.py @@ -6,8 +6,7 @@ class LocalhostCORSMiddleware(CORSMiddleware): - """ - Custom CORS middleware that allows any request from localhost/127.0.0.1 domains, + """Custom CORS middleware that allows any request from localhost/127.0.0.1 domains, while using standard CORS rules for other origins. """ @@ -28,9 +27,7 @@ def is_allowed_origin(self, origin: str) -> bool: class NoCacheMiddleware(BaseHTTPMiddleware): - """ - Middleware to disable caching for all routes by adding appropriate headers - """ + """Middleware to disable caching for all routes by adding appropriate headers""" async def dispatch(self, request, call_next): response = await call_next(request) diff --git a/openhands/server/mock/listen.py b/openhands/server/mock/listen.py index 9b9d1560e88b..650b653786b8 100644 --- a/openhands/server/mock/listen.py +++ b/openhands/server/mock/listen.py @@ -8,55 +8,55 @@ app = FastAPI() -@app.websocket('/ws') +@app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() # send message to mock connection await websocket.send_json( - {'action': ActionType.INIT, 'message': 'Control loop started.'} + {"action": ActionType.INIT, "message": "Control loop started."} ) try: while should_continue(): # receive message data = await websocket.receive_json() - logger.debug(f'Received message: {data}') + logger.debug(f"Received message: {data}") # send mock response to client - response = {'message': f'receive {data}'} + response = {"message": f"receive {data}"} await websocket.send_json(response) - logger.debug(f'Sent message: {response}') + logger.debug(f"Sent message: {response}") except Exception as e: - logger.debug(f'WebSocket Error: {e}') + logger.debug(f"WebSocket Error: {e}") -@app.get('/') +@app.get("/") def read_root(): - return {'message': 'This is a mock server'} + return {"message": "This is a mock server"} -@app.get('/api/options/models') +@app.get("/api/options/models") def read_llm_models(): return [ - 'gpt-4', - 'gpt-4-turbo-preview', - 'gpt-4-0314', - 'gpt-4-0613', + "gpt-4", + "gpt-4-turbo-preview", + "gpt-4-0314", + "gpt-4-0613", ] -@app.get('/api/options/agents') +@app.get("/api/options/agents") def read_llm_agents(): return [ - 'CodeActAgent', - 'PlannerAgent', + "CodeActAgent", + "PlannerAgent", ] -@app.get('/api/list-files') +@app.get("/api/list-files") def refresh_files(): - return ['hello_world.py'] + return ["hello_world.py"] -if __name__ == '__main__': - uvicorn.run(app, host='127.0.0.1', port=3000) +if __name__ == "__main__": + uvicorn.run(app, host="127.0.0.1", port=3000) diff --git a/openhands/server/session/__init__.py b/openhands/server/session/__init__.py index 3ee03d959461..0c6af2bdb38d 100644 --- a/openhands/server/session/__init__.py +++ b/openhands/server/session/__init__.py @@ -1,4 +1,4 @@ from openhands.server.session.manager import SessionManager from openhands.server.session.session import Session -__all__ = ['Session', 'SessionManager'] +__all__ = ["Session", "SessionManager"] diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index 8e7376a7668b..8bb258015cc0 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -44,7 +44,6 @@ def __init__( - sid: The session ID - file_store: Instance of the FileStore """ - self.sid = sid self.event_stream = EventStream(sid, file_store) self.file_store = file_store @@ -72,7 +71,7 @@ async def start( """ if self.controller or self.runtime: raise RuntimeError( - 'Session already started. You need to close this session and start a new one.' + "Session already started. You need to close this session and start a new one." ) asyncio.get_event_loop().run_in_executor( @@ -91,8 +90,8 @@ def _start_thread(self, *args): try: asyncio.run(self._start(*args), debug=True) except RuntimeError: - logger.error(f'Error starting session: {RuntimeError}', exc_info=True) - logger.debug('Session Finished') + logger.error(f"Error starting session: {RuntimeError}", exc_info=True) + logger.debug("Session Finished") async def _start( self, @@ -157,9 +156,8 @@ def _create_security_analyzer(self, security_analyzer: str | None): Parameters: - security_analyzer: The name of the security analyzer to use """ - if security_analyzer: - logger.debug(f'Using security analyzer: {security_analyzer}') + logger.debug(f"Using security analyzer: {security_analyzer}") self.security_analyzer = options.SecurityAnalyzers.get( security_analyzer, SecurityAnalyzer )(self.event_stream) @@ -177,11 +175,10 @@ async def _create_runtime( - config: - agent: """ - if self.runtime is not None: - raise RuntimeError('Runtime already created') + raise RuntimeError("Runtime already created") - logger.debug(f'Initializing runtime `{runtime_name}` now...') + logger.debug(f"Initializing runtime `{runtime_name}` now...") runtime_cls = get_runtime_cls(runtime_name) self.runtime = runtime_cls( config=config, @@ -194,19 +191,19 @@ async def _create_runtime( try: await self.runtime.connect() except Exception as e: - logger.error(f'Runtime initialization failed: {e}', exc_info=True) + logger.error(f"Runtime initialization failed: {e}", exc_info=True) if self._status_callback: self._status_callback( - 'error', 'STATUS$ERROR_RUNTIME_DISCONNECTED', str(e) + "error", "STATUS$ERROR_RUNTIME_DISCONNECTED", str(e) ) raise if self.runtime is not None: logger.debug( - f'Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}' + f"Runtime initialized with plugins: {[plugin.name for plugin in self.runtime.plugins]}" ) else: - logger.warning('Runtime initialization failed') + logger.warning("Runtime initialization failed") def _create_controller( self, @@ -227,29 +224,28 @@ def _create_controller( - agent_to_llm_config: - agent_configs: """ - if self.controller is not None: - raise RuntimeError('Controller already created') + raise RuntimeError("Controller already created") if self.runtime is None: raise RuntimeError( - 'Runtime must be initialized before the agent controller' + "Runtime must be initialized before the agent controller" ) msg = ( - '\n--------------------------------- OpenHands Configuration ---------------------------------\n' - f'LLM: {agent.llm.config.model}\n' - f'Base URL: {agent.llm.config.base_url}\n' + "\n--------------------------------- OpenHands Configuration ---------------------------------\n" + f"LLM: {agent.llm.config.model}\n" + f"Base URL: {agent.llm.config.base_url}\n" ) if agent.llm.config.draft_editor: msg += ( - f'Draft editor LLM (for file editing): {agent.llm.config.draft_editor.model}\n' - f'Draft editor LLM (for file editing) Base URL: {agent.llm.config.draft_editor.base_url}\n' + f"Draft editor LLM (for file editing): {agent.llm.config.draft_editor.model}\n" + f"Draft editor LLM (for file editing) Base URL: {agent.llm.config.draft_editor.base_url}\n" ) msg += ( - f'Agent: {agent.name}\n' - f'Runtime: {self.runtime.__class__.__name__}\n' - f'Plugins: {agent.sandbox_plugins}\n' - '-------------------------------------------------------------------------------------------' + f"Agent: {agent.name}\n" + f"Runtime: {self.runtime.__class__.__name__}\n" + f"Plugins: {agent.sandbox_plugins}\n" + "-------------------------------------------------------------------------------------------" ) logger.debug(msg) @@ -270,7 +266,7 @@ def _create_controller( self.controller.set_initial_state( agent_state, max_iterations, confirmation_mode ) - logger.debug(f'Restored agent state from session, sid: {self.sid}') + logger.debug(f"Restored agent state from session, sid: {self.sid}") except Exception as e: - logger.debug(f'State could not be restored: {e}') - logger.debug('Agent controller initialized.') + logger.debug(f"State could not be restored: {e}") + logger.debug("Agent controller initialized.") diff --git a/openhands/server/session/manager.py b/openhands/server/session/manager.py index f746b3676e29..fee610c1b8cc 100644 --- a/openhands/server/session/manager.py +++ b/openhands/server/session/manager.py @@ -29,7 +29,7 @@ async def attach_to_conversation(self, sid: str) -> Conversation | None: await c.connect() end_time = time.time() logger.info( - f'Conversation {c.sid} connected in {end_time - start_time} seconds' + f"Conversation {c.sid} connected in {end_time - start_time} seconds" ) return c diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index 25f707f15f53..91fe15ae2aa9 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -61,14 +61,14 @@ async def loop_recv(self): try: data = await self.websocket.receive_json() except ValueError: - await self.send_error('Invalid JSON') + await self.send_error("Invalid JSON") continue await self.dispatch(data) except WebSocketDisconnect: - logger.info('WebSocket disconnected, sid: %s', self.sid) + logger.info("WebSocket disconnected, sid: %s", self.sid) self.close() except RuntimeError as e: - logger.exception('Error in loop_recv: %s', e) + logger.exception("Error in loop_recv: %s", e) self.close() async def _initialize_agent(self, data: dict): @@ -76,16 +76,16 @@ async def _initialize_agent(self, data: dict): ChangeAgentStateAction(AgentState.LOADING), EventSource.ENVIRONMENT ) self.agent_session.event_stream.add_event( - AgentStateChangedObservation('', AgentState.LOADING), + AgentStateChangedObservation("", AgentState.LOADING), EventSource.ENVIRONMENT, ) # Extract the agent-relevant arguments from the request - args = {key: value for key, value in data.get('args', {}).items()} + args = {key: value for key, value in data.get("args", {}).items()} agent_cls = args.get(ConfigType.AGENT, self.config.default_agent) self.config.security.confirmation_mode = args.get( ConfigType.CONFIRMATION_MODE, self.config.security.confirmation_mode ) - self.config.security.security_analyzer = data.get('args', {}).get( + self.config.security.security_analyzer = data.get("args", {}).get( ConfigType.SECURITY_ANALYZER, self.config.security.security_analyzer ) max_iterations = args.get(ConfigType.MAX_ITERATIONS, self.config.max_iterations) @@ -119,9 +119,9 @@ async def _initialize_agent(self, data: dict): agent_configs=self.config.get_agent_configs(), ) except Exception as e: - logger.exception(f'Error creating controller: {e}') + logger.exception(f"Error creating controller: {e}") await self.send_error( - f'Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information..' + f"Error creating controller. Please check Docker is running and visit `{TROUBLESHOOTING_URL}` for more debugging information.." ) return @@ -148,16 +148,16 @@ async def on_event(self, event: Event): ): # feedback from the environment to agent actions is understood as agent events by the UI event_dict = event_to_dict(event) - event_dict['source'] = EventSource.AGENT + event_dict["source"] = EventSource.AGENT await self.send(event_dict) elif isinstance(event, ErrorObservation): # send error events as agent events to the UI event_dict = event_to_dict(event) - event_dict['source'] = EventSource.AGENT + event_dict["source"] = EventSource.AGENT await self.send(event_dict) async def dispatch(self, data: dict): - action = data.get('action', '') + action = data.get("action", "") if action == ActionType.INIT: await self._initialize_agent(data) return @@ -168,12 +168,12 @@ async def dispatch(self, data: dict): if controller: if controller.agent.llm.config.disable_vision: await self.send_error( - 'Support for images is disabled for this model, try without an image.' + "Support for images is disabled for this model, try without an image." ) return if not controller.agent.llm.vision_is_active(): await self.send_error( - 'Model does not support image upload, change to a different model or try without an image.' + "Model does not support image upload, change to a different model or try without an image." ) return self.agent_session.event_stream.add_event(event, EventSource.USER) @@ -192,15 +192,15 @@ async def send(self, data: dict[str, object]) -> bool: async def send_error(self, message: str) -> bool: """Sends an error message to the client.""" - return await self.send({'error': True, 'message': message}) + return await self.send({"error": True, "message": message}) async def _send_status_message(self, msg_type: str, id: str, message: str) -> bool: """Sends a status message to the client.""" - if msg_type == 'error': + if msg_type == "error": await self.agent_session.stop_agent_loop_for_error() return await self.send( - {'status_update': True, 'type': msg_type, 'id': id, 'message': message} + {"status_update": True, "type": msg_type, "id": id, "message": message} ) def queue_status_message(self, msg_type: str, id: str, message: str): diff --git a/openhands/storage/google_cloud.py b/openhands/storage/google_cloud.py index bbd2da273098..4f426532df32 100644 --- a/openhands/storage/google_cloud.py +++ b/openhands/storage/google_cloud.py @@ -9,8 +9,7 @@ class GoogleCloudFileStore(FileStore): def __init__(self, bucket_name: Optional[str] = None) -> None: - """ - Create a new FileStore. If GOOGLE_APPLICATION_CREDENTIALS is defined in the + """Create a new FileStore. If GOOGLE_APPLICATION_CREDENTIALS is defined in the environment it will be used for authentication. Otherwise access will be anonymous. """ diff --git a/openhands/utils/async_utils.py b/openhands/utils/async_utils.py index 2a3b73f5da7d..bcf5467f75a9 100644 --- a/openhands/utils/async_utils.py +++ b/openhands/utils/async_utils.py @@ -8,8 +8,7 @@ async def call_sync_from_async(fn: Callable, *args, **kwargs): - """ - Shorthand for running a function in the default background thread pool executor + """Shorthand for running a function in the default background thread pool executor and awaiting the result. The nature of synchronous code is that the future returned by this function is not cancellable """ @@ -22,11 +21,9 @@ async def call_sync_from_async(fn: Callable, *args, **kwargs): def call_async_from_sync( corofn: Callable, timeout: float = GENERAL_TIMEOUT, *args, **kwargs ): - """ - Shorthand for running a coroutine in the default background thread pool executor + """Shorthand for running a coroutine in the default background thread pool executor and awaiting the result """ - if corofn is None: raise ValueError('corofn is None') if not asyncio.iscoroutinefunction(corofn): @@ -61,8 +58,7 @@ async def call_coro_in_bg_thread( async def wait_all( iterable: Iterable[Coroutine], timeout: int = GENERAL_TIMEOUT ) -> List: - """ - Shorthand for waiting for all the coroutines in the iterable given in parallel. Creates + """Shorthand for waiting for all the coroutines in the iterable given in parallel. Creates a task for each coroutine. Returns a list of results in the original order. If any single task raised an exception, this is raised. If multiple tasks raised exceptions, an AsyncException is raised containing all exceptions. diff --git a/openhands/utils/embeddings.py b/openhands/utils/embeddings.py index 900b43052b13..00a118d2b798 100644 --- a/openhands/utils/embeddings.py +++ b/openhands/utils/embeddings.py @@ -76,7 +76,6 @@ def get_embedding_model(strategy: str, llm_config: LLMConfig) -> 'BaseEmbedding' Returns: - An instance of the selected embedding model or None. """ - if strategy in SUPPORTED_OLLAMA_EMBED_MODELS: from llama_index.embeddings.ollama import OllamaEmbedding @@ -152,7 +151,6 @@ def run_pipeline( embed_model: 'BaseEmbedding', documents: list['Document'], num_workers: int ) -> list['TextNode']: """Run a pipeline embedding documents.""" - # set up a pipeline with the transformations to make pipeline = IngestionPipeline( transformations=[ diff --git a/openhands/utils/prompt.py b/openhands/utils/prompt.py index 5d0d92968d35..85907f663347 100644 --- a/openhands/utils/prompt.py +++ b/openhands/utils/prompt.py @@ -9,8 +9,7 @@ class PromptManager: - """ - Manages prompt templates and micro-agents for AI interactions. + """Manages prompt templates and micro-agents for AI interactions. This class handles loading and rendering of system and user prompt templates, as well as loading micro-agent specifications. It provides methods to access diff --git a/tests/runtime/test_stress_remote_runtime.py b/tests/runtime/test_stress_remote_runtime.py index a38b5c5dbe24..a83ef230d362 100644 --- a/tests/runtime/test_stress_remote_runtime.py +++ b/tests/runtime/test_stress_remote_runtime.py @@ -204,7 +204,6 @@ def next_command(*args, **kwargs): ) def test_stress_remote_runtime(n_eval_workers: int = 64): """Mimic evaluation setting to test remote runtime in a multi-processing setting.""" - llm_config = LLMConfig() metadata = make_metadata( llm_config, diff --git a/tests/unit/linters/conftest.py b/tests/unit/linters/conftest.py index 4a2b51812bb9..bf2312c9a18d 100644 --- a/tests/unit/linters/conftest.py +++ b/tests/unit/linters/conftest.py @@ -9,7 +9,7 @@ def foo(): print("Wrong indent") foo( """ - file_path = tmp_path / 'test_file.py' + file_path = tmp_path / "test_file.py" file_path.write_text(file_content) return str(file_path) @@ -20,7 +20,7 @@ def wrongly_indented_py_file(tmp_path): def foo(): print("Hello, World!") """ - file_path = tmp_path / 'test_file.py' + file_path = tmp_path / "test_file.py" file_path.write_text(file_content) return str(file_path) @@ -28,7 +28,7 @@ def foo(): @pytest.fixture def simple_correct_py_file(tmp_path): file_content = 'print("Hello, World!")\n' - file_path = tmp_path / 'test_file.py' + file_path = tmp_path / "test_file.py" file_path.write_text(file_content) return str(file_path) @@ -39,7 +39,7 @@ def simple_correct_py_func_def(tmp_path): print("Hello, World!") foo() """ - file_path = tmp_path / 'test_file.py' + file_path = tmp_path / "test_file.py" file_path.write_text(file_content) return str(file_path) @@ -51,7 +51,7 @@ def simple_correct_ruby_file(tmp_path): end foo """ - file_path = tmp_path / 'test_file.rb' + file_path = tmp_path / "test_file.rb" file_path.write_text(file_content) return str(file_path) @@ -62,7 +62,7 @@ def simple_incorrect_ruby_file(tmp_path): print("Hello, World!") foo() """ - file_path = tmp_path / 'test_file.rb' + file_path = tmp_path / "test_file.rb" file_path.write_text(file_content) return str(file_path) @@ -70,6 +70,6 @@ def simple_incorrect_ruby_file(tmp_path): @pytest.fixture def parenthesis_incorrect_ruby_file(tmp_path): file_content = """def print_hello_world()\n puts 'Hello World'\n""" - file_path = tmp_path / 'test_file.rb' + file_path = tmp_path / "test_file.rb" file_path.write_text(file_content) return str(file_path) diff --git a/tests/unit/linters/test_lint_diff.py b/tests/unit/linters/test_lint_diff.py index f3b560c3df32..ce2fb6698e17 100644 --- a/tests/unit/linters/test_lint_diff.py +++ b/tests/unit/linters/test_lint_diff.py @@ -26,7 +26,7 @@ def foo(): def test_get_and_parse_diff(tmp_path): - diff = get_diff(OLD_CONTENT, NEW_CONTENT_V1, 'test.py') + diff = get_diff(OLD_CONTENT, NEW_CONTENT_V1, "test.py") print(diff) assert ( diff @@ -41,8 +41,8 @@ def test_get_and_parse_diff(tmp_path): ) print( - '\n'.join( - [f'{i+1}|{line}' for i, line in enumerate(NEW_CONTENT_V1.splitlines())] + "\n".join( + [f"{i+1}|{line}" for i, line in enumerate(NEW_CONTENT_V1.splitlines())] ) ) changes = parse_diff(diff) @@ -50,26 +50,26 @@ def test_get_and_parse_diff(tmp_path): assert ( changes[0].old is None and changes[0].new == 7 - and changes[0].line == 'def new_function_that_causes_error():' + and changes[0].line == "def new_function_that_causes_error():" ) assert ( changes[1].old is None and changes[1].new == 8 - and changes[1].line == ' y = ANOTHER_UNDEFINED_VARIABLE' + and changes[1].line == " y = ANOTHER_UNDEFINED_VARIABLE" ) - assert changes[2].old is None and changes[2].new == 9 and changes[2].line == '' + assert changes[2].old is None and changes[2].new == 9 and changes[2].line == "" def test_lint_with_diff_append(tmp_path): - with open(tmp_path / 'old.py', 'w') as f: + with open(tmp_path / "old.py", "w") as f: f.write(OLD_CONTENT) - with open(tmp_path / 'new.py', 'w') as f: + with open(tmp_path / "new.py", "w") as f: f.write(NEW_CONTENT_V1) linter = DefaultLinter() result: list[LintResult] = linter.lint_file_diff( - str(tmp_path / 'old.py'), - str(tmp_path / 'new.py'), + str(tmp_path / "old.py"), + str(tmp_path / "new.py"), ) print(result) assert len(result) == 1 @@ -81,15 +81,15 @@ def test_lint_with_diff_append(tmp_path): def test_lint_with_diff_insert(tmp_path): - with open(tmp_path / 'old.py', 'w') as f: + with open(tmp_path / "old.py", "w") as f: f.write(OLD_CONTENT) - with open(tmp_path / 'new.py', 'w') as f: + with open(tmp_path / "new.py", "w") as f: f.write(NEW_CONTENT_V2) linter = DefaultLinter() result: list[LintResult] = linter.lint_file_diff( - str(tmp_path / 'old.py'), - str(tmp_path / 'new.py'), + str(tmp_path / "old.py"), + str(tmp_path / "new.py"), ) assert len(result) == 1 assert ( @@ -119,15 +119,15 @@ def bar(): foo() bar() """ - with open(tmp_path / 'old.py', 'w') as f: + with open(tmp_path / "old.py", "w") as f: f.write(old_content) - with open(tmp_path / 'new.py', 'w') as f: + with open(tmp_path / "new.py", "w") as f: f.write(new_content) linter = DefaultLinter() result: list[LintResult] = linter.lint_file_diff( - str(tmp_path / 'old.py'), - str(tmp_path / 'new.py'), + str(tmp_path / "old.py"), + str(tmp_path / "new.py"), ) assert len(result) == 2 assert ( @@ -152,15 +152,15 @@ def test_lint_with_introduced_and_fixed_errors(tmp_path): y = ANOTHER_UNDEFINED_VARIABLE z = UNDEFINED_VARIABLE """ - with open(tmp_path / 'old.py', 'w') as f: + with open(tmp_path / "old.py", "w") as f: f.write(old_content) - with open(tmp_path / 'new.py', 'w') as f: + with open(tmp_path / "new.py", "w") as f: f.write(new_content) linter = DefaultLinter() result: list[LintResult] = linter.lint_file_diff( - str(tmp_path / 'old.py'), - str(tmp_path / 'new.py'), + str(tmp_path / "old.py"), + str(tmp_path / "new.py"), ) assert len(result) == 2 assert ( @@ -189,15 +189,15 @@ def complex_function(a, b, c): b + c) """ - with open(tmp_path / 'old.py', 'w') as f: + with open(tmp_path / "old.py", "w") as f: f.write(old_content) - with open(tmp_path / 'new.py', 'w') as f: + with open(tmp_path / "new.py", "w") as f: f.write(new_content) linter = DefaultLinter() result: list[LintResult] = linter.lint_file_diff( - str(tmp_path / 'old.py'), - str(tmp_path / 'new.py'), + str(tmp_path / "old.py"), + str(tmp_path / "new.py"), ) assert len(result) == 1 assert ( @@ -216,15 +216,15 @@ def foo(): def foo(): print("Hello, World!" """ - with open(tmp_path / 'old.py', 'w') as f: + with open(tmp_path / "old.py", "w") as f: f.write(old_content) - with open(tmp_path / 'new.py', 'w') as f: + with open(tmp_path / "new.py", "w") as f: f.write(new_content) linter = DefaultLinter() result: list[LintResult] = linter.lint_file_diff( - str(tmp_path / 'old.py'), - str(tmp_path / 'new.py'), + str(tmp_path / "old.py"), + str(tmp_path / "new.py"), ) assert len(result) == 1 assert ( @@ -248,15 +248,15 @@ def foo(): """ print("Hello, World!") ''' - with open(tmp_path / 'old.py', 'w') as f: + with open(tmp_path / "old.py", "w") as f: f.write(old_content) - with open(tmp_path / 'new.py', 'w') as f: + with open(tmp_path / "new.py", "w") as f: f.write(new_content) linter = DefaultLinter() result: list[LintResult] = linter.lint_file_diff( - str(tmp_path / 'old.py'), - str(tmp_path / 'new.py'), + str(tmp_path / "old.py"), + str(tmp_path / "new.py"), ) assert len(result) == 0 # Linter should ignore changes in docstrings @@ -274,15 +274,15 @@ def foo(): x = UNDEFINED_VARIABLE + ANOTHER_UNDEFINED_VARIABLE foo() """ - with open(tmp_path / 'old.py', 'w') as f: + with open(tmp_path / "old.py", "w") as f: f.write(old_content) - with open(tmp_path / 'new.py', 'w') as f: + with open(tmp_path / "new.py", "w") as f: f.write(new_content) linter = DefaultLinter() result: list[LintResult] = linter.lint_file_diff( - str(tmp_path / 'old.py'), - str(tmp_path / 'new.py'), + str(tmp_path / "old.py"), + str(tmp_path / "new.py"), ) print(result) assert len(result) == 2 @@ -299,14 +299,13 @@ def foo(): def test_parse_diff_with_empty_patch(): - diff_patch = '' + diff_patch = "" changes = parse_diff(diff_patch) assert len(changes) == 0 def test_lint_file_diff_ignore_existing_errors(tmp_path): - """ - Make sure we allow edits as long as it does not introduce new errors. In other + """Make sure we allow edits as long as it does not introduce new errors. In other words, we don't care about existing linting errors. Although they might be real syntax issues, sometimes they are just false positives, or errors that we don't care about. @@ -323,10 +322,10 @@ def some_wrong_but_unused_function(): def sum(a, b): return a - b """ - new_content = content.replace(' return a - b', ' return a + b') - temp_file_old_path = tmp_path / 'problematic-file-test.py' + new_content = content.replace(" return a - b", " return a + b") + temp_file_old_path = tmp_path / "problematic-file-test.py" temp_file_old_path.write_text(content) - temp_file_new_path = tmp_path / 'problematic-file-test-new.py' + temp_file_new_path = tmp_path / "problematic-file-test-new.py" temp_file_new_path.write_text(new_content) linter = DefaultLinter() @@ -338,8 +337,7 @@ def sum(a, b): def test_lint_file_diff_catch_new_errors_in_edits(tmp_path): - """ - Make sure we catch new linting errors in our edit chunk, and at the same + """Make sure we catch new linting errors in our edit chunk, and at the same time, ignore old linting errors (in this case, the old linting error is a false positive) """ @@ -352,10 +350,10 @@ def sum(a, b): return a - b """ - temp_file_old_path = tmp_path / 'problematic-file-test.py' + temp_file_old_path = tmp_path / "problematic-file-test.py" temp_file_old_path.write_text(content) - new_content = content.replace(' return a - b', ' return a + variable') - temp_file_new_path = tmp_path / 'problematic-file-test-new.py' + new_content = content.replace(" return a - b", " return a + variable") + temp_file_new_path = tmp_path / "problematic-file-test-new.py" temp_file_new_path.write_text(new_content) linter = DefaultLinter() @@ -373,8 +371,7 @@ def sum(a, b): def test_lint_file_diff_catch_new_errors_outside_edits(tmp_path): - """ - Make sure we catch new linting errors induced by our edits, even + """Make sure we catch new linting errors induced by our edits, even though the error itself is not in the edit chunk """ content = """def valid_func1(): @@ -390,13 +387,13 @@ def valid_func2(): # linting would pass, and thus there won't be any comparison # between pre-edit and post-edit linting. for _ in range(100): - content += '\ninvalid_func()' + content += "\ninvalid_func()" - temp_file_old_path = tmp_path / 'problematic-file-test.py' + temp_file_old_path = tmp_path / "problematic-file-test.py" temp_file_old_path.write_text(content) - new_content = content.replace('def my_sum(a, b):', 'def my_sum2(a, b):') - temp_file_new_path = tmp_path / 'problematic-file-test-new.py' + new_content = content.replace("def my_sum(a, b):", "def my_sum2(a, b):") + temp_file_new_path = tmp_path / "problematic-file-test-new.py" temp_file_new_path.write_text(new_content) linter = DefaultLinter() diff --git a/tests/unit/linters/test_python_linter.py b/tests/unit/linters/test_python_linter.py index 40aed81ec3f3..096fc1d4baa0 100644 --- a/tests/unit/linters/test_python_linter.py +++ b/tests/unit/linters/test_python_linter.py @@ -9,7 +9,7 @@ def test_wrongly_indented_py_file(wrongly_indented_py_file): # Test Python linter linter = PythonLinter() - assert '.py' in linter.supported_extensions + assert ".py" in linter.supported_extensions result = linter.lint(wrongly_indented_py_file) print(result) assert isinstance(result, list) and len(result) == 1 @@ -17,21 +17,21 @@ def test_wrongly_indented_py_file(wrongly_indented_py_file): file=wrongly_indented_py_file, line=2, column=5, - message='E999 IndentationError: unexpected indent', + message="E999 IndentationError: unexpected indent", ) print(result[0].visualize()) assert result[0].visualize() == ( - '1|\n' - '\033[91m2| def foo():\033[0m\n' - ' ^ ERROR HERE: E999 IndentationError: unexpected indent\n' + "1|\n" + "\033[91m2| def foo():\033[0m\n" + " ^ ERROR HERE: E999 IndentationError: unexpected indent\n" '3| print("Hello, World!")\n' - '4|' + "4|" ) # General linter should have same result as Python linter # bc it uses PythonLinter under the hood general_linter = DefaultLinter() - assert '.py' in general_linter.supported_extensions + assert ".py" in general_linter.supported_extensions result = general_linter.lint(wrongly_indented_py_file) assert result == linter.lint(wrongly_indented_py_file) @@ -42,18 +42,18 @@ def test_wrongly_indented_py_file(wrongly_indented_py_file): compile_result = python_compile_lint(wrongly_indented_py_file) assert isinstance(compile_result, list) and len(compile_result) == 1 assert compile_result[0] == LintResult( - file=wrongly_indented_py_file, line=2, column=4, message='unexpected indent' + file=wrongly_indented_py_file, line=2, column=4, message="unexpected indent" ) def test_simple_correct_py_file(simple_correct_py_file): linter = PythonLinter() - assert '.py' in linter.supported_extensions + assert ".py" in linter.supported_extensions result = linter.lint(simple_correct_py_file) assert result == [] general_linter = DefaultLinter() - assert '.py' in general_linter.supported_extensions + assert ".py" in general_linter.supported_extensions result = general_linter.lint(simple_correct_py_file) assert result == linter.lint(simple_correct_py_file) @@ -72,7 +72,7 @@ def test_simple_correct_py_func_def(simple_correct_py_func_def): assert result == [] general_linter = DefaultLinter() - assert '.py' in general_linter.supported_extensions + assert ".py" in general_linter.supported_extensions result = general_linter.lint(simple_correct_py_func_def) assert result == linter.lint(simple_correct_py_func_def) diff --git a/tests/unit/linters/test_treesitter_linter.py b/tests/unit/linters/test_treesitter_linter.py index 195a48bf3632..c5d661eebf99 100644 --- a/tests/unit/linters/test_treesitter_linter.py +++ b/tests/unit/linters/test_treesitter_linter.py @@ -11,18 +11,18 @@ def test_syntax_error_py_file(syntax_error_py_file): file=syntax_error_py_file, line=5, column=5, - message='Syntax error', + message="Syntax error", ) assert ( result[0].visualize() == ( - '2| def foo():\n' + "2| def foo():\n" '3| print("Hello, World!")\n' '4| print("Wrong indent")\n' - '\033[91m5| foo(\033[0m\n' # color red - ' ^ ERROR HERE: Syntax error\n' - '6|' + "\033[91m5| foo(\033[0m\n" # color red + " ^ ERROR HERE: Syntax error\n" + "6|" ) ) print(result[0].visualize()) @@ -54,32 +54,32 @@ def test_simple_incorrect_ruby_file(simple_incorrect_ruby_file): file=simple_incorrect_ruby_file, line=1, column=1, - message='Syntax error', + message="Syntax error", ) print(result[0].visualize()) assert ( result[0].visualize() == ( - '\033[91m1|def foo():\033[0m\n' # color red - ' ^ ERROR HERE: Syntax error\n' + "\033[91m1|def foo():\033[0m\n" # color red + " ^ ERROR HERE: Syntax error\n" '2| print("Hello, World!")\n' - '3|foo()' + "3|foo()" ) ) assert result[1] == LintResult( file=simple_incorrect_ruby_file, line=1, column=10, - message='Syntax error', + message="Syntax error", ) print(result[1].visualize()) assert ( result[1].visualize() == ( - '\033[91m1|def foo():\033[0m\n' # color red - ' ^ ERROR HERE: Syntax error\n' + "\033[91m1|def foo():\033[0m\n" # color red + " ^ ERROR HERE: Syntax error\n" '2| print("Hello, World!")\n' - '3|foo()' + "3|foo()" ) ) @@ -98,12 +98,12 @@ def test_parenthesis_incorrect_ruby_file(parenthesis_incorrect_ruby_file): file=parenthesis_incorrect_ruby_file, line=1, column=1, - message='Syntax error', + message="Syntax error", ) print(result[0].visualize()) assert result[0].visualize() == ( - '\033[91m1|def print_hello_world()\033[0m\n' - ' ^ ERROR HERE: Syntax error\n' + "\033[91m1|def print_hello_world()\033[0m\n" + " ^ ERROR HERE: Syntax error\n" "2| puts 'Hello World'" ) diff --git a/tests/unit/linters/test_visualize.py b/tests/unit/linters/test_visualize.py index e8232afd0117..344f82b69e81 100644 --- a/tests/unit/linters/test_visualize.py +++ b/tests/unit/linters/test_visualize.py @@ -7,15 +7,15 @@ @pytest.fixture def mock_file_content(): - return '\n'.join([f'Line {i}' for i in range(1, 21)]) + return "\n".join([f"Line {i}" for i in range(1, 21)]) def test_visualize_standard_case(mock_file_content): lint_result = LintResult( - file='test_file.py', line=10, column=5, message='Test error message' + file="test_file.py", line=10, column=5, message="Test error message" ) - with patch('builtins.open', mock_open(read_data=mock_file_content)): + with patch("builtins.open", mock_open(read_data=mock_file_content)): result = lint_result.visualize(half_window=3) expected_output = ( @@ -34,10 +34,10 @@ def test_visualize_standard_case(mock_file_content): def test_visualize_small_window(mock_file_content): lint_result = LintResult( - file='test_file.py', line=10, column=5, message='Test error message' + file="test_file.py", line=10, column=5, message="Test error message" ) - with patch('builtins.open', mock_open(read_data=mock_file_content)): + with patch("builtins.open", mock_open(read_data=mock_file_content)): result = lint_result.visualize(half_window=1) expected_output = ( @@ -52,10 +52,10 @@ def test_visualize_small_window(mock_file_content): def test_visualize_error_at_start(mock_file_content): lint_result = LintResult( - file='test_file.py', line=1, column=3, message='Start error' + file="test_file.py", line=1, column=3, message="Start error" ) - with patch('builtins.open', mock_open(read_data=mock_file_content)): + with patch("builtins.open", mock_open(read_data=mock_file_content)): result = lint_result.visualize(half_window=2) expected_output = ( @@ -70,10 +70,10 @@ def test_visualize_error_at_start(mock_file_content): def test_visualize_error_at_end(mock_file_content): lint_result = LintResult( - file='test_file.py', line=20, column=1, message='End error' + file="test_file.py", line=20, column=1, message="End error" ) - with patch('builtins.open', mock_open(read_data=mock_file_content)): + with patch("builtins.open", mock_open(read_data=mock_file_content)): result = lint_result.visualize(half_window=2) expected_output = (