nova_rerun_bridge.robot_visualizer

  1import re
  2
  3import numpy as np
  4import rerun as rr
  5import trimesh
  6from scipy.spatial.transform import Rotation
  7
  8from nova.api import models
  9from nova_rerun_bridge import colors
 10from nova_rerun_bridge.conversion_helpers import normalize_pose
 11from nova_rerun_bridge.dh_robot import DHRobot
 12from nova_rerun_bridge.helper_scripts.download_models import get_project_root
 13from nova_rerun_bridge.hull_visualizer import HullVisualizer
 14
 15
 16def get_model_path(model_name: str) -> str:
 17    """Get absolute path to model file in project directory"""
 18    return str(get_project_root() / "models" / f"{model_name}.glb")
 19
 20
 21class RobotVisualizer:
 22    def __init__(
 23        self,
 24        robot: DHRobot,
 25        robot_model_geometries,
 26        tcp_geometries,
 27        static_transform: bool = True,
 28        base_entity_path: str = "robot",
 29        albedo_factor: list = [255, 255, 255],
 30        collision_link_chain=None,
 31        collision_tcp=None,
 32        model_from_controller="",
 33    ):
 34        """
 35        :param robot: DHRobot instance
 36        :param robot_model_geometries: List of geometries for each link
 37        :param tcp_geometries: TCP geometries (similar structure to link geometries)
 38        :param static_transform: If True, transforms are logged as static, else temporal.
 39        :param base_entity_path: A base path prefix for logging the entities (e.g. motion group name)
 40        :param albedo_factor: A list representing the RGB values [R, G, B] to apply as the albedo factor.
 41        :param glb_path: Path to the GLB file for the robot model.
 42        """
 43        self.robot = robot
 44        self.link_geometries: dict[int, list[trimesh.Trimesh]] = {}
 45        self.tcp_geometries = tcp_geometries
 46        self.logged_meshes: set[str] = set()
 47        self.static_transform = static_transform
 48        self.base_entity_path = base_entity_path.rstrip("/")
 49        self.albedo_factor = albedo_factor
 50        self.mesh_loaded = False
 51        self.collision_link_geometries = {}
 52        self.collision_tcp_geometries = collision_tcp
 53
 54        # This will hold the names of discovered joints (e.g. ["robot_J00", "robot_J01", ...])
 55        self.joint_names: list[str] = []
 56        self.layer_nodes_dict: dict[str, list[str]] = {}
 57        self.parent_nodes_dict: dict[str, str] = {}
 58
 59        # load mesh
 60        try:
 61            glb_path = get_model_path(model_from_controller)
 62            self.scene = trimesh.load_scene(glb_path, file_type="glb")
 63            self.mesh_loaded = True
 64            self.edge_data = self.scene.graph.transforms.edge_data
 65
 66            # After loading, auto-discover any child nodes that match *_J0n
 67            self.discover_joints()
 68        except Exception as e:
 69            print(f"Failed to load mesh: {e}")
 70
 71        # Group geometries by link
 72        for gm in robot_model_geometries:
 73            self.link_geometries.setdefault(gm.link_index, []).append(gm.geometry)
 74
 75        # Group geometries by link
 76        self.collision_link_geometries = collision_link_chain
 77
 78    def discover_joints(self):
 79        """
 80        Find all child node names that contain '_J0' followed by digits or '_FLG'.
 81        Store joints with their parent nodes and print layer information.
 82        """
 83        joint_pattern = re.compile(r"_J0(\d+)")
 84        flg_pattern = re.compile(r"_FLG")
 85        matches = []
 86        flg_nodes = []
 87        joint_parents = {}  # Store parent for each joint/FLG
 88
 89        for (parent, child), data in self.edge_data.items():
 90            # Check for joints
 91            joint_match = joint_pattern.search(child)
 92            if joint_match:
 93                j_idx = int(joint_match.group(1))
 94                matches.append((j_idx, child))
 95                joint_parents[child] = parent
 96
 97            # Check for FLG
 98            flg_match = flg_pattern.search(child)
 99            if flg_match:
100                flg_nodes.append(child)
101                joint_parents[child] = parent
102
103        matches.sort(key=lambda x: x[0])
104        self.joint_names = [name for _, name in matches] + flg_nodes
105
106        # print("Discovered nodes:", self.joint_names)
107        # Print layer information for each joint
108        for joint in self.joint_names:
109            self.get_nodes_on_same_layer(joint_parents[joint], joint)
110            # print(f"\nNodes on same layer as {joint}:")
111            # print(f"Parent node: {joint_parents[joint]}")
112            # print(f"Layer nodes: {same_layer_nodes}")
113
114    def get_nodes_on_same_layer(self, parent_node, joint):
115        """
116        Find nodes on same layer and only add descendants of link nodes.
117        """
118        same_layer = []
119        # First get immediate layer nodes
120        for (parent, child), data in self.edge_data.items():
121            if parent == parent_node:
122                if child == joint:
123                    continue
124                if "geometry" in data:
125                    same_layer.append(data["geometry"])
126                    self.parent_nodes_dict[data["geometry"]] = child
127
128                # Get all descendants for this link
129                parentChild = child
130                stack = [child]
131                while stack:
132                    current = stack.pop()
133                    for (p, c), data in self.edge_data.items():
134                        if p == current:
135                            if "geometry" in data:
136                                same_layer.append(data["geometry"])
137                                self.parent_nodes_dict[data["geometry"]] = parentChild
138                            stack.append(c)
139
140        self.layer_nodes_dict[joint] = same_layer
141        return same_layer
142
143    def geometry_pose_to_matrix(self, init_pose: models.PlannerPose):
144        return self.robot.pose_to_matrix(init_pose)
145
146    def compute_forward_kinematics(self, joint_values):
147        """Compute link transforms using the robot's methods."""
148        accumulated = self.robot.pose_to_matrix(self.robot.mounting)
149        transforms = [accumulated.copy()]
150        for dh_param, joint_rot in zip(self.robot.dh_parameters, joint_values.joints, strict=False):
151            transform = self.robot.dh_transform(dh_param, joint_rot)
152            accumulated = accumulated @ transform
153            transforms.append(accumulated.copy())
154        return transforms
155
156    def rotation_matrix_to_axis_angle(self, Rm):
157        """Use scipy for cleaner axis-angle extraction."""
158        rot = Rotation.from_matrix(Rm)
159        angle = rot.magnitude()
160        axis = rot.as_rotvec() / angle if angle > 1e-8 else np.array([1.0, 0.0, 0.0])
161        return axis, angle
162
163    def gamma_lift_single_color(self, color: np.ndarray, gamma: float = 0.8) -> np.ndarray:
164        """
165        Apply gamma correction to a single RGBA color in-place.
166        color: shape (4,) with [R, G, B, A] in 0..255, dtype=uint8
167        gamma: < 1.0 brightens midtones, > 1.0 darkens them.
168        """
169        rgb_float = color[:3].astype(np.float32) / 255.0
170        rgb_float = np.power(rgb_float, gamma)
171        color[:3] = (rgb_float * 255.0).astype(np.uint8)
172
173        return color
174
175    def get_transform_matrix(self):
176        """
177        Creates a transformation matrix that converts from glTF's right-handed Y-up
178        coordinate system to Rerun's right-handed Z-up coordinate system.
179
180        Returns:
181            np.ndarray: A 4x4 transformation matrix
182        """
183        # Convert from glTF's Y-up to Rerun's Z-up coordinate system
184        return np.array(
185            [
186                [1.0, 0.0, 0.0, 0.0],  # X stays the same
187                [0.0, 0.0, -1.0, 0.0],  # Y becomes -Z
188                [0.0, 1.0, 0.0, 0.0],  # Z becomes Y
189                [0.0, 0.0, 0.0, 1.0],  # Homogeneous coordinate
190            ]
191        )
192
193    def init_mesh(self, entity_path: str, geom, joint_name):
194        """Generic method to log a single geometry, either capsule or box."""
195
196        if entity_path not in self.logged_meshes:
197            if geom.metadata.get("node") not in self.parent_nodes_dict:
198                return
199
200            base_transform = np.eye(4)
201            # if the dh parameters are not at 0,0,0 from the mesh we have to move the first mesh joint
202            if "J00" in joint_name:
203                base_transform_, _ = self.scene.graph.get(frame_to=joint_name)
204                base_transform = base_transform_.copy()
205            base_transform[:3, 3] *= 1000
206
207            # if the mesh has the pivot not in the center, we need to adjust the transform
208            cumulative_transform, _ = self.scene.graph.get(
209                frame_to=self.parent_nodes_dict[geom.metadata.get("node")]
210            )
211            ctransform = cumulative_transform.copy()
212
213            # scale positions to mm
214            ctransform[:3, 3] *= 1000
215
216            # scale mesh to mm
217            transform = base_transform @ ctransform
218            mesh_scale_matrix = np.eye(4)
219            mesh_scale_matrix[:3, :3] *= 1000
220            transform = transform @ mesh_scale_matrix
221            transformed_mesh = geom.copy()
222
223            transformed_mesh.apply_transform(transform)
224
225            if transformed_mesh.visual is not None:
226                transformed_mesh.visual = transformed_mesh.visual.to_color()
227
228            vertex_colors = None
229            if transformed_mesh.visual and hasattr(transformed_mesh.visual, "vertex_colors"):
230                vertex_colors = transformed_mesh.visual.vertex_colors
231
232            rr.log(
233                entity_path,
234                rr.Mesh3D(
235                    vertex_positions=transformed_mesh.vertices,
236                    triangle_indices=transformed_mesh.faces,
237                    vertex_normals=getattr(transformed_mesh, "vertex_normals", None),
238                    albedo_factor=self.gamma_lift_single_color(vertex_colors, gamma=0.5)
239                    if vertex_colors is not None
240                    else None,
241                ),
242            )
243
244            self.logged_meshes.add(entity_path)
245
246    def init_collision_geometry(
247        self, entity_path: str, collider: models.Collider, pose: models.PlannerPose
248    ):
249        if entity_path in self.logged_meshes:
250            return
251
252        if isinstance(collider.shape.actual_instance, models.Sphere2):
253            rr.log(
254                f"{entity_path}",
255                rr.Ellipsoids3D(
256                    radii=[
257                        collider.shape.actual_instance.radius,
258                        collider.shape.actual_instance.radius,
259                        collider.shape.actual_instance.radius,
260                    ],
261                    centers=[[pose.position.x, pose.position.y, pose.position.z]]
262                    if pose.position
263                    else [0, 0, 0],
264                    colors=[(221, 193, 193, 255)],
265                ),
266            )
267
268        elif isinstance(collider.shape.actual_instance, models.Box2):
269            rr.log(
270                f"{entity_path}",
271                rr.Boxes3D(
272                    centers=[[pose.position.x, pose.position.y, pose.position.z]]
273                    if pose.position
274                    else [0, 0, 0],
275                    sizes=[
276                        collider.shape.actual_instance.size_x,
277                        collider.shape.actual_instance.size_y,
278                        collider.shape.actual_instance.size_z,
279                    ],
280                    colors=[(221, 193, 193, 255)],
281                ),
282            )
283
284        elif isinstance(collider.shape.actual_instance, models.Capsule2):
285            height = collider.shape.actual_instance.cylinder_height
286            radius = collider.shape.actual_instance.radius
287
288            # Generate trimesh capsule
289            capsule = trimesh.creation.capsule(height=height, radius=radius, count=[6, 8])
290
291            # Extract vertices and faces for solid visualization
292            vertices = np.array(capsule.vertices)
293
294            # Transform vertices to world position
295            transform = np.eye(4)
296            if pose.position:
297                transform[:3, 3] = [pose.position.x, pose.position.y, pose.position.z - height / 2]
298            else:
299                transform[:3, 3] = [0, 0, -height / 2]
300
301            if collider.pose and collider.pose.orientation:
302                rot_mat = Rotation.from_quat(
303                    [
304                        collider.pose.orientation[0],
305                        collider.pose.orientation[1],
306                        collider.pose.orientation[2],
307                        collider.pose.orientation[3],
308                    ]
309                )
310                transform[:3, :3] = rot_mat.as_matrix()
311
312            vertices = np.array([transform @ np.append(v, 1) for v in vertices])[:, :3]
313
314            polygons = HullVisualizer.compute_hull_outlines_from_points(vertices)
315
316            if polygons:
317                line_segments = [p.tolist() for p in polygons]
318                rr.log(
319                    f"{entity_path}",
320                    rr.LineStrips3D(
321                        line_segments,
322                        radii=rr.Radius.ui_points(0.75),
323                        colors=[[221, 193, 193, 255]],
324                    ),
325                    static=True,
326                )
327
328        elif isinstance(collider.shape.actual_instance, models.ConvexHull2):
329            polygons = HullVisualizer.compute_hull_outlines_from_points(
330                collider.shape.actual_instance.vertices
331            )
332
333            if polygons:
334                line_segments = [p.tolist() for p in polygons]
335                rr.log(
336                    f"{entity_path}",
337                    rr.LineStrips3D(
338                        line_segments, radii=rr.Radius.ui_points(1.5), colors=[colors.colors[2]]
339                    ),
340                    static=True,
341                )
342
343                vertices, triangles, normals = HullVisualizer.compute_hull_mesh(polygons)
344
345                rr.log(
346                    f"{entity_path}",
347                    rr.Mesh3D(
348                        vertex_positions=vertices,
349                        triangle_indices=triangles,
350                        vertex_normals=normals,
351                        albedo_factor=colors.colors[0],
352                    ),
353                    static=True,
354                )
355
356        self.logged_meshes.add(entity_path)
357
358    def init_geometry(self, entity_path: str, capsule):
359        """Generic method to log a single geometry, either capsule or box."""
360
361        if entity_path in self.logged_meshes:
362            return
363
364        if capsule:
365            radius = capsule.radius
366            height = capsule.cylinder_height
367
368            # Slightly shrink the capsule if static to reduce z-fighting
369            if self.static_transform:
370                radius *= 0.99
371                height *= 0.99
372
373            # Create capsule and retrieve normals
374            cap_mesh = trimesh.creation.capsule(radius=radius, height=height)
375            vertex_normals = cap_mesh.vertex_normals.tolist()
376
377            rr.log(
378                entity_path,
379                rr.Mesh3D(
380                    vertex_positions=cap_mesh.vertices.tolist(),
381                    triangle_indices=cap_mesh.faces.tolist(),
382                    vertex_normals=vertex_normals,
383                    albedo_factor=self.albedo_factor,
384                ),
385            )
386            self.logged_meshes.add(entity_path)
387        else:
388            # fallback to a box
389            rr.log(entity_path, rr.Boxes3D(half_sizes=[[50, 50, 50]]))
390            self.logged_meshes.add(entity_path)
391
392    def log_robot_geometry(self, joint_position):
393        transforms = self.compute_forward_kinematics(joint_position)
394
395        def log_geometry(entity_path, transform):
396            translation = transform[:3, 3]
397            Rm = transform[:3, :3]
398            axis, angle = self.rotation_matrix_to_axis_angle(Rm)
399            rr.log(
400                entity_path,
401                rr.InstancePoses3D(
402                    translations=[translation.tolist()],
403                    rotation_axis_angles=[
404                        rr.RotationAxisAngle(axis=axis.tolist(), angle=float(angle))
405                    ],
406                ),
407                static=self.static_transform,
408            )
409
410        # Log robot joint geometries
411        if self.mesh_loaded:
412            for link_index, joint_name in enumerate(self.joint_names):
413                link_transform = transforms[link_index]
414
415                # Get nodes on same layer using dictionary
416                same_layer_nodes = self.layer_nodes_dict.get(joint_name)
417                if not same_layer_nodes:
418                    continue
419
420                filtered_geoms = []
421                for node_name in same_layer_nodes:
422                    if node_name in self.scene.geometry:
423                        geom = self.scene.geometry[node_name]
424                        # Add metadata that would normally come from dump
425                        geom.metadata = {"node": node_name}
426                        filtered_geoms.append(geom)
427
428                for geom in filtered_geoms:
429                    entity_path = f"{self.base_entity_path}/visual/links/link_{link_index}/mesh/{geom.metadata.get('node')}"
430
431                    # calculate the inverse transform to get the mesh in the correct position
432                    cumulative_transform, _ = self.scene.graph.get(frame_to=joint_name)
433                    ctransform = cumulative_transform.copy()
434                    inverse_transform = np.linalg.inv(ctransform)
435
436                    # DH theta is rotated, rotate mesh around z in direction of theta
437                    rotation_matrix_z_4x4 = np.eye(4)
438                    if len(self.robot.dh_parameters) > link_index:
439                        rotation_z_minus_90 = Rotation.from_euler(
440                            "z", self.robot.dh_parameters[link_index].theta, degrees=False
441                        ).as_matrix()
442                        rotation_matrix_z_4x4[:3, :3] = rotation_z_minus_90
443
444                    # scale positions to mm
445                    inverse_transform[:3, 3] *= 1000
446
447                    root_transform = self.get_transform_matrix()
448
449                    transform = root_transform @ inverse_transform
450
451                    final_transform = link_transform @ rotation_matrix_z_4x4 @ transform
452
453                    self.init_mesh(entity_path, geom, joint_name)
454                    log_geometry(entity_path, final_transform)
455
456        # Log link geometries
457        for link_index, geometries in self.link_geometries.items():
458            link_transform = transforms[link_index]
459            for i, geom in enumerate(geometries):
460                entity_path = f"{self.base_entity_path}/safety_from_controller/links/link_{link_index}/geometry_{i}"
461                final_transform = link_transform @ self.geometry_pose_to_matrix(geom.init_pose)
462
463                self.init_geometry(entity_path, geom.capsule)
464                log_geometry(entity_path, final_transform)
465
466        # Log TCP geometries
467        if self.tcp_geometries:
468            tcp_transform = transforms[-1]  # the final frame transform
469            for i, geom in enumerate(self.tcp_geometries):
470                entity_path = f"{self.base_entity_path}/safety_from_controller/tcp/geometry_{i}"
471                final_transform = tcp_transform @ self.geometry_pose_to_matrix(geom.init_pose)
472
473                self.init_geometry(entity_path, geom.capsule)
474                log_geometry(entity_path, final_transform)
475
476    def log_robot_geometries(self, trajectory: list[models.TrajectorySample], times_column):
477        """
478        Log the robot geometries for each link and TCP as separate entities.
479
480        Args:
481            trajectory (List[wb.models.TrajectorySample]): The list of trajectory sample points.
482            times_column (rr.TimeSecondsColumn): The time column associated with the trajectory points.
483        """
484        link_positions = {}
485        link_rotations = {}
486
487        def collect_geometry_data(entity_path, transform):
488            """Helper to collect geometry data for a given entity."""
489            translation = transform[:3, 3].tolist()
490            Rm = transform[:3, :3]
491            axis, angle = self.rotation_matrix_to_axis_angle(Rm)
492            if entity_path not in link_positions:
493                link_positions[entity_path] = []
494                link_rotations[entity_path] = []
495            link_positions[entity_path].append(translation)
496            link_rotations[entity_path].append(rr.RotationAxisAngle(axis=axis, angle=angle))
497
498        for point in trajectory:
499            transforms = self.compute_forward_kinematics(point.joint_position)
500
501            # Log robot joint geometries
502            if self.mesh_loaded:
503                for link_index, joint_name in enumerate(self.joint_names):
504                    if link_index >= len(transforms):
505                        break
506                    link_transform = transforms[link_index]
507
508                    # Get nodes on same layer using dictionary
509                    same_layer_nodes = self.layer_nodes_dict.get(joint_name)
510                    if not same_layer_nodes:
511                        continue
512
513                    filtered_geoms = []
514                    for node_name in same_layer_nodes:
515                        if node_name in self.scene.geometry:
516                            geom = self.scene.geometry[node_name]
517                            # Add metadata that would normally come from dump
518                            geom.metadata = {"node": node_name}
519                            filtered_geoms.append(geom)
520
521                    for geom in filtered_geoms:
522                        entity_path = f"{self.base_entity_path}/visual/links/link_{link_index}/mesh/{geom.metadata.get('node')}"
523
524                        # calculate the inverse transform to get the mesh in the correct position
525                        cumulative_transform, _ = self.scene.graph.get(frame_to=joint_name)
526                        ctransform = cumulative_transform.copy()
527                        inverse_transform = np.linalg.inv(ctransform)
528
529                        # DH theta is rotated, rotate mesh around z in direction of theta
530                        rotation_matrix_z_4x4 = np.eye(4)
531                        if len(self.robot.dh_parameters) > link_index:
532                            rotation_z_minus_90 = Rotation.from_euler(
533                                "z", self.robot.dh_parameters[link_index].theta, degrees=False
534                            ).as_matrix()
535                            rotation_matrix_z_4x4[:3, :3] = rotation_z_minus_90
536
537                        # scale positions to mm
538                        inverse_transform[:3, 3] *= 1000
539
540                        root_transform = self.get_transform_matrix()
541
542                        transform = root_transform @ inverse_transform
543
544                        final_transform = link_transform @ rotation_matrix_z_4x4 @ transform
545
546                        self.init_mesh(entity_path, geom, joint_name)
547                        collect_geometry_data(entity_path, final_transform)
548
549            # Collect data for link geometries
550            for link_index, geometries in self.link_geometries.items():
551                link_transform = transforms[link_index]
552                for i, geom in enumerate(geometries):
553                    entity_path = f"{self.base_entity_path}/safety_from_controller/links/link_{link_index}/geometry_{i}"
554                    final_transform = link_transform @ self.geometry_pose_to_matrix(geom.init_pose)
555                    self.init_geometry(entity_path, geom.capsule)
556                    collect_geometry_data(entity_path, final_transform)
557
558            # Collect data for TCP geometries
559            if self.tcp_geometries:
560                tcp_transform = transforms[-1]  # End-effector transform
561                for i, geom in enumerate(self.tcp_geometries):
562                    entity_path = f"{self.base_entity_path}/safety_from_controller/tcp/geometry_{i}"
563                    final_transform = tcp_transform @ self.geometry_pose_to_matrix(geom.init_pose)
564                    self.init_geometry(entity_path, geom.capsule)
565                    collect_geometry_data(entity_path, final_transform)
566
567            # Collect data for collision link geometries
568            for link_index, geometries in enumerate(self.collision_link_geometries):
569                link_transform = transforms[link_index]
570                for i, geom_id in enumerate(geometries):
571                    entity_path = f"{self.base_entity_path}/collision/links/link_{link_index}/geometry_{geom_id}"
572
573                    pose = normalize_pose(geometries[geom_id].pose)
574
575                    final_transform = link_transform @ self.geometry_pose_to_matrix(pose)
576                    self.init_collision_geometry(entity_path, geometries[geom_id], pose)
577                    collect_geometry_data(entity_path, final_transform)
578
579            # Collect data for collision TCP geometries
580            if self.collision_tcp_geometries:
581                tcp_transform = transforms[-1]  # End-effector transform
582                for i, geom_id in enumerate(self.collision_tcp_geometries):
583                    entity_path = f"{self.base_entity_path}/collision/tcp/geometry_{geom_id}"
584
585                    pose = normalize_pose(self.collision_tcp_geometries[geom_id].pose)
586                    final_transform = tcp_transform @ self.geometry_pose_to_matrix(pose)
587
588                    # tcp collision geometries are defined in flange frame
589                    identity_pose = models.PlannerPose(
590                        position=models.Vector3d(x=0, y=0, z=0),
591                        orientation=models.Quaternion(x=0, y=0, z=0, w=1),
592                    )
593                    self.init_collision_geometry(
594                        entity_path, self.collision_tcp_geometries[geom_id], identity_pose
595                    )
596                    collect_geometry_data(entity_path, final_transform)
597
598        # Send collected columns for all geometries
599        for entity_path, positions in link_positions.items():
600            rr.send_columns(
601                entity_path,
602                indexes=[times_column],
603                columns=[
604                    *rr.Transform3D.columns(
605                        translation=positions, rotation_axis_angle=link_rotations[entity_path]
606                    )
607                ],
608            )
def get_model_path(model_name: str) -> str:
17def get_model_path(model_name: str) -> str:
18    """Get absolute path to model file in project directory"""
19    return str(get_project_root() / "models" / f"{model_name}.glb")

