add unit tests for members CLI

pull/13/head
Max Erenberg 1 year ago
parent 7a8751fd8f
commit 08a3faaefc
  1. 1
      ceo/cli/__init__.py
  2. 5
      ceo/cli/entrypoint.py
  3. 64
      ceo/cli/members.py
  4. 107
      ceo/cli/utils.py
  5. 20
      ceo/krb_check.py
  6. 96
      ceo/utils.py
  7. 67
      ceo_common/model/Term.py
  8. 1
      ceo_common/model/__init__.py
  9. 40
      ceo_common/utils.py
  10. 0
      tests/ceo/__init__.py
  11. 0
      tests/ceo/cli/__init__.py
  12. 136
      tests/ceo/cli/test_members.py
  13. 48
      tests/ceo_common/model/test_remote_mailman.py
  14. 2
      tests/ceod_test_local.ini
  15. 104
      tests/conftest.py
  16. 18
      tests/conftest_ceo.py
  17. 63
      tests/conftest_ceod_api.py
  18. 34
      tests/utils.py

@ -0,0 +1 @@
from .entrypoint import cli

@ -5,7 +5,7 @@ import socket
import click
from zope import component
from .krb_check import krb_check
from ..krb_check import krb_check
from .members import members
from ceo_common.interfaces import IConfig, IHTTPClient
from ceo_common.model import Config, HTTPClient
@ -21,7 +21,8 @@ def cli(ctx):
user = princ[:princ.index('@')]
ctx.obj['user'] = user
register_services()
if os.environ.get('PYTEST') != '1':
register_services()
cli.add_command(members)

@ -5,11 +5,10 @@ from typing import Dict
import click
from zope import component
from .utils import http_post, http_get, http_patch, http_delete, \
handle_stream_response, handle_sync_response, print_colon_kv, \
get_failed_operations
from ..utils import http_post, http_get, http_patch, http_delete, get_failed_operations
from .utils import handle_stream_response, handle_sync_response, print_colon_kv
from ceo_common.interfaces import IConfig
from ceo_common.utils import get_current_term, add_term, get_max_term
from ceo_common.model import Term
from ceod.transactions.members import (
AddMemberTransaction,
DeleteMemberTransaction,
@ -37,11 +36,10 @@ def add(username, cn, program, num_terms, clubrep, forwarding_address):
cfg = component.getUtility(IConfig)
uw_domain = cfg.get('uw_domain')
current_term = get_current_term()
terms = [current_term]
for _ in range(1, num_terms):
term = add_term(terms[-1])
terms.append(term)
current_term = Term.current()
terms = [current_term + i for i in range(num_terms)]
terms = list(map(str, terms))
if forwarding_address is None:
forwarding_address = username + '@' + uw_domain
@ -69,13 +67,18 @@ def add(username, cn, program, num_terms, clubrep, forwarding_address):
if program is not None:
body['program'] = program
if clubrep:
body['terms'] = terms
else:
body['non_member_terms'] = terms
else:
body['terms'] = terms
if forwarding_address != '':
body['forwarding_addresses'] = [forwarding_address]
operations = AddMemberTransaction.operations
if forwarding_address == '':
# don't bother displaying this because it won't be run
operations.remove('set_forwarding_addresses')
resp = http_post('/api/members', json=body)
data = handle_stream_response(resp, AddMemberTransaction.operations)
data = handle_stream_response(resp, operations)
result = data[-1]['result']
print_user_lines(result)
@ -121,7 +124,10 @@ def get(username):
@click.argument('username')
@click.option('--login-shell', required=False, help='Login shell')
@click.option('--forwarding-addresses', required=False,
help='Comma-separated list of forwarding addresses')
help=(
'Comma-separated list of forwarding addresses. '
'Set to the empty string to disable forwarding.'
))
def modify(username, login_shell, forwarding_addresses):
if login_shell is None and forwarding_addresses is None:
click.echo('Nothing to do.')
@ -133,13 +139,19 @@ def modify(username, login_shell, forwarding_addresses):
operations.append('replace_login_shell')
click.echo('Login shell will be set to: ' + login_shell)
if forwarding_addresses is not None:
forwarding_addresses = forwarding_addresses.split(',')
if forwarding_addresses == '':
forwarding_addresses = []
else:
forwarding_addresses = forwarding_addresses.split(',')
body['forwarding_addresses'] = forwarding_addresses
operations.append('replace_forwarding_addresses')
prefix = '~/.forward will be set to: '
click.echo(prefix + forwarding_addresses[0])
for address in forwarding_addresses[1:]:
click.echo((' ' * len(prefix)) + address)
if len(forwarding_addresses) > 0:
click.echo(prefix + forwarding_addresses[0])
for address in forwarding_addresses[1:]:
click.echo((' ' * len(prefix)) + address)
else:
click.echo(prefix)
click.confirm('Do you want to continue?', abort=True)
@ -157,19 +169,19 @@ def renew(username, num_terms, clubrep):
resp = http_get('/api/members/' + username)
result = handle_sync_response(resp)
max_term = None
current_term = Term.current()
if clubrep and 'non_member_terms' in result:
max_term = get_max_term(result['non_member_terms'])
max_term = max(Term(s) for s in result['non_member_terms'])
elif not clubrep and 'terms' in result:
max_term = get_max_term(result['terms'])
if max_term is not None:
max_term = get_max_term([max_term, get_current_term()])
max_term = max(Term(s) for s in result['terms'])
if max_term is not None and max_term >= current_term:
next_term = max_term + 1
else:
max_term = get_current_term()
next_term = Term.current()
terms = [add_term(max_term)]
for _ in range(1, num_terms):
term = add_term(terms[-1])
terms.append(term)
terms = [next_term + i for i in range(num_terms)]
terms = list(map(str, terms))
if clubrep:
body = {'non_member_terms': terms}

