Source code for utilities.testtools

"""Test utilities. Common functions used across various GstLAL unittests
"""

import gi

gi.require_version('Gst', '1.0')
from gi.repository import GObject, Gst

GObject.threads_init()
Gst.init(None)
import os
import pathlib
import string
import sys
import tempfile
import types
from typing import Tuple, Dict
from unittest import mock

import pytest

PLATFORM = sys.platform

DEFAULT_MOCK_PATCHES = (
	# Ordered mapping of (target, {kwarg: value}) for passing into unittest.mock.patch
	('gstlal.datafind.load_frame_cache', {'return_value': [1, 2, 3]}),
)
CLEAN_TRANSLATION = {ord(c): None for c in string.whitespace}


[docs] def clean_str(c: str): """Clean a copyright string before comparison""" return c.translate(CLEAN_TRANSLATION)
[docs] def is_osx(platform: str = PLATFORM): """Check is OSX""" return platform.lower() == 'darwin'
[docs] def skip_osx(f: types.FunctionType) -> types.FunctionType: """Decorator wrapping pytest.skipif""" return pytest.mark.skipif(is_osx(), reason='Test not supported on OSX')(f)
[docs] def requires_full_build(f: types.FunctionType): return pytest.mark.requires_full_build(f)
[docs] def broken(reason: str): def wrapper(f: types.FunctionType): func = pytest.mark.skip(f, reason) func = pytest.mark.broken(func) return func return wrapper
[docs] def impl_deprecated(f): return broken('Underlying implementation not included in build')(f)
[docs] class GstLALTestManager: """Context manager for GstLAL tests""" def __init__(self, patch_info: Tuple[Dict[str, Dict]] = DEFAULT_MOCK_PATCHES, env_overrides: dict = None, with_pipeline: bool = False): self.tmp_dir = tempfile.TemporaryDirectory() self.tmp_path = pathlib.Path(self.tmp_dir.name) self.patch_info = patch_info self._patches = [] self._env_originals = {} self._env_overrides = {} if env_overrides is None else env_overrides self._with_pipeline = with_pipeline def __enter__(self): """Enter the GstLAL testing context""" # Create temporary directory self.tmp_dir.__enter__() # Set all mocks for target, kwargs in self.patch_info: p = mock.patch(target, **kwargs) self._patches.append(p) p.__enter__() # Set env overrides keys = list(self._env_overrides.keys()) for k in keys: self.override_env_var(k, self._env_overrides[k]) # Set pipeline self.set_pipeline() return self def __exit__(self, exc_type, exc_val, exc_tb): """Exit GstLAL context""" # Remove tmp dir self.tmp_dir.__exit__(exc_type, exc_val, exc_tb) # Unset all mocks for p in self._patches: p.__exit__(exc_type, exc_val, exc_tb) # Undo all env overrides keys = list(self._env_originals.keys()) for k in keys: self.reset_env_var(k) @property def cache_path(self): """Cache path""" return (self.tmp_path / 'cache.txt').as_posix()
[docs] def override_env_var(self, key: str, val: str): if key in self._env_originals: raise ValueError('That env var is already overridden') self._env_originals[key] = os.environ.get(key, None) self._env_overrides[key] = val os.environ[key] = val
[docs] def reset_env_var(self, key: str): if key not in self._env_originals: raise ValueError('Unable to reset var: {}'.format(key)) val = self._env_originals.pop(key) if val is None: os.environ.pop(key) else: os.environ[key] = val if key in self._env_overrides: self._env_overrides.pop(key)
[docs] def set_pipeline(self): if self._with_pipeline: self.pipeline = Gst.Pipeline(name=os.path.split(sys.argv[0])[1]) else: self.pipeline = None