Source code for pooltool.ptmath.roots._quartic_numba

"""Translation of the "1010" quartic root-finding algorithm with modifications.

The original implementation is written in C, and this module was written by Claude Code.

Modifications from the original algorithm:
    - The threshold for falling back to an alternative factorization method has been
      loosened by a safety factor (d2_safety_factor). The original threshold was too
      strict for certain edge cases where the discriminant d2 is very small but nonzero,
      causing the algorithm to produce duplicate roots instead of four distinct roots.
      This is particularly relevant for Newton's cradle-like simulations where balls
      move with very similar velocities.

Solve speed:

* Numba (this implementation): 2.8 million quartics / second
* C (original implementation): 3.3 million quartics / second

References:
    @article{10.1145/3386241,
        author = {Orellana, Alberto Giacomo and Michele, Cristiano De},
        title = {Algorithm 1010: Boosting Efficiency in Solving Quartic Equations with No Compromise in Accuracy},
        year = {2020},
        issue_date = {June 2020},
        publisher = {Association for Computing Machinery},
        address = {New York, NY, USA},
        volume = {46},
        number = {2},
        issn = {0098-3500},
        url = {https://doi.org/10.1145/3386241},
        doi = {10.1145/3386241},
        abstract = {Aiming to provide a very accurate, efficient, and robust quartic
        equation solver for physical applications, we have proposed an algorithm that
        builds on the previous works of P. Strobach and S. L. Shmakov. It is based on
        the decomposition of the quartic polynomial into two quadratics, whose
        coefficients are first accurately estimated by handling carefully numerical
        errors and afterward refined through the use of the Newton-Raphson method. Our
        algorithm is very accurate in comparison with other state-of-the-art solvers
        that can be found in the literature, but (most importantly) it turns out to be
        very efficient according to our timing tests. A crucial issue for us is the
        robustness of the algorithm, i.e., its ability to cope with the detrimental
        effect of round-off errors, no matter what set of quartic coefficients is
        provided in a practical application. In this respect, we extensively tested our
        algorithm in comparison to other quartic equation solvers both by considering
        specific extreme cases and by carrying out a statistical analysis over a very
        large set of quartics. Our algorithm has also been heavily tested in a physical
        application, i.e., simulations of hard cylinders, where it proved its absolute
        reliability as well as its efficiency.},
        journal = {ACM Trans. Math. Softw.},
        month = may,
        articleno = {20},
        numpages = {28},
        keywords = {Newton-Raphson scheme, Quartic equation, factorization into quadratics, numerical solver design, performance}
    }
"""

import math

import numpy as np
from numba import jit
from numpy.typing import NDArray

import pooltool.constants as const

cubic_rescal_fact = 3.488062113727083e102
quart_rescal_fact = 7.156344627944542e76
macheps = 2.2204460492503131e-16
d2_safety_factor = 100


@jit(nopython=True, cache=const.use_numba_cache)
def oqs_solve_cubic_analytic_depressed_handle_inf(b, c):
    PI2 = math.pi / 2.0
    TWOPI = 2.0 * math.pi
    Q = -b / 3.0
    R = 0.5 * c
    if R == 0:
        if b <= 0:
            sol = math.sqrt(-b)
        else:
            sol = 0.0
        return sol

    if abs(Q) < abs(R):
        QR = Q / R
        QRSQ = QR * QR
        KK = 1.0 - Q * QRSQ
    else:
        RQ = R / Q
        KK = math.copysign(1.0, Q) * (RQ * RQ / Q - 1.0)

    if KK < 0.0:
        sqrtQ = math.sqrt(Q)
        theta = math.acos((R / abs(Q)) / sqrtQ)
        if theta < PI2:
            sol = -2.0 * sqrtQ * math.cos(theta / 3.0)
        else:
            sol = -2.0 * sqrtQ * math.cos((theta + TWOPI) / 3.0)
    else:
        if abs(Q) < abs(R):
            A = -math.copysign(1.0, R) * math.pow(
                abs(R) * (1.0 + math.sqrt(KK)), 1.0 / 3.0
            )
        else:
            A = -math.copysign(1.0, R) * math.pow(
                abs(R) + math.sqrt(abs(Q)) * abs(Q) * math.sqrt(KK), 1.0 / 3.0
            )
        if A == 0.0:
            B = 0.0
        else:
            B = Q / A
        sol = A + B
    return sol


