158 lines
5.1 KiB
Python
158 lines
5.1 KiB
Python
import functools
|
|
import json
|
|
import traceback
|
|
from typing import Callable, List
|
|
|
|
from flask import current_app, g, stream_with_context
|
|
from zope import component
|
|
|
|
from .spnego import requires_authentication
|
|
from ceo_common.errors import InvalidMembershipError
|
|
from ceo_common.interfaces import IUser, ILDAPService
|
|
from ceo_common.logger_factory import logger_factory
|
|
from ceod.transactions import AbstractTransaction
|
|
|
|
logger = logger_factory(__name__)
|
|
|
|
|
|
def get_valid_member_or_throw(username: str) -> IUser:
|
|
ldap_srv = component.getUtility(ILDAPService)
|
|
user = ldap_srv.get_user(username)
|
|
if not user.membership_is_valid():
|
|
raise InvalidMembershipError()
|
|
return user
|
|
|
|
|
|
def requires_authentication_no_realm(f: Callable) -> Callable:
|
|
"""
|
|
Like requires_authentication, but strips the realm out of the principal string.
|
|
e.g. user1@CSCLUB.UWATERLOO.CA -> user1
|
|
"""
|
|
@requires_authentication
|
|
@functools.wraps(f)
|
|
def wrapper(principal: str, *args, **kwargs):
|
|
user = principal[:principal.index('@')]
|
|
logger.debug(f'received request from {user}')
|
|
return f(user, *args, **kwargs)
|
|
return wrapper
|
|
|
|
|
|
def requires_admin_creds(f: Callable) -> Callable:
|
|
"""
|
|
Forces the next LDAP connection to use the admin Kerberos credentials.
|
|
This must be used BEFORE any of the authz decorators, since those
|
|
may require an LDAP connection, which will get cached for later use.
|
|
"""
|
|
@functools.wraps(f)
|
|
def wrapper(*args, **kwargs):
|
|
g.need_admin_creds = True
|
|
return f(*args, **kwargs)
|
|
return wrapper
|
|
|
|
|
|
def user_is_in_group(username: str, group_name: str) -> bool:
|
|
"""
|
|
Returns True if `username` is in `group_name` (or starts with "ceod/"),
|
|
False otherwise.
|
|
"""
|
|
if username.startswith("ceod/"):
|
|
return True
|
|
ldap_srv = component.getUtility(ILDAPService)
|
|
group = ldap_srv.get_group(group_name)
|
|
return username in group.members
|
|
|
|
|
|
def authz_restrict_to_groups(f: Callable, allowed_groups: List[str]) -> Callable:
|
|
"""
|
|
Restrict an endpoint to users who belong to one or more of the
|
|
specified groups.
|
|
"""
|
|
|
|
@requires_authentication_no_realm
|
|
@functools.wraps(f)
|
|
def wrapper(_username: str, *args, **kwargs):
|
|
# we need to call the argument _username to avoid name clashes with
|
|
# the arguments of f
|
|
username = _username
|
|
if username.startswith('ceod/'):
|
|
# ceod services are always allowed to make internal calls
|
|
return f(*args, **kwargs)
|
|
ldap_srv = component.getUtility(ILDAPService)
|
|
for group_name in ldap_srv.get_groups_for_user(username):
|
|
if group_name in allowed_groups:
|
|
return f(*args, **kwargs)
|
|
logger.debug(
|
|
f"User '{username}' denied since they are not in one of {allowed_groups}"
|
|
)
|
|
return {'error': f'You must be in one of {allowed_groups}'}, 403
|
|
|
|
return wrapper
|
|
|
|
|
|
def authz_restrict_to_staff(f: Callable) -> Callable:
|
|
"""A decorator to restrict an endpoint to staff members."""
|
|
|
|
allowed_groups = ['syscom', 'exec', 'office', 'staff', 'adm']
|
|
return authz_restrict_to_groups(f, allowed_groups)
|
|
|
|
|
|
def authz_restrict_to_syscom(f: Callable) -> Callable:
|
|
"""A decorator to restrict an endpoint to syscom members."""
|
|
|
|
allowed_groups = ['syscom']
|
|
return authz_restrict_to_groups(f, allowed_groups)
|
|
|
|
|
|
def create_streaming_response(txn: AbstractTransaction):
|
|
"""
|
|
Returns a plain text response with one JSON object per line,
|
|
indicating the progress of the transaction.
|
|
"""
|
|
def generate():
|
|
generator = txn.execute_iter()
|
|
try:
|
|
for operation in generator:
|
|
yield json.dumps({
|
|
'status': 'in progress',
|
|
'operation': operation,
|
|
}) + '\n'
|
|
yield json.dumps({
|
|
'status': 'completed',
|
|
'result': txn.result,
|
|
}) + '\n'
|
|
except GeneratorExit:
|
|
# Keep on going. Even if the client closes the connection, we don't
|
|
# want to give up half way through.
|
|
try:
|
|
for operation in generator:
|
|
pass
|
|
except Exception:
|
|
logger.warning('Transaction failed:\n' + traceback.format_exc())
|
|
txn.rollback()
|
|
except Exception as err:
|
|
logger.warning('Transaction failed:\n' + traceback.format_exc())
|
|
txn.rollback()
|
|
yield json.dumps({
|
|
'status': 'aborted',
|
|
'error': str(err),
|
|
}) + '\n'
|
|
|
|
return current_app.response_class(
|
|
stream_with_context(generate()), mimetype='text/plain')
|
|
|
|
|
|
def development_only(f: Callable) -> Callable:
|
|
@functools.wraps(f)
|
|
def wrapper(*args, **kwargs):
|
|
if current_app.config.get('DEBUG') or \
|
|
current_app.config.get('TESTING'):
|
|
return f(*args, **kwargs)
|
|
return {
|
|
'error': 'This endpoint may only be called in development'
|
|
}, 403
|
|
return wrapper
|
|
|
|
|
|
def is_truthy(s: str) -> bool:
|
|
return s.lower() in ['yes', 'true', '1']
|