pyceo/ceod/api/utils.py

123 lines
3.7 KiB
Python

import functools
import grp
import json
import socket
from typing import Callable, List
from flask import current_app
from flask_kerberos import requires_authentication
from zope import component
from ceo_common.logger_factory import logger_factory
from ceo_common.interfaces import IConfig
from ceod.transactions import AbstractTransaction
logger = logger_factory(__name__)
def restrict_host(role: str) -> Callable[[Callable], Callable]:
"""
This is a function which returns a decorator.
It returns a 400 if the client makes a request to an endpoint
which is restricted to a different host.
:param role: a key in the app's config (e.g. 'ceod_admin_host')
which maps to a specific hostname
Example:
@app.route('/<mailing_list>/<username>', methods=['POST'])
@restrict_host('mailman_host')
def subscribe(mailing_list, username):
....
"""
hostname = socket.gethostname()
cfg = component.getUtility(IConfig)
desired_hostname = cfg.get(role)
def identity(f: Callable):
return f
def error_decorator(f: Callable):
@functools.wraps(f)
def wrapper(*args, **kwargs):
return {'error': f'Wrong host! Use {desired_hostname} instead'}, 400
return wrapper
if hostname == desired_hostname:
return identity
return error_decorator
def authz_restrict_to_groups(f: Callable, allowed_groups: List[str]) -> Callable:
"""
Restrict an endpoint to users who belong to one or more of the
specified groups.
"""
# TODO: cache group members, but place a time limit on the cache validity
@requires_authentication
@functools.wraps(f)
def wrapper(user: str, *args, **kwargs):
"""
:param user: a Kerberos principal (e.g. 'user1@CSCLUB.UWATERLOO.CA')
"""
logger.debug(f'received request from {user}')
username = user[:user.index('@')]
if username.startswith('ceod/'):
# ceod services are always allowed to make internal calls
return f(*args, **kwargs)
for group in allowed_groups:
for group_member in grp.getgrnam(group).gr_mem:
if username == group_member:
return f(*args, **kwargs)
logger.debug(
f"User '{username}' denied since they are not in one of {allowed_groups}"
)
return {
'error': f'You must be in one of {allowed_groups}'
}, 403
return wrapper
def authz_restrict_to_staff(f: Callable) -> Callable:
"""A decorator to restrict an endpoint to staff members."""
allowed_groups = ['office', 'staff', 'adm']
return authz_restrict_to_groups(f, allowed_groups)
def authz_restrict_to_syscom(f: Callable) -> Callable:
"""A decorator to restrict an endpoint to syscom members."""
allowed_groups = ['syscom']
return authz_restrict_to_groups(f, allowed_groups)
def create_streaming_response(txn: AbstractTransaction):
"""
Returns a plain text response with one JSON object per line,
indicating the progress of the transaction.
"""
def generate():
try:
for operation in txn.execute_iter():
operation = yield json.dumps({
'status': 'in progress',
'operation': operation,
}) + '\n'
yield json.dumps({
'status': 'completed',
'result': txn.result,
}) + '\n'
except Exception as err:
txn.rollback()
yield json.dumps({
'status': 'aborted',
'error': str(err),
}) + '\n'
return current_app.response_class(generate(), mimetype='text/plain')