add unit tests for members CLI
continuous-integration/drone/push Build is passing
Details
continuous-integration/drone/push Build is passing
Details
This commit is contained in:
parent
7a8751fd8f
commit
08a3faaefc
|
@ -0,0 +1 @@
|
|||
from .entrypoint import cli
|
|
@ -5,7 +5,7 @@ import socket
|
|||
import click
|
||||
from zope import component
|
||||
|
||||
from .krb_check import krb_check
|
||||
from ..krb_check import krb_check
|
||||
from .members import members
|
||||
from ceo_common.interfaces import IConfig, IHTTPClient
|
||||
from ceo_common.model import Config, HTTPClient
|
||||
|
@ -21,7 +21,8 @@ def cli(ctx):
|
|||
user = princ[:princ.index('@')]
|
||||
ctx.obj['user'] = user
|
||||
|
||||
register_services()
|
||||
if os.environ.get('PYTEST') != '1':
|
||||
register_services()
|
||||
|
||||
|
||||
cli.add_command(members)
|
|
@ -5,11 +5,10 @@ from typing import Dict
|
|||
import click
|
||||
from zope import component
|
||||
|
||||
from .utils import http_post, http_get, http_patch, http_delete, \
|
||||
handle_stream_response, handle_sync_response, print_colon_kv, \
|
||||
get_failed_operations
|
||||
from ..utils import http_post, http_get, http_patch, http_delete, get_failed_operations
|
||||
from .utils import handle_stream_response, handle_sync_response, print_colon_kv
|
||||
from ceo_common.interfaces import IConfig
|
||||
from ceo_common.utils import get_current_term, add_term, get_max_term
|
||||
from ceo_common.model import Term
|
||||
from ceod.transactions.members import (
|
||||
AddMemberTransaction,
|
||||
DeleteMemberTransaction,
|
||||
|
@ -37,11 +36,10 @@ def add(username, cn, program, num_terms, clubrep, forwarding_address):
|
|||
cfg = component.getUtility(IConfig)
|
||||
uw_domain = cfg.get('uw_domain')
|
||||
|
||||
current_term = get_current_term()
|
||||
terms = [current_term]
|
||||
for _ in range(1, num_terms):
|
||||
term = add_term(terms[-1])
|
||||
terms.append(term)
|
||||
current_term = Term.current()
|
||||
terms = [current_term + i for i in range(num_terms)]
|
||||
terms = list(map(str, terms))
|
||||
|
||||
if forwarding_address is None:
|
||||
forwarding_address = username + '@' + uw_domain
|
||||
|
||||
|
@ -69,13 +67,18 @@ def add(username, cn, program, num_terms, clubrep, forwarding_address):
|
|||
if program is not None:
|
||||
body['program'] = program
|
||||
if clubrep:
|
||||
body['terms'] = terms
|
||||
else:
|
||||
body['non_member_terms'] = terms
|
||||
else:
|
||||
body['terms'] = terms
|
||||
if forwarding_address != '':
|
||||
body['forwarding_addresses'] = [forwarding_address]
|
||||
operations = AddMemberTransaction.operations
|
||||
if forwarding_address == '':
|
||||
# don't bother displaying this because it won't be run
|
||||
operations.remove('set_forwarding_addresses')
|
||||
|
||||
resp = http_post('/api/members', json=body)
|
||||
data = handle_stream_response(resp, AddMemberTransaction.operations)
|
||||
data = handle_stream_response(resp, operations)
|
||||
result = data[-1]['result']
|
||||
print_user_lines(result)
|
||||
|
||||
|
@ -121,7 +124,10 @@ def get(username):
|
|||
@click.argument('username')
|
||||
@click.option('--login-shell', required=False, help='Login shell')
|
||||
@click.option('--forwarding-addresses', required=False,
|
||||
help='Comma-separated list of forwarding addresses')
|
||||
help=(
|
||||
'Comma-separated list of forwarding addresses. '
|
||||
'Set to the empty string to disable forwarding.'
|
||||
))
|
||||
def modify(username, login_shell, forwarding_addresses):
|
||||
if login_shell is None and forwarding_addresses is None:
|
||||
click.echo('Nothing to do.')
|
||||
|
@ -133,13 +139,19 @@ def modify(username, login_shell, forwarding_addresses):
|
|||
operations.append('replace_login_shell')
|
||||
click.echo('Login shell will be set to: ' + login_shell)
|
||||
if forwarding_addresses is not None:
|
||||
forwarding_addresses = forwarding_addresses.split(',')
|
||||
if forwarding_addresses == '':
|
||||
forwarding_addresses = []
|
||||
else:
|
||||
forwarding_addresses = forwarding_addresses.split(',')
|
||||
body['forwarding_addresses'] = forwarding_addresses
|
||||
operations.append('replace_forwarding_addresses')
|
||||
prefix = '~/.forward will be set to: '
|
||||
click.echo(prefix + forwarding_addresses[0])
|
||||
for address in forwarding_addresses[1:]:
|
||||
click.echo((' ' * len(prefix)) + address)
|
||||
if len(forwarding_addresses) > 0:
|
||||
click.echo(prefix + forwarding_addresses[0])
|
||||
for address in forwarding_addresses[1:]:
|
||||
click.echo((' ' * len(prefix)) + address)
|
||||
else:
|
||||
click.echo(prefix)
|
||||
|
||||
click.confirm('Do you want to continue?', abort=True)
|
||||
|
||||
|
@ -157,19 +169,19 @@ def renew(username, num_terms, clubrep):
|
|||
resp = http_get('/api/members/' + username)
|
||||
result = handle_sync_response(resp)
|
||||
max_term = None
|
||||
current_term = Term.current()
|
||||
if clubrep and 'non_member_terms' in result:
|
||||
max_term = get_max_term(result['non_member_terms'])
|
||||
max_term = max(Term(s) for s in result['non_member_terms'])
|
||||
elif not clubrep and 'terms' in result:
|
||||
max_term = get_max_term(result['terms'])
|
||||
if max_term is not None:
|
||||
max_term = get_max_term([max_term, get_current_term()])
|
||||
else:
|
||||
max_term = get_current_term()
|
||||
max_term = max(Term(s) for s in result['terms'])
|
||||
|
||||
terms = [add_term(max_term)]
|
||||
for _ in range(1, num_terms):
|
||||
term = add_term(terms[-1])
|
||||
terms.append(term)
|
||||
if max_term is not None and max_term >= current_term:
|
||||
next_term = max_term + 1
|
||||
else:
|
||||
next_term = Term.current()
|
||||
|
||||
terms = [next_term + i for i in range(num_terms)]
|
||||
terms = list(map(str, terms))
|
||||
|
||||
if clubrep:
|
||||
body = {'non_member_terms': terms}
|
|
@ -0,0 +1,107 @@
|
|||
import json
|
||||
import sys
|
||||
from typing import List, Tuple, Dict
|
||||
|
||||
import click
|
||||
import requests
|
||||
|
||||
from ..operation_strings import descriptions as op_desc
|
||||
|
||||
|
||||
class Abort(click.ClickException):
|
||||
"""Abort silently."""
|
||||
|
||||
def __init__(self, exit_code=1):
|
||||
super().__init__('')
|
||||
self.exit_code = exit_code
|
||||
|
||||
def show(self):
|
||||
pass
|
||||
|
||||
|
||||
def print_colon_kv(pairs: List[Tuple[str, str]]):
|
||||
"""
|
||||
Pretty-print a list of key-value pairs such that the key and value
|
||||
columns align.
|
||||
Example:
|
||||
key1: value1
|
||||
key1000: value2
|
||||
"""
|
||||
maxlen = max(len(key) for key, val in pairs)
|
||||
for key, val in pairs:
|
||||
click.echo(key + ': ', nl=False)
|
||||
extra_space = ' ' * (maxlen - len(key))
|
||||
click.echo(extra_space, nl=False)
|
||||
click.echo(val)
|
||||
|
||||
|
||||
def handle_stream_response(resp: requests.Response, operations: List[str]) -> List[Dict]:
|
||||
"""
|
||||
Print output to the console while operations are being streamed
|
||||
from the server over HTTP.
|
||||
Returns the parsed JSON data streamed from the server.
|
||||
"""
|
||||
if resp.status_code != 200:
|
||||
click.echo('An error occurred:')
|
||||
click.echo(resp.text.rstrip())
|
||||
raise Abort()
|
||||
click.echo(op_desc[operations[0]] + '... ', nl=False)
|
||||
idx = 0
|
||||
data = []
|
||||
for line in resp.iter_lines(decode_unicode=True, chunk_size=8):
|
||||
d = json.loads(line)
|
||||
data.append(d)
|
||||
if d['status'] == 'aborted':
|
||||
click.echo(click.style('ABORTED', fg='red'))
|
||||
click.echo('The transaction was rolled back.')
|
||||
click.echo('The error was: ' + d['error'])
|
||||
click.echo('Please check the ceod logs.')
|
||||
sys.exit(1)
|
||||
elif d['status'] == 'completed':
|
||||
click.echo('Transaction successfully completed.')
|
||||
return data
|
||||
|
||||
operation = d['operation']
|
||||
oper_failed = False
|
||||
err_msg = None
|
||||
prefix = 'failed_to_'
|
||||
if operation.startswith(prefix):
|
||||
operation = operation[len(prefix):]
|
||||
oper_failed = True
|
||||
# sometimes the operation looks like
|
||||
# "failed_to_do_something: error message"
|
||||
if ':' in operation:
|
||||
operation, err_msg = operation.split(': ', 1)
|
||||
|
||||
while idx < len(operations) and operations[idx] != operation:
|
||||
click.echo('Skipped')
|
||||
idx += 1
|
||||
if idx == len(operations):
|
||||
break
|
||||
click.echo(op_desc[operations[idx]] + '... ', nl=False)
|
||||
if idx == len(operations):
|
||||
click.echo('Unrecognized operation: ' + operation)
|
||||
continue
|
||||
if oper_failed:
|
||||
click.echo(click.style('Failed', fg='red'))
|
||||
if err_msg is not None:
|
||||
click.echo(' Error message: ' + err_msg)
|
||||
else:
|
||||
click.echo(click.style('Done', fg='green'))
|
||||
idx += 1
|
||||
if idx < len(operations):
|
||||
click.echo(op_desc[operations[idx]] + '... ', nl=False)
|
||||
|
||||
raise Exception('server response ended abruptly')
|
||||
|
||||
|
||||
def handle_sync_response(resp: requests.Response):
|
||||
"""
|
||||
Exit the program if the request was not successful.
|
||||
Returns the parsed JSON response.
|
||||
"""
|
||||
if resp.status_code != 200:
|
||||
click.echo('An error occurred:')
|
||||
click.echo(resp.text.rstrip())
|
||||
raise Abort()
|
||||
return resp.json()
|
|
@ -9,19 +9,15 @@ def krb_check():
|
|||
credentials have expired.
|
||||
Returns the principal string 'user@REALM'.
|
||||
"""
|
||||
try:
|
||||
creds = gssapi.Credentials(usage='initiate')
|
||||
except gssapi.raw.misc.GSSError:
|
||||
kinit()
|
||||
creds = gssapi.Credentials(usage='initiate')
|
||||
for _ in range(2):
|
||||
try:
|
||||
creds = gssapi.Credentials(usage='initiate')
|
||||
result = creds.inquire()
|
||||
return str(result.name)
|
||||
except (gssapi.raw.misc.GSSError, gssapi.raw.exceptions.ExpiredCredentialsError):
|
||||
kinit()
|
||||
|
||||
try:
|
||||
result = creds.inquire()
|
||||
except gssapi.raw.exceptions.ExpiredCredentialsError:
|
||||
kinit()
|
||||
result = creds.inquire()
|
||||
|
||||
return str(result.name)
|
||||
raise Exception('could not acquire GSSAPI credentials')
|
||||
|
||||
|
||||
def kinit():
|
||||
|
|
96
ceo/utils.py
96
ceo/utils.py
|
@ -1,12 +1,8 @@
|
|||
import json
|
||||
import sys
|
||||
from typing import List, Tuple, Dict
|
||||
from typing import List, Dict
|
||||
|
||||
import click
|
||||
import requests
|
||||
from zope import component
|
||||
|
||||
from .operation_strings import descriptions as op_desc
|
||||
from ceo_common.interfaces import IHTTPClient, IConfig
|
||||
|
||||
|
||||
|
@ -41,94 +37,6 @@ def http_delete(path: str, **kwargs) -> requests.Response:
|
|||
return http_request('DELETE', path, **kwargs)
|
||||
|
||||
|
||||
def handle_stream_response(resp: requests.Response, operations: List[str]) -> List[Dict]:
|
||||
"""
|
||||
Print output to the console while operations are being streamed
|
||||
from the server over HTTP.
|
||||
Returns the parsed JSON data streamed from the server.
|
||||
"""
|
||||
if resp.status_code != 200:
|
||||
click.echo('An error occurred:')
|
||||
click.echo(resp.text)
|
||||
sys.exit(1)
|
||||
click.echo(op_desc[operations[0]] + '... ', nl=False)
|
||||
idx = 0
|
||||
data = []
|
||||
for line in resp.iter_lines(decode_unicode=True, chunk_size=8):
|
||||
d = json.loads(line)
|
||||
data.append(d)
|
||||
if d['status'] == 'aborted':
|
||||
click.echo(click.style('ABORTED', fg='red'))
|
||||
click.echo('The transaction was rolled back.')
|
||||
click.echo('The error was: ' + d['error'])
|
||||
click.echo('Please check the ceod logs.')
|
||||
sys.exit(1)
|
||||
elif d['status'] == 'completed':
|
||||
click.echo('Transaction successfully completed.')
|
||||
return data
|
||||
|
||||
operation = d['operation']
|
||||
oper_failed = False
|
||||
err_msg = None
|
||||
prefix = 'failed_to_'
|
||||
if operation.startswith(prefix):
|
||||
operation = operation[len(prefix):]
|
||||
oper_failed = True
|
||||
# sometimes the operation looks like
|
||||
# "failed_to_do_something: error message"
|
||||
if ':' in operation:
|
||||
operation, err_msg = operation.split(': ', 1)
|
||||
|
||||
while idx < len(operations) and operations[idx] != operation:
|
||||
click.echo('Skipped')
|
||||
idx += 1
|
||||
if idx == len(operations):
|
||||
break
|
||||
click.echo(op_desc[operations[idx]] + '... ', nl=False)
|
||||
if idx == len(operations):
|
||||
click.echo('Unrecognized operation: ' + operation)
|
||||
continue
|
||||
if oper_failed:
|
||||
click.echo(click.style('Failed', fg='red'))
|
||||
if err_msg is not None:
|
||||
click.echo(' Error message: ' + err_msg)
|
||||
else:
|
||||
click.echo(click.style('Done', fg='green'))
|
||||
idx += 1
|
||||
if idx < len(operations):
|
||||
click.echo(op_desc[operations[idx]] + '... ', nl=False)
|
||||
|
||||
raise Exception('server response ended abruptly')
|
||||
|
||||
|
||||
def handle_sync_response(resp: requests.Response):
|
||||
"""
|
||||
Exit the program if the request was not successful.
|
||||
Returns the parsed JSON response.
|
||||
"""
|
||||
if resp.status_code // 100 != 2:
|
||||
click.echo('An error occurred:')
|
||||
click.echo(resp.text)
|
||||
sys.exit(1)
|
||||
return resp.json()
|
||||
|
||||
|
||||
def print_colon_kv(pairs: List[Tuple[str, str]]):
|
||||
"""
|
||||
Pretty-print a list of key-value pairs such that the key and value
|
||||
columns align.
|
||||
Example:
|
||||
key1: value1
|
||||
key1000: value2
|
||||
"""
|
||||
maxlen = max(len(key) for key, val in pairs)
|
||||
for key, val in pairs:
|
||||
click.echo(key + ': ', nl=False)
|
||||
extra_space = ' ' * (maxlen - len(key))
|
||||
click.echo(extra_space, nl=False)
|
||||
click.echo(val)
|
||||
|
||||
|
||||
def get_failed_operations(data: List[Dict]) -> List[str]:
|
||||
"""
|
||||
Get a list of the failed operations using the JSON objects
|
||||
|
@ -144,6 +52,8 @@ def get_failed_operations(data: List[Dict]) -> List[str]:
|
|||
continue
|
||||
operation = operation[len(prefix):]
|
||||
if ':' in operation:
|
||||
# sometimes the operation looks like
|
||||
# "failed_to_do_something: error message"
|
||||
operation = operation[:operation.index(':')]
|
||||
failed.append(operation)
|
||||
return failed
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
import datetime
|
||||
|
||||
|
||||
class Term:
|
||||
"""A representation of a term in the CSC LDAP, e.g. 's2021'."""
|
||||
|
||||
seasons = ['w', 's', 'f']
|
||||
|
||||
def __init__(self, s_term: str):
|
||||
assert len(s_term) == 5 and s_term[0] in self.seasons and \
|
||||
s_term[1:].isdigit()
|
||||
self.s_term = s_term
|
||||
|
||||
def __repr__(self):
|
||||
return self.s_term
|
||||
|
||||
@staticmethod
|
||||
def current():
|
||||
"""Get a Term object for the current date."""
|
||||
dt = datetime.datetime.now()
|
||||
c = 'w'
|
||||
if 5 <= dt.month <= 8:
|
||||
c = 's'
|
||||
elif 9 <= dt.month:
|
||||
c = 'f'
|
||||
s_term = c + str(dt.year)
|
||||
return Term(s_term)
|
||||
|
||||
def __add__(self, other):
|
||||
assert type(other) is int and other >= 0
|
||||
c = self.s_term[0]
|
||||
season_idx = self.seasons.index(c)
|
||||
year = int(self.s_term[1:])
|
||||
year += other // 3
|
||||
season_idx += other % 3
|
||||
if season_idx >= 3:
|
||||
year += 1
|
||||
season_idx -= 3
|
||||
s_term = self.seasons[season_idx] + str(year)
|
||||
return Term(s_term)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, Term) and self.s_term == other.s_term
|
||||
|
||||
def __lt__(self, other):
|
||||
if not isinstance(other, Term):
|
||||
return NotImplemented
|
||||
c1, c2 = self.s_term[0], other.s_term[0]
|
||||
year1, year2 = int(self.s_term[1:]), int(other.s_term[1:])
|
||||
return year1 < year2 or (
|
||||
year1 == year2 and self.seasons.index(c1) < self.seasons.index(c2)
|
||||
)
|
||||
|
||||
def __gt__(self, other):
|
||||
if not isinstance(other, Term):
|
||||
return NotImplemented
|
||||
c1, c2 = self.s_term[0], other.s_term[0]
|
||||
year1, year2 = int(self.s_term[1:]), int(other.s_term[1:])
|
||||
return year1 > year2 or (
|
||||
year1 == year2 and self.seasons.index(c1) > self.seasons.index(c2)
|
||||
)
|
||||
|
||||
def __ge__(self, other):
|
||||
return self > other or self == other
|
||||
|
||||
def __le__(self, other):
|
||||
return self < other or self == other
|
|
@ -1,3 +1,4 @@
|
|||
from .Config import Config
|
||||
from .HTTPClient import HTTPClient
|
||||
from .RemoteMailmanService import RemoteMailmanService
|
||||
from .Term import Term
|
||||
|
|
|
@ -1,40 +0,0 @@
|
|||
import datetime
|
||||
from typing import List
|
||||
|
||||
|
||||
def get_current_term() -> str:
|
||||
"""
|
||||
Get the current term as formatted in the CSC LDAP (e.g. 's2021').
|
||||
"""
|
||||
dt = datetime.datetime.now()
|
||||
c = 'w'
|
||||
if 5 <= dt.month <= 8:
|
||||
c = 's'
|
||||
elif 9 <= dt.month:
|
||||
c = 'f'
|
||||
return c + str(dt.year)
|
||||
|
||||
|
||||
def add_term(term: str) -> str:
|
||||
"""
|
||||
Add one term to the given term and return the string.
|
||||
Example: add_term('s2021') -> 'f2021'
|
||||
"""
|
||||
c = term[0]
|
||||
s_year = term[1:]
|
||||
if c == 'w':
|
||||
return 's' + s_year
|
||||
elif c == 's':
|
||||
return 'f' + s_year
|
||||
year = int(s_year)
|
||||
return 'w' + str(year + 1)
|
||||
|
||||
|
||||
def get_max_term(terms: List[str]) -> str:
|
||||
"""Get the maximum (latest) term."""
|
||||
max_year = max(term[1:] for term in terms)
|
||||
if 'f' + max_year in terms:
|
||||
return 'f' + max_year
|
||||
elif 's' + max_year in terms:
|
||||
return 's' + max_year
|
||||
return 'w' + max_year
|
|
@ -0,0 +1,136 @@
|
|||
import os
|
||||
import re
|
||||
import shutil
|
||||
|
||||
from click.testing import CliRunner
|
||||
|
||||
from ceo.cli import cli
|
||||
from ceo_common.model import Term
|
||||
|
||||
|
||||
def test_members_get(cli_setup, ldap_user):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ['members', 'get', ldap_user.uid])
|
||||
expected = (
|
||||
f"uid: {ldap_user.uid}\n"
|
||||
f"cn: {ldap_user.cn}\n"
|
||||
f"program: {ldap_user.program}\n"
|
||||
f"UID number: {ldap_user.uid_number}\n"
|
||||
f"GID number: {ldap_user.gid_number}\n"
|
||||
f"login shell: {ldap_user.login_shell}\n"
|
||||
f"home directory: {ldap_user.home_directory}\n"
|
||||
f"is a club: {ldap_user.is_club()}\n"
|
||||
"forwarding addresses: \n"
|
||||
f"terms: {','.join(ldap_user.terms)}\n"
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert result.output == expected
|
||||
|
||||
|
||||
def test_members_add(cli_setup):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, [
|
||||
'members', 'add', 'test_1', '--cn', 'Test One', '--program', 'Math',
|
||||
'--terms', '1',
|
||||
], input='y\n')
|
||||
expected_pat = re.compile((
|
||||
"^The following user will be created:\n"
|
||||
"uid: test_1\n"
|
||||
"cn: Test One\n"
|
||||
"program: Math\n"
|
||||
"member terms: [sfw]\\d{4}\n"
|
||||
"forwarding address: test_1@uwaterloo.internal\n"
|
||||
"Do you want to continue\\? \\[y/N\\]: y\n"
|
||||
"Add user to LDAP... Done\n"
|
||||
"Add group to LDAP... Done\n"
|
||||
"Add user to Kerberos... Done\n"
|
||||
"Create home directory... Done\n"
|
||||
"Set forwarding addresses... Done\n"
|
||||
"Send welcome message... Done\n"
|
||||
"Subscribe to mailing list... Done\n"
|
||||
"Announce new user to mailing list... Done\n"
|
||||
"Transaction successfully completed.\n"
|
||||
"uid: test_1\n"
|
||||
"cn: Test One\n"
|
||||
"program: Math\n"
|
||||
"UID number: \\d{5}\n"
|
||||
"GID number: \\d{5}\n"
|
||||
"login shell: /bin/bash\n"
|
||||
"home directory: [a-z0-9/_-]+/test_1\n"
|
||||
"is a club: False\n"
|
||||
"forwarding addresses: test_1@uwaterloo.internal\n"
|
||||
"terms: [sfw]\\d{4}\n"
|
||||
"password: \\S+\n$"
|
||||
), re.MULTILINE)
|
||||
assert result.exit_code == 0
|
||||
assert expected_pat.match(result.output) is not None
|
||||
|
||||
result = runner.invoke(cli, ['members', 'delete', 'test_1'], input='y\n')
|
||||
assert result.exit_code == 0
|
||||
|
||||
|
||||
def test_members_modify(cli_setup, ldap_user):
|
||||
# The homedir needs to exist so the API can write to ~/.forward
|
||||
os.makedirs(ldap_user.home_directory)
|
||||
try:
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, [
|
||||
'members', 'modify', ldap_user.uid, '--login-shell', '/bin/sh',
|
||||
'--forwarding-addresses', 'jdoe@test1.internal,jdoe@test2.internal',
|
||||
], input='y\n')
|
||||
expected = (
|
||||
"Login shell will be set to: /bin/sh\n"
|
||||
"~/.forward will be set to: jdoe@test1.internal\n"
|
||||
" jdoe@test2.internal\n"
|
||||
"Do you want to continue? [y/N]: y\n"
|
||||
"Replace login shell... Done\n"
|
||||
"Replace forwarding addresses... Done\n"
|
||||
"Transaction successfully completed.\n"
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert result.output == expected
|
||||
finally:
|
||||
shutil.rmtree(ldap_user.home_directory)
|
||||
|
||||
|
||||
def test_members_renew(cli_setup, ldap_user, g_admin_ctx):
|
||||
# set the user's last term to something really old
|
||||
with g_admin_ctx(), ldap_user.ldap_srv.entry_ctx_for_user(ldap_user) as entry:
|
||||
entry.term = ['s1999', 'f1999']
|
||||
current_term = Term.current()
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, [
|
||||
'members', 'renew', ldap_user.uid, '--terms', '1',
|
||||
], input='y\n')
|
||||
expected = (
|
||||
f"The following member terms will be added: {current_term}\n"
|
||||
"Do you want to continue? [y/N]: y\n"
|
||||
"Done.\n"
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert result.output == expected
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, [
|
||||
'members', 'renew', ldap_user.uid, '--terms', '2',
|
||||
], input='y\n')
|
||||
expected = (
|
||||
f"The following member terms will be added: {current_term+1},{current_term+2}\n"
|
||||
"Do you want to continue? [y/N]: y\n"
|
||||
"Done.\n"
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert result.output == expected
|
||||
|
||||
|
||||
def test_members_pwreset(cli_setup, ldap_user, krb_user):
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli, ['members', 'pwreset', ldap_user.uid], input='y\n')
|
||||
expected_pat = re.compile((
|
||||
f"^Are you sure you want to reset {ldap_user.uid}'s password\\? \\[y/N\\]: y\n"
|
||||
"New password: \\S+\n$"
|
||||
), re.MULTILINE)
|
||||
assert result.exit_code == 0
|
||||
assert expected_pat.match(result.output) is not None
|
|
@ -1,42 +1,12 @@
|
|||
from multiprocessing import Process
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from ceo_common.model import RemoteMailmanService
|
||||
|
||||
|
||||
def test_remote_mailman(cfg, http_client, app, mock_mailman_server, g_syscom):
|
||||
port = cfg.get('ceod_port')
|
||||
hostname = socket.gethostname()
|
||||
|
||||
def server_start():
|
||||
sys.stdout = open('/dev/null', 'w')
|
||||
sys.stderr = sys.stdout
|
||||
app.run(debug=False, host='0.0.0.0', port=port)
|
||||
|
||||
proc = Process(target=server_start)
|
||||
proc.start()
|
||||
|
||||
for _ in range(5):
|
||||
try:
|
||||
http_client.get(hostname, '/ping')
|
||||
except requests.exceptions.ConnectionError:
|
||||
time.sleep(1)
|
||||
continue
|
||||
break
|
||||
|
||||
try:
|
||||
mailman_srv = RemoteMailmanService()
|
||||
assert mock_mailman_server.subscriptions['csc-general'] == []
|
||||
# RemoteMailmanService -> app -> MailmanService -> MockMailmanServer
|
||||
address = 'test_1@csclub.internal'
|
||||
mailman_srv.subscribe(address, 'csc-general')
|
||||
assert mock_mailman_server.subscriptions['csc-general'] == [address]
|
||||
mailman_srv.unsubscribe(address, 'csc-general')
|
||||
assert mock_mailman_server.subscriptions['csc-general'] == []
|
||||
finally:
|
||||
proc.terminate()
|
||||
proc.join()
|
||||
def test_remote_mailman(app_process, mock_mailman_server, g_syscom):
|
||||
mailman_srv = RemoteMailmanService()
|
||||
assert mock_mailman_server.subscriptions['csc-general'] == []
|
||||
# RemoteMailmanService -> app -> MailmanService -> MockMailmanServer
|
||||
address = 'test_1@csclub.internal'
|
||||
mailman_srv.subscribe(address, 'csc-general')
|
||||
assert mock_mailman_server.subscriptions['csc-general'] == [address]
|
||||
mailman_srv.unsubscribe(address, 'csc-general')
|
||||
assert mock_mailman_server.subscriptions['csc-general'] == []
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
[DEFAULT]
|
||||
base_domain = csclub.internal
|
||||
# merge ceod.ini and ceo.ini values together to make testing easier
|
||||
uw_domain = uwaterloo.internal
|
||||
|
||||
[ceod]
|
||||
admin_host = phosphoric-acid
|
||||
|
|
|
@ -1,20 +1,23 @@
|
|||
import contextlib
|
||||
import grp
|
||||
import importlib.resources
|
||||
from multiprocessing import Process
|
||||
import os
|
||||
import pwd
|
||||
import shutil
|
||||
import subprocess
|
||||
from subprocess import DEVNULL
|
||||
import tempfile
|
||||
import sys
|
||||
import time
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
import flask
|
||||
import ldap3
|
||||
import pytest
|
||||
import requests
|
||||
import socket
|
||||
from zope import component
|
||||
|
||||
from .utils import krb5ccname_ctx
|
||||
from ceo_common.interfaces import IConfig, IKerberosService, ILDAPService, \
|
||||
IFileService, IMailmanService, IHTTPClient, IUWLDAPService, IMailService
|
||||
from ceo_common.model import Config, HTTPClient
|
||||
|
@ -25,6 +28,7 @@ import ceod.utils as utils
|
|||
from .MockSMTPServer import MockSMTPServer
|
||||
from .MockMailmanServer import MockMailmanServer
|
||||
from .conftest_ceod_api import client # noqa: F401
|
||||
from .conftest_ceo import cli_setup # noqa: F401
|
||||
|
||||
|
||||
@pytest.fixture(scope='session', autouse=True)
|
||||
|
@ -47,6 +51,18 @@ def cfg(_drone_hostname_mock):
|
|||
return _cfg
|
||||
|
||||
|
||||
@pytest.fixture(scope='session', autouse=True)
|
||||
def _delete_ccaches():
|
||||
# I've noticed when pytest finishes, the temporary files
|
||||
# created by tempfile.NamedTemporaryFile() aren't destroyed.
|
||||
# So, we clean them up here.
|
||||
from .utils import _ccaches
|
||||
yield
|
||||
# forcefully decrement the reference counts, which will trigger
|
||||
# the destructors
|
||||
_ccaches.clear()
|
||||
|
||||
|
||||
def delete_test_princs(krb_srv):
|
||||
proc = subprocess.run([
|
||||
'kadmin', '-k', '-p', krb_srv.admin_principal, 'listprincs', 'test_*',
|
||||
|
@ -86,20 +102,8 @@ def delete_subtree(conn: ldap3.Connection, base_dn: str):
|
|||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def ceod_admin_creds(cfg, krb_srv):
|
||||
"""
|
||||
Acquire credentials for ceod/admin and store them
|
||||
in the default ccache.
|
||||
"""
|
||||
subprocess.run(
|
||||
['kinit', '-k', cfg.get('ldap_admin_principal')],
|
||||
check=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def g_admin_ctx(cfg, ceod_admin_creds, app):
|
||||
def g_admin_ctx(app):
|
||||
"""
|
||||
Store the principal for ceod/admin in flask.g.
|
||||
This context manager should be used any time LDAP is modified via the
|
||||
|
@ -109,62 +113,45 @@ def g_admin_ctx(cfg, ceod_admin_creds, app):
|
|||
"""
|
||||
@contextlib.contextmanager
|
||||
def wrapper():
|
||||
admin_principal = cfg.get('ldap_admin_principal')
|
||||
with app.app_context():
|
||||
with krb5ccname_ctx('ceod/admin'), app.app_context():
|
||||
try:
|
||||
flask.g.sasl_user = admin_principal
|
||||
flask.g.sasl_user = 'ceod/admin'
|
||||
yield
|
||||
finally:
|
||||
flask.g.pop('sasl_user')
|
||||
return wrapper
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def syscom_creds():
|
||||
"""
|
||||
Acquire credentials for a syscom member and store them in a ccache.
|
||||
Yields the name of the ccache file.
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
subprocess.run(
|
||||
['kinit', '-c', f.name, 'ctdalek'],
|
||||
check=True, text=True, input='krb5', stdout=DEVNULL,
|
||||
)
|
||||
yield f.name
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def g_syscom(syscom_creds, app):
|
||||
def g_syscom(app):
|
||||
"""
|
||||
Store the principal for the syscom member in flask.g, and point
|
||||
KRB5CCNAME to the file where the TGT is stored.
|
||||
Use this fixture if you need syscom credentials for an HTTP request
|
||||
to a different process.
|
||||
"""
|
||||
filename = syscom_creds
|
||||
with app.app_context():
|
||||
old_krb5ccname = os.environ['KRB5CCNAME']
|
||||
os.environ['KRB5CCNAME'] = 'FILE:' + filename
|
||||
with krb5ccname_ctx('ctdalek'), app.app_context():
|
||||
try:
|
||||
flask.g.sasl_user = 'ctdalek'
|
||||
yield filename
|
||||
yield
|
||||
finally:
|
||||
os.environ['KRB5CCNAME'] = old_krb5ccname
|
||||
flask.g.pop('sasl_user')
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def ldap_conn(cfg, ceod_admin_creds) -> ldap3.Connection:
|
||||
def ldap_conn(cfg) -> ldap3.Connection:
|
||||
# Assume that the same server URL is being used for the CSC
|
||||
# and UWLDAP during the tests.
|
||||
cfg = component.getUtility(IConfig)
|
||||
server_url = cfg.get('ldap_server_url')
|
||||
# sanity check
|
||||
assert server_url == cfg.get('uwldap_server_url')
|
||||
return ldap3.Connection(
|
||||
server_url, auto_bind=True, raise_exceptions=True,
|
||||
authentication=ldap3.SASL, sasl_mechanism=ldap3.KERBEROS,
|
||||
user=cfg.get('ldap_admin_principal'))
|
||||
with krb5ccname_ctx('ceod/admin'):
|
||||
conn = ldap3.Connection(
|
||||
server_url, auto_bind=True, raise_exceptions=True,
|
||||
authentication=ldap3.SASL, sasl_mechanism=ldap3.KERBEROS,
|
||||
user='ceod/admin')
|
||||
return conn
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
|
@ -375,3 +362,32 @@ def uwldap_user(cfg, uwldap_srv, ldap_conn):
|
|||
)
|
||||
yield user
|
||||
conn.delete(dn)
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def app_process(cfg, app, http_client):
|
||||
port = cfg.get('ceod_port')
|
||||
hostname = socket.gethostname()
|
||||
|
||||
def server_start():
|
||||
sys.stdout = open('/dev/null', 'w')
|
||||
sys.stderr = sys.stdout
|
||||
app.run(debug=False, host='0.0.0.0', port=port)
|
||||
|
||||
proc = Process(target=server_start)
|
||||
proc.start()
|
||||
|
||||
try:
|
||||
with krb5ccname_ctx('ctdalek'):
|
||||
for i in range(5):
|
||||
try:
|
||||
http_client.get(hostname, '/ping')
|
||||
except requests.exceptions.ConnectionError:
|
||||
time.sleep(1)
|
||||
continue
|
||||
break
|
||||
assert i != 5, 'Timed out'
|
||||
yield
|
||||
finally:
|
||||
proc.terminate()
|
||||
proc.join()
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from .utils import krb5ccname_ctx
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def cli_setup(app_process):
|
||||
# This tells the CLI entrypoint not to register additional zope services.
|
||||
os.environ['PYTEST'] = '1'
|
||||
|
||||
# Running the client and the server in the same process would be very
|
||||
# messy because they would be sharing the same environment variables,
|
||||
# Kerberos cache, and registered utilities (via zope). So we're just
|
||||
# going to start the app in a child process intead.
|
||||
with krb5ccname_ctx('ctdalek'):
|
||||
yield
|
|
@ -1,10 +1,6 @@
|
|||
from base64 import b64encode
|
||||
import contextlib
|
||||
import os
|
||||
import json
|
||||
import socket
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
from flask import g
|
||||
from flask.testing import FlaskClient
|
||||
|
@ -14,6 +10,7 @@ from requests import Request
|
|||
from requests_gssapi import HTTPSPNEGOAuth
|
||||
|
||||
from ceo_common.krb5.utils import get_fwd_tgt
|
||||
from .utils import krb5ccname_ctx
|
||||
|
||||
__all__ = ['client']
|
||||
|
||||
|
@ -21,40 +18,23 @@ __all__ = ['client']
|
|||
@pytest.fixture(scope='session')
|
||||
def client(app):
|
||||
app_client = app.test_client()
|
||||
with tempfile.TemporaryDirectory() as cache_dir:
|
||||
yield CeodTestClient(app_client, cache_dir)
|
||||
yield CeodTestClient(app_client)
|
||||
|
||||
|
||||
class CeodTestClient:
|
||||
def __init__(self, app_client: FlaskClient, cache_dir: str):
|
||||
def __init__(self, app_client: FlaskClient):
|
||||
self.client = app_client
|
||||
self.syscom_principal = 'ctdalek'
|
||||
# this is only used for the HTTPSNEGOAuth
|
||||
self.base_url = f'http://{socket.getfqdn()}'
|
||||
# for each principal for which we acquired a TGT, map their
|
||||
# username to a file (ccache) storing their TGT
|
||||
self.principal_ccaches = {}
|
||||
# this is where we'll store the credentials for each principal
|
||||
self.cache_dir = cache_dir
|
||||
# for SPNEGO
|
||||
self.target_name = gssapi.Name('ceod/' + socket.getfqdn())
|
||||
|
||||
@contextlib.contextmanager
|
||||
def krb5ccname_env(self, principal):
|
||||
"""Temporarily change KRB5CCNAME to the ccache of the principal."""
|
||||
old_krb5ccname = os.environ['KRB5CCNAME']
|
||||
os.environ['KRB5CCNAME'] = self.principal_ccaches[principal]
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
os.environ['KRB5CCNAME'] = old_krb5ccname
|
||||
|
||||
def get_auth(self, principal):
|
||||
"""Acquire a HTTPSPNEGOAuth instance for the principal."""
|
||||
name = gssapi.Name(principal)
|
||||
# the 'store' arg doesn't seem to work for DIR ccaches
|
||||
with self.krb5ccname_env(principal):
|
||||
creds = gssapi.Credentials(name=name, usage='initiate')
|
||||
creds = gssapi.Credentials(name=name, usage='initiate')
|
||||
auth = HTTPSPNEGOAuth(
|
||||
opportunistic_auth=True,
|
||||
target_name=self.target_name,
|
||||
|
@ -62,32 +42,17 @@ class CeodTestClient:
|
|||
)
|
||||
return auth
|
||||
|
||||
def kinit(self, principal):
|
||||
"""Acquire an initial TGT for the principal."""
|
||||
# For some reason, kinit with the '-c' option deletes the other
|
||||
# credentials in the cache collection, so we need to override the
|
||||
# env variable
|
||||
subprocess.run(
|
||||
['kinit', principal],
|
||||
text=True, input='krb5', check=True, stdout=subprocess.DEVNULL,
|
||||
env={'KRB5CCNAME': self.principal_ccaches[principal]})
|
||||
|
||||
def get_headers(self, principal: str, need_cred: bool):
|
||||
if principal not in self.principal_ccaches:
|
||||
_, filename = tempfile.mkstemp(dir=self.cache_dir)
|
||||
self.principal_ccaches[principal] = filename
|
||||
self.kinit(principal)
|
||||
# Get the Authorization header (SPNEGO).
|
||||
# The method doesn't matter here because we just need to extract
|
||||
# the header using req.prepare().
|
||||
req = Request('GET', self.base_url, auth=self.get_auth(principal))
|
||||
headers = list(req.prepare().headers.items())
|
||||
if need_cred:
|
||||
# Get the X-KRB5-CRED header (forwarded TGT).
|
||||
cred = b64encode(get_fwd_tgt(
|
||||
'ceod/' + socket.getfqdn(), self.principal_ccaches[principal]
|
||||
)).decode()
|
||||
headers.append(('X-KRB5-CRED', cred))
|
||||
with krb5ccname_ctx(principal):
|
||||
# Get the Authorization header (SPNEGO).
|
||||
# The method doesn't matter here because we just need to extract
|
||||
# the header using req.prepare().
|
||||
req = Request('GET', self.base_url, auth=self.get_auth(principal))
|
||||
headers = list(req.prepare().headers.items())
|
||||
if need_cred:
|
||||
# Get the X-KRB5-CRED header (forwarded TGT).
|
||||
cred = b64encode(get_fwd_tgt('ceod/' + socket.getfqdn())).decode()
|
||||
headers.append(('X-KRB5-CRED', cred))
|
||||
return headers
|
||||
|
||||
def request(self, method: str, path: str, principal: str, need_cred: bool, **kwargs):
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
import contextlib
|
||||
import os
|
||||
import subprocess
|
||||
from subprocess import DEVNULL
|
||||
import tempfile
|
||||
|
||||
|
||||
# map principals to files storing credentials
|
||||
_ccaches = {}
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def krb5ccname_ctx(principal: str):
|
||||
"""
|
||||
Temporarily set KRB5CCNAME to a ccache storing credentials
|
||||
for the specified user.
|
||||
"""
|
||||
old_krb5ccname = os.environ['KRB5CCNAME']
|
||||
try:
|
||||
if principal not in _ccaches:
|
||||
f = tempfile.NamedTemporaryFile()
|
||||
os.environ['KRB5CCNAME'] = 'FILE:' + f.name
|
||||
args = ['kinit', principal]
|
||||
if principal == 'ceod/admin':
|
||||
args = ['kinit', '-k', principal]
|
||||
subprocess.run(
|
||||
args, stdout=DEVNULL, text=True, input='krb5',
|
||||
check=True)
|
||||
_ccaches[principal] = f
|
||||
else:
|
||||
os.environ['KRB5CCNAME'] = 'FILE:' + _ccaches[principal].name
|
||||
yield
|
||||
finally:
|
||||
os.environ['KRB5CCNAME'] = old_krb5ccname
|
Loading…
Reference in New Issue