Source code for pooltool.objects.table.components

"""Houses all components that make up a table (pockets, cushions, etc)"""

from __future__ import annotations

import copy
from functools import cached_property
from typing import TypeVar

import numpy as np
from attrs import define, evolve, field
from numpy.typing import NDArray

import pooltool.ptmath as ptmath
from pooltool.utils.dataclasses import are_dataclasses_equal


[docs] class CushionDirection: """An Enum for the direction of a cushion Important for constructing cushions if simulation performance speed is required. For most table geometries, the playing surface only exists on one side of the cushion, so collisions only need to be checked for one direction. This direction can be specified with this class's attributes. Attributes: SIDE1: Use side 1. SIDE2: Use side 2. BOTH: Use sides 1 and 2. Unfortunately, the rule governing whether to use :attr:`SIDE1` or :attr:`SIDE2` is not clear and instead requires experimentation. If :attr:`BOTH` is used, both collision checks are performed which makes collision checks twice as slow. Note: This used to inherit from ``Enum``, but accessing the cushion direction in ball-vs-linear-cushion-segment detection took up 20% of the function's runtime, so it was removed. """ SIDE1 = 0 SIDE2 = 1 BOTH = 2
[docs] @define(eq=False, frozen=True, slots=False) class LinearCushionSegment: """A linear cushion segment defined by a skinny cylinder between points :math:`p_1` and :math:`p_2`. Attributes: id: The ID of the cushion segment. p1: The 3D coordinate where the cushion segment starts. Note: - ``p1`` and ``p2`` must share the same height (``p1[2] == p2[2]``). p2: The 3D coordinate where the cushion segment ends. Note: - ``p1`` and ``p2`` must share the same height (``p1[2] == p2[2]``). nose_radius: The radius of the cylinder defined between points ``p1`` and ``p2``. This is referred to as the cushion nose radius. direction: The cushion direction (*default* = :attr:`pooltool.objects.CushionDirection.BOTH`). See :class:`pooltool.objects.CushionDirection` for explanation. """ id: str p1: NDArray[np.float64] p2: NDArray[np.float64] nose_radius: float direction: int = field(default=CushionDirection.BOTH) def __attrs_post_init__(self): # Segment must have constant height assert self.p1[2] == self.p2[2] # p1 and p2 are read only self.p1.flags["WRITEABLE"] = False self.p2.flags["WRITEABLE"] = False def __eq__(self, other): return are_dataclasses_equal(self, other)
[docs] @cached_property def height(self) -> float: """The height of the cushion.""" return self.p1[2]
[docs] @cached_property def lx(self) -> float: """The x-coefficient (:math:`l_x`) of the cushion's 2D general form line equation. The cushion's general form line equation in the :math:`XY` plane (*i.e.* dismissing the z-component) is .. math:: l_x x + l_y y + l_0 = 0 where .. math:: \\begin{align*} l_x &= -\\frac{p_{2y} - p_{1y}}{p_{2x} - p_{1x}} \\\\ l_y &= 1 \\\\ l_0 &= \\frac{p_{2y} - p_{1y}}{p_{2x} - p_{1x}} p_{1x} - p_{1y} \\\\ \\end{align*} """ p1x, p1y, _ = self.p1 p2x, p2y, _ = self.p2 return 1 if (p2x - p1x) == 0 else -(p2y - p1y) / (p2x - p1x)
[docs] @cached_property def ly(self) -> float: """The x-coefficient (:math:`l_y`) of the cushion's 2D general form line equation. See :meth:`lx` for definition. """ return 0 if (self.p2[0] - self.p1[0]) == 0 else 1
[docs] @cached_property def l0(self) -> float: """The constant term (:math:`l_0`) of the cushion's 2D general form line equation. See :meth:`lx` for definition. """ p1x, p1y, _ = self.p1 p2x, p2y, _ = self.p2 return -p1x if (p2x - p1x) == 0 else (p2y - p1y) / (p2x - p1x) * p1x - p1y
[docs] @cached_property def unit_axis(self) -> NDArray[np.float64]: """The unit vector :math:`\\frac{p_2 - p_1}{\\|p_2 - p_1\\|}`.""" axis = self.p2 - self.p1 return axis / ptmath.norm3d(axis)
[docs] @cached_property def normal(self) -> NDArray[np.float64]: """The line's normal vector, with the z-component zeroed prior to normalization. Warning: The returned normal vector is arbitrarily directed, meaning it may point away from the table surface, rather than towards it. This nonideality is properly handled in downstream simulation logic, however if you're using this method for custom purposes, you may want to reverse the direction of this vector by negating it. """ return ptmath.unit_vector(np.array([self.lx, self.ly, 0]))
[docs] def get_normal_xy(self, xyz: NDArray[np.float64]) -> NDArray[np.float64]: """Calculates the normal vector for a ball contacting the cushion. Warning: The returned normal vector is arbitrarily directed, meaning it may point away from the table surface, rather than towards it. This nonideality is properly handled in downstream simulation logic, however if you're using this method for custom purposes, you may want to reverse the direction of this vector by negating it. Args: xyz: The position of the ball. See ``xyz`` property of :class:`pooltool.objects.BallState`. Returns: NDArray[np.float64]: The line's normal vector, with the z-component zeroed prior to normalization. Note: - This method only exists for call signature parity with :meth:`pooltool.objects.CircularCushionSegment.get_normal_xy`. Consider using :meth:`normal` instead. """ return self.normal
[docs] def get_normal_3d(self, xyz: NDArray[np.float64]) -> NDArray[np.float64]: """Calculates the 3D normal vector for a point contacting the cushion. This method computes the normal by finding the component of the vector from :attr:`p1` to the contact point that is perpendicular to the cushion's :attr:`unit_axis`. Mathematically, this is achieved by subtracting the projection of the position vector onto the cushion's axis from the position vector itself, yielding the perpendicular component which defines the normal direction. Warning: The returned normal vector is arbitrarily directed, meaning it may point away from the table surface, rather than towards it. This nonideality is properly handled in downstream simulation logic, however if you're using this method for custom purposes, you may want to reverse the direction of this vector by negating it. Args: xyz: The 3D coordinate of the contacting point. Returns: NDArray[np.float64]: The 3D normal vector pointing outward from the cushion surface. """ r = xyz - self.p1 return ptmath.unit_vector(r - np.dot(r, self.unit_axis) * self.unit_axis)
[docs] def copy(self) -> LinearCushionSegment: """Create a copy""" # LinearCushionSegment is a frozen instance, and its attributes are either (a) # immutable, or (b) have read-only flags set. It is sufficient to simply return # oneself. return self
[docs] @staticmethod def dummy() -> LinearCushionSegment: return LinearCushionSegment( id="dummy", p1=np.array([0, 0, 1]), p2=np.array([1, 1, 1]), nose_radius=0.005, )
[docs] @define(frozen=True, eq=False, slots=False) class CircularCushionSegment: """A circular cushion segment defined by a circle center and radius Attributes: id: The ID of the cushion segment. center: A length-3 array specifying the circular cushion's center. ``center[0]``, ``center[1]``, and ``center[2]`` are the x-, y-, and z-coordinates of the cushion's center. The circle is assumed to be parallel to the XY plane, which makes ``center[2]`` is the height of the cushion. radius: The radius of the cushion segment. """ id: str center: NDArray[np.float64] radius: float def __eq__(self, other): return are_dataclasses_equal(self, other) def __attrs_post_init__(self): assert len(self.center) == 3 # center is read only self.center.flags["WRITEABLE"] = False
[docs] @cached_property def height(self) -> float: """The height of the cushion.""" return self.center[2]
[docs] @cached_property def a(self) -> float: """The x-coordinate of the cushion's center.""" return self.center[0]
[docs] @cached_property def b(self) -> float: """The y-coordinate of the cushion's center.""" return self.center[1]
[docs] def get_normal_xy(self, xyz: NDArray[np.float64]) -> NDArray[np.float64]: """Calculates the normal vector for a ball contacting the cushion Assumes that the ball is in fact in contact with the cushion. Args: xyz: The position of the ball. (See ``xyz`` property of class `pooltool.objects.BallState`). Returns: NDArray[np.float64]: The normal vector, with the z-component zeroed prior to normalization. """ normal = xyz - self.center normal[2] = 0 # remove z-component return ptmath.unit_vector(normal)
[docs] def get_normal_3d(self, xyz: NDArray[np.float64]) -> NDArray[np.float64]: """Calculates the 3D normal vector for a point contacting the cushion. Args: xyz: The 3D coordinate of the contacting point. Returns: NDArray[np.float64]: The 3D normal vector pointing outward from the cushion surface. """ return ptmath.unit_vector(xyz - self.center)
[docs] def copy(self) -> CircularCushionSegment: """Create a copy""" # CircularCushionSegment is a frozen instance, and its attributes are either (a) # immutable, or (b) have read-only flags set. It is sufficient to simply return # oneself. return self
[docs] @staticmethod def dummy() -> CircularCushionSegment: return CircularCushionSegment( id="dummy", center=np.array([0, 0, 0], dtype=np.float64), radius=10.0 )
[docs] @define class CushionSegments: """A collection of cushion segments Cushion segments can be either linear (see :class:`pooltool.objects.LinearCushionSegment`) or circular (see :class:`pooltool.objects.CircularCushionSegment`). This class stores both. Attributes: linear: A dictionary of linear cushion segments. Warning: Keys must match the value IDs, *e.g.* ``{"2": LinearCushionSegment(id="2", ...)}`` circular: A dictionary of circular cushion segments. Warning: Keys must match the value IDs, *e.g.* ``{"2t": CircularCushionSegment(id="2t", ...)}`` """ linear: dict[str, LinearCushionSegment] = field() circular: dict[str, CircularCushionSegment] = field() @linear.validator # type: ignore @circular.validator # type: ignore def _keys_match_value_ids(self, _, attribute) -> None: for key, val in attribute.items(): assert key == val.id, f"Key '{key}' mismatch with ID '{val.id}'"
[docs] def copy(self) -> CushionSegments: """Create a copy""" # Delegates the deep-ish copying of LinearCushionSegment and # CircularCushionSegment elements to their respective copy() methods. Uses # dictionary comprehensions to construct equal but different `linear` and # `circular` attributes. return evolve( self, linear={k: v.copy() for k, v in self.linear.items()}, circular={k: v.copy() for k, v in self.circular.items()}, )
[docs] @define(eq=False, frozen=True, slots=False) class Pocket: """A circular pocket Attributes: id: The ID of the pocket. center: A length-3 array specifying the pocket's position. - ``center[0]`` is the x-coordinate of the pocket's center - ``center[1]`` is the y-coordinate of the pocket's center - ``center[2]`` must be 0.0 radius: The radius of the pocket. depth: How deep the pocket is. contains: Stores the ball IDs of pocketed balls (*default* = ``set()``). """ id: str center: NDArray[np.float64] radius: float depth: float = field(default=0.08) contains: set = field(factory=set) def __attrs_post_init__(self): assert len(self.center) == 3 assert self.center[2] == 0 # center is read only self.center.flags["WRITEABLE"] = False def __eq__(self, other): return are_dataclasses_equal(self, other)
[docs] @cached_property def a(self) -> float: """The x-coordinate of the pocket's center.""" return self.center[0]
[docs] @cached_property def b(self) -> float: """The y-coordinate of the pocket's center.""" return self.center[1]
[docs] def add(self, ball_id: str) -> None: """Add a ball ID to :attr:`contains`""" self.contains.add(ball_id)
[docs] def remove(self, ball_id: str) -> None: """Remove a ball ID from :attr:`contains`""" self.contains.remove(ball_id)
[docs] def copy(self) -> Pocket: """Create a copy""" # Pocket is a frozen instance, and except for `contains`, its attributes are # either (a) immutable, or (b) have read-only flags set. Therefore, only a copy # of `contains` needs to be made. Since it's members are strs (immutable), a # shallow copy suffices. return evolve(self, contains=copy.copy(self.contains))
[docs] @staticmethod def dummy() -> Pocket: return Pocket(id="dummy", center=np.array([0, 0, 0]), radius=10)
Cushion = TypeVar("Cushion", LinearCushionSegment, CircularCushionSegment)