Source code for solrat.atom_model.shared.utility.wigner_3j_6j_9j

from functools import lru_cache

import numpy as np
from numpy import abs, sqrt

from solrat.engine.functions.general import fact2


def _w3j_doubled_argument(j1_doubled, j2_doubled, j3_doubled, m1_doubled, m2_doubled, m3_doubled):
    """
    float Wigner 3J symbol where all arguments are doubled to be integer

    j1_doubled = 2 * J1: int
    j2_doubled = 2 * J2: int
    j3_doubled = 2 * J3: int
    m1_doubled = 2 * M1: int
    m2_doubled = 2 * M2: int
    m3_doubled = 2 * M3: int

    Reference: Appendix A1

    ( J1 J2 J3 )
    ( M1 M2 M3 )
    """

    assert m1_doubled + m2_doubled + m3_doubled == 0, "M1 + M2 + M3 != 0."

    if (abs(m1_doubled) > j1_doubled) or (abs(m2_doubled) > j2_doubled) or (abs(m3_doubled) > j3_doubled):
        return 0.0
    a = j1_doubled + j2_doubled
    if j3_doubled > a:
        return 0.0
    b = j1_doubled - j2_doubled
    if j3_doubled < abs(b):
        return 0.0
    j_sum = j3_doubled + a
    c = j1_doubled - m1_doubled
    d = j2_doubled - m2_doubled

    assert j_sum % 2 == 0, "J1 + J2 + J3 != even."
    assert c % 2 == 0, "J1 - M1 != even."
    assert d % 2 == 0, "J2 - M2 != even."

    e = j3_doubled - j2_doubled + m1_doubled
    f = j3_doubled - j1_doubled - m2_doubled
    z_min = max(0, -e, -f)
    g = a - j3_doubled
    h = j2_doubled + m2_doubled
    z_max = min(g, h, c)
    result = 0.0
    for z in range(int(z_min), int(z_max) + 1, 2):
        denominator = fact2(z) * fact2(g - z) * fact2(c - z) * fact2(h - z) * fact2(e + z) * fact2(f + z)
        if z % 4 != 0:
            denominator = -denominator
        result += 1 / denominator
    cc1 = fact2(g) * fact2(j3_doubled + b) * fact2(j3_doubled - b) / fact2(j_sum + 2)
    cc2 = (
        fact2(j1_doubled + m1_doubled)
        * fact2(c)
        * fact2(h)
        * fact2(d)
        * fact2(j3_doubled - m3_doubled)
        * fact2(j3_doubled + m3_doubled)
    )
    result *= sqrt(cc1 * cc2)
    if (b - m3_doubled) % 4 != 0:
        result = -result
    return result


def _w6j_doubled_argument(j1_doubled, j2_doubled, j3_doubled, l1_doubled, l2_doubled, l3_doubled):
    """
    float Wigner 6J symbol where all arguments are doubled to be integer

    j1_doubled = 2 * J1: int
    j2_doubled = 2 * J2: int
    j3_doubled = 2 * J3: int
    l1_doubled = 2 * L1: int
    l2_doubled = 2 * L2: int
    l3_doubled = 2 * L3: int

    Reference: Appendix A1

    { J1 J2 J3 }
    { L1 L2 L3 }
    """
    a = j1_doubled + j2_doubled
    b = j1_doubled - j2_doubled
    c = j1_doubled + l2_doubled
    d = j1_doubled - l2_doubled
    e = l1_doubled + j2_doubled
    f = l1_doubled - j2_doubled
    g = l1_doubled + l2_doubled
    h = l1_doubled - l2_doubled

    if (a < j3_doubled) or (c < l3_doubled) or (e < l3_doubled) or (g < j3_doubled):
        # logging.warning("Performance warning: J1 + J2 < J3 or L1 + L2 < L3 or L1 + J2 < L3 or L1 + L2 < J3")
        return 0.0

    if (abs(b) > j3_doubled) or (abs(d) > l3_doubled) or (abs(f) > l3_doubled) or (abs(h) > j3_doubled):
        # logging.warning("Performance warning: J1 - J2 > J3 or J1 - L2 > L3 or L1 - J2 > L3 or L1 - L2 > J3")
        return 0.0

    sum_1 = a + j3_doubled
    sum_2 = c + l3_doubled
    sum_3 = e + l3_doubled
    sum_4 = g + j3_doubled

    assert sum_1 % 2 == 0, "J1 + J2 + J3 != even."
    assert sum_2 % 2 == 0, "J1 + L2 + L3 != even."
    assert sum_3 % 2 == 0, "L1 + J2 + L3 != even."

    w_min = max(sum_1, sum_2, sum_3, sum_4)
    i = a + g
    j = j2_doubled + j3_doubled + l2_doubled + l3_doubled
    k = j3_doubled + j1_doubled + l3_doubled + l1_doubled
    w_max = min(i, j, k)

    result = 0.0
    for w in range(int(w_min), int(w_max) + 1, 2):
        denominator = (
            fact2(w - sum_1)
            * fact2(w - sum_2)
            * fact2(w - sum_3)
            * fact2(w - sum_4)
            * fact2(i - w)
            * fact2(j - w)
            * fact2(k - w)
        )
        if w % 4 != 0:
            denominator = -denominator
        result += fact2(w + 2) / denominator

    theta1 = fact2(a - j3_doubled) * fact2(j3_doubled + b) * fact2(j3_doubled - b) / fact2(sum_1 + 2)
    theta2 = fact2(c - l3_doubled) * fact2(l3_doubled + d) * fact2(l3_doubled - d) / fact2(sum_2 + 2)
    theta3 = fact2(e - l3_doubled) * fact2(l3_doubled + f) * fact2(l3_doubled - f) / fact2(sum_3 + 2)
    theta4 = fact2(g - j3_doubled) * fact2(j3_doubled + h) * fact2(j3_doubled - h) / fact2(sum_4 + 2)
    result = result * sqrt(theta1 * theta2 * theta3 * theta4)
    return result