@jit(nopython=True, cache=const.use_numba_cache)
def oqs_solve_cubic_analytic_depressed(b, c):
    Q = -b / 3.0
    R = 0.5 * c
    if abs(Q) > 1e102 or abs(R) > 1e154:
        sol = oqs_solve_cubic_analytic_depressed_handle_inf(b, c)
        return sol
    Q3 = Q * Q * Q
    R2 = R * R
    if R2 < Q3:
        theta = math.acos(R / math.sqrt(Q3))
        sqrtQ = -2.0 * math.sqrt(Q)
        if theta < math.pi / 2:
            sol = sqrtQ * math.cos(theta / 3.0)
        else:
            sol = sqrtQ * math.cos((theta + 2.0 * math.pi) / 3.0)
    else:
        A = -math.copysign(1.0, R) * math.pow(abs(R) + math.sqrt(R2 - Q3), 1.0 / 3.0)
        if A == 0.0:
            B = 0.0
        else:
            B = Q / A
        sol = A + B
    return sol


@jit(nopython=True, cache=const.use_numba_cache)
def oqs_calc_phi0(a, b, c, d, scaled):
    diskr = 9 * a * a - 24 * b
    if diskr > 0.0:
        diskr = math.sqrt(diskr)
        if a > 0.0:
            s = -2 * b / (3 * a + diskr)
        else:
            s = -2 * b / (3 * a - diskr)
    else:
        s = -a / 4

    aq = a + 4 * s
    bq = b + 3 * s * (a + 2 * s)
    cq = c + s * (2 * b + s * (3 * a + 4 * s))
    dq = d + s * (c + s * (b + s * (a + s)))
    gg = bq * bq / 9
    hh = aq * cq

    g = hh - 4 * dq - 3 * gg
    h = (8 * dq + hh - 2 * gg) * bq / 3 - cq * cq - dq * aq * aq
    rmax = oqs_solve_cubic_analytic_depressed(g, h)
    if math.isnan(rmax) or math.isinf(rmax):
        rmax = oqs_solve_cubic_analytic_depressed_handle_inf(g, h)
        if (math.isnan(rmax) or math.isinf(rmax)) and scaled:
            rfact = cubic_rescal_fact
            rfactsq = rfact * rfact
            ggss = gg / rfactsq
            hhss = hh / rfactsq
            dqss = dq / rfactsq
            aqs = aq / rfact
            bqs = bq / rfact
            cqs = cq / rfact
            ggss = bqs * bqs / 9.0
            hhss = aqs * cqs
            g = hhss - 4.0 * dqss - 3.0 * ggss
            h = (
                (8.0 * dqss + hhss - 2.0 * ggss) * bqs / 3
                - cqs * (cqs / rfact)
                - (dq / rfact) * aqs * aqs
            )
            rmax = oqs_solve_cubic_analytic_depressed(g, h)
            if math.isnan(rmax) or math.isinf(rmax):
                rmax = oqs_solve_cubic_analytic_depressed_handle_inf(g, h)
            rmax *= rfact

    x = rmax
    xsq = x * x
    xxx = x * xsq
    gx = g * x
    f = x * (xsq + g) + h
    if abs(xxx) > abs(gx):
        maxtt = abs(xxx)
    else:
        maxtt = abs(gx)
    if abs(h) > maxtt:
        maxtt = abs(h)

    if abs(f) > macheps * maxtt:
        for iter in range(8):
            df = 3.0 * xsq + g
            if df == 0:
                break
            xold = x
            x += -f / df
            fold = f
            xsq = x * x
            f = x * (xsq + g) + h
            if f == 0:
                break

            if abs(f) >= abs(fold):
                x = xold
                break
    phi0 = x
    return phi0


@jit(nopython=True, cache=const.use_numba_cache)
def oqs_calc_err_ldlt(b, c, d, d2, l1, l2, l3):
    if b == 0:
        sum = abs(d2 + l1 * l1 + 2.0 * l3)
    else:
        sum = abs(((d2 + l1 * l1 + 2.0 * l3) - b) / b)
    if c == 0:
        sum += abs(2.0 * d2 * l2 + 2.0 * l1 * l3)
    else:
        sum += abs(((2.0 * d2 * l2 + 2.0 * l1 * l3) - c) / c)
    if d == 0:
        sum += abs(d2 * l2 * l2 + l3 * l3)
    else:
        sum += abs(((d2 * l2 * l2 + l3 * l3) - d) / d)
    return sum


