Source code for fastapi_icontract.openapi

"""Define data structures to be added to the OpenAPI specs."""
import functools
from typing import List, Any, Optional, Dict

import fastapi
import fastapi.openapi.utils

from fastapi_icontract._globals import CallableT


class Contract:
    """Describe a contract of an operation."""

    def __init__(
        self, enforced: bool, text: str, status_code: int, description: Optional[str]
    ) -> None:
        """Initialize with the given values."""
        self.enforced = enforced
        self.text = text
        self.status_code = status_code
        self.description = description


class Snapshot:
    """Describe a snapshot involved in an operation."""

    def __init__(self, name: str, enabled: bool, text: str) -> None:
        """Initialize with the given values."""
        self.name = name
        self.enabled = enabled
        self.text = text


class Contracts:
    """Describe all the contracts of an operation."""

    def __init__(self) -> None:
        """Initialize with the empty values."""
        self.preconditions = []  # type: List[Contract]
        self.snapshots = []  # type: List[Snapshot]
        self.postconditions = []  # type: List[Contract]


def get_or_attach(func: CallableT) -> Contracts:
    """Get or create the attribute of the endpoint for the contracts."""
    if not hasattr(func, "__fastapi_icontract_openapi__"):
        contracts = Contracts()

        setattr(func, "__fastapi_icontract_openapi__", contracts)
    else:
        contracts = getattr(func, "__fastapi_icontract_openapi__")

    return contracts


def _contract_to_jsonable(contract: Contract) -> Dict[str, Any]:
    """Convert the contract to a JSON-able structure."""
    jsonable = {
        "enforced": contract.enforced,
        "text": contract.text,
        "language": "python3",
        "statusCode": contract.status_code,
    }
    if contract.description:
        jsonable["description"] = contract.description

    return jsonable


def _snapshot_to_jsonable(snapshot: Snapshot) -> Dict[str, Any]:
    """Convert the snapshot to a JSON-able structure."""
    jsonable = {
        "name": snapshot.name,
        "enabled": snapshot.enabled,
        "text": snapshot.text,
        "language": "python3",
    }

    return jsonable


def contracts_to_jsonable(contracts: Contracts) -> Dict[str, Any]:
    """Convert the contracts to a JSON-able structure."""
    return {
        "preconditions": [
            _contract_to_jsonable(contract) for contract in contracts.preconditions
        ],
        "snapshots": [
            _snapshot_to_jsonable(snapshot) for snapshot in contracts.snapshots
        ],
        "postconditions": [
            _contract_to_jsonable(contract) for contract in contracts.postconditions
        ],
    }


[docs]def wrap_openapi_with_contracts(app: fastapi.FastAPI) -> None: """Wrap the ``openapi`` method of the ``app`` to include the contracts in the schema.""" old_openapi_func = app.openapi # Delete the cached openapi_schema so that we can re-create it app.openapi_schema = None def wrapper() -> Optional[Dict[str, Any]]: # Retrieve from cache, if possible if app.openapi_schema is not None: return app.openapi_schema openapi_schema = old_openapi_func() operation_contracts = dict() # type: Dict[str, Contracts] for route in app.routes: if ( not isinstance(route, fastapi.routing.APIRoute) or not route.include_in_schema ): continue contracts = getattr(route.endpoint, "__fastapi_icontract_openapi__", None) if contracts is None: continue assert isinstance(contracts, Contracts) for method in route.methods: operation_id = fastapi.openapi.utils.generate_operation_id( route=route, method=method ) assert ( operation_id is not None ), f"Unexpected None operation ID for endpoint {route.endpoint}" assert ( operation_id not in operation_contracts ), f"Unexpected duplicate contracts for operation ID: {operation_id}" operation_contracts[operation_id] = contracts # Find the operation in the schema for path in openapi_schema["paths"].values(): for operation in path.values(): operation_id = operation.get("operationId", None) if operation_id is not None and operation_id in operation_contracts: contracts = operation_contracts[operation_id] operation["x-contracts"] = contracts_to_jsonable(contracts) # Cache app.openapi_schema = openapi_schema return openapi_schema functools.update_wrapper(wrapper=wrapper, wrapped=old_openapi_func) app.openapi = wrapper # type: ignore