import pooltool.constants as const
from pooltool.events.datatypes import EventType
from pooltool.events.filter import by_ball, by_type, filter_events, filter_type
from pooltool.objects.ball.datatypes import Ball, BallState
from pooltool.ruleset.datatypes import ShotConstraints
from pooltool.system.datatypes import System
from pooltool.utils.strenum import StrEnum, auto
[docs]
def get_pocketed_ball_ids_during_shot(
shot: System, exclude: set[str] | None = None
) -> list[str]:
"""Get list of ball IDs pocketed during the shot
See also get_pocketed_ball_ids
"""
if exclude is None:
exclude = set()
return [
event.agents[0].id
for event in filter_type(shot.events, EventType.BALL_POCKET)
if event.agents[0].id not in exclude
]
[docs]
def get_pocketed_ball_ids(shot: System) -> list[str]:
"""Get list of ball IDs that are in the pocketed state (by end of shot)
See also get_pocketed_ball_ids_during_shot
"""
return [ball.id for ball in shot.balls.values() if ball.state.s == const.pocketed]
[docs]
def get_id_of_first_ball_hit(shot: System, cue: str = "cue") -> str | None:
cue_collisions = filter_events(
shot.events,
by_ball(cue),
by_type(EventType.BALL_BALL),
)
if not len(cue_collisions):
return None
id1, id2 = cue_collisions[0].ids
return id1 if id1 != cue else id2
[docs]
def is_ball_pocketed(shot: System, ball_id: str) -> bool:
return any(
ball_id in event.agents[0].id
for event in filter_type(shot.events, EventType.BALL_POCKET)
)
[docs]
def is_ball_pocketed_in_pocket(shot: System, ball_id: str, pocket_id: str) -> bool:
for event in filter_type(shot.events, EventType.BALL_POCKET):
agent1, agent2 = event.ids
if ball_id == agent1 and pocket_id == agent2:
return True
return False
[docs]
def is_target_group_hit_first(
shot: System, target_balls: tuple[str, ...], cue: str
) -> bool:
return get_id_of_first_ball_hit(shot, cue=cue) in target_balls
[docs]
def respot(
shot: System, ball_id: str, x: float, y: float, z: float | None = None
) -> None:
"""Respot a ball
Args:
z:
If not provided, z is set to the ball's radius
Notes
=====
- FIXME check if respot position overlaps with ball
"""
R = shot.balls[ball_id].params.R
if z is None:
z = R
if z > R:
raise NotImplementedError("No airborne state exists")
# state = "airborne"
else:
state = const.stationary
shot.balls[ball_id].state.rvw[0] = [x, y, z]
shot.balls[ball_id].state.s = state
[docs]
def get_ball_ids_on_table(
shot: System, at_start: bool, exclude: set[str] | None = None
) -> set[str]:
history_idx = 0 if at_start else -1
return set(
ball.id
for ball in shot.balls.values()
if ball.history[history_idx].s in const.on_table
and (exclude is None or ball.id not in exclude)
)
class StateProbe(StrEnum):
CURRENT = auto()
START = auto()
END = auto()
def _probe_ball_state(ball: Ball, when: StateProbe, simulated: bool) -> BallState:
if not simulated:
return ball.state
if when is StateProbe.CURRENT:
return ball.state
elif when is StateProbe.START:
return ball.history[0]
else:
return ball.history[-1]
[docs]
def get_lowest_ball(shot: System, when: StateProbe) -> Ball:
"""Get the lowest ball on the table at start or end of shot
Args:
at_start:
If True, the lowest ball on the table at t=0 is calculated. If False,
the lowest ball at the end of the shot (t=inf) is calculated. The latter
returns a different result if the lowest ball on the table was pocketed
"""
_dummy = "10000"
lowest = Ball.dummy(id=_dummy)
for ball in shot.balls.values():
if ball.id == "cue":
continue
if _probe_ball_state(ball, when, shot.simulated).s == const.pocketed:
continue
if int(ball.id) < int(lowest.id):
lowest = ball
assert lowest.id != _dummy, "No numbered balls on table"
return lowest
[docs]
def get_highest_ball(shot: System, at_start: bool) -> Ball:
"""Get the highest ball on the table at start or end of shot
Args:
at_start:
If True, the highest ball on the table at t=0 is calculated. If False,
the highest ball at the end of the shot (t=inf) is calculated. The latter
returns a different result if the highest ball on the table was pocketed
"""
_dummy = "0"
highest = Ball.dummy(id=_dummy)
history_idx = 0 if at_start else -1
for ball in shot.balls.values():
if ball.id == "cue":
continue
if ball.history[history_idx].s == const.pocketed:
continue
if int(ball.id) > int(highest.id):
highest = ball
assert highest.id != _dummy, "No numbered balls on table"
return highest
[docs]
def is_lowest_hit_first(shot: System) -> bool:
if (ball_id := get_id_of_first_ball_hit(shot, cue="cue")) is None:
return False
return get_lowest_ball(shot, when=StateProbe.START).id == ball_id
[docs]
def balls_that_hit_cushion(shot: System, exclude: set[str] | None = None) -> set[str]:
if exclude is None:
exclude = set()
numbered_ball_ids = [
ball.id for ball in shot.balls.values() if ball.id not in exclude
]
cushion_events = filter_events(
shot.events,
by_type([EventType.BALL_LINEAR_CUSHION, EventType.BALL_CIRCULAR_CUSHION]),
by_ball(numbered_ball_ids),
)
return set(event.agents[0].id for event in cushion_events)
[docs]
def is_ball_hit(shot: System) -> bool:
return bool(len(filter_events(shot.events, by_type(EventType.BALL_BALL))))
[docs]
def is_numbered_ball_pocketed(shot: System) -> bool:
return bool(len(get_pocketed_ball_ids_during_shot(shot, exclude={"cue"})))
[docs]
def is_shot_called_if_required(shot_constraints: ShotConstraints) -> bool:
if not shot_constraints.call_shot:
return True
if shot_constraints.ball_call is None or shot_constraints.pocket_call is None:
return False
return True