144 lines
4.5 KiB

import functools
import grp
import json
import os
import pwd
import traceback
from typing import Callable, List
from flask import current_app, stream_with_context
from zope import component
from .spnego import requires_authentication
from ceo_common.errors import InvalidMembershipError
from ceo_common.interfaces import IUser, ILDAPService
from ceo_common.logger_factory import logger_factory
from ceod.transactions import AbstractTransaction
logger = logger_factory(__name__)
def get_valid_member_or_throw(username: str) -> IUser:
ldap_srv = component.getUtility(ILDAPService)
user = ldap_srv.get_user(username)
if not user.membership_is_valid():
raise InvalidMembershipError()
return user
def requires_authentication_no_realm(f: Callable) -> Callable:
Like requires_authentication, but strips the realm out of the principal string.
e.g. user1@CSCLUB.UWATERLOO.CA -> user1
def wrapper(principal: str, *args, **kwargs):
user = principal[:principal.index('@')]
logger.debug(f'received request from {user}')
return f(user, *args, **kwargs)
return wrapper
def user_is_in_group(user: str, group: str) -> bool:
"""Returns True if `user` is in `group`, False otherwise."""
return user in grp.getgrnam(group).gr_mem
def authz_restrict_to_groups(f: Callable, allowed_groups: List[str]) -> Callable:
Restrict an endpoint to users who belong to one or more of the
specified groups.
allowed_group_ids = [grp.getgrnam(g).gr_gid for g in allowed_groups]
def wrapper(_username: str, *args, **kwargs):
# we need to call the argument _username to avoid name clashes with
# the arguments of f
username = _username
if username.startswith('ceod/'):
# ceod services are always allowed to make internal calls
return f(*args, **kwargs)
for gid in os.getgrouplist(username, pwd.getpwnam(username).pw_gid):
if gid in allowed_group_ids:
return f(*args, **kwargs)
f"User '{username}' denied since they are not in one of {allowed_groups}"
return {
'error': f'You must be in one of {allowed_groups}'
}, 403
return wrapper
def authz_restrict_to_staff(f: Callable) -> Callable:
"""A decorator to restrict an endpoint to staff members."""
allowed_groups = ['syscom', 'exec', 'office', 'staff', 'adm']
return authz_restrict_to_groups(f, allowed_groups)
def authz_restrict_to_syscom(f: Callable) -> Callable:
"""A decorator to restrict an endpoint to syscom members."""
allowed_groups = ['syscom']
return authz_restrict_to_groups(f, allowed_groups)
def create_streaming_response(txn: AbstractTransaction):
Returns a plain text response with one JSON object per line,
indicating the progress of the transaction.
def generate():
generator = txn.execute_iter()
for operation in generator:
yield json.dumps({
'status': 'in progress',
'operation': operation,
}) + '\n'
yield json.dumps({
'status': 'completed',
'result': txn.result,
}) + '\n'
except GeneratorExit:
# Keep on going. Even if the client closes the connection, we don't
# want to give up half way through.
for operation in generator:
except Exception:
logger.warning('Transaction failed:\n' + traceback.format_exc())
except Exception as err:
logger.warning('Transaction failed:\n' + traceback.format_exc())
yield json.dumps({
'status': 'aborted',
'error': str(err),
}) + '\n'
return current_app.response_class(
stream_with_context(generate()), mimetype='text/plain')
def development_only(f: Callable) -> Callable:
def wrapper(*args, **kwargs):
if current_app.config.get('ENV') == 'development' or \
return f(*args, **kwargs)
return {
'error': 'This endpoint may only be called in development'
}, 403
return wrapper
def is_truthy(s: str) -> bool:
return s.lower() in ['yes', 'true', '1']