pyceo/ceod/api/utils.py

158 lines
5.1 KiB
Python

import functools
import json
import traceback
from typing import Callable, List
from flask import current_app, g, stream_with_context
from zope import component
from .spnego import requires_authentication
from ceo_common.errors import InvalidMembershipError
from ceo_common.interfaces import IUser, ILDAPService
from ceo_common.logger_factory import logger_factory
from ceod.transactions import AbstractTransaction
logger = logger_factory(__name__)
def get_valid_member_or_throw(username: str) -> IUser:
ldap_srv = component.getUtility(ILDAPService)
user = ldap_srv.get_user(username)
if not user.membership_is_valid():
raise InvalidMembershipError()
return user
def requires_authentication_no_realm(f: Callable) -> Callable:
"""
Like requires_authentication, but strips the realm out of the principal string.
e.g. user1@CSCLUB.UWATERLOO.CA -> user1
"""
@requires_authentication
@functools.wraps(f)
def wrapper(principal: str, *args, **kwargs):
user = principal[:principal.index('@')]
logger.debug(f'received request from {user}')
return f(user, *args, **kwargs)
return wrapper
def requires_admin_creds(f: Callable) -> Callable:
"""
Forces the next LDAP connection to use the admin Kerberos credentials.
This must be used BEFORE any of the authz decorators, since those
may require an LDAP connection, which will get cached for later use.
"""
@functools.wraps(f)
def wrapper(*args, **kwargs):
g.need_admin_creds = True
return f(*args, **kwargs)
return wrapper
def user_is_in_group(username: str, group_name: str) -> bool:
"""
Returns True if `username` is in `group_name` (or starts with "ceod/"),
False otherwise.
"""
if username.startswith("ceod/"):
return True
ldap_srv = component.getUtility(ILDAPService)
group = ldap_srv.get_group(group_name)
return username in group.members
def authz_restrict_to_groups(f: Callable, allowed_groups: List[str]) -> Callable:
"""
Restrict an endpoint to users who belong to one or more of the
specified groups.
"""
@requires_authentication_no_realm
@functools.wraps(f)
def wrapper(_username: str, *args, **kwargs):
# we need to call the argument _username to avoid name clashes with
# the arguments of f
username = _username
if username.startswith('ceod/'):
# ceod services are always allowed to make internal calls
return f(*args, **kwargs)
ldap_srv = component.getUtility(ILDAPService)
for group_name in ldap_srv.get_groups_for_user(username):
if group_name in allowed_groups:
return f(*args, **kwargs)
logger.debug(
f"User '{username}' denied since they are not in one of {allowed_groups}"
)
return {'error': f'You must be in one of {allowed_groups}'}, 403
return wrapper
def authz_restrict_to_staff(f: Callable) -> Callable:
"""A decorator to restrict an endpoint to staff members."""
allowed_groups = ['syscom', 'exec', 'office', 'staff', 'adm']
return authz_restrict_to_groups(f, allowed_groups)
def authz_restrict_to_syscom(f: Callable) -> Callable:
"""A decorator to restrict an endpoint to syscom members."""
allowed_groups = ['syscom']
return authz_restrict_to_groups(f, allowed_groups)
def create_streaming_response(txn: AbstractTransaction):
"""
Returns a plain text response with one JSON object per line,
indicating the progress of the transaction.
"""
def generate():
generator = txn.execute_iter()
try:
for operation in generator:
yield json.dumps({
'status': 'in progress',
'operation': operation,
}) + '\n'
yield json.dumps({
'status': 'completed',
'result': txn.result,
}) + '\n'
except GeneratorExit:
# Keep on going. Even if the client closes the connection, we don't
# want to give up half way through.
try:
for operation in generator:
pass
except Exception:
logger.warning('Transaction failed:\n' + traceback.format_exc())
txn.rollback()
except Exception as err:
logger.warning('Transaction failed:\n' + traceback.format_exc())
txn.rollback()
yield json.dumps({
'status': 'aborted',
'error': str(err),
}) + '\n'
return current_app.response_class(
stream_with_context(generate()), mimetype='text/plain')
def development_only(f: Callable) -> Callable:
@functools.wraps(f)
def wrapper(*args, **kwargs):
if current_app.config.get('DEBUG') or \
current_app.config.get('TESTING'):
return f(*args, **kwargs)
return {
'error': 'This endpoint may only be called in development'
}, 403
return wrapper
def is_truthy(s: str) -> bool:
return s.lower() in ['yes', 'true', '1']