Get absolute path to model file in project directory

class RobotVisualizer:
 22class RobotVisualizer:
 23    def __init__(
 24        self,
 25        robot: DHRobot,
 26        robot_model_geometries,
 27        tcp_geometries,
 28        static_transform: bool = True,
 29        base_entity_path: str = "robot",
 30        albedo_factor: list = [255, 255, 255],
 31        collision_link_chain=None,
 32        collision_tcp=None,
 33        model_from_controller="",
 34    ):
 35        """
 36        :param robot: DHRobot instance
 37        :param robot_model_geometries: List of geometries for each link
 38        :param tcp_geometries: TCP geometries (similar structure to link geometries)
 39        :param static_transform: If True, transforms are logged as static, else temporal.
 40        :param base_entity_path: A base path prefix for logging the entities (e.g. motion group name)
 41        :param albedo_factor: A list representing the RGB values [R, G, B] to apply as the albedo factor.
 42        :param glb_path: Path to the GLB file for the robot model.
 43        """
 44        self.robot = robot
 45        self.link_geometries: dict[int, list[trimesh.Trimesh]] = {}
 46        self.tcp_geometries = tcp_geometries
 47        self.logged_meshes: set[str] = set()
 48        self.static_transform = static_transform
 49        self.base_entity_path = base_entity_path.rstrip("/")
 50        self.albedo_factor = albedo_factor
 51        self.mesh_loaded = False
 52        self.collision_link_geometries = {}
 53        self.collision_tcp_geometries = collision_tcp
 54
 55        # This will hold the names of discovered joints (e.g. ["robot_J00", "robot_J01", ...])
 56        self.joint_names: list[str] = []
 57        self.layer_nodes_dict: dict[str, list[str]] = {}
 58        self.parent_nodes_dict: dict[str, str] = {}
 59
 60        # load mesh
 61        try:
 62            glb_path = get_model_path(model_from_controller)
 63            self.scene = trimesh.load_scene(glb_path, file_type="glb")
 64            self.mesh_loaded = True
 65            self.edge_data = self.scene.graph.transforms.edge_data
 66
 67            # After loading, auto-discover any child nodes that match *_J0n
 68            self.discover_joints()
 69        except Exception as e:
 70            print(f"Failed to load mesh: {e}")
 71
 72        # Group geometries by link
 73        for gm in robot_model_geometries:
 74            self.link_geometries.setdefault(gm.link_index, []).append(gm.geometry)
 75
 76        # Group geometries by link
 77        self.collision_link_geometries = collision_link_chain
 78
 79    def discover_joints(self):
 80        """
 81        Find all child node names that contain '_J0' followed by digits or '_FLG'.
 82        Store joints with their parent nodes and print layer information.
 83        """
 84        joint_pattern = re.compile(r"_J0(\d+)")
 85        flg_pattern = re.compile(r"_FLG")
 86        matches = []
 87        flg_nodes = []
 88        joint_parents = {}  # Store parent for each joint/FLG
 89
 90        for (parent, child), data in self.edge_data.items():
 91            # Check for joints
 92            joint_match = joint_pattern.search(child)
 93            if joint_match:
 94                j_idx = int(joint_match.group(1))
 95                matches.append((j_idx, child))
 96                joint_parents[child] = parent
 97
 98            # Check for FLG
 99            flg_match = flg_pattern.search(child)
