Source code for fastapi_icontract._decorators

"""Provide decorators for the endpoints."""

import inspect
import traceback
from typing import Optional

import fastapi
import icontract
import icontract._checkers
import icontract._represent
import icontract._types

import fastapi_icontract.openapi
from fastapi_icontract._globals import CallableT


def _func_body_as_text(func: CallableT) -> str:
    """Represent the condition as text to be included in the specs."""
    if not icontract._represent.is_lambda(a_function=func):
        return func.__name__

    lines, condition_lineno = inspect.findsource(func)
    filename = inspect.getsourcefile(func)
    assert filename is not None, f"Unexpected None filename for condition: {func}"

    decorator_inspection = icontract._represent.inspect_decorator(
        lines=lines, lineno=condition_lineno, filename=filename
    )
    lambda_inspection = icontract._represent.find_lambda_condition(
        decorator_inspection=decorator_inspection
    )

    assert (
        lambda_inspection is not None
    ), f"Expected lambda_inspection to be non-None if is_lambda is True on: {func}"

    return lambda_inspection.text


[docs]class require: # pylint: disable=invalid-name """Decorate a FastAPI endpoint with a pre-condition.""" # pylint: disable=too-many-instance-attributes
[docs] def __init__( self, condition: CallableT, status_code: int = 422, description: Optional[str] = None, enforced: bool = True, undocument: bool = False, ) -> None: """ Initialize. :param condition: pre-condition predicate. The arguments of the pre-condition are expected to be a subset of the endpoint arguments. It can either be a sync function, a lambda, an async function. If the condition returns a coroutine, the coroutine will be awaited first, and then checked for truthiness. :param status_code: If the pre-condition is violated, the checker will raise a :class:`fastapi.HTTPException`. This ``status_code`` will be indicated in the exception. :param description: textual description of the pre-condition. The ``description`` will be included in the exception if the pre-condition is violated. :param enforced: If set, the pre-condition is enforced. Otherwise, the pre-condition is only added to the OpenAPI schema, but is not verified. Usually, you enforce certain slow pre-conditions during testing and then disable them in production. An unenforced pre-condition is however still useful for the client as a formal documentation which is at least verified during testing. :param undocument: If set, the pre-condition is not documented in the OpenAPI schema. """ # pylint: disable=too-many-arguments self.condition = condition self.description = description self.status_code = status_code self.undocument = undocument self.enforced = enforced self._contract = None # type: Optional[icontract._types.Contract] if enforced: location = None # type: Optional[str] tb_stack = traceback.extract_stack(limit=2)[:1] if len(tb_stack) > 0: frame = tb_stack[0] location = f"File {frame.filename}, line {frame.lineno} in {frame.name}" self._contract = icontract._types.Contract( condition=condition, description=description, error=fastapi.HTTPException( status_code=status_code, detail=( f"Pre-condition violated: {description}" if description else None ), ), location=location, )
[docs] def __call__(self, func: CallableT) -> CallableT: """ Add the pre-condition to the checker of a FastAPI endpoint. :param func: endpoint function to be wrapped :return: wrapped endpoint """ if not self.enforced: result = func else: assert self._contract is not None, "Expected a contract if enforced." contract_checker = icontract._checkers.find_checker(func=func) if contract_checker is None: contract_checker = icontract._checkers.decorate_with_checker(func=func) result = contract_checker icontract._checkers.add_precondition_to_checker( checker=contract_checker, contract=self._contract ) if not self.undocument: openapi_contracts = fastapi_icontract.openapi.get_or_attach(func=result) text = _func_body_as_text(func=self.condition) openapi_contracts.preconditions.append( fastapi_icontract.openapi.Contract( enforced=self.enforced, text=text, status_code=self.status_code, description=self.description, ) ) return result
[docs]class snapshot: # pylint: disable=invalid-name """ Add a snapshot to the checker of an FastAPI endpoint. This will decorate the endpoint with a snapshot of argument values captured *prior* to the invocation. A snapshot is defined by a capture function (usually a lambda) that accepts one or more arguments of the function. The captured values are supplied to post-conditions with the OLD argument of the condition. """
[docs] def __init__( self, capture: CallableT, name: str, enabled: bool = True, undocument: bool = False, ) -> None: """ Initialize. :param capture: function to capture the snapshot accepting a one or more arguments of the original function *prior* to the execution. The ``capture`` can either be a lambda, a sync function or an async function. If ``capture`` returns a coroutine, the coroutine will be first awaited before it is stored into the ``OLD`` structure. :param name: name of the snapshot as will be stored in the OLD structure. :param enabled: The snapshot is applied only if ``enabled`` is set. Otherwise, the snapshot is disabled and there is no run-time overhead. Usually the snapshots are enabled and disabled together with their related post-conditions. :param undocument: If set, the snapshot is not documented in the OpenAPI schema. """ self.capture = capture self._snapshot = None # type: Optional[icontract._types.Snapshot] self.enabled = enabled self.name = name # Resolve the snapshot only if enabled so that no overhead is incurred if enabled: location = None # type: Optional[str] tb_stack = traceback.extract_stack(limit=2)[:1] if len(tb_stack) > 0: frame = tb_stack[0] location = f"File {frame.filename}, line {frame.lineno} in {frame.name}" self._snapshot = icontract._types.Snapshot( capture=capture, name=name, location=location ) self.undocument = undocument
[docs] def __call__(self, func: CallableT) -> CallableT: """ Add the snapshot to the checker of a FastAPI endpoint ``func``. The function ``func`` is expected to be decorated with at least one postcondition before the snapshot. :param func: function whose arguments we need to snapshot :return: ``func`` as given in the input """ if not self.enabled: result = func else: # Find a contract checker contract_checker = icontract._checkers.find_checker(func=func) if contract_checker is None: raise ValueError( "You are decorating a function with a snapshot, " "but no postcondition was defined on the function before. " "(If you defined one or more post-conditions, can it be that " "they are not enforced at the same time as this snapshot?)" ) result = contract_checker assert ( self._snapshot is not None ), "Expected the enabled snapshot to have the property ``snapshot`` set." icontract._checkers.add_snapshot_to_checker( checker=contract_checker, snapshot=self._snapshot ) if not self.undocument: openapi_contracts = fastapi_icontract.openapi.get_or_attach(func=result) text = _func_body_as_text(func=self.capture) openapi_contracts.snapshots.append( fastapi_icontract.openapi.Snapshot( name=self.name, enabled=self.enabled, text=text ) ) return result
[docs]class ensure: # pylint: disable=invalid-name """Decorate a FastAPI endpoint with a post-condition.""" # pylint: disable=too-many-instance-attributes
[docs] def __init__( self, condition: CallableT, status_code: int = 500, description: Optional[str] = None, enforced: bool = True, undocument: bool = False, ) -> None: """ Initialize. :param condition: post-condition predicate. The arguments of the post-condition are expected to be a subset of the endpoint arguments. It can either be a sync function, a lambda, an async function. If the condition returns a coroutine, the coroutine will be awaited first, and then checked for truthiness. :param status_code: If the post-condition is violated, the checker will raise a :class:`fastapi.HTTPException`. This ``status_code`` will be indicated in the exception. :param description: textual description of the post-condition. The ``description`` will be included in the exception if the post-condition is violated. :param enforced: If set, the post-condition is enforced. Otherwise, the post-condition is only added to the OpenAPI schema, but is not verified. Usually, you enforce post-conditions during testing and then disable them all in production. An unenforced post-condition is however still useful for the client as a formal documentation which is at least verified during testing. :param undocument: If set, the post-condition is not documented in the OpenAPI schema. """ # pylint: disable=too-many-arguments self.condition = condition self.status_code = status_code self.description = description self.enforced = enforced self.undocument = undocument self._contract = None # type: Optional[icontract._types.Contract] if enforced: location = None # type: Optional[str] tb_stack = traceback.extract_stack(limit=2)[:1] if len(tb_stack) > 0: frame = tb_stack[0] location = f"File {frame.filename}, line {frame.lineno} in {frame.name}" self._contract = icontract._types.Contract( condition=condition, description=description, error=fastapi.HTTPException( status_code=status_code, detail=( f"Post-condition violated: {description}" if description else None ), ), location=location, )
[docs] def __call__(self, func: CallableT) -> CallableT: """ Add the postcondition to the checker of a FastAPI endpoint. If the endpoint has not been already wrapped with a checker, this will wrap it with a checker first. :param func: endpoint function to be wrapped :return: wrapped endpoint """ if not self.enforced: result = func else: assert self._contract is not None, "Expected a contract if enforced." contract_checker = icontract._checkers.find_checker(func=func) if contract_checker is None: contract_checker = icontract._checkers.decorate_with_checker(func=func) result = contract_checker icontract._checkers.add_postcondition_to_checker( checker=contract_checker, contract=self._contract ) if not self.undocument: openapi_contracts = fastapi_icontract.openapi.get_or_attach(func=result) text = _func_body_as_text(func=self.condition) openapi_contracts.postconditions.append( fastapi_icontract.openapi.Contract( enforced=self.enforced, text=text, status_code=self.status_code, description=self.description, ) ) return result