diff --git a/ceo/__main__.py b/ceo/__main__.py new file mode 100644 index 0000000..aa3acee --- /dev/null +++ b/ceo/__main__.py @@ -0,0 +1,4 @@ +from .cli import cli + +if __name__ == '__main__': + cli(obj={}) diff --git a/ceo/cli.py b/ceo/cli.py new file mode 100644 index 0000000..e590b2f --- /dev/null +++ b/ceo/cli.py @@ -0,0 +1,43 @@ +import importlib.resources +import os +import socket + +import click +from zope import component + +from .krb_check import krb_check +from .members import members +from ceo_common.interfaces import IConfig, IHTTPClient +from ceo_common.model import Config, HTTPClient + + +@click.group() +@click.pass_context +def cli(ctx): + # ensure ctx exists and is a dict + ctx.ensure_object(dict) + + princ = krb_check() + user = princ[:princ.index('@')] + ctx.obj['user'] = user + + register_services() + + +cli.add_command(members) + + +def register_services(): + # Config + # This is a hack to determine if we're in the dev env or not + if socket.getfqdn().endswith('.csclub.internal'): + with importlib.resources.path('tests', 'ceo_dev.ini') as p: + config_file = p.__fspath__() + else: + config_file = os.environ.get('CEO_CONFIG', '/etc/csc/ceo.ini') + cfg = Config(config_file) + component.provideUtility(cfg, IConfig) + + # HTTPService + http_client = HTTPClient() + component.provideUtility(http_client, IHTTPClient) diff --git a/ceo/krb_check.py b/ceo/krb_check.py new file mode 100644 index 0000000..f9f9ce8 --- /dev/null +++ b/ceo/krb_check.py @@ -0,0 +1,28 @@ +import subprocess + +import gssapi + + +def krb_check(): + """ + Spawns a `kinit` process if no credentials are available or the + credentials have expired. + Returns the principal string 'user@REALM'. + """ + try: + creds = gssapi.Credentials(usage='initiate') + except gssapi.raw.misc.GSSError: + kinit() + creds = gssapi.Credentials(usage='initiate') + + try: + result = creds.inquire() + except gssapi.raw.exceptions.ExpiredCredentialsError: + kinit() + result = creds.inquire() + + return str(result.name) + + +def kinit(): + subprocess.run(['kinit'], check=True) diff --git a/ceo/members.py b/ceo/members.py new file mode 100644 index 0000000..b16caaa --- /dev/null +++ b/ceo/members.py @@ -0,0 +1,206 @@ +import socket +import sys +from typing import Dict + +import click +from zope import component + +from .utils import http_post, http_get, http_patch, http_delete, \ + handle_stream_response, handle_sync_response, print_colon_kv, \ + get_failed_operations +from ceo_common.interfaces import IConfig +from ceo_common.utils import get_current_term, add_term, get_max_term +from ceod.transactions.members import ( + AddMemberTransaction, + DeleteMemberTransaction, +) + + +@click.group() +def members(): + pass + + +@members.command(short_help='Add a new member or club rep') +@click.argument('username') +@click.option('--cn', help='Full name', prompt='Full name') +@click.option('--program', required=False, help='Academic program') +@click.option('--terms', 'num_terms', type=click.IntRange(1, 100), + help='Number of terms to add', prompt='Number of terms') +@click.option('--clubrep', is_flag=True, default=False, + help='Add non-member terms instead of member terms') +@click.option('--forwarding-address', required=False, + help=('Forwarding address to set in ~/.forward. ' + 'Default is UW address. ' + 'Set to the empty string to disable forwarding.')) +def add(username, cn, program, num_terms, clubrep, forwarding_address): + cfg = component.getUtility(IConfig) + uw_domain = cfg.get('uw_domain') + + current_term = get_current_term() + terms = [current_term] + for _ in range(1, num_terms): + term = add_term(terms[-1]) + terms.append(term) + if forwarding_address is None: + forwarding_address = username + '@' + uw_domain + + click.echo("The following user will be created:") + lines = [ + ('uid', username), + ('cn', cn), + ] + if program is not None: + lines.append(('program', program)) + if clubrep: + lines.append(('non-member terms', ','.join(terms))) + else: + lines.append(('member terms', ','.join(terms))) + if forwarding_address != '': + lines.append(('forwarding address', forwarding_address)) + print_colon_kv(lines) + + click.confirm('Do you want to continue?', abort=True) + + body = { + 'uid': username, + 'cn': cn, + } + if program is not None: + body['program'] = program + if clubrep: + body['terms'] = terms + else: + body['non_member_terms'] = terms + if forwarding_address != '': + body['forwarding_addresses'] = [forwarding_address] + resp = http_post('/api/members', json=body) + data = handle_stream_response(resp, AddMemberTransaction.operations) + result = data[-1]['result'] + print_user_lines(result) + + failed_operations = get_failed_operations(data) + if 'send_welcome_message' in failed_operations: + click.echo(click.style( + 'Warning: welcome message was not sent. You now need to manually ' + 'send the user their password.', fg='yellow')) + + +def print_user_lines(result: Dict): + """Pretty-print a user JSON response.""" + lines = [ + ('uid', result['uid']), + ('cn', result['cn']), + ('program', result.get('program', 'Unknown')), + ('UID number', result['uid_number']), + ('GID number', result['gid_number']), + ('login shell', result['login_shell']), + ('home directory', result['home_directory']), + ('is a club', result['is_club']), + ] + if 'forwarding_addresses' in result: + lines.append(('forwarding addresses', ','.join(result['forwarding_addresses']))) + if 'terms' in result: + lines.append(('terms', ','.join(result['terms']))) + if 'non_member_terms' in result: + lines.append(('non-member terms', ','.join(result['non_member_terms']))) + if 'password' in result: + lines.append(('password', result['password'])) + print_colon_kv(lines) + + +@members.command(short_help='Get info about a user') +@click.argument('username') +def get(username): + resp = http_get('/api/members/' + username) + result = handle_sync_response(resp) + print_user_lines(result) + + +@members.command(short_help="Replace a user's login shell or forwarding addresses") +@click.argument('username') +@click.option('--login-shell', required=False, help='Login shell') +@click.option('--forwarding-addresses', required=False, + help='Comma-separated list of forwarding addresses') +def modify(username, login_shell, forwarding_addresses): + if login_shell is None and forwarding_addresses is None: + click.echo('Nothing to do.') + sys.exit() + operations = [] + body = {} + if login_shell is not None: + body['login_shell'] = login_shell + operations.append('replace_login_shell') + click.echo('Login shell will be set to: ' + login_shell) + if forwarding_addresses is not None: + forwarding_addresses = forwarding_addresses.split(',') + body['forwarding_addresses'] = forwarding_addresses + operations.append('replace_forwarding_addresses') + prefix = '~/.forward will be set to: ' + click.echo(prefix + forwarding_addresses[0]) + for address in forwarding_addresses[1:]: + click.echo((' ' * len(prefix)) + address) + + click.confirm('Do you want to continue?', abort=True) + + resp = http_patch('/api/members/' + username, json=body) + handle_stream_response(resp, operations) + + +@members.command(short_help="Renew a member or club rep's membership") +@click.argument('username') +@click.option('--terms', 'num_terms', type=click.IntRange(1, 100), + help='Number of terms to add', prompt='Number of terms') +@click.option('--clubrep', is_flag=True, default=False, + help='Add non-member terms instead of member terms') +def renew(username, num_terms, clubrep): + resp = http_get('/api/members/' + username) + result = handle_sync_response(resp) + max_term = None + if clubrep and 'non_member_terms' in result: + max_term = get_max_term(result['non_member_terms']) + elif not clubrep and 'terms' in result: + max_term = get_max_term(result['terms']) + if max_term is not None: + max_term = get_max_term([max_term, get_current_term()]) + else: + max_term = get_current_term() + + terms = [add_term(max_term)] + for _ in range(1, num_terms): + term = add_term(terms[-1]) + terms.append(term) + + if clubrep: + body = {'non_member_terms': terms} + click.echo('The following non-member terms will be added: ' + ','.join(terms)) + else: + body = {'terms': terms} + click.echo('The following member terms will be added: ' + ','.join(terms)) + + click.confirm('Do you want to continue?', abort=True) + + resp = http_post(f'/api/members/{username}/renew', json=body) + handle_sync_response(resp) + click.echo('Done.') + + +@members.command(short_help="Reset a user's password") +@click.argument('username') +def pwreset(username): + click.confirm(f"Are you sure you want to reset {username}'s password?", abort=True) + resp = http_post(f'/api/members/{username}/pwreset') + result = handle_sync_response(resp) + click.echo('New password: ' + result['password']) + + +@members.command(short_help="Delete a user") +@click.argument('username') +def delete(username): + # a hack to determine if we're in the dev environment + if not socket.getfqdn().endswith('.csclub.internal'): + click.echo('This command may only be called during development.') + sys.exit(1) + click.confirm(f"Are you sure you want to delete {username}?", abort=True) + resp = http_delete(f'/api/members/{username}') + handle_stream_response(resp, DeleteMemberTransaction.operations) diff --git a/ceo/operation_strings.py b/ceo/operation_strings.py new file mode 100644 index 0000000..48e7d30 --- /dev/null +++ b/ceo/operation_strings.py @@ -0,0 +1,19 @@ +# These descriptions are printed to the console while a transaction +# is performed, in real time. +descriptions = { + 'add_user_to_ldap': 'Add user to LDAP', + 'add_group_to_ldap': 'Add group to LDAP', + 'add_user_to_kerberos': 'Add user to Kerberos', + 'create_home_dir': 'Create home directory', + 'set_forwarding_addresses': 'Set forwarding addresses', + 'send_welcome_message': 'Send welcome message', + 'subscribe_to_mailing_list': 'Subscribe to mailing list', + 'announce_new_user': 'Announce new user to mailing list', + 'replace_login_shell': 'Replace login shell', + 'replace_forwarding_addresses': 'Replace forwarding addresses', + 'remove_user_from_ldap': 'Remove user from LDAP', + 'remove_group_from_ldap': 'Remove group from LDAP', + 'remove_user_from_kerberos': 'Remove user from Kerberos', + 'delete_home_dir': 'Delete home directory', + 'unsubscribe_from_mailing_list': 'Unsubscribe from mailing list', +} diff --git a/ceo/utils.py b/ceo/utils.py new file mode 100644 index 0000000..b84386c --- /dev/null +++ b/ceo/utils.py @@ -0,0 +1,149 @@ +import json +import sys +from typing import List, Tuple, Dict + +import click +import requests +from zope import component + +from .operation_strings import descriptions as op_desc +from ceo_common.interfaces import IHTTPClient, IConfig + + +def http_request(method: str, path: str, **kwargs) -> requests.Response: + client = component.getUtility(IHTTPClient) + cfg = component.getUtility(IConfig) + if path.startswith('/api/db'): + host = cfg.get('ceod_db_host') + need_cred = False + else: + host = cfg.get('ceod_admin_host') + # The forwarded TGT is only needed for endpoints which write to LDAP + need_cred = method != 'GET' + return client.request( + host, path, method, principal=None, need_cred=need_cred, + stream=True, **kwargs) + + +def http_get(path: str, **kwargs) -> requests.Response: + return http_request('GET', path, **kwargs) + + +def http_post(path: str, **kwargs) -> requests.Response: + return http_request('POST', path, **kwargs) + + +def http_patch(path: str, **kwargs) -> requests.Response: + return http_request('PATCH', path, **kwargs) + + +def http_delete(path: str, **kwargs) -> requests.Response: + return http_request('DELETE', path, **kwargs) + + +def handle_stream_response(resp: requests.Response, operations: List[str]) -> List[Dict]: + """ + Print output to the console while operations are being streamed + from the server over HTTP. + Returns the parsed JSON data streamed from the server. + """ + if resp.status_code != 200: + click.echo('An error occurred:') + click.echo(resp.text) + sys.exit(1) + click.echo(op_desc[operations[0]] + '... ', nl=False) + idx = 0 + data = [] + for line in resp.iter_lines(decode_unicode=True, chunk_size=8): + d = json.loads(line) + data.append(d) + if d['status'] == 'aborted': + click.echo(click.style('ABORTED', fg='red')) + click.echo('The transaction was rolled back.') + click.echo('The error was: ' + d['error']) + click.echo('Please check the ceod logs.') + sys.exit(1) + elif d['status'] == 'completed': + click.echo('Transaction successfully completed.') + return data + + operation = d['operation'] + oper_failed = False + err_msg = None + prefix = 'failed_to_' + if operation.startswith(prefix): + operation = operation[len(prefix):] + oper_failed = True + # sometimes the operation looks like + # "failed_to_do_something: error message" + if ':' in operation: + operation, err_msg = operation.split(': ', 1) + + while idx < len(operations) and operations[idx] != operation: + click.echo('Skipped') + idx += 1 + if idx == len(operations): + break + click.echo(op_desc[operations[idx]] + '... ', nl=False) + if idx == len(operations): + click.echo('Unrecognized operation: ' + operation) + continue + if oper_failed: + click.echo(click.style('Failed', fg='red')) + if err_msg is not None: + click.echo(' Error message: ' + err_msg) + else: + click.echo(click.style('Done', fg='green')) + idx += 1 + if idx < len(operations): + click.echo(op_desc[operations[idx]] + '... ', nl=False) + + raise Exception('server response ended abruptly') + + +def handle_sync_response(resp: requests.Response): + """ + Exit the program if the request was not successful. + Returns the parsed JSON response. + """ + if resp.status_code // 100 != 2: + click.echo('An error occurred:') + click.echo(resp.text) + sys.exit(1) + return resp.json() + + +def print_colon_kv(pairs: List[Tuple[str, str]]): + """ + Pretty-print a list of key-value pairs such that the key and value + columns align. + Example: + key1: value1 + key1000: value2 + """ + maxlen = max(len(key) for key, val in pairs) + for key, val in pairs: + click.echo(key + ': ', nl=False) + extra_space = ' ' * (maxlen - len(key)) + click.echo(extra_space, nl=False) + click.echo(val) + + +def get_failed_operations(data: List[Dict]) -> List[str]: + """ + Get a list of the failed operations using the JSON objects + streamed from the server. + """ + prefix = 'failed_to_' + failed = [] + for d in data: + if 'operation' not in d: + continue + operation = d['operation'] + if not operation.startswith(prefix): + continue + operation = operation[len(prefix):] + if ':' in operation: + operation = operation[:operation.index(':')] + failed.append(operation) + return failed diff --git a/ceo_common/interfaces/IHTTPClient.py b/ceo_common/interfaces/IHTTPClient.py index 0bec342..61c9599 100644 --- a/ceo_common/interfaces/IHTTPClient.py +++ b/ceo_common/interfaces/IHTTPClient.py @@ -1,14 +1,27 @@ +from typing import Union + from zope.interface import Interface class IHTTPClient(Interface): """A helper class for HTTP requests to ceod.""" - def get(host: str, api_path: str, **kwargs): + def request(host: str, api_path: str, method: str, principal: str, + need_cred: bool, **kwargs): + """Make an HTTP request.""" + + def get(host: str, api_path: str, principal: Union[str, None] = None, + need_cred: bool = True, **kwargs): """Make a GET request.""" - def post(host: str, api_path: str, **kwargs): + def post(host: str, api_path: str, principal: Union[str, None] = None, + need_cred: bool = True, **kwargs): """Make a POST request.""" - def delete(host: str, api_path: str, **kwargs): + def patch(host: str, api_path: str, principal: Union[str, None] = None, + need_cred: bool = True, **kwargs): + """Make a PATCH request.""" + + def delete(host: str, api_path: str, principal: Union[str, None] = None, + need_cred: bool = True, **kwargs): """Make a DELETE request.""" diff --git a/ceo_common/model/HTTPClient.py b/ceo_common/model/HTTPClient.py index 81d4bc5..4da38fc 100644 --- a/ceo_common/model/HTTPClient.py +++ b/ceo_common/model/HTTPClient.py @@ -1,10 +1,13 @@ -from flask import g +from base64 import b64encode +from typing import Union + import gssapi import requests from requests_gssapi import HTTPSPNEGOAuth from zope import component from zope.interface import implementer +from ceo_common.krb5.utils import get_fwd_tgt from ceo_common.interfaces import IConfig, IHTTPClient @@ -20,32 +23,48 @@ class HTTPClient: self.ceod_port = cfg.get('ceod_port') self.base_domain = cfg.get('base_domain') - self.krb_realm = cfg.get('ldap_sasl_realm') - - def request(self, host: str, api_path: str, method='GET', **kwargs): - principal = g.sasl_user - gssapi_name = gssapi.Name(principal) - creds = gssapi.Credentials(name=gssapi_name, usage='initiate') + def request(self, host: str, api_path: str, method: str, principal: str, + need_cred: bool, **kwargs): # always use the FQDN if '.' not in host: host = host + '.' + self.base_domain + + # SPNEGO + if principal is not None: + gssapi_name = gssapi.Name(principal) + creds = gssapi.Credentials(name=gssapi_name, usage='initiate') + else: + creds = None auth = HTTPSPNEGOAuth( opportunistic_auth=True, target_name=gssapi.Name('ceod/' + host), creds=creds, ) + + # Forwarded TGT (X-KRB5-CRED) + headers = {} + if need_cred: + b = get_fwd_tgt('ceod/' + host) + headers['X-KRB5-CRED'] = b64encode(b).decode() + return requests.request( method, f'{self.scheme}://{host}:{self.ceod_port}{api_path}', - auth=auth, - **kwargs, + auth=auth, headers=headers, **kwargs, ) - def get(self, host: str, api_path: str, **kwargs): - return self.request(host, api_path, 'GET', **kwargs) + def get(self, host: str, api_path: str, principal: Union[str, None] = None, + need_cred: bool = False, **kwargs): + return self.request(host, api_path, 'GET', principal, need_cred, **kwargs) - def post(self, host: str, api_path: str, **kwargs): - return self.request(host, api_path, 'POST', **kwargs) + def post(self, host: str, api_path: str, principal: Union[str, None] = None, + need_cred: bool = False, **kwargs): + return self.request(host, api_path, 'POST', principal, need_cred, **kwargs) - def delete(self, host: str, api_path: str, **kwargs): - return self.request(host, api_path, 'DELETE', **kwargs) + def patch(self, host: str, api_path: str, principal: Union[str, None] = None, + need_cred: bool = False, **kwargs): + return self.request(host, api_path, 'PATCH', principal, need_cred, **kwargs) + + def delete(self, host: str, api_path: str, principal: Union[str, None] = None, + need_cred: bool = False, **kwargs): + return self.request(host, api_path, 'DELETE', principal, need_cred, **kwargs) diff --git a/ceo_common/model/RemoteMailmanService.py b/ceo_common/model/RemoteMailmanService.py index 03b07ee..d9e70e0 100644 --- a/ceo_common/model/RemoteMailmanService.py +++ b/ceo_common/model/RemoteMailmanService.py @@ -1,3 +1,4 @@ +from flask import g from zope import component from zope.interface import implementer @@ -14,7 +15,9 @@ class RemoteMailmanService: self.http_client = component.getUtility(IHTTPClient) def subscribe(self, address: str, mailing_list: str): - resp = self.http_client.post(self.mailman_host, f'/api/mailman/{mailing_list}/{address}') + resp = self.http_client.post( + self.mailman_host, f'/api/mailman/{mailing_list}/{address}', + principal=g.sasl_user) if not resp.ok: if resp.status_code == 409: raise UserAlreadySubscribedError() @@ -23,7 +26,9 @@ class RemoteMailmanService: raise Exception(resp.json()) def unsubscribe(self, address: str, mailing_list: str): - resp = self.http_client.delete(self.mailman_host, f'/api/mailman/{mailing_list}/{address}') + resp = self.http_client.delete( + self.mailman_host, f'/api/mailman/{mailing_list}/{address}', + principal=g.sasl_user) if not resp.ok: if resp.status_code == 404: raise UserNotSubscribedError() diff --git a/ceo_common/utils.py b/ceo_common/utils.py new file mode 100644 index 0000000..a7d858b --- /dev/null +++ b/ceo_common/utils.py @@ -0,0 +1,40 @@ +import datetime +from typing import List + + +def get_current_term() -> str: + """ + Get the current term as formatted in the CSC LDAP (e.g. 's2021'). + """ + dt = datetime.datetime.now() + c = 'w' + if 5 <= dt.month <= 8: + c = 's' + elif 9 <= dt.month: + c = 'f' + return c + str(dt.year) + + +def add_term(term: str) -> str: + """ + Add one term to the given term and return the string. + Example: add_term('s2021') -> 'f2021' + """ + c = term[0] + s_year = term[1:] + if c == 'w': + return 's' + s_year + elif c == 's': + return 'f' + s_year + year = int(s_year) + return 'w' + str(year + 1) + + +def get_max_term(terms: List[str]) -> str: + """Get the maximum (latest) term.""" + max_year = max(term[1:] for term in terms) + if 'f' + max_year in terms: + return 'f' + max_year + elif 's' + max_year in terms: + return 's' + max_year + return 'w' + max_year diff --git a/ceod/api/utils.py b/ceod/api/utils.py index 42f900e..3db9056 100644 --- a/ceod/api/utils.py +++ b/ceod/api/utils.py @@ -84,9 +84,10 @@ def create_streaming_response(txn: AbstractTransaction): indicating the progress of the transaction. """ def generate(): + generator = txn.execute_iter() try: - for operation in txn.execute_iter(): - operation = yield json.dumps({ + for operation in generator: + yield json.dumps({ 'status': 'in progress', 'operation': operation, }) + '\n' @@ -94,6 +95,15 @@ def create_streaming_response(txn: AbstractTransaction): 'status': 'completed', 'result': txn.result, }) + '\n' + except GeneratorExit: + # Keep on going. Even if the client closes the connection, we don't + # want to give up half way through. + try: + for operation in generator: + pass + except Exception: + logger.warning('Transaction failed:\n' + traceback.format_exc()) + txn.rollback() except Exception as err: logger.warning('Transaction failed:\n' + traceback.format_exc()) txn.rollback() diff --git a/ceod/transactions/members/AddMemberTransaction.py b/ceod/transactions/members/AddMemberTransaction.py index 2479864..38cf539 100644 --- a/ceod/transactions/members/AddMemberTransaction.py +++ b/ceod/transactions/members/AddMemberTransaction.py @@ -22,8 +22,8 @@ class AddMemberTransaction(AbstractTransaction): 'add_user_to_kerberos', 'create_home_dir', 'set_forwarding_addresses', - 'subscribe_to_mailing_list', 'send_welcome_message', + 'subscribe_to_mailing_list', 'announce_new_user', ] @@ -78,7 +78,7 @@ class AddMemberTransaction(AbstractTransaction): if self.forwarding_addresses: user.set_forwarding_addresses(self.forwarding_addresses) - yield 'set_forwarding_addresses' + yield 'set_forwarding_addresses' # The following operations can't/shouldn't be rolled back because the # user has already seen the email diff --git a/requirements.txt b/requirements.txt index 7023ac8..d416050 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +click==8.0.1 Flask==2.0.1 gssapi==1.6.14 Jinja2==3.0.1 diff --git a/tests/ceo_dev.ini b/tests/ceo_dev.ini new file mode 100644 index 0000000..e74895e --- /dev/null +++ b/tests/ceo_dev.ini @@ -0,0 +1,9 @@ +[DEFAULT] +base_domain = csclub.internal +uw_domain = uwaterloo.internal + +[ceod] +# this is the host with the ceod/admin Kerberos key +admin_host = phosphoric-acid +use_https = false +port = 9987 diff --git a/tests/ceod/api/test_members.py b/tests/ceod/api/test_members.py index d82b2af..79e4c6f 100644 --- a/tests/ceod/api/test_members.py +++ b/tests/ceod/api/test_members.py @@ -18,6 +18,7 @@ def create_user_resp(client, mocks_for_create_user): 'cn': 'Test One', 'program': 'Math', 'terms': ['s2021'], + 'forwarding_addresses': ['test_1@uwaterloo.internal'], }) assert status == 200 assert data[-1]['status'] == 'completed' @@ -56,7 +57,7 @@ def test_api_create_user(cfg, create_user_resp): "is_club": False, "program": "Math", "terms": ["s2021"], - "forwarding_addresses": [], + "forwarding_addresses": ['test_1@uwaterloo.internal'], "password": "krb5" }}, ] @@ -209,5 +210,5 @@ def test_authz_check(client, create_user_result): # If we're syscom but we don't pass credentials, the request should fail _, data = client.post('/api/members', json={ 'uid': 'test_1', 'cn': 'Test One', 'terms': ['s2021'], - }, principal='ctdalek', no_creds=True) + }, principal='ctdalek', need_cred=False) assert data[-1]['status'] == 'aborted' diff --git a/tests/conftest.py b/tests/conftest.py index 5b252a1..3388fe6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -52,8 +52,6 @@ def delete_test_princs(krb_srv): 'kadmin', '-k', '-p', krb_srv.admin_principal, 'listprincs', 'test_*', ], text=True, capture_output=True, check=True) princs = [line.strip() for line in proc.stdout.splitlines()] - # remove the password prompt - princs = princs[1:] for princ in princs: krb_srv.delprinc(princ) diff --git a/tests/conftest_ceod_api.py b/tests/conftest_ceod_api.py index 4dbeb95..19f16ba 100644 --- a/tests/conftest_ceod_api.py +++ b/tests/conftest_ceod_api.py @@ -72,7 +72,7 @@ class CeodTestClient: text=True, input='krb5', check=True, stdout=subprocess.DEVNULL, env={'KRB5CCNAME': self.principal_ccaches[principal]}) - def get_headers(self, principal: str, no_creds: bool): + def get_headers(self, principal: str, need_cred: bool): if principal not in self.principal_ccaches: _, filename = tempfile.mkstemp(dir=self.cache_dir) self.principal_ccaches[principal] = filename @@ -82,7 +82,7 @@ class CeodTestClient: # the header using req.prepare(). req = Request('GET', self.base_url, auth=self.get_auth(principal)) headers = list(req.prepare().headers.items()) - if not no_creds: + if need_cred: # Get the X-KRB5-CRED header (forwarded TGT). cred = b64encode(get_fwd_tgt( 'ceod/' + socket.getfqdn(), self.principal_ccaches[principal] @@ -90,14 +90,14 @@ class CeodTestClient: headers.append(('X-KRB5-CRED', cred)) return headers - def request(self, method: str, path: str, principal: str, no_creds: bool, **kwargs): + def request(self, method: str, path: str, principal: str, need_cred: bool, **kwargs): # Make sure that we're not already in a request context, otherwise # g will get overridden with pytest.raises(RuntimeError): '' in g if principal is None: principal = self.syscom_principal - headers = self.get_headers(principal, no_creds) + headers = self.get_headers(principal, need_cred) resp = self.client.open(path, method=method, headers=headers, **kwargs) status = int(resp.status.split(' ', 1)[0]) if resp.headers['content-type'] == 'application/json': @@ -106,14 +106,14 @@ class CeodTestClient: data = [json.loads(line) for line in resp.data.splitlines()] return status, data - def get(self, path, principal=None, no_creds=False, **kwargs): - return self.request('GET', path, principal, no_creds, **kwargs) + def get(self, path, principal=None, need_cred=True, **kwargs): + return self.request('GET', path, principal, need_cred, **kwargs) - def post(self, path, principal=None, no_creds=False, **kwargs): - return self.request('POST', path, principal, no_creds, **kwargs) + def post(self, path, principal=None, need_cred=True, **kwargs): + return self.request('POST', path, principal, need_cred, **kwargs) - def patch(self, path, principal=None, no_creds=False, **kwargs): - return self.request('PATCH', path, principal, no_creds, **kwargs) + def patch(self, path, principal=None, need_cred=True, **kwargs): + return self.request('PATCH', path, principal, need_cred, **kwargs) - def delete(self, path, principal=None, no_creds=False, **kwargs): - return self.request('DELETE', path, principal, no_creds, **kwargs) + def delete(self, path, principal=None, need_cred=True, **kwargs): + return self.request('DELETE', path, principal, need_cred, **kwargs)