100            if flg_match:
101                flg_nodes.append(child)
102                joint_parents[child] = parent
103
104        matches.sort(key=lambda x: x[0])
105        self.joint_names = [name for _, name in matches] + flg_nodes
106
107        # print("Discovered nodes:", self.joint_names)
108        # Print layer information for each joint
109        for joint in self.joint_names:
110            self.get_nodes_on_same_layer(joint_parents[joint], joint)
111            # print(f"\nNodes on same layer as {joint}:")
112            # print(f"Parent node: {joint_parents[joint]}")
113            # print(f"Layer nodes: {same_layer_nodes}")
114
115    def get_nodes_on_same_layer(self, parent_node, joint):
116        """
117        Find nodes on same layer and only add descendants of link nodes.
118        """
119        same_layer = []
120        # First get immediate layer nodes
121        for (parent, child), data in self.edge_data.items():
122            if parent == parent_node:
123                if child == joint:
124                    continue
125                if "geometry" in data:
126                    same_layer.append(data["geometry"])
127                    self.parent_nodes_dict[data["geometry"]] = child
128
129                # Get all descendants for this link
130                parentChild = child
131                stack = [child]
132                while stack:
133                    current = stack.pop()
134                    for (p, c), data in self.edge_data.items():
135                        if p == current:
136                            if "geometry" in data:
137                                same_layer.append(data["geometry"])
138                                self.parent_nodes_dict[data["geometry"]] = parentChild
139                            stack.append(c)
140
141        self.layer_nodes_dict[joint] = same_layer
142        return same_layer
143
144    def geometry_pose_to_matrix(self, init_pose: models.PlannerPose):
145        return self.robot.pose_to_matrix(init_pose)
146
147    def compute_forward_kinematics(self, joint_values):
148        """Compute link transforms using the robot's methods."""
149        accumulated = self.robot.pose_to_matrix(self.robot.mounting)
150        transforms = [accumulated.copy()]
151        for dh_param, joint_rot in zip(self.robot.dh_parameters, joint_values.joints, strict=False):
152            transform = self.robot.dh_transform(dh_param, joint_rot)
153            accumulated = accumulated @ transform
154            transforms.append(accumulated.copy())
155        return transforms
156
157    def rotation_matrix_to_axis_angle(self, Rm):
158        """Use scipy for cleaner axis-angle extraction."""
159        rot = Rotation.from_matrix(Rm)
160        angle = rot.magnitude()
161        axis = rot.as_rotvec() / angle if angle > 1e-8 else np.array([1.0, 0.0, 0.0])
162        return axis, angle
163
164    def gamma_lift_single_color(self, color: np.ndarray, gamma: float = 0.8) -> np.ndarray:
165        """
166        Apply gamma correction to a single RGBA color in-place.
167        color: shape (4,) with [R, G, B, A] in 0..255, dtype=uint8
168        gamma: < 1.0 brightens midtones, > 1.0 darkens them.
169        """
170        rgb_float = color[:3].astype(np.float32) / 255.0
171        rgb_float = np.power(rgb_float, gamma)
172        color[:3] = (rgb_float * 255.0).astype(np.uint8)
173
174        return color
175
176    def get_transform_matrix(self):
177        """
178        Creates a transformation matrix that converts from glTF's right-handed Y-up
179        coordinate system to Rerun's right-handed Z-up coordinate system.
180
181        Returns:
182            np.ndarray: A 4x4 transformation matrix
183        """
184        # Convert from glTF's Y-up to Rerun's Z-up coordinate system
185        return np.array(
186            [
187                [1.0, 0.0, 0.0, 0.0],  # X stays the same
188                [0.0, 0.0, -1.0, 0.0],  # Y becomes -Z
189                [0.0, 1.0, 0.0, 0.0],  # Z becomes Y
190                [0.0, 0.0, 0.0, 1.0],  # Homogeneous coordinate
191            ]
192        )
193
194    def init_mesh(self, entity_path: str, geom, joint_name):
195        """Generic method to log a single geometry, either capsule or box."""
196
197        if entity_path not in self.logged_meshes:
198            if geom.metadata.get("node") not in self.parent_nodes_dict:
199                return
200
201            base_transform = np.eye(4)
202            # if the dh parameters are not at 0,0,0 from the mesh we have to move the first mesh joint
203            if "J00" in joint_name:
204                base_transform_, _ = self.scene.graph.get(frame_to=joint_name)
205                base_transform = base_transform_.copy()
206            base_transform[:3, 3] *= 1000
207
208            # if the mesh has the pivot not in the center, we need to adjust the transform
209            cumulative_transform, _ = self.scene.graph.get(
210                frame_to=self.parent_nodes_dict[geom.metadata.get("node")]
211            )
212            ctransform = cumulative_transform.copy()
213
214            # scale positions to mm
215            ctransform[:3, 3] *= 1000
216
217            # scale mesh to mm
218            transform = base_transform @ ctransform
219            mesh_scale_matrix = np.eye(4)
220            mesh_scale_matrix[:3, :3] *= 1000
221            transform = transform @ mesh_scale_matrix
222            transformed_mesh = geom.copy()
223
224            transformed_mesh.apply_transform(transform)
225
226            if transformed_mesh.visual is not None:
227                transformed_mesh.visual = transformed_mesh.visual.to_color()
228
229            vertex_colors = None
230            if transformed_mesh.visual and hasattr(transformed_mesh.visual, "vertex_colors"):
231                vertex_colors = transformed_mesh.visual.vertex_colors
232
233            rr.log(
234                entity_path,
235                rr.Mesh3D(
236                    vertex_positions=transformed_mesh.vertices,
237                    triangle_indices=transformed_mesh.faces,
238                    vertex_normals=getattr(transformed_mesh, "vertex_normals", None),
239                    albedo_factor=self.gamma_lift_single_color(vertex_colors, gamma=0.5)
240                    if vertex_colors is not None
241                    else None,
242                ),
243            )
244
245            self.logged_meshes.add(entity_path)
246
247    def init_collision_geometry(
248        self, entity_path: str, collider: models.Collider, pose: models.PlannerPose
249    ):
250        if entity_path in self.logged_meshes:
251            return
252
253        if isinstance(collider.shape.actual_instance, models.Sphere2):
254            rr.log(
255                f"{entity_path}",
256                rr.Ellipsoids3D(
257                    radii=[
258                        collider.shape.actual_instance.radius,
259                        collider.shape.actual_instance.radius,
260                        collider.shape.actual_instance.radius,
261                    ],
262                    centers=[[pose.position.x, pose.position.y, pose.position.z]]
263                    if pose.position
264                    else [0, 0, 0],
265                    colors=[(221, 193, 193, 255)],
266                ),
267            )
268
269        elif isinstance(collider.shape.actual_instance, models.Box2):
270            rr.log(
271                f"{entity_path}",
272                rr.Boxes3D(
273                    centers=[[pose.position.x, pose.position.y, pose.position.z]]
274                    if pose.position
275                    else [0, 0, 0],
276                    sizes=[
277                        collider.shape.actual_instance.size_x,
278                        collider.shape.actual_instance.size_y,
279                        collider.shape.actual_instance.size_z,
280                    ],
281                    colors=[(221, 193, 193, 255)],
282                ),
283            )
284
285        elif isinstance(collider.shape.actual_instance, models.Capsule2):
286            height = collider.shape.actual_instance.cylinder_height
287            radius = collider.shape.actual_instance.radius
288
289            # Generate trimesh capsule
290            capsule = trimesh.creation.capsule(height=height, radius=radius, count=[6, 8])
291
292            # Extract vertices and faces for solid visualization
293            vertices = np.array(capsule.vertices)
294
295            # Transform vertices to world position
296            transform = np.eye(4)
297            if pose.position:
298                transform[:3, 3] = [pose.position.x, pose.position.y, pose.position.z - height / 2]
299            else:
300                transform[:3, 3] = [0, 0, -height / 2]
301
302            if collider.pose and collider.pose.orientation:
303                rot_mat = Rotation.from_quat(
304                    [
305                        collider.pose.orientation[0],
306                        collider.pose.orientation[1],
307                        collider.pose.orientation[2],
308                        collider.pose.orientation[3],
309                    ]
310                )
311                transform[:3, :3] = rot_mat.as_matrix()
312
313            vertices = np.array([transform @ np.append(v, 1) for v in vertices])[:, :3]
314
315            polygons = HullVisualizer.compute_hull_outlines_from_points(vertices)
316
317            if polygons:
318                line_segments = [p.tolist() for p in polygons]
319                rr.log(
320                    f"{entity_path}",
321                    rr.LineStrips3D(
322                        line_segments,
323                        radii=rr.Radius.ui_points(0.75),
324                        colors=[[221, 193, 193, 255]],
325                    ),
326                    static=True,
327                )
328
329        elif isinstance(collider.shape.actual_instance, models.ConvexHull2):
330            polygons = HullVisualizer.compute_hull_outlines_from_points(
331                collider.shape.actual_instance.vertices
332            )
333
334            if polygons:
335                line_segments = [p.tolist() for p in polygons]
336                rr.log(
337                    f"{entity_path}",
338                    rr.LineStrips3D(
339                        line_segments, radii=rr.Radius.ui_points(1.5), colors=[colors.colors[2]]
340                    ),
341                    static=True,
342                )
343
344                vertices, triangles, normals = HullVisualizer.compute_hull_mesh(polygons)
345
346                rr.log(
347                    f"{entity_path}",
348                    rr.Mesh3D(
349                        vertex_positions=vertices,
350                        triangle_indices=triangles,
351                        vertex_normals=normals,
352                        albedo_factor=colors.colors[0],
353                    ),
354                    static=True,
355                )
356
357        self.logged_meshes.add(entity_path)
358
359    def init_geometry(self, entity_path: str, capsule):
360        """Generic method to log a single geometry, either capsule or box."""
361
362        if entity_path in self.logged_meshes:
363            return
364
365        if capsule:
366            radius = capsule.radius
367            height = capsule.cylinder_height
368
369            # Slightly shrink the capsule if static to reduce z-fighting
370            if self.static_transform:
371                radius *= 0.99
372                height *= 0.99
373
374            # Create capsule and retrieve normals
375            cap_mesh = trimesh.creation.capsule(radius=radius, height=height)
376            vertex_normals = cap_mesh.vertex_normals.tolist()
377
378            rr.log(
379                entity_path,
380                rr.Mesh3D(
381                    vertex_positions=cap_mesh.vertices.tolist(),
382                    triangle_indices=cap_mesh.faces.tolist(),
383                    vertex_normals=vertex_normals,
384                    albedo_factor=self.albedo_factor,
385                ),
386            )
387            self.logged_meshes.add(entity_path)
388        else:
389            # fallback to a box
390            rr.log(entity_path, rr.Boxes3D(half_sizes=[[50, 50, 50]]))
391            self.logged_meshes.add(entity_path)
392
393    def log_robot_geometry(self, joint_position):
394        transforms = self.compute_forward_kinematics(joint_position)
395
396        def log_geometry(entity_path, transform):
397            translation = transform[:3, 3]
398            Rm = transform[:3, :3]
399            axis, angle = self.rotation_matrix_to_axis_angle(Rm)
400            rr.log(
401                entity_path,
402                rr.InstancePoses3D(
403                    translations=[translation.tolist()],
404                    rotation_axis_angles=[
405                        rr.RotationAxisAngle(axis=axis.tolist(), angle=float(angle))
406                    ],
407                ),
408                static=self.static_transform,
409            )
410
411        # Log robot joint geometries
412        if self.mesh_loaded:
413            for link_index, joint_name in enumerate(self.joint_names):
414                link_transform = transforms[link_index]
415
416                # Get nodes on same layer using dictionary
417                same_layer_nodes = self.layer_nodes_dict.get(joint_name)
418                if not same_layer_nodes:
419                    continue
420
421                filtered_geoms = []
422                for node_name in same_layer_nodes:
423                    if node_name in self.scene.geometry:
424                        geom = self.scene.geometry[node_name]
425                        # Add metadata that would normally come from dump
426                        geom.metadata = {"node": node_name}
427                        filtered_geoms.append(geom)
428
429                for geom in filtered_geoms:
430                    entity_path = f"{self.base_entity_path}/visual/links/link_{link_index}/mesh/{geom.metadata.get('node')}"
431
432                    # calculate the inverse transform to get the mesh in the correct position
433                    cumulative_transform, _ = self.scene.graph.get(frame_to=joint_name)
434                    ctransform = cumulative_transform.copy()
435                    inverse_transform = np.linalg.inv(ctransform)
436
437                    # DH theta is rotated, rotate mesh around z in direction of theta
438                    rotation_matrix_z_4x4 = np.eye(4)
439                    if len(self.robot.dh_parameters) > link_index:
440                        rotation_z_minus_90 = Rotation.from_euler(
441                            "z", self.robot.dh_parameters[link_index].theta, degrees=False
442                        ).as_matrix()
443                        rotation_matrix_z_4x4[:3, :3] = rotation_z_minus_90
444
445                    # scale positions to mm
446                    inverse_transform[:3, 3] *= 1000
447
448                    root_transform = self.get_transform_matrix()
449
450                    transform = root_transform @ inverse_transform
451
452                    final_transform = link_transform @ rotation_matrix_z_4x4 @ transform
453
454                    self.init_mesh(entity_path, geom, joint_name)
455                    log_geometry(entity_path, final_transform)
456
457        # Log link geometries
458        for link_index, geometries in self.link_geometries.items():
459            link_transform = transforms[link_index]
460            for i, geom in enumerate(geometries):
461                entity_path = f"{self.base_entity_path}/safety_from_controller/links/link_{link_index}/geometry_{i}"
462                final_transform = link_transform @ self.geometry_pose_to_matrix(geom.init_pose)
463
464                self.init_geometry(entity_path, geom.capsule)
465                log_geometry(entity_path, final_transform)
466
467        # Log TCP geometries
468        if self.tcp_geometries:
469            tcp_transform = transforms[-1]  # the final frame transform
470            for i, geom in enumerate(self.tcp_geometries):
471                entity_path = f"{self.base_entity_path}/safety_from_controller/tcp/geometry_{i}"
472                final_transform = tcp_transform @ self.geometry_pose_to_matrix(geom.init_pose)
473
474                self.init_geometry(entity_path, geom.capsule)
475                log_geometry(entity_path, final_transform)
476
477    def log_robot_geometries(self, trajectory: list[models.TrajectorySample], times_column):
478        """
479        Log the robot geometries for each link and TCP as separate entities.
480
481        Args:
482            trajectory (List[wb.models.TrajectorySample]): The list of trajectory sample points.
483            times_column (rr.TimeSecondsColumn): The time column associated with the trajectory points.
484        """
485        link_positions = {}
486        link_rotations = {}
487
488        def collect_geometry_data(entity_path, transform):
489            """Helper to collect geometry data for a given entity."""
490            translation = transform[:3, 3].tolist()
491            Rm = transform[:3, :3]
492            axis, angle = self.rotation_matrix_to_axis_angle(Rm)
493            if entity_path not in link_positions:
494                link_positions[entity_path] = []
495                link_rotations[entity_path] = []
496            link_positions[entity_path].append(translation)
497            link_rotations[entity_path].append(rr.RotationAxisAngle(axis=axis, angle=angle))
498
499        for point in trajectory:
500            transforms = self.compute_forward_kinematics(point.joint_position)
501
502            # Log robot joint geometries
503            if self.mesh_loaded:
504                for link_index, joint_name in enumerate(self.joint_names):
505                    if link_index >= len(transforms):
506                        break
507                    link_transform = transforms[link_index]
508
509                    # Get nodes on same layer using dictionary
510                    same_layer_nodes = self.layer_nodes_dict.get(joint_name)
511                    if not same_layer_nodes:
512                        continue
513
514                    filtered_geoms = []
515                    for node_name in same_layer_nodes:
516                        if node_name in self.scene.geometry:
517                            geom = self.scene.geometry[node_name]
518                            # Add metadata that would normally come from dump
519                            geom.metadata = {"node": node_name}
520                            filtered_geoms.append(geom)
521
522                    for geom in filtered_geoms:
523                        entity_path = f"{self.base_entity_path}/visual/links/link_{link_index}/mesh/{geom.metadata.get('node')}"
524
525                        # calculate the inverse transform to get the mesh in the correct position
526                        cumulative_transform, _ = self.scene.graph.get(frame_to=joint_name)
527                        ctransform = cumulative_transform.copy()
528                        inverse_transform = np.linalg.inv(ctransform)
529
530                        # DH theta is rotated, rotate mesh around z in direction of theta
531                        rotation_matrix_z_4x4 = np.eye(4)
532                        if len(self.robot.dh_parameters) > link_index:
533                            rotation_z_minus_90 = Rotation.from_euler(
534                                "z", self.robot.dh_parameters[link_index].theta, degrees=False
535                            ).as_matrix()
536                            rotation_matrix_z_4x4[:3, :3] = rotation_z_minus_90
537
538                        # scale positions to mm
539                        inverse_transform[:3, 3] *= 1000
540
541                        root_transform = self.get_transform_matrix()
542
543                        transform = root_transform @ inverse_transform
544
545                        final_transform = link_transform @ rotation_matrix_z_4x4 @ transform
546
547                        self.init_mesh(entity_path, geom, joint_name)
548                        collect_geometry_data(entity_path, final_transform)
549
550            # Collect data for link geometries
551            for link_index, geometries in self.link_geometries.items():
552                link_transform = transforms[link_index]
553                for i, geom in enumerate(geometries):
554                    entity_path = f"{self.base_entity_path}/safety_from_controller/links/link_{link_index}/geometry_{i}"
555                    final_transform = link_transform @ self.geometry_pose_to_matrix(geom.init_pose)
556                    self.init_geometry(entity_path, geom.capsule)
557                    collect_geometry_data(entity_path, final_transform)
558
559            # Collect data for TCP geometries
560            if self.tcp_geometries:
561                tcp_transform = transforms[-1]  # End-effector transform
562                for i, geom in enumerate(self.tcp_geometries):
563                    entity_path = f"{self.base_entity_path}/safety_from_controller/tcp/geometry_{i}"
564                    final_transform = tcp_transform @ self.geometry_pose_to_matrix(geom.init_pose)
565                    self.init_geometry(entity_path, geom.capsule)
566                    collect_geometry_data(entity_path, final_transform)
567
568            # Collect data for collision link geometries
569            for link_index, geometries in enumerate(self.collision_link_geometries):
570                link_transform = transforms[link_index]
571                for i, geom_id in enumerate(geometries):
572                    entity_path = f"{self.base_entity_path}/collision/links/link_{link_index}/geometry_{geom_id}"
573
574                    pose = normalize_pose(geometries[geom_id].pose)
575
576                    final_transform = link_transform @ self.geometry_pose_to_matrix(pose)
577                    self.init_collision_geometry(entity_path, geometries[geom_id], pose)
578                    collect_geometry_data(entity_path, final_transform)
579
580            # Collect data for collision TCP geometries
581            if self.collision_tcp_geometries:
582                tcp_transform = transforms[-1]  # End-effector transform
583                for i, geom_id in enumerate(self.collision_tcp_geometries):
584                    entity_path = f"{self.base_entity_path}/collision/tcp/geometry_{geom_id}"
585
586                    pose = normalize_pose(self.collision_tcp_geometries[geom_id].pose)
587                    final_transform = tcp_transform @ self.geometry_pose_to_matrix(pose)
588
589                    # tcp collision geometries are defined in flange frame
590                    identity_pose = models.PlannerPose(
591                        position=models.Vector3d(x=0, y=0, z=0),
592                        orientation=models.Quaternion(x=0, y=0, z=0, w=1),
593                    )
594                    self.init_collision_geometry(
595                        entity_path, self.collision_tcp_geometries[geom_id], identity_pose
596                    )
597                    collect_geometry_data(entity_path, final_transform)
598
599        # Send collected columns for all geometries
600        for entity_path, positions in link_positions.items():
601            rr.send_columns(
602                entity_path,
603                indexes=[times_column],
604                columns=[
605                    *rr.Transform3D.columns(
606                        translation=positions, rotation_axis_angle=link_rotations[entity_path]
607                    )
608                ],
609            )
RobotVisualizer( robot: nova_rerun_bridge.dh_robot.DHRobot, robot_model_geometries, tcp_geometries, static_transform: bool = True, base_entity_path: str = 'robot', albedo_factor: list = [255, 255, 255], collision_link_chain=None, collision_tcp=None, model_from_controller='')
23    def __init__(
24        self,
25        robot: DHRobot,
26        robot_model_geometries,
27        tcp_geometries,
28        static_transform: bool = True,
29        base_entity_path: str = "robot",
30        albedo_factor: list = [255, 255, 255],
31        collision_link_chain=None,
32        collision_tcp=None,
33        model_from_controller="",
34    ):
35        """
36        :param robot: DHRobot instance
37        :param robot_model_geometries: List of geometries for each link
38        :param tcp_geometries: TCP geometries (similar structure to link geometries)
39        :param static_transform: If True, transforms are logged as static, else temporal.
40        :param base_entity_path: A base path prefix for logging the entities (e.g. motion group name)
41        :param albedo_factor: A list representing the RGB values [R, G, B] to apply as the albedo factor.
42        :param glb_path: Path to the GLB file for the robot model.
43        """
44        self.robot = robot
45        self.link_geometries: dict[int, list[trimesh.Trimesh]] = {}
46        self.tcp_geometries = tcp_geometries
47        self.logged_meshes: set[str] = set()
48        self.static_transform = static_transform
49        self.base_entity_path = base_entity_path.rstrip("/")
50        self.albedo_factor = albedo_factor
51        self.mesh_loaded = False
52        self.collision_link_geometries = {}
53        self.collision_tcp_geometries = collision_tcp
54
55        # This will hold the names of discovered joints (e.g. ["robot_J00", "robot_J01", ...])
56        self.joint_names: list[str] = []
57        self.layer_nodes_dict: dict[str, list[str]] = {}
58        self.parent_nodes_dict: dict[str, str] = {}
59
60        # load mesh
61        try:
62            glb_path = get_model_path(model_from_controller)
63            self.scene = trimesh.load_scene(glb_path, file_type="glb")
64            self.mesh_loaded = True
65            self.edge_data = self.scene.graph.transforms.edge_data
66
67            # After loading, auto-discover any child nodes that match *_J0n
68            self.discover_joints()
69        except Exception as e:
70            print(f"Failed to load mesh: {e}")
71
72        # Group geometries by link
73        for gm in robot_model_geometries:
74            self.link_geometries.setdefault(gm.link_index, []).append(gm.geometry)
75
76        # Group geometries by link
77        self.collision_link_geometries = collision_link_chain
Parameters
  • robot: DHRobot instance
  • robot_model_geometries: List of geometries for each link
  • tcp_geometries: TCP geometries (similar structure to link geometries)
  • static_transform: If True, transforms are logged as static, else temporal.
  • base_entity_path: A base path prefix for logging the entities (e.g. motion group name)
  • albedo_factor: A list representing the RGB values [R, G, B] to apply as the albedo factor.
  • glb_path: Path to the GLB file for the robot model.
