PK ! flama/__init__.pyPK ! S| flama/applications.pyimport typing from starlette.applications import Starlette from starlette.exceptions import ExceptionMiddleware from starlette.middleware.errors import ServerErrorMiddleware from starlette.types import ASGIApp from flama import exceptions from flama.components import Component from flama.exceptions import HTTPException from flama.http import Request, Response from flama.injection import Injector from flama.responses import APIErrorResponse from flama.routing import Router from flama.schemas import SchemaMixin if typing.TYPE_CHECKING: from flama.resources import BaseResource __all__ = ["Flama"] class Flama(Starlette, SchemaMixin): def __init__( self, components: typing.Optional[typing.List[Component]] = None, debug: bool = False, title: typing.Optional[str] = "", version: typing.Optional[str] = "", description: typing.Optional[str] = "", schema: typing.Optional[str] = "/schema/", docs: typing.Optional[str] = "/docs/", redoc: typing.Optional[str] = None, *args, **kwargs ) -> None: super().__init__(debug=debug, *args, **kwargs) if components is None: components = [] # Initialize injector self.components = components self.router = Router(components=components) self.app = self.router self.exception_middleware = ExceptionMiddleware(self.router, debug=debug) self.error_middleware = ServerErrorMiddleware(self.exception_middleware, debug=debug) # Add exception handler for API exceptions self.add_exception_handler(exceptions.HTTPException, self.api_http_exception_handler) # Add schema and docs routes self.add_schema_docs_routes( title=title, version=version, description=description, schema=schema, docs=docs, redoc=redoc ) @property def injector(self): return Injector(components=self.components) def mount(self, path: str, app: ASGIApp, name: str = None) -> None: self.components += getattr(app, "components", []) self.router.mount(path, app=app, name=name) def add_resource(self, path: str, resource: "BaseResource"): self.router.add_resource(path, resource=resource) def resource(self, path: str) -> typing.Callable: def decorator(resource: "BaseResource") -> "BaseResource": self.router.add_resource(path, resource=resource) return resource return decorator def api_http_exception_handler(self, request: Request, exc: HTTPException) -> Response: return APIErrorResponse(detail=exc.detail, status_code=exc.status_code, exception=exc) PK ! y~ ~ flama/codecs/__init__.pyfrom flama.codecs.base import * # noqa from flama.codecs.http import * # noqa from flama.codecs.websockets import * # noqa PK ! ~ flama/codecs/base.pyimport typing from flama import http, websockets class Codec: async def decode(self, item: typing.Any, **options): raise NotImplementedError() async def encode(self, item: typing.Any, **options): raise NotImplementedError() class HTTPCodec(Codec): media_type = None async def decode(self, request: http.Request, **options): raise NotImplementedError() async def encode(self, item: typing.Any, **options): raise NotImplementedError() class WebsocketsCodec(Codec): encoding = None async def decode(self, item: websockets.Message, **options): raise NotImplementedError() async def encode(self, item: typing.Any, **options): raise NotImplementedError() PK ! xX flama/codecs/http/__init__.pyfrom flama.codecs.http.jsondata import JSONDataCodec # noqa from flama.codecs.http.multipart import MultiPartCodec # noqa from flama.codecs.http.urlencoded import URLEncodedCodec # noqa PK ! |K flama/codecs/http/jsondata.pyfrom flama import exceptions, http from flama.codecs.base import HTTPCodec __all__ = ["JSONDataCodec"] class JSONDataCodec(HTTPCodec): media_type = "application/json" format = "json" async def decode(self, request: http.Request, **options): try: if await request.body() == b"": return None return await request.json() except ValueError as exc: raise exceptions.DecodeError(f"Malformed JSON. {exc}") from None PK ! B flama/codecs/http/multipart.pyfrom flama import http from flama.codecs.base import HTTPCodec __all__ = ["MultiPartCodec"] class MultiPartCodec(HTTPCodec): media_type = "multipart/form-data" async def decode(self, request: http.Request, **options): return await request.form() PK ! 5" " flama/codecs/http/urlencoded.pyfrom flama import http from flama.codecs.base import HTTPCodec __all__ = ["URLEncodedCodec"] class URLEncodedCodec(HTTPCodec): media_type = "application/x-www-form-urlencoded" async def decode(self, request: http.Request, **options): return await request.form() or None PK ! pvy # flama/codecs/websockets/__init__.pyfrom flama.codecs.websockets.bytes import BytesCodec from flama.codecs.websockets.json import JSONCodec from flama.codecs.websockets.text import TextCodec __all__ = ["BytesCodec", "TextCodec", "JSONCodec"] PK ! - flama/codecs/websockets/bytes.pyfrom flama import exceptions, websockets from flama.codecs.base import WebsocketsCodec __all__ = ["BytesCodec"] class BytesCodec(WebsocketsCodec): encoding = "bytes" async def decode(self, message: websockets.Message, **options): if "bytes" not in message: raise exceptions.DecodeError("Expected bytes websocket messages") return message["bytes"] PK ! :%5 5 flama/codecs/websockets/json.pyimport json from flama import exceptions, websockets from flama.codecs.base import WebsocketsCodec __all__ = ["JSONCodec"] class JSONCodec(WebsocketsCodec): encoding = "json" async def decode(self, message: websockets.Message, **options): if message.get("text") is not None: text = message["text"] else: text = message["bytes"].decode("utf-8") try: return json.loads(text) except json.decoder.JSONDecodeError: raise exceptions.DecodeError("Malformed JSON data received") PK ! H~ ~ flama/codecs/websockets/text.pyfrom flama import exceptions, websockets from flama.codecs.base import WebsocketsCodec __all__ = ["TextCodec"] class TextCodec(WebsocketsCodec): encoding = "text" async def decode(self, message: websockets.Message, **options): if "text" not in message: raise exceptions.DecodeError("Expected text websocket messages") return message["text"] PK ! 4 4 flama/components/__init__.pyfrom flama.components.base import Component # noqa PK ! %]c[ flama/components/asgi.pyimport typing from inspect import Parameter from urllib.parse import parse_qsl from flama import http from flama.components import Component ASGIScope = typing.NewType("ASGIScope", dict) ASGIReceive = typing.NewType("ASGIReceive", typing.Callable) ASGISend = typing.NewType("ASGISend", typing.Callable) class MethodComponent(Component): def resolve(self, scope: ASGIScope) -> http.Method: return http.Method(scope["method"]) class URLComponent(Component): def resolve(self, scope: ASGIScope) -> http.URL: return http.URL(scope=scope) class SchemeComponent(Component): def resolve(self, scope: ASGIScope) -> http.Scheme: return http.Scheme(scope["scheme"]) class HostComponent(Component): def resolve(self, scope: ASGIScope) -> http.Host: return http.Host(scope["server"][0]) class PortComponent(Component): def resolve(self, scope: ASGIScope) -> http.Port: return http.Port(scope["server"][1]) class PathComponent(Component): def resolve(self, scope: ASGIScope) -> http.Path: return http.Path(scope.get("root_path", "") + scope["path"]) class QueryStringComponent(Component): def resolve(self, scope: ASGIScope) -> http.QueryString: return http.QueryString(scope["query_string"].decode()) class QueryParamsComponent(Component): def resolve(self, scope: ASGIScope) -> http.QueryParams: query_string = scope["query_string"].decode() return http.QueryParams(parse_qsl(query_string)) class QueryParamComponent(Component): def resolve(self, parameter: Parameter, query_params: http.QueryParams) -> http.QueryParam: name = parameter.name if name not in query_params: return None return http.QueryParam(query_params[name]) class HeadersComponent(Component): def resolve(self, scope: ASGIScope) -> http.Headers: return http.Headers(scope=scope) class HeaderComponent(Component): def resolve(self, parameter: Parameter, headers: http.Headers) -> http.Header: name = parameter.name.replace("_", "-") if name not in headers: return None return http.Header(headers[name]) class BodyComponent(Component): async def resolve(self, receive: ASGIReceive) -> http.Body: body = b"" while True: message = await receive() if not message["type"] == "http.request": raise Exception(f"Unexpected ASGI message type '{message['type']}'.") body += message.get("body", b"") if not message.get("more_body", False): break return http.Body(body) ASGI_COMPONENTS = ( MethodComponent(), URLComponent(), SchemeComponent(), HostComponent(), PortComponent(), PathComponent(), QueryStringComponent(), QueryParamsComponent(), QueryParamComponent(), HeadersComponent(), HeaderComponent(), BodyComponent(), ) PK ! vS flama/components/base.pyimport inspect from abc import ABCMeta, abstractmethod from flama import exceptions class Component(metaclass=ABCMeta): def identity(self, parameter: inspect.Parameter) -> str: """ Each component needs a unique identifier string that we use for lookups from the `state` dictionary when we run the dependency injection. :param parameter: The parameter to check if that component can handle it. :return: Unique identifier. """ parameter_name = parameter.name.lower() try: annotation_name = parameter.annotation.__name__.lower() except AttributeError: annotation_name = parameter.annotation.__args__[0].__name__.lower() # If `resolve_parameter` includes `Parameter` then we use an identifier that is additionally parameterized by # the parameter name. args = inspect.signature(self.resolve).parameters.values() if inspect.Parameter in [arg.annotation for arg in args]: return annotation_name + ":" + parameter_name # Standard case is to use the class name, lowercased. return annotation_name def can_handle_parameter(self, parameter: inspect.Parameter) -> bool: """ The default behavior is for components to handle whatever class is used as the return annotation by the `resolve` method. You can override this for more customized styles, for example if you wanted name-based parameter resolution, or if you want to provide a value for a range of different types. Eg. Include the `Request` instance for any parameter named `request`. :param parameter: The parameter to check if that component can handle it. :return: True if this component can handle the given parameter. """ return_annotation = inspect.signature(self.resolve).return_annotation if return_annotation is inspect.Signature.empty: msg = ( f'Component "{self.__class__.__name__}" must include a return annotation on the `resolve()` method, or ' f"override `can_handle_parameter`" ) raise exceptions.ConfigurationError(msg) return parameter.annotation is return_annotation @abstractmethod def resolve(self): pass PK ! ptΝ flama/components/validation.pyimport datetime import inspect import typing import uuid import marshmallow from starlette import status from flama import codecs, exceptions, http, websockets from flama.components import Component from flama.exceptions import WebSocketException from flama.negotiation import ContentTypeNegotiator, WebSocketEncodingNegotiator from flama.routing import Route from flama.types import OptBool, OptDate, OptDateTime, OptFloat, OptInt, OptStr, OptUUID ValidatedPathParams = typing.NewType("ValidatedPathParams", dict) ValidatedQueryParams = typing.NewType("ValidatedQueryParams", dict) ValidatedRequestData = typing.TypeVar("ValidatedRequestData") class RequestDataComponent(Component): def __init__(self): self.negotiator = ContentTypeNegotiator( [codecs.JSONDataCodec(), codecs.URLEncodedCodec(), codecs.MultiPartCodec()] ) def can_handle_parameter(self, parameter: inspect.Parameter): return parameter.annotation is http.RequestData async def resolve(self, request: http.Request): content_type = request.headers.get("Content-Type") try: codec = self.negotiator.negotiate(content_type) except exceptions.NoCodecAvailable: raise exceptions.HTTPException(415) try: return await codec.decode(request) except exceptions.DecodeError as exc: raise exceptions.HTTPException(400, detail=str(exc)) class WebSocketMessageDataComponent(Component): def __init__(self): self.negotiator = WebSocketEncodingNegotiator([codecs.BytesCodec(), codecs.TextCodec(), codecs.JSONCodec()]) def can_handle_parameter(self, parameter: inspect.Parameter): return parameter.annotation is websockets.Data async def resolve(self, message: websockets.Message, websocket_encoding: websockets.Encoding): try: codec = self.negotiator.negotiate(websocket_encoding) return await codec.decode(message) except (exceptions.NoCodecAvailable, exceptions.DecodeError): raise WebSocketException(close_code=status.WS_1003_UNSUPPORTED_DATA) class ValidatePathParamsComponent(Component): async def resolve(self, request: http.Request, route: Route, path_params: http.PathParams) -> ValidatedPathParams: validator = type( "Validator", (marshmallow.Schema,), {f.name: f.schema for f in route.path_fields[request.method].values()} ) try: path_params = validator().load(path_params) except marshmallow.ValidationError as exc: raise exceptions.InputValidationError(detail=exc.normalized_messages()) return ValidatedPathParams(path_params) class ValidateQueryParamsComponent(Component): def resolve(self, request: http.Request, route: Route, query_params: http.QueryParams) -> ValidatedQueryParams: validator = type( "Validator", (marshmallow.Schema,), {f.name: f.schema for f in route.query_fields[request.method].values()} ) try: query_params = validator().load(dict(query_params), unknown=marshmallow.EXCLUDE) except marshmallow.ValidationError as exc: raise exceptions.InputValidationError(detail=exc.normalized_messages()) return ValidatedQueryParams(query_params) class ValidateRequestDataComponent(Component): def can_handle_parameter(self, parameter: inspect.Parameter): return parameter.annotation is ValidatedRequestData def resolve(self, request: http.Request, route: Route, data: http.RequestData): if not route.body_field[request.method] or not route.body_field[request.method].schema: return data validator = route.body_field[request.method].schema try: return validator.load(data) except marshmallow.ValidationError as exc: raise exceptions.InputValidationError(detail=exc.normalized_messages()) class PrimitiveParamComponent(Component): def can_handle_parameter(self, parameter: inspect.Parameter): return parameter.annotation in ( str, int, float, bool, OptStr, OptInt, OptFloat, OptBool, parameter.empty, http.QueryParam, http.PathParam, uuid.UUID, datetime.date, datetime.datetime, ) def resolve( self, parameter: inspect.Parameter, path_params: ValidatedPathParams, query_params: ValidatedQueryParams ): params = path_params if (parameter.name in path_params) else query_params if parameter.annotation in (OptInt, OptFloat, OptBool, OptStr) or parameter.default is not parameter.empty: kwargs = {"missing": parameter.default if parameter.default is not parameter.empty else None} else: kwargs = {"required": True} param_validator = { inspect.Signature.empty: marshmallow.fields.Field, int: marshmallow.fields.Integer, float: marshmallow.fields.Number, bool: marshmallow.fields.Boolean, str: marshmallow.fields.String, uuid.UUID: marshmallow.fields.UUID, datetime.date: marshmallow.fields.Date, datetime.datetime: marshmallow.fields.DateTime, OptInt: marshmallow.fields.Integer, OptFloat: marshmallow.fields.Number, OptBool: marshmallow.fields.Boolean, OptStr: marshmallow.fields.String, http.QueryParam: marshmallow.fields.String, http.PathParam: marshmallow.fields.String, OptUUID: marshmallow.fields.UUID, OptDate: marshmallow.fields.Date, OptDateTime: marshmallow.fields.DateTime, }[parameter.annotation](**kwargs) validator = type("Validator", (marshmallow.Schema,), {parameter.name: param_validator}) try: params = validator().load(params, unknown=marshmallow.EXCLUDE) except marshmallow.ValidationError as exc: raise exceptions.InputValidationError(detail=exc.normalized_messages()) return params.get(parameter.name, parameter.default) class CompositeParamComponent(Component): def can_handle_parameter(self, parameter: inspect.Parameter): return inspect.isclass(parameter.annotation) and issubclass(parameter.annotation, marshmallow.Schema) def resolve(self, parameter: inspect.Parameter, data: ValidatedRequestData): return data VALIDATION_COMPONENTS = ( RequestDataComponent(), WebSocketMessageDataComponent(), ValidatePathParamsComponent(), ValidateQueryParamsComponent(), ValidateRequestDataComponent(), PrimitiveParamComponent(), CompositeParamComponent(), ) PK ! X flama/endpoints.pyimport asyncio import typing from starlette import status from starlette.concurrency import run_in_threadpool from starlette.endpoints import HTTPEndpoint as BaseHTTPEndpoint from starlette.endpoints import WebSocketEndpoint as BaseWebSocketEndpoint from starlette.requests import Request from starlette.responses import Response from starlette.types import Receive, Send from starlette.websockets import WebSocket, WebSocketState from flama import exceptions, websockets from flama.responses import APIResponse from flama.validation import get_output_schema __all__ = ["HTTPEndpoint", "WebSocketEndpoint"] class HTTPEndpoint(BaseHTTPEndpoint): async def __call__(self, receive: Receive, send: Send): request = Request(self.scope, receive=receive) app = self.scope["app"] kwargs = self.scope.get("kwargs", {}) route, route_scope = app.router.get_route_from_scope(self.scope) state = { "scope": self.scope, "receive": receive, "send": send, "exc": None, "app": app, "path_params": route_scope["path_params"], "route": route, "request": request, } response = await self.dispatch(request, state, **kwargs) return await response(receive, send) async def dispatch(self, request: Request, state: typing.Dict, **kwargs) -> Response: handler_name = "get" if request.method == "HEAD" else request.method.lower() handler = getattr(self, handler_name, self.method_not_allowed) app = state["app"] injected_func = await app.injector.inject(handler, state) if asyncio.iscoroutinefunction(handler): response = await injected_func() else: response = await run_in_threadpool(injected_func) # Wrap response data with a proper response class if isinstance(response, (dict, list)): response = APIResponse(content=response, schema=get_output_schema(handler)) elif isinstance(response, str): response = APIResponse(content=response) elif response is None: response = APIResponse(content="") return response class WebSocketEndpoint(BaseWebSocketEndpoint): async def __call__(self, receive: Receive, send: Send) -> None: app = self.scope["app"] websocket = WebSocket(self.scope, receive, send) route, route_scope = app.router.get_route_from_scope(self.scope) state = { "scope": self.scope, "receive": receive, "send": send, "exc": None, "app": app, "path_params": route_scope["path_params"], "route": route, "websocket": websocket, "websocket_encoding": self.encoding, "websocket_code": status.WS_1000_NORMAL_CLOSURE, "websocket_message": None, } try: on_connect = await app.injector.inject(self.on_connect, state) await on_connect() except Exception as e: raise exceptions.WebSocketConnectionException("Error connecting socket") from e try: state["websocket_message"] = await websocket.receive() while websocket.client_state == WebSocketState.CONNECTED: on_receive = await app.injector.inject(self.on_receive, state) await on_receive() state["websocket_message"] = await websocket.receive() state["websocket_code"] = int(state["websocket_message"].get("code", status.WS_1000_NORMAL_CLOSURE)) except exceptions.WebSocketException as e: state["websocket_code"] = e.close_code except Exception as e: state["websocket_code"] = status.WS_1011_INTERNAL_ERROR raise e from None finally: on_disconnect = await app.injector.inject(self.on_disconnect, state) await on_disconnect() async def on_connect(self, websocket: websockets.WebSocket) -> None: """Override to handle an incoming websocket connection""" await websocket.accept() async def on_receive(self, websocket: websockets.WebSocket, data: websockets.Data) -> None: """Override to handle an incoming websocket message""" async def on_disconnect(self, websocket: websockets.WebSocket, websocket_code: websockets.Code) -> None: """Override to handle a disconnecting websocket""" await websocket.close(websocket_code) PK ! @K K flama/exceptions.pyimport typing import starlette.exceptions class DecodeError(Exception): """ Raised by a Codec when `decode` fails due to malformed syntax. """ def __init__(self, message, marker=None, base_format=None): Exception.__init__(self, message) self.message = message self.marker = marker self.base_format = base_format class NoReverseMatch(Exception): """ Raised by a Router when `reverse_url` is passed an invalid handler name. """ ... class NoCodecAvailable(Exception): ... class ConfigurationError(Exception): ... class ComponentNotFound(ConfigurationError): def __init__(self, parameter: str, resolver: typing.Optional[str] = None, *args, **kwargs): self.parameter = parameter self.resolver = resolver super().__init__(*args, **kwargs) def __str__(self): msg = f'No component able to handle parameter "{self.parameter}"' if self.resolver: msg += f' in function "{self.resolver}"' return msg class WebSocketException(Exception): def __init__(self, close_code: int): self.close_code = close_code class WebSocketConnectionException(Exception): ... class HTTPException(starlette.exceptions.HTTPException): ... class ValidationError(HTTPException): def __init__(self, detail: typing.Union[str, typing.Dict[str, typing.List[str]]], status_code: int = 400): super().__init__(status_code=status_code, detail=detail) class InputValidationError(ValidationError): ... class OutputValidationError(ValidationError): ... PK ! >n flama/http.pyimport typing from starlette.datastructures import URL, Headers, MutableHeaders, QueryParams from starlette.requests import Request from starlette.responses import ( FileResponse, HTMLResponse, JSONResponse, PlainTextResponse, RedirectResponse, Response, StreamingResponse, ) __all__ = [ "Method", "Scheme", "Host", "Port", "Path", "QueryString", "QueryParam", "Header", "Body", "PathParams", "PathParam", "RequestData", "URL", "QueryParams", "Headers", "MutableHeaders", "Request", "Response", "PlainTextResponse", "HTMLResponse", "JSONResponse", "FileResponse", "RedirectResponse", "StreamingResponse", "ReturnValue", ] Method = typing.NewType("Method", str) Scheme = typing.NewType("Scheme", str) Host = typing.NewType("Host", str) Port = typing.NewType("Port", int) Path = typing.NewType("Path", str) QueryString = typing.NewType("QueryString", str) QueryParam = typing.NewType("QueryParam", str) Header = typing.NewType("Header", str) Body = typing.NewType("Body", bytes) PathParams = typing.NewType("PathParams", dict) PathParam = typing.NewType("PathParam", str) RequestData = typing.TypeVar("RequestData") ReturnValue = typing.TypeVar("ReturnValue") PK ! 81 1 flama/injection.pyimport asyncio import functools import inspect import typing from flama import http, websockets from flama.components.asgi import ASGI_COMPONENTS, ASGIReceive, ASGIScope, ASGISend from flama.components.validation import VALIDATION_COMPONENTS from flama.exceptions import ComponentNotFound from flama.routing import Route __all__ = ["Injector"] class Injector: def __init__(self, components): from flama.applications import Flama self.components = list(ASGI_COMPONENTS + VALIDATION_COMPONENTS) + components self.initial = { "scope": ASGIScope, "receive": ASGIReceive, "send": ASGISend, "exc": Exception, "app": Flama, "path_params": http.PathParams, "route": Route, "request": http.Request, "response": http.Response, "websocket": websockets.WebSocket, "websocket_message": websockets.Message, "websocket_encoding": websockets.Encoding, "websocket_code": websockets.Code, } self.reverse_initial = {val: key for key, val in self.initial.items()} self.resolver_cache = {} def resolve_parameter( self, parameter, kwargs: typing.Dict, consts: typing.Dict, seen_state: typing.Set, parent_parameter=None ) -> typing.List[typing.Tuple]: """ Resolve a parameter by inferring the component that suits it or by adding a value to kwargs or consts. The list of steps returned consists of a resolver function, a boolean that indicates if the function is async, function kwargs and consts and the output name. :param parameter: parameter to be resolved. :param kwargs: kwargs that defines current context. :param consts: consts that defines current context. :param seen_state: cached state. :param parent_parameter: parent parameter. :return: list of steps to resolve the component. """ if parameter.annotation is http.ReturnValue: kwargs[parameter.name] = "return_value" return [] # Check if the parameter class exists in 'initial'. if parameter.annotation in self.reverse_initial: initial_kwarg = self.reverse_initial[parameter.annotation] kwargs[parameter.name] = initial_kwarg return [] # The 'Parameter' annotation can be used to get the parameter # itself. Used for example in 'Header' components that need the # parameter name in order to lookup a particular value. if parameter.annotation is inspect.Parameter: consts[parameter.name] = parent_parameter return [] for component in self.components: if component.can_handle_parameter(parameter): identity = component.identity(parameter) kwargs[parameter.name] = identity if identity not in seen_state: seen_state.add(identity) return self.resolve_component( resolver=component.resolve, output_name=identity, seen_state=seen_state, parent_parameter=parameter, ) return [] else: raise ComponentNotFound(parameter.name) def resolve_component( self, resolver, output_name: str, seen_state: typing.Set, parent_parameter=None ) -> typing.List[typing.Tuple]: """ Resolve a component injecting all dependencies needed in its resolver function. The list of steps returned consists of a resolver function, a boolean that indicates if the function is async, function kwargs and consts and the output name. :param resolver: component resolver function. :param output_name: name used for that component in application status. :param seen_state: cached status. :param parent_parameter: parent parameter. :return: list of steps to resolve the component. """ signature = inspect.signature(resolver) steps = [] kwargs = {} consts = {} if output_name is None: if signature.return_annotation in self.reverse_initial: output_name = self.reverse_initial[signature.return_annotation] else: output_name = "return_value" for parameter in signature.parameters.values(): try: steps += self.resolve_parameter( parameter, kwargs, consts, seen_state=seen_state, parent_parameter=parent_parameter ) except ComponentNotFound: raise ComponentNotFound(parameter=parameter.name, resolver=resolver.__class__.__name__) is_async = asyncio.iscoroutinefunction(resolver) step = (resolver, is_async, kwargs, consts, output_name) steps.append(step) return steps def resolve(self, func) -> typing.Tuple[typing.Dict, typing.Dict, typing.List]: """ Inspects a function and creates a resolution list of all components needed to run it. returning :param func: function to resolve. :return: the keyword arguments, consts for that function and the steps to resolve all components. """ seen_state = set(self.initial) steps = [] kwargs = {} consts = {} signature = inspect.signature(func) for parameter in signature.parameters.values(): try: steps += self.resolve_parameter(parameter, kwargs, consts, seen_state=seen_state) except ComponentNotFound: raise ComponentNotFound(parameter=parameter.name, resolver=func.__name__) return kwargs, consts, steps async def inject(self, func, state: typing.Dict) -> typing.Callable: """ Given a function, injects all components defined in its signature and returns the partialized function. :param func: function to be partialized. :param state: mapping of current application state to infer components state. :return: partialized function. """ try: func_kwargs, func_consts, steps = self.resolver_cache[func] except KeyError: func_kwargs, func_consts, steps = self.resolve(func) self.resolver_cache[func] = (func_kwargs, func_consts, steps) for resolver, is_async, kwargs, consts, output_name in steps: kw = {key: state[val] for key, val in kwargs.items()} kw.update(consts) if is_async: state[output_name] = await resolver(**kw) else: state[output_name] = resolver(**kw) kw = {key: state[val] for key, val in func_kwargs.items()} kw.update(func_consts) return functools.partial(func, **kw) PK ! 9W2y y flama/negotiation.pyimport typing from flama import exceptions from flama.codecs.base import HTTPCodec, WebsocketsCodec from flama.codecs.websockets import BytesCodec __all__ = ["ContentTypeNegotiator", "WebSocketEncodingNegotiator"] class ContentTypeNegotiator: def __init__(self, codecs: typing.Optional[typing.List[HTTPCodec]] = None): self.codecs = codecs or [] def negotiate(self, content_type: str = None) -> WebsocketsCodec: """ Given the value of a 'Content-Type' header, return the appropriate codec for decoding the request content. """ if content_type is None: return self.codecs[0] content_type = content_type.split(";")[0].strip().lower() main_type = content_type.split("/")[0] + "/*" wildcard_type = "*/*" for codec in self.codecs: if codec.media_type in (content_type, main_type, wildcard_type): return codec raise exceptions.NoCodecAvailable(f"Unsupported media in Content-Type header '{content_type}'") class WebSocketEncodingNegotiator: def __init__(self, codecs: typing.Optional[typing.List[WebsocketsCodec]] = None): self.codecs = codecs or [BytesCodec()] def negotiate(self, encoding: str = None) -> WebsocketsCodec: """ Given a websocket encoding, return the appropriate codec for decoding the request content. """ if encoding is None: return self.codecs[0] for codec in self.codecs: if codec.encoding == encoding: return codec raise exceptions.NoCodecAvailable(f"Unsupported websocket encoding '{encoding}'") PK ! '!kJ J flama/pagination/__init__.pyfrom flama.pagination.paginator import Paginator __all__ = ["Paginator"] PK ! @u^ flama/pagination/limit_offset.pyimport typing import marshmallow from flama.responses import APIResponse __all__ = ["LimitOffsetSchema", "LimitOffsetResponse"] class LimitOffsetMeta(marshmallow.Schema): limit = marshmallow.fields.Integer(title="limit", description="Number of retrieved items") offset = marshmallow.fields.Integer(title="offset", description="Collection offset") count = marshmallow.fields.Integer(title="count", description="Total number of items", allow_none=True) class LimitOffsetSchema(marshmallow.Schema): meta = marshmallow.fields.Nested(LimitOffsetMeta) data = marshmallow.fields.List(marshmallow.fields.Dict()) class LimitOffsetResponse(APIResponse): """ Response paginated based on a limit of elements and an offset. First 10 elements: /resource?offset=0&limit=10 Elements 20-30: /resource?offset=20&limit=10 """ default_limit = 10 def __init__( self, schema: marshmallow.Schema, offset: typing.Optional[typing.Union[int, str]] = None, limit: typing.Optional[typing.Union[int, str]] = None, count: typing.Optional[bool] = True, **kwargs ): self.offset = int(offset) if offset is not None else 0 self.limit = int(limit) if limit is not None else self.default_limit self.count = count super().__init__(schema=schema, **kwargs) def render(self, content: typing.Sequence): init = self.offset end = self.offset + self.limit return super().render( { "meta": {"limit": self.limit, "offset": self.offset, "count": len(content) if self.count else None}, "data": content[init:end], } ) PK ! Eg g flama/pagination/page_number.pyimport typing import marshmallow from flama.responses import APIResponse __all__ = ["PageNumberSchema", "PageNumberResponse"] class PageNumberMeta(marshmallow.Schema): page = marshmallow.fields.Integer(title="page", description="Current page number") page_size = marshmallow.fields.Integer(title="page_size", description="Page size") count = marshmallow.fields.Integer(title="count", description="Total number of items", allow_none=True) class PageNumberSchema(marshmallow.Schema): meta = marshmallow.fields.Nested(PageNumberMeta) data = marshmallow.fields.List(marshmallow.fields.Dict()) class PageNumberResponse(APIResponse): """ Response paginated based on a page number and a page size. First 10 elements: /resource?page=1 Third 10 elements: /resource?page=3 First 20 elements: /resource?page=1&page_size=20 """ default_page_size = 10 def __init__( self, schema: marshmallow.Schema, page: typing.Optional[typing.Union[int, str]] = None, page_size: typing.Optional[typing.Union[int, str]] = None, count: typing.Optional[bool] = True, **kwargs ): self.page_number = int(page) if page is not None else 1 self.page_size = int(page_size) if page_size is not None else self.default_page_size self.count = count super().__init__(schema=schema, **kwargs) def render(self, content: typing.Sequence): init = (self.page_number - 1) * self.page_size end = self.page_number * self.page_size return super().render( { "meta": { "page": self.page_number, "page_size": self.page_size, "count": len(content) if self.count else None, }, "data": content[init:end], } ) PK ! | flama/pagination/paginator.pyimport asyncio import functools import marshmallow from flama.pagination.limit_offset import LimitOffsetResponse, LimitOffsetSchema from flama.pagination.page_number import PageNumberResponse, PageNumberSchema from flama.validation import get_output_schema try: import forge except Exception: # pragma: no cover forge = None # type: ignore __all__ = ["Paginator"] class Paginator: @classmethod def page_number(cls, func): """ Decorator for adding pagination behavior to a view. That decorator produces a view based on page numbering and it adds three query parameters to control the pagination: page, page_size and count. Page has a default value of first page, page_size default value is defined in :class:`PageNumberResponse` and count defines if the response will define the total number of elements. The output schema is also modified by :class:`PageNumberSchema`, creating a new schema based on it but using the old output schema as the content of its data field. :param func: View to be decorated. :return: Decorated view. """ assert forge is not None, "`python-forge` must be installed to use OpenAPIResponse." resource_schema = get_output_schema(func) schema = type( "PageNumberPaginated" + resource_schema.__class__.__name__, # Add a prefix to avoid collision (PageNumberSchema,), {"data": marshmallow.fields.Nested(resource_schema, many=True)}, # Replace generic with resource schema )() forge_revision_list = ( forge.copy(func), forge.insert(forge.arg("page", default=None, type=int), index=-1), forge.insert(forge.arg("page_size", default=None, type=int), index=-1), forge.insert(forge.arg("count", default=True, type=bool), index=-1), forge.delete("kwargs"), forge.returns(schema), ) try: if asyncio.iscoroutinefunction(func): @forge.compose(*forge_revision_list) @functools.wraps(func) async def decorator(*args, page: int = None, page_size: int = None, count: bool = True, **kwargs): return PageNumberResponse( schema=schema, page=page, page_size=page_size, count=count, content=await func(*args, **kwargs) ) else: @forge.compose(*forge_revision_list) @functools.wraps(func) def decorator(*args, page: int = None, page_size: int = None, count: bool = True, **kwargs): return PageNumberResponse( schema=schema, page=page, page_size=page_size, count=count, content=func(*args, **kwargs) ) except ValueError as e: raise TypeError("Paginated views must define **kwargs param") from e return decorator @classmethod def limit_offset(cls, func): """ Decorator for adding pagination behavior to a view. That decorator produces a view based on limit-offset and it adds three query parameters to control the pagination: limit, offset and count. Offset has a default value of zero to start with the first element of the collection, limit default value is defined in :class:`LimitOffsetResponse` and count defines if the response will define the total number of elements. The output schema is also modified by :class:`LimitOffsetSchema`, creating a new schema based on it but using the old output schema as the content of its data field. :param func: View to be decorated. :return: Decorated view. """ assert forge is not None, "`python-forge` must be installed to use OpenAPIResponse." resource_schema = get_output_schema(func) schema = type( "LimitOffsetPaginated" + resource_schema.__class__.__name__, # Add a prefix to avoid collision (LimitOffsetSchema,), {"data": marshmallow.fields.Nested(resource_schema, many=True)}, # Replace generic with resource schema )() forge_revision_list = ( forge.copy(func), forge.insert(forge.arg("limit", default=None, type=int), index=-1), forge.insert(forge.arg("offset", default=None, type=int), index=-1), forge.insert(forge.arg("count", default=True, type=bool), index=-1), forge.delete("kwargs"), forge.returns(schema), ) try: if asyncio.iscoroutinefunction(func): @forge.compose(*forge_revision_list) @functools.wraps(func) async def decorator(*args, limit: int = None, offset: int = None, count: bool = True, **kwargs): return LimitOffsetResponse( schema=schema, limit=limit, offset=offset, count=count, content=await func(*args, **kwargs) ) else: @forge.compose(*forge_revision_list) @functools.wraps(func) def decorator(*args, limit: int = None, offset: int = None, count: bool = True, **kwargs): return LimitOffsetResponse( schema=schema, limit=limit, offset=offset, count=count, content=func(*args, **kwargs) ) except ValueError as e: raise TypeError("Paginated views must define **kwargs param") from e return decorator PK ! gV"> "> flama/resources.pyimport datetime import logging import re import typing import uuid import marshmallow from flama.exceptions import HTTPException from flama.pagination import Paginator from flama.responses import APIResponse from flama.types import Model, PrimaryKey, ResourceMeta, ResourceMethodMeta try: import sqlalchemy from sqlalchemy.dialects import postgresql except Exception: # pragma: no cover raise AssertionError("`sqlalchemy` must be installed to use resources") from None try: import databases except Exception: # pragma: no cover raise AssertionError("`databases` must be installed to use resources") from None logger = logging.getLogger(__name__) __all__ = ["resource_method", "BaseResource", "CRUDResource", "CRUDListResource", "CRUDListDropResource"] PK_MAPPING = { sqlalchemy.Integer: int, sqlalchemy.String: str, sqlalchemy.Date: datetime.date, sqlalchemy.DateTime: datetime.datetime, postgresql.UUID: uuid.UUID, } class DropCollection(marshmallow.Schema): deleted = marshmallow.fields.Integer(title="deleted", description="Number of deleted elements", required=True) def resource_method(path: str, methods: typing.List[str] = None, name: str = None, **kwargs) -> typing.Callable: def wrapper(func: typing.Callable) -> typing.Callable: func._meta = ResourceMethodMeta( path=path, methods=methods if methods is not None else ["GET"], name=name, kwargs=kwargs ) return func return wrapper class ResourceRoutes: """Routes descriptor""" def __init__(self, methods: typing.Dict[str, typing.Callable]): self.methods = methods def __get__(self, instance, owner) -> typing.Dict[str, typing.Callable]: return self.methods class BaseResource(type): METHODS = () # type: typing.Sequence[str] def __new__(mcs, name, bases, namespace): # Get database and replace it with a read-only descriptor database = mcs._get_attribute("database", name, namespace, bases) namespace["database"] = property(lambda self: self._meta.database) # Get model and replace it with a read-only descriptor model = mcs._get_model(name, namespace, bases) namespace["model"] = property(lambda self: self._meta.model.table) # Define resource names resource_name, verbose_name = mcs._get_resource_name(name, namespace) # Default columns and order for admin interface columns = namespace.pop("columns", [model.primary_key.name]) order = namespace.pop("order", model.primary_key.name) # Get input and output schemas input_schema, output_schema = mcs._get_schemas(name, namespace, bases) namespace["_meta"] = ResourceMeta( database=database, model=model, name=resource_name, verbose_name=verbose_name, input_schema=input_schema, output_schema=output_schema, columns=columns, order=order, ) # Create CRUD methods and routes mcs._add_methods(resource_name, verbose_name, namespace, database, input_schema, output_schema, model) mcs._add_routes(namespace) return super().__new__(mcs, name, bases, namespace) @classmethod def _get_attribute( mcs, attribute: str, name: str, namespace: typing.Dict[str, typing.Any], bases: typing.Sequence[typing.Any] ) -> typing.Any: try: return namespace.pop(attribute) except KeyError: for base in bases: if hasattr(base, "_meta") and hasattr(base._meta, attribute): return getattr(base._meta, attribute) elif hasattr(base, attribute): return getattr(base, attribute) raise AttributeError(f"{name} needs to define attribute '{attribute}'") @classmethod def _get_resource_name(mcs, name: str, namespace: typing.Dict[str, typing.Any]) -> typing.Tuple[str, str]: resource_name = namespace.pop("name", name.lower()) # Check resource name validity if re.match("[a-zA-Z][-_a-zA-Z]", resource_name) is None: raise AttributeError(f"Invalid resource name '{resource_name}'") return resource_name, namespace.pop("verbose_name", resource_name) @classmethod def _get_model( mcs, name: str, namespace: typing.Dict[str, typing.Any], bases: typing.Sequence[typing.Any] ) -> Model: model = mcs._get_attribute("model", name, namespace, bases) # Already defined model probably because resource inheritance, so no need to create it if isinstance(model, Model): return model # Resource define model as a sqlalchemy Table, so extract necessary info from it elif isinstance(model, sqlalchemy.Table): # Get model primary key model_pk = list(sqlalchemy.inspect(model).primary_key.columns.values()) # Check primary key exists and is a single column if len(model_pk) != 1: raise AttributeError(f"{name} model must define a single-column primary key") model_pk = model_pk[0] model_pk_name = model_pk.name # Check primary key is a valid type try: model_pk_type = PK_MAPPING[model_pk.type.__class__] except KeyError: raise AttributeError( f"{name} model primary key must be any of {', '.join((i.__name__ for i in PK_MAPPING.keys()))}" ) return Model(table=model, primary_key=PrimaryKey(model_pk_name, model_pk_type)) raise AttributeError(f"{name} model must be a valid SQLAlchemy Table instance or a Model instance") @classmethod def _get_schemas( mcs, name: str, namespace: typing.Dict[str, typing.Any], bases: typing.Sequence[typing.Any] ) -> typing.Tuple[marshmallow.Schema, marshmallow.Schema]: try: schema = mcs._get_attribute("schema", name, namespace, bases) input_schema = schema output_schema = schema except AttributeError: try: input_schema = mcs._get_attribute("input_schema", name, namespace, bases) output_schema = mcs._get_attribute("output_schema", name, namespace, bases) except AttributeError: raise AttributeError( f"{name} needs to define attribute 'schema' or the pair 'input_schema' and 'output_schema'" ) return input_schema, output_schema @classmethod def _add_routes(mcs, namespace: typing.Dict[str, typing.Any]): methods = {name: m for name, m in namespace.items() if getattr(m, "_meta", False) and not name.startswith("_")} routes = ResourceRoutes(methods) namespace["routes"] = routes @classmethod def _add_methods( mcs, name: str, verbose_name: str, namespace: typing.Dict[str, typing.Any], database: "databases.Database", input_schema: marshmallow.Schema, output_schema: marshmallow.Schema, model: Model, ): # Get available methods methods = [getattr(mcs, f"_add_{method}") for method in mcs.METHODS if hasattr(mcs, f"_add_{method}")] # Generate CRUD methods crud_namespace = { func_name: func for method in methods for func_name, func in method( name=name, verbose_name=verbose_name, database=database, input_schema=input_schema, output_schema=output_schema, model=model, ).items() } # Preserve already defined methods crud_namespace.update( {method: crud_namespace[f"_{method}"] for method in mcs.METHODS if method not in namespace} ) namespace.update(crud_namespace) class CreateMixin: @classmethod def _add_create( mcs, name: str, verbose_name: str, database: databases.Database, input_schema: marshmallow.Schema, output_schema: marshmallow.Schema, **kwargs, ) -> typing.Dict[str, typing.Any]: @resource_method("/", methods=["POST"], name=f"{name}-create") @database.transaction() async def create(self, element: input_schema) -> output_schema: query = self.model.insert().values(**element) await self.database.execute(query) return APIResponse(schema=output_schema(), content=element, status_code=201) create.__doc__ = f""" tags: - {verbose_name} summary: Create a new document. description: Create a new document in this resource. responses: 201: description: Document created successfully. """ return {"_create": create} class RetrieveMixin: @classmethod def _add_retrieve( mcs, name: str, verbose_name: str, output_schema: marshmallow.Schema, model: Model, **kwargs ) -> typing.Dict[str, typing.Any]: @resource_method("/{element_id}/", methods=["GET"], name=f"{name}-retrieve") async def retrieve(self, element_id: model.primary_key.type) -> output_schema: query = self.model.select().where(self.model.c[model.primary_key.name] == element_id) element = await self.database.fetch_one(query) if element is None: raise HTTPException(status_code=404) return dict(element) retrieve.__doc__ = f""" tags: - {verbose_name} summary: Retrieve a document. description: Retrieve a document from this resource. responses: 200: description: Document found. 404: description: Document not found. """ return {"_retrieve": retrieve} class UpdateMixin: @classmethod def _add_update( mcs, name: str, verbose_name: str, database: databases.Database, input_schema: marshmallow.Schema, output_schema: marshmallow.Schema, model: Model, **kwargs, ) -> typing.Dict[str, typing.Any]: @resource_method("/{element_id}/", methods=["PUT"], name=f"{name}-update") @database.transaction() async def update(self, element_id: model.primary_key.type, element: input_schema) -> output_schema: query = sqlalchemy.select([sqlalchemy.exists().where(self.model.c[model.primary_key.name] == element_id)]) exists = next((i for i in (await self.database.fetch_one(query)).values())) if not exists: raise HTTPException(status_code=404) query = self.model.update().where(self.model.c[model.primary_key.name] == element_id).values(**element) await self.database.execute(query) return {model.primary_key.name: element_id, **element} update.__doc__ = f""" tags: - {verbose_name} summary: Update a document. description: Update a document in this resource. responses: 200: description: Document updated successfully. 404: description: Document not found. """ return {"_update": update} class DeleteMixin: @classmethod def _add_delete( mcs, name: str, verbose_name: str, database: databases.Database, model: Model, **kwargs ) -> typing.Dict[str, typing.Any]: @resource_method("/{element_id}/", methods=["DELETE"], name=f"{name}-delete") @database.transaction() async def delete(self, element_id: model.primary_key.type): query = sqlalchemy.select([sqlalchemy.exists().where(self.model.c[model.primary_key.name] == element_id)]) exists = next((i for i in (await self.database.fetch_one(query)).values())) if not exists: raise HTTPException(status_code=404) query = self.model.delete().where(self.model.c[model.primary_key.name] == element_id) await self.database.execute(query) return APIResponse(status_code=204) delete.__doc__ = f""" tags: - {verbose_name} summary: Delete a document. description: Delete a document in this resource. responses: 204: description: Document deleted successfully. 404: description: Document not found. """ return {"_delete": delete} class ListMixin: @classmethod def _add_list( mcs, name: str, verbose_name: str, output_schema: marshmallow.Schema, **kwargs ) -> typing.Dict[str, typing.Any]: async def filter(self, *clauses, **filters) -> typing.List[typing.Dict]: query = self.model.select() where_clauses = tuple(clauses) + tuple(self.model.c[k] == v for k, v in filters.items()) if where_clauses: query = query.where(sqlalchemy.and_(*where_clauses)) return [dict(row) for row in await self.database.fetch_all(query)] @resource_method("/", methods=["GET"], name=f"{name}-list") @Paginator.page_number async def list(self, **kwargs) -> output_schema(many=True): return await self._filter() # noqa list.__doc__ = f""" tags: - {verbose_name} summary: List collection. description: List resource collection. responses: 200: description: List collection items. """ return {"_list": list, "_filter": filter} class DropMixin: @classmethod def _add_drop( mcs, name: str, verbose_name: str, database: databases.Database, model: Model, **kwargs ) -> typing.Dict[str, typing.Any]: @resource_method("/", methods=["DELETE"], name=f"{name}-drop") @database.transaction() async def drop(self) -> DropCollection: query = sqlalchemy.select([sqlalchemy.func.count(self.model.c[model.primary_key.name])]) result = next((i for i in (await self.database.fetch_one(query)).values())) query = self.model.delete() await self.database.execute(query) return APIResponse(schema=DropCollection(), content={"deleted": result}, status_code=204) drop.__doc__ = f""" tags: - {verbose_name} summary: Drop collection. description: Drop resource collection. responses: 204: description: Collection dropped successfully. """ return {"_drop": drop} class CRUDResource(BaseResource, CreateMixin, RetrieveMixin, UpdateMixin, DeleteMixin): METHODS = ("create", "retrieve", "update", "delete") class CRUDListResource(BaseResource, CreateMixin, RetrieveMixin, UpdateMixin, DeleteMixin, ListMixin): METHODS = ("create", "retrieve", "update", "delete", "list") class CRUDListDropResource(BaseResource, CreateMixin, RetrieveMixin, UpdateMixin, DeleteMixin, ListMixin, DropMixin): METHODS = ("create", "retrieve", "update", "delete", "list", "drop") PK ! JT flama/responses.pyimport typing import marshmallow from starlette.responses import JSONResponse __all__ = ["APIResponse", "APIErrorResponse", "APIError"] class APIError(marshmallow.Schema): status_code = marshmallow.fields.Integer(title="status_code", description="HTTP status code", required=True) detail = marshmallow.fields.Raw(title="detail", description="Error detail", required=True) error = marshmallow.fields.String(title="type", description="Exception or error type") class APIResponse(JSONResponse): media_type = "application/json" def __init__(self, schema: typing.Optional[marshmallow.Schema] = None, *args, **kwargs): self.schema = schema super().__init__(*args, **kwargs) def render(self, content: typing.Any): # Use output schema to validate and format data if self.schema is not None: content = self.schema.dump(content) return super().render(content) class APIErrorResponse(APIResponse): def __init__( self, detail: typing.Any, status_code: int = 400, exception: typing.Optional[Exception] = None, *args, **kwargs ): content = { "status_code": status_code, "detail": detail, "error": str(exception.__class__.__name__) if exception is not None else None, } super().__init__(schema=APIError(), content=content, status_code=status_code, *args, **kwargs) self.detail = detail self.exception = exception PK ! sV/ / flama/routing.pyimport asyncio import inspect import typing from functools import wraps import marshmallow import starlette.routing from starlette.concurrency import run_in_threadpool from starlette.routing import Match, Mount from starlette.types import ASGIApp, ASGIInstance, Receive, Scope, Send from flama import http, websockets from flama.components import Component from flama.responses import APIResponse from flama.types import Field, FieldLocation, HTTPMethod, OptBool, OptFloat, OptInt, OptStr from flama.validation import get_output_schema if typing.TYPE_CHECKING: from flama.resources import BaseResource __all__ = ["Route", "WebSocketRoute", "Router"] FieldsMap = typing.Dict[str, Field] MethodsMap = typing.Dict[str, FieldsMap] PATH_SCHEMA_MAPPING = { inspect.Signature.empty: lambda *args, **kwargs: None, int: marshmallow.fields.Integer, float: marshmallow.fields.Number, str: marshmallow.fields.String, bool: marshmallow.fields.Boolean, http.PathParam: marshmallow.fields.String, } QUERY_SCHEMA_MAPPING = { inspect.Signature.empty: lambda *args, **kwargs: None, int: marshmallow.fields.Integer, float: marshmallow.fields.Number, bool: marshmallow.fields.Boolean, str: marshmallow.fields.String, OptInt: marshmallow.fields.Integer, OptFloat: marshmallow.fields.Number, OptBool: marshmallow.fields.Boolean, OptStr: marshmallow.fields.String, http.QueryParam: marshmallow.fields.String, } class FieldsMixin: def _get_fields( self, router: "Router" ) -> typing.Tuple[MethodsMap, MethodsMap, typing.Dict[str, Field], typing.Dict[str, typing.Any]]: query_fields: MethodsMap = {} path_fields: MethodsMap = {} body_field: typing.Dict[str, Field] = {} output_field: typing.Dict[str, typing.Any] = {} if hasattr(self, "methods") and self.methods is not None: if inspect.isclass(self.endpoint): # HTTP endpoint methods = [(m, getattr(self.endpoint, m.lower() if m != "HEAD" else "get")) for m in self.methods] else: # HTTP function methods = [(m, self.endpoint) for m in self.methods] if self.methods else [] else: # Websocket methods = [("GET", self.endpoint)] for m, h in methods: query_fields[m], path_fields[m], body_field[m], output_field[m] = self._get_fields_from_handler(h, router) return query_fields, path_fields, body_field, output_field def _get_parameters_from_handler( self, handler: typing.Callable, router: "Router" ) -> typing.Dict[str, inspect.Parameter]: parameters = {} for name, parameter in inspect.signature(handler).parameters.items(): for component in router.components: if component.can_handle_parameter(parameter): parameters.update(self._get_parameters_from_handler(component.resolve, router)) break else: parameters[name] = parameter return parameters def _get_fields_from_handler( self, handler: typing.Callable, router: "Router" ) -> typing.Tuple[FieldsMap, FieldsMap, Field, typing.Any]: query_fields: FieldsMap = {} path_fields: FieldsMap = {} body_field: Field = None # Iterate over all params for name, param in self._get_parameters_from_handler(handler, router).items(): if name in ("self", "cls"): continue # Matches as path param if name in self.param_convertors.keys(): try: schema = PATH_SCHEMA_MAPPING[param.annotation] except KeyError: schema = marshmallow.fields.String path_fields[name] = Field( name=name, location=FieldLocation.path, schema=schema(required=True), required=True ) # Matches as query param elif param.annotation in QUERY_SCHEMA_MAPPING: if param.annotation in (OptInt, OptFloat, OptBool, OptStr) or param.default is not param.empty: required = False kwargs = {"missing": param.default if param.default is not param.empty else None} else: required = True kwargs = {"required": True} query_fields[name] = Field( name=name, location=FieldLocation.query, schema=QUERY_SCHEMA_MAPPING[param.annotation](**kwargs), required=required, ) # Body params elif inspect.isclass(param.annotation) and issubclass(param.annotation, marshmallow.Schema): body_field = Field(name=name, location=FieldLocation.body, schema=param.annotation()) output_field = inspect.signature(handler).return_annotation return query_fields, path_fields, body_field, output_field class Route(starlette.routing.Route, FieldsMixin): def __init__(self, path: str, endpoint: typing.Callable, router: "Router", *args, **kwargs): super().__init__(path, endpoint=endpoint, **kwargs) # Replace function with another wrapper that uses the injector if inspect.isfunction(endpoint) or inspect.ismethod(endpoint): self.app = self.endpoint_wrapper(endpoint) if self.methods is None: self.methods = [m for m in HTTPMethod.__members__.keys() if hasattr(self, m.lower())] self.query_fields, self.path_fields, self.body_field, self.output_field = self._get_fields(router) def endpoint_wrapper(self, endpoint: typing.Callable) -> ASGIApp: """ Wraps a http function into ASGI application. """ @wraps(endpoint) def _app(scope: Scope) -> ASGIInstance: async def awaitable(receive: Receive, send: Send) -> None: app = scope["app"] route, route_scope = app.router.get_route_from_scope(scope) state = { "scope": scope, "receive": receive, "send": send, "exc": None, "app": app, "path_params": route_scope["path_params"], "route": route, "request": http.Request(scope, receive), } injected_func = await app.injector.inject(endpoint, state) if asyncio.iscoroutinefunction(endpoint): response = await injected_func() else: response = await run_in_threadpool(injected_func) # Wrap response data with a proper response class if isinstance(response, (dict, list)): response = APIResponse(content=response, schema=get_output_schema(endpoint)) elif isinstance(response, str): response = APIResponse(content=response) elif response is None: response = APIResponse(content="") await response(receive, send) return awaitable return _app class WebSocketRoute(starlette.routing.WebSocketRoute, FieldsMixin): def __init__(self, path: str, endpoint: typing.Callable, router: "Router", *args, **kwargs): super().__init__(path, endpoint=endpoint, **kwargs) # Replace function with another wrapper that uses the injector if inspect.isfunction(endpoint): self.app = self.endpoint_wrapper(endpoint) self.query_fields, self.path_fields, self.body_field, self.output_field = self._get_fields(router) def endpoint_wrapper(self, endpoint: typing.Callable) -> ASGIApp: """ Wraps websocket function into ASGI application. """ @wraps(endpoint) def _app(scope: Scope) -> ASGIInstance: async def awaitable(receive: Receive, send: Send) -> None: app = scope["app"] route, route_scope = app.router.get_route_from_scope(scope) state = { "scope": scope, "receive": receive, "send": send, "exc": None, "app": app, "path_params": route_scope["path_params"], "route": route, "websocket": websockets.WebSocket(scope, receive, send), } injected_func = await app.injector.inject(endpoint, state) kwargs = scope.get("kwargs", {}) await injected_func(**kwargs) return awaitable return _app class Router(starlette.routing.Router): def __init__(self, components: typing.Optional[typing.List[Component]] = None, *args, **kwargs): super().__init__(*args, **kwargs) if components is None: components = [] self.components = components def add_route( self, path: str, endpoint: typing.Callable, methods: typing.List[str] = None, name: str = None, include_in_schema: bool = True, ): self.routes.append( Route(path, endpoint=endpoint, methods=methods, name=name, include_in_schema=include_in_schema, router=self) ) def add_websocket_route(self, path: str, endpoint: typing.Callable, name: str = None): self.routes.append(WebSocketRoute(path, endpoint=endpoint, name=name, router=self)) def add_resource(self, path: str, resource: "BaseResource"): # Handle class or instance objects if inspect.isclass(resource): # noqa resource = resource() for name, route in resource.routes.items(): route_path = path + resource._meta.name + route._meta.path route_func = getattr(resource, name) name = route._meta.name if route._meta.name is not None else f"{resource._meta.name}-{route.__name__}" self.add_route(route_path, route_func, route._meta.methods, name, **route._meta.kwargs) def mount(self, path: str, app: ASGIApp, name: str = None) -> None: if isinstance(app, Router): app.components = self.components path = path.rstrip("/") route = Mount(path, app=app, name=name) self.routes.append(route) def route( self, path: str, methods: typing.List[str] = None, name: str = None, include_in_schema: bool = True ) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: self.add_route(path, func, methods=methods, name=name, include_in_schema=include_in_schema) return func return decorator def websocket_route(self, path: str, name: str = None) -> typing.Callable: def decorator(func: typing.Callable) -> typing.Callable: self.add_websocket_route(path, func, name=name) return func return decorator def resource(self, path: str) -> typing.Callable: def decorator(resource: "BaseResource") -> "BaseResource": self.add_resource(path, resource=resource) return resource return decorator def get_route_from_scope(self, scope) -> typing.Tuple[Route, typing.Optional[typing.Dict]]: partial = None for route in self.routes: if isinstance(route, Mount): path = scope.get("path", "") root_path = scope.pop("root_path", "") scope["path"] = root_path + path match, child_scope = route.matches(scope) if match == Match.FULL: scope.update(child_scope) if isinstance(route, Mount): route, mount_scope = route.app.get_route_from_scope(scope) return route, mount_scope return route, scope elif match == Match.PARTIAL and partial is None: partial = route partial_scope = child_scope if partial is not None: scope.update(partial_scope) return partial, scope return self.not_found, None PK ! K#( #( flama/schemas.pyimport inspect import itertools import os import typing from collections import defaultdict from string import Template import marshmallow from starlette import routing, schemas from starlette.responses import HTMLResponse from flama.responses import APIError from flama.types import EndpointInfo from flama.utils import dict_safe_add try: import apispec except Exception: # pragma: no cover apispec = None # type: ignore try: import yaml except Exception: # pragma: no cover yaml = None # type: ignore __all__ = ["OpenAPIResponse", "SchemaGenerator", "SchemaMixin"] if yaml is not None and apispec is not None: from apispec.core import YAMLDumper as BaseYAMLDumper class YAMLDumper(BaseYAMLDumper): def ignore_aliases(self, data): return True class OpenAPIResponse(schemas.OpenAPIResponse): def render(self, content: typing.Any) -> bytes: assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse." assert apispec is not None, "`apispec` must be installed to use OpenAPIResponse." assert isinstance(content, dict), "The schema passed to OpenAPIResponse should be a dictionary." return yaml.dump(content, default_flow_style=False, Dumper=YAMLDumper).encode("utf-8") class SchemaRegistry(dict): def __init__(self, spec, *args, **kwargs): super().__init__(*args, **kwargs) self.spec = spec self.openapi = self.spec.plugins[0].openapi def __getitem__(self, item): try: schema = super().__getitem__(item) except KeyError: component_schema = item if inspect.isclass(item) else item.__class__ self.spec.definition(name=component_schema.__name__, schema=component_schema) schema = self.openapi.resolve_schema_dict(item) super().__setitem__(item, schema) return schema class SchemaGenerator(schemas.BaseSchemaGenerator): def __init__(self, title: str, version: str, description: str, openapi_version="3.0.0"): assert apispec is not None, "`apispec` must be installed to use SchemaGenerator." from apispec.ext.marshmallow import MarshmallowPlugin self.spec = apispec.APISpec( title=title, version=version, openapi_version=openapi_version, info={"description": description}, plugins=[MarshmallowPlugin()], ) self.openapi = self.spec.plugins[0].openapi # Builtin definitions self.schemas = SchemaRegistry(self.spec) def get_endpoints( self, routes: typing.List[routing.BaseRoute], base_path: str = "" ) -> typing.Dict[str, typing.Sequence[EndpointInfo]]: """ Given the routes, yields the following information: - path eg: /users/ - http_method one of 'get', 'post', 'put', 'patch', 'delete', 'options' - func method ready to extract the docstring """ endpoints_info: typing.Dict[str, typing.Sequence[EndpointInfo]] = defaultdict(list) for route in routes: if isinstance(route, routing.Route) and route.include_in_schema: _, path, _ = routing.compile_path(base_path + route.path) if inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint): for method in route.methods or ["GET"]: if method == "HEAD": continue endpoints_info[path].append( EndpointInfo( path=path, method=method.lower(), func=route.endpoint, query_fields=route.query_fields.get(method, {}), path_fields=route.path_fields.get(method, {}), body_field=route.body_field.get(method), output_field=route.output_field.get(method), ) ) else: for method in ["get", "post", "put", "patch", "delete", "options"]: if not hasattr(route.endpoint, method): continue func = getattr(route.endpoint, method) endpoints_info[path].append( EndpointInfo( path=path, method=method.lower(), func=func, query_fields=route.query_fields.get(method.upper(), {}), path_fields=route.path_fields.get(method.upper(), {}), body_field=route.body_field.get(method.upper()), output_field=route.output_field.get(method.upper()), ) ) elif isinstance(route, routing.Mount): endpoints_info.update(self.get_endpoints(route.routes, base_path=route.path)) return endpoints_info def _add_endpoint_parameters(self, endpoint: EndpointInfo, schema: typing.Dict): schema["parameters"] = [ self.openapi.field2parameter(field.schema, name=field.name, default_in=field.location.name) for field in itertools.chain(endpoint.query_fields.values(), endpoint.path_fields.values()) ] def _add_endpoint_body(self, endpoint: EndpointInfo, schema: typing.Dict): component_schema = ( endpoint.body_field.schema if inspect.isclass(endpoint.body_field.schema) else endpoint.body_field.schema.__class__ ) self.spec.definition(name=component_schema.__name__, schema=component_schema) dict_safe_add( schema, self.openapi.schema2jsonschema(endpoint.body_field.schema), "requestBody", "content", "application/json", "schema", ) def _add_endpoint_response(self, endpoint: EndpointInfo, schema: typing.Dict): response_codes = list(schema.get("responses", {}).keys()) main_response = response_codes[0] if response_codes else 200 dict_safe_add( schema, self.schemas[endpoint.output_field], "responses", main_response, "content", "application/json", "schema", ) def _add_endpoint_default_response(self, schema: typing.Dict): dict_safe_add(schema, self.schemas[APIError], "responses", "default", "content", "application/json", "schema") # Default description schema["responses"]["default"]["description"] = schema["responses"]["default"].get( "description", "Unexpected error." ) def get_endpoint_schema(self, endpoint: EndpointInfo) -> typing.Dict[str, typing.Any]: schema = self.parse_docstring(endpoint.func) # Query and Path parameters if endpoint.query_fields or endpoint.path_fields: self._add_endpoint_parameters(endpoint, schema) # Body if endpoint.body_field: self._add_endpoint_body(endpoint, schema) # Response if endpoint.output_field and ( (inspect.isclass(endpoint.output_field) and issubclass(endpoint.output_field, marshmallow.Schema)) or isinstance(endpoint.output_field, marshmallow.Schema) ): self._add_endpoint_response(endpoint, schema) # Default response self._add_endpoint_default_response(schema) return schema def get_schema(self, routes: typing.List[routing.BaseRoute]) -> typing.Dict[str, typing.Any]: endpoints_info = self.get_endpoints(routes) for path, endpoints in endpoints_info.items(): self.spec.add_path(path=path, operations={e.method: self.get_endpoint_schema(e) for e in endpoints}) return self.spec.to_dict() class SchemaMixin: def add_schema_docs_routes( self, title: str = "", version: str = "", description: str = "", schema: typing.Optional[str] = "/schema/", docs: typing.Optional[str] = "/docs/", redoc: typing.Optional[str] = None, ): # Schema self.title = title self.version = version self.description = description self.schema_url = schema if self.schema_url: self.add_schema_route() # Docs (Swagger UI) self.docs_url = docs if self.docs_url: self.add_docs_route() # Redoc self.redoc_url = redoc if self.redoc_url: self.add_redoc_route() @property def schema_generator(self): if not hasattr(self, "_schema_generator"): self._schema_generator = SchemaGenerator( title=self.title, version=self.version, description=self.description ) return self._schema_generator @property def schema(self): return self.schema_generator.get_schema(self.routes) def add_schema_route(self): def schema(): return OpenAPIResponse(self.schema) self.add_route(path=self.schema_url, route=schema, methods=["GET"], include_in_schema=False) def add_docs_route(self): def swagger_ui() -> HTMLResponse: with open(os.path.join(os.path.dirname(__file__), "templates/swagger_ui.html")) as f: content = Template(f.read()).substitute(title=self.title, schema_url=self.schema_url) return HTMLResponse(content) self.add_route(path=self.docs_url, route=swagger_ui, methods=["GET"], include_in_schema=False) def add_redoc_route(self): def redoc() -> HTMLResponse: with open(os.path.join(os.path.dirname(__file__), "templates/redoc.html")) as f: content = Template(f.read()).substitute(title=self.title, schema_url=self.schema_url) return HTMLResponse(content) self.add_route(path=self.redoc_url, route=redoc, methods=["GET"], include_in_schema=False) PK ! BK} } flama/templates/redoc.html
vѴ&8iR5_Ujv(4)s\'cctف зhwA2&sۃA/UM9`S2~GF懋o>Dmwl1U?arW
F;%KiPK !H= flama-0.8.2.dist-info/RECORDuǒXD-Ѓ7Y `C}:F5'{y3o/by?y䴾)4l2k^5"ܘO }_Q0];+mgP
ЋZ{U>ɍtIhbn@ BBhn4(Ϲ}Gezq`[٩)Z #W^Ө8Rfꥪ[kI0sҁÆZ _Y4
WF5(\[١Pb:u
'~G2s L@˱k`
ޠYFQaT7NR!LbL#טNU}mz*`ިxV9\V3qY[5蜛UCۆG2[7yC=9B3\62J mBjrb$Jobw(|LˎLSSn]")y_SCd_/)z߸%$9 AbJ*x
@)e8뀔UhLФb-.4mkm8;%ۻnx
z#؋Z Ku009-mM_ԋ mWk{t?Owe
x\ÕvIR
AcVyڤ