@ -0,0 +1,107 @@
import json
import sys
from typing import List, Tuple, Dict
import click
import requests
from ..operation_strings import descriptions as op_desc
class Abort(click.ClickException):
"""Abort silently."""
def __init__(self, exit_code=1):
super().__init__('')
self.exit_code = exit_code
def show(self):
pass
def print_colon_kv(pairs: List[Tuple[str, str]]):
"""
Pretty-print a list of key-value pairs such that the key and value
columns align.
Example:
key1: value1
key1000: value2
"""
maxlen = max(len(key) for key, val in pairs)
for key, val in pairs:
click.echo(key + ': ', nl=False)
extra_space = ' ' * (maxlen - len(key))
click.echo(extra_space, nl=False)
click.echo(val)
def handle_stream_response(resp: requests.Response, operations: List[str]) -> List[Dict]:
"""
Print output to the console while operations are being streamed
from the server over HTTP.
Returns the parsed JSON data streamed from the server.
"""
if resp.status_code != 200:
click.echo('An error occurred:')
click.echo(resp.text.rstrip())
raise Abort()
click.echo(op_desc[operations[0]] + '... ', nl=False)
idx = 0
data = []
for line in resp.iter_lines(decode_unicode=True, chunk_size=8):
d = json.loads(line)
data.append(d)
if d['status'] == 'aborted':
click.echo(click.style('ABORTED', fg='red'))
click.echo('The transaction was rolled back.')
click.echo('The error was: ' + d['error'])
click.echo('Please check the ceod logs.')
sys.exit(1)
elif d['status'] == 'completed':
click.echo('Transaction successfully completed.')
return data
operation = d['operation']
oper_failed = False
err_msg = None
prefix = 'failed_to_'
if operation.startswith(prefix):
operation = operation[len(prefix):]
oper_failed = True
# sometimes the operation looks like
# "failed_to_do_something: error message"
if ':' in operation:
operation, err_msg = operation.split(': ', 1)
while idx < len(operations) and operations[idx] != operation:
click.echo('Skipped')
idx += 1
if idx == len(operations):
break
click.echo(op_desc[operations[idx]] + '... ', nl=False)
if idx == len(operations):
click.echo('Unrecognized operation: ' + operation)
continue
if oper_failed:
click.echo(click.style('Failed', fg='red'))
if err_msg is not None:
click.echo(' Error message: ' + err_msg)
else:
click.echo(click.style('Done', fg='green'))
idx += 1
if idx < len(operations):
click.echo(op_desc[operations[idx]] + '... ', nl=False)
raise Exception('server response ended abruptly')
def handle_sync_response(resp: requests.Response):
"""
Exit the program if the request was not successful.
Returns the parsed JSON response.
"""
if resp.status_code != 200:
click.echo('An error occurred:')
click.echo(resp.text.rstrip())
raise Abort()
return resp.json()

@ -9,19 +9,15 @@ def krb_check():
credentials have expired.
Returns the principal string 'user@REALM'.
"""
try:
creds = gssapi.Credentials(usage='initiate')
except gssapi.raw.misc.GSSError:
kinit()
creds = gssapi.Credentials(usage='initiate')
for _ in range(2):
try:
creds = gssapi.Credentials(usage='initiate')
result = creds.inquire()
return str(result.name)
except (gssapi.raw.misc.GSSError, gssapi.raw.exceptions.ExpiredCredentialsError):
kinit()
try:
result = creds.inquire()
except gssapi.raw.exceptions.ExpiredCredentialsError:
kinit()
result = creds.inquire()
return str(result.name)
raise Exception('could not acquire GSSAPI credentials')
def kinit():

@ -1,12 +1,8 @@
import json
import sys
from typing import List, Tuple, Dict
from typing import List, Dict
import click
import requests
from zope import component
from .operation_strings import descriptions as op_desc
from ceo_common.interfaces import IHTTPClient, IConfig
@ -41,94 +37,6 @@ def http_delete(path: str, **kwargs) -> requests.Response:
return http_request('DELETE', path, **kwargs)
def handle_stream_response(resp: requests.Response, operations: List[str]) -> List[Dict]:
"""
Print output to the console while operations are being streamed
from the server over HTTP.
Returns the parsed JSON data streamed from the server.
"""
if resp.status_code != 200:
click.echo('An error occurred:')
click.echo(resp.text)
sys.exit(1)
click.echo(op_desc[operations[0]] + '... ', nl=False)
idx = 0
data = []
for line in resp.iter_lines(decode_unicode=True, chunk_size=8):
d = json.loads(line)
data.append(d)
if d['status'] == 'aborted':
click.echo(click.style('ABORTED', fg='red'))
click.echo('The transaction was rolled back.')
click.echo('The error was: ' + d['error'])
click.echo('Please check the ceod logs.')
sys.exit(1)
elif d['status'] == 'completed':
click.echo('Transaction successfully completed.')
return data
operation = d['operation']
oper_failed = False
err_msg = None
prefix = 'failed_to_'
if operation.startswith(prefix):
operation = operation[len(prefix):]
oper_failed = True
# sometimes the operation looks like
# "failed_to_do_something: error message"
if ':' in operation:
operation, err_msg = operation.split(': ', 1)
while idx < len(operations) and operations[idx] != operation:
click.echo('Skipped')
idx += 1
if idx == len(operations):
break
click.echo(op_desc[operations[idx]] + '... ', nl=False)
if idx == len(operations):
click.echo('Unrecognized operation: ' + operation)
continue
if oper_failed:
click.echo(click.style('Failed', fg='red'))
if err_msg is not None:
click.echo(' Error message: ' + err_msg)
else:
click.echo(click.style('Done', fg='green'))
idx += 1
if idx < len(operations):
click.echo(op_desc[operations[idx]] + '... ', nl=False)
raise Exception('server response ended abruptly')
def handle_sync_response(resp: requests.Response):
"""
Exit the program if the request was not successful.
Returns the parsed JSON response.
"""
if resp.status_code // 100 != 2:
click.echo('An error occurred:')
click.echo(resp.text)
sys.exit(1)
return resp.json()
def print_colon_kv(pairs: List[Tuple[str, str]]):
"""
Pretty-print a list of key-value pairs such that the key and value
columns align.
Example:
key1: value1
key1000: value2
"""
maxlen = max(len(key) for key, val in pairs)
for key, val in pairs:
click.echo(key + ': ', nl=False)
extra_space = ' ' * (maxlen - len(key))
click.echo(extra_space, nl=False)
click.echo(val)
def get_failed_operations(data: List[Dict]) -> List[str]:
"""
Get a list of the failed operations using the JSON objects
@ -144,6 +52,8 @@ def get_failed_operations(data: List[Dict]) -> List[str]:
continue
operation = operation[len(prefix):]
if ':' in operation:
# sometimes the operation looks like
# "failed_to_do_something: error message"
operation = operation[:operation.index(':')]
failed.append(operation)
return failed