@jit(nopython=True, cache=const.use_numba_cache)
def oqs_calc_err_abcd_cmplx(a, b, c, d, aq, bq, cq, dq):
    if d == 0:
        sum = abs(bq * dq)
    else:
        sum = abs((bq * dq - d) / d)
    if c == 0:
        sum += abs(bq * cq + aq * dq)
    else:
        sum += abs(((bq * cq + aq * dq) - c) / c)
    if b == 0:
        sum += abs(bq + aq * cq + dq)
    else:
        sum += abs(((bq + aq * cq + dq) - b) / b)
    if a == 0:
        sum += abs(aq + cq)
    else:
        sum += abs(((aq + cq) - a) / a)
    return sum


@jit(nopython=True, cache=const.use_numba_cache)
def oqs_calc_err_abcd(a, b, c, d, aq, bq, cq, dq):
    if d == 0:
        sum = abs(bq * dq)
    else:
        sum = abs((bq * dq - d) / d)
    if c == 0:
        sum += abs(bq * cq + aq * dq)
    else:
        sum += abs(((bq * cq + aq * dq) - c) / c)
    if b == 0:
        sum += abs(bq + aq * cq + dq)
    else:
        sum += abs(((bq + aq * cq + dq) - b) / b)
    if a == 0:
        sum += abs(aq + cq)
    else:
        sum += abs(((aq + cq) - a) / a)
    return sum


@jit(nopython=True, cache=const.use_numba_cache)
def oqs_calc_err_abc(a, b, c, aq, bq, cq, dq):
    if c == 0:
        sum = abs(bq * cq + aq * dq)
    else:
        sum = abs(((bq * cq + aq * dq) - c) / c)
    if b == 0:
        sum += abs(bq + aq * cq + dq)
    else:
        sum += abs(((bq + aq * cq + dq) - b) / b)
    if a == 0:
        sum += abs(aq + cq)
    else:
        sum += abs(((aq + cq) - a) / a)
    return sum


@jit(nopython=True, cache=const.use_numba_cache)
def oqs_NRabcd(a, b, c, d, AQ, BQ, CQ, DQ):
    x0, x1, x2, x3 = AQ, BQ, CQ, DQ

    fvec0 = x1 * x3 - d
    fvec1 = x1 * x2 + x0 * x3 - c
    fvec2 = x1 + x0 * x2 + x3 - b
    fvec3 = x0 + x2 - a

    if d == 0:
        errf = abs(fvec0)
    else:
        errf = abs(fvec0 / d)
    if c == 0:
        errf += abs(fvec1)
    else:
        errf += abs(fvec1 / c)
    if b == 0:
        errf += abs(fvec2)
    else:
        errf += abs(fvec2 / b)
    if a == 0:
        errf += abs(fvec3)
    else:
        errf += abs(fvec3 / a)

    for iter in range(8):
        x02 = x0 - x2
        det = x1 * x1 + x1 * (-x2 * x02 - 2.0 * x3) + x3 * (x0 * x02 + x3)
        if det == 0.0:
            break

        J00 = x02
        J01 = x3 - x1
        J02 = x1 * x2 - x0 * x3
        J03 = -x1 * J01 - x0 * J02
        J10 = x0 * J00 + J01
        J11 = -x1 * J00
        J12 = -x1 * J01
        J13 = -x1 * J02
        J20 = -J00
        J21 = -J01
        J22 = -J02
        J23 = J02 * x2 + J01 * x3
        J30 = -x2 * J00 - J01
        J31 = J00 * x3
        J32 = x3 * J01
        J33 = x3 * J02

        dx0 = J00 * fvec0 + J01 * fvec1 + J02 * fvec2 + J03 * fvec3
        dx1 = J10 * fvec0 + J11 * fvec1 + J12 * fvec2 + J13 * fvec3
        dx2 = J20 * fvec0 + J21 * fvec1 + J22 * fvec2 + J23 * fvec3
        dx3 = J30 * fvec0 + J31 * fvec1 + J32 * fvec2 + J33 * fvec3

        xold0, xold1, xold2, xold3 = x0, x1, x2, x3

        x0 += -dx0 / det
        x1 += -dx1 / det
        x2 += -dx2 / det
        x3 += -dx3 / det

        fvec0 = x1 * x3 - d
        fvec1 = x1 * x2 + x0 * x3 - c
        fvec2 = x1 + x0 * x2 + x3 - b
        fvec3 = x0 + x2 - a

        errfold = errf
        if d == 0:
            errf = abs(fvec0)
        else:
            errf = abs(fvec0 / d)
        if c == 0:
            errf += abs(fvec1)
        else:
            errf += abs(fvec1 / c)
        if b == 0:
            errf += abs(fvec2)
        else:
            errf += abs(fvec2 / b)
        if a == 0:
            errf += abs(fvec3)
        else:
            errf += abs(fvec3 / a)

        if errf == 0:
            break
        if errf >= errfold:
            x0, x1, x2, x3 = xold0, xold1, xold2, xold3
            break

    return x0, x1, x2, x3


