Source code for pooltool.utils.dataclasses
import numpy as np
from attrs import astuple
def _array_safe_eq(a, b) -> bool:
"""Check if a and b are equal, even if they are numpy arrays"""
if a is b:
return True
if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
return np.array_equal(a, b, equal_nan=True)
try:
return a == b
except TypeError:
return NotImplemented
[docs]def are_dataclasses_equal(dc1, dc2) -> bool:
"""Check if two dataclasses which hold numpy arrays are equal
This is necessary to avoid ambiguous truthy comparisons, where numpy suggests using
all() and/or any().
"""
if dc1 is dc2:
return True
if dc1.__class__ is not dc2.__class__:
return NotImplemented # better than False
t1 = astuple(dc1)
t2 = astuple(dc2)
return all(_array_safe_eq(a1, a2) for a1, a2 in zip(t1, t2))