add unit tests for members CLI
continuous-integration/drone/push Build is passing
Details
continuous-integration/drone/push Build is passing
Details
This commit is contained in:
parent
7a8751fd8f
commit
08a3faaefc
|
@ -0,0 +1 @@
|
||||||
|
from .entrypoint import cli
|
|
@ -5,7 +5,7 @@ import socket
|
||||||
import click
|
import click
|
||||||
from zope import component
|
from zope import component
|
||||||
|
|
||||||
from .krb_check import krb_check
|
from ..krb_check import krb_check
|
||||||
from .members import members
|
from .members import members
|
||||||
from ceo_common.interfaces import IConfig, IHTTPClient
|
from ceo_common.interfaces import IConfig, IHTTPClient
|
||||||
from ceo_common.model import Config, HTTPClient
|
from ceo_common.model import Config, HTTPClient
|
||||||
|
@ -21,7 +21,8 @@ def cli(ctx):
|
||||||
user = princ[:princ.index('@')]
|
user = princ[:princ.index('@')]
|
||||||
ctx.obj['user'] = user
|
ctx.obj['user'] = user
|
||||||
|
|
||||||
register_services()
|
if os.environ.get('PYTEST') != '1':
|
||||||
|
register_services()
|
||||||
|
|
||||||
|
|
||||||
cli.add_command(members)
|
cli.add_command(members)
|
|
@ -5,11 +5,10 @@ from typing import Dict
|
||||||
import click
|
import click
|
||||||
from zope import component
|
from zope import component
|
||||||
|
|
||||||
from .utils import http_post, http_get, http_patch, http_delete, \
|
from ..utils import http_post, http_get, http_patch, http_delete, get_failed_operations
|
||||||
handle_stream_response, handle_sync_response, print_colon_kv, \
|
from .utils import handle_stream_response, handle_sync_response, print_colon_kv
|
||||||
get_failed_operations
|
|
||||||
from ceo_common.interfaces import IConfig
|
from ceo_common.interfaces import IConfig
|
||||||
from ceo_common.utils import get_current_term, add_term, get_max_term
|
from ceo_common.model import Term
|
||||||
from ceod.transactions.members import (
|
from ceod.transactions.members import (
|
||||||
AddMemberTransaction,
|
AddMemberTransaction,
|
||||||
DeleteMemberTransaction,
|
DeleteMemberTransaction,
|
||||||
|
@ -37,11 +36,10 @@ def add(username, cn, program, num_terms, clubrep, forwarding_address):
|
||||||
cfg = component.getUtility(IConfig)
|
cfg = component.getUtility(IConfig)
|
||||||
uw_domain = cfg.get('uw_domain')
|
uw_domain = cfg.get('uw_domain')
|
||||||
|
|
||||||
current_term = get_current_term()
|
current_term = Term.current()
|
||||||
terms = [current_term]
|
terms = [current_term + i for i in range(num_terms)]
|
||||||
for _ in range(1, num_terms):
|
terms = list(map(str, terms))
|
||||||
term = add_term(terms[-1])
|
|
||||||
terms.append(term)
|
|
||||||
if forwarding_address is None:
|
if forwarding_address is None:
|
||||||
forwarding_address = username + '@' + uw_domain
|
forwarding_address = username + '@' + uw_domain
|
||||||
|
|
||||||
|
@ -69,13 +67,18 @@ def add(username, cn, program, num_terms, clubrep, forwarding_address):
|
||||||
if program is not None:
|
if program is not None:
|
||||||
body['program'] = program
|
body['program'] = program
|
||||||
if clubrep:
|
if clubrep:
|
||||||
body['terms'] = terms
|
|
||||||
else:
|
|
||||||
body['non_member_terms'] = terms
|
body['non_member_terms'] = terms
|
||||||
|
else:
|
||||||
|
body['terms'] = terms
|
||||||
if forwarding_address != '':
|
if forwarding_address != '':
|
||||||
body['forwarding_addresses'] = [forwarding_address]
|
body['forwarding_addresses'] = [forwarding_address]
|
||||||
|
operations = AddMemberTransaction.operations
|
||||||
|
if forwarding_address == '':
|
||||||
|
# don't bother displaying this because it won't be run
|
||||||
|
operations.remove('set_forwarding_addresses')
|
||||||
|
|
||||||
resp = http_post('/api/members', json=body)
|
resp = http_post('/api/members', json=body)
|
||||||
data = handle_stream_response(resp, AddMemberTransaction.operations)
|
data = handle_stream_response(resp, operations)
|
||||||
result = data[-1]['result']
|
result = data[-1]['result']
|
||||||
print_user_lines(result)
|
print_user_lines(result)
|
||||||
|
|
||||||
|
@ -121,7 +124,10 @@ def get(username):
|
||||||
@click.argument('username')
|
@click.argument('username')
|
||||||
@click.option('--login-shell', required=False, help='Login shell')
|
@click.option('--login-shell', required=False, help='Login shell')
|
||||||
@click.option('--forwarding-addresses', required=False,
|
@click.option('--forwarding-addresses', required=False,
|
||||||
help='Comma-separated list of forwarding addresses')
|
help=(
|
||||||
|
'Comma-separated list of forwarding addresses. '
|
||||||
|
'Set to the empty string to disable forwarding.'
|
||||||
|
))
|
||||||
def modify(username, login_shell, forwarding_addresses):
|
def modify(username, login_shell, forwarding_addresses):
|
||||||
if login_shell is None and forwarding_addresses is None:
|
if login_shell is None and forwarding_addresses is None:
|
||||||
click.echo('Nothing to do.')
|
click.echo('Nothing to do.')
|
||||||
|
@ -133,13 +139,19 @@ def modify(username, login_shell, forwarding_addresses):
|
||||||
operations.append('replace_login_shell')
|
operations.append('replace_login_shell')
|
||||||
click.echo('Login shell will be set to: ' + login_shell)
|
click.echo('Login shell will be set to: ' + login_shell)
|
||||||
if forwarding_addresses is not None:
|
if forwarding_addresses is not None:
|
||||||
forwarding_addresses = forwarding_addresses.split(',')
|
if forwarding_addresses == '':
|
||||||
|
forwarding_addresses = []
|
||||||
|
else:
|
||||||
|
forwarding_addresses = forwarding_addresses.split(',')
|
||||||
body['forwarding_addresses'] = forwarding_addresses
|
body['forwarding_addresses'] = forwarding_addresses
|
||||||
operations.append('replace_forwarding_addresses')
|
operations.append('replace_forwarding_addresses')
|
||||||
prefix = '~/.forward will be set to: '
|
prefix = '~/.forward will be set to: '
|
||||||
click.echo(prefix + forwarding_addresses[0])
|
if len(forwarding_addresses) > 0:
|
||||||
for address in forwarding_addresses[1:]:
|
click.echo(prefix + forwarding_addresses[0])
|
||||||
click.echo((' ' * len(prefix)) + address)
|
for address in forwarding_addresses[1:]:
|
||||||
|
click.echo((' ' * len(prefix)) + address)
|
||||||
|
else:
|
||||||
|
click.echo(prefix)
|
||||||
|
|
||||||
click.confirm('Do you want to continue?', abort=True)
|
click.confirm('Do you want to continue?', abort=True)
|
||||||
|
|
||||||
|
@ -157,19 +169,19 @@ def renew(username, num_terms, clubrep):
|
||||||
resp = http_get('/api/members/' + username)
|
resp = http_get('/api/members/' + username)
|
||||||
result = handle_sync_response(resp)
|
result = handle_sync_response(resp)
|
||||||
max_term = None
|
max_term = None
|
||||||
|
current_term = Term.current()
|
||||||
if clubrep and 'non_member_terms' in result:
|
if clubrep and 'non_member_terms' in result:
|
||||||
max_term = get_max_term(result['non_member_terms'])
|
max_term = max(Term(s) for s in result['non_member_terms'])
|
||||||
elif not clubrep and 'terms' in result:
|
elif not clubrep and 'terms' in result:
|
||||||
max_term = get_max_term(result['terms'])
|
max_term = max(Term(s) for s in result['terms'])
|
||||||
if max_term is not None:
|
|
||||||
max_term = get_max_term([max_term, get_current_term()])
|
|
||||||
else:
|
|
||||||
max_term = get_current_term()
|
|
||||||
|
|
||||||
terms = [add_term(max_term)]
|
if max_term is not None and max_term >= current_term:
|
||||||
for _ in range(1, num_terms):
|
next_term = max_term + 1
|
||||||
term = add_term(terms[-1])
|
else:
|
||||||
terms.append(term)
|
next_term = Term.current()
|
||||||
|
|
||||||
|
terms = [next_term + i for i in range(num_terms)]
|
||||||
|
terms = list(map(str, terms))
|
||||||
|
|
||||||
if clubrep:
|
if clubrep:
|
||||||
body = {'non_member_terms': terms}
|
body = {'non_member_terms': terms}
|
|
@ -0,0 +1,107 @@
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from typing import List, Tuple, Dict
|
||||||
|
|
||||||
|
import click
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from ..operation_strings import descriptions as op_desc
|
||||||
|
|
||||||
|
|
||||||
|
class Abort(click.ClickException):
|
||||||
|
"""Abort silently."""
|
||||||
|
|
||||||
|
def __init__(self, exit_code=1):
|
||||||
|
super().__init__('')
|
||||||
|
self.exit_code = exit_code
|
||||||
|
|
||||||
|
def show(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def print_colon_kv(pairs: List[Tuple[str, str]]):
|
||||||
|
"""
|
||||||
|
Pretty-print a list of key-value pairs such that the key and value
|
||||||
|
columns align.
|
||||||
|
Example:
|
||||||
|
key1: value1
|
||||||
|
key1000: value2
|
||||||
|
"""
|
||||||
|
maxlen = max(len(key) for key, val in pairs)
|
||||||
|
for key, val in pairs:
|
||||||
|
click.echo(key + ': ', nl=False)
|
||||||
|
extra_space = ' ' * (maxlen - len(key))
|
||||||
|
click.echo(extra_space, nl=False)
|
||||||
|
click.echo(val)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_stream_response(resp: requests.Response, operations: List[str]) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Print output to the console while operations are being streamed
|
||||||
|
from the server over HTTP.
|
||||||
|
Returns the parsed JSON data streamed from the server.
|
||||||
|
"""
|
||||||
|
if resp.status_code != 200:
|
||||||
|
click.echo('An error occurred:')
|
||||||
|
click.echo(resp.text.rstrip())
|
||||||
|
raise Abort()
|
||||||
|
click.echo(op_desc[operations[0]] + '... ', nl=False)
|
||||||
|
idx = 0
|
||||||
|
data = []
|
||||||
|
for line in resp.iter_lines(decode_unicode=True, chunk_size=8):
|
||||||
|
d = json.loads(line)
|
||||||
|
data.append(d)
|
||||||
|
if d['status'] == 'aborted':
|
||||||
|
click.echo(click.style('ABORTED', fg='red'))
|
||||||
|
click.echo('The transaction was rolled back.')
|
||||||
|
click.echo('The error was: ' + d['error'])
|
||||||
|
click.echo('Please check the ceod logs.')
|
||||||
|
sys.exit(1)
|
||||||
|
elif d['status'] == 'completed':
|
||||||
|
click.echo('Transaction successfully completed.')
|
||||||
|
return data
|
||||||
|
|
||||||
|
operation = d['operation']
|
||||||
|
oper_failed = False
|
||||||
|
err_msg = None
|
||||||
|
prefix = 'failed_to_'
|
||||||
|
if operation.startswith(prefix):
|
||||||
|
operation = operation[len(prefix):]
|
||||||
|
oper_failed = True
|
||||||
|
# sometimes the operation looks like
|
||||||
|
# "failed_to_do_something: error message"
|
||||||
|
if ':' in operation:
|
||||||
|
operation, err_msg = operation.split(': ', 1)
|
||||||
|
|
||||||
|
while idx < len(operations) and operations[idx] != operation:
|
||||||
|
click.echo('Skipped')
|
||||||
|
idx += 1
|
||||||
|
if idx == len(operations):
|
||||||
|
break
|
||||||
|
click.echo(op_desc[operations[idx]] + '... ', nl=False)
|
||||||
|
if idx == len(operations):
|
||||||
|
click.echo('Unrecognized operation: ' + operation)
|
||||||
|
continue
|
||||||
|
if oper_failed:
|
||||||
|
click.echo(click.style('Failed', fg='red'))
|
||||||
|
if err_msg is not None:
|
||||||
|
click.echo(' Error message: ' + err_msg)
|
||||||
|
else:
|
||||||
|
click.echo(click.style('Done', fg='green'))
|
||||||
|
idx += 1
|
||||||
|
if idx < len(operations):
|
||||||
|
click.echo(op_desc[operations[idx]] + '... ', nl=False)
|
||||||
|
|
||||||
|
raise Exception('server response ended abruptly')
|
||||||
|
|
||||||
|
|
||||||
|
def handle_sync_response(resp: requests.Response):
|
||||||
|
"""
|
||||||
|
Exit the program if the request was not successful.
|
||||||
|
Returns the parsed JSON response.
|
||||||
|
"""
|
||||||
|
if resp.status_code != 200:
|
||||||
|
click.echo('An error occurred:')
|
||||||
|
click.echo(resp.text.rstrip())
|
||||||
|
raise Abort()
|
||||||
|
return resp.json()
|
|
@ -9,19 +9,15 @@ def krb_check():
|
||||||
credentials have expired.
|
credentials have expired.
|
||||||
Returns the principal string 'user@REALM'.
|
Returns the principal string 'user@REALM'.
|
||||||
"""
|
"""
|
||||||
try:
|
for _ in range(2):
|
||||||
creds = gssapi.Credentials(usage='initiate')
|
try:
|
||||||
except gssapi.raw.misc.GSSError:
|
creds = gssapi.Credentials(usage='initiate')
|
||||||
kinit()
|
result = creds.inquire()
|
||||||
creds = gssapi.Credentials(usage='initiate')
|
return str(result.name)
|
||||||
|
except (gssapi.raw.misc.GSSError, gssapi.raw.exceptions.ExpiredCredentialsError):
|
||||||
|
kinit()
|
||||||
|
|
||||||
try:
|
raise Exception('could not acquire GSSAPI credentials')
|
||||||
result = creds.inquire()
|
|
||||||
except gssapi.raw.exceptions.ExpiredCredentialsError:
|
|
||||||
kinit()
|
|
||||||
result = creds.inquire()
|
|
||||||
|
|
||||||
return str(result.name)
|
|
||||||
|
|
||||||
|
|
||||||
def kinit():
|
def kinit():
|
||||||
|
|
96
ceo/utils.py
96
ceo/utils.py
|
@ -1,12 +1,8 @@
|
||||||
import json
|
from typing import List, Dict
|
||||||
import sys
|
|
||||||
from typing import List, Tuple, Dict
|
|
||||||
|
|
||||||
import click
|
|
||||||
import requests
|
import requests
|
||||||
from zope import component
|
from zope import component
|
||||||
|
|
||||||
from .operation_strings import descriptions as op_desc
|
|
||||||
from ceo_common.interfaces import IHTTPClient, IConfig
|
from ceo_common.interfaces import IHTTPClient, IConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -41,94 +37,6 @@ def http_delete(path: str, **kwargs) -> requests.Response:
|
||||||
return http_request('DELETE', path, **kwargs)
|
return http_request('DELETE', path, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def handle_stream_response(resp: requests.Response, operations: List[str]) -> List[Dict]:
|
|
||||||
"""
|
|
||||||
Print output to the console while operations are being streamed
|
|
||||||
from the server over HTTP.
|
|
||||||
Returns the parsed JSON data streamed from the server.
|
|
||||||
"""
|
|
||||||
if resp.status_code != 200:
|
|
||||||
click.echo('An error occurred:')
|
|
||||||
click.echo(resp.text)
|
|
||||||
sys.exit(1)
|
|
||||||
click.echo(op_desc[operations[0]] + '... ', nl=False)
|
|
||||||
idx = 0
|
|
||||||
data = []
|
|
||||||
for line in resp.iter_lines(decode_unicode=True, chunk_size=8):
|
|
||||||
d = json.loads(line)
|
|
||||||
data.append(d)
|
|
||||||
if d['status'] == 'aborted':
|
|
||||||
click.echo(click.style('ABORTED', fg='red'))
|
|
||||||
click.echo('The transaction was rolled back.')
|
|
||||||
click.echo('The error was: ' + d['error'])
|
|
||||||
click.echo('Please check the ceod logs.')
|
|
||||||
sys.exit(1)
|
|
||||||
elif d['status'] == 'completed':
|
|
||||||
click.echo('Transaction successfully completed.')
|
|
||||||
return data
|
|
||||||
|
|
||||||
operation = d['operation']
|
|
||||||
oper_failed = False
|
|
||||||
err_msg = None
|
|
||||||
prefix = 'failed_to_'
|
|
||||||
if operation.startswith(prefix):
|
|
||||||
operation = operation[len(prefix):]
|
|
||||||
oper_failed = True
|
|
||||||
# sometimes the operation looks like
|
|
||||||
# "failed_to_do_something: error message"
|
|
||||||
if ':' in operation:
|
|
||||||
operation, err_msg = operation.split(': ', 1)
|
|
||||||
|
|
||||||
while idx < len(operations) and operations[idx] != operation:
|
|
||||||
click.echo('Skipped')
|
|
||||||
idx += 1
|
|
||||||
if idx == len(operations):
|
|
||||||
break
|
|
||||||
click.echo(op_desc[operations[idx]] + '... ', nl=False)
|
|
||||||
if idx == len(operations):
|
|
||||||
click.echo('Unrecognized operation: ' + operation)
|
|
||||||
continue
|
|
||||||
if oper_failed:
|
|
||||||
click.echo(click.style('Failed', fg='red'))
|
|
||||||
if err_msg is not None:
|
|
||||||
click.echo(' Error message: ' + err_msg)
|
|
||||||
else:
|
|
||||||
click.echo(click.style('Done', fg='green'))
|
|
||||||
idx += 1
|
|
||||||
if idx < len(operations):
|
|
||||||
click.echo(op_desc[operations[idx]] + '... ', nl=False)
|
|
||||||
|
|
||||||
raise Exception('server response ended abruptly')
|
|
||||||
|
|
||||||
|
|
||||||
def handle_sync_response(resp: requests.Response):
|
|
||||||
"""
|
|
||||||
Exit the program if the request was not successful.
|
|
||||||
Returns the parsed JSON response.
|
|
||||||
"""
|
|
||||||
if resp.status_code // 100 != 2:
|
|
||||||
click.echo('An error occurred:')
|
|
||||||
click.echo(resp.text)
|
|
||||||
sys.exit(1)
|
|
||||||
return resp.json()
|
|
||||||
|
|
||||||
|
|
||||||
def print_colon_kv(pairs: List[Tuple[str, str]]):
|
|
||||||
"""
|
|
||||||
Pretty-print a list of key-value pairs such that the key and value
|
|
||||||
columns align.
|
|
||||||
Example:
|
|
||||||
key1: value1
|
|
||||||
key1000: value2
|
|
||||||
"""
|
|
||||||
maxlen = max(len(key) for key, val in pairs)
|
|
||||||
for key, val in pairs:
|
|
||||||
click.echo(key + ': ', nl=False)
|
|
||||||
extra_space = ' ' * (maxlen - len(key))
|
|
||||||
click.echo(extra_space, nl=False)
|
|
||||||
click.echo(val)
|
|
||||||
|
|
||||||
|
|
||||||
def get_failed_operations(data: List[Dict]) -> List[str]:
|
def get_failed_operations(data: List[Dict]) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Get a list of the failed operations using the JSON objects
|
Get a list of the failed operations using the JSON objects
|
||||||
|
@ -144,6 +52,8 @@ def get_failed_operations(data: List[Dict]) -> List[str]:
|
||||||
continue
|
continue
|
||||||
operation = operation[len(prefix):]
|
operation = operation[len(prefix):]
|
||||||
if ':' in operation:
|
if ':' in operation:
|
||||||
|
# sometimes the operation looks like
|
||||||
|
# "failed_to_do_something: error message"
|
||||||
operation = operation[:operation.index(':')]
|
operation = operation[:operation.index(':')]
|
||||||
failed.append(operation)
|
failed.append(operation)
|
||||||
return failed
|
return failed
|
||||||
|
|
|
@ -0,0 +1,67 @@
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
|
class Term:
|
||||||
|
"""A representation of a term in the CSC LDAP, e.g. 's2021'."""
|
||||||
|
|
||||||
|
seasons = ['w', 's', 'f']
|
||||||
|
|
||||||
|
def __init__(self, s_term: str):
|
||||||
|
assert len(s_term) == 5 and s_term[0] in self.seasons and \
|
||||||
|
s_term[1:].isdigit()
|
||||||
|
self.s_term = s_term
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return self.s_term
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def current():
|
||||||
|
"""Get a Term object for the current date."""
|
||||||
|
dt = datetime.datetime.now()
|
||||||
|
c = 'w'
|
||||||
|
if 5 <= dt.month <= 8:
|
||||||
|
c = 's'
|
||||||
|
elif 9 <= dt.month:
|
||||||
|
c = 'f'
|
||||||
|
s_term = c + str(dt.year)
|
||||||
|
return Term(s_term)
|
||||||
|
|
||||||
|
def __add__(self, other):
|
||||||
|
assert type(other) is int and other >= 0
|
||||||
|
c = self.s_term[0]
|
||||||
|
season_idx = self.seasons.index(c)
|
||||||
|
year = int(self.s_term[1:])
|
||||||
|
year += other // 3
|
||||||
|
season_idx += other % 3
|
||||||
|
if season_idx >= 3:
|
||||||
|
year += 1
|
||||||
|
season_idx -= 3
|
||||||
|
s_term = self.seasons[season_idx] + str(year)
|
||||||
|
return Term(s_term)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return isinstance(other, Term) and self.s_term == other.s_term
|
||||||
|
|
||||||
|
def __lt__(self, other):
|
||||||
|
if not isinstance(other, Term):
|
||||||
|
return NotImplemented
|
||||||
|
c1, c2 = self.s_term[0], other.s_term[0]
|
||||||
|
year1, year2 = int(self.s_term[1:]), int(other.s_term[1:])
|
||||||
|
return year1 < year2 or (
|
||||||
|
year1 == year2 and self.seasons.index(c1) < self.seasons.index(c2)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __gt__(self, other):
|
||||||
|
if not isinstance(other, Term):
|
||||||
|
return NotImplemented
|
||||||
|
c1, c2 = self.s_term[0], other.s_term[0]
|
||||||
|
year1, year2 = int(self.s_term[1:]), int(other.s_term[1:])
|
||||||
|
return year1 > year2 or (
|
||||||
|
year1 == year2 and self.seasons.index(c1) > self.seasons.index(c2)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __ge__(self, other):
|
||||||
|
return self > other or self == other
|
||||||
|
|
||||||
|
def __le__(self, other):
|
||||||
|
return self < other or self == other
|
|
@ -1,3 +1,4 @@
|
||||||
from .Config import Config
|
from .Config import Config
|
||||||
from .HTTPClient import HTTPClient
|
from .HTTPClient import HTTPClient
|
||||||
from .RemoteMailmanService import RemoteMailmanService
|
from .RemoteMailmanService import RemoteMailmanService
|
||||||
|
from .Term import Term
|
||||||
|
|
|
@ -1,40 +0,0 @@
|
||||||
import datetime
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_term() -> str:
|
|
||||||
"""
|
|
||||||
Get the current term as formatted in the CSC LDAP (e.g. 's2021').
|
|
||||||
"""
|
|
||||||
dt = datetime.datetime.now()
|
|
||||||
c = 'w'
|
|
||||||
if 5 <= dt.month <= 8:
|
|
||||||
c = 's'
|
|
||||||
elif 9 <= dt.month:
|
|
||||||
c = 'f'
|
|
||||||
return c + str(dt.year)
|
|
||||||
|
|
||||||
|
|
||||||
def add_term(term: str) -> str:
|
|
||||||
"""
|
|
||||||
Add one term to the given term and return the string.
|
|
||||||
Example: add_term('s2021') -> 'f2021'
|
|
||||||
"""
|
|
||||||
c = term[0]
|
|
||||||
s_year = term[1:]
|
|
||||||
if c == 'w':
|
|
||||||
return 's' + s_year
|
|
||||||
elif c == 's':
|
|
||||||
return 'f' + s_year
|
|
||||||
year = int(s_year)
|
|
||||||
return 'w' + str(year + 1)
|
|
||||||
|
|
||||||
|
|
||||||
def get_max_term(terms: List[str]) -> str:
|
|
||||||
"""Get the maximum (latest) term."""
|
|
||||||
max_year = max(term[1:] for term in terms)
|
|
||||||
if 'f' + max_year in terms:
|
|
||||||
return 'f' + max_year
|
|
||||||
elif 's' + max_year in terms:
|
|
||||||
return 's' + max_year
|
|
||||||
return 'w' + max_year
|
|
|
@ -0,0 +1,136 @@
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from click.testing import CliRunner
|
||||||
|
|
||||||
|
from ceo.cli import cli
|
||||||
|
from ceo_common.model import Term
|
||||||
|
|
||||||
|
|
||||||
|
def test_members_get(cli_setup, ldap_user):
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, ['members', 'get', ldap_user.uid])
|
||||||
|
expected = (
|
||||||
|
f"uid: {ldap_user.uid}\n"
|
||||||
|
f"cn: {ldap_user.cn}\n"
|
||||||
|
f"program: {ldap_user.program}\n"
|
||||||
|
f"UID number: {ldap_user.uid_number}\n"
|
||||||
|
f"GID number: {ldap_user.gid_number}\n"
|
||||||
|
f"login shell: {ldap_user.login_shell}\n"
|
||||||
|
f"home directory: {ldap_user.home_directory}\n"
|
||||||
|
f"is a club: {ldap_user.is_club()}\n"
|
||||||
|
"forwarding addresses: \n"
|
||||||
|
f"terms: {','.join(ldap_user.terms)}\n"
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert result.output == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_members_add(cli_setup):
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, [
|
||||||
|
'members', 'add', 'test_1', '--cn', 'Test One', '--program', 'Math',
|
||||||
|
'--terms', '1',
|
||||||
|
], input='y\n')
|
||||||
|
expected_pat = re.compile((
|
||||||
|
"^The following user will be created:\n"
|
||||||
|
"uid: test_1\n"
|
||||||
|
"cn: Test One\n"
|
||||||
|
"program: Math\n"
|
||||||
|
"member terms: [sfw]\\d{4}\n"
|
||||||
|
"forwarding address: test_1@uwaterloo.internal\n"
|
||||||
|
"Do you want to continue\\? \\[y/N\\]: y\n"
|
||||||
|
"Add user to LDAP... Done\n"
|
||||||
|
"Add group to LDAP... Done\n"
|
||||||
|
"Add user to Kerberos... Done\n"
|
||||||
|
"Create home directory... Done\n"
|
||||||
|
"Set forwarding addresses... Done\n"
|
||||||
|
"Send welcome message... Done\n"
|
||||||
|
"Subscribe to mailing list... Done\n"
|
||||||
|
"Announce new user to mailing list... Done\n"
|
||||||
|
"Transaction successfully completed.\n"
|
||||||
|
"uid: test_1\n"
|
||||||
|
"cn: Test One\n"
|
||||||
|
"program: Math\n"
|
||||||
|
"UID number: \\d{5}\n"
|
||||||
|
"GID number: \\d{5}\n"
|
||||||
|
"login shell: /bin/bash\n"
|
||||||
|
"home directory: [a-z0-9/_-]+/test_1\n"
|
||||||
|
"is a club: False\n"
|
||||||
|
"forwarding addresses: test_1@uwaterloo.internal\n"
|
||||||
|
"terms: [sfw]\\d{4}\n"
|
||||||
|
"password: \\S+\n$"
|
||||||
|
), re.MULTILINE)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert expected_pat.match(result.output) is not None
|
||||||
|
|
||||||
|
result = runner.invoke(cli, ['members', 'delete', 'test_1'], input='y\n')
|
||||||
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_members_modify(cli_setup, ldap_user):
|
||||||
|
# The homedir needs to exist so the API can write to ~/.forward
|
||||||
|
os.makedirs(ldap_user.home_directory)
|
||||||
|
try:
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, [
|
||||||
|
'members', 'modify', ldap_user.uid, '--login-shell', '/bin/sh',
|
||||||
|
'--forwarding-addresses', 'jdoe@test1.internal,jdoe@test2.internal',
|
||||||
|
], input='y\n')
|
||||||
|
expected = (
|
||||||
|
"Login shell will be set to: /bin/sh\n"
|
||||||
|
"~/.forward will be set to: jdoe@test1.internal\n"
|
||||||
|
" jdoe@test2.internal\n"
|
||||||
|
"Do you want to continue? [y/N]: y\n"
|
||||||
|
"Replace login shell... Done\n"
|
||||||
|
"Replace forwarding addresses... Done\n"
|
||||||
|
"Transaction successfully completed.\n"
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert result.output == expected
|
||||||
|
finally:
|
||||||
|
shutil.rmtree(ldap_user.home_directory)
|
||||||
|
|
||||||
|
|
||||||
|
def test_members_renew(cli_setup, ldap_user, g_admin_ctx):
|
||||||
|
# set the user's last term to something really old
|
||||||
|
with g_admin_ctx(), ldap_user.ldap_srv.entry_ctx_for_user(ldap_user) as entry:
|
||||||
|
entry.term = ['s1999', 'f1999']
|
||||||
|
current_term = Term.current()
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, [
|
||||||
|
'members', 'renew', ldap_user.uid, '--terms', '1',
|
||||||
|
], input='y\n')
|
||||||
|
expected = (
|
||||||
|
f"The following member terms will be added: {current_term}\n"
|
||||||
|
"Do you want to continue? [y/N]: y\n"
|
||||||
|
"Done.\n"
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert result.output == expected
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(cli, [
|
||||||
|
'members', 'renew', ldap_user.uid, '--terms', '2',
|
||||||
|
], input='y\n')
|
||||||
|
expected = (
|
||||||
|
f"The following member terms will be added: {current_term+1},{current_term+2}\n"
|
||||||
|
"Do you want to continue? [y/N]: y\n"
|
||||||
|
"Done.\n"
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert result.output == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_members_pwreset(cli_setup, ldap_user, krb_user):
|
||||||
|
runner = CliRunner()
|
||||||
|
result = runner.invoke(
|
||||||
|
cli, ['members', 'pwreset', ldap_user.uid], input='y\n')
|
||||||
|
expected_pat = re.compile((
|
||||||
|
f"^Are you sure you want to reset {ldap_user.uid}'s password\\? \\[y/N\\]: y\n"
|
||||||
|
"New password: \\S+\n$"
|
||||||
|
), re.MULTILINE)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert expected_pat.match(result.output) is not None
|
|
@ -1,42 +1,12 @@
|
||||||
from multiprocessing import Process
|
|
||||||
import socket
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from ceo_common.model import RemoteMailmanService
|
from ceo_common.model import RemoteMailmanService
|
||||||
|
|
||||||
|
|
||||||
def test_remote_mailman(cfg, http_client, app, mock_mailman_server, g_syscom):
|
def test_remote_mailman(app_process, mock_mailman_server, g_syscom):
|
||||||
port = cfg.get('ceod_port')
|
mailman_srv = RemoteMailmanService()
|
||||||
hostname = socket.gethostname()
|
assert mock_mailman_server.subscriptions['csc-general'] == []
|
||||||
|
# RemoteMailmanService -> app -> MailmanService -> MockMailmanServer
|
||||||
def server_start():
|
address = 'test_1@csclub.internal'
|
||||||
sys.stdout = open('/dev/null', 'w')
|
mailman_srv.subscribe(address, 'csc-general')
|
||||||
sys.stderr = sys.stdout
|
assert mock_mailman_server.subscriptions['csc-general'] == [address]
|
||||||
app.run(debug=False, host='0.0.0.0', port=port)
|
mailman_srv.unsubscribe(address, 'csc-general')
|
||||||
|
assert mock_mailman_server.subscriptions['csc-general'] == []
|
||||||
proc = Process(target=server_start)
|
|
||||||
proc.start()
|
|
||||||
|
|
||||||
for _ in range(5):
|
|
||||||
try:
|
|
||||||
http_client.get(hostname, '/ping')
|
|
||||||
except requests.exceptions.ConnectionError:
|
|
||||||
time.sleep(1)
|
|
||||||
continue
|
|
||||||
break
|
|
||||||
|
|
||||||
try:
|
|
||||||
mailman_srv = RemoteMailmanService()
|
|
||||||
assert mock_mailman_server.subscriptions['csc-general'] == []
|
|
||||||
# RemoteMailmanService -> app -> MailmanService -> MockMailmanServer
|
|
||||||
address = 'test_1@csclub.internal'
|
|
||||||
mailman_srv.subscribe(address, 'csc-general')
|
|
||||||
assert mock_mailman_server.subscriptions['csc-general'] == [address]
|
|
||||||
mailman_srv.unsubscribe(address, 'csc-general')
|
|
||||||
assert mock_mailman_server.subscriptions['csc-general'] == []
|
|
||||||
finally:
|
|
||||||
proc.terminate()
|
|
||||||
proc.join()
|
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
[DEFAULT]
|
[DEFAULT]
|
||||||
base_domain = csclub.internal
|
base_domain = csclub.internal
|
||||||
|
# merge ceod.ini and ceo.ini values together to make testing easier
|
||||||
|
uw_domain = uwaterloo.internal
|
||||||
|
|
||||||
[ceod]
|
[ceod]
|
||||||
admin_host = phosphoric-acid
|
admin_host = phosphoric-acid
|
||||||
|
|
|
@ -1,20 +1,23 @@
|
||||||
import contextlib
|
import contextlib
|
||||||
import grp
|
import grp
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
|
from multiprocessing import Process
|
||||||
import os
|
import os
|
||||||
import pwd
|
import pwd
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
from subprocess import DEVNULL
|
import sys
|
||||||
import tempfile
|
import time
|
||||||
from unittest.mock import patch, Mock
|
from unittest.mock import patch, Mock
|
||||||
|
|
||||||
import flask
|
import flask
|
||||||
import ldap3
|
import ldap3
|
||||||
import pytest
|
import pytest
|
||||||
|
import requests
|
||||||
import socket
|
import socket
|
||||||
from zope import component
|
from zope import component
|
||||||
|
|
||||||
|
from .utils import krb5ccname_ctx
|
||||||
from ceo_common.interfaces import IConfig, IKerberosService, ILDAPService, \
|
from ceo_common.interfaces import IConfig, IKerberosService, ILDAPService, \
|
||||||
IFileService, IMailmanService, IHTTPClient, IUWLDAPService, IMailService
|
IFileService, IMailmanService, IHTTPClient, IUWLDAPService, IMailService
|
||||||
from ceo_common.model import Config, HTTPClient
|
from ceo_common.model import Config, HTTPClient
|
||||||
|
@ -25,6 +28,7 @@ import ceod.utils as utils
|
||||||
from .MockSMTPServer import MockSMTPServer
|
from .MockSMTPServer import MockSMTPServer
|
||||||
from .MockMailmanServer import MockMailmanServer
|
from .MockMailmanServer import MockMailmanServer
|
||||||
from .conftest_ceod_api import client # noqa: F401
|
from .conftest_ceod_api import client # noqa: F401
|
||||||
|
from .conftest_ceo import cli_setup # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='session', autouse=True)
|
@pytest.fixture(scope='session', autouse=True)
|
||||||
|
@ -47,6 +51,18 @@ def cfg(_drone_hostname_mock):
|
||||||
return _cfg
|
return _cfg
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='session', autouse=True)
|
||||||
|
def _delete_ccaches():
|
||||||
|
# I've noticed when pytest finishes, the temporary files
|
||||||
|
# created by tempfile.NamedTemporaryFile() aren't destroyed.
|
||||||
|
# So, we clean them up here.
|
||||||
|
from .utils import _ccaches
|
||||||
|
yield
|
||||||
|
# forcefully decrement the reference counts, which will trigger
|
||||||
|
# the destructors
|
||||||
|
_ccaches.clear()
|
||||||
|
|
||||||
|
|
||||||
def delete_test_princs(krb_srv):
|
def delete_test_princs(krb_srv):
|
||||||
proc = subprocess.run([
|
proc = subprocess.run([
|
||||||
'kadmin', '-k', '-p', krb_srv.admin_principal, 'listprincs', 'test_*',
|
'kadmin', '-k', '-p', krb_srv.admin_principal, 'listprincs', 'test_*',
|
||||||
|
@ -86,20 +102,8 @@ def delete_subtree(conn: ldap3.Connection, base_dn: str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='session')
|
|
||||||
def ceod_admin_creds(cfg, krb_srv):
|
|
||||||
"""
|
|
||||||
Acquire credentials for ceod/admin and store them
|
|
||||||
in the default ccache.
|
|
||||||
"""
|
|
||||||
subprocess.run(
|
|
||||||
['kinit', '-k', cfg.get('ldap_admin_principal')],
|
|
||||||
check=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def g_admin_ctx(cfg, ceod_admin_creds, app):
|
def g_admin_ctx(app):
|
||||||
"""
|
"""
|
||||||
Store the principal for ceod/admin in flask.g.
|
Store the principal for ceod/admin in flask.g.
|
||||||
This context manager should be used any time LDAP is modified via the
|
This context manager should be used any time LDAP is modified via the
|
||||||
|
@ -109,62 +113,45 @@ def g_admin_ctx(cfg, ceod_admin_creds, app):
|
||||||
"""
|
"""
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def wrapper():
|
def wrapper():
|
||||||
admin_principal = cfg.get('ldap_admin_principal')
|
with krb5ccname_ctx('ceod/admin'), app.app_context():
|
||||||
with app.app_context():
|
|
||||||
try:
|
try:
|
||||||
flask.g.sasl_user = admin_principal
|
flask.g.sasl_user = 'ceod/admin'
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
flask.g.pop('sasl_user')
|
flask.g.pop('sasl_user')
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='session')
|
|
||||||
def syscom_creds():
|
|
||||||
"""
|
|
||||||
Acquire credentials for a syscom member and store them in a ccache.
|
|
||||||
Yields the name of the ccache file.
|
|
||||||
"""
|
|
||||||
with tempfile.NamedTemporaryFile() as f:
|
|
||||||
subprocess.run(
|
|
||||||
['kinit', '-c', f.name, 'ctdalek'],
|
|
||||||
check=True, text=True, input='krb5', stdout=DEVNULL,
|
|
||||||
)
|
|
||||||
yield f.name
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def g_syscom(syscom_creds, app):
|
def g_syscom(app):
|
||||||
"""
|
"""
|
||||||
Store the principal for the syscom member in flask.g, and point
|
Store the principal for the syscom member in flask.g, and point
|
||||||
KRB5CCNAME to the file where the TGT is stored.
|
KRB5CCNAME to the file where the TGT is stored.
|
||||||
Use this fixture if you need syscom credentials for an HTTP request
|
Use this fixture if you need syscom credentials for an HTTP request
|
||||||
to a different process.
|
to a different process.
|
||||||
"""
|
"""
|
||||||
filename = syscom_creds
|
with krb5ccname_ctx('ctdalek'), app.app_context():
|
||||||
with app.app_context():
|
|
||||||
old_krb5ccname = os.environ['KRB5CCNAME']
|
|
||||||
os.environ['KRB5CCNAME'] = 'FILE:' + filename
|
|
||||||
try:
|
try:
|
||||||
flask.g.sasl_user = 'ctdalek'
|
flask.g.sasl_user = 'ctdalek'
|
||||||
yield filename
|
yield
|
||||||
finally:
|
finally:
|
||||||
os.environ['KRB5CCNAME'] = old_krb5ccname
|
|
||||||
flask.g.pop('sasl_user')
|
flask.g.pop('sasl_user')
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='session')
|
@pytest.fixture(scope='session')
|
||||||
def ldap_conn(cfg, ceod_admin_creds) -> ldap3.Connection:
|
def ldap_conn(cfg) -> ldap3.Connection:
|
||||||
# Assume that the same server URL is being used for the CSC
|
# Assume that the same server URL is being used for the CSC
|
||||||
# and UWLDAP during the tests.
|
# and UWLDAP during the tests.
|
||||||
cfg = component.getUtility(IConfig)
|
cfg = component.getUtility(IConfig)
|
||||||
server_url = cfg.get('ldap_server_url')
|
server_url = cfg.get('ldap_server_url')
|
||||||
# sanity check
|
# sanity check
|
||||||
assert server_url == cfg.get('uwldap_server_url')
|
assert server_url == cfg.get('uwldap_server_url')
|
||||||
return ldap3.Connection(
|
with krb5ccname_ctx('ceod/admin'):
|
||||||
server_url, auto_bind=True, raise_exceptions=True,
|
conn = ldap3.Connection(
|
||||||
authentication=ldap3.SASL, sasl_mechanism=ldap3.KERBEROS,
|
server_url, auto_bind=True, raise_exceptions=True,
|
||||||
user=cfg.get('ldap_admin_principal'))
|
authentication=ldap3.SASL, sasl_mechanism=ldap3.KERBEROS,
|
||||||
|
user='ceod/admin')
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='session')
|
@pytest.fixture(scope='session')
|
||||||
|
@ -375,3 +362,32 @@ def uwldap_user(cfg, uwldap_srv, ldap_conn):
|
||||||
)
|
)
|
||||||
yield user
|
yield user
|
||||||
conn.delete(dn)
|
conn.delete(dn)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='module')
|
||||||
|
def app_process(cfg, app, http_client):
|
||||||
|
port = cfg.get('ceod_port')
|
||||||
|
hostname = socket.gethostname()
|
||||||
|
|
||||||
|
def server_start():
|
||||||
|
sys.stdout = open('/dev/null', 'w')
|
||||||
|
sys.stderr = sys.stdout
|
||||||
|
app.run(debug=False, host='0.0.0.0', port=port)
|
||||||
|
|
||||||
|
proc = Process(target=server_start)
|
||||||
|
proc.start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with krb5ccname_ctx('ctdalek'):
|
||||||
|
for i in range(5):
|
||||||
|
try:
|
||||||
|
http_client.get(hostname, '/ping')
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
time.sleep(1)
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
assert i != 5, 'Timed out'
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
proc.terminate()
|
||||||
|
proc.join()
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from .utils import krb5ccname_ctx
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope='module')
|
||||||
|
def cli_setup(app_process):
|
||||||
|
# This tells the CLI entrypoint not to register additional zope services.
|
||||||
|
os.environ['PYTEST'] = '1'
|
||||||
|
|
||||||
|
# Running the client and the server in the same process would be very
|
||||||
|
# messy because they would be sharing the same environment variables,
|
||||||
|
# Kerberos cache, and registered utilities (via zope). So we're just
|
||||||
|
# going to start the app in a child process intead.
|
||||||
|
with krb5ccname_ctx('ctdalek'):
|
||||||
|
yield
|
|
@ -1,10 +1,6 @@
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
import contextlib
|
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
import socket
|
import socket
|
||||||
import subprocess
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
from flask import g
|
from flask import g
|
||||||
from flask.testing import FlaskClient
|
from flask.testing import FlaskClient
|
||||||
|
@ -14,6 +10,7 @@ from requests import Request
|
||||||
from requests_gssapi import HTTPSPNEGOAuth
|
from requests_gssapi import HTTPSPNEGOAuth
|
||||||
|
|
||||||
from ceo_common.krb5.utils import get_fwd_tgt
|
from ceo_common.krb5.utils import get_fwd_tgt
|
||||||
|
from .utils import krb5ccname_ctx
|
||||||
|
|
||||||
__all__ = ['client']
|
__all__ = ['client']
|
||||||
|
|
||||||
|
@ -21,40 +18,23 @@ __all__ = ['client']
|
||||||
@pytest.fixture(scope='session')
|
@pytest.fixture(scope='session')
|
||||||
def client(app):
|
def client(app):
|
||||||
app_client = app.test_client()
|
app_client = app.test_client()
|
||||||
with tempfile.TemporaryDirectory() as cache_dir:
|
yield CeodTestClient(app_client)
|
||||||
yield CeodTestClient(app_client, cache_dir)
|
|
||||||
|
|
||||||
|
|
||||||
class CeodTestClient:
|
class CeodTestClient:
|
||||||
def __init__(self, app_client: FlaskClient, cache_dir: str):
|
def __init__(self, app_client: FlaskClient):
|
||||||
self.client = app_client
|
self.client = app_client
|
||||||
self.syscom_principal = 'ctdalek'
|
self.syscom_principal = 'ctdalek'
|
||||||
# this is only used for the HTTPSNEGOAuth
|
# this is only used for the HTTPSNEGOAuth
|
||||||
self.base_url = f'http://{socket.getfqdn()}'
|
self.base_url = f'http://{socket.getfqdn()}'
|
||||||
# for each principal for which we acquired a TGT, map their
|
|
||||||
# username to a file (ccache) storing their TGT
|
|
||||||
self.principal_ccaches = {}
|
|
||||||
# this is where we'll store the credentials for each principal
|
|
||||||
self.cache_dir = cache_dir
|
|
||||||
# for SPNEGO
|
# for SPNEGO
|
||||||
self.target_name = gssapi.Name('ceod/' + socket.getfqdn())
|
self.target_name = gssapi.Name('ceod/' + socket.getfqdn())
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def krb5ccname_env(self, principal):
|
|
||||||
"""Temporarily change KRB5CCNAME to the ccache of the principal."""
|
|
||||||
old_krb5ccname = os.environ['KRB5CCNAME']
|
|
||||||
os.environ['KRB5CCNAME'] = self.principal_ccaches[principal]
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
os.environ['KRB5CCNAME'] = old_krb5ccname
|
|
||||||
|
|
||||||
def get_auth(self, principal):
|
def get_auth(self, principal):
|
||||||
"""Acquire a HTTPSPNEGOAuth instance for the principal."""
|
"""Acquire a HTTPSPNEGOAuth instance for the principal."""
|
||||||
name = gssapi.Name(principal)
|
name = gssapi.Name(principal)
|
||||||
# the 'store' arg doesn't seem to work for DIR ccaches
|
# the 'store' arg doesn't seem to work for DIR ccaches
|
||||||
with self.krb5ccname_env(principal):
|
creds = gssapi.Credentials(name=name, usage='initiate')
|
||||||
creds = gssapi.Credentials(name=name, usage='initiate')
|
|
||||||
auth = HTTPSPNEGOAuth(
|
auth = HTTPSPNEGOAuth(
|
||||||
opportunistic_auth=True,
|
opportunistic_auth=True,
|
||||||
target_name=self.target_name,
|
target_name=self.target_name,
|
||||||
|
@ -62,32 +42,17 @@ class CeodTestClient:
|
||||||
)
|
)
|
||||||
return auth
|
return auth
|
||||||
|
|
||||||
def kinit(self, principal):
|
|
||||||
"""Acquire an initial TGT for the principal."""
|
|
||||||
# For some reason, kinit with the '-c' option deletes the other
|
|
||||||
# credentials in the cache collection, so we need to override the
|
|
||||||
# env variable
|
|
||||||
subprocess.run(
|
|
||||||
['kinit', principal],
|
|
||||||
text=True, input='krb5', check=True, stdout=subprocess.DEVNULL,
|
|
||||||
env={'KRB5CCNAME': self.principal_ccaches[principal]})
|
|
||||||
|
|
||||||
def get_headers(self, principal: str, need_cred: bool):
|
def get_headers(self, principal: str, need_cred: bool):
|
||||||
if principal not in self.principal_ccaches:
|
with krb5ccname_ctx(principal):
|
||||||
_, filename = tempfile.mkstemp(dir=self.cache_dir)
|
# Get the Authorization header (SPNEGO).
|
||||||
self.principal_ccaches[principal] = filename
|
# The method doesn't matter here because we just need to extract
|
||||||
self.kinit(principal)
|
# the header using req.prepare().
|
||||||
# Get the Authorization header (SPNEGO).
|
req = Request('GET', self.base_url, auth=self.get_auth(principal))
|
||||||
# The method doesn't matter here because we just need to extract
|
headers = list(req.prepare().headers.items())
|
||||||
# the header using req.prepare().
|
if need_cred:
|
||||||
req = Request('GET', self.base_url, auth=self.get_auth(principal))
|
# Get the X-KRB5-CRED header (forwarded TGT).
|
||||||
headers = list(req.prepare().headers.items())
|
cred = b64encode(get_fwd_tgt('ceod/' + socket.getfqdn())).decode()
|
||||||
if need_cred:
|
headers.append(('X-KRB5-CRED', cred))
|
||||||
# Get the X-KRB5-CRED header (forwarded TGT).
|
|
||||||
cred = b64encode(get_fwd_tgt(
|
|
||||||
'ceod/' + socket.getfqdn(), self.principal_ccaches[principal]
|
|
||||||
)).decode()
|
|
||||||
headers.append(('X-KRB5-CRED', cred))
|
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
def request(self, method: str, path: str, principal: str, need_cred: bool, **kwargs):
|
def request(self, method: str, path: str, principal: str, need_cred: bool, **kwargs):
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
import contextlib
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from subprocess import DEVNULL
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
|
||||||
|
# map principals to files storing credentials
|
||||||
|
_ccaches = {}
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def krb5ccname_ctx(principal: str):
|
||||||
|
"""
|
||||||
|
Temporarily set KRB5CCNAME to a ccache storing credentials
|
||||||
|
for the specified user.
|
||||||
|
"""
|
||||||
|
old_krb5ccname = os.environ['KRB5CCNAME']
|
||||||
|
try:
|
||||||
|
if principal not in _ccaches:
|
||||||
|
f = tempfile.NamedTemporaryFile()
|
||||||
|
os.environ['KRB5CCNAME'] = 'FILE:' + f.name
|
||||||
|
args = ['kinit', principal]
|
||||||
|
if principal == 'ceod/admin':
|
||||||
|
args = ['kinit', '-k', principal]
|
||||||
|
subprocess.run(
|
||||||
|
args, stdout=DEVNULL, text=True, input='krb5',
|
||||||
|
check=True)
|
||||||
|
_ccaches[principal] = f
|
||||||
|
else:
|
||||||
|
os.environ['KRB5CCNAME'] = 'FILE:' + _ccaches[principal].name
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
os.environ['KRB5CCNAME'] = old_krb5ccname
|
Loading…
Reference in New Issue