Skip to content

Commit 6074eb7

Browse files
Add BSplineSE3 class. (#128)
1 parent 43d9a95 commit 6074eb7

File tree

4 files changed

+144
-3
lines changed

4 files changed

+144
-3
lines changed

spatialmath/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616
from spatialmath.quaternion import Quaternion, UnitQuaternion
1717
from spatialmath.DualQuaternion import DualQuaternion, UnitDualQuaternion
18+
from spatialmath.spline import BSplineSE3
1819

1920
# from spatialmath.Plucker import *
2021
# from spatialmath import base as smb
@@ -43,6 +44,7 @@
4344
"LineSegment2",
4445
"Polygon2",
4546
"Ellipse",
47+
"BSplineSE3",
4648
]
4749

4850
try:

spatialmath/base/animate.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def update(frame, animation):
212212
# assume it is an SO(3) or SE(3)
213213
T = frame
214214
# ensure result is SE(3)
215+
215216
if T.shape == (3, 3):
216217
T = smb.r2t(T)
217218

@@ -308,7 +309,7 @@ def __init__(self, anim: Animate, h, xs, ys, zs):
308309
self.anim = anim
309310

310311
def draw(self, T):
311-
p = T @ self.p
312+
p = T.A @ self.p
312313
self.h.set_data(p[0, :], p[1, :])
313314
self.h.set_3d_properties(p[2, :])
314315

@@ -365,7 +366,8 @@ def __init__(self, anim, h):
365366
self.anim = anim
366367

367368
def draw(self, T):
368-
p = T @ self.p
369+
# import ipdb; ipdb.set_trace()
370+
p = T.A @ self.p
369371

370372
# reshape it
371373
p = p[0:3, :].T.reshape(3, 2, 3)
@@ -419,7 +421,7 @@ def __init__(self, anim, h, x, y, z):
419421
self.anim = anim
420422

421423
def draw(self, T):
422-
p = T @ self.p
424+
p = T.A @ self.p
423425
# x2, y2, _ = proj3d.proj_transform(
424426
# p[0], p[1], p[2], self.anim.ax.get_proj())
425427
# self.h.set_position((x2, y2))

spatialmath/spline.py

+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright (c) 2024 Boston Dynamics AI Institute LLC.
2+
# MIT Licence, see details in top-level file: LICENCE
3+
4+
"""
5+
Classes for parameterizing a trajectory in SE3 with B-splines.
6+
7+
Copies parts of the API from scipy's B-spline class.
8+
"""
9+
10+
from typing import Any, Dict, List, Optional
11+
from scipy.interpolate import BSpline
12+
from spatialmath import SE3
13+
import numpy as np
14+
import matplotlib.pyplot as plt
15+
from spatialmath.base.transforms3d import tranimate, trplot
16+
17+
18+
class BSplineSE3:
19+
"""A class to parameterize a trajectory in SE3 with a 6-dimensional B-spline.
20+
21+
The SE3 control poses are converted to se3 twists (the lie algebra) and a B-spline
22+
is created for each dimension of the twist, using the corresponding element of the twists
23+
as the control point for the spline.
24+
25+
For detailed information about B-splines, please see this wikipedia article.
26+
https://en.wikipedia.org/wiki/Non-uniform_rational_B-spline
27+
"""
28+
29+
def __init__(
30+
self,
31+
control_poses: List[SE3],
32+
degree: int = 3,
33+
knots: Optional[List[float]] = None,
34+
) -> None:
35+
"""Construct BSplineSE3 object. The default arguments generate a cubic B-spline
36+
with uniformly spaced knots.
37+
38+
- control_poses: list of SE3 objects that govern the shape of the spline.
39+
- degree: int that controls degree of the polynomial that governs any given point on the spline.
40+
- knots: list of floats that govern which control points are active during evaluating the spline
41+
at a given t input. If none, they are automatically, uniformly generated based on number of control poses and
42+
degree of spline.
43+
"""
44+
45+
self.control_poses = control_poses
46+
47+
# a matrix where each row is a control pose as a twist
48+
# (so each column is a vector of control points for that dim of the twist)
49+
self.control_pose_matrix = np.vstack(
50+
[np.array(element.twist()) for element in control_poses]
51+
)
52+
53+
self.degree = degree
54+
55+
if knots is None:
56+
knots = np.linspace(0, 1, len(control_poses) - degree + 1, endpoint=True)
57+
knots = np.append(
58+
[0.0] * degree, knots
59+
) # ensures the curve starts on the first control pose
60+
knots = np.append(
61+
knots, [1] * degree
62+
) # ensures the curve ends on the last control pose
63+
self.knots = knots
64+
65+
self.splines = [
66+
BSpline(knots, self.control_pose_matrix[:, i], degree)
67+
for i in range(0, 6) # twists are length 6
68+
]
69+
70+
def __call__(self, t: float) -> SE3:
71+
"""Returns pose of spline at t.
72+
73+
t: Normalized time value [0,1] to evaluate the spline at.
74+
"""
75+
twist = np.hstack([spline(t) for spline in self.splines])
76+
return SE3.Exp(twist)
77+
78+
def visualize(
79+
self,
80+
num_samples: int,
81+
length: float = 1.0,
82+
repeat: bool = False,
83+
ax: Optional[plt.Axes] = None,
84+
kwargs_trplot: Dict[str, Any] = {"color": "green"},
85+
kwargs_tranimate: Dict[str, Any] = {"wait": True},
86+
kwargs_plot: Dict[str, Any] = {},
87+
) -> None:
88+
"""Displays an animation of the trajectory with the control poses."""
89+
out_poses = [self(t) for t in np.linspace(0, 1, num_samples)]
90+
x = [pose.x for pose in out_poses]
91+
y = [pose.y for pose in out_poses]
92+
z = [pose.z for pose in out_poses]
93+
94+
if ax is None:
95+
fig = plt.figure(figsize=(10, 10))
96+
ax = fig.add_subplot(projection="3d")
97+
98+
trplot(
99+
[np.array(self.control_poses)], ax=ax, length=length, **kwargs_trplot
100+
) # plot control points
101+
ax.plot(x, y, z, **kwargs_plot) # plot x,y,z trajectory
102+
103+
tranimate(
104+
out_poses, repeat=repeat, length=length, **kwargs_tranimate
105+
) # animate pose along trajectory

tests/test_spline.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import numpy.testing as nt
2+
import numpy as np
3+
import matplotlib.pyplot as plt
4+
import unittest
5+
import sys
6+
import pytest
7+
8+
from spatialmath import BSplineSE3, SE3
9+
10+
11+
class TestBSplineSE3(unittest.TestCase):
12+
control_poses = [
13+
SE3.Trans([e, 2 * np.cos(e / 2 * np.pi), 2 * np.sin(e / 2 * np.pi)])
14+
* SE3.Ry(e / 8 * np.pi)
15+
for e in range(0, 8)
16+
]
17+
18+
@classmethod
19+
def tearDownClass(cls):
20+
plt.close("all")
21+
22+
def test_constructor(self):
23+
BSplineSE3(self.control_poses)
24+
25+
def test_evaluation(self):
26+
spline = BSplineSE3(self.control_poses)
27+
nt.assert_almost_equal(spline(0).A, self.control_poses[0].A)
28+
nt.assert_almost_equal(spline(1).A, self.control_poses[-1].A)
29+
30+
def test_visualize(self):
31+
spline = BSplineSE3(self.control_poses)
32+
spline.visualize(num_samples=100, repeat=False)

0 commit comments

Comments
 (0)