Source code for _gettsim.shared

import inspect
import re
import textwrap
from datetime import date
from typing import Callable, Optional

from _gettsim.config import SUPPORTED_GROUPINGS


class KeyErrorMessage(str):
    """Subclass str to allow for line breaks in KeyError messages."""

    def __repr__(self):
        return str(self)


[docs]def add_rounding_spec(params_key): """Decorator adding the location of the rounding specification to a function. Parameters ---------- params_key : str Key of the parameters dictionary where rounding specifications are found. For functions that are not user-written this is just the name of the respective .yaml file. Returns ------- func : function Function with __info__["rounding_params_key"] attribute """ def inner(func): # Remember data from decorator if not hasattr(func, "__info__"): func.__info__ = {} func.__info__["rounding_params_key"] = params_key return func return inner
TIME_DEPENDENT_FUNCTIONS: dict[str, list[Callable]] = {}
[docs]def dates_active( start: str = "0001-01-01", end: str = "9999-12-31", change_name: Optional[str] = None, ) -> Callable: """ Specifies that a function is only active between two dates, `start` and `end`. By using the `change_name` argument, you can specify a different name for the function in the DAG. Note that even if you use this decorator with the `change_name` argument, you must ensure that the function name is unique in the file where it is defined. Otherwise, the function will be overwritten by the last function with the same name. Parameters ---------- start The start date (inclusive) in the format YYYY-MM-DD (part of ISO 8601). end The end date (inclusive) in the format YYYY-MM-DD (part of ISO 8601). change_name The name that should be used as the key for the function in the DAG. If omitted, we use the name of the function as defined. Returns ------- The function with attributes __info__["dates_active_start"], __info__["dates_active_end"], and __info__["dates_active_dag_key"]. """ _validate_dashed_iso_date(start) _validate_dashed_iso_date(end) start_date = date.fromisoformat(start) end_date = date.fromisoformat(end) _validate_date_range(start_date, end_date) def inner(func: Callable) -> Callable: dag_key = change_name if change_name else func.__name__ _check_for_conflicts_in_time_dependent_functions( dag_key, func.__name__, start_date, end_date ) # Remember data from decorator if not hasattr(func, "__info__"): func.__info__ = {} func.__info__["dates_active_start"] = start_date func.__info__["dates_active_end"] = end_date func.__info__["dates_active_dag_key"] = dag_key # Register time-dependent function if dag_key not in TIME_DEPENDENT_FUNCTIONS: TIME_DEPENDENT_FUNCTIONS[dag_key] = [] TIME_DEPENDENT_FUNCTIONS[dag_key].append(func) return func return inner
_dashed_iso_date = re.compile(r"\d{4}-\d{2}-\d{2}") def _validate_dashed_iso_date(date_str: str): if not _dashed_iso_date.match(date_str): raise ValueError( # noqa: TRY003 f"Date {date_str} does not match the format YYYY-MM-DD." ) def _validate_date_range(start: date, end: date): if start > end: raise ValueError( # noqa: TRY003 f"The start date {start} must be before the end date {end}." ) def _check_for_conflicts_in_time_dependent_functions( dag_key: str, function_name: str, start: date, end: date ): """ Raises an error if a different time-dependent function has already been registered for the given dag_key and their date ranges overlap. """ if dag_key not in TIME_DEPENDENT_FUNCTIONS: return for f in TIME_DEPENDENT_FUNCTIONS[dag_key]: # A function is not conflicting with itself. We compare names instead of # identities since functions might get wrapped, which would change their # identity but not their name. if f.__name__ != function_name and ( start <= f.__info__["dates_active_start"] <= end or f.__info__["dates_active_start"] <= start <= f.__info__["dates_active_end"] ): raise ConflictingTimeDependentFunctionsError( dag_key, function_name, start, end, f.__name__, f.__info__["dates_active_start"], f.__info__["dates_active_end"], ) class ConflictingTimeDependentFunctionsError(Exception): """Raised when two time-dependent functions have overlapping time ranges.""" def __init__( # noqa: PLR0913 self, dag_key: str, function_name_1: str, start_1: date, end_1: date, function_name_2: str, start_2: date, end_2: date, ): super().__init__( f"Conflicting functions for key {dag_key!r}: " f"{function_name_1!r} ({start_1} to {end_1}) vs. " f"{function_name_2!r} ({start_2} to {end_2}).\n\n" f"Overlapping from {max(start_1, start_2)} to {min(end_1, end_2)}." ) def format_list_linewise(list_): formatted_list = '",\n "'.join(list_) return textwrap.dedent( """ [ "{formatted_list}", ] """ ).format(formatted_list=formatted_list) def parse_to_list_of_strings(user_input, name): """Parse None, str, and list of strings to list of strings. Note that the function automatically removes duplicates. """ if user_input is None: user_input = [] elif isinstance(user_input, str): user_input = [user_input] elif isinstance(user_input, list) and all(isinstance(i, str) for i in user_input): pass else: NotImplementedError( f"{name!r} needs to be None, a string or a list of strings." ) return sorted(set(user_input)) def format_errors_and_warnings(text, width=79): """Format our own exception messages and warnings by dedenting paragraphs and wrapping at the specified width. Mainly required because of messages are written as part of indented blocks in our source code. Parameter --------- text : str The text which can include multiple paragraphs separated by two newlines. width : int The text will be wrapped by `width` characters. Returns ------- formatted_text : str Correctly dedented, wrapped text. """ text = text.lstrip("\n") paragraphs = text.split("\n\n") wrapped_paragraphs = [] for paragraph in paragraphs: dedented_paragraph = textwrap.dedent(paragraph) wrapped_paragraph = textwrap.fill(dedented_paragraph, width=width) wrapped_paragraphs.append(wrapped_paragraph) formatted_text = "\n\n".join(wrapped_paragraphs) return formatted_text def get_names_of_arguments_without_defaults(function): """Get argument names without defaults. The detection of argument names also works for partialed functions. Examples -------- >>> def func(a, b): pass >>> get_names_of_arguments_without_defaults(func) ['a', 'b'] >>> import functools >>> func_ = functools.partial(func, a=1) >>> get_names_of_arguments_without_defaults(func_) ['b'] """ parameters = inspect.signature(function).parameters argument_names_without_defaults = [] for parameter in parameters: if parameters[parameter].default == parameters[parameter].empty: argument_names_without_defaults.append(parameter) return argument_names_without_defaults def remove_group_suffix(col): out = col for g in SUPPORTED_GROUPINGS: out = out.removesuffix(f"_{g}") return out