diff --git a/ceo/cli/members.py b/ceo/cli/members.py index 7972df3..44e2ecd 100644 --- a/ceo/cli/members.py +++ b/ceo/cli/members.py @@ -4,12 +4,13 @@ from typing import Dict import click from zope import component -from ..term_utils import get_terms_for_new_user, get_terms_for_renewal +from ..term_utils import get_terms_for_renewal_for_user from ..utils import http_post, http_get, http_patch, http_delete, \ get_failed_operations, user_dict_lines, get_adduser_operations from .utils import handle_stream_response, handle_sync_response, print_lines, \ check_if_in_development from ceo_common.interfaces import IConfig +from ceo_common.model.Term import get_terms_for_new_user from ceod.transactions.members import DeleteMemberTransaction @@ -155,7 +156,7 @@ def modify(username, login_shell, forwarding_addresses): @click.option('--clubrep', is_flag=True, default=False, help='Add non-member terms instead of member terms') def renew(username, num_terms, clubrep): - terms = get_terms_for_renewal(username, num_terms, clubrep) + terms = get_terms_for_renewal_for_user(username, num_terms, clubrep) if clubrep: body = {'non_member_terms': terms} diff --git a/ceo/term_utils.py b/ceo/term_utils.py index 56b635a..d400a5e 100644 --- a/ceo/term_utils.py +++ b/ceo/term_utils.py @@ -1,38 +1,22 @@ from typing import List from .utils import http_get -from ceo_common.model import Term +from ceo_common.model.Term import get_terms_for_renewal import ceo.cli.utils as cli_utils import ceo.tui.utils as tui_utils -# Had to put these in a separate file to avoid a circular import. - -def get_terms_for_new_user(num_terms: int) -> List[str]: - current_term = Term.current() - terms = [current_term + i for i in range(num_terms)] - return list(map(str, terms)) - - -def get_terms_for_renewal( +def get_terms_for_renewal_for_user( username: str, num_terms: int, clubrep: bool, tui_controller=None, ) -> List[str]: resp = http_get('/api/members/' + username) + # FIXME: this is ugly, we shouldn't need a hacky if statement like this if tui_controller is None: result = cli_utils.handle_sync_response(resp) else: result = tui_utils.handle_sync_response(resp, tui_controller) - max_term = None - current_term = Term.current() - if clubrep and 'non_member_terms' in result: - max_term = max(Term(s) for s in result['non_member_terms']) - elif not clubrep and 'terms' in result: - max_term = max(Term(s) for s in result['terms']) - if max_term is not None and max_term >= current_term: - next_term = max_term + 1 + if clubrep: + return get_terms_for_renewal(result.get('non_member_terms'), num_terms) else: - next_term = Term.current() - - terms = [next_term + i for i in range(num_terms)] - return list(map(str, terms)) + return get_terms_for_renewal(result.get('terms'), num_terms) diff --git a/ceo/tui/controllers/AddUserController.py b/ceo/tui/controllers/AddUserController.py index b0b7604..93a64c3 100644 --- a/ceo/tui/controllers/AddUserController.py +++ b/ceo/tui/controllers/AddUserController.py @@ -3,9 +3,9 @@ from threading import Thread from ...utils import http_get from .Controller import Controller from .AddUserTransactionController import AddUserTransactionController -import ceo.term_utils as term_utils from ceo.tui.models import TransactionModel from ceo.tui.views import AddUserConfirmationView, TransactionView +from ceo_common.model.Term import get_terms_for_new_user from ceod.transactions.members import AddMemberTransaction @@ -26,7 +26,7 @@ class AddUserController(Controller): body['program'] = self.model.program if self.model.forwarding_address: body['forwarding_addresses'] = [self.model.forwarding_address] - new_terms = term_utils.get_terms_for_new_user(self.model.num_terms) + new_terms = get_terms_for_new_user(self.model.num_terms) if self.model.membership_type == 'club_rep': body['non_member_terms'] = new_terms else: diff --git a/ceo/tui/controllers/RenewUserController.py b/ceo/tui/controllers/RenewUserController.py index f836f9f..786ba53 100644 --- a/ceo/tui/controllers/RenewUserController.py +++ b/ceo/tui/controllers/RenewUserController.py @@ -28,7 +28,7 @@ class RenewUserController(SyncRequestController): def _get_next_terms(self): try: - self.model.new_terms = term_utils.get_terms_for_renewal( + self.model.new_terms = term_utils.get_terms_for_renewal_for_user( self.model.username, self.model.num_terms, self.model.membership_type == 'club_rep', diff --git a/ceo_common/model/Term.py b/ceo_common/model/Term.py index 3839b63..4a544c5 100644 --- a/ceo_common/model/Term.py +++ b/ceo_common/model/Term.py @@ -1,4 +1,5 @@ import datetime +from typing import List, Union import ceo_common.utils as utils @@ -81,3 +82,39 @@ class Term: month = self.seasons.index(c) * 4 + 1 day = 1 return datetime.datetime(year, month, day) + + +# Utility functions + +def get_terms_for_new_user(num_terms: int) -> List[str]: + current_term = Term.current() + terms = [current_term + i for i in range(num_terms)] + return list(map(str, terms)) + + +def get_terms_for_renewal( + existing_terms: Union[List[str], None], + num_terms: int, +) -> List[str]: + """Calculates the terms for which a member or club rep should be renewed. + + :param terms: The existing terms for the user being renewed. If the user + is being renewed as a regular member, these should be the + member terms. If they are being renewed as a club rep, these + should be the non-member terms. + This may be None if the user does not have any terms of the + appropriate type (an empty list is also acceptable). + :param num_terms: The number of terms for which the user is being renewed. + """ + max_term = None + current_term = Term.current() + if existing_terms: + max_term = max(map(Term, existing_terms)) + + if max_term is not None and max_term >= current_term: + next_term = max_term + 1 + else: + next_term = current_term + + terms = [next_term + i for i in range(num_terms)] + return list(map(str, terms)) diff --git a/ceod/api/members.py b/ceod/api/members.py index 8394ee7..0fe36fb 100644 --- a/ceod/api/members.py +++ b/ceod/api/members.py @@ -8,6 +8,7 @@ from .utils import authz_restrict_to_staff, authz_restrict_to_syscom, \ from ceo_common.errors import BadRequest, UserAlreadySubscribedError, UserNotSubscribedError from ceo_common.interfaces import ILDAPService, IConfig, IMailService from ceo_common.logger_factory import logger_factory +from ceo_common.model.Term import get_terms_for_new_user, get_terms_for_renewal from ceod.transactions.members import ( AddMemberTransaction, ModifyMemberTransaction, @@ -31,6 +32,10 @@ def create_user(): non_member_terms = body.get('non_member_terms') if (terms and non_member_terms) or not (terms or non_member_terms): raise BadRequest('Must specify either terms or non-member terms') + if type(terms) is int: + terms = get_terms_for_new_user(terms) + elif type(non_member_terms) is int: + non_member_terms = get_terms_for_new_user(non_member_terms) for attr in ['uid', 'cn', 'given_name', 'sn']: if not body.get(attr): raise BadRequest(f"Attribute '{attr}' is missing or empty") @@ -104,6 +109,11 @@ def renew_user(username: str): user = ldap_srv.get_user(username) member_list = cfg.get('mailman3_new_member_list') + if type(terms) is int: + terms = get_terms_for_renewal(user.terms, terms) + elif type(non_member_terms) is int: + non_member_terms = get_terms_for_renewal(user.non_member_terms, non_member_terms) + def unexpire(user): if user.shadowExpire: user.set_expired(False) @@ -113,16 +123,16 @@ def renew_user(username: str): except UserAlreadySubscribedError: logger.debug(f'{user.uid} is already unsubscribed from {member_list}') - if body.get('terms'): - logger.info(f"Renewing member {username} for terms {body['terms']}") - user.add_terms(body['terms']) + if terms: + logger.info(f"Renewing member {username} for terms {terms}") + user.add_terms(terms) unexpire(user) - return {'terms_added': body['terms']} - elif body.get('non_member_terms'): - logger.info(f"Renewing club rep {username} for non-member terms {body['non_member_terms']}") - user.add_non_member_terms(body['non_member_terms']) + return {'terms_added': terms} + elif non_member_terms: + logger.info(f"Renewing club rep {username} for non-member terms {non_member_terms}") + user.add_non_member_terms(non_member_terms) unexpire(user) - return {'non_member_terms_added': body['non_member_terms']} + return {'non_member_terms_added': non_member_terms} else: raise BadRequest('Must specify either terms or non-member terms') diff --git a/dev-requirements.txt b/dev-requirements.txt index 8392fb6..e92273e 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,6 +1,6 @@ -flake8==3.9.2 -setuptools==40.8.0 -wheel==0.36.2 -pytest==6.2.4 +flake8==5.0.4 +setuptools==65.4.1 +wheel==0.37.1 +pytest==7.1.3 aiosmtpd==1.4.2 -aiohttp==3.7.4.post0 +aiohttp==3.8.3 diff --git a/tests/ceod/api/test_members.py b/tests/ceod/api/test_members.py index 8998a6a..2814265 100644 --- a/tests/ceod/api/test_members.py +++ b/tests/ceod/api/test_members.py @@ -83,6 +83,25 @@ def test_api_create_user(cfg, create_user_resp, mock_mail_server): mock_mail_server.messages.clear() +def test_api_create_user_with_num_terms(client): + status, data = client.post('/api/members', json={ + 'uid': 'test2', + 'cn': 'Test Two', + 'given_name': 'Test', + 'sn': 'Two', + 'program': 'Math', + 'terms': 2, + 'forwarding_addresses': ['test2@uwaterloo.internal'], + }) + assert status == 200 + assert data[-1]['status'] == 'completed' + current_term = Term.current() + assert data[-1]['result']['terms'] == [str(current_term), str(current_term + 1)] + status, data = client.delete('/api/members/test2') + assert status == 200 + assert data[-1]['status'] == 'completed' + + def test_api_next_uid(cfg, client, create_user_result): min_uid = cfg.get('members_min_id') _, data = client.post('/api/members', json={ @@ -202,6 +221,20 @@ def test_api_renew_user(cfg, client, create_user_result, ldap_conn): ldap_conn.modify(dn, changes) +def test_api_renew_user_with_num_terms(client, ldap_user): + uid = ldap_user.uid + status, data = client.post(f'/api/members/{uid}/renew', json={'terms': 2}) + assert status == 200 + _, data = client.get(f'/api/members/{uid}') + current_term = Term.current() + assert data['terms'] == [str(current_term), str(current_term + 1), str(current_term + 2)] + + status, data = client.post(f'/api/members/{uid}/renew', json={'non_member_terms': 2}) + assert status == 200 + _, data = client.get(f'/api/members/{uid}') + assert data['non_member_terms'] == [str(current_term), str(current_term + 1)] + + def test_api_reset_password(client, create_user_result): uid = create_user_result['uid'] with patch.object(ceod.utils, 'gen_password') as gen_password_mock: