Expire member cli and api (#33)
Closes #23 Co-authored-by: Rio Liu <rio.liu@r26.me> Co-authored-by: Max Erenberg <> Reviewed-on: #33 Co-authored-by: Rio <r345liu@localhost> Co-committed-by: Rio <r345liu@localhost>
This commit is contained in:
parent
f1c0ce3dd6
commit
b4110d887d
|
@ -187,3 +187,18 @@ def delete(username):
|
|||
click.confirm(f"Are you sure you want to delete {username}?", abort=True)
|
||||
resp = http_delete(f'/api/members/{username}')
|
||||
handle_stream_response(resp, DeleteMemberTransaction.operations)
|
||||
|
||||
|
||||
@members.command(short_help="Check for and mark expired members")
|
||||
@click.option('--dry-run', is_flag=True, default=False)
|
||||
def expire(dry_run):
|
||||
resp = http_post(f'/api/members/expire?dry_run={dry_run and "yes" or "no"}')
|
||||
result = handle_sync_response(resp)
|
||||
|
||||
if len(result) > 0:
|
||||
if dry_run:
|
||||
click.echo("The following members will be marked as expired:")
|
||||
else:
|
||||
click.echo("The following members has been marked as expired:")
|
||||
for username in result:
|
||||
click.echo(username)
|
||||
|
|
|
@ -87,3 +87,9 @@ class ILDAPService(Interface):
|
|||
be returned along with their new programs, in the same format
|
||||
described above.
|
||||
"""
|
||||
|
||||
def get_expiring_users(self) -> List[IUser]:
|
||||
"""
|
||||
Retrieves members whose term or nonMemberTerm does not contain the
|
||||
current or the last term.
|
||||
"""
|
||||
|
|
|
@ -22,6 +22,7 @@ class IUser(Interface):
|
|||
'a club rep')
|
||||
mail_local_addresses = Attribute('email aliases')
|
||||
is_club_rep = Attribute('whether this user is a club rep or not')
|
||||
shadowExpire = Attribute('whether the user is marked as expired')
|
||||
|
||||
# Non-LDAP attributes
|
||||
ldap3_entry = Attribute('cached ldap3.Entry instance for this user')
|
||||
|
|
|
@ -16,17 +16,22 @@ class Term:
|
|||
def __repr__(self):
|
||||
return self.s_term
|
||||
|
||||
@staticmethod
|
||||
def from_datetime(dt: datetime.datetime):
|
||||
"""Get a Term object for the given date."""
|
||||
idx = (dt.month - 1) // 4
|
||||
c = Term.seasons[idx]
|
||||
s_term = c + str(dt.year)
|
||||
return Term(s_term)
|
||||
|
||||
@staticmethod
|
||||
def current():
|
||||
"""Get a Term object for the current date."""
|
||||
dt = utils.get_current_datetime()
|
||||
c = 'w'
|
||||
if 5 <= dt.month <= 8:
|
||||
c = 's'
|
||||
elif 9 <= dt.month:
|
||||
c = 'f'
|
||||
s_term = c + str(dt.year)
|
||||
return Term(s_term)
|
||||
return Term.from_datetime(dt)
|
||||
|
||||
def start_month(self):
|
||||
return self.seasons.index(self.s_term[0]) * 4 + 1
|
||||
|
||||
def __add__(self, other):
|
||||
assert type(other) is int
|
||||
|
@ -40,6 +45,7 @@ class Term:
|
|||
return Term(s_term)
|
||||
|
||||
def __sub__(self, other):
|
||||
assert type(other) is int
|
||||
return self.__add__(-other)
|
||||
|
||||
def __eq__(self, other):
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from flask import Blueprint, request
|
||||
from flask import Blueprint, request, json
|
||||
from zope import component
|
||||
|
||||
from .utils import authz_restrict_to_staff, authz_restrict_to_syscom, \
|
||||
user_is_in_group, requires_authentication_no_realm, \
|
||||
create_streaming_response, development_only
|
||||
create_streaming_response, development_only, is_truthy
|
||||
from ceo_common.errors import BadRequest
|
||||
from ceo_common.interfaces import ILDAPService
|
||||
from ceod.transactions.members import (
|
||||
|
@ -86,9 +86,11 @@ def renew_user(username: str):
|
|||
user = ldap_srv.get_user(username)
|
||||
if body.get('terms'):
|
||||
user.add_terms(body['terms'])
|
||||
user.set_expired(False)
|
||||
return {'terms_added': body['terms']}
|
||||
elif body.get('non_member_terms'):
|
||||
user.add_non_member_terms(body['non_member_terms'])
|
||||
user.set_expired(False)
|
||||
return {'non_member_terms_added': body['non_member_terms']}
|
||||
else:
|
||||
raise BadRequest('Must specify either terms or non-member terms')
|
||||
|
@ -110,3 +112,18 @@ def reset_user_password(username: str):
|
|||
def delete_user(username: str):
|
||||
txn = DeleteMemberTransaction(username)
|
||||
return create_streaming_response(txn)
|
||||
|
||||
|
||||
@bp.route('/expire', methods=['POST'])
|
||||
@authz_restrict_to_syscom
|
||||
def expire_users():
|
||||
dry_run = is_truthy(request.args.get('dry_run', 'false'))
|
||||
|
||||
ldap_srv = component.getUtility(ILDAPService)
|
||||
members = ldap_srv.get_expiring_users()
|
||||
|
||||
if not dry_run:
|
||||
for member in members:
|
||||
member.set_expired(True)
|
||||
|
||||
return json.jsonify([member.uid for member in members])
|
||||
|
|
|
@ -140,4 +140,4 @@ def development_only(f: Callable) -> Callable:
|
|||
|
||||
|
||||
def is_truthy(s: str) -> bool:
|
||||
return s.lower() in ['yes', 'true']
|
||||
return s.lower() in ['yes', 'true', '1']
|
||||
|
|
|
@ -13,6 +13,8 @@ from ceo_common.errors import UserNotFoundError, GroupNotFoundError, \
|
|||
UserAlreadyExistsError, GroupAlreadyExistsError
|
||||
from ceo_common.interfaces import ILDAPService, IConfig, \
|
||||
IUser, IGroup, IUWLDAPService
|
||||
from ceo_common.model import Term
|
||||
import ceo_common.utils as ceo_common_utils
|
||||
from .User import User
|
||||
from .Group import Group
|
||||
|
||||
|
@ -243,6 +245,30 @@ class LDAPService:
|
|||
except ldap3.core.exceptions.LDAPEntryAlreadyExistsResult:
|
||||
raise GroupAlreadyExistsError()
|
||||
|
||||
def get_expiring_users(self) -> List[IUser]:
|
||||
query = []
|
||||
|
||||
term = Term.current()
|
||||
query.append(f'term={term}')
|
||||
query.append(f'nonMemberTerm={term}')
|
||||
|
||||
# Include last term too if the new term just started
|
||||
dt = ceo_common_utils.get_current_datetime()
|
||||
if dt.month == term.start_month():
|
||||
last_term = term - 1
|
||||
query.append(f'term={last_term}')
|
||||
query.append(f'nonMemberTerm={last_term}')
|
||||
|
||||
query = '(!(|(shadowExpire=1)(' + ')('.join(query) + ')))'
|
||||
|
||||
conn = self._get_ldap_conn()
|
||||
conn.search(
|
||||
self.ldap_users_base,
|
||||
query,
|
||||
attributes=ldap3.ALL_ATTRIBUTES,
|
||||
search_scope=ldap3.LEVEL)
|
||||
return [User.deserialize_from_ldap(entry) for entry in conn.entries]
|
||||
|
||||
@contextlib.contextmanager
|
||||
def entry_ctx_for_group(self, group: IGroup):
|
||||
entry = self._get_writable_entry_for_group(group)
|
||||
|
|
|
@ -33,6 +33,7 @@ class User:
|
|||
is_club_rep: Union[bool, None] = None,
|
||||
is_club: bool = False,
|
||||
ldap3_entry: Union[ldap3.Entry, None] = None,
|
||||
shadowExpire: Union[int, None] = None,
|
||||
):
|
||||
cfg = component.getUtility(IConfig)
|
||||
|
||||
|
@ -66,6 +67,7 @@ class User:
|
|||
else:
|
||||
self.is_club_rep = is_club_rep
|
||||
self.ldap3_entry = ldap3_entry
|
||||
self.shadowExpire = shadowExpire
|
||||
|
||||
self.ldap_srv = component.getUtility(ILDAPService)
|
||||
self.krb_srv = component.getUtility(IKerberosService)
|
||||
|
@ -82,6 +84,7 @@ class User:
|
|||
'is_club': self.is_club(),
|
||||
'is_club_rep': self.is_club_rep,
|
||||
'program': self.program or 'Unknown',
|
||||
'shadowExpire': self.shadowExpire,
|
||||
}
|
||||
if self.sn and self.given_name:
|
||||
data['sn'] = self.sn
|
||||
|
@ -155,6 +158,7 @@ class User:
|
|||
mail_local_addresses=attrs.get('mailLocalAddress'),
|
||||
is_club_rep=attrs.get('isClubRep', [False])[0],
|
||||
is_club=('club' in attrs['objectClass']),
|
||||
shadowExpire=attrs.get('shadowExpire'),
|
||||
ldap3_entry=entry,
|
||||
)
|
||||
|
||||
|
@ -205,3 +209,12 @@ class User:
|
|||
current_term = Term.current()
|
||||
most_recent_term = max(map(Term, self.terms))
|
||||
return most_recent_term >= current_term
|
||||
|
||||
def set_expired(self, expired: bool):
|
||||
with self.ldap_srv.entry_ctx_for_user(self) as entry:
|
||||
if expired:
|
||||
entry.shadowExpire = 1
|
||||
self.shadowExpire = 1
|
||||
else:
|
||||
entry.shadowExpire.remove()
|
||||
self.shadowExpire = None
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
import os
|
||||
import re
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
|
||||
from click.testing import CliRunner
|
||||
from unittest.mock import patch
|
||||
|
||||
from ceo.cli import cli
|
||||
from ceo_common.model import Term
|
||||
import ceo_common.utils
|
||||
|
||||
|
||||
def test_members_get(cli_setup, ldap_user):
|
||||
|
@ -142,3 +145,26 @@ def test_members_pwreset(cli_setup, ldap_user, krb_user):
|
|||
), re.MULTILINE)
|
||||
assert result.exit_code == 0
|
||||
assert expected_pat.match(result.output) is not None
|
||||
|
||||
|
||||
def test_members_expire(cli_setup, ldap_user):
|
||||
runner = CliRunner()
|
||||
|
||||
with patch.object(ceo_common.utils, 'get_current_datetime') as datetime_mock:
|
||||
# use a time that we know for sure will expire
|
||||
datetime_mock.return_value = datetime(4000, 4, 1)
|
||||
|
||||
result = runner.invoke(cli, ['members', 'expire', '--dry-run'])
|
||||
assert result.exit_code == 0
|
||||
assert result.output == f"The following members will be marked as expired:\n{ldap_user.uid}\n"
|
||||
|
||||
result = runner.invoke(cli, ['members', 'expire'])
|
||||
assert result.exit_code == 0
|
||||
assert result.output == f"The following members has been marked as expired:\n{ldap_user.uid}\n"
|
||||
|
||||
runner.invoke(cli, ['members', 'renew', ldap_user.uid, '--terms', '1'])
|
||||
assert result.exit_code == 0
|
||||
|
||||
result = runner.invoke(cli, ['members', 'expire', '--dry-run'])
|
||||
assert result.exit_code == 0
|
||||
assert result.output == ''
|
||||
|
|
|
@ -2,8 +2,11 @@ from unittest.mock import patch
|
|||
|
||||
import ldap3
|
||||
import pytest
|
||||
import datetime
|
||||
|
||||
import ceod.utils as utils
|
||||
import ceod.utils
|
||||
import ceo_common.utils
|
||||
from ceo_common.model import Term
|
||||
|
||||
|
||||
def test_api_user_not_found(client):
|
||||
|
@ -12,26 +15,26 @@ def test_api_user_not_found(client):
|
|||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def create_user_resp(client, mocks_for_create_user, mock_mail_server):
|
||||
def create_user_resp(client, mocks_for_create_user_module, mock_mail_server):
|
||||
mock_mail_server.messages.clear()
|
||||
status, data = client.post('/api/members', json={
|
||||
'uid': 'test_1',
|
||||
'uid': 'test1',
|
||||
'cn': 'Test One',
|
||||
'given_name': 'Test',
|
||||
'sn': 'One',
|
||||
'program': 'Math',
|
||||
'terms': ['s2021'],
|
||||
'forwarding_addresses': ['test_1@uwaterloo.internal'],
|
||||
'forwarding_addresses': ['test1@uwaterloo.internal'],
|
||||
})
|
||||
assert status == 200
|
||||
assert data[-1]['status'] == 'completed'
|
||||
yield status, data
|
||||
status, data = client.delete('/api/members/test_1')
|
||||
status, data = client.delete('/api/members/test1')
|
||||
assert status == 200
|
||||
assert data[-1]['status'] == 'completed'
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
@pytest.fixture(scope='function')
|
||||
def create_user_result(create_user_resp):
|
||||
# convenience method
|
||||
_, data = create_user_resp
|
||||
|
@ -54,25 +57,26 @@ def test_api_create_user(cfg, create_user_resp, mock_mail_server):
|
|||
"cn": "Test One",
|
||||
"given_name": "Test",
|
||||
"sn": "One",
|
||||
"uid": "test_1",
|
||||
"uid": "test1",
|
||||
"uid_number": min_uid,
|
||||
"gid_number": min_uid,
|
||||
"login_shell": "/bin/bash",
|
||||
"home_directory": "/tmp/test_users/test_1",
|
||||
"home_directory": "/tmp/test_users/test1",
|
||||
"is_club": False,
|
||||
"is_club_rep": False,
|
||||
"program": "Math",
|
||||
"terms": ["s2021"],
|
||||
"mail_local_addresses": ["test_1@csclub.internal"],
|
||||
"forwarding_addresses": ['test_1@uwaterloo.internal'],
|
||||
"password": "krb5"
|
||||
"mail_local_addresses": ["test1@csclub.internal"],
|
||||
"forwarding_addresses": ['test1@uwaterloo.internal'],
|
||||
"password": "krb5",
|
||||
"shadowExpire": None,
|
||||
}},
|
||||
]
|
||||
assert data == expected
|
||||
# Two messages should have been sent: a welcome message to the new member,
|
||||
# and an announcement to the ceo mailing list
|
||||
assert len(mock_mail_server.messages) == 2
|
||||
assert mock_mail_server.messages[0]['to'] == 'test_1@csclub.internal'
|
||||
assert mock_mail_server.messages[0]['to'] == 'test1@csclub.internal'
|
||||
assert mock_mail_server.messages[1]['to'] == 'ceo@csclub.internal,ctdalek@csclub.internal'
|
||||
mock_mail_server.messages.clear()
|
||||
|
||||
|
@ -198,7 +202,7 @@ def test_api_renew_user(cfg, client, create_user_result, ldap_conn):
|
|||
|
||||
def test_api_reset_password(client, create_user_result):
|
||||
uid = create_user_result['uid']
|
||||
with patch.object(utils, 'gen_password') as gen_password_mock:
|
||||
with patch.object(ceod.utils, 'gen_password') as gen_password_mock:
|
||||
gen_password_mock.return_value = 'new_password'
|
||||
status, data = client.post(f'/api/members/{uid}/pwreset')
|
||||
assert status == 200
|
||||
|
@ -212,7 +216,7 @@ def test_api_reset_password(client, create_user_result):
|
|||
def test_authz_check(client, create_user_result):
|
||||
# non-staff members may not create users
|
||||
status, data = client.post('/api/members', json={
|
||||
'uid': 'test_1', 'cn': 'Test One', 'given_name': 'Test',
|
||||
'uid': 'test1', 'cn': 'Test One', 'given_name': 'Test',
|
||||
'sn': 'One', 'terms': ['s2021'],
|
||||
}, principal='regular1')
|
||||
assert status == 403
|
||||
|
@ -227,7 +231,56 @@ def test_authz_check(client, create_user_result):
|
|||
|
||||
# If we're syscom but we don't pass credentials, the request should fail
|
||||
_, data = client.post('/api/members', json={
|
||||
'uid': 'test_1', 'cn': 'Test One', 'given_name': 'Test',
|
||||
'uid': 'test1', 'cn': 'Test One', 'given_name': 'Test',
|
||||
'sn': 'One', 'terms': ['s2021'],
|
||||
}, principal='ctdalek', delegate=False)
|
||||
assert data[-1]['status'] == 'aborted'
|
||||
|
||||
|
||||
@pytest.mark.parametrize('term_attr', ['terms', 'non_member_terms'])
|
||||
def test_expire(client, new_user_gen, term_attr):
|
||||
start_of_current_term = Term.current().to_datetime()
|
||||
# test_date, should_expire
|
||||
test_cases = [
|
||||
# same term, membership is still valid
|
||||
(start_of_current_term + datetime.timedelta(days=90), False),
|
||||
# first month of next term, grace period is activated
|
||||
(start_of_current_term + datetime.timedelta(days=130), False),
|
||||
# second month of next term, membership is now invalid
|
||||
(start_of_current_term + datetime.timedelta(days=160), True),
|
||||
# next next term, membership is definitely invalid
|
||||
(start_of_current_term + datetime.timedelta(days=250), True),
|
||||
]
|
||||
|
||||
for test_date, should_expire in test_cases:
|
||||
with new_user_gen() as user_obj, \
|
||||
patch.object(ceo_common.utils, 'get_current_datetime') as datetime_mock:
|
||||
user = user_obj.to_dict()
|
||||
uid = user['uid']
|
||||
datetime_mock.return_value = test_date
|
||||
|
||||
assert user['shadowExpire'] is None
|
||||
|
||||
status, data = client.post('/api/members/expire?dry_run=yes')
|
||||
assert status == 200
|
||||
assert (data == [uid]) == should_expire
|
||||
|
||||
_, user = client.get(f'/api/members/{uid}')
|
||||
assert user['shadowExpire'] is None
|
||||
|
||||
status, data = client.post('/api/members/expire')
|
||||
assert status == 200
|
||||
assert (data == [uid]) == should_expire
|
||||
|
||||
_, user = client.get(f'/api/members/{uid}')
|
||||
assert (user['shadowExpire'] is not None) == should_expire
|
||||
|
||||
if not should_expire:
|
||||
continue
|
||||
|
||||
term = Term.from_datetime(test_date)
|
||||
status, _ = client.post(f'/api/members/{uid}/renew', json={term_attr: [str(term)]})
|
||||
assert status == 200
|
||||
|
||||
_, user = client.get(f'/api/members/{uid}')
|
||||
assert user['shadowExpire'] is None
|
||||
|
|
|
@ -171,6 +171,7 @@ def test_user_to_dict(cfg):
|
|||
'home_directory': user.home_directory,
|
||||
'is_club': False,
|
||||
'is_club_rep': False,
|
||||
'shadowExpire': None,
|
||||
}
|
||||
assert user.to_dict() == expected
|
||||
|
||||
|
|
|
@ -1,15 +1,13 @@
|
|||
import contextlib
|
||||
import grp
|
||||
import importlib.resources
|
||||
from multiprocessing import Process
|
||||
import os
|
||||
import pwd
|
||||
import shutil
|
||||
import subprocess
|
||||
from subprocess import DEVNULL
|
||||
import sys
|
||||
import time
|
||||
from unittest.mock import patch, Mock
|
||||
from unittest.mock import Mock
|
||||
|
||||
import flask
|
||||
import gssapi
|
||||
|
@ -19,7 +17,12 @@ import requests
|
|||
import socket
|
||||
from zope import component
|
||||
|
||||
from .utils import gssapi_token_ctx, ccache_cleanup # noqa: F401
|
||||
# noqa: F401
|
||||
from .utils import ( # noqa: F401
|
||||
gssapi_token_ctx,
|
||||
ccache_cleanup,
|
||||
mocks_for_create_user_ctx,
|
||||
)
|
||||
from ceo_common.interfaces import IConfig, IKerberosService, ILDAPService, \
|
||||
IFileService, IMailmanService, IHTTPClient, IUWLDAPService, IMailService, \
|
||||
IDatabaseService, ICloudService
|
||||
|
@ -29,7 +32,6 @@ from ceod.db import MySQLService, PostgreSQLService
|
|||
from ceod.model import KerberosService, LDAPService, FileService, User, \
|
||||
MailmanService, Group, UWLDAPService, UWLDAPRecord, MailService, \
|
||||
CloudService
|
||||
import ceod.utils as utils
|
||||
from .MockSMTPServer import MockSMTPServer
|
||||
from .MockMailmanServer import MockMailmanServer
|
||||
from .MockCloudStackServer import MockCloudStackServer
|
||||
|
@ -301,17 +303,15 @@ def app(
|
|||
return app
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
@pytest.fixture
|
||||
def mocks_for_create_user():
|
||||
with patch.object(utils, 'gen_password') as gen_password_mock, \
|
||||
patch.object(pwd, 'getpwuid') as getpwuid_mock, \
|
||||
patch.object(grp, 'getgrgid') as getgrgid_mock:
|
||||
gen_password_mock.return_value = 'krb5'
|
||||
# Normally, if getpwuid or getgrgid do *not* raise a KeyError,
|
||||
# then LDAPService will skip that UID. Therefore, by raising a
|
||||
# KeyError, we are making sure that the UID will *not* be skipped.
|
||||
getpwuid_mock.side_effect = KeyError()
|
||||
getgrgid_mock.side_effect = KeyError()
|
||||
with mocks_for_create_user_ctx():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture(scope='module')
|
||||
def mocks_for_create_user_module():
|
||||
with mocks_for_create_user_ctx():
|
||||
yield
|
||||
|
||||
|
||||
|
@ -356,10 +356,15 @@ def krb_user(simple_user):
|
|||
simple_user.remove_from_kerberos()
|
||||
|
||||
|
||||
_new_user_id_counter = 10001
|
||||
@pytest.fixture # noqa: E302
|
||||
def new_user(client, g_admin_ctx, ldap_srv_session): # noqa: F811
|
||||
global _new_user_id_counter
|
||||
@pytest.fixture
|
||||
def new_user_gen(
|
||||
client, g_admin_ctx, ldap_srv_session, mocks_for_create_user, # noqa: F811
|
||||
):
|
||||
_new_user_id_counter = 11001
|
||||
|
||||
@contextlib.contextmanager
|
||||
def wrapper():
|
||||
nonlocal _new_user_id_counter
|
||||
uid = 'test' + str(_new_user_id_counter)
|
||||
_new_user_id_counter += 1
|
||||
status, data = client.post('/api/members', json={
|
||||
|
@ -372,17 +377,25 @@ def new_user(client, g_admin_ctx, ldap_srv_session): # noqa: F811
|
|||
})
|
||||
assert status == 200
|
||||
assert data[-1]['status'] == 'completed'
|
||||
subprocess.run([
|
||||
'kadmin', '-k', '-p', 'ceod/admin',
|
||||
'modprinc', '-needchange', uid,
|
||||
], check=True)
|
||||
with g_admin_ctx():
|
||||
user = ldap_srv_session.get_user(uid)
|
||||
subprocess.run([
|
||||
'kadmin', '-k', '-p', 'ceod/admin', 'cpw',
|
||||
'-pw', 'krb5', uid,
|
||||
], check=True)
|
||||
yield user
|
||||
status, data = client.delete(f'/api/members/{uid}')
|
||||
assert status == 200
|
||||
assert data[-1]['status'] == 'completed'
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def new_user(new_user_gen):
|
||||
with new_user_gen() as user:
|
||||
yield user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_group():
|
||||
|
|
|
@ -1,8 +1,12 @@
|
|||
import ceod.utils as ceod_utils
|
||||
import contextlib
|
||||
import os
|
||||
import grp
|
||||
import pwd
|
||||
import subprocess
|
||||
from subprocess import DEVNULL
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
|
||||
import gssapi
|
||||
import pytest
|
||||
|
@ -45,3 +49,23 @@ def ccache_cleanup():
|
|||
"""Make sure the ccache files get deleted at the end of the tests."""
|
||||
yield
|
||||
_cache.clear()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def gen_password_mock_ctx():
|
||||
with patch.object(ceod_utils, 'gen_password') as mock:
|
||||
mock.return_value = 'krb5'
|
||||
yield
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mocks_for_create_user_ctx():
|
||||
with gen_password_mock_ctx(), \
|
||||
patch.object(pwd, 'getpwuid') as getpwuid_mock, \
|
||||
patch.object(grp, 'getgrgid') as getgrgid_mock:
|
||||
# Normally, if getpwuid or getgrgid do *not* raise a KeyError,
|
||||
# then LDAPService will skip that UID. Therefore, by raising a
|
||||
# KeyError, we are making sure that the UID will *not* be skipped.
|
||||
getpwuid_mock.side_effect = KeyError()
|
||||
getgrgid_mock.side_effect = KeyError()
|
||||
yield
|
||||
|
|
Loading…
Reference in New Issue