robot
tcp_geometries
logged_meshes: set[str]
static_transform
base_entity_path
albedo_factor
mesh_loaded
collision_tcp_geometries
joint_names: list[str]
layer_nodes_dict: dict[str, list[str]]
parent_nodes_dict: dict[str, str]
def discover_joints(self):
 79    def discover_joints(self):
 80        """
 81        Find all child node names that contain '_J0' followed by digits or '_FLG'.
 82        Store joints with their parent nodes and print layer information.
 83        """
 84        joint_pattern = re.compile(r"_J0(\d+)")
 85        flg_pattern = re.compile(r"_FLG")
 86        matches = []
 87        flg_nodes = []
 88        joint_parents = {}  # Store parent for each joint/FLG
 89
 90        for (parent, child), data in self.edge_data.items():
 91            # Check for joints
 92            joint_match = joint_pattern.search(child)
 93            if joint_match:
 94                j_idx = int(joint_match.group(1))
 95                matches.append((j_idx, child))
 96                joint_parents[child] = parent
 97
 98            # Check for FLG
 99            flg_match = flg_pattern.search(child)
100            if flg_match:
101                flg_nodes.append(child)
102                joint_parents[child] = parent
103
104        matches.sort(key=lambda x: x[0])
105        self.joint_names = [name for _, name in matches] + flg_nodes
106
107        # print("Discovered nodes:", self.joint_names)
108        # Print layer information for each joint
109        for joint in self.joint_names:
110            self.get_nodes_on_same_layer(joint_parents[joint], joint)
111            # print(f"\nNodes on same layer as {joint}:")
112            # print(f"Parent node: {joint_parents[joint]}")
113            # print(f"Layer nodes: {same_layer_nodes}")

