Source code for pooltool.events.filter

from typing import Callable, List, Union

from pooltool.events.datatypes import AgentType, Event, EventType

FilterFunc = Callable[[List[Event]], List[Event]]


[docs] def by_type(types: Union[EventType, List[EventType]]) -> FilterFunc: """Returns a function that filters events based on event type. Args: types: Event type(s) you want to include in your result. All others will be filtered. Returns: FilterFunc: A function that when passed a list of events, returns a filtered list containing only events matching the passed event type(s). """ def func(events: List[Event]) -> List[Event]: _types: Union[EventType, List[EventType]] if isinstance(types, str): _types = [types] else: _types = types new: List[Event] = [] for event in events: if event.event_type in _types: new.append(event) return new return func
[docs] def by_ball(ball_ids: Union[str, List[str]], keep_nonevent: bool = False) -> FilterFunc: """Returns a function that filters events based on ball IDs. Args: ball_ids: A collection of ball IDs. keep_nonevent: Retain non-events (:attr:`EventType.NONE`). Returns: FilterFunc: A function that when passed a list of events, returns a filtered list containing only events that involve balls matching the passed ball ID(s). """ def func(events: List[Event]) -> List[Event]: _ball_ids: Union[str, List[str]] if isinstance(ball_ids, str): _ball_ids = [ball_ids] else: _ball_ids = ball_ids new: List[Event] = [] for event in events: if keep_nonevent and event.event_type == EventType.NONE: new.append(event) else: for agent in event.agents: if agent.id in _ball_ids and agent.agent_type == AgentType.BALL: new.append(event) break return new return func
[docs] def by_time(t: float, after: bool = True) -> FilterFunc: """Returns a function that filter events with respect to a time cutoff. Args: t: The cutoff time for filtering events. after: If ``True``, return events after time ``t`` (non-inclusive). If ``False``, return events before time ``t`` (non-inclusive). Returns: FilterFunc: A function that when passed a list of events, returns a filtered list containing only events before or after the cutoff time, non-inclusive. """ def func(events: List[Event]) -> List[Event]: if not events == sorted(events, key=lambda event: event.time): raise ValueError("Event lists must be chronological") new: List[Event] = [] for event in events: if after and event.time > t: new.append(event) elif not after and event.time < t: new.append(event) return new return func
def _chain(*funcs: FilterFunc) -> FilterFunc: def func(events: List[Event]) -> List[Event]: result = events for f in funcs: result = f(result) return result return func
[docs] def filter_events(events: List[Event], *funcs: FilterFunc) -> List[Event]: """Filter events using multiple criteria. A convenient way to filter based multiple filtering criteria. Args: events: A list of chronological events. *funcs: An arbitrary number of functions that take a list of events as input, and gives a subset of that list as input. It sounds laborious--it's not. See *Examples*. Returns: List[Event]: A filtered event list containing only events passing the supplied criteria. Examples: Generate a list of events. >>> import pooltool as pt >>> system = pt.System.example() >>> system.cue.set_state(a=0.68) >>> pt.simulate(system, inplace=True) >>> events = system.events In this shot, both the cue-ball and the 1-ball are potted. We are interested in filtering for the cue-ball pocket event. Option 1 is to call :func:`filter_type` and then :func:`filter_ball`: >>> filtered_events = pt.events.filter_type(events, pt.EventType.BALL_POCKET) >>> filtered_events = pt.events.filter_ball(filtered_events, "cue") >>> event_of_interest = filtered_events[0] >>> event_of_interest <Event object at 0x7fa855e7e6c0> ├── type : ball_pocket ├── time : 3.231130101576186 └── agents : ('cue', 'rt') Option 2, the better option, is to use :func:`filter_events`: >>> filtered_events = pt.events.filter_events( >>> events, >>> pt.events.by_type(pt.EventType.BALL_POCKET), >>> pt.events.by_ball("cue"), >>> ) >>> event_of_interest = filtered_events[0] >>> event_of_interest <Event object at 0x7fa855e7e6c0> ├── type : ball_pocket ├── time : 3.231130101576186 └── agents : ('cue', 'rt') See Also: - If you're filtering based on a single criterion, you can consider using :func:`filter_type`, :func:`filter_ball`, :func:`filter_time`, etc. """ return _chain(*funcs)(events)
[docs] def filter_type( events: List[Event], types: Union[EventType, List[EventType]] ) -> List[Event]: """Filter events based on event type. Args: events: A list of chronological events. types: Event type(s) you want to include in your result. All others will be filtered. Returns: List[Event]: A filtered event list containing only events matching the passed event type(s). See Also: - If you're filtering based on multiple criteria, you can (and should!) use :func:`filter_events`. """ return by_type(types)(events)
[docs] def filter_ball( events: List[Event], ball_ids: Union[str, List[str]], keep_nonevent: bool = False ) -> List[Event]: """Filter events based on ball IDs. Args: events: A list of chronological events. ball_ids: A collection of ball IDs. keep_nonevent: Retain non-events (:attr:`EventType.NONE`). Returns: List[Event]: A filtered event list containing only events that involve balls matching the passed ball ID(s). See Also: - If you're filtering based on multiple criteria, you can (and should!) use :func:`filter_events`. """ return by_ball(ball_ids, keep_nonevent)(events)
[docs] def filter_time(events: List[Event], t: float, after: bool = True) -> List[Event]: """Filter events with respect to a time cutoff. Args: events: A list of chronological events. t: The cutoff time for filtering events. after: If ``True``, return events after time ``t`` (non-inclusive). If ``False``, return events before time ``t`` (non-inclusive). Returns: List[Event]: A filtered event list containing only events before or after the cutoff time, non-inclusive. See Also: - If you're filtering based on multiple criteria, you can (and should!) use :func:`filter_events`. """ return by_time(t, after)(events)