forked from public/pyceo
123 lines
3.7 KiB
Python
123 lines
3.7 KiB
Python
|
import functools
|
||
|
import grp
|
||
|
import json
|
||
|
import socket
|
||
|
from typing import Callable, List
|
||
|
|
||
|
from flask import current_app
|
||
|
from flask_kerberos import requires_authentication
|
||
|
from zope import component
|
||
|
|
||
|
from ceo_common.logger_factory import logger_factory
|
||
|
from ceo_common.interfaces import IConfig
|
||
|
from ceod.transactions import AbstractTransaction
|
||
|
|
||
|
logger = logger_factory(__name__)
|
||
|
|
||
|
|
||
|
def restrict_host(role: str) -> Callable[[Callable], Callable]:
|
||
|
"""
|
||
|
This is a function which returns a decorator.
|
||
|
It returns a 400 if the client makes a request to an endpoint
|
||
|
which is restricted to a different host.
|
||
|
|
||
|
:param role: a key in the app's config (e.g. 'ceod_admin_host')
|
||
|
which maps to a specific hostname
|
||
|
|
||
|
Example:
|
||
|
@app.route('/<mailing_list>/<username>', methods=['POST'])
|
||
|
@restrict_host('mailman_host')
|
||
|
def subscribe(mailing_list, username):
|
||
|
....
|
||
|
"""
|
||
|
|
||
|
hostname = socket.gethostname()
|
||
|
cfg = component.getUtility(IConfig)
|
||
|
desired_hostname = cfg.get(role)
|
||
|
|
||
|
def identity(f: Callable):
|
||
|
return f
|
||
|
|
||
|
def error_decorator(f: Callable):
|
||
|
@functools.wraps(f)
|
||
|
def wrapper(*args, **kwargs):
|
||
|
return {'error': f'Wrong host! Use {desired_hostname} instead'}, 400
|
||
|
return wrapper
|
||
|
|
||
|
if hostname == desired_hostname:
|
||
|
return identity
|
||
|
return error_decorator
|
||
|
|
||
|
|
||
|
def authz_restrict_to_groups(f: Callable, allowed_groups: List[str]) -> Callable:
|
||
|
"""
|
||
|
Restrict an endpoint to users who belong to one or more of the
|
||
|
specified groups.
|
||
|
"""
|
||
|
|
||
|
# TODO: cache group members, but place a time limit on the cache validity
|
||
|
|
||
|
@requires_authentication
|
||
|
@functools.wraps(f)
|
||
|
def wrapper(user: str, *args, **kwargs):
|
||
|
"""
|
||
|
:param user: a Kerberos principal (e.g. 'user1@CSCLUB.UWATERLOO.CA')
|
||
|
"""
|
||
|
logger.debug(f'received request from {user}')
|
||
|
username = user[:user.index('@')]
|
||
|
if username.startswith('ceod/'):
|
||
|
# ceod services are always allowed to make internal calls
|
||
|
return f(*args, **kwargs)
|
||
|
for group in allowed_groups:
|
||
|
for group_member in grp.getgrnam(group).gr_mem:
|
||
|
if username == group_member:
|
||
|
return f(*args, **kwargs)
|
||
|
logger.debug(
|
||
|
f"User '{username}' denied since they are not in one of {allowed_groups}"
|
||
|
)
|
||
|
return {
|
||
|
'error': f'You must be in one of {allowed_groups}'
|
||
|
}, 403
|
||
|
|
||
|
return wrapper
|
||
|
|
||
|
|
||
|
def authz_restrict_to_staff(f: Callable) -> Callable:
|
||
|
"""A decorator to restrict an endpoint to staff members."""
|
||
|
|
||
|
allowed_groups = ['office', 'staff', 'adm']
|
||
|
return authz_restrict_to_groups(f, allowed_groups)
|
||
|
|
||
|
|
||
|
def authz_restrict_to_syscom(f: Callable) -> Callable:
|
||
|
"""A decorator to restrict an endpoint to syscom members."""
|
||
|
|
||
|
allowed_groups = ['syscom']
|
||
|
return authz_restrict_to_groups(f, allowed_groups)
|
||
|
|
||
|
|
||
|
def create_streaming_response(txn: AbstractTransaction):
|
||
|
"""
|
||
|
Returns a plain text response with one JSON object per line,
|
||
|
indicating the progress of the transaction.
|
||
|
"""
|
||
|
def generate():
|
||
|
try:
|
||
|
for operation in txn.execute_iter():
|
||
|
operation = yield json.dumps({
|
||
|
'status': 'in progress',
|
||
|
'operation': operation,
|
||
|
}) + '\n'
|
||
|
yield json.dumps({
|
||
|
'status': 'completed',
|
||
|
'result': txn.result,
|
||
|
}) + '\n'
|
||
|
except Exception as err:
|
||
|
txn.rollback()
|
||
|
yield json.dumps({
|
||
|
'status': 'aborted',
|
||
|
'error': str(err),
|
||
|
}) + '\n'
|
||
|
|
||
|
return current_app.response_class(generate(), mimetype='text/plain')
|