Skip to content

Commit

Permalink
pose viz cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
sronilsson committed Jan 22, 2025
1 parent acf6211 commit 30eb871
Show file tree
Hide file tree
Showing 10 changed files with 208 additions and 64 deletions.
Binary file added docs/_static/img/pose_plotter_cuda.mp4
Binary file not shown.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
# Setup configuration
setuptools.setup(
name="simba-uw-tf-dev",
version="2.5.4",
version="2.5.5",
author="Simon Nilsson, Jia Jie Choong, Sophia Hwang",
author_email="[email protected]",
description="Toolkit for computer classification and analysis of behaviors in experimental animals",
Expand Down
126 changes: 118 additions & 8 deletions simba/data_processors/cuda/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from numba import cuda
from numba.core.errors import NumbaPerformanceWarning

from simba.data_processors.cuda.utils import (_cuda_luminance_pixel_to_grey,
_cuda_mse, _is_cuda_available)
from simba.data_processors.cuda.utils import (_cuda_luminance_pixel_to_grey, _cuda_mse, _is_cuda_available)
from simba.mixins.plotting_mixin import PlottingMixin
from simba.mixins.image_mixin import ImageMixin
from simba.utils.checks import (check_file_exist_and_readable, check_float,
check_if_dir_exists,
Expand All @@ -37,14 +37,11 @@
check_that_hhmmss_start_is_before_end,
check_valid_array, check_valid_boolean,
is_video_color)
from simba.utils.data import find_frame_numbers_from_time_stamp
from simba.utils.data import find_frame_numbers_from_time_stamp, create_color_palette
from simba.utils.enums import Formats
from simba.utils.errors import (FFMPEGCodecGPUError, InvalidInputError,
SimBAGPUError)
from simba.utils.errors import (FFMPEGCodecGPUError, InvalidInputError, SimBAGPUError, FrameRangeError)
from simba.utils.printing import SimbaTimer, stdout_success
from simba.utils.read_write import (
check_if_hhmmss_timestamp_is_valid_part_of_video, get_fn_ext,
get_video_meta_data, read_img_batch_from_video_gpu)
from simba.utils.read_write import (check_if_hhmmss_timestamp_is_valid_part_of_video, get_fn_ext, get_video_meta_data, read_img_batch_from_video_gpu, read_df)

warnings.simplefilter('ignore', category=NumbaPerformanceWarning)

Expand Down Expand Up @@ -1191,6 +1188,119 @@ def bg_subtraction_cupy(video_path: Union[str, os.PathLike],
timer.stop_timer()
stdout_success(msg=f'Video saved at {save_path}', elapsed_time=timer.elapsed_time_str)



@cuda.jit(max_registers=None)
def _pose_plot_kernel(imgs, data, circle_size, resolution, colors):
bp_n, img_n = cuda.grid(2)
if img_n < 0 or img_n > (imgs.shape[0] -1):
return
if bp_n < 0 or bp_n > (data[0].shape[0] -1):
return

img, bp_loc, color = imgs[img_n], data[img_n][bp_n], colors[bp_n]
for x1 in range(bp_loc[0]-circle_size[0], bp_loc[0]+circle_size[0]):
for y1 in range(bp_loc[1]-circle_size[0], bp_loc[1]+circle_size[0]):
if (x1 > 0) and (x1 < resolution[0]):
if (y1 > 0) and (y1 < resolution[1]):
b = (x1 - bp_loc[0]) ** 2
c = (y1 - bp_loc[1]) ** 2
if (b + c) < (circle_size[0] ** 2):
imgs[img_n][y1][x1][0] = int(color[0])
imgs[img_n][y1][x1][1] = int(color[1])
imgs[img_n][y1][x1][2] = int(color[2])


def pose_plotter(data: Union[str, os.PathLike, np.ndarray],
video_path: Union[str, os.PathLike],
save_path: Union[str, os.PathLike],
circle_size: Optional[int] = None,
colors: Optional[str] = 'Set1',
batch_size: int = 1500,
verbose: bool = True) -> None:

"""
Creates a video overlaying pose-estimation data on frames from a given video using GPU acceleration.
.. video:: _static/img/pose_plotter_cuda.mp4
:width: 800
:autoplay:
:loop:
.. seealso::
For CPU based methods, see :func:`~simba.plotting.path_plotter.PathPlotterSingleCore` and :func:`~simba.plotting.path_plotter_mp.PathPlotterMulticore`.
:param Union[str, os.PathLike, np.ndarray] data: Path to a CSV file with pose-estimation data or a 3d numpy array (n_images, n_bodyparts, 2) with pose-estimated locations.
:param Union[str, os.PathLike] video_path: Path to a video file where the ``data`` has been pose-estimated.
:param Union[str, os.PathLike] save_path: Location where to store the output visualization.
:param Optional[int] circle_size: The size of the circles representing the location of the pose-estimated locations. If None, the optimal size will be inferred as a 100th of the max(resultion_w, h).
:param int batch_size: The number of frames to process concurrently on the GPU. Default: 1500. Increase of host and device RAM allows it to improve runtime.
:example:
>>> DATA_PATH = "/mnt/c/troubleshooting/mitra/project_folder/csv/outlier_corrected_movement_location/501_MA142_Gi_CNO_0514.csv"
>>> VIDEO_PATH = "/mnt/c/troubleshooting/mitra/project_folder/videos/501_MA142_Gi_CNO_0514.mp4"
>>> SAVE_PATH = "/mnt/c/troubleshooting/mitra/project_folder/frames/output/pose_ex/test.mp4"
>>> pose_plotter(data=DATA_PATH, video_path=VIDEO_PATH, save_path=SAVE_PATH, circle_size=10)
"""

THREADS_PER_BLOCK = (32, 32, 1)
if isinstance(data, str):
check_file_exist_and_readable(file_path=data)
df = read_df(file_path=data, file_type='csv')
cols = [x for x in df.columns if not x.lower().endswith('_p')]
data = df[cols].values
data = np.ascontiguousarray(data.reshape(data.shape[0], int(data.shape[1] / 2), 2).astype(np.int32))
elif isinstance(data, np.ndarray):
check_valid_array(data=data, source=pose_plotter.__name__, accepted_ndims=(3,), accepted_dtypes=Formats.NUMERIC_DTYPES.value)

check_int(name=f'{pose_plotter.__name__} batch_size', value=batch_size, min_value=1)
check_int(name=f'{pose_plotter.__name__} circle_size', value=circle_size, min_value=1)
check_valid_boolean(value=[verbose], source=f'{pose_plotter.__name__} verbose')
video_meta_data = get_video_meta_data(video_path=video_path)
n, w, h = video_meta_data['frame_count'], video_meta_data['width'], video_meta_data['height']
check_if_dir_exists(in_dir=os.path.dirname(save_path))
if data.shape[0] != video_meta_data['frame_count']:
raise FrameRangeError(msg=f'The data contains {data.shape[0]} frames while the video contains {video_meta_data["frame_count"]} frames')

if circle_size is None:
circle_size = np.array([PlottingMixin().get_optimal_circle_size(frame_size=(w, h))]).astype(np.int32)
else:
circle_size = np.array([circle_size]).astype(np.int32)
fourcc = cv2.VideoWriter_fourcc(*Formats.MP4_CODEC.value)
video_writer = cv2.VideoWriter(save_path, fourcc, video_meta_data['fps'], (w, h))

colors = np.array(create_color_palette(pallete_name=colors, increments=data[0].shape[0])).astype(np.int32)
circle_size_dev = cuda.to_device(circle_size)
colors_dev = cuda.to_device(colors)
resolution_dev = cuda.to_device(np.array([video_meta_data['width'], video_meta_data['height']]))

img_dev = cuda.device_array((batch_size, h, w, 3), dtype=np.int32)
data_dev = cuda.device_array((batch_size, data.shape[1], 2), dtype=np.int32)

total_timer = SimbaTimer(start=True)
for batch_cnt, l in enumerate(range(0, data.shape[0], batch_size)):
r = min(data.shape[0], l + batch_size - 1)
if verbose:
print(f'Processing frames {l}-{r} of {data.shape[0]} frames (video: {video_meta_data["video_name"]})...')
batch_data = data[l:r + 1]
batch_n = batch_data.shape[0]
batch_frms = read_img_batch_from_video_gpu(video_path=video_path, start_frm=l, end_frm=r, out_format='array').astype(np.int32)
grid_x = math.ceil(batch_frms.shape[0] / THREADS_PER_BLOCK[0])
grid_z = math.ceil(batch_n / THREADS_PER_BLOCK[2])
bpg = (grid_x, grid_z)
img_dev[:batch_n].copy_to_device(batch_frms[:batch_n])
data_dev[:batch_n] = cuda.to_device(batch_data[:batch_n])

_pose_plot_kernel[bpg, THREADS_PER_BLOCK](img_dev, data_dev, circle_size_dev, resolution_dev, colors_dev)
batch_frms = img_dev.copy_to_host()
for img_idx in range(0, batch_n):
video_writer.write(batch_frms[img_idx].astype(np.uint8))

video_writer.release()
total_timer.stop_timer()
if verbose:
stdout_success(msg=f'Pose-estimation video saved at {save_path}.', elapsed_time=total_timer.elapsed_time_str)

#
# #from simba.data_processors.cuda.image import create_average_frm_cupy
# SAVE_PATH = "/mnt/c/Users/sroni/Downloads/bg_remove_nb/bg_removed_ex_7.mp4"
Expand Down
1 change: 0 additions & 1 deletion simba/roi_tools/ROI_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ def run(self):
if event_pose.shape[0] > 1:
distance, velocity = (FeatureExtractionSupplemental.distance_and_velocity(x=event_pose, fps=self.fps, pixels_per_mm=pix_per_mm, centimeters=True))
distances.append(distance)
print(distances, velocity)
velocities.append(velocity)
self.movements_df.loc[len(self.movements_df)] = [video_name, animal_name, roi_name, "Movement (cm)", sum(distances)]
self.movements_df.loc[len(self.movements_df)] = [video_name, animal_name, roi_name, "Average velocity (cm/s)", np.average(velocities)]
Expand Down
4 changes: 2 additions & 2 deletions simba/roi_tools/ROI_define.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self,
self.video_path = video_path
_, self.file_name, self.file_ext = get_fn_ext(self.video_path)
self.other_video_paths = list(find_all_videos_in_directory(directory=self.video_dir, as_dict=True).values())
self.other_video_paths.remove(video_path)
self.other_video_paths = [x for x in self.other_video_paths if x != video_path]
self.other_video_file_names = []
for video in self.other_video_paths:
self.other_video_file_names.append(os.path.basename(video))
Expand Down Expand Up @@ -94,7 +94,7 @@ def __init__(self,
self.img_no = 1
self.duplicate_jump_size = 20
self.click_sens = 10
self.text_size, _, _ = PlottingMixin().get_optimal_font_scales(text='TEN DIGITS', accepted_px_width=int(self.video_info['Resolution_width']/10), accepted_px_height=int(self.video_info['Resolution_height']/10), text_thickness=2, font=cv2.FONT_HERSHEY_SIMPLEX)
self.text_size, _, _ = PlottingMixin().get_optimal_font_scales(text='TEN DIGITS', accepted_px_width=int(self.video_info['Resolution_width'].iloc[0]/10), accepted_px_height=int(self.video_info['Resolution_height'].iloc[0]/10), text_thickness=2, font=cv2.FONT_HERSHEY_SIMPLEX)
self.text_thickness = 2
self.line_type = -1
self.named_shape_colors = get_color_dict()
Expand Down
41 changes: 8 additions & 33 deletions simba/roi_tools/ROI_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,11 @@ def initiate_x_y_callback():

def initiate_draw(self, draw_dict):
self.draw_info = draw_dict
if self.draw_info["Shape_type"] is "rectangle":
if self.draw_info["Shape_type"] == "rectangle":
self.draw_rectangle(self.draw_info)
if self.draw_info["Shape_type"] is "circle":
if self.draw_info["Shape_type"] == "circle":
self.draw_circle(self.draw_info)
if self.draw_info["Shape_type"] is "polygon":
if self.draw_info["Shape_type"] == "polygon":
self.draw_polygon()

self.all_shape_names = []
Expand Down Expand Up @@ -609,42 +609,17 @@ def insert_all_ROIs_into_image(
self.no_shapes += 1
pts = np.array(pg["vertices"]).reshape((-1, 1, 2))
try:
cv2.polylines(
self.working_frame,
[pts],
True,
pg["Color BGR"],
int(pg["Thickness"]),
lineType=self.line_type,
)
cv2.polylines(self.working_frame, [pts], True, pg["Color BGR"], int(pg["Thickness"]), lineType=self.line_type)
except cv2.error as e:
cv2.polylines(
self.working_frame,
[pts],
True,
pg["Color BGR"],
int(pg["Thickness"]),
lineType=4,
)
cv2.polylines(self.working_frame, [pts], True, pg["Color BGR"], int(pg["Thickness"]), lineType=4)

if ROI_ear_tags is True:
for p in pg["Tags"]:
try:
cv2.circle(
self.working_frame,
p,
int(pg["Ear_tag_size"]),
self.colors[pg["Color name"]],
-1,
)
cv2.circle(self.working_frame, p, int(pg["Ear_tag_size"]), self.colors[pg["Color name"]], -1)
except:
cv2.circle(
self.working_frame,
tuple(pg["Tags"][p]),
int(pg["Ear_tag_size"]),
self.colors[pg["Color name"]],
-1,
)
cv2.circle(self.working_frame, tuple(pg["Tags"][p]), int(pg["Ear_tag_size"]), self.colors[pg["Color name"]], -1)

if show_size_info is True:
area_cm = self.polygon_size_dict["Polygons"][pg["Name"]]["area_cm"]
self.working_frame = PlottingMixin().put_text(img=self.working_frame, text=f'AREA: {str(area_cm)}', pos=(pg["Center_X"], pg["Center_Y"]), font_size=self.text_size, font_thickness=self.text_thickness, font=cv2.FONT_HERSHEY_SIMPLEX, text_color=self.colors[pg["Color name"]])
Expand Down
7 changes: 4 additions & 3 deletions simba/roi_tools/ROI_size_calculations.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
__author__ = "Simon Nilsson"
__email__ = "[email protected]"



import math
from typing import Dict, Optional, Tuple, Union

import numpy as np
from scipy.spatial import ConvexHull
from shapely.geometry import Polygon

from simba.utils.checks import check_int, check_str, check_valid_tuple
from simba.utils.enums import Formats
Expand Down Expand Up @@ -105,7 +104,9 @@ def get_half_circle_vertices(center: Tuple[int, int],
a = np.linspace(np.pi, 2 * np.pi, n_points)
x, y = x_c + radius * np.cos(a), y_c + radius * np.sin(a)
vertices = np.column_stack((x, y)).astype(np.int32)
vertices_dict = {"Center_tag": (center[0], center[1])}
shape_center = np.array(Polygon(vertices).centroid.coords)[0].astype(np.int32)
vertices_dict = {"Center_tag": (shape_center[0], shape_center[1])}

for tag_id in range(vertices.shape[0]):
vertices_dict[f"Tag_{tag_id}"] = (vertices[tag_id][0], vertices[tag_id][1])
return (np.array(vertices).astype("int32"), vertices_dict)
Expand Down
40 changes: 37 additions & 3 deletions simba/third_party_label_appenders/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
from simba.mixins.geometry_mixin import GeometryMixin
from simba.mixins.image_mixin import ImageMixin
from simba.utils.checks import (check_file_exist_and_readable, check_float,
check_if_dir_exists,
check_if_dir_exists, check_str,
check_if_keys_exist_in_dict,
check_if_valid_img, check_int,
check_valid_array, check_valid_boolean,
check_valid_dict, check_valid_tuple)
from simba.utils.enums import Formats
from simba.utils.enums import Formats, Options
from simba.utils.errors import InvalidInputError, NoFilesFoundError
from simba.utils.printing import SimbaTimer, stdout_success
from simba.utils.read_write import (copy_files_to_directory,
Expand Down Expand Up @@ -224,7 +224,6 @@ def b64_to_arr(img_b64) -> np.ndarray:
return img_arr



def arr_to_b64(x: np.ndarray) -> str:
"""
Helper to convert image in array format to an image in byte string format
Expand Down Expand Up @@ -824,6 +823,41 @@ def yolo_obb_data_to_bounding_box(center_x: float, center_y: float, width: float
return box.astype(np.int32)


def labelme_to_img_dir(labelme_dir: Union[str, os.PathLike],
img_dir: Union[str, os.PathLike],
img_format: str = 'png') -> None:

"""
Given a directory of labelme JSON annotations, extract the images from the JSONs in b64 format and store them as images in a directory
:param labelme_dir: Directory containing labelme json annotations.
:param img_dir: Directory where to store the images.
:param img_format: Format in which to save the images.
:return: None
:example:
>>> labelme_to_img_dir(img_dir=r"C:\troubleshooting\coco_data\labels\train_images", labelme_dir=r'C:\troubleshooting\coco_data\labels\train_')
"""

timer = SimbaTimer(start=True)
check_if_dir_exists(in_dir=labelme_dir)
check_if_dir_exists(in_dir=img_dir)
check_str(name=f'{labelme_to_img_dir.__name__} img_format', value=f'.{img_format}', options=Options.ALL_IMAGE_FORMAT_OPTIONS.value)
img_format = f'.{img_format}'
annotation_paths = find_files_of_filetypes_in_directory(directory=labelme_dir, extensions=['.json'], raise_error=True)
for file_cnt, annot_path in enumerate(annotation_paths):
with open(annot_path) as f: annot_data = json.load(f)
check_if_keys_exist_in_dict(data=annot_data, key=['shapes', 'imageData', 'imagePath'], name=annot_path)
img_name = os.path.basename(annot_data['imagePath'])
print(f'Reading image {file_cnt+1}/{len(annotation_paths)} ({img_name})...')
img = b64_to_arr(annot_data['imageData'])
save_path = os.path.join(img_dir, f'{img_name}{img_format}')
cv2.imwrite(filename=save_path, img=img)
timer.stop_timer()
stdout_success(msg=f'{len(annotation_paths)} images saved in {img_dir}.', elapsed_time=timer.elapsed_time_str)


#labelme_to_img_dir(img_dir=r"C:\troubleshooting\coco_data\labels\train_images", labelme_dir=r'C:\troubleshooting\coco_data\labels\train_')



Expand Down
Loading

0 comments on commit 30eb871

Please sign in to comment.