@ -0,0 +1,67 @@
import datetime
class Term:
"""A representation of a term in the CSC LDAP, e.g. 's2021'."""
seasons = ['w', 's', 'f']
def __init__(self, s_term: str):
assert len(s_term) == 5 and s_term[0] in self.seasons and \
s_term[1:].isdigit()
self.s_term = s_term
def __repr__(self):
return self.s_term
@staticmethod
def current():
"""Get a Term object for the current date."""
dt = datetime.datetime.now()
c = 'w'
if 5 <= dt.month <= 8:
c = 's'
elif 9 <= dt.month:
c = 'f'
s_term = c + str(dt.year)
return Term(s_term)
def __add__(self, other):
assert type(other) is int and other >= 0
c = self.s_term[0]
season_idx = self.seasons.index(c)
year = int(self.s_term[1:])
year += other // 3
season_idx += other % 3
if season_idx >= 3:
year += 1
season_idx -= 3
s_term = self.seasons[season_idx] + str(year)
return Term(s_term)
def __eq__(self, other):
return isinstance(other, Term) and self.s_term == other.s_term
def __lt__(self, other):
if not isinstance(other, Term):
return NotImplemented
c1, c2 = self.s_term[0], other.s_term[0]
year1, year2 = int(self.s_term[1:]), int(other.s_term[1:])
return year1 < year2 or (
year1 == year2 and self.seasons.index(c1) < self.seasons.index(c2)
)
def __gt__(self, other):
if not isinstance(other, Term):
return NotImplemented
c1, c2 = self.s_term[0], other.s_term[0]
year1, year2 = int(self.s_term[1:]), int(other.s_term[1:])
return year1 > year2 or (
year1 == year2 and self.seasons.index(c1) > self.seasons.index(c2)
)
def __ge__(self, other):
return self > other or self == other
def __le__(self, other):
return self < other or self == other

@ -1,3 +1,4 @@
from .Config import Config
from .HTTPClient import HTTPClient
from .RemoteMailmanService import RemoteMailmanService
from .Term import Term

@ -1,40 +0,0 @@
import datetime
from typing import List
def get_current_term() -> str:
"""
Get the current term as formatted in the CSC LDAP (e.g. 's2021').
"""
dt = datetime.datetime.now()
c = 'w'
if 5 <= dt.month <= 8:
c = 's'
elif 9 <= dt.month:
c = 'f'
return c + str(dt.year)
def add_term(term: str) -> str:
"""
Add one term to the given term and return the string.
Example: add_term('s2021') -> 'f2021'
"""
c = term[0]
s_year = term[1:]
if c == 'w':
return 's' + s_year
elif c == 's':
return 'f' + s_year
year = int(s_year)
return 'w' + str(year + 1)
def get_max_term(terms: List[str]) -> str:
"""Get the maximum (latest) term."""
max_year = max(term[1:] for term in terms)
if 'f' + max_year in terms:
return 'f' + max_year
elif 's' + max_year in terms:
return 's' + max_year
return 'w' + max_year