def _w9j_doubled_argument(
    j1_doubled,
    j2_doubled,
    j3_doubled,
    j4_doubled,
    j5_doubled,
    j6_doubled,
    j7_doubled,
    j8_doubled,
    j9_doubled,
):
    """
    float Wigner 9J symbol where all arguments are doubled to be integer

    j1_doubled = 2 * J1: int
    j2_doubled = 2 * J2: int
    j3_doubled = 2 * J3: int
    j4_doubled = 2 * J4: int
    j5_doubled = 2 * J5: int
    j6_doubled = 2 * J6: int
    j7_doubled = 2 * J7: int
    j8_doubled = 2 * J8: int
    j9_doubled = 2 * J9: int

    Reference: Appendix A1

    { J1 J2 J3 }
    { J4 J5 J6 }
    { J7 J8 J9 }
    """
    k_min = max(
        abs(j1_doubled - j9_doubled),
        abs(j4_doubled - j8_doubled),
        abs(j2_doubled - j6_doubled),
    )

    k_max = min(
        abs(j1_doubled + j9_doubled),
        abs(j4_doubled + j8_doubled),
        abs(j2_doubled + j6_doubled),
    )
    result = 0
    for k in range(int(k_min), int(k_max) + 1, 2):
        s = -1 if k % 2 != 0 else 1
        x1 = _w6j_doubled_argument(j1_doubled, j9_doubled, k, j8_doubled, j4_doubled, j7_doubled)
        x2 = _w6j_doubled_argument(j2_doubled, j6_doubled, k, j4_doubled, j8_doubled, j5_doubled)
        x3 = _w6j_doubled_argument(j1_doubled, j9_doubled, k, j6_doubled, j2_doubled, j3_doubled)
        result += s * x1 * x2 * x3 * (k + 1)
    return result


# vectorize
_w3j_doubled_argument_vec = np.vectorize(_w3j_doubled_argument)
_w6j_doubled_argument_vec = np.vectorize(_w6j_doubled_argument)

# Todo investigate caching vs vectorization performance.


[docs] def wigner_3j(j1, j2, j3, m1, m2, m3): return _w3j_doubled_argument_vec(j1 * 2, j2 * 2, j3 * 2, m1 * 2, m2 * 2, m3 * 2)
[docs] def wigner_6j(j1, j2, j3, l1, l2, l3): return _w6j_doubled_argument_vec(j1 * 2, j2 * 2, j3 * 2, l1 * 2, l2 * 2, l3 * 2)
[docs] @lru_cache(maxsize=None) def wigner_9j(j1, j2, j3, j4, j5, j6, j7, j8, j9): return _w9j_doubled_argument(j1 * 2, j2 * 2, j3 * 2, j4 * 2, j5 * 2, j6 * 2, j7 * 2, j8 * 2, j9 * 2)