import logging
import os
import time
import timeit
import warnings
from astropy.table import vstack
import numpy as np
import ujson
from confluent_kafka import Consumer, KafkaException, TopicPartition
from gwpy.table import EventTable
from gwpy.table.filters import in_segmentlist
from ligo.segments import segment, segmentlist
from ... import exceptions
from ... import names
from ... import utils
from ... import hookimpl
from . import DataLoader
DEFAULT_CHUNK_SIZE = 60
logger = logging.getLogger("idq")
[docs]class SNAXDataLoader(DataLoader):
"""an extension meant to read SNAX features off disk
We assume the following directory structure:
${rootdir}/${gpsMOD1e5}/${basename}-${start}-${dur}.h5
"""
_default_columns = utils.SNAX_COLUMNS
_allowed_columns = utils.SNAX_COLUMNS
_required_kwargs = ["rootdir", "basename"]
def _load_table(self, cache, channels, selection, verbose=False):
data = {}
# load features
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
try:
table = EventTable.read(
cache,
channels=channels,
format="hdf5.snax",
columns=self.columns + ["channel"],
selection=selection,
nproc=self.kwargs.get("nproc", 1),
on_missing=self.kwargs.get("on_missing", "error"),
compact=True,
verbose=verbose,
)
except (KeyError, TypeError, ValueError):
# empty table prior to selection
raise exceptions.NoDataError(self.start, self.end - self.start)
# skip groupby if table is empty after selection
if not table:
raise exceptions.NoDataError(self.start, self.end - self.start)
# group by channel and drop channel column
table = table.group_by("channel")
table.remove_column("channel")
for key, group in zip(table.groups.keys, table.groups):
channel = table.meta["channel_map"][key["channel"]]
data[channel] = EventTable(group, copy=True)
return data
def _query(self, channels=None, bounds=None, verbose=False, **kwargs):
"""Workhorse data discovery method for SNAX hdf5 files."""
segs = self.kwargs.get("segs", self.segs)
# set up cache, filtering by segments
filename = names.basename2snax_filename(self.kwargs["basename"])
cache = utils.path2cache(self.kwargs["rootdir"], os.path.join("*", filename))
cache = [entry for entry in cache if segs.intersects_segment(entry.segment)]
# set up filters for selection
selection = [(utils.SNAX_TIMECOL, in_segmentlist, segs)]
if bounds:
for col, (min_, max_) in bounds.items():
selection.extend([f"{col} >= {min_}", f"{col} <= {max_}"])
# iteratively load features, combining their results
data = {}
for subcache in utils.chunker(cache, DEFAULT_CHUNK_SIZE):
chunk = self._load_table(subcache, channels, selection, verbose=verbose)
for channel, table in chunk.items():
data.setdefault(channel, []).append(table)
for channel in data.keys():
data[channel] = vstack(data[channel], join_type="exact")
# fill in missing channels
if channels:
dtype = [(col, "float") for col in self.columns]
for channel in channels:
if channel not in data:
data[channel] = EventTable(data=np.array([], dtype=dtype))
return data
[docs]class SNAXKafkaDataLoader(DataLoader):
"""an extension meant to load streaming SNAX features from Kafka.
Intended to keep a running current timestamp, and has the ability to poll
for new data and fill its own ClassifierData objects for use when triggers
are to be retrieved.
NOTE: when called, this will cache all trigger regardless of the bounds.
This is done to avoid issues with re-querying data from rolling buffers,
which is not guaranteed to return consistent results. Instead, we record
everything we query and filter.
"""
_skip_filter = True
_default_columns = utils.SNAX_COLUMNS
_allowed_columns = utils.SNAX_COLUMNS
_required_kwargs = [
"group",
"url",
"topic",
"poll_timeout",
"retry_cadence",
"sample_rate",
]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._topic = self.kwargs["topic"]
self._timestamp = self.kwargs.get("timestamp", 0)
# create a Kafka consumer
if "consumer" not in self.kwargs:
self.kwargs["consumer"] = Consumer(
{
"group.id": self.kwargs["group"],
"bootstrap.servers": self.kwargs["url"],
**self.kwargs.get("kafka_kwargs", {}),
}
)
# subscribe to topic and assign to specific partition
tp = TopicPartition(self._topic, partition=0)
self.kwargs["consumer"].subscribe([self._topic])
self.kwargs["consumer"].assign([tp])
# dtype for query results
self._dtype = [(col, "float") for col in self.columns]
def _seek_timestamp(self, topic, timestamp, stride):
"""
seek to the offset in a given topic corresponding to the timestamp specified
"""
if not (self._timestamp + stride == timestamp):
# only seek if offset needs to be changed
offset_timestamp = int(timestamp * 1e3) # NOTE: convert to ms
logger.debug(
"next timestamp not contiguous (%f != %f), seeking to offset"
% (self._timestamp + stride, timestamp)
)
time_partition = TopicPartition(topic, partition=0, offset=offset_timestamp)
offset = self.kwargs["consumer"].offsets_for_times([time_partition])[0]
self.kwargs["consumer"].seek(offset)
def _retrieve(self, **kwargs):
"""
retrieves messages in a timestamp-aware way
NOTE: timestamp field in message is required
"""
data_buffer = self.kwargs["consumer"].poll(**kwargs)
# retrieve only if there is a message and there isn't an error involved
if data_buffer and not data_buffer.error():
data_buffer = ujson.loads(data_buffer.value())
data_timestamp = data_buffer["timestamp"]
self._timestamp = data_timestamp
self.kwargs["timestamp"] = data_timestamp
return data_timestamp, data_buffer
else:
return None, None
def _query(self, channels=None, **kwargs):
"""NOTE: we intentionally ignore bounds here and just store everything"""
# NOTE: this forces data to only be populated once
if self._data:
return self._data
if not utils.livetime(self.segs):
# we'll never find any data in the segments because they have
# zero livetime
self._data = {}
if channels is not None:
for channel in channels:
self._data[channel] = np.array([], dtype=self._dtype)
return self._data
# assume last processed timestamp is pointed just before this span
stride = 1.0 / self.kwargs["sample_rate"]
timestamp = self.start - stride
segs = segmentlist([segment(self.start, self.end)])
# attempt to seek to correct offset
try:
self._seek_timestamp(self._topic, self.start, stride)
except KafkaException as e:
logger.warning("could not seek to offset: %s" % repr(e))
# check if timestamps fall within span of classifier data
data = {}
stop = self.end - stride
while timestamp < stop:
# poll for data at requested offset
start_time = timeit.default_timer()
new_timestamp, data_buffer = self._retrieve(
timeout=self.kwargs["poll_timeout"]
)
if data_buffer:
logger.debug("timestamp of buffer: %.6f" % new_timestamp)
logger.debug("latency: %.6f" % utils.gps2latency(new_timestamp))
if new_timestamp == timestamp + stride:
# data is contiguous, which is GOOD
if utils.time_in_segments(new_timestamp, self.segs):
# fill in all data from buffer
for channel, rows in data_buffer["features"].items():
buf_data = [
{col: row[col] for col in self.columns}
for row in rows
if row
]
data.setdefault(channel, []).extend(buf_data)
else:
# data is NOT contiguous, which could be BAD
gap = segment(timestamp + stride, new_timestamp)
# check whether the gap touched things we care about
if utils.livetime(self.segs & segmentlist([gap])):
# move to the next stride
newer_timestamp = max(
utils.floor_div(timestamp, self.end - self.start),
self.end,
)
# reset the consumer's position
self._seek_timestamp(self._topic, newer_timestamp, stride)
raise exceptions.IncontiguousDataError(
gap[0], gap[1], newer_timestamp
)
# update the timestamp to be the new thing
timestamp = new_timestamp
# wait to poll for new data
elapsed = timeit.default_timer() - start_time
sleep_time = self.kwargs["retry_cadence"] - min(
elapsed, self.kwargs["retry_cadence"]
)
time.sleep(sleep_time)
# timestamp >= stop
if not data:
raise exceptions.NoDataError(self.start, self.end - self.start)
# combine rows together
for channel, rows in data.items():
data[channel] = EventTable(rows=rows, names=self.columns)
# filter by segs if needed
if not utils.livetime(self.segs) == utils.livetime(self.segs & segs):
for channel, table in data.items():
data[channel] = table.filter(
(utils.SNAX_TIMECOL, in_segmentlist, self.segs)
)
self._data = data
# fill in missing channels
if channels:
for channel in channels:
if channel not in self._data: # ignore bounds here
self._data[channel] = EventTable(
data=np.array([], dtype=self._dtype)
)
# make sure we record that we cached everything for all channels present
# we do not impose any bounds
bounds = {}
for channel in self._data:
# update what is considered cached
if channel in self._cached:
# not yet cached recorded as cached, but we record that we
# cached everything for all channels
self._cached[channel] = bounds
return self._data
@hookimpl
def get_dataloaders():
return {
"snax": SNAXDataLoader,
"snax:hdf5": SNAXDataLoader,
"snax:kafka": SNAXKafkaDataLoader,
}