From 01d64d64c5730873ffa4046f85d03968349dd62f Mon Sep 17 00:00:00 2001 From: AlexKaravaev Date: Sun, 24 Mar 2024 22:57:40 +0100 Subject: [PATCH] fixup --- .isort.cfg | 2 +- ciare_world_creator/collections/utils.py | 4 +- ciare_world_creator/commands/create.py | 76 ++++++++++--------- .../model_databases/objaverse.py | 14 ++++ .../sim_interfaces/__init__.py | 0 ciare_world_creator/sim_interfaces/gazebo.py | 61 +++++++++++++++ ciare_world_creator/sim_interfaces/mujoco.py | 62 +++++++++++++++ ciare_world_creator/xml/worlds.py | 57 -------------- test.py | 41 ++++++---- 9 files changed, 208 insertions(+), 109 deletions(-) create mode 100644 ciare_world_creator/sim_interfaces/__init__.py create mode 100644 ciare_world_creator/sim_interfaces/gazebo.py create mode 100644 ciare_world_creator/sim_interfaces/mujoco.py diff --git a/.isort.cfg b/.isort.cfg index d5cab83..9e91008 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -13,4 +13,4 @@ default_section=THIRDPARTY known_localfolder= ciare_world_creator, -known_third_party = aiohttp,aspose,chromadb,click,langchain,lxml,objaverse,openai,pandas,pytest,questionary,requests,tabulate,tinydb,tqdm +known_third_party = aiohttp,aspose,chromadb,click,langchain,lxml,obj2mjcf,objaverse,openai,pandas,pytest,questionary,requests,tabulate,tinydb,tqdm,trimesh diff --git a/ciare_world_creator/collections/utils.py b/ciare_world_creator/collections/utils.py index 9b1dae9..93c2770 100644 --- a/ciare_world_creator/collections/utils.py +++ b/ciare_world_creator/collections/utils.py @@ -16,8 +16,8 @@ def fill_index(collection, loader: BaseLoader): f"Generating indicies for chromadb. This might take a while, but it's done only once", style="bold italic fg:green", ) - models = loader.get_models() - + models, _ = loader.get_models() + print(models) df_models = pd.DataFrame(models) df_models = df_models.drop_duplicates(subset="name") df_models["tags"] = df_models["tags"].apply( diff --git a/ciare_world_creator/commands/create.py b/ciare_world_creator/commands/create.py index 9e6c5d8..52c9936 100644 --- a/ciare_world_creator/commands/create.py +++ b/ciare_world_creator/commands/create.py @@ -16,15 +16,11 @@ from ciare_world_creator.model_databases.fetch_worlds import download_world from ciare_world_creator.model_databases.gazebo import GazeboLoader from ciare_world_creator.model_databases.objaverse import ObjaverseLoader +from ciare_world_creator.sim_interfaces.gazebo import GazeboSimInterface +from ciare_world_creator.sim_interfaces.mujoco import MujocoSimInterface from ciare_world_creator.utils.cache import Cache from ciare_world_creator.utils.style import STYLE -from ciare_world_creator.xml.worlds import ( - add_model_to_xml, - check_world, - find_model, - find_world, - save_xml, -) +from ciare_world_creator.xml.worlds import find_model @click.command( @@ -39,25 +35,26 @@ def cli(ctx): from ciare_world_creator.llm.model import prompt_model simulators = ["mujoco", "gazebo"] - chosen_simulator = questionary.select( - message=("Choose simulator to generate world for."), - choices=simulators, - style=STYLE, - ).ask() - + # chosen_simulator = questionary.select( + # message=("Choose simulator to generate world for."), + # choices=simulators, + # style=STYLE, + # ).ask() + chosen_simulator = "mujoco" if chosen_simulator == "gazebo": # Only gazebo is supported loader = GazeboLoader() - full_models = loader.get_models_full() - full_worlds = loader.get_worlds_full() + interface = GazeboSimInterface() elif chosen_simulator == "mujoco": loader = ObjaverseLoader() + interface = MujocoSimInterface() models, worlds = loader.get_models() - world_query = questionary.text( - "Enter query for world generation(E.g Two cars and person next to it)", - style=STYLE, - ).ask() + # world_query = questionary.text( + # "Enter query for world generation(E.g Two cars and person next to it)", + # style=STYLE, + # ).ask() + world_query = "10 cups" if not world_query: sys.exit(os.EX_OK) @@ -68,7 +65,7 @@ def cli(ctx): exists = db.search(World.prompt == query) openai.api_key = os.getenv("OPENAI_API_KEY") - models = openai.Model.list() + llm_models = openai.Model.list() chosen_model = "gpt-4" if exists: @@ -78,7 +75,7 @@ def cli(ctx): ) return - model_collection = get_or_create_collection("models_" + chosen_simulator) + model_collection = get_or_create_collection("models_" + chosen_simulator, loader) try: claim_query_result = model_collection.query( query_texts=[query], @@ -101,16 +98,18 @@ def cli(ctx): ) ] - generate_world = questionary.confirm( - "Do you want to spawn model in an empty world?" - " Saying no will download world from database, but it's very unstable. Y/n", - style=STYLE, - ).ask() + # generate_world = questionary.confirm( + # "Do you want to spawn model in an empty world?" + # " Saying no will download world from database, but it's very unstable. Y/n", + # style=STYLE, + # ).ask() + + generate_world = False if generate_world is None: sys.exit(os.EX_OK) - if not generate_world: + if generate_world: content = fmt_world_qa_tmpl.format(context_str=worlds) questionary.print("Generating world... 🌎", style="bold fg:yellow") @@ -120,7 +119,7 @@ def cli(ctx): f"World is {world['World']}, downloading it", style="bold italic fg:green" ) - full_world = find_world(world["World"], full_worlds) + full_world = interface.find_world(world["World"], worlds) template_world_path = None if world["World"] != "None": template_world_path = download_world( @@ -137,7 +136,7 @@ def cli(ctx): world = {"World": "None"} template_world_path = os.path.join(cache.worlds_path, "empty.sdf") - if not check_world(template_world_path): + if not interface.check_world(template_world_path): questionary.print( "Suggested world is malformed. Falling back to empty world", style="bold italic fg:red", @@ -148,11 +147,12 @@ def cli(ctx): "Spawning models in the world... 🫖", style="bold italic fg:yellow" ) content = fmt_model_qa_tmpl.format(context_str=context) - models = prompt_model(content, query, chosen_model) + chosen_models = prompt_model(content, query, chosen_model) - for model in models: - if not find_model(model["Model"], full_models): - models = prompt_model( + print(chosen_models) + for model in chosen_models: + if not find_model(model["Model"], models): + chosen_models = prompt_model( content, f"{model} was not found in context list. " "Generate only the one that are in the context", @@ -161,12 +161,14 @@ def cli(ctx): questionary.print("Placing models in the world... 📍", style="bold italic fg:yellow") content = fmt_place_qa_tmpl.format( - context_str=f"Arrange following models: {str(models)}", + context_str=f"Arrange following models: {str(chosen_models)}", world_file=open(template_world_path, "r"), ) - + # print(content) + # sys.exit(0) placement = prompt_model(content, query, chosen_model) + print(placement) # TODO handle ,.; etc cleaned_query = re.sub(r'[<>:;.,"/\\|?*]', "", query).strip() world_name = f'world_{cleaned_query.replace(" ", "_")}' @@ -177,9 +179,11 @@ def cli(ctx): # TODO add asserts on model fields non_existent_models = [] + interface.add_models(placement, models) + sys.exit(0) for model in placement: # Example usage - m = find_model(model["Model"], full_models) + m = find_model(model["Model"], models) if not m: questionary.print( f"Model {model} was not found in database. " diff --git a/ciare_world_creator/model_databases/objaverse.py b/ciare_world_creator/model_databases/objaverse.py index aea16e1..051763f 100644 --- a/ciare_world_creator/model_databases/objaverse.py +++ b/ciare_world_creator/model_databases/objaverse.py @@ -1,3 +1,6 @@ +import json +import os + import objaverse from ciare_world_creator.model_databases.base import BaseLoader @@ -5,6 +8,15 @@ class ObjaverseLoader(BaseLoader): def __init__(self): + fp = "./LVIS.json" + if os.path.exists(fp): + # If the file does not exist, create it and dump the JSON data + with open(fp, "r") as file: + cached = json.load(file) + self.annotations = cached[0] + self.uid_to_category = cached[1] + return + lvis_annotations = objaverse.load_lvis_annotations() truncated_annotations = ( @@ -19,6 +31,8 @@ def __init__(self): self.uid_to_category[item] = key self.annotations = objaverse.load_annotations(self.uid_to_category.keys()) + with open(fp, "w") as file: + json.dump([self.annotations, self.uid_to_category], file) def get_models(self): only_description_models = [] diff --git a/ciare_world_creator/sim_interfaces/__init__.py b/ciare_world_creator/sim_interfaces/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ciare_world_creator/sim_interfaces/gazebo.py b/ciare_world_creator/sim_interfaces/gazebo.py new file mode 100644 index 0000000..7c47c91 --- /dev/null +++ b/ciare_world_creator/sim_interfaces/gazebo.py @@ -0,0 +1,61 @@ +from lxml import etree +from lxml import etree as ET + + +class GazeboSimInterface: + def __init__(self): + pass + + def check_world(self, template_world_path): + """Load world and asserts if basic tags are there.""" + parser = ET.XMLParser(recover=True, remove_blank_text=True) + + tree = etree.parse(template_world_path, parser=parser) + + root = tree.getroot() + world_xml = root.find("world") + return world_xml is not None + + def add_model_to_xml( + self, model_name, pose_x, pose_y, pose_z, pose_roll, pose_pitch, pose_yaw, uri + ): + # Create the new element + include = ET.Element("include") + + name = ET.SubElement(include, "name") + name.text = model_name + + pose = ET.SubElement(include, "pose") + pose.text = f"{pose_x} {pose_y} {pose_z} {pose_roll} {pose_pitch} {pose_yaw}" + + uri_element = ET.SubElement(include, "uri") + uri_element.text = uri + + return include + + def save_xml(self, xml_file, template_world_path, include_tags): + parser = ET.XMLParser(recover=True, remove_blank_text=True) + + tree = etree.parse(template_world_path, parser=parser) + + root = tree.getroot() + + world_xml = root.find("world") + + for include in include_tags: + world_xml.append(include) + + # Indent the XML with two spaces + tree_str = ET.tostring( + root, + pretty_print=True, + encoding="utf-8", + xml_declaration=True, + with_tail=True, + ) + + # parsed_tree = ET.fromstring(tree_str) + + # Save the formatted XML to the file + with open(xml_file, "wb") as file: + file.write(tree_str) diff --git a/ciare_world_creator/sim_interfaces/mujoco.py b/ciare_world_creator/sim_interfaces/mujoco.py new file mode 100644 index 0000000..4179b59 --- /dev/null +++ b/ciare_world_creator/sim_interfaces/mujoco.py @@ -0,0 +1,62 @@ +import os +from pathlib import Path + +import objaverse +import trimesh +from obj2mjcf.cli import Args, process_obj + + +class MujocoSimInterface: + def __init__(self): + pass + + def check_world(self, world): + pass + + def generate_world(self): + world = {"World": "None"} + template_world_path = os.path.join(cache.worlds_path, "empty.sdf") + + def find_entry_by_name(self, name, full_list): + for entry in full_list: + if entry["name"] == name: + return entry + return None + + def add_models(self, placed_models, models): + full_placed_models = [] + + for model in placed_models: + if model_entry := self.find_entry_by_name(model["Model"], models): + model_entry.update(model) + full_placed_models.append(model_entry) + print(model_entry) + print(full_placed_models) + + # model_db_interface.load_models(full_placed_models) + objects = objaverse.load_objects( + uids=[entry["uuid"] for entry in full_placed_models] + ) + obj_locs = list(objects.values()) + + print(obj_locs) + + for i in range(len(full_placed_models)): + full_placed_models[i]["model_loc"] = obj_locs[i] + + mesh = trimesh.load(obj_locs[i]) + # trimesh.exchange.obj.export_obj(mesh) + obj, data = trimesh.exchange.export.export_obj( + mesh, include_texture=True, return_texture=True + ) + + obj_path = f"./converted/{full_placed_models[i]['uuid']}.obj" + with open(obj_path, "w") as f: + f.write(obj) + # save the MTL and images + for k, v in data.items(): + with open(os.path.join("./converted/", k), "wb") as f: + f.write(v) + args = Args("./", save_mjcf=True, compile_model=True, overwrite=True) + + process_obj(Path(obj_path), args) diff --git a/ciare_world_creator/xml/worlds.py b/ciare_world_creator/xml/worlds.py index 8092675..6ea23e5 100644 --- a/ciare_world_creator/xml/worlds.py +++ b/ciare_world_creator/xml/worlds.py @@ -1,60 +1,3 @@ -from lxml import etree -from lxml import etree as ET - - -def check_world(template_world_path): - """Load world and asserts if basic tags are there.""" - parser = ET.XMLParser(recover=True, remove_blank_text=True) - - tree = etree.parse(template_world_path, parser=parser) - - root = tree.getroot() - world_xml = root.find("world") - return world_xml is not None - - -def add_model_to_xml( - model_name, pose_x, pose_y, pose_z, pose_roll, pose_pitch, pose_yaw, uri -): - # Create the new element - include = ET.Element("include") - - name = ET.SubElement(include, "name") - name.text = model_name - - pose = ET.SubElement(include, "pose") - pose.text = f"{pose_x} {pose_y} {pose_z} {pose_roll} {pose_pitch} {pose_yaw}" - - uri_element = ET.SubElement(include, "uri") - uri_element.text = uri - - return include - - -def save_xml(xml_file, template_world_path, include_tags): - parser = ET.XMLParser(recover=True, remove_blank_text=True) - - tree = etree.parse(template_world_path, parser=parser) - - root = tree.getroot() - - world_xml = root.find("world") - - for include in include_tags: - world_xml.append(include) - - # Indent the XML with two spaces - tree_str = ET.tostring( - root, pretty_print=True, encoding="utf-8", xml_declaration=True, with_tail=True - ) - - # parsed_tree = ET.fromstring(tree_str) - - # Save the formatted XML to the file - with open(xml_file, "wb") as file: - file.write(tree_str) - - def find_model(model, models): for m in models: if m["name"] == model: diff --git a/test.py b/test.py index 2cdfaa4..ebe1980 100644 --- a/test.py +++ b/test.py @@ -1,4 +1,7 @@ +import os + import objaverse +import trimesh from aspose.threed import Scene from aspose.threed.formats import ObjSaveOptions @@ -26,17 +29,29 @@ print(obj_locs) -# trimesh.load(list(objects.values())[4]).show() - - -scene = Scene.from_file( - "/home/Alexander.Karavaev/.objaverse/hf-objaverse-v1/glbs/000-023/bear.glb" +mesh = trimesh.load(list(objects.values())[4]) +# trimesh.exchange.obj.export_obj(mesh) +obj, data = trimesh.exchange.export.export_obj( + mesh, include_texture=True, return_texture=True ) - -# Specify OBJ save options -obj_save_options = ObjSaveOptions() -# Import materials from external material library file -obj_save_options.enable_materials = True - -# Save it as an OBJ -scene.save("test.obj", obj_save_options) +obj_path = "test_m.obj" +with open(obj_path, "w") as f: + f.write(obj) +# save the MTL and images +for k, v in data.items(): + with open(os.path.join("./", k), "wb") as f: + f.write(v) +# reload the mesh from the export +# rec = trimesh.load(obj_path) + +# scene = Scene.from_file( +# "/home/Alexander.Karavaev/.objaverse/hf-objaverse-v1/glbs/000-023/bear.glb" +# ) + +# # Specify OBJ save options +# obj_save_options = ObjSaveOptions() +# # Import materials from external material library file +# obj_save_options.enable_materials = True + +# # Save it as an OBJ +# scene.save("test.obj", obj_save_options)