@jit(nopython=True, cache=const.use_numba_cache)
def oqs_solve_quadratic(a, b):
    diskr = a * a - 4 * b
    if diskr >= 0.0:
        if a >= 0.0:
            div = -a - math.sqrt(diskr)
        else:
            div = -a + math.sqrt(diskr)

        zmax = div / 2

        if zmax == 0.0:
            zmin = 0.0
        else:
            zmin = b / zmax

        root0 = zmax + 0.0j
        root1 = zmin + 0.0j
    else:
        sqrtd = math.sqrt(-diskr)
        root0 = -a / 2 + sqrtd / 2 * 1j
        root1 = -a / 2 - sqrtd / 2 * 1j
    return (root0, root1)


[docs] @jit(nopython=True, cache=const.use_numba_cache) def solve_many(ps: NDArray[np.float64]) -> NDArray[np.complex128]: num_eqn = ps.shape[0] roots = np.zeros((num_eqn, 4), dtype=np.complex128) for i in range(num_eqn): p = ps[i, :] roots[i, :] = solve(p[0], p[1], p[2], p[3], p[4]) return roots
[docs] @jit(nopython=True, cache=const.use_numba_cache) def solve(a: float, b: float, c: float, d: float, e: float) -> NDArray[np.complex128]: """Solve quartic equation ``a*t^4 + b*t^3 + c*t^2 + d*t + e = 0``. Args: a: Coefficient of ``t^4``. Must be nonzero. b: Coefficient of ``t^3``. c: Coefficient of ``t^2``. d: Coefficient of ``t``. e: Constant term. Returns: Array of 4 complex roots. Raises: ValueError: If ``a == 0`` (the polynomial is not a quartic). Note: TODO: Currently, this handles genuine quartics only, i.e. ``a != 0``. For all collision-detection contexts that we're aware of, ``a == 0`` always implies ``b == 0``. Thus, the cubic case (``a == 0, b != 0``) is unreachable. Thus, until this function handles ``a == 0``, the workaround is to detect ``a == 0`` upstream and dispatch to the quadratic solver instead (after also verifying that ``b==0``). """ if a == 0.0: raise ValueError( "Leading coefficient a is zero; this is not a quartic. " "Use a cubic/quadratic/linear solver, or handle the degenerate case at " "the call site." ) roots = np.zeros(4, dtype=np.complex128) a_p = b / a b_p = c / a c_p = d / a d_p = e / a phi0 = oqs_calc_phi0(a_p, b_p, c_p, d_p, 0) rfact = 1.0 if math.isnan(phi0) or math.isinf(phi0): rfact = quart_rescal_fact a_p /= rfact rfactsq = rfact * rfact b_p /= rfactsq c_p /= rfactsq * rfact d_p /= rfactsq * rfactsq phi0 = oqs_calc_phi0(a_p, b_p, c_p, d_p, 1) l1 = a_p / 2 l3 = b_p / 6 + phi0 / 2 del2 = c_p - a_p * l3 nsol = 0 bl311 = 2.0 * b_p / 3.0 - phi0 - l1 * l1 dml3l3 = d_p - l3 * l3 d2m_0, d2m_1, d2m_2 = 0.0, 0.0, 0.0 l2m_0, l2m_1, l2m_2 = 0.0, 0.0, 0.0 res_0, res_1, res_2 = 0.0, 0.0, 0.0 if bl311 != 0.0: d2m_0 = bl311 l2m_0 = del2 / (2.0 * d2m_0) res_0 = oqs_calc_err_ldlt(b_p, c_p, d_p, d2m_0, l1, l2m_0, l3) nsol = 1 if del2 != 0: if nsol == 0: l2m_0 = 2 * dml3l3 / del2 if l2m_0 != 0: d2m_0 = del2 / (2 * l2m_0) res_0 = oqs_calc_err_ldlt(b_p, c_p, d_p, d2m_0, l1, l2m_0, l3) nsol = 1 elif nsol == 1: l2m_1 = 2 * dml3l3 / del2 if l2m_1 != 0: d2m_1 = del2 / (2 * l2m_1) res_1 = oqs_calc_err_ldlt(b_p, c_p, d_p, d2m_1, l1, l2m_1, l3) nsol = 2 if nsol == 1: d2m_1 = bl311 l2m_1 = 2.0 * dml3l3 / del2 res_1 = oqs_calc_err_ldlt(b_p, c_p, d_p, d2m_1, l1, l2m_1, l3) nsol = 2 elif nsol == 2: d2m_2 = bl311 l2m_2 = 2.0 * dml3l3 / del2 res_2 = oqs_calc_err_ldlt(b_p, c_p, d_p, d2m_2, l1, l2m_2, l3) nsol = 3 if nsol == 0: l2 = 0.0 d2 = 0.0 elif nsol == 1: d2 = d2m_0 l2 = l2m_0 elif nsol == 2: if res_0 <= res_1: d2 = d2m_0 l2 = l2m_0 else: d2 = d2m_1 l2 = l2m_1 else: if res_0 <= res_1 and res_0 <= res_2: d2 = d2m_0 l2 = l2m_0 elif res_1 <= res_2: d2 = d2m_1 l2 = l2m_1 else: d2 = d2m_2 l2 = l2m_2 whichcase = 0 realcase_0 = -1 realcase_1 = -1 aq = 0.0 bq = 0.0 cq = 0.0 dq = 0.0 aq1 = 0.0 bq1 = 0.0 cq1 = 0.0 dq1 = 0.0 acx = 0.0 + 0.0j bcx = 0.0 + 0.0j ccx = 0.0 + 0.0j dcx = 0.0 + 0.0j acx1 = 0.0 + 0.0j bcx1 = 0.0 + 0.0j ccx1 = 0.0 + 0.0j dcx1 = 0.0 + 0.0j err0 = 0.0 err1 = 0.0 if d2 < 0.0: gamma = math.sqrt(-d2) aq = l1 + gamma bq = l3 + gamma * l2 cq = l1 - gamma dq = l3 - gamma * l2 if abs(dq) < abs(bq): dq = d_p / bq elif abs(dq) > abs(bq): bq = d_p / dq if abs(aq) < abs(cq): aqv_0, aqv_1, aqv_2 = 0.0, 0.0, 0.0 errv_0, errv_1, errv_2 = 0.0, 0.0, 0.0 nsol = 0 if dq != 0: aqv_0 = (c_p - bq * cq) / dq errv_0 = oqs_calc_err_abc(a_p, b_p, c_p, aqv_0, bq, cq, dq) nsol = 1 if cq != 0: if nsol == 0: aqv_0 = (b_p - dq - bq) / cq errv_0 = oqs_calc_err_abc(a_p, b_p, c_p, aqv_0, bq, cq, dq) nsol = 1 else: aqv_1 = (b_p - dq - bq) / cq errv_1 = oqs_calc_err_abc(a_p, b_p, c_p, aqv_1, bq, cq, dq) nsol = 2 if nsol == 0: aqv_0 = a_p - cq errv_0 = oqs_calc_err_abc(a_p, b_p, c_p, aqv_0, bq, cq, dq) aq = aqv_0 elif nsol == 1: aqv_1 = a_p - cq errv_1 = oqs_calc_err_abc(a_p, b_p, c_p, aqv_1, bq, cq, dq) if errv_0 <= errv_1: aq = aqv_0 else: aq = aqv_1 else: aqv_2 = a_p - cq errv_2 = oqs_calc_err_abc(a_p, b_p, c_p, aqv_2, bq, cq, dq) if errv_0 <= errv_1 and errv_0 <= errv_2: aq = aqv_0 elif errv_1 <= errv_2: aq = aqv_1 else: aq = aqv_2 else: cqv_0, cqv_1, cqv_2 = 0.0, 0.0, 0.0 errv_0, errv_1, errv_2 = 0.0, 0.0, 0.0 nsol = 0 if bq != 0: cqv_0 = (c_p - aq * dq) / bq errv_0 = oqs_calc_err_abc(a_p, b_p, c_p, aq, bq, cqv_0, dq) nsol = 1 if aq != 0: if nsol == 0: cqv_0 = (b_p - bq - dq) / aq errv_0 = oqs_calc_err_abc(a_p, b_p, c_p, aq, bq, cqv_0, dq) nsol = 1 else: cqv_1 = (b_p - bq - dq) / aq errv_1 = oqs_calc_err_abc(a_p, b_p, c_p, aq, bq, cqv_1, dq) nsol = 2 if nsol == 0: cqv_0 = a_p - aq errv_0 = oqs_calc_err_abc(a_p, b_p, c_p, aq, bq, cqv_0, dq) cq = cqv_0 elif nsol == 1: cqv_1 = a_p - aq errv_1 = oqs_calc_err_abc(a_p, b_p, c_p, aq, bq, cqv_1, dq) if errv_0 <= errv_1: cq = cqv_0 else: cq = cqv_1 else: cqv_2 = a_p - aq errv_2 = oqs_calc_err_abc(a_p, b_p, c_p, aq, bq, cqv_2, dq) if errv_0 <= errv_1 and errv_0 <= errv_2: cq = cqv_0 elif errv_1 <= errv_2: cq = cqv_1 else: cq = cqv_2 realcase_0 = 1 elif d2 > 0: gamma = math.sqrt(d2) acx = complex(l1, gamma) bcx = complex(l3, gamma * l2) ccx = acx.conjugate() dcx = bcx.conjugate() realcase_0 = 0 else: realcase_0 = -1 if realcase_0 == -1 or ( abs(d2) <= d2_safety_factor * macheps * max(abs(2.0 * b_p / 3.0), abs(phi0), l1 * l1) ): d3 = d_p - l3 * l3 if realcase_0 == 1: err0 = oqs_calc_err_abcd(a_p, b_p, c_p, d_p, aq, bq, cq, dq) elif realcase_0 == 0: err0 = oqs_calc_err_abcd_cmplx(a_p, b_p, c_p, d_p, acx, bcx, ccx, dcx) if d3 <= 0: realcase_1 = 1 aq1 = l1 bq1 = l3 + math.sqrt(-d3) cq1 = l1 dq1 = l3 - math.sqrt(-d3) if abs(dq1) < abs(bq1): dq1 = d_p / bq1 elif abs(dq1) > abs(bq1): bq1 = d_p / dq1 err1 = oqs_calc_err_abcd(a_p, b_p, c_p, d_p, aq1, bq1, cq1, dq1) else: realcase_1 = 0 acx1 = complex(l1, 0.0) bcx1 = complex(l3, math.sqrt(d3)) ccx1 = complex(l1, 0.0) dcx1 = bcx1.conjugate() err1 = oqs_calc_err_abcd_cmplx(a_p, b_p, c_p, d_p, acx1, bcx1, ccx1, dcx1) if realcase_0 == -1 or err1 < err0: whichcase = 1 if realcase_1 == 1: aq = aq1 bq = bq1 cq = cq1 dq = dq1 else: acx = acx1 bcx = bcx1 ccx = ccx1 dcx = dcx1 if (whichcase == 0 and realcase_0 == 1) or (whichcase == 1 and realcase_1 == 1): aq, bq, cq, dq = oqs_NRabcd(a_p, b_p, c_p, d_p, aq, bq, cq, dq) roots[0], roots[1] = oqs_solve_quadratic(aq, bq) roots[2], roots[3] = oqs_solve_quadratic(cq, dq) else: if whichcase == 0: cdiskr = acx * acx / 4 - bcx zx1 = -acx / 2 + np.sqrt(cdiskr) zx2 = -acx / 2 - np.sqrt(cdiskr) if abs(zx1) > abs(zx2): zxmax = zx1 else: zxmax = zx2 zxmin = bcx / zxmax roots[0] = zxmin roots[1] = zxmin.conjugate() roots[2] = zxmax roots[3] = zxmax.conjugate() else: cdiskr = np.sqrt(acx * acx - 4.0 * bcx) zx1 = -0.5 * (acx + cdiskr) zx2 = -0.5 * (acx - cdiskr) if abs(zx1) > abs(zx2): zxmax = zx1 else: zxmax = zx2 zxmin = bcx / zxmax roots[0] = zxmax roots[1] = zxmin cdiskr = np.sqrt(ccx * ccx - 4.0 * dcx) zx1 = -0.5 * (ccx + cdiskr) zx2 = -0.5 * (ccx - cdiskr) if abs(zx1) > abs(zx2): zxmax = zx1 else: zxmax = zx2 zxmin = dcx / zxmax roots[2] = zxmax roots[3] = zxmin if rfact != 1.0: for k in range(4): roots[k] *= rfact return roots