Source code for pooltool.events.filter

from collections.abc import Callable

from pooltool.events.datatypes import AgentType, Event, EventType

FilterFunc = Callable[[list[Event]], list[Event]]


[docs] def by_type(types: EventType | list[EventType]) -> FilterFunc: """Returns a function that filters events based on event type. Args: types: Event type(s) you want to include in your result. All others will be filtered. Returns: FilterFunc: A function that when passed a list of events, returns a filtered list containing only events matching the passed event type(s). """ def func(events: list[Event]) -> list[Event]: _types: EventType | list[EventType] if isinstance(types, str): _types = [types] else: _types = types new: list[Event] = [] for event in events: if event.event_type in _types: new.append(event) return new return func
[docs] def by_ball(ball_ids: str | list[str], keep_nonevent: bool = False) -> FilterFunc: """Returns a function that filters events based on ball IDs. Args: ball_ids: A collection of ball IDs. keep_nonevent: Retain non-events (:attr:`pooltool.events.EventType.NONE`). Returns: FilterFunc: A function that when passed a list of events, returns a filtered list containing only events that involve balls matching the passed ball ID(s). """ def func(events: list[Event]) -> list[Event]: _ball_ids: str | list[str] if isinstance(ball_ids, str): _ball_ids = [ball_ids] else: _ball_ids = ball_ids new: list[Event] = [] for event in events: if keep_nonevent and event.event_type == EventType.NONE: new.append(event) else: for agent in event.agents: if agent.id in _ball_ids and agent.agent_type == AgentType.BALL: new.append(event) break return new return func
[docs] def by_time(t: float, after: bool = True) -> FilterFunc: """Returns a function that filter events with respect to a time cutoff. Args: t: The cutoff time for filtering events. after: If ``True``, return events after time ``t`` (non-inclusive). If ``False``, return events before time ``t`` (non-inclusive). Returns: FilterFunc: A function that when passed a list of events, returns a filtered list containing only events before or after the cutoff time, non-inclusive. """ def func(events: list[Event]) -> list[Event]: if not events == sorted(events, key=lambda event: event.time): raise ValueError("Event lists must be chronological") new: list[Event] = [] for event in events: if after and event.time > t: new.append(event) elif not after and event.time < t: new.append(event) return new return func
def _chain(*funcs: FilterFunc) -> FilterFunc: def func(events: list[Event]) -> list[Event]: result = events for f in funcs: result = f(result) return result return func
[docs] def filter_events(events: list[Event], *funcs: FilterFunc) -> list[Event]: """Filter events using multiple criteria. A convenient way to filter based multiple filtering criteria. Args: events: A list of chronological events. *funcs: An arbitrary number of functions that take a list of events as input, and gives a subset of that list as input. It sounds laborious--it's not. See *Examples*. Returns: List[Event]: A filtered event list containing only events passing the supplied criteria. Examples: Generate a list of events. >>> import pooltool as pt >>> system = pt.System.example() >>> system.cue.set_state(a=0.68) >>> pt.simulate(system, inplace=True) >>> events = system.events In this shot, both the cue-ball and the 1-ball are potted. We are interested in filtering for the cue-ball pocket event. Option 1 is to call :func:`pooltool.events.filter_type` and then :func:`pooltool.events.filter_ball`: >>> filtered_events = pt.events.filter_type(events, pt.EventType.BALL_POCKET) >>> filtered_events = pt.events.filter_ball(filtered_events, "cue") >>> event_of_interest = filtered_events[0] >>> event_of_interest <Event object at 0x7fa855e7e6c0> ├── type : ball_pocket ├── time : 3.231130101576186 └── agents : ('cue', 'rt') Option 2, the better option, is to use :func:`pooltool.events.filter_events`: >>> filtered_events = pt.events.filter_events( >>> events, >>> pt.events.by_type(pt.EventType.BALL_POCKET), >>> pt.events.by_ball("cue"), >>> ) >>> event_of_interest = filtered_events[0] >>> event_of_interest <Event object at 0x7fa855e7e6c0> ├── type : ball_pocket ├── time : 3.231130101576186 └── agents : ('cue', 'rt') See Also: - If you're filtering based on a single criterion, you can consider using :func:`pooltool.events.filter_type`, :func:`pooltool.events.filter_ball`, :func:`pooltool.events.filter_time`, etc. """ return _chain(*funcs)(events)
[docs] def filter_type(events: list[Event], types: EventType | list[EventType]) -> list[Event]: """Filter events based on event type. Args: events: A list of chronological events. types: Event type(s) you want to include in your result. All others will be filtered. Returns: List[Event]: A filtered event list containing only events matching the passed event type(s). See Also: - If you're filtering based on multiple criteria, you can (and should!) use :func:`pooltool.events.filter_events`. """ return by_type(types)(events)
[docs] def filter_ball( events: list[Event], ball_ids: str | list[str], keep_nonevent: bool = False ) -> list[Event]: """Filter events based on ball IDs. Args: events: A list of chronological events. ball_ids: A collection of ball IDs. keep_nonevent: Retain non-events (:attr:`pooltool.events.EventType.NONE`). Returns: List[Event]: A filtered event list containing only events that involve balls matching the passed ball ID(s). See Also: - If you're filtering based on multiple criteria, you can (and should!) use :func:`pooltool.events.filter_events`. """ return by_ball(ball_ids, keep_nonevent)(events)
[docs] def filter_time(events: list[Event], t: float, after: bool = True) -> list[Event]: """Filter events with respect to a time cutoff. Args: events: A list of chronological events. t: The cutoff time for filtering events. after: If ``True``, return events after time ``t`` (non-inclusive). If ``False``, return events before time ``t`` (non-inclusive). Returns: List[Event]: A filtered event list containing only events before or after the cutoff time, non-inclusive. See Also: - If you're filtering based on multiple criteria, you can (and should!) use :func:`pooltool.events.filter_events`. """ return by_time(t, after)(events)