"""Houses all components that make up a table (pockets, cushions, etc)"""
from __future__ import annotations
import copy
from functools import cached_property
from typing import TypeVar
import numpy as np
from attrs import define, evolve, field
from numpy.typing import NDArray
import pooltool.ptmath as ptmath
from pooltool.utils.dataclasses import are_dataclasses_equal
[docs]
class CushionDirection:
"""An Enum for the direction of a cushion
Important for constructing cushions if simulation performance speed is required.
For most table geometries, the playing surface only exists on one side of the
cushion, so collisions only need to be checked for one direction. This direction can
be specified with this class's attributes.
Attributes:
SIDE1: Use side 1.
SIDE2: Use side 2.
BOTH: Use sides 1 and 2.
Unfortunately, the rule governing whether to use :attr:`SIDE1` or :attr:`SIDE2` is
not clear and instead requires experimentation.
If :attr:`BOTH` is used, both collision checks are performed which makes collision
checks twice as slow.
Note:
This used to inherit from ``Enum``, but accessing the cushion direction in
ball-vs-linear-cushion-segment detection took up 20% of the function's
runtime, so it was removed.
"""
SIDE1 = 0
SIDE2 = 1
BOTH = 2
[docs]
@define(eq=False, frozen=True, slots=False)
class LinearCushionSegment:
"""A linear cushion segment defined by a skinny cylinder between points :math:`p_1` and :math:`p_2`.
Attributes:
id:
The ID of the cushion segment.
p1:
The 3D coordinate where the cushion segment starts.
Note:
- ``p1`` and ``p2`` must share the same height (``p1[2] == p2[2]``).
p2:
The 3D coordinate where the cushion segment ends.
Note:
- ``p1`` and ``p2`` must share the same height (``p1[2] == p2[2]``).
nose_radius:
The radius of the cylinder defined between points ``p1`` and ``p2``. This is
referred to as the cushion nose radius.
direction:
The cushion direction (*default* =
:attr:`pooltool.objects.CushionDirection.BOTH`).
See :class:`pooltool.objects.CushionDirection` for explanation.
"""
id: str
p1: NDArray[np.float64]
p2: NDArray[np.float64]
nose_radius: float
direction: int = field(default=CushionDirection.BOTH)
def __attrs_post_init__(self):
# Segment must have constant height
assert self.p1[2] == self.p2[2]
# p1 and p2 are read only
self.p1.flags["WRITEABLE"] = False
self.p2.flags["WRITEABLE"] = False
def __eq__(self, other):
return are_dataclasses_equal(self, other)
[docs]
@cached_property
def height(self) -> float:
"""The height of the cushion."""
return self.p1[2]
[docs]
@cached_property
def lx(self) -> float:
"""The x-coefficient (:math:`l_x`) of the cushion's 2D general form line equation.
The cushion's general form line equation in the :math:`XY` plane (*i.e.*
dismissing the z-component) is
.. math::
l_x x + l_y y + l_0 = 0
where
.. math::
\\begin{align*}
l_x &= -\\frac{p_{2y} - p_{1y}}{p_{2x} - p_{1x}} \\\\
l_y &= 1 \\\\
l_0 &= \\frac{p_{2y} - p_{1y}}{p_{2x} - p_{1x}} p_{1x} - p_{1y} \\\\
\\end{align*}
"""
p1x, p1y, _ = self.p1
p2x, p2y, _ = self.p2
return 1 if (p2x - p1x) == 0 else -(p2y - p1y) / (p2x - p1x)
[docs]
@cached_property
def ly(self) -> float:
"""The x-coefficient (:math:`l_y`) of the cushion's 2D general form line equation.
See :meth:`lx` for definition.
"""
return 0 if (self.p2[0] - self.p1[0]) == 0 else 1
[docs]
@cached_property
def l0(self) -> float:
"""The constant term (:math:`l_0`) of the cushion's 2D general form line equation.
See :meth:`lx` for definition.
"""
p1x, p1y, _ = self.p1
p2x, p2y, _ = self.p2
return -p1x if (p2x - p1x) == 0 else (p2y - p1y) / (p2x - p1x) * p1x - p1y
[docs]
@cached_property
def unit_axis(self) -> NDArray[np.float64]:
"""The unit vector :math:`\\frac{p_2 - p_1}{\\|p_2 - p_1\\|}`."""
axis = self.p2 - self.p1
return axis / ptmath.norm3d(axis)
[docs]
@cached_property
def normal(self) -> NDArray[np.float64]:
"""The line's normal vector, with the z-component zeroed prior to normalization.
Warning:
The returned normal vector is arbitrarily directed, meaning it may point
away from the table surface, rather than towards it. This nonideality is
properly handled in downstream simulation logic, however if you're using
this method for custom purposes, you may want to reverse the direction of
this vector by negating it.
"""
return ptmath.unit_vector(np.array([self.lx, self.ly, 0]))
[docs]
def get_normal_xy(self, xyz: NDArray[np.float64]) -> NDArray[np.float64]:
"""Calculates the normal vector for a ball contacting the cushion.
Warning:
The returned normal vector is arbitrarily directed, meaning it may point
away from the table surface, rather than towards it. This nonideality is
properly handled in downstream simulation logic, however if you're using
this method for custom purposes, you may want to reverse the direction of
this vector by negating it.
Args:
xyz:
The position of the ball.
See ``xyz`` property of :class:`pooltool.objects.BallState`.
Returns:
NDArray[np.float64]:
The line's normal vector, with the z-component zeroed prior to normalization.
Note:
- This method only exists for call signature parity with
:meth:`pooltool.objects.CircularCushionSegment.get_normal_xy`. Consider using
:meth:`normal` instead.
"""
return self.normal
[docs]
def get_normal_3d(self, xyz: NDArray[np.float64]) -> NDArray[np.float64]:
"""Calculates the 3D normal vector for a point contacting the cushion.
This method computes the normal by finding the component of the vector from
:attr:`p1` to the contact point that is perpendicular to the cushion's
:attr:`unit_axis`. Mathematically, this is achieved by subtracting the
projection of the position vector onto the cushion's axis from the position
vector itself, yielding the perpendicular component which defines the normal
direction.
Warning:
The returned normal vector is arbitrarily directed, meaning it may point
away from the table surface, rather than towards it. This nonideality is
properly handled in downstream simulation logic, however if you're using
this method for custom purposes, you may want to reverse the direction of
this vector by negating it.
Args:
xyz:
The 3D coordinate of the contacting point.
Returns:
NDArray[np.float64]:
The 3D normal vector pointing outward from the cushion surface.
"""
r = xyz - self.p1
return ptmath.unit_vector(r - np.dot(r, self.unit_axis) * self.unit_axis)
[docs]
def copy(self) -> LinearCushionSegment:
"""Create a copy"""
# LinearCushionSegment is a frozen instance, and its attributes are either (a)
# immutable, or (b) have read-only flags set. It is sufficient to simply return
# oneself.
return self
[docs]
@staticmethod
def dummy() -> LinearCushionSegment:
return LinearCushionSegment(
id="dummy",
p1=np.array([0, 0, 1]),
p2=np.array([1, 1, 1]),
nose_radius=0.005,
)
[docs]
@define(frozen=True, eq=False, slots=False)
class CircularCushionSegment:
"""A circular cushion segment defined by a circle center and radius
Attributes:
id:
The ID of the cushion segment.
center:
A length-3 array specifying the circular cushion's center.
``center[0]``, ``center[1]``, and ``center[2]`` are the x-, y-, and
z-coordinates of the cushion's center. The circle is assumed to be parallel to
the XY plane, which makes ``center[2]`` is the height of the cushion.
radius:
The radius of the cushion segment.
"""
id: str
center: NDArray[np.float64]
radius: float
def __eq__(self, other):
return are_dataclasses_equal(self, other)
def __attrs_post_init__(self):
assert len(self.center) == 3
# center is read only
self.center.flags["WRITEABLE"] = False
[docs]
@cached_property
def height(self) -> float:
"""The height of the cushion."""
return self.center[2]
[docs]
@cached_property
def a(self) -> float:
"""The x-coordinate of the cushion's center."""
return self.center[0]
[docs]
@cached_property
def b(self) -> float:
"""The y-coordinate of the cushion's center."""
return self.center[1]
[docs]
def get_normal_xy(self, xyz: NDArray[np.float64]) -> NDArray[np.float64]:
"""Calculates the normal vector for a ball contacting the cushion
Assumes that the ball is in fact in contact with the cushion.
Args:
xyz:
The position of the ball. (See ``xyz`` property of class
`pooltool.objects.BallState`).
Returns:
NDArray[np.float64]:
The normal vector, with the z-component zeroed prior to normalization.
"""
normal = xyz - self.center
normal[2] = 0 # remove z-component
return ptmath.unit_vector(normal)
[docs]
def get_normal_3d(self, xyz: NDArray[np.float64]) -> NDArray[np.float64]:
"""Calculates the 3D normal vector for a point contacting the cushion.
Args:
xyz:
The 3D coordinate of the contacting point.
Returns:
NDArray[np.float64]:
The 3D normal vector pointing outward from the cushion surface.
"""
return ptmath.unit_vector(xyz - self.center)
[docs]
def copy(self) -> CircularCushionSegment:
"""Create a copy"""
# CircularCushionSegment is a frozen instance, and its attributes are either (a)
# immutable, or (b) have read-only flags set. It is sufficient to simply return
# oneself.
return self
[docs]
@staticmethod
def dummy() -> CircularCushionSegment:
return CircularCushionSegment(
id="dummy", center=np.array([0, 0, 0], dtype=np.float64), radius=10.0
)
[docs]
@define
class CushionSegments:
"""A collection of cushion segments
Cushion segments can be either linear (see
:class:`pooltool.objects.LinearCushionSegment`) or circular (see
:class:`pooltool.objects.CircularCushionSegment`). This class stores both.
Attributes:
linear:
A dictionary of linear cushion segments.
Warning:
Keys must match the value IDs, *e.g.* ``{"2":
LinearCushionSegment(id="2", ...)}``
circular:
A dictionary of circular cushion segments.
Warning:
Keys must match the value IDs, *e.g.* ``{"2t":
CircularCushionSegment(id="2t", ...)}``
"""
linear: dict[str, LinearCushionSegment] = field()
circular: dict[str, CircularCushionSegment] = field()
@linear.validator # type: ignore
@circular.validator # type: ignore
def _keys_match_value_ids(self, _, attribute) -> None:
for key, val in attribute.items():
assert key == val.id, f"Key '{key}' mismatch with ID '{val.id}'"
[docs]
def copy(self) -> CushionSegments:
"""Create a copy"""
# Delegates the deep-ish copying of LinearCushionSegment and
# CircularCushionSegment elements to their respective copy() methods. Uses
# dictionary comprehensions to construct equal but different `linear` and
# `circular` attributes.
return evolve(
self,
linear={k: v.copy() for k, v in self.linear.items()},
circular={k: v.copy() for k, v in self.circular.items()},
)
[docs]
@define(eq=False, frozen=True, slots=False)
class Pocket:
"""A circular pocket
Attributes:
id:
The ID of the pocket.
center:
A length-3 array specifying the pocket's position.
- ``center[0]`` is the x-coordinate of the pocket's center
- ``center[1]`` is the y-coordinate of the pocket's center
- ``center[2]`` must be 0.0
radius:
The radius of the pocket.
depth:
How deep the pocket is.
contains:
Stores the ball IDs of pocketed balls (*default* = ``set()``).
"""
id: str
center: NDArray[np.float64]
radius: float
depth: float = field(default=0.08)
contains: set = field(factory=set)
def __attrs_post_init__(self):
assert len(self.center) == 3
assert self.center[2] == 0
# center is read only
self.center.flags["WRITEABLE"] = False
def __eq__(self, other):
return are_dataclasses_equal(self, other)
[docs]
@cached_property
def a(self) -> float:
"""The x-coordinate of the pocket's center."""
return self.center[0]
[docs]
@cached_property
def b(self) -> float:
"""The y-coordinate of the pocket's center."""
return self.center[1]
[docs]
def add(self, ball_id: str) -> None:
"""Add a ball ID to :attr:`contains`"""
self.contains.add(ball_id)
[docs]
def remove(self, ball_id: str) -> None:
"""Remove a ball ID from :attr:`contains`"""
self.contains.remove(ball_id)
[docs]
def copy(self) -> Pocket:
"""Create a copy"""
# Pocket is a frozen instance, and except for `contains`, its attributes are
# either (a) immutable, or (b) have read-only flags set. Therefore, only a copy
# of `contains` needs to be made. Since it's members are strs (immutable), a
# shallow copy suffices.
return evolve(self, contains=copy.copy(self.contains))
[docs]
@staticmethod
def dummy() -> Pocket:
return Pocket(id="dummy", center=np.array([0, 0, 0]), radius=10)
Cushion = TypeVar("Cushion", LinearCushionSegment, CircularCushionSegment)