Skip to content

Commit

Permalink
after small refactor x2
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexKaravaev committed Apr 1, 2024
1 parent 921ee06 commit 554a03d
Showing 1 changed file with 132 additions and 122 deletions.
254 changes: 132 additions & 122 deletions ciare_world_creator/sim_interfaces/mujoco.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,138 +115,148 @@ def create_main_root(self, full_placed_models, objects):
material_count = 0
for i, _ in enumerate(full_placed_models):
material_map = {}
print(full_placed_models[i])
full_placed_models[i]["model_loc"] = objects[full_placed_models[i]["uuid"]]

mesh = trimesh.load(full_placed_models[i]["model_loc"], force="mesh")

mesh.apply_scale(full_placed_models[i]["scale"])
# trimesh.exchange.obj.export_obj(mesh)
obj, data = trimesh.exchange.export.export_obj(
mesh, include_texture=True, return_texture=True
)

path = f"./converted/{full_placed_models[i]['uuid']}"
path = os.path.abspath(path)
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
nested_path = Path(
os.path.abspath(str(path) + f"/{full_placed_models[i]['uuid']}")
mesh = self.load_and_scale_mesh(full_placed_models[i])
obj, data = self.export_mesh(mesh)
path = self.create_model_path(full_placed_models[i])
obj_path = self.write_obj_file(obj, path, full_placed_models[i])
self.save_material_and_images(data, path)
self.copy_images_to_nested_path(path, full_placed_models[i])
args = self.create_args(path)
printed_output = self.process_obj_file(obj_path, args)
if "Error compiling model" in printed_output:
continue
saved_mjc_path = self.get_saved_mjc_path(path, full_placed_models[i])
tree, root, included_tree, included_root = self.parse_xml(saved_mjc_path)
self.modify_default_class_attributes(
included_root, material_map, visual_count, collision_count
)
nested_path.mkdir(parents=True, exist_ok=True)

obj_path = f"{path}/{full_placed_models[i]['uuid']}.obj"
with open(obj_path, "w") as f:
f.write(obj)
# save the MTL and images
for k, v in data.items():
with open(os.path.join(path, k), "wb") as f:
f.write(v)
# List all files in the source directory
files = os.listdir(path)
for file in files:
# Check if the file is a .png or .jpg file
if file.endswith(".png") or file.endswith(".jpg"):
print(file)
# Construct paths for the source and destination
source_path = os.path.join(path, file)
destination_path = os.path.join(nested_path, file)
print(source_path, destination_path)
shutil.copy(source_path, destination_path)
# sys.exit(0)
args = Args(
obj_dir=path,
verbose=True,
save_mjcf=True,
compile_model=True,
overwrite=True,
self.modify_body_tag(included_root, full_placed_models[i])
self.rewrite_material_name_and_references(
included_root, material_map, material_count
)
self.write_modified_xml(included_tree, saved_mjc_path)
self.insert_include_tags(main_root, saved_mjc_path)
return main_root

sys.stdout = StringIO()
process_obj(Path(obj_path), args)
printed_output = sys.stdout.getvalue()
def load_and_scale_mesh(self, model):
mesh = trimesh.load(model["model_loc"], force="mesh")
mesh.apply_scale(model["scale"])
return mesh

# Restore stdout
sys.stdout = sys.__stdout__
if "Error compiling model" in printed_output:
continue
def export_mesh(self, mesh):
return trimesh.exchange.export.export_obj(
mesh, include_texture=True, return_texture=True
)

saved_mjc_path = Path(
os.path.abspath(
str(path)
+ f"/{full_placed_models[i]['uuid']}"
+ f"/{full_placed_models[i]['uuid']}.xml"
)
)
tree = ET.parse(saved_mjc_path)
root = tree.getroot()

included_tree = ET.parse(saved_mjc_path)
included_root = included_tree.getroot()

# Step 1: Modify default class attributes
for default in included_root.findall(".//default"):
print(default)
class_attribute = default.get("class")
if class_attribute == "visual":
material_map[class_attribute] = f"visual{visual_count}"
default.set("class", material_map[class_attribute])
visual_count += 1
elif class_attribute == "collision":
material_map[class_attribute] = f"collision{visual_count}"
default.set("class", material_map[class_attribute])
collision_count += 1

# Step 2: Find and modify body tag
for body in included_root.findall(".//body"):
# Add pos and euler attributes
body.set(
"pos",
" ".join(str(e) for e in full_placed_models[i]["Pose"].values()),
)
body.set("euler", "90 0 0")
def create_model_path(self, model):
path = f"./converted/{model['uuid']}"
path = os.path.abspath(path)
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
return path

def write_obj_file(self, obj, path, model):
obj_path = f"{path}/{model['uuid']}.obj"
with open(obj_path, "w") as f:
f.write(obj)
return obj_path

def save_material_and_images(self, data, path):
for k, v in data.items():
with open(os.path.join(path, k), "wb") as f:
f.write(v)

def copy_images_to_nested_path(self, path, model):
nested_path = Path(os.path.abspath(str(path) + f"/{model['uuid']}"))
nested_path.mkdir(parents=True, exist_ok=True)
files = os.listdir(path)
for file in files:
if file.endswith(".png") or file.endswith(".jpg"):
source_path = os.path.join(path, file)
destination_path = os.path.join(nested_path, file)
shutil.copy(source_path, destination_path)

def create_args(self, path):
return Args(
obj_dir=path,
verbose=True,
save_mjcf=True,
compile_model=True,
overwrite=True,
)

joint_tag = ET.SubElement(body, "joint", type="free")
def process_obj_file(self, obj_path, args):
sys.stdout = StringIO()
process_obj(Path(obj_path), args)
printed_output = sys.stdout.getvalue()
sys.stdout = sys.__stdout__
return printed_output

# Step 3: Rewrite material name and references
materials = included_root.findall(".//material")
for i, texture in enumerate(included_root.findall(".//texture")):
old_name = texture.get("name")
def get_saved_mjc_path(self, path, model):
return Path(
os.path.abspath(str(path) + f"/{model['uuid']}" + f"/{model['uuid']}.xml")
)

def parse_xml(self, saved_mjc_path):
tree = ET.parse(saved_mjc_path)
root = tree.getroot()
included_tree = ET.parse(saved_mjc_path)
included_root = included_tree.getroot()
return tree, root, included_tree, included_root

def modify_default_class_attributes(
self, included_root, material_map, visual_count, collision_count
):
for default in included_root.findall(".//default"):
class_attribute = default.get("class")
if class_attribute == "visual":
material_map[class_attribute] = f"visual{visual_count}"
default.set("class", material_map[class_attribute])
visual_count += 1
elif class_attribute == "collision":
material_map[class_attribute] = f"collision{visual_count}"
default.set("class", material_map[class_attribute])
collision_count += 1

def modify_body_tag(self, included_root, model):
for body in included_root.findall(".//body"):
body.set("pos", " ".join(str(e) for e in model["Pose"].values()))
body.set("euler", "90 0 0")
ET.SubElement(body, "joint", type="free")

def rewrite_material_name_and_references(
self, included_root, material_map, material_count
):
materials = included_root.findall(".//material")
for i, texture in enumerate(included_root.findall(".//texture")):
old_name = texture.get("name")
material_map[old_name] = f"material_{material_count}"
texture.set("name", material_map[old_name])
material_count += 1
for i, material in enumerate(materials):
old_name = material.get("name")
if old_name not in material_map:
material_map[old_name] = f"material_{material_count}"
texture.set("name", material_map[old_name])
material_count += 1

for i, material in enumerate(materials):
old_name = material.get("name")
if old_name not in material_map:
material_map[old_name] = f"material_{material_count}"
material_count += 1
material.set("name", material_map[old_name])

texture = material.get("texture")
if texture:
material.set("texture", material_map[old_name])
for i, geom in enumerate(included_root.findall(".//geom")):
# material.set('name', f'material_{visual_count}')
material = geom.get("material")
if material:
geom.set("material", material_map[material])
class_ = geom.get("class")
if class_:
geom.set("class", material_map[class_])
# Replace include element with modified content
# include.clear()
# include.tag = included_root.tag
# include.attrib = included_root.attrib
# include.extend(included_root)
# Step 4: Write the modified XML to a file
print(saved_mjc_path)
included_tree.write(saved_mjc_path)
# Insert include tags for each filepath
include = ET.SubElement(main_root, "include", file=str(saved_mjc_path))
include.tail = "\n"
# Create the tree
return main_root
material.set("name", material_map[old_name])
texture = material.get("texture")
if texture:
material.set("texture", material_map[old_name])
for i, geom in enumerate(included_root.findall(".//geom")):
material = geom.get("material")
if material:
geom.set("material", material_map[material])
class_ = geom.get("class")
if class_:
geom.set("class", material_map[class_])

def write_modified_xml(self, included_tree, saved_mjc_path):
included_tree.write(saved_mjc_path)

def insert_include_tags(self, main_root, saved_mjc_path):
include = ET.SubElement(main_root, "include", file=str(saved_mjc_path))
include.tail = "\n"

def create_tree(self, main_root):
tree = ET.ElementTree(main_root)
Expand Down

0 comments on commit 554a03d

Please sign in to comment.