import functools import grp import json import os import pwd import traceback from typing import Callable, List from flask import current_app, stream_with_context from flask_kerberos import requires_authentication from ceo_common.logger_factory import logger_factory from ceod.transactions import AbstractTransaction logger = logger_factory(__name__) def requires_authentication_no_realm(f: Callable) -> Callable: """ Like requires_authentication, but strips the realm out of the principal string. e.g. user1@CSCLUB.UWATERLOO.CA -> user1 """ @requires_authentication @functools.wraps(f) def wrapper(principal: str, *args, **kwargs): user = principal[:principal.index('@')] logger.debug(f'received request from {user}') return f(user, *args, **kwargs) return wrapper def user_is_in_group(user: str, group: str) -> bool: """Returns True if `user` is in `group`, False otherwise.""" return user in grp.getgrnam(group).gr_mem def authz_restrict_to_groups(f: Callable, allowed_groups: List[str]) -> Callable: """ Restrict an endpoint to users who belong to one or more of the specified groups. """ allowed_group_ids = [grp.getgrnam(g).gr_gid for g in allowed_groups] @requires_authentication_no_realm @functools.wraps(f) def wrapper(_username: str, *args, **kwargs): # we need to call the argument _username to avoid name clashes with # the arguments of f username = _username if username.startswith('ceod/'): # ceod services are always allowed to make internal calls return f(*args, **kwargs) for gid in os.getgrouplist(username, pwd.getpwnam(username).pw_gid): if gid in allowed_group_ids: return f(*args, **kwargs) logger.debug( f"User '{username}' denied since they are not in one of {allowed_groups}" ) return { 'error': f'You must be in one of {allowed_groups}' }, 403 return wrapper def authz_restrict_to_staff(f: Callable) -> Callable: """A decorator to restrict an endpoint to staff members.""" allowed_groups = ['office', 'staff', 'adm'] return authz_restrict_to_groups(f, allowed_groups) def authz_restrict_to_syscom(f: Callable) -> Callable: """A decorator to restrict an endpoint to syscom members.""" allowed_groups = ['syscom'] return authz_restrict_to_groups(f, allowed_groups) def create_streaming_response(txn: AbstractTransaction): """ Returns a plain text response with one JSON object per line, indicating the progress of the transaction. """ def generate(): try: for operation in txn.execute_iter(): operation = yield json.dumps({ 'status': 'in progress', 'operation': operation, }) + '\n' yield json.dumps({ 'status': 'completed', 'result': txn.result, }) + '\n' except Exception as err: logger.warning('Transaction failed:\n' + traceback.format_exc()) txn.rollback() yield json.dumps({ 'status': 'aborted', 'error': str(err), }) + '\n' return current_app.response_class( stream_with_context(generate()), mimetype='text/plain') def development_only(f: Callable) -> Callable: @functools.wraps(f) def wrapper(*args, **kwargs): if current_app.config.get('ENV') == 'development' or \ current_app.config.get('TESTING'): return f(*args, **kwargs) return { 'error': 'This endpoint may only be called in development' }, 403 return wrapper def is_truthy(s: str) -> bool: return s.lower() in ['yes', 'true']