Find all child node names that contain '_J0' followed by digits or '_FLG'. Store joints with their parent nodes and print layer information.

def get_nodes_on_same_layer(self, parent_node, joint):
115    def get_nodes_on_same_layer(self, parent_node, joint):
116        """
117        Find nodes on same layer and only add descendants of link nodes.
118        """
119        same_layer = []
120        # First get immediate layer nodes
121        for (parent, child), data in self.edge_data.items():
122            if parent == parent_node:
123                if child == joint:
124                    continue
125                if "geometry" in data:
126                    same_layer.append(data["geometry"])
127                    self.parent_nodes_dict[data["geometry"]] = child
128
129                # Get all descendants for this link
130                parentChild = child
131                stack = [child]
132                while stack:
133                    current = stack.pop()
134                    for (p, c), data in self.edge_data.items():
135                        if p == current:
136                            if "geometry" in data:
137                                same_layer.append(data["geometry"])
138                                self.parent_nodes_dict[data["geometry"]] = parentChild
139                            stack.append(c)
140
141        self.layer_nodes_dict[joint] = same_layer
142        return same_layer

Find nodes on same layer and only add descendants of link nodes.

def geometry_pose_to_matrix( self, init_pose: wandelbots_api_client.models.planner_pose.PlannerPose):
144    def geometry_pose_to_matrix(self, init_pose: models.PlannerPose):
145        return self.robot.pose_to_matrix(init_pose)
def compute_forward_kinematics(self, joint_values):
147    def compute_forward_kinematics(self, joint_values):
148        """Compute link transforms using the robot's methods."""
149        accumulated = self.robot.pose_to_matrix(self.robot.mounting)
150        transforms = [accumulated.copy()]
151        for dh_param, joint_rot in zip(self.robot.dh_parameters, joint_values.joints, strict=False):
152            transform = self.robot.dh_transform(dh_param, joint_rot)
153            accumulated = accumulated @ transform
154            transforms.append(accumulated.copy())
155        return transforms

Compute link transforms using the robot's methods.

def rotation_matrix_to_axis_angle(self, Rm):
157    def rotation_matrix_to_axis_angle(self, Rm):
158        """Use scipy for cleaner axis-angle extraction."""
159        rot = Rotation.from_matrix(Rm)
160        angle = rot.magnitude()
161        axis = rot.as_rotvec() / angle if angle > 1e-8 else np.array([1.0, 0.0, 0.0])
162        return axis, angle

Use scipy for cleaner axis-angle extraction.

def gamma_lift_single_color(self, color: numpy.ndarray, gamma: float = 0.8) -> numpy.ndarray:
164    def gamma_lift_single_color(self, color: np.ndarray, gamma: float = 0.8) -> np.ndarray:
165        """
166        Apply gamma correction to a single RGBA color in-place.
167        color: shape (4,) with [R, G, B, A] in 0..255, dtype=uint8
168        gamma: < 1.0 brightens midtones, > 1.0 darkens them.
169        """
170        rgb_float = color[:3].astype(np.float32) / 255.0
171        rgb_float = np.power(rgb_float, gamma)
172        color[:3] = (rgb_float * 255.0).astype(np.uint8)
173
174        return color

