from collections.abc import Callable
import numpy as np
from numba import jit
from numpy.typing import NDArray
import pooltool.constants as const
from pooltool.ptmath.roots.core import (
get_real_positive_smallest_roots,
)
from pooltool.utils.strenum import StrEnum, auto
[docs]
class QuarticSolver(StrEnum):
HYBRID = auto()
NUMERIC = auto()
[docs]
def solve_quartics(
ps: NDArray[np.float64], solver: QuarticSolver = QuarticSolver.HYBRID
) -> NDArray[np.float64]:
"""Returns the smallest positive and real root for each quartic polynomial.
Args:
ps:
A mx5 array of polynomial coefficients, where m is the number of equations.
The columns are in the order a, b, c, d, e, where these coefficients make up
the quartic polynomial equation at^4 + bt^3 + ct^2 + dt + e = 0.
solver:
The method used to calculate the roots. See
pooltool.ptmath.roots.quartic.QuarticSolver.
Returns:
roots:
An array of shape m. Each value is the smallest root that is real and
positive. If no such root exists (e.g. all roots have complex), then
`np.inf` is returned.
"""
# Get the roots for the polynomials
assert QuarticSolver(solver)
roots = _quartic_routine[solver](ps) # Shape m x 4, dtype complex128
best_roots = get_real_positive_smallest_roots(roots) # Shape m, dtype float64
return best_roots
[docs]
def solve_many_numerical(p: NDArray[np.float64]) -> NDArray[np.complex128]:
"""Solve multiple polynomial equations using companion matrix eigenvalues
This is a vectorized implementation of numpy.roots that can solve multiple
polynomials in a vectorized fashion. The solution is taken from this wonderful
stackoverflow answer: https://stackoverflow.com/a/35853977
Args:
p:
A mxn array of polynomial coefficients, where m is the number of equations
and n-1 is the order of the polynomial. If n is 5 (4th order polynomial),
the columns are in the order a, b, c, d, e, where these coefficients make up
the polynomial equation at^4 + bt^3 + ct^2 + dt + e = 0
Notes:
- Not yet amenable to numbaization (0.56.4). Problem is the numba implementation of
np.linalg.eigvals, which only supports 2D arrays, but the strategy here is to pass
np.lingalg.eigvals as a vectorized 3D array. Nevertheless, here is a numba
implementation that is just slightly slower (7% slower) than this function:
n = p.shape[-1]
A = np.zeros(p.shape[:1] + (n - 1, n - 1), dtype=np.complex128)
A[:, 1:, :-1] = np.eye(n - 2)
p0 = np.copy(p[:, 0]).reshape((-1, 1))
A[:, 0, :] = -p[:, 1:] / p0
roots = np.zeros((p.shape[0], n - 1), dtype=np.complex128)
for i in range(p.shape[0]):
roots[i, :] = np.linalg.eigvals(A[i, :, :])
return roots
"""
n = p.shape[-1]
A = np.zeros(p.shape[:1] + (n - 1, n - 1), np.float64)
A[..., 1:, :-1] = np.eye(n - 2)
A[..., 0, :] = -p[..., 1:] / p[..., None, 0]
return np.linalg.eigvals(A) # type: ignore
[docs]
def solve_many(ps: NDArray[np.float64]) -> NDArray[np.complex128]:
"""Solve multiple quartic equations using analytical solutions when possible
Closed-form analytical solutions exist for the quartic polynomial equation, but can
suffer from severe numerical instability. Fortunately, the quality of an
analytically calculated roots can be determined by plugging them back into the
quartic and ensuring they evaluate the function to 0.
This function calculates roots to a quartic by analytically solving the quartic
polynomials. If the roots are inaccurate, an isomorphic polynomial is solved
analytically. If those roots are also inaccurate, the roots are solved using the
companion matrix eigenvalue method, which is very reliable, but slower.
Args:
p:
A mx5 array of polynomial coefficients, where m is the number of equations.
The columns are in the order a, b, c, d, e, where these coefficients make up
the polynomial equation at^4 + bt^3 + ct^2 + dt + e = 0
"""
roots, _ = _solve_many(ps.astype(np.complex128))
return roots
[docs]
@jit(nopython=True, cache=const.use_numba_cache)
def solve(a: float, b: float, c: float, d: float, e: float) -> NDArray[np.complex128]:
return _solve(np.array([a, b, c, d, e], dtype=np.complex128))[0]
@jit(nopython=True, cache=const.use_numba_cache)
def _solve_many(
ps: NDArray[np.complex128],
) -> tuple[NDArray[np.complex128], NDArray[np.uint8]]:
num_eqn = ps.shape[0]
all_roots = np.zeros((num_eqn, 4), dtype=np.complex128)
indicators = np.zeros(num_eqn, dtype=np.uint8)
for i in range(num_eqn):
all_roots[i, :], indicators[i] = _solve(ps[i, :])
return all_roots, indicators
@jit(nopython=True, cache=const.use_numba_cache)
def _solve(
p: NDArray[np.complex128], ftol: float = 1e-5
) -> tuple[NDArray[np.complex128], int]:
"""Solve a quartic with mixed strategy
Args:
ftol:
This is a very sensitive parameter and controls whether or not the
analytically calculated roots sufficiently satisfy the polynomial. After
much testing, I've determined that when the root evaluates to <1e-5, it's
nearly always numerically similar to the true root. I didn't find any bad
roots that evaluate to <1e-5. I did find good/decent roots evaluating to
>1e-5, but I'm happy to play it conservative and keep it at 1e-5--those
roots will always be caught by the numerical solution.
"""
e = p[-1].real
# This means t=0 is a root. No point solving the other roots, just return all 0s
if e == 0.0:
return np.zeros(4, dtype=np.complex128), 0
# Round-off error seems to be especially problematic for analytic solutions when e
# is small
if abs(e) < 1e-7:
return numeric(p), 3
# The analytic solutions don't like 0s
if (p == 0).any():
return numeric(p), 3
# Guess which of the two isomorphic polynomial equations is more likely to be
# numerically stable
reverse = instability(p[::-1]) < instability(p)
# Solve that polynomial first
if reverse:
soln_1 = 1.0 / analytic(p[::-1])
else:
soln_1 = analytic(p)
# Check whether the solved roots are genuine
for root in soln_1:
if abs(evaluate(p, root)) > ftol:
break
else:
return soln_1, 1
# The roots were bad. Try the other polynomial equation
if reverse:
soln_2 = analytic(p)
else:
soln_2 = 1.0 / analytic(p[::-1])
# Check whether the solved roots are genuine
for root in soln_2:
if abs(evaluate(p, root)) > ftol:
break
else:
return soln_2, 2
# The roots were bad. Resorting to companion matrix eigenvalues
return numeric(p), 3
[docs]
@jit(nopython=True, cache=const.use_numba_cache)
def evaluate(p: NDArray[np.complex128], val: complex) -> complex:
return p[0] * val**4 + p[1] * val**3 + p[2] * val**2 + p[3] * val + p[4]
[docs]
@jit(nopython=True, cache=const.use_numba_cache)
def instability(p: NDArray[np.complex128]) -> float:
"""Range is from [0, inf], 0 is most stable"""
a, b = p[:2]
if a == 0 or b == 0:
return 0.0
t = abs(a / b)
return t + 1 / t
[docs]
@jit(nopython=True, cache=const.use_numba_cache)
def numeric(p: NDArray[np.complex128]) -> NDArray[np.complex128]:
return np.roots(p).astype(np.complex128)
[docs]
@jit(nopython=True, cache=const.use_numba_cache)
def analytic(p: NDArray[np.complex128]) -> NDArray[np.complex128]:
"""Calculate a quartic's roots using the closed-form solution
This function was created with the help of sympy.
To start, I solved the general quartic polynomial roots:
>>> from sympy import symbols, Eq, solve
>>> x, a, b, c, d, e = symbols('x a b c d e')
>>> general_solution = solve(a*x**4 + b*x**3 + c*x**2 + d*x + e, x)
This yields 4 expressions, one for each root. Each expression is piecewise
conditional, where if the following equality is true, the first expression is used,
and otherwise the second expression is used.
>>> general_solution[0].args[0][1]
Eq(e/a - b*d/(4*a**2) + c**2/(12*a**2), 0)
So in total there are 8 expressions, 2 for each root, and the expression used for
each root is determined based on whether the above equality holds true. These
expressions are huge, so to better digest them, I used the following common
subexpression elimination:
>>> from sympy import cse
>>> cse(
>>> [
>>> general_solution[0].args[0][0],
>>> general_solution[0].args[1][0],
>>> general_solution[1].args[0][0],
>>> general_solution[1].args[1][0],
>>> general_solution[2].args[0][0],
>>> general_solution[2].args[1][0],
>>> general_solution[3].args[0][0],
>>> general_solution[3].args[1][0],
>>> ]
>>> )
Then I used a vim macro to convert these subexpressions into the lines of python
code you see below.
"""
# Convert to complex so we can take cubic root of negatives
a, b, c, d, e = p
if e == 0:
return np.array([0, np.nan, np.nan, np.nan], dtype=np.complex128)
x0 = 1 / a
x1 = c * x0
x2 = a ** (-2)
x3 = b**2
x4 = x2 * x3
x5 = x1 - 3 * x4 / 8
x6 = x5**3
x7 = d * x0
x8 = b * x2
x9 = c * x8
x10 = a ** (-3)
x11 = b**3 * x10
x12 = (x11 / 8 + x7 - x9 / 2) ** 2
x13 = -d * x8 / 4 + e * x0
x14 = c * x10 * x3 / 16 + x13 - 3 * b**4 / (256 * a**4)
x15 = -x12 / 8 + x14 * x5 / 3 - x6 / 108
x16 = 2 * x15 ** (1 / 3)
x17 = x11 / 4 + 2 * x7 - x9
x18 = 2 * x1 / 3 - x2 * x3 / 4
x19 = np.sqrt(-x16 - x18)
x20 = x17 / x19
x21 = 4 * x1 / 3
x22 = -x21 + x4 / 2
x23 = np.sqrt(x16 + x20 + x22) / 2
x24 = x19 / 2
x25 = b * x0 / 4
x26 = x24 + x25
x27 = -(c**2) * x2 / 12 - x13
x28 = (x12 / 16 - x14 * x5 / 6 + x6 / 216 + np.sqrt(x15**2 / 4 + x27**3 / 27)) ** (
1 / 3
) or const.EPS
x29 = 2 * x28
x30 = 2 * x27 / (3 * x28)
x31 = -x29 + x30
x32 = np.sqrt(-x18 - x31) or const.EPS
x33 = x17 / x32
x34 = np.sqrt(x22 + x31 + x33) / 2
x35 = x32 / 2
x36 = x25 + x35
x37 = -x2 * x3 / 2 + x21
x38 = np.sqrt(x16 - x20 - x37) / 2
x39 = np.sqrt(-x29 + x30 - x33 - x37) / 2
x40 = -x25
if abs(e / a - b * d / (4 * a**2) + c**2 / (12 * a**2)) < const.EPS:
roots = (
-x23 - x26,
x23 - x26,
x24 - x25 - x38,
x24 + x38 + x40,
)
else:
roots = (
-x34 - x36,
x34 - x36,
-x25 + x35 - x39,
x35 + x39 + x40,
)
return np.array(roots, dtype=np.complex128)
def _truth(a_val, b_val, c_val, d_val, e_val, digits=50):
import sympy # type: ignore
x, a, b, c, d, e = sympy.symbols("x a b c d e")
general_solution = sympy.solve(a * x**4 + b * x**3 + c * x**2 + d * x + e, x)
return [
sol.evalf(digits, subs={a: a_val, b: b_val, c: c_val, d: d_val, e: e_val})
for sol in general_solution
]
_quartic_routine: dict[QuarticSolver, Callable] = {
QuarticSolver.NUMERIC: solve_many_numerical,
QuarticSolver.HYBRID: solve_many,
}