From 08a3faaefc881c7491e0cff3997f175a4dc4813e Mon Sep 17 00:00:00 2001 From: Max Erenberg Date: Mon, 23 Aug 2021 23:01:24 +0000 Subject: [PATCH] add unit tests for members CLI --- ceo/cli/__init__.py | 1 + ceo/{cli.py => cli/entrypoint.py} | 5 +- ceo/{ => cli}/members.py | 66 +++++---- ceo/cli/utils.py | 107 ++++++++++++++ ceo/krb_check.py | 20 ++- ceo/utils.py | 96 +------------ ceo_common/model/Term.py | 67 +++++++++ ceo_common/model/__init__.py | 1 + ceo_common/utils.py | 40 ------ tests/ceo/__init__.py | 0 tests/ceo/cli/__init__.py | 0 tests/ceo/cli/test_members.py | 136 ++++++++++++++++++ tests/ceo_common/model/test_remote_mailman.py | 48 ++----- tests/ceod_test_local.ini | 2 + tests/conftest.py | 104 ++++++++------ tests/conftest_ceo.py | 18 +++ tests/conftest_ceod_api.py | 63 ++------ tests/utils.py | 34 +++++ 18 files changed, 502 insertions(+), 306 deletions(-) create mode 100644 ceo/cli/__init__.py rename ceo/{cli.py => cli/entrypoint.py} (90%) rename ceo/{ => cli}/members.py (80%) create mode 100644 ceo/cli/utils.py create mode 100644 ceo_common/model/Term.py delete mode 100644 ceo_common/utils.py create mode 100644 tests/ceo/__init__.py create mode 100644 tests/ceo/cli/__init__.py create mode 100644 tests/ceo/cli/test_members.py create mode 100644 tests/conftest_ceo.py create mode 100644 tests/utils.py diff --git a/ceo/cli/__init__.py b/ceo/cli/__init__.py new file mode 100644 index 0000000..d3a9b7d --- /dev/null +++ b/ceo/cli/__init__.py @@ -0,0 +1 @@ +from .entrypoint import cli diff --git a/ceo/cli.py b/ceo/cli/entrypoint.py similarity index 90% rename from ceo/cli.py rename to ceo/cli/entrypoint.py index e590b2f..9389c56 100644 --- a/ceo/cli.py +++ b/ceo/cli/entrypoint.py @@ -5,7 +5,7 @@ import socket import click from zope import component -from .krb_check import krb_check +from ..krb_check import krb_check from .members import members from ceo_common.interfaces import IConfig, IHTTPClient from ceo_common.model import Config, HTTPClient @@ -21,7 +21,8 @@ def cli(ctx): user = princ[:princ.index('@')] ctx.obj['user'] = user - register_services() + if os.environ.get('PYTEST') != '1': + register_services() cli.add_command(members) diff --git a/ceo/members.py b/ceo/cli/members.py similarity index 80% rename from ceo/members.py rename to ceo/cli/members.py index b16caaa..67cd049 100644 --- a/ceo/members.py +++ b/ceo/cli/members.py @@ -5,11 +5,10 @@ from typing import Dict import click from zope import component -from .utils import http_post, http_get, http_patch, http_delete, \ - handle_stream_response, handle_sync_response, print_colon_kv, \ - get_failed_operations +from ..utils import http_post, http_get, http_patch, http_delete, get_failed_operations +from .utils import handle_stream_response, handle_sync_response, print_colon_kv from ceo_common.interfaces import IConfig -from ceo_common.utils import get_current_term, add_term, get_max_term +from ceo_common.model import Term from ceod.transactions.members import ( AddMemberTransaction, DeleteMemberTransaction, @@ -37,11 +36,10 @@ def add(username, cn, program, num_terms, clubrep, forwarding_address): cfg = component.getUtility(IConfig) uw_domain = cfg.get('uw_domain') - current_term = get_current_term() - terms = [current_term] - for _ in range(1, num_terms): - term = add_term(terms[-1]) - terms.append(term) + current_term = Term.current() + terms = [current_term + i for i in range(num_terms)] + terms = list(map(str, terms)) + if forwarding_address is None: forwarding_address = username + '@' + uw_domain @@ -69,13 +67,18 @@ def add(username, cn, program, num_terms, clubrep, forwarding_address): if program is not None: body['program'] = program if clubrep: - body['terms'] = terms - else: body['non_member_terms'] = terms + else: + body['terms'] = terms if forwarding_address != '': body['forwarding_addresses'] = [forwarding_address] + operations = AddMemberTransaction.operations + if forwarding_address == '': + # don't bother displaying this because it won't be run + operations.remove('set_forwarding_addresses') + resp = http_post('/api/members', json=body) - data = handle_stream_response(resp, AddMemberTransaction.operations) + data = handle_stream_response(resp, operations) result = data[-1]['result'] print_user_lines(result) @@ -121,7 +124,10 @@ def get(username): @click.argument('username') @click.option('--login-shell', required=False, help='Login shell') @click.option('--forwarding-addresses', required=False, - help='Comma-separated list of forwarding addresses') + help=( + 'Comma-separated list of forwarding addresses. ' + 'Set to the empty string to disable forwarding.' + )) def modify(username, login_shell, forwarding_addresses): if login_shell is None and forwarding_addresses is None: click.echo('Nothing to do.') @@ -133,13 +139,19 @@ def modify(username, login_shell, forwarding_addresses): operations.append('replace_login_shell') click.echo('Login shell will be set to: ' + login_shell) if forwarding_addresses is not None: - forwarding_addresses = forwarding_addresses.split(',') + if forwarding_addresses == '': + forwarding_addresses = [] + else: + forwarding_addresses = forwarding_addresses.split(',') body['forwarding_addresses'] = forwarding_addresses operations.append('replace_forwarding_addresses') prefix = '~/.forward will be set to: ' - click.echo(prefix + forwarding_addresses[0]) - for address in forwarding_addresses[1:]: - click.echo((' ' * len(prefix)) + address) + if len(forwarding_addresses) > 0: + click.echo(prefix + forwarding_addresses[0]) + for address in forwarding_addresses[1:]: + click.echo((' ' * len(prefix)) + address) + else: + click.echo(prefix) click.confirm('Do you want to continue?', abort=True) @@ -157,19 +169,19 @@ def renew(username, num_terms, clubrep): resp = http_get('/api/members/' + username) result = handle_sync_response(resp) max_term = None + current_term = Term.current() if clubrep and 'non_member_terms' in result: - max_term = get_max_term(result['non_member_terms']) + max_term = max(Term(s) for s in result['non_member_terms']) elif not clubrep and 'terms' in result: - max_term = get_max_term(result['terms']) - if max_term is not None: - max_term = get_max_term([max_term, get_current_term()]) - else: - max_term = get_current_term() + max_term = max(Term(s) for s in result['terms']) - terms = [add_term(max_term)] - for _ in range(1, num_terms): - term = add_term(terms[-1]) - terms.append(term) + if max_term is not None and max_term >= current_term: + next_term = max_term + 1 + else: + next_term = Term.current() + + terms = [next_term + i for i in range(num_terms)] + terms = list(map(str, terms)) if clubrep: body = {'non_member_terms': terms} diff --git a/ceo/cli/utils.py b/ceo/cli/utils.py new file mode 100644 index 0000000..cc25e84 --- /dev/null +++ b/ceo/cli/utils.py @@ -0,0 +1,107 @@ +import json +import sys +from typing import List, Tuple, Dict + +import click +import requests + +from ..operation_strings import descriptions as op_desc + + +class Abort(click.ClickException): + """Abort silently.""" + + def __init__(self, exit_code=1): + super().__init__('') + self.exit_code = exit_code + + def show(self): + pass + + +def print_colon_kv(pairs: List[Tuple[str, str]]): + """ + Pretty-print a list of key-value pairs such that the key and value + columns align. + Example: + key1: value1 + key1000: value2 + """ + maxlen = max(len(key) for key, val in pairs) + for key, val in pairs: + click.echo(key + ': ', nl=False) + extra_space = ' ' * (maxlen - len(key)) + click.echo(extra_space, nl=False) + click.echo(val) + + +def handle_stream_response(resp: requests.Response, operations: List[str]) -> List[Dict]: + """ + Print output to the console while operations are being streamed + from the server over HTTP. + Returns the parsed JSON data streamed from the server. + """ + if resp.status_code != 200: + click.echo('An error occurred:') + click.echo(resp.text.rstrip()) + raise Abort() + click.echo(op_desc[operations[0]] + '... ', nl=False) + idx = 0 + data = [] + for line in resp.iter_lines(decode_unicode=True, chunk_size=8): + d = json.loads(line) + data.append(d) + if d['status'] == 'aborted': + click.echo(click.style('ABORTED', fg='red')) + click.echo('The transaction was rolled back.') + click.echo('The error was: ' + d['error']) + click.echo('Please check the ceod logs.') + sys.exit(1) + elif d['status'] == 'completed': + click.echo('Transaction successfully completed.') + return data + + operation = d['operation'] + oper_failed = False + err_msg = None + prefix = 'failed_to_' + if operation.startswith(prefix): + operation = operation[len(prefix):] + oper_failed = True + # sometimes the operation looks like + # "failed_to_do_something: error message" + if ':' in operation: + operation, err_msg = operation.split(': ', 1) + + while idx < len(operations) and operations[idx] != operation: + click.echo('Skipped') + idx += 1 + if idx == len(operations): + break + click.echo(op_desc[operations[idx]] + '... ', nl=False) + if idx == len(operations): + click.echo('Unrecognized operation: ' + operation) + continue + if oper_failed: + click.echo(click.style('Failed', fg='red')) + if err_msg is not None: + click.echo(' Error message: ' + err_msg) + else: + click.echo(click.style('Done', fg='green')) + idx += 1 + if idx < len(operations): + click.echo(op_desc[operations[idx]] + '... ', nl=False) + + raise Exception('server response ended abruptly') + + +def handle_sync_response(resp: requests.Response): + """ + Exit the program if the request was not successful. + Returns the parsed JSON response. + """ + if resp.status_code != 200: + click.echo('An error occurred:') + click.echo(resp.text.rstrip()) + raise Abort() + return resp.json() diff --git a/ceo/krb_check.py b/ceo/krb_check.py index f9f9ce8..7824da7 100644 --- a/ceo/krb_check.py +++ b/ceo/krb_check.py @@ -9,19 +9,15 @@ def krb_check(): credentials have expired. Returns the principal string 'user@REALM'. """ - try: - creds = gssapi.Credentials(usage='initiate') - except gssapi.raw.misc.GSSError: - kinit() - creds = gssapi.Credentials(usage='initiate') + for _ in range(2): + try: + creds = gssapi.Credentials(usage='initiate') + result = creds.inquire() + return str(result.name) + except (gssapi.raw.misc.GSSError, gssapi.raw.exceptions.ExpiredCredentialsError): + kinit() - try: - result = creds.inquire() - except gssapi.raw.exceptions.ExpiredCredentialsError: - kinit() - result = creds.inquire() - - return str(result.name) + raise Exception('could not acquire GSSAPI credentials') def kinit(): diff --git a/ceo/utils.py b/ceo/utils.py index b84386c..fc84265 100644 --- a/ceo/utils.py +++ b/ceo/utils.py @@ -1,12 +1,8 @@ -import json -import sys -from typing import List, Tuple, Dict +from typing import List, Dict -import click import requests from zope import component -from .operation_strings import descriptions as op_desc from ceo_common.interfaces import IHTTPClient, IConfig @@ -41,94 +37,6 @@ def http_delete(path: str, **kwargs) -> requests.Response: return http_request('DELETE', path, **kwargs) -def handle_stream_response(resp: requests.Response, operations: List[str]) -> List[Dict]: - """ - Print output to the console while operations are being streamed - from the server over HTTP. - Returns the parsed JSON data streamed from the server. - """ - if resp.status_code != 200: - click.echo('An error occurred:') - click.echo(resp.text) - sys.exit(1) - click.echo(op_desc[operations[0]] + '... ', nl=False) - idx = 0 - data = [] - for line in resp.iter_lines(decode_unicode=True, chunk_size=8): - d = json.loads(line) - data.append(d) - if d['status'] == 'aborted': - click.echo(click.style('ABORTED', fg='red')) - click.echo('The transaction was rolled back.') - click.echo('The error was: ' + d['error']) - click.echo('Please check the ceod logs.') - sys.exit(1) - elif d['status'] == 'completed': - click.echo('Transaction successfully completed.') - return data - - operation = d['operation'] - oper_failed = False - err_msg = None - prefix = 'failed_to_' - if operation.startswith(prefix): - operation = operation[len(prefix):] - oper_failed = True - # sometimes the operation looks like - # "failed_to_do_something: error message" - if ':' in operation: - operation, err_msg = operation.split(': ', 1) - - while idx < len(operations) and operations[idx] != operation: - click.echo('Skipped') - idx += 1 - if idx == len(operations): - break - click.echo(op_desc[operations[idx]] + '... ', nl=False) - if idx == len(operations): - click.echo('Unrecognized operation: ' + operation) - continue - if oper_failed: - click.echo(click.style('Failed', fg='red')) - if err_msg is not None: - click.echo(' Error message: ' + err_msg) - else: - click.echo(click.style('Done', fg='green')) - idx += 1 - if idx < len(operations): - click.echo(op_desc[operations[idx]] + '... ', nl=False) - - raise Exception('server response ended abruptly') - - -def handle_sync_response(resp: requests.Response): - """ - Exit the program if the request was not successful. - Returns the parsed JSON response. - """ - if resp.status_code // 100 != 2: - click.echo('An error occurred:') - click.echo(resp.text) - sys.exit(1) - return resp.json() - - -def print_colon_kv(pairs: List[Tuple[str, str]]): - """ - Pretty-print a list of key-value pairs such that the key and value - columns align. - Example: - key1: value1 - key1000: value2 - """ - maxlen = max(len(key) for key, val in pairs) - for key, val in pairs: - click.echo(key + ': ', nl=False) - extra_space = ' ' * (maxlen - len(key)) - click.echo(extra_space, nl=False) - click.echo(val) - - def get_failed_operations(data: List[Dict]) -> List[str]: """ Get a list of the failed operations using the JSON objects @@ -144,6 +52,8 @@ def get_failed_operations(data: List[Dict]) -> List[str]: continue operation = operation[len(prefix):] if ':' in operation: + # sometimes the operation looks like + # "failed_to_do_something: error message" operation = operation[:operation.index(':')] failed.append(operation) return failed diff --git a/ceo_common/model/Term.py b/ceo_common/model/Term.py new file mode 100644 index 0000000..84dc1ad --- /dev/null +++ b/ceo_common/model/Term.py @@ -0,0 +1,67 @@ +import datetime + + +class Term: + """A representation of a term in the CSC LDAP, e.g. 's2021'.""" + + seasons = ['w', 's', 'f'] + + def __init__(self, s_term: str): + assert len(s_term) == 5 and s_term[0] in self.seasons and \ + s_term[1:].isdigit() + self.s_term = s_term + + def __repr__(self): + return self.s_term + + @staticmethod + def current(): + """Get a Term object for the current date.""" + dt = datetime.datetime.now() + c = 'w' + if 5 <= dt.month <= 8: + c = 's' + elif 9 <= dt.month: + c = 'f' + s_term = c + str(dt.year) + return Term(s_term) + + def __add__(self, other): + assert type(other) is int and other >= 0 + c = self.s_term[0] + season_idx = self.seasons.index(c) + year = int(self.s_term[1:]) + year += other // 3 + season_idx += other % 3 + if season_idx >= 3: + year += 1 + season_idx -= 3 + s_term = self.seasons[season_idx] + str(year) + return Term(s_term) + + def __eq__(self, other): + return isinstance(other, Term) and self.s_term == other.s_term + + def __lt__(self, other): + if not isinstance(other, Term): + return NotImplemented + c1, c2 = self.s_term[0], other.s_term[0] + year1, year2 = int(self.s_term[1:]), int(other.s_term[1:]) + return year1 < year2 or ( + year1 == year2 and self.seasons.index(c1) < self.seasons.index(c2) + ) + + def __gt__(self, other): + if not isinstance(other, Term): + return NotImplemented + c1, c2 = self.s_term[0], other.s_term[0] + year1, year2 = int(self.s_term[1:]), int(other.s_term[1:]) + return year1 > year2 or ( + year1 == year2 and self.seasons.index(c1) > self.seasons.index(c2) + ) + + def __ge__(self, other): + return self > other or self == other + + def __le__(self, other): + return self < other or self == other diff --git a/ceo_common/model/__init__.py b/ceo_common/model/__init__.py index 382fcef..14967e6 100644 --- a/ceo_common/model/__init__.py +++ b/ceo_common/model/__init__.py @@ -1,3 +1,4 @@ from .Config import Config from .HTTPClient import HTTPClient from .RemoteMailmanService import RemoteMailmanService +from .Term import Term diff --git a/ceo_common/utils.py b/ceo_common/utils.py deleted file mode 100644 index a7d858b..0000000 --- a/ceo_common/utils.py +++ /dev/null @@ -1,40 +0,0 @@ -import datetime -from typing import List - - -def get_current_term() -> str: - """ - Get the current term as formatted in the CSC LDAP (e.g. 's2021'). - """ - dt = datetime.datetime.now() - c = 'w' - if 5 <= dt.month <= 8: - c = 's' - elif 9 <= dt.month: - c = 'f' - return c + str(dt.year) - - -def add_term(term: str) -> str: - """ - Add one term to the given term and return the string. - Example: add_term('s2021') -> 'f2021' - """ - c = term[0] - s_year = term[1:] - if c == 'w': - return 's' + s_year - elif c == 's': - return 'f' + s_year - year = int(s_year) - return 'w' + str(year + 1) - - -def get_max_term(terms: List[str]) -> str: - """Get the maximum (latest) term.""" - max_year = max(term[1:] for term in terms) - if 'f' + max_year in terms: - return 'f' + max_year - elif 's' + max_year in terms: - return 's' + max_year - return 'w' + max_year diff --git a/tests/ceo/__init__.py b/tests/ceo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/ceo/cli/__init__.py b/tests/ceo/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/ceo/cli/test_members.py b/tests/ceo/cli/test_members.py new file mode 100644 index 0000000..73e852f --- /dev/null +++ b/tests/ceo/cli/test_members.py @@ -0,0 +1,136 @@ +import os +import re +import shutil + +from click.testing import CliRunner + +from ceo.cli import cli +from ceo_common.model import Term + + +def test_members_get(cli_setup, ldap_user): + runner = CliRunner() + result = runner.invoke(cli, ['members', 'get', ldap_user.uid]) + expected = ( + f"uid: {ldap_user.uid}\n" + f"cn: {ldap_user.cn}\n" + f"program: {ldap_user.program}\n" + f"UID number: {ldap_user.uid_number}\n" + f"GID number: {ldap_user.gid_number}\n" + f"login shell: {ldap_user.login_shell}\n" + f"home directory: {ldap_user.home_directory}\n" + f"is a club: {ldap_user.is_club()}\n" + "forwarding addresses: \n" + f"terms: {','.join(ldap_user.terms)}\n" + ) + assert result.exit_code == 0 + assert result.output == expected + + +def test_members_add(cli_setup): + runner = CliRunner() + result = runner.invoke(cli, [ + 'members', 'add', 'test_1', '--cn', 'Test One', '--program', 'Math', + '--terms', '1', + ], input='y\n') + expected_pat = re.compile(( + "^The following user will be created:\n" + "uid: test_1\n" + "cn: Test One\n" + "program: Math\n" + "member terms: [sfw]\\d{4}\n" + "forwarding address: test_1@uwaterloo.internal\n" + "Do you want to continue\\? \\[y/N\\]: y\n" + "Add user to LDAP... Done\n" + "Add group to LDAP... Done\n" + "Add user to Kerberos... Done\n" + "Create home directory... Done\n" + "Set forwarding addresses... Done\n" + "Send welcome message... Done\n" + "Subscribe to mailing list... Done\n" + "Announce new user to mailing list... Done\n" + "Transaction successfully completed.\n" + "uid: test_1\n" + "cn: Test One\n" + "program: Math\n" + "UID number: \\d{5}\n" + "GID number: \\d{5}\n" + "login shell: /bin/bash\n" + "home directory: [a-z0-9/_-]+/test_1\n" + "is a club: False\n" + "forwarding addresses: test_1@uwaterloo.internal\n" + "terms: [sfw]\\d{4}\n" + "password: \\S+\n$" + ), re.MULTILINE) + assert result.exit_code == 0 + assert expected_pat.match(result.output) is not None + + result = runner.invoke(cli, ['members', 'delete', 'test_1'], input='y\n') + assert result.exit_code == 0 + + +def test_members_modify(cli_setup, ldap_user): + # The homedir needs to exist so the API can write to ~/.forward + os.makedirs(ldap_user.home_directory) + try: + runner = CliRunner() + result = runner.invoke(cli, [ + 'members', 'modify', ldap_user.uid, '--login-shell', '/bin/sh', + '--forwarding-addresses', 'jdoe@test1.internal,jdoe@test2.internal', + ], input='y\n') + expected = ( + "Login shell will be set to: /bin/sh\n" + "~/.forward will be set to: jdoe@test1.internal\n" + " jdoe@test2.internal\n" + "Do you want to continue? [y/N]: y\n" + "Replace login shell... Done\n" + "Replace forwarding addresses... Done\n" + "Transaction successfully completed.\n" + ) + assert result.exit_code == 0 + assert result.output == expected + finally: + shutil.rmtree(ldap_user.home_directory) + + +def test_members_renew(cli_setup, ldap_user, g_admin_ctx): + # set the user's last term to something really old + with g_admin_ctx(), ldap_user.ldap_srv.entry_ctx_for_user(ldap_user) as entry: + entry.term = ['s1999', 'f1999'] + current_term = Term.current() + + runner = CliRunner() + result = runner.invoke(cli, [ + 'members', 'renew', ldap_user.uid, '--terms', '1', + ], input='y\n') + expected = ( + f"The following member terms will be added: {current_term}\n" + "Do you want to continue? [y/N]: y\n" + "Done.\n" + ) + assert result.exit_code == 0 + assert result.output == expected + + runner = CliRunner() + result = runner.invoke(cli, [ + 'members', 'renew', ldap_user.uid, '--terms', '2', + ], input='y\n') + expected = ( + f"The following member terms will be added: {current_term+1},{current_term+2}\n" + "Do you want to continue? [y/N]: y\n" + "Done.\n" + ) + assert result.exit_code == 0 + assert result.output == expected + + +def test_members_pwreset(cli_setup, ldap_user, krb_user): + runner = CliRunner() + result = runner.invoke( + cli, ['members', 'pwreset', ldap_user.uid], input='y\n') + expected_pat = re.compile(( + f"^Are you sure you want to reset {ldap_user.uid}'s password\\? \\[y/N\\]: y\n" + "New password: \\S+\n$" + ), re.MULTILINE) + assert result.exit_code == 0 + assert expected_pat.match(result.output) is not None diff --git a/tests/ceo_common/model/test_remote_mailman.py b/tests/ceo_common/model/test_remote_mailman.py index 44b61fb..8518e5b 100644 --- a/tests/ceo_common/model/test_remote_mailman.py +++ b/tests/ceo_common/model/test_remote_mailman.py @@ -1,42 +1,12 @@ -from multiprocessing import Process -import socket -import sys -import time - -import requests - from ceo_common.model import RemoteMailmanService -def test_remote_mailman(cfg, http_client, app, mock_mailman_server, g_syscom): - port = cfg.get('ceod_port') - hostname = socket.gethostname() - - def server_start(): - sys.stdout = open('/dev/null', 'w') - sys.stderr = sys.stdout - app.run(debug=False, host='0.0.0.0', port=port) - - proc = Process(target=server_start) - proc.start() - - for _ in range(5): - try: - http_client.get(hostname, '/ping') - except requests.exceptions.ConnectionError: - time.sleep(1) - continue - break - - try: - mailman_srv = RemoteMailmanService() - assert mock_mailman_server.subscriptions['csc-general'] == [] - # RemoteMailmanService -> app -> MailmanService -> MockMailmanServer - address = 'test_1@csclub.internal' - mailman_srv.subscribe(address, 'csc-general') - assert mock_mailman_server.subscriptions['csc-general'] == [address] - mailman_srv.unsubscribe(address, 'csc-general') - assert mock_mailman_server.subscriptions['csc-general'] == [] - finally: - proc.terminate() - proc.join() +def test_remote_mailman(app_process, mock_mailman_server, g_syscom): + mailman_srv = RemoteMailmanService() + assert mock_mailman_server.subscriptions['csc-general'] == [] + # RemoteMailmanService -> app -> MailmanService -> MockMailmanServer + address = 'test_1@csclub.internal' + mailman_srv.subscribe(address, 'csc-general') + assert mock_mailman_server.subscriptions['csc-general'] == [address] + mailman_srv.unsubscribe(address, 'csc-general') + assert mock_mailman_server.subscriptions['csc-general'] == [] diff --git a/tests/ceod_test_local.ini b/tests/ceod_test_local.ini index 04fdbd9..6a8f7bd 100644 --- a/tests/ceod_test_local.ini +++ b/tests/ceod_test_local.ini @@ -1,5 +1,7 @@ [DEFAULT] base_domain = csclub.internal +# merge ceod.ini and ceo.ini values together to make testing easier +uw_domain = uwaterloo.internal [ceod] admin_host = phosphoric-acid diff --git a/tests/conftest.py b/tests/conftest.py index 3388fe6..d51f230 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,20 +1,23 @@ import contextlib import grp import importlib.resources +from multiprocessing import Process import os import pwd import shutil import subprocess -from subprocess import DEVNULL -import tempfile +import sys +import time from unittest.mock import patch, Mock import flask import ldap3 import pytest +import requests import socket from zope import component +from .utils import krb5ccname_ctx from ceo_common.interfaces import IConfig, IKerberosService, ILDAPService, \ IFileService, IMailmanService, IHTTPClient, IUWLDAPService, IMailService from ceo_common.model import Config, HTTPClient @@ -25,6 +28,7 @@ import ceod.utils as utils from .MockSMTPServer import MockSMTPServer from .MockMailmanServer import MockMailmanServer from .conftest_ceod_api import client # noqa: F401 +from .conftest_ceo import cli_setup # noqa: F401 @pytest.fixture(scope='session', autouse=True) @@ -47,6 +51,18 @@ def cfg(_drone_hostname_mock): return _cfg +@pytest.fixture(scope='session', autouse=True) +def _delete_ccaches(): + # I've noticed when pytest finishes, the temporary files + # created by tempfile.NamedTemporaryFile() aren't destroyed. + # So, we clean them up here. + from .utils import _ccaches + yield + # forcefully decrement the reference counts, which will trigger + # the destructors + _ccaches.clear() + + def delete_test_princs(krb_srv): proc = subprocess.run([ 'kadmin', '-k', '-p', krb_srv.admin_principal, 'listprincs', 'test_*', @@ -86,20 +102,8 @@ def delete_subtree(conn: ldap3.Connection, base_dn: str): pass -@pytest.fixture(scope='session') -def ceod_admin_creds(cfg, krb_srv): - """ - Acquire credentials for ceod/admin and store them - in the default ccache. - """ - subprocess.run( - ['kinit', '-k', cfg.get('ldap_admin_principal')], - check=True, - ) - - @pytest.fixture -def g_admin_ctx(cfg, ceod_admin_creds, app): +def g_admin_ctx(app): """ Store the principal for ceod/admin in flask.g. This context manager should be used any time LDAP is modified via the @@ -109,62 +113,45 @@ def g_admin_ctx(cfg, ceod_admin_creds, app): """ @contextlib.contextmanager def wrapper(): - admin_principal = cfg.get('ldap_admin_principal') - with app.app_context(): + with krb5ccname_ctx('ceod/admin'), app.app_context(): try: - flask.g.sasl_user = admin_principal + flask.g.sasl_user = 'ceod/admin' yield finally: flask.g.pop('sasl_user') return wrapper -@pytest.fixture(scope='session') -def syscom_creds(): - """ - Acquire credentials for a syscom member and store them in a ccache. - Yields the name of the ccache file. - """ - with tempfile.NamedTemporaryFile() as f: - subprocess.run( - ['kinit', '-c', f.name, 'ctdalek'], - check=True, text=True, input='krb5', stdout=DEVNULL, - ) - yield f.name - - @pytest.fixture -def g_syscom(syscom_creds, app): +def g_syscom(app): """ Store the principal for the syscom member in flask.g, and point KRB5CCNAME to the file where the TGT is stored. Use this fixture if you need syscom credentials for an HTTP request to a different process. """ - filename = syscom_creds - with app.app_context(): - old_krb5ccname = os.environ['KRB5CCNAME'] - os.environ['KRB5CCNAME'] = 'FILE:' + filename + with krb5ccname_ctx('ctdalek'), app.app_context(): try: flask.g.sasl_user = 'ctdalek' - yield filename + yield finally: - os.environ['KRB5CCNAME'] = old_krb5ccname flask.g.pop('sasl_user') @pytest.fixture(scope='session') -def ldap_conn(cfg, ceod_admin_creds) -> ldap3.Connection: +def ldap_conn(cfg) -> ldap3.Connection: # Assume that the same server URL is being used for the CSC # and UWLDAP during the tests. cfg = component.getUtility(IConfig) server_url = cfg.get('ldap_server_url') # sanity check assert server_url == cfg.get('uwldap_server_url') - return ldap3.Connection( - server_url, auto_bind=True, raise_exceptions=True, - authentication=ldap3.SASL, sasl_mechanism=ldap3.KERBEROS, - user=cfg.get('ldap_admin_principal')) + with krb5ccname_ctx('ceod/admin'): + conn = ldap3.Connection( + server_url, auto_bind=True, raise_exceptions=True, + authentication=ldap3.SASL, sasl_mechanism=ldap3.KERBEROS, + user='ceod/admin') + return conn @pytest.fixture(scope='session') @@ -375,3 +362,32 @@ def uwldap_user(cfg, uwldap_srv, ldap_conn): ) yield user conn.delete(dn) + + +@pytest.fixture(scope='module') +def app_process(cfg, app, http_client): + port = cfg.get('ceod_port') + hostname = socket.gethostname() + + def server_start(): + sys.stdout = open('/dev/null', 'w') + sys.stderr = sys.stdout + app.run(debug=False, host='0.0.0.0', port=port) + + proc = Process(target=server_start) + proc.start() + + try: + with krb5ccname_ctx('ctdalek'): + for i in range(5): + try: + http_client.get(hostname, '/ping') + except requests.exceptions.ConnectionError: + time.sleep(1) + continue + break + assert i != 5, 'Timed out' + yield + finally: + proc.terminate() + proc.join() diff --git a/tests/conftest_ceo.py b/tests/conftest_ceo.py new file mode 100644 index 0000000..be1f63d --- /dev/null +++ b/tests/conftest_ceo.py @@ -0,0 +1,18 @@ +import os + +import pytest + +from .utils import krb5ccname_ctx + + +@pytest.fixture(scope='module') +def cli_setup(app_process): + # This tells the CLI entrypoint not to register additional zope services. + os.environ['PYTEST'] = '1' + + # Running the client and the server in the same process would be very + # messy because they would be sharing the same environment variables, + # Kerberos cache, and registered utilities (via zope). So we're just + # going to start the app in a child process intead. + with krb5ccname_ctx('ctdalek'): + yield diff --git a/tests/conftest_ceod_api.py b/tests/conftest_ceod_api.py index 19f16ba..1962d37 100644 --- a/tests/conftest_ceod_api.py +++ b/tests/conftest_ceod_api.py @@ -1,10 +1,6 @@ from base64 import b64encode -import contextlib -import os import json import socket -import subprocess -import tempfile from flask import g from flask.testing import FlaskClient @@ -14,6 +10,7 @@ from requests import Request from requests_gssapi import HTTPSPNEGOAuth from ceo_common.krb5.utils import get_fwd_tgt +from .utils import krb5ccname_ctx __all__ = ['client'] @@ -21,40 +18,23 @@ __all__ = ['client'] @pytest.fixture(scope='session') def client(app): app_client = app.test_client() - with tempfile.TemporaryDirectory() as cache_dir: - yield CeodTestClient(app_client, cache_dir) + yield CeodTestClient(app_client) class CeodTestClient: - def __init__(self, app_client: FlaskClient, cache_dir: str): + def __init__(self, app_client: FlaskClient): self.client = app_client self.syscom_principal = 'ctdalek' # this is only used for the HTTPSNEGOAuth self.base_url = f'http://{socket.getfqdn()}' - # for each principal for which we acquired a TGT, map their - # username to a file (ccache) storing their TGT - self.principal_ccaches = {} - # this is where we'll store the credentials for each principal - self.cache_dir = cache_dir # for SPNEGO self.target_name = gssapi.Name('ceod/' + socket.getfqdn()) - @contextlib.contextmanager - def krb5ccname_env(self, principal): - """Temporarily change KRB5CCNAME to the ccache of the principal.""" - old_krb5ccname = os.environ['KRB5CCNAME'] - os.environ['KRB5CCNAME'] = self.principal_ccaches[principal] - try: - yield - finally: - os.environ['KRB5CCNAME'] = old_krb5ccname - def get_auth(self, principal): """Acquire a HTTPSPNEGOAuth instance for the principal.""" name = gssapi.Name(principal) # the 'store' arg doesn't seem to work for DIR ccaches - with self.krb5ccname_env(principal): - creds = gssapi.Credentials(name=name, usage='initiate') + creds = gssapi.Credentials(name=name, usage='initiate') auth = HTTPSPNEGOAuth( opportunistic_auth=True, target_name=self.target_name, @@ -62,32 +42,17 @@ class CeodTestClient: ) return auth - def kinit(self, principal): - """Acquire an initial TGT for the principal.""" - # For some reason, kinit with the '-c' option deletes the other - # credentials in the cache collection, so we need to override the - # env variable - subprocess.run( - ['kinit', principal], - text=True, input='krb5', check=True, stdout=subprocess.DEVNULL, - env={'KRB5CCNAME': self.principal_ccaches[principal]}) - def get_headers(self, principal: str, need_cred: bool): - if principal not in self.principal_ccaches: - _, filename = tempfile.mkstemp(dir=self.cache_dir) - self.principal_ccaches[principal] = filename - self.kinit(principal) - # Get the Authorization header (SPNEGO). - # The method doesn't matter here because we just need to extract - # the header using req.prepare(). - req = Request('GET', self.base_url, auth=self.get_auth(principal)) - headers = list(req.prepare().headers.items()) - if need_cred: - # Get the X-KRB5-CRED header (forwarded TGT). - cred = b64encode(get_fwd_tgt( - 'ceod/' + socket.getfqdn(), self.principal_ccaches[principal] - )).decode() - headers.append(('X-KRB5-CRED', cred)) + with krb5ccname_ctx(principal): + # Get the Authorization header (SPNEGO). + # The method doesn't matter here because we just need to extract + # the header using req.prepare(). + req = Request('GET', self.base_url, auth=self.get_auth(principal)) + headers = list(req.prepare().headers.items()) + if need_cred: + # Get the X-KRB5-CRED header (forwarded TGT). + cred = b64encode(get_fwd_tgt('ceod/' + socket.getfqdn())).decode() + headers.append(('X-KRB5-CRED', cred)) return headers def request(self, method: str, path: str, principal: str, need_cred: bool, **kwargs): diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..324aedd --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,34 @@ +import contextlib +import os +import subprocess +from subprocess import DEVNULL +import tempfile + + +# map principals to files storing credentials +_ccaches = {} + + +@contextlib.contextmanager +def krb5ccname_ctx(principal: str): + """ + Temporarily set KRB5CCNAME to a ccache storing credentials + for the specified user. + """ + old_krb5ccname = os.environ['KRB5CCNAME'] + try: + if principal not in _ccaches: + f = tempfile.NamedTemporaryFile() + os.environ['KRB5CCNAME'] = 'FILE:' + f.name + args = ['kinit', principal] + if principal == 'ceod/admin': + args = ['kinit', '-k', principal] + subprocess.run( + args, stdout=DEVNULL, text=True, input='krb5', + check=True) + _ccaches[principal] = f + else: + os.environ['KRB5CCNAME'] = 'FILE:' + _ccaches[principal].name + yield + finally: + os.environ['KRB5CCNAME'] = old_krb5ccname