Apply gamma correction to a single RGBA color in-place. color: shape (4,) with [R, G, B, A] in 0..255, dtype=uint8 gamma: < 1.0 brightens midtones, > 1.0 darkens them.

def get_transform_matrix(self):
176    def get_transform_matrix(self):
177        """
178        Creates a transformation matrix that converts from glTF's right-handed Y-up
179        coordinate system to Rerun's right-handed Z-up coordinate system.
180
181        Returns:
182            np.ndarray: A 4x4 transformation matrix
183        """
184        # Convert from glTF's Y-up to Rerun's Z-up coordinate system
185        return np.array(
186            [
187                [1.0, 0.0, 0.0, 0.0],  # X stays the same
188                [0.0, 0.0, -1.0, 0.0],  # Y becomes -Z
189                [0.0, 1.0, 0.0, 0.0],  # Z becomes Y
190                [0.0, 0.0, 0.0, 1.0],  # Homogeneous coordinate
191            ]
192        )

Creates a transformation matrix that converts from glTF's right-handed Y-up coordinate system to Rerun's right-handed Z-up coordinate system.

Returns: np.ndarray: A 4x4 transformation matrix

def init_mesh(self, entity_path: str, geom, joint_name):
194    def init_mesh(self, entity_path: str, geom, joint_name):
195        """Generic method to log a single geometry, either capsule or box."""
196
197        if entity_path not in self.logged_meshes:
198            if geom.metadata.get("node") not in self.parent_nodes_dict:
199                return
200
201            base_transform = np.eye(4)
202            # if the dh parameters are not at 0,0,0 from the mesh we have to move the first mesh joint
203            if "J00" in joint_name:
204                base_transform_, _ = self.scene.graph.get(frame_to=joint_name)
205                base_transform = base_transform_.copy()
206            base_transform[:3, 3] *= 1000
207
208            # if the mesh has the pivot not in the center, we need to adjust the transform
209            cumulative_transform, _ = self.scene.graph.get(
210                frame_to=self.parent_nodes_dict[geom.metadata.get("node")]
211            )
212            ctransform = cumulative_transform.copy()
213
214            # scale positions to mm
215            ctransform[:3, 3] *= 1000
216
217            # scale mesh to mm
218            transform = base_transform @ ctransform
219            mesh_scale_matrix = np.eye(4)
220            mesh_scale_matrix[:3, :3] *= 1000
221            transform = transform @ mesh_scale_matrix
222            transformed_mesh = geom.copy()
223
224            transformed_mesh.apply_transform(transform)
225
226            if transformed_mesh.visual is not None:
227                transformed_mesh.visual = transformed_mesh.visual.to_color()
228
229            vertex_colors = None
230            if transformed_mesh.visual and hasattr(transformed_mesh.visual, "vertex_colors"):
231                vertex_colors = transformed_mesh.visual.vertex_colors
232
233            rr.log(
234                entity_path,
235                rr.Mesh3D(
236                    vertex_positions=transformed_mesh.vertices,
237                    triangle_indices=transformed_mesh.faces,
238                    vertex_normals=getattr(transformed_mesh, "vertex_normals", None),
239                    albedo_factor=self.gamma_lift_single_color(vertex_colors, gamma=0.5)
240                    if vertex_colors is not None
241                    else None,
242                ),
243            )
244
245            self.logged_meshes.add(entity_path)

Generic method to log a single geometry, either capsule or box.

def init_collision_geometry( self, entity_path: str, collider: wandelbots_api_client.models.collider.Collider, pose: wandelbots_api_client.models.planner_pose.PlannerPose):
247    def init_collision_geometry(
248        self, entity_path: str, collider: models.Collider, pose: models.PlannerPose
249    ):
250        if entity_path in self.logged_meshes:
251            return
252
253        if isinstance(collider.shape.actual_instance, models.Sphere2):
254            rr.log(
255                f"{entity_path}",
256                rr.Ellipsoids3D(
257                    radii=[
258                        collider.shape.actual_instance.radius,
259                        collider.shape.actual_instance.radius,
260                        collider.shape.actual_instance.radius,
261                    ],
262                    centers=[[pose.position.x, pose.position.y, pose.position.z]]
263                    if pose.position
264                    else [0, 0, 0],
265                    colors=[(221, 193, 193, 255)],
266                ),
267            )
268
269        elif isinstance(collider.shape.actual_instance, models.Box2):
270            rr.log(
271                f"{entity_path}",
272                rr.Boxes3D(
273                    centers=[[pose.position.x, pose.position.y, pose.position.z]]
274                    if pose.position
275                    else [0, 0, 0],
276                    sizes=[
277                        collider.shape.actual_instance.size_x,
278                        collider.shape.actual_instance.size_y,
279                        collider.shape.actual_instance.size_z,
280                    ],
281                    colors=[(221, 193, 193, 255)],
282                ),
283            )
284
285        elif isinstance(collider.shape.actual_instance, models.Capsule2):
286            height = collider.shape.actual_instance.cylinder_height
287            radius = collider.shape.actual_instance.radius
288
289            # Generate trimesh capsule
290            capsule = trimesh.creation.capsule(height=height, radius=radius, count=[6, 8])
291
292            # Extract vertices and faces for solid visualization
293            vertices = np.array(capsule.vertices)
294
295            # Transform vertices to world position
296            transform = np.eye(4)
297            if pose.position:
298                transform[:3, 3] = [pose.position.x, pose.position.y, pose.position.z - height / 2]
299            else:
300                transform[:3, 3] = [0, 0, -height / 2]
301
302            if collider.pose and collider.pose.orientation:
303                rot_mat = Rotation.from_quat(
304                    [
305                        collider.pose.orientation[0],
306                        collider.pose.orientation[1],
307                        collider.pose.orientation[2],
308                        collider.pose.orientation[3],
309                    ]
310                )
311                transform[:3, :3] = rot_mat.as_matrix()
312
313            vertices = np.array([transform @ np.append(v, 1) for v in vertices])[:, :3]
314
315            polygons = HullVisualizer.compute_hull_outlines_from_points(vertices)
316
317            if polygons:
318                line_segments = [p.tolist() for p in polygons]
319                rr.log(
320                    f"{entity_path}",
321                    rr.LineStrips3D(
322                        line_segments,
323                        radii=rr.Radius.ui_points(0.75),
324                        colors=[[221, 193, 193, 255]],
325                    ),
326                    static=True,
327                )
328
329        elif isinstance(collider.shape.actual_instance, models.ConvexHull2):
330            polygons = HullVisualizer.compute_hull_outlines_from_points(
331                collider.shape.actual_instance.vertices
332            )
333
334            if polygons:
335                line_segments = [p.tolist() for p in polygons]
336                rr.log(
337                    f"{entity_path}",
338                    rr.LineStrips3D(
339                        line_segments, radii=rr.Radius.ui_points(1.5), colors=[colors.colors[2]]
340                    ),
341                    static=True,
342                )
343
344                vertices, triangles, normals = HullVisualizer.compute_hull_mesh(polygons)
345
346                rr.log(
347                    f"{entity_path}",
348                    rr.Mesh3D(
349                        vertex_positions=vertices,
350                        triangle_indices=triangles,
351                        vertex_normals=normals,
352                        albedo_factor=colors.colors[0],
353                    ),
354                    static=True,
355                )
356
357        self.logged_meshes.add(entity_path)
def init_geometry(self, entity_path: str, capsule):
359    def init_geometry(self, entity_path: str, capsule):
360        """Generic method to log a single geometry, either capsule or box."""
361
362        if entity_path in self.logged_meshes:
363            return
364
365        if capsule:
366            radius = capsule.radius
367            height = capsule.cylinder_height
368
369            # Slightly shrink the capsule if static to reduce z-fighting
370            if self.static_transform:
371                radius *= 0.99
372                height *= 0.99
373
374            # Create capsule and retrieve normals
375            cap_mesh = trimesh.creation.capsule(radius=radius, height=height)
376            vertex_normals = cap_mesh.vertex_normals.tolist()
377
378            rr.log(
379                entity_path,
380                rr.Mesh3D(
381                    vertex_positions=cap_mesh.vertices.tolist(),
382                    triangle_indices=cap_mesh.faces.tolist(),
383                    vertex_normals=vertex_normals,
384                    albedo_factor=self.albedo_factor,
385                ),
386            )
387            self.logged_meshes.add(entity_path)
388        else:
389            # fallback to a box
390            rr.log(entity_path, rr.Boxes3D(half_sizes=[[50, 50, 50]]))
391            self.logged_meshes.add(entity_path)

Generic method to log a single geometry, either capsule or box.

