Source code for idq.io.triggers.snax

import logging
import os
import time
import timeit
import warnings

from astropy.table import vstack
import numpy as np
import ujson

from confluent_kafka import Consumer, KafkaException, TopicPartition
from gwpy.table import EventTable
from gwpy.table.filters import in_segmentlist
from ligo.segments import segment, segmentlist

from ... import exceptions
from ... import names
from ... import utils
from ... import hookimpl
from . import DataLoader


DEFAULT_CHUNK_SIZE = 60

logger = logging.getLogger("idq")


[docs]class SNAXDataLoader(DataLoader): """an extension meant to read SNAX features off disk We assume the following directory structure: ${rootdir}/${gpsMOD1e5}/${basename}-${start}-${dur}.h5 """ _default_columns = utils.SNAX_COLUMNS _allowed_columns = utils.SNAX_COLUMNS _required_kwargs = ["rootdir", "basename"] def _load_table(self, cache, channels, selection, verbose=False): data = {} # load features with warnings.catch_warnings(): warnings.simplefilter("ignore", UserWarning) try: table = EventTable.read( cache, channels=channels, format="hdf5.snax", columns=self.columns + ["channel"], selection=selection, nproc=self.kwargs.get("nproc", 1), on_missing=self.kwargs.get("on_missing", "error"), compact=True, verbose=verbose, ) except (KeyError, TypeError, ValueError): # empty table prior to selection raise exceptions.NoDataError(self.start, self.end - self.start) # skip groupby if table is empty after selection if not table: raise exceptions.NoDataError(self.start, self.end - self.start) # group by channel and drop channel column table = table.group_by("channel") table.remove_column("channel") for key, group in zip(table.groups.keys, table.groups): channel = table.meta["channel_map"][key["channel"]] data[channel] = EventTable(group, copy=True) return data def _query(self, channels=None, bounds=None, verbose=False, **kwargs): """Workhorse data discovery method for SNAX hdf5 files.""" segs = self.kwargs.get("segs", self.segs) # set up cache, filtering by segments filename = names.basename2snax_filename(self.kwargs["basename"]) cache = utils.path2cache(self.kwargs["rootdir"], os.path.join("*", filename)) cache = [entry for entry in cache if segs.intersects_segment(entry.segment)] # set up filters for selection selection = [(utils.SNAX_TIMECOL, in_segmentlist, segs)] if bounds: for col, (min_, max_) in bounds.items(): selection.extend([f"{col} >= {min_}", f"{col} <= {max_}"]) # iteratively load features, combining their results data = {} for subcache in utils.chunker(cache, DEFAULT_CHUNK_SIZE): chunk = self._load_table(subcache, channels, selection, verbose=verbose) for channel, table in chunk.items(): data.setdefault(channel, []).append(table) for channel in data.keys(): data[channel] = vstack(data[channel], join_type="exact") # fill in missing channels if channels: dtype = [(col, "float") for col in self.columns] for channel in channels: if channel not in data: data[channel] = EventTable(data=np.array([], dtype=dtype)) return data
[docs]class SNAXKafkaDataLoader(DataLoader): """an extension meant to load streaming SNAX features from Kafka. Intended to keep a running current timestamp, and has the ability to poll for new data and fill its own ClassifierData objects for use when triggers are to be retrieved. NOTE: when called, this will cache all trigger regardless of the bounds. This is done to avoid issues with re-querying data from rolling buffers, which is not guaranteed to return consistent results. Instead, we record everything we query and filter. """ _skip_filter = True _default_columns = utils.SNAX_COLUMNS _allowed_columns = utils.SNAX_COLUMNS _required_kwargs = [ "group", "url", "topic", "poll_timeout", "retry_cadence", "sample_rate", ] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._topic = self.kwargs["topic"] self._timestamp = self.kwargs.get("timestamp", 0) # create a Kafka consumer if "consumer" not in self.kwargs: self.kwargs["consumer"] = Consumer( { "group.id": self.kwargs["group"], "bootstrap.servers": self.kwargs["url"], **self.kwargs.get("kafka_kwargs", {}), } ) # subscribe to topic and assign to specific partition tp = TopicPartition(self._topic, partition=0) self.kwargs["consumer"].subscribe([self._topic]) self.kwargs["consumer"].assign([tp]) # dtype for query results self._dtype = [(col, "float") for col in self.columns] def _seek_timestamp(self, topic, timestamp, stride): """ seek to the offset in a given topic corresponding to the timestamp specified """ if not (self._timestamp + stride == timestamp): # only seek if offset needs to be changed offset_timestamp = int(timestamp * 1e3) # NOTE: convert to ms logger.debug( "next timestamp not contiguous (%f != %f), seeking to offset" % (self._timestamp + stride, timestamp) ) time_partition = TopicPartition(topic, partition=0, offset=offset_timestamp) offset = self.kwargs["consumer"].offsets_for_times([time_partition])[0] self.kwargs["consumer"].seek(offset) def _retrieve(self, **kwargs): """ retrieves messages in a timestamp-aware way NOTE: timestamp field in message is required """ data_buffer = self.kwargs["consumer"].poll(**kwargs) # retrieve only if there is a message and there isn't an error involved if data_buffer and not data_buffer.error(): data_buffer = ujson.loads(data_buffer.value()) data_timestamp = data_buffer["timestamp"] self._timestamp = data_timestamp self.kwargs["timestamp"] = data_timestamp return data_timestamp, data_buffer else: return None, None def _query(self, channels=None, **kwargs): """NOTE: we intentionally ignore bounds here and just store everything""" # NOTE: this forces data to only be populated once if self._data: return self._data if not utils.livetime(self.segs): # we'll never find any data in the segments because they have # zero livetime self._data = {} if channels is not None: for channel in channels: self._data[channel] = np.array([], dtype=self._dtype) return self._data # assume last processed timestamp is pointed just before this span stride = 1.0 / self.kwargs["sample_rate"] timestamp = self.start - stride segs = segmentlist([segment(self.start, self.end)]) # attempt to seek to correct offset try: self._seek_timestamp(self._topic, self.start, stride) except KafkaException as e: logger.warning("could not seek to offset: %s" % repr(e)) # check if timestamps fall within span of classifier data data = {} stop = self.end - stride while timestamp < stop: # poll for data at requested offset start_time = timeit.default_timer() new_timestamp, data_buffer = self._retrieve( timeout=self.kwargs["poll_timeout"] ) if data_buffer: logger.debug("timestamp of buffer: %.6f" % new_timestamp) logger.debug("latency: %.6f" % utils.gps2latency(new_timestamp)) if new_timestamp == timestamp + stride: # data is contiguous, which is GOOD if utils.time_in_segments(new_timestamp, self.segs): # fill in all data from buffer for channel, rows in data_buffer["features"].items(): buf_data = [ {col: row[col] for col in self.columns} for row in rows if row ] data.setdefault(channel, []).extend(buf_data) else: # data is NOT contiguous, which could be BAD gap = segment(timestamp + stride, new_timestamp) # check whether the gap touched things we care about if utils.livetime(self.segs & segmentlist([gap])): # move to the next stride newer_timestamp = max( utils.floor_div(timestamp, self.end - self.start), self.end, ) # reset the consumer's position self._seek_timestamp(self._topic, newer_timestamp, stride) raise exceptions.IncontiguousDataError( gap[0], gap[1], newer_timestamp ) # update the timestamp to be the new thing timestamp = new_timestamp # wait to poll for new data elapsed = timeit.default_timer() - start_time sleep_time = self.kwargs["retry_cadence"] - min( elapsed, self.kwargs["retry_cadence"] ) time.sleep(sleep_time) # timestamp >= stop if not data: raise exceptions.NoDataError(self.start, self.end - self.start) # combine rows together for channel, rows in data.items(): data[channel] = EventTable(rows=rows, names=self.columns) # filter by segs if needed if not utils.livetime(self.segs) == utils.livetime(self.segs & segs): for channel, table in data.items(): data[channel] = table.filter( (utils.SNAX_TIMECOL, in_segmentlist, self.segs) ) self._data = data # fill in missing channels if channels: for channel in channels: if channel not in self._data: # ignore bounds here self._data[channel] = EventTable( data=np.array([], dtype=self._dtype) ) # make sure we record that we cached everything for all channels present # we do not impose any bounds bounds = {} for channel in self._data: # update what is considered cached if channel in self._cached: # not yet cached recorded as cached, but we record that we # cached everything for all channels self._cached[channel] = bounds return self._data
@hookimpl def get_dataloaders(): return { "snax": SNAXDataLoader, "snax:hdf5": SNAXDataLoader, "snax:kafka": SNAXKafkaDataLoader, }