@ -0,0 +1,136 @@
import os
import re
import shutil
from click.testing import CliRunner
from ceo.cli import cli
from ceo_common.model import Term
def test_members_get(cli_setup, ldap_user):
runner = CliRunner()
result = runner.invoke(cli, ['members', 'get', ldap_user.uid])
expected = (
f"uid: {ldap_user.uid}\n"
f"cn: {ldap_user.cn}\n"
f"program: {ldap_user.program}\n"
f"UID number: {ldap_user.uid_number}\n"
f"GID number: {ldap_user.gid_number}\n"
f"login shell: {ldap_user.login_shell}\n"
f"home directory: {ldap_user.home_directory}\n"
f"is a club: {ldap_user.is_club()}\n"
"forwarding addresses: \n"
f"terms: {','.join(ldap_user.terms)}\n"
)
assert result.exit_code == 0
assert result.output == expected
def test_members_add(cli_setup):
runner = CliRunner()
result = runner.invoke(cli, [
'members', 'add', 'test_1', '--cn', 'Test One', '--program', 'Math',
'--terms', '1',
], input='y\n')
expected_pat = re.compile((
"^The following user will be created:\n"
"uid: test_1\n"
"cn: Test One\n"
"program: Math\n"
"member terms: [sfw]\\d{4}\n"
"forwarding address: test_1@uwaterloo.internal\n"
"Do you want to continue\\? \\[y/N\\]: y\n"
"Add user to LDAP... Done\n"
"Add group to LDAP... Done\n"
"Add user to Kerberos... Done\n"
"Create home directory... Done\n"
"Set forwarding addresses... Done\n"
"Send welcome message... Done\n"
"Subscribe to mailing list... Done\n"
"Announce new user to mailing list... Done\n"
"Transaction successfully completed.\n"
"uid: test_1\n"
"cn: Test One\n"
"program: Math\n"
"UID number: \\d{5}\n"
"GID number: \\d{5}\n"
"login shell: /bin/bash\n"
"home directory: [a-z0-9/_-]+/test_1\n"
"is a club: False\n"
"forwarding addresses: test_1@uwaterloo.internal\n"
"terms: [sfw]\\d{4}\n"
"password: \\S+\n$"
), re.MULTILINE)
assert result.exit_code == 0
assert expected_pat.match(result.output) is not None
result = runner.invoke(cli, ['members', 'delete', 'test_1'], input='y\n')
assert result.exit_code == 0
def test_members_modify(cli_setup, ldap_user):
# The homedir needs to exist so the API can write to ~/.forward
os.makedirs(ldap_user.home_directory)
try:
runner = CliRunner()
result = runner.invoke(cli, [
'members', 'modify', ldap_user.uid, '--login-shell', '/bin/sh',
'--forwarding-addresses', 'jdoe@test1.internal,jdoe@test2.internal',
], input='y\n')
expected = (
"Login shell will be set to: /bin/sh\n"
"~/.forward will be set to: jdoe@test1.internal\n"
" jdoe@test2.internal\n"
"Do you want to continue? [y/N]: y\n"
"Replace login shell... Done\n"
"Replace forwarding addresses... Done\n"
"Transaction successfully completed.\n"
)
assert result.exit_code == 0
assert result.output == expected
finally:
shutil.rmtree(ldap_user.home_directory)
def test_members_renew(cli_setup, ldap_user, g_admin_ctx):
# set the user's last term to something really old
with g_admin_ctx(), ldap_user.ldap_srv.entry_ctx_for_user(ldap_user) as entry:
entry.term = ['s1999', 'f1999']
current_term = Term.current()
runner = CliRunner()
result = runner.invoke(cli, [
'members', 'renew', ldap_user.uid, '--terms', '1',
], input='y\n')
expected = (
f"The following member terms will be added: {current_term}\n"
"Do you want to continue? [y/N]: y\n"
"Done.\n"
)
assert result.exit_code == 0
assert result.output == expected
runner = CliRunner()
result = runner.invoke(cli, [
'members', 'renew', ldap_user.uid, '--terms', '2',
], input='y\n')
expected = (
f"The following member terms will be added: {current_term+1},{current_term+2}\n"
"Do you want to continue? [y/N]: y\n"
"Done.\n"
)
assert result.exit_code == 0
assert result.output == expected
def test_members_pwreset(cli_setup, ldap_user, krb_user):
runner = CliRunner()
result = runner.invoke(
cli, ['members', 'pwreset', ldap_user.uid], input='y\n')
expected_pat = re.compile((
f"^Are you sure you want to reset {ldap_user.uid}'s password\\? \\[y/N\\]: y\n"
"New password: \\S+\n$"
), re.MULTILINE)
assert result.exit_code == 0
assert expected_pat.match(result.output) is not None

@ -1,42 +1,12 @@
from multiprocessing import Process
import socket
import sys
import time
import requests
from ceo_common.model import RemoteMailmanService
def test_remote_mailman(cfg, http_client, app, mock_mailman_server, g_syscom):
port = cfg.get('ceod_port')
hostname = socket.gethostname()
def server_start():
sys.stdout = open('/dev/null', 'w')
sys.stderr = sys.stdout
app.run(debug=False, host='0.0.0.0', port=port)
proc = Process(target=server_start)
proc.start()
for _ in range(5):
try:
http_client.get(hostname, '/ping')
except requests.exceptions.ConnectionError:
time.sleep(1)
continue
break
try:
mailman_srv = RemoteMailmanService()
assert mock_mailman_server.subscriptions['csc-general'] == []
# RemoteMailmanService -> app -> MailmanService -> MockMailmanServer
address = 'test_1@csclub.internal'
mailman_srv.subscribe(address, 'csc-general')
assert mock_mailman_server.subscriptions['csc-general'] == [address]
mailman_srv.unsubscribe(address, 'csc-general')
assert mock_mailman_server.subscriptions['csc-general'] == []
finally:
proc.terminate()
proc.join()
def test_remote_mailman(app_process, mock_mailman_server, g_syscom):
mailman_srv = RemoteMailmanService()
assert mock_mailman_server.subscriptions['csc-general'] == []
# RemoteMailmanService -> app -> MailmanService -> MockMailmanServer
address = 'test_1@csclub.internal'
mailman_srv.subscribe(address, 'csc-general')
assert mock_mailman_server.subscriptions['csc-general'] == [address]
mailman_srv.unsubscribe(address, 'csc-general')
assert mock_mailman_server.subscriptions['csc-general'] == []

@ -1,5 +1,7 @@
[DEFAULT]
base_domain = csclub.internal
# merge ceod.ini and ceo.ini values together to make testing easier
uw_domain = uwaterloo.internal
[ceod]
admin_host = phosphoric-acid

@ -1,20 +1,23 @@
import contextlib
import grp
import importlib.resources
from multiprocessing import Process
import os
import pwd
import shutil
import subprocess
from subprocess import DEVNULL
import tempfile
import sys
import time
from unittest.mock import patch, Mock
import flask
import ldap3
import pytest
import requests
import socket
from zope import component
from .utils import krb5ccname_ctx
from ceo_common.interfaces import IConfig, IKerberosService, ILDAPService, \
IFileService, IMailmanService, IHTTPClient, IUWLDAPService, IMailService
from ceo_common.model import Config, HTTPClient
@ -25,6 +28,7 @@ import ceod.utils as utils
from .MockSMTPServer import MockSMTPServer
from .MockMailmanServer import MockMailmanServer
from .conftest_ceod_api import client # noqa: F401
from .conftest_ceo import cli_setup # noqa: F401
@pytest.fixture(scope='session', autouse=True)
@ -47,6 +51,18 @@ def cfg(_drone_hostname_mock):
return _cfg
@pytest.fixture(scope='session', autouse=True)
def _delete_ccaches():
# I've noticed when pytest finishes, the temporary files
# created by tempfile.NamedTemporaryFile() aren't destroyed.
# So, we clean them up here.
from .utils import _ccaches
yield
# forcefully decrement the reference counts, which will trigger
# the destructors
_ccaches.clear()
def delete_test_princs(krb_srv):
proc = subprocess.run([
'kadmin', '-k', '-p', krb_srv.admin_principal, 'listprincs', 'test_*',
@ -86,20 +102,8 @@ def delete_subtree(conn: ldap3.Connection, base_dn: str):
pass
@pytest.fixture(scope='session')
def ceod_admin_creds(cfg, krb_srv):
"""
Acquire credentials for ceod/admin and store them
in the default ccache.
"""
subprocess.run(
['kinit', '-k', cfg.get('ldap_admin_principal')],
check=True,
)
@pytest.fixture
def g_admin_ctx(cfg, ceod_admin_creds, app):
def g_admin_ctx(app):
"""
Store the principal for ceod/admin in flask.g.
This context manager should be used any time LDAP is modified via the
@ -109,62 +113,45 @@ def g_admin_ctx(cfg, ceod_admin_creds, app):
"""
@contextlib.contextmanager
def wrapper():
admin_principal = cfg.get('ldap_admin_principal')
with app.app_context():
with krb5ccname_ctx('ceod/admin'), app.app_context():
try:
flask.g.sasl_user = admin_principal
flask.g.sasl_user = 'ceod/admin'
yield
finally:
flask.g.pop('sasl_user')
return wrapper
@pytest.fixture(scope='session')
def syscom_creds():
"""
Acquire credentials for a syscom member and store them in a ccache.
Yields the name of the ccache file.
"""
with tempfile.NamedTemporaryFile() as f:
subprocess.run(
['kinit', '-c', f.name, 'ctdalek'],
check=True, text=True, input='krb5', stdout=DEVNULL,
)
yield f.name
@pytest.fixture
def g_syscom(syscom_creds, app):
def g_syscom(app):
"""
Store the principal for the syscom member in flask.g, and point
KRB5CCNAME to the file where the TGT is stored.
Use this fixture if you need syscom credentials for an HTTP request
to a different process.
"""
filename = syscom_creds
with app.app_context():
old_krb5ccname = os.environ['KRB5CCNAME']
os.environ['KRB5CCNAME'] = 'FILE:' + filename
with krb5ccname_ctx('ctdalek'), app.app_context():
try:
flask.g.sasl_user = 'ctdalek'
yield filename
yield
finally:
os.environ['KRB5CCNAME'] = old_krb5ccname
flask.g.pop('sasl_user')
@pytest.fixture(scope='session')
def ldap_conn(cfg, ceod_admin_creds) -> ldap3.Connection:
def ldap_conn(cfg) -> ldap3.Connection:
# Assume that the same server URL is being used for the CSC
# and UWLDAP during the tests.
cfg = component.getUtility(IConfig)
server_url = cfg.get('ldap_server_url')
# sanity check
assert server_url == cfg.get('uwldap_server_url')
return ldap3.Connection(
server_url, auto_bind=True, raise_exceptions=True,
authentication=ldap3.SASL, sasl_mechanism=ldap3.KERBEROS,
user=cfg.get('ldap_admin_principal'))
with krb5ccname_ctx('ceod/admin'):
conn = ldap3.Connection(
server_url, auto_bind=True, raise_exceptions=True,
authentication=ldap3.SASL, sasl_mechanism=ldap3.KERBEROS,
user='ceod/admin')
return conn
@pytest.fixture(scope='session')
@ -375,3 +362,32 @@ def uwldap_user(cfg, uwldap_srv, ldap_conn):
)
yield user
conn.delete(dn)
@pytest.fixture(scope='module')
def app_process(cfg, app, http_client):
port = cfg.get('ceod_port')
hostname = socket.gethostname()
def server_start():
sys.stdout = open('/dev/null', 'w')
sys.stderr = sys.stdout
app.run(debug=False, host='0.0.0.0', port=port)
proc = Process(target=server_start)
proc.start()
try:
with krb5ccname_ctx('ctdalek'):
for i in range(5):
try:
http_client.get(hostname, '/ping')
except requests.exceptions.ConnectionError:
time.sleep(1)
continue
break
assert i != 5, 'Timed out'
yield
finally:
proc.terminate()
proc.join()

@ -0,0 +1,18 @@
import os
import pytest
from .utils import krb5ccname_ctx
@pytest.fixture(scope='module')
def cli_setup(app_process):
# This tells the CLI entrypoint not to register additional zope services.
os.environ['PYTEST'] = '1'
# Running the client and the server in the same process would be very
# messy because they would be sharing the same environment variables,
# Kerberos cache, and registered utilities (via zope). So we're just
# going to start the app in a child process intead.
with krb5ccname_ctx('ctdalek'):
yield

@ -1,10 +1,6 @@
from base64 import b64encode
import contextlib
import os
import json
import socket
import subprocess
import tempfile
from flask import g
from flask.testing import FlaskClient
@ -14,6 +10,7 @@ from requests import Request
from requests_gssapi import HTTPSPNEGOAuth
from ceo_common.krb5.utils import get_fwd_tgt
from .utils import krb5ccname_ctx
__all__ = ['client']
@ -21,40 +18,23 @@ __all__ = ['client']
@pytest.fixture(scope='session')
def client(app):
app_client = app.test_client()
with tempfile.TemporaryDirectory() as cache_dir:
yield CeodTestClient(app_client, cache_dir)
yield CeodTestClient(app_client)
class CeodTestClient:
def __init__(self, app_client: FlaskClient, cache_dir: str):
def __init__(self, app_client: FlaskClient):
self.client = app_client
self.syscom_principal = 'ctdalek'
# this is only used for the HTTPSNEGOAuth
self.base_url = f'http://{socket.getfqdn()}'
# for each principal for which we acquired a TGT, map their
# username to a file (ccache) storing their TGT
self.principal_ccaches = {}
# this is where we'll store the credentials for each principal
self.cache_dir = cache_dir
# for SPNEGO
self.target_name = gssapi.Name('ceod/' + socket.getfqdn())
@contextlib.contextmanager
def krb5ccname_env(self, principal):
"""Temporarily change KRB5CCNAME to the ccache of the principal."""
old_krb5ccname = os.environ['KRB5CCNAME']
os.environ['KRB5CCNAME'] = self.principal_ccaches[principal]
try:
yield
finally:
os.environ['KRB5CCNAME'] = old_krb5ccname
def get_auth(self, principal):
"""Acquire a HTTPSPNEGOAuth instance for the principal."""
name = gssapi.Name(principal)
# the 'store' arg doesn't seem to work for DIR ccaches
with self.krb5ccname_env(principal):
creds = gssapi.Credentials(name=name, usage='initiate')
creds = gssapi.Credentials(name=name, usage='initiate')
auth = HTTPSPNEGOAuth(
opportunistic_auth=True,
target_name=self.target_name,
@ -62,32 +42,17 @@ class CeodTestClient:
)
return auth
def kinit(self, principal):
"""Acquire an initial TGT for the principal."""
# For some reason, kinit with the '-c' option deletes the other
# credentials in the cache collection, so we need to override the
# env variable
subprocess.run(
['kinit', principal],
text=True, input='krb5', check=True, stdout=subprocess.DEVNULL,
env={'KRB5CCNAME': self.principal_ccaches[principal]})
def get_headers(self, principal: str, need_cred: bool):
if principal not in self.principal_ccaches:
_, filename = tempfile.mkstemp(dir=self.cache_dir)
self.principal_ccaches[principal] = filename
self.kinit(principal)
# Get the Authorization header (SPNEGO).
# The method doesn't matter here because we just need to extract
# the header using req.prepare().
req = Request('GET', self.base_url, auth=self.get_auth(principal))
headers = list(req.prepare().headers.items())
if need_cred:
# Get the X-KRB5-CRED header (forwarded TGT).
cred = b64encode(get_fwd_tgt(
'ceod/' + socket.getfqdn(), self.principal_ccaches[principal]
)).decode()
headers.append(('X-KRB5-CRED', cred))
with krb5ccname_ctx(principal):
# Get the Authorization header (SPNEGO).
# The method doesn't matter here because we just need to extract
# the header using req.prepare().
req = Request('GET', self.base_url, auth=self.get_auth(principal))
headers = list(req.prepare().headers.items())
if need_cred:
# Get the X-KRB5-CRED header (forwarded TGT).
cred = b64encode(get_fwd_tgt('ceod/' + socket.getfqdn())).decode()
headers.append(('X-KRB5-CRED', cred))
return headers
def request(self, method: str, path: str, principal: str, need_cred: bool, **kwargs):

@ -0,0 +1,34 @@
import contextlib
import os
import subprocess
from subprocess import DEVNULL
import tempfile
# map principals to files storing credentials
_ccaches = {}
@contextlib.contextmanager
def krb5ccname_ctx(principal: str):
"""
Temporarily set KRB5CCNAME to a ccache storing credentials
for the specified user.
"""
old_krb5ccname = os.environ['KRB5CCNAME']
try:
if principal not in _ccaches:
f = tempfile.NamedTemporaryFile()
os.environ['KRB5CCNAME'] = 'FILE:' + f.name
args = ['kinit', principal]
if principal == 'ceod/admin':
args = ['kinit', '-k', principal]
subprocess.run(
args, stdout=DEVNULL, text=True, input='krb5',
check=True)
_ccaches[principal] = f
else:
os.environ['KRB5CCNAME'] = 'FILE:' + _ccaches[principal].name
yield
finally:
os.environ['KRB5CCNAME'] = old_krb5ccname
Loading…
Cancel
Save