def log_robot_geometry(self, joint_position):
393    def log_robot_geometry(self, joint_position):
394        transforms = self.compute_forward_kinematics(joint_position)
395
396        def log_geometry(entity_path, transform):
397            translation = transform[:3, 3]
398            Rm = transform[:3, :3]
399            axis, angle = self.rotation_matrix_to_axis_angle(Rm)
400            rr.log(
401                entity_path,
402                rr.InstancePoses3D(
403                    translations=[translation.tolist()],
404                    rotation_axis_angles=[
405                        rr.RotationAxisAngle(axis=axis.tolist(), angle=float(angle))
406                    ],
407                ),
408                static=self.static_transform,
409            )
410
411        # Log robot joint geometries
412        if self.mesh_loaded:
413            for link_index, joint_name in enumerate(self.joint_names):
414                link_transform = transforms[link_index]
415
416                # Get nodes on same layer using dictionary
417                same_layer_nodes = self.layer_nodes_dict.get(joint_name)
418                if not same_layer_nodes:
419                    continue
420
421                filtered_geoms = []
422                for node_name in same_layer_nodes:
423                    if node_name in self.scene.geometry:
424                        geom = self.scene.geometry[node_name]
425                        # Add metadata that would normally come from dump
426                        geom.metadata = {"node": node_name}
427                        filtered_geoms.append(geom)
428
429                for geom in filtered_geoms:
430                    entity_path = f"{self.base_entity_path}/visual/links/link_{link_index}/mesh/{geom.metadata.get('node')}"
431
432                    # calculate the inverse transform to get the mesh in the correct position
433                    cumulative_transform, _ = self.scene.graph.get(frame_to=joint_name)
434                    ctransform = cumulative_transform.copy()
435                    inverse_transform = np.linalg.inv(ctransform)
436
437                    # DH theta is rotated, rotate mesh around z in direction of theta
438                    rotation_matrix_z_4x4 = np.eye(4)
439                    if len(self.robot.dh_parameters) > link_index:
440                        rotation_z_minus_90 = Rotation.from_euler(
441                            "z", self.robot.dh_parameters[link_index].theta, degrees=False
442                        ).as_matrix()
443                        rotation_matrix_z_4x4[:3, :3] = rotation_z_minus_90
444
445                    # scale positions to mm
446                    inverse_transform[:3, 3] *= 1000
447
448                    root_transform = self.get_transform_matrix()
449
450                    transform = root_transform @ inverse_transform
451
452                    final_transform = link_transform @ rotation_matrix_z_4x4 @ transform
453
454                    self.init_mesh(entity_path, geom, joint_name)
455                    log_geometry(entity_path, final_transform)
456
457        # Log link geometries
458        for link_index, geometries in self.link_geometries.items():
459            link_transform = transforms[link_index]
460            for i, geom in enumerate(geometries):
461                entity_path = f"{self.base_entity_path}/safety_from_controller/links/link_{link_index}/geometry_{i}"
462                final_transform = link_transform @ self.geometry_pose_to_matrix(geom.init_pose)
463
464                self.init_geometry(entity_path, geom.capsule)
465                log_geometry(entity_path, final_transform)
466
467        # Log TCP geometries
468        if self.tcp_geometries:
469            tcp_transform = transforms[-1]  # the final frame transform
470            for i, geom in enumerate(self.tcp_geometries):
471                entity_path = f"{self.base_entity_path}/safety_from_controller/tcp/geometry_{i}"
472                final_transform = tcp_transform @ self.geometry_pose_to_matrix(geom.init_pose)
473
474                self.init_geometry(entity_path, geom.capsule)
475                log_geometry(entity_path, final_transform)
def log_robot_geometries( self, trajectory: list[wandelbots_api_client.models.trajectory_sample.TrajectorySample], times_column):
477    def log_robot_geometries(self, trajectory: list[models.TrajectorySample], times_column):
478        """
479        Log the robot geometries for each link and TCP as separate entities.
480
481        Args:
482            trajectory (List[wb.models.TrajectorySample]): The list of trajectory sample points.
483            times_column (rr.TimeSecondsColumn): The time column associated with the trajectory points.
484        """
485        link_positions = {}
486        link_rotations = {}
487
488        def collect_geometry_data(entity_path, transform):
489            """Helper to collect geometry data for a given entity."""
490            translation = transform[:3, 3].tolist()
491            Rm = transform[:3, :3]
492            axis, angle = self.rotation_matrix_to_axis_angle(Rm)
493            if entity_path not in link_positions:
494                link_positions[entity_path] = []
495                link_rotations[entity_path] = []
496            link_positions[entity_path].append(translation)
497            link_rotations[entity_path].append(rr.RotationAxisAngle(axis=axis, angle=angle))
498
499        for point in trajectory:
500            transforms = self.compute_forward_kinematics(point.joint_position)
501
502            # Log robot joint geometries
503            if self.mesh_loaded:
504                for link_index, joint_name in enumerate(self.joint_names):
505                    if link_index >= len(transforms):
506                        break
507                    link_transform = transforms[link_index]
508
509                    # Get nodes on same layer using dictionary
510                    same_layer_nodes = self.layer_nodes_dict.get(joint_name)
511                    if not same_layer_nodes:
512                        continue
513
514                    filtered_geoms = []
515                    for node_name in same_layer_nodes:
516                        if node_name in self.scene.geometry:
517                            geom = self.scene.geometry[node_name]
518                            # Add metadata that would normally come from dump
519                            geom.metadata = {"node": node_name}
520                            filtered_geoms.append(geom)
521
522                    for geom in filtered_geoms:
523                        entity_path = f"{self.base_entity_path}/visual/links/link_{link_index}/mesh/{geom.metadata.get('node')}"
524
525                        # calculate the inverse transform to get the mesh in the correct position
526                        cumulative_transform, _ = self.scene.graph.get(frame_to=joint_name)
527                        ctransform = cumulative_transform.copy()
528                        inverse_transform = np.linalg.inv(ctransform)
529
530                        # DH theta is rotated, rotate mesh around z in direction of theta
531                        rotation_matrix_z_4x4 = np.eye(4)
532                        if len(self.robot.dh_parameters) > link_index:
533                            rotation_z_minus_90 = Rotation.from_euler(
534                                "z", self.robot.dh_parameters[link_index].theta, degrees=False
535                            ).as_matrix()
536                            rotation_matrix_z_4x4[:3, :3] = rotation_z_minus_90
537
538                        # scale positions to mm
539                        inverse_transform[:3, 3] *= 1000
540
541                        root_transform = self.get_transform_matrix()
542
543                        transform = root_transform @ inverse_transform
544
545                        final_transform = link_transform @ rotation_matrix_z_4x4 @ transform
546
547                        self.init_mesh(entity_path, geom, joint_name)
548                        collect_geometry_data(entity_path, final_transform)
549
550            # Collect data for link geometries
551            for link_index, geometries in self.link_geometries.items():
552                link_transform = transforms[link_index]
553                for i, geom in enumerate(geometries):
554                    entity_path = f"{self.base_entity_path}/safety_from_controller/links/link_{link_index}/geometry_{i}"
555                    final_transform = link_transform @ self.geometry_pose_to_matrix(geom.init_pose)
556                    self.init_geometry(entity_path, geom.capsule)
557                    collect_geometry_data(entity_path, final_transform)
558
559            # Collect data for TCP geometries
560            if self.tcp_geometries:
561                tcp_transform = transforms[-1]  # End-effector transform
562                for i, geom in enumerate(self.tcp_geometries):
563                    entity_path = f"{self.base_entity_path}/safety_from_controller/tcp/geometry_{i}"
564                    final_transform = tcp_transform @ self.geometry_pose_to_matrix(geom.init_pose)
565                    self.init_geometry(entity_path, geom.capsule)
566                    collect_geometry_data(entity_path, final_transform)
567
568            # Collect data for collision link geometries
569            for link_index, geometries in enumerate(self.collision_link_geometries):
570                link_transform = transforms[link_index]
571                for i, geom_id in enumerate(geometries):
572                    entity_path = f"{self.base_entity_path}/collision/links/link_{link_index}/geometry_{geom_id}"
573
574                    pose = normalize_pose(geometries[geom_id].pose)
575
576                    final_transform = link_transform @ self.geometry_pose_to_matrix(pose)
577                    self.init_collision_geometry(entity_path, geometries[geom_id], pose)
578                    collect_geometry_data(entity_path, final_transform)
579
580            # Collect data for collision TCP geometries
581            if self.collision_tcp_geometries:
582                tcp_transform = transforms[-1]  # End-effector transform
583                for i, geom_id in enumerate(self.collision_tcp_geometries):
584                    entity_path = f"{self.base_entity_path}/collision/tcp/geometry_{geom_id}"
585
586                    pose = normalize_pose(self.collision_tcp_geometries[geom_id].pose)
587                    final_transform = tcp_transform @ self.geometry_pose_to_matrix(pose)
588
589                    # tcp collision geometries are defined in flange frame
590                    identity_pose = models.PlannerPose(
591                        position=models.Vector3d(x=0, y=0, z=0),
592                        orientation=models.Quaternion(x=0, y=0, z=0, w=1),
593                    )
594                    self.init_collision_geometry(
595                        entity_path, self.collision_tcp_geometries[geom_id], identity_pose
596                    )
597                    collect_geometry_data(entity_path, final_transform)
598
599        # Send collected columns for all geometries
600        for entity_path, positions in link_positions.items():
601            rr.send_columns(
602                entity_path,
603                indexes=[times_column],
604                columns=[
605                    *rr.Transform3D.columns(
606                        translation=positions, rotation_axis_angle=link_rotations[entity_path]
607                    )
608                ],
609            )

Log the robot geometries for each link and TCP as separate entities.

Args: trajectory (List[wb.models.TrajectorySample]): The list of trajectory sample points. times_column (rr.TimeSecondsColumn): The time column associated with the trajectory points.