# This code is part of Qiskit.
#
# (C) Copyright IBM 2017.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""
A module for using quaternions.
"""
from __future__ import annotations
import math
import numpy as np
[docs]class Quaternion:
"""A class representing a Quaternion."""
def __init__(self, data):
self.data = np.asarray(data, dtype=float)
def __call__(self, idx):
return self.data[idx]
def __repr__(self):
return np.array_str(self.data)
def __str__(self):
return np.array_str(self.data)
def __mul__(self, r):
if isinstance(r, Quaternion):
q = self
out_data = np.zeros(4, dtype=float)
out_data[0] = r(0) * q(0) - r(1) * q(1) - r(2) * q(2) - r(3) * q(3)
out_data[1] = r(0) * q(1) + r(1) * q(0) - r(2) * q(3) + r(3) * q(2)
out_data[2] = r(0) * q(2) + r(1) * q(3) + r(2) * q(0) - r(3) * q(1)
out_data[3] = r(0) * q(3) - r(1) * q(2) + r(2) * q(1) + r(3) * q(0)
return Quaternion(out_data)
else:
raise Exception("Multiplication by other not supported.")
[docs] def norm(self):
"""Norm of quaternion."""
import scipy.linalg as la
return la.norm(self.data)
[docs] def normalize(self, inplace: bool = False) -> Quaternion:
"""Normalizes a Quaternion to unit length
so that it represents a valid rotation.
Args:
inplace (bool): Do an inplace normalization.
Returns:
Quaternion: Normalized quaternion.
"""
if inplace:
nrm = self.norm()
self.data /= nrm
return None
nrm = self.norm()
data_copy = np.array(self.data, copy=True)
data_copy /= nrm
return Quaternion(data_copy)
[docs] def to_matrix(self) -> np.ndarray:
"""Converts a unit-length quaternion to a rotation matrix.
Returns:
ndarray: Rotation matrix.
"""
w, x, y, z = self.normalize().data
mat = np.array(
[
[1 - 2 * y**2 - 2 * z**2, 2 * x * y - 2 * z * w, 2 * x * z + 2 * y * w],
[2 * x * y + 2 * z * w, 1 - 2 * x**2 - 2 * z**2, 2 * y * z - 2 * x * w],
[2 * x * z - 2 * y * w, 2 * y * z + 2 * x * w, 1 - 2 * x**2 - 2 * y**2],
],
dtype=float,
)
return mat
[docs] def to_zyz(self) -> np.ndarray:
"""Converts a unit-length quaternion to a sequence
of ZYZ Euler angles.
Returns:
ndarray: Array of Euler angles.
"""
mat = self.to_matrix()
euler = np.zeros(3, dtype=float)
if mat[2, 2] < 1:
if mat[2, 2] > -1:
euler[0] = math.atan2(mat[1, 2], mat[0, 2])
euler[1] = math.acos(mat[2, 2])
euler[2] = math.atan2(mat[2, 1], -mat[2, 0])
else:
euler[0] = -math.atan2(mat[1, 0], mat[1, 1])
euler[1] = np.pi
else:
euler[0] = math.atan2(mat[1, 0], mat[1, 1])
return euler
[docs] @classmethod
def from_axis_rotation(cls, angle: float, axis: str) -> Quaternion:
"""Return quaternion for rotation about given axis.
Args:
angle (float): Angle in radians.
axis (str): Axis for rotation
Returns:
Quaternion: Quaternion for axis rotation.
Raises:
ValueError: Invalid input axis.
"""
out = np.zeros(4, dtype=float)
if axis == "x":
out[1] = 1
elif axis == "y":
out[2] = 1
elif axis == "z":
out[3] = 1
else:
raise ValueError("Invalid axis input.")
out *= math.sin(angle / 2.0)
out[0] = math.cos(angle / 2.0)
return cls(out)
[docs] @classmethod
def from_euler(cls, angles: list | np.ndarray, order: str = "yzy") -> Quaternion:
"""Generate a quaternion from a set of Euler angles.
Args:
angles (array_like): Array of Euler angles.
order (str): Order of Euler rotations. 'yzy' is default.
Returns:
Quaternion: Quaternion representation of Euler rotation.
"""
angles = np.asarray(angles, dtype=float)
quat = (
cls.from_axis_rotation(angles[0], order[0])
* cls.from_axis_rotation(angles[1], order[1])
* cls.from_axis_rotation(angles[2], order[2])
)
quat.normalize(inplace=True)
return quat