From bb56870652d4f2fc68cd7cba19efd7e71a2ca723 Mon Sep 17 00:00:00 2001 From: Max Erenberg Date: Sun, 29 Aug 2021 03:09:02 +0000 Subject: [PATCH 1/2] add skeleton for TUI --- ceo/StreamResponseHandler.py | 43 +++++++ ceo/__main__.py | 39 +++++- ceo/cli/CLIStreamResponseHandler.py | 67 ++++++++++ ceo/cli/entrypoint.py | 37 +----- ceo/cli/members.py | 72 +++-------- ceo/cli/utils.py | 87 ++----------- ceo/krb_check.py | 15 ++- ceo/tui/ConfirmView.py | 59 +++++++++ ceo/tui/Model.py | 12 ++ ceo/tui/TUIStreamResponseHandler.py | 98 +++++++++++++++ ceo/tui/TransactionView.py | 81 ++++++++++++ ceo/tui/WelcomeView.py | 57 +++++++++ ceo/tui/__init__.py | 0 ceo/tui/members/AddUserView.py | 105 ++++++++++++++++ ceo/tui/members/__init__.py | 0 ceo/tui/start.py | 46 +++++++ ceo/utils.py | 152 ++++++++++++++++++++++- ceo_common/interfaces/IHTTPClient.py | 13 +- ceo_common/model/HTTPClient.py | 58 +++++---- ceo_common/model/RemoteMailmanService.py | 6 +- gen_cred.py | 14 --- requirements.txt | 1 + tests/ceo/cli/test_groups.py | 2 + tests/ceo/cli/test_members.py | 31 ++--- tests/conftest.py | 17 ++- tests/conftest_ceod_api.py | 26 ++-- 26 files changed, 882 insertions(+), 256 deletions(-) create mode 100644 ceo/StreamResponseHandler.py create mode 100644 ceo/cli/CLIStreamResponseHandler.py create mode 100644 ceo/tui/ConfirmView.py create mode 100644 ceo/tui/Model.py create mode 100644 ceo/tui/TUIStreamResponseHandler.py create mode 100644 ceo/tui/TransactionView.py create mode 100644 ceo/tui/WelcomeView.py create mode 100644 ceo/tui/__init__.py create mode 100644 ceo/tui/members/AddUserView.py create mode 100644 ceo/tui/members/__init__.py create mode 100644 ceo/tui/start.py delete mode 100755 gen_cred.py diff --git a/ceo/StreamResponseHandler.py b/ceo/StreamResponseHandler.py new file mode 100644 index 0000000..00187fa --- /dev/null +++ b/ceo/StreamResponseHandler.py @@ -0,0 +1,43 @@ +from abc import ABC, abstractmethod +from typing import Union + +import requests + + +class StreamResponseHandler(ABC): + """ + An abstract class to handle stream responses from the server. + The CLI and TUI should implement a child class. + """ + + @abstractmethod + def handle_non_200(self, resp: requests.Response): + """Handle a non-200 response.""" + + @abstractmethod + def begin(self): + """Begin the transaction.""" + + @abstractmethod + def handle_aborted(self, err_msg: str): + """Handle an aborted transaction.""" + + @abstractmethod + def handle_completed(self): + """Handle a completed transaction.""" + + @abstractmethod + def handle_successful_operation(self): + """Handle a successful operation.""" + + @abstractmethod + def handle_failed_operation(self, err_msg: Union[str, None]): + """Handle a failed operation.""" + + @abstractmethod + def handle_skipped_operation(self): + """Handle a skipped operation.""" + + @abstractmethod + def handle_unrecognized_operation(self, operation: str): + """Handle an unrecognized operation.""" diff --git a/ceo/__main__.py b/ceo/__main__.py index aa3acee..af29d1c 100644 --- a/ceo/__main__.py +++ b/ceo/__main__.py @@ -1,4 +1,41 @@ +import importlib.resources +import os +import socket +import sys + +from zope import component + from .cli import cli +from .krb_check import krb_check +from .tui.start import main as tui_main +from ceo_common.interfaces import IConfig, IHTTPClient +from ceo_common.model import Config, HTTPClient + + +def register_services(): + # Config + # This is a hack to determine if we're in the dev env or not + if socket.getfqdn().endswith('.csclub.internal'): + with importlib.resources.path('tests', 'ceo_dev.ini') as p: + config_file = p.__fspath__() + else: + config_file = os.environ.get('CEO_CONFIG', '/etc/csc/ceo.ini') + cfg = Config(config_file) + component.provideUtility(cfg, IConfig) + + # HTTPService + http_client = HTTPClient() + component.provideUtility(http_client, IHTTPClient) + + +def main(): + krb_check() + register_services() + if len(sys.argv) > 1: + cli(obj={}) + else: + tui_main() + if __name__ == '__main__': - cli(obj={}) + main() diff --git a/ceo/cli/CLIStreamResponseHandler.py b/ceo/cli/CLIStreamResponseHandler.py new file mode 100644 index 0000000..de0dae0 --- /dev/null +++ b/ceo/cli/CLIStreamResponseHandler.py @@ -0,0 +1,67 @@ +from typing import List, Union + +import click +import requests + +from ..StreamResponseHandler import StreamResponseHandler +from ..operation_strings import descriptions as op_desc + + +class Abort(click.ClickException): + """Abort silently.""" + + def __init__(self, exit_code=1): + super().__init__('') + self.exit_code = exit_code + + def show(self): + pass + + +class CLIStreamResponseHandler(StreamResponseHandler): + def __init__(self, operations: List[str]): + self.operations = operations + self.idx = 0 + + def handle_non_200(self, resp: requests.Response): + click.echo('An error occurred:') + click.echo(resp.text.rstrip()) + raise Abort() + + def begin(self): + click.echo(op_desc[self.operations[0]] + '... ', nl=False) + + def handle_aborted(self, err_msg: str): + click.echo(click.style('ABORTED', fg='red')) + click.echo('The transaction was rolled back.') + click.echo('The error was: ' + err_msg) + click.echo('Please check the ceod logs.') + + def handle_completed(self): + click.echo('Transaction successfully completed.') + + def _go_to_next_op(self): + """ + Increment the operation index and print the next operation, if + there is one. + """ + self.idx += 1 + if self.idx < len(self.operations): + click.echo(op_desc[self.operations[self.idx]] + '... ', nl=False) + + def handle_successful_operation(self): + click.echo(click.style('Done', fg='green')) + self._go_to_next_op() + + def handle_failed_operation(self, err_msg: Union[str, None]): + click.echo(click.style('Failed', fg='red')) + if err_msg is not None: + click.echo(' Error message: ' + err_msg) + self._go_to_next_op() + + def handle_skipped_operation(self): + click.echo('Skipped') + self._go_to_next_op() + + def handle_unrecognized_operation(self, operation: str): + click.echo('Unrecognized operation: ' + operation) diff --git a/ceo/cli/entrypoint.py b/ceo/cli/entrypoint.py index 04f5306..1221144 100644 --- a/ceo/cli/entrypoint.py +++ b/ceo/cli/entrypoint.py @@ -1,48 +1,15 @@ -import importlib.resources -import os -import socket - import click -from zope import component -from ..krb_check import krb_check from .members import members from .groups import groups from .updateprograms import updateprograms -from ceo_common.interfaces import IConfig, IHTTPClient -from ceo_common.model import Config, HTTPClient @click.group() -@click.pass_context -def cli(ctx): - # ensure ctx exists and is a dict - ctx.ensure_object(dict) - - princ = krb_check() - user = princ[:princ.index('@')] - ctx.obj['user'] = user - - if os.environ.get('PYTEST') != '1': - register_services() +def cli(): + pass cli.add_command(members) cli.add_command(groups) cli.add_command(updateprograms) - - -def register_services(): - # Config - # This is a hack to determine if we're in the dev env or not - if socket.getfqdn().endswith('.csclub.internal'): - with importlib.resources.path('tests', 'ceo_dev.ini') as p: - config_file = p.__fspath__() - else: - config_file = os.environ.get('CEO_CONFIG', '/etc/csc/ceo.ini') - cfg = Config(config_file) - component.provideUtility(cfg, IConfig) - - # HTTPService - http_client = HTTPClient() - component.provideUtility(http_client, IHTTPClient) diff --git a/ceo/cli/members.py b/ceo/cli/members.py index dd6a3e8..6ad0243 100644 --- a/ceo/cli/members.py +++ b/ceo/cli/members.py @@ -4,15 +4,14 @@ from typing import Dict import click from zope import component -from ..utils import http_post, http_get, http_patch, http_delete, get_failed_operations -from .utils import handle_stream_response, handle_sync_response, print_colon_kv, \ +from ..utils import http_post, http_get, http_patch, http_delete, \ + get_failed_operations, get_terms_for_new_user, user_dict_lines, \ + get_adduser_operations +from .utils import handle_stream_response, handle_sync_response, print_lines, \ check_if_in_development from ceo_common.interfaces import IConfig from ceo_common.model import Term -from ceod.transactions.members import ( - AddMemberTransaction, - DeleteMemberTransaction, -) +from ceod.transactions.members import DeleteMemberTransaction @click.group(short_help='Perform operations on CSC members and club reps') @@ -36,30 +35,12 @@ def add(username, cn, program, num_terms, clubrep, forwarding_address): cfg = component.getUtility(IConfig) uw_domain = cfg.get('uw_domain') - current_term = Term.current() - terms = [current_term + i for i in range(num_terms)] - terms = list(map(str, terms)) + terms = get_terms_for_new_user(num_terms) + # TODO: get email address from UWLDAP if forwarding_address is None: forwarding_address = username + '@' + uw_domain - click.echo("The following user will be created:") - lines = [ - ('uid', username), - ('cn', cn), - ] - if program is not None: - lines.append(('program', program)) - if clubrep: - lines.append(('non-member terms', ','.join(terms))) - else: - lines.append(('member terms', ','.join(terms))) - if forwarding_address != '': - lines.append(('forwarding address', forwarding_address)) - print_colon_kv(lines) - - click.confirm('Do you want to continue?', abort=True) - body = { 'uid': username, 'cn': cn, @@ -72,10 +53,14 @@ def add(username, cn, program, num_terms, clubrep, forwarding_address): body['terms'] = terms if forwarding_address != '': body['forwarding_addresses'] = [forwarding_address] - operations = AddMemberTransaction.operations - if forwarding_address == '': - # don't bother displaying this because it won't be run - operations.remove('set_forwarding_addresses') + else: + body['forwarding_addresses'] = [] + + click.echo("The following user will be created:") + print_user_lines(body) + click.confirm('Do you want to continue?', abort=True) + + operations = get_adduser_operations(body) resp = http_post('/api/members', json=body) data = handle_stream_response(resp, operations) @@ -89,30 +74,9 @@ def add(username, cn, program, num_terms, clubrep, forwarding_address): 'send the user their password.', fg='yellow')) -def print_user_lines(result: Dict): - """Pretty-print a user JSON response.""" - lines = [ - ('uid', result['uid']), - ('cn', result['cn']), - ('program', result.get('program', 'Unknown')), - ('UID number', result['uid_number']), - ('GID number', result['gid_number']), - ('login shell', result['login_shell']), - ('home directory', result['home_directory']), - ('is a club', result['is_club']), - ] - if 'forwarding_addresses' in result: - if len(result['forwarding_addresses']) != 0: - lines.append(('forwarding addresses', result['forwarding_addresses'][0])) - for address in result['forwarding_addresses'][1:]: - lines.append(('', address)) - if 'terms' in result: - lines.append(('terms', ','.join(result['terms']))) - if 'non_member_terms' in result: - lines.append(('non-member terms', ','.join(result['non_member_terms']))) - if 'password' in result: - lines.append(('password', result['password'])) - print_colon_kv(lines) +def print_user_lines(d: Dict): + """Pretty-print a serialized User.""" + print_lines(user_dict_lines(d)) @members.command(short_help='Get info about a user') diff --git a/ceo/cli/utils.py b/ceo/cli/utils.py index 98aa23f..76389c9 100644 --- a/ceo/cli/utils.py +++ b/ceo/cli/utils.py @@ -7,6 +7,8 @@ import click import requests from ..operation_strings import descriptions as op_desc +from ..utils import space_colon_kv, generic_handle_stream_response +from .CLIStreamResponseHandler import CLIStreamResponseHandler class Abort(click.ClickException): @@ -20,86 +22,23 @@ class Abort(click.ClickException): pass +def print_lines(lines: List[str]): + """Print multiple lines to stdout.""" + for line in lines: + click.echo(line) + + def print_colon_kv(pairs: List[Tuple[str, str]]): """ - Pretty-print a list of key-value pairs such that the key and value - columns align. - Example: - key1: value1 - key1000: value2 + Pretty-print a list of key-value pairs. """ - maxlen = max(len(key) for key, val in pairs) - for key, val in pairs: - if key != '': - click.echo(key + ': ', nl=False) - else: - # assume this is a continuation from the previous line - click.echo(' ', nl=False) - extra_space = ' ' * (maxlen - len(key)) - click.echo(extra_space, nl=False) - click.echo(val) + for line in space_colon_kv(pairs): + click.echo(line) def handle_stream_response(resp: requests.Response, operations: List[str]) -> List[Dict]: - """ - Print output to the console while operations are being streamed - from the server over HTTP. - Returns the parsed JSON data streamed from the server. - """ - if resp.status_code != 200: - click.echo('An error occurred:') - click.echo(resp.text.rstrip()) - raise Abort() - click.echo(op_desc[operations[0]] + '... ', nl=False) - idx = 0 - data = [] - for line in resp.iter_lines(decode_unicode=True, chunk_size=8): - d = json.loads(line) - data.append(d) - if d['status'] == 'aborted': - click.echo(click.style('ABORTED', fg='red')) - click.echo('The transaction was rolled back.') - click.echo('The error was: ' + d['error']) - click.echo('Please check the ceod logs.') - sys.exit(1) - elif d['status'] == 'completed': - if idx < len(operations): - click.echo('Skipped') - click.echo('Transaction successfully completed.') - return data - - operation = d['operation'] - oper_failed = False - err_msg = None - prefix = 'failed_to_' - if operation.startswith(prefix): - operation = operation[len(prefix):] - oper_failed = True - # sometimes the operation looks like - # "failed_to_do_something: error message" - if ':' in operation: - operation, err_msg = operation.split(': ', 1) - - while idx < len(operations) and operations[idx] != operation: - click.echo('Skipped') - idx += 1 - if idx == len(operations): - break - click.echo(op_desc[operations[idx]] + '... ', nl=False) - if idx == len(operations): - click.echo('Unrecognized operation: ' + operation) - continue - if oper_failed: - click.echo(click.style('Failed', fg='red')) - if err_msg is not None: - click.echo(' Error message: ' + err_msg) - else: - click.echo(click.style('Done', fg='green')) - idx += 1 - if idx < len(operations): - click.echo(op_desc[operations[idx]] + '... ', nl=False) - - raise Exception('server response ended abruptly') + handler = CLIStreamResponseHandler(operations) + return generic_handle_stream_response(resp, operations, handler) def handle_sync_response(resp: requests.Response): diff --git a/ceo/krb_check.py b/ceo/krb_check.py index 7824da7..fcecbdc 100644 --- a/ceo/krb_check.py +++ b/ceo/krb_check.py @@ -3,17 +3,28 @@ import subprocess import gssapi +_username = None + + +def get_username(): + """Get the user currently logged into CEO.""" + return _username + + def krb_check(): """ Spawns a `kinit` process if no credentials are available or the credentials have expired. - Returns the principal string 'user@REALM'. + Stores the username for later use by get_username(). """ + global _username for _ in range(2): try: creds = gssapi.Credentials(usage='initiate') result = creds.inquire() - return str(result.name) + princ = str(result.name) + _username = princ[:princ.index('@')] + return except (gssapi.raw.misc.GSSError, gssapi.raw.exceptions.ExpiredCredentialsError): kinit() diff --git a/ceo/tui/ConfirmView.py b/ceo/tui/ConfirmView.py new file mode 100644 index 0000000..ebf3fbe --- /dev/null +++ b/ceo/tui/ConfirmView.py @@ -0,0 +1,59 @@ +from asciimatics.exceptions import NextScene +from asciimatics.widgets import Frame, Layout, Button, Divider, Label + + +class ConfirmView(Frame): + def __init__(self, screen, width, height, model): + super().__init__( + screen, + height, + width, + can_scroll=False, + on_load=self._on_load, + title='Confirmation', + ) + self._model = model + + def _add_buttons(self): + layout = Layout([100]) + self.add_layout(layout) + layout.add_widget(Divider()) + + layout = Layout([1, 1]) + self.add_layout(layout) + layout.add_widget(Button('No', self._back), 0) + layout.add_widget(Button('Yes', self._next), 1) + + def _add_line(self, text: str = ''): + layout = Layout([100]) + self.add_layout(layout) + layout.add_widget(Label(text, align='^')) + + def _add_pair(self, key: str, val: str): + layout = Layout([10, 1, 10]) + self.add_layout(layout) + layout.add_widget(Label(key + ':', align='>'), 0) + layout.add_widget(Label(val, align='<'), 2) + + def _on_load(self): + for _ in range(2): + self._add_line() + for line in self._model.confirm_lines: + if isinstance(line, str): + self._add_line(line) + else: + # assume tuple + key, val = line + self._add_pair(key, val) + # fill the rest of the space + self.add_layout(Layout([100], fill_frame=True)) + + self._add_buttons() + self.fix() + + def _back(self): + raise NextScene(self._model.scene_stack.pop()) + + def _next(self): + self._model.scene_stack.append('Confirm') + raise NextScene('Transaction') diff --git a/ceo/tui/Model.py b/ceo/tui/Model.py new file mode 100644 index 0000000..01c4453 --- /dev/null +++ b/ceo/tui/Model.py @@ -0,0 +1,12 @@ +class Model: + """A convenient place to share data beween views.""" + + def __init__(self): + # simple key-value pairs + self.screen = None + self.title = None + self.for_member = True + self.scene_stack = [] + self.confirm_lines = None + self.operations = None + self.deferred_req = None diff --git a/ceo/tui/TUIStreamResponseHandler.py b/ceo/tui/TUIStreamResponseHandler.py new file mode 100644 index 0000000..d400448 --- /dev/null +++ b/ceo/tui/TUIStreamResponseHandler.py @@ -0,0 +1,98 @@ +from typing import Dict, Union + +from asciimatics.widgets import Label, Button, Layout, Frame +import requests + +from .Model import Model +from ..StreamResponseHandler import StreamResponseHandler + + +class TUIStreamResponseHandler(StreamResponseHandler): + def __init__( + self, + model: Model, + labels: Dict[str, Label], + next_btn: Button, + msg_layout: Layout, + frame: Frame, + ): + self.screen = model.screen + self.operations = model.operations + self.idx = 0 + self.labels = labels + self.next_btn = next_btn + self.msg_layout = msg_layout + self.frame = frame + self.error_messages = [] + + def _update(self): + # Since we're running in a separate thread, we need to force the + # screen to update. See + # https://github.com/peterbrittain/asciimatics/issues/56 + self.frame.fix() + self.screen.force_update() + + def _enable_next_btn(self): + self.next_btn.disabled = False + self.frame.reset() + + def _show_msg(self, msg: str = ''): + for line in msg.splitlines(): + self.msg_layout.add_widget(Label(line, align='^')) + + def _abort(self): + for operation in self.operations[self.idx:]: + self.labels[operation].text = 'ABORTED' + self._enable_next_btn() + + def handle_non_200(self, resp: requests.Response): + self._abort() + self._show_msg('An error occurred:') + self._show_msg(resp.text) + self._update() + + def begin(self): + pass + + def handle_aborted(self, err_msg: str): + self._abort() + self._show_msg('The transaction was rolled back.') + self._show_msg('The error was:') + self._show_msg(err_msg) + self._show_msg('Please check the ceod logs.') + self._update() + + def handle_completed(self): + self._show_msg('Transaction successfully completed.') + if len(self.error_messages) > 0: + self._show_msg('There were some errors, please check the ' + 'ceod logs.') + # we don't have enough space in the TUI to actually + # show the error messages + self._enable_next_btn() + self._update() + + def handle_successful_operation(self): + operation = self.operations[self.idx] + self.labels[operation].text = 'Done' + self.idx += 1 + self._update() + + def handle_failed_operation(self, err_msg: Union[str, None]): + operation = self.operations[self.idx] + self.labels[operation].text = 'Failed' + if err_msg is not None: + self.error_messages.append(err_msg) + self.idx += 1 + self._update() + + def handle_skipped_operation(self): + operation = self.operations[self.idx] + self.labels[operation].text = 'Skipped' + self.idx += 1 + self._update() + + def handle_unrecognized_operation(self, operation: str): + self.error_messages.append('Unrecognized operation: ' + operation) + self.idx += 1 + self._update() diff --git a/ceo/tui/TransactionView.py b/ceo/tui/TransactionView.py new file mode 100644 index 0000000..1773cfd --- /dev/null +++ b/ceo/tui/TransactionView.py @@ -0,0 +1,81 @@ +from threading import Thread + +from asciimatics.exceptions import NextScene +from asciimatics.widgets import Frame, Layout, Button, Divider, Label + +from ..operation_strings import descriptions as op_desc +from ..utils import generic_handle_stream_response +from .TUIStreamResponseHandler import TUIStreamResponseHandler + + +class TransactionView(Frame): + def __init__(self, screen, width, height, model): + super().__init__( + screen, + height, + width, + can_scroll=False, + on_load=self._on_load, + title='Running Transaction', + ) + self._model = model + # map operation names to label widgets + self._labels = {} + # this is an ugly hack to get around the fact that _on_load() + # will be called again when we reset() in the TUIStreamResponseHandler + self._loaded = False + + def _add_buttons(self): + layout = Layout([100]) + self.add_layout(layout) + layout.add_widget(Divider()) + + layout = Layout([1, 1]) + self.add_layout(layout) + self._next_btn = Button('Next', self._next) + self._next_btn.disabled = True + layout.add_widget(self._next_btn, 1) + + def _add_line(self, text: str = ''): + layout = Layout([100]) + self.add_layout(layout) + layout.add_widget(Label(text, align='^')) + + def _on_load(self): + if self._loaded: + return + self._loaded = True + + for _ in range(2): + self._add_line() + for operation in self._model.operations: + desc = op_desc[operation] + layout = Layout([10, 1, 10]) + self.add_layout(layout) + layout.add_widget(Label(desc + '...', align='>'), 0) + desc_label = Label('', align='<') + layout.add_widget(desc_label, 2) + self._labels[operation] = desc_label + self._add_line() + self._msg_layout = Layout([100]) + self.add_layout(self._msg_layout) + self.add_layout(Layout([100], fill_frame=True)) + + self._add_buttons() + self.fix() + Thread(target=self._do_txn).start() + + def _do_txn(self): + resp = self._model.deferred_req() + handler = TUIStreamResponseHandler( + model=self._model, + labels=self._labels, + next_btn=self._next_btn, + msg_layout=self._msg_layout, + frame=self, + ) + generic_handle_stream_response(resp, self._model.operations, handler) + + def _next(self): + self._model.scene_stack.clear() + raise NextScene('Welcome') diff --git a/ceo/tui/WelcomeView.py b/ceo/tui/WelcomeView.py new file mode 100644 index 0000000..52a269a --- /dev/null +++ b/ceo/tui/WelcomeView.py @@ -0,0 +1,57 @@ +from asciimatics.widgets import Frame, ListBox, Layout, Divider, \ + Button, Widget +from asciimatics.exceptions import NextScene, StopApplication + + +class WelcomeView(Frame): + def __init__(self, screen, width, height, model): + super().__init__( + screen, + height, + width, + can_scroll=False, + title='CSC Electronic Office', + ) + self._model = model + self._members_menu_items = [ + ('Add member', 'AddUser'), + ('Add club rep', 'AddUser'), + ('Renew member', 'RenewUser'), + ('Renew club rep', 'RenewUser'), + ('Get user info', 'GetUserInfo'), + ('Reset password', 'ResetPassword'), + ('Modify user', 'ModifyUser'), + ] + self._members_menu = ListBox( + Widget.FILL_FRAME, + [ + (desc, i) for i, (desc, view) in + enumerate(self._members_menu_items) + ], + name='members', + label='Members', + on_select=self._members_menu_select, + ) + layout = Layout([100], fill_frame=True) + self.add_layout(layout) + layout.add_widget(self._members_menu) + layout.add_widget(Divider()) + + layout = Layout([1, 1, 1]) + self.add_layout(layout) + layout.add_widget(Button("Quit", self._quit), 2) + self.fix() + + def _members_menu_select(self): + self.save() + item_id = self.data['members'] + desc, view = self._members_menu_items[item_id] + if desc.endswith('club rep'): + self._model.for_member = False + self._model.title = desc + self._model.scene_stack.append('Welcome') + raise NextScene(view) + + @staticmethod + def _quit(): + raise StopApplication("User pressed quit") diff --git a/ceo/tui/__init__.py b/ceo/tui/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ceo/tui/members/AddUserView.py b/ceo/tui/members/AddUserView.py new file mode 100644 index 0000000..ee2d747 --- /dev/null +++ b/ceo/tui/members/AddUserView.py @@ -0,0 +1,105 @@ +from asciimatics.exceptions import NextScene +from asciimatics.widgets import Frame, Layout, Text, Button, Divider + +from ...utils import http_get, http_post, defer, user_dict_kv, \ + get_terms_for_new_user, get_adduser_operations + + +class AddUserView(Frame): + def __init__(self, screen, width, height, model): + super().__init__( + screen, + height, + width, + can_scroll=False, + on_load=self._on_load, + ) + self._model = model + self._username_changed = False + layout = Layout([100], fill_frame=True) + self.add_layout(layout) + self._username = Text( + "Username:", "uid", + on_change=self._on_username_change, + on_blur=self._on_username_blur, + ) + layout.add_widget(self._username) + self._full_name = Text("Full name:", "cn") + layout.add_widget(self._full_name) + self._program = Text("Program:", "program") + layout.add_widget(self._program) + self._forwarding_address = Text("Forwarding address:", "forwarding_address") + layout.add_widget(self._forwarding_address) + self._num_terms = Text( + "Number of terms:", "num_terms", + validator=lambda s: s.isdigit() and s[0] != '0') + self._num_terms.value = '1' + layout.add_widget(self._num_terms) + + layout = Layout([100]) + self.add_layout(layout) + layout.add_widget(Divider()) + + layout = Layout([1, 1]) + self.add_layout(layout) + layout.add_widget(Button('Back', self._back), 0) + layout.add_widget(Button("Next", self._next), 1) + self.fix() + + def _on_load(self): + self.title = self._model.title + + def _on_username_change(self): + self._username_changed = True + + def _on_username_blur(self): + if not self._username_changed: + return + self._username_changed = False + username = self._username.value + if username == '': + return + self._get_uwldap_info(username) + + def _get_uwldap_info(self, username): + resp = http_get('/api/uwldap/' + username) + if resp.status_code != 200: + return + data = resp.json() + self._full_name.value = data['cn'] + self._program.value = data.get('program', '') + if data.get('mail_local_addresses'): + self._forwarding_address.value = data['mail_local_addresses'][0] + + def _back(self): + raise NextScene(self._model.scene_stack.pop()) + + def _next(self): + self._model.prev_scene = 'AddUser' + body = { + 'uid': self._username.value, + 'cn': self._full_name.value, + } + if self._program.value: + body['program'] = self._program.value + if self._forwarding_address.value: + body['forwarding_addresses'] = [self._forwarding_address.value] + new_terms = get_terms_for_new_user(int(self._num_terms.value)) + if self._model.for_member: + body['terms'] = new_terms + else: + body['non_member_terms'] = new_terms + pairs = user_dict_kv(body) + self._model.confirm_lines = [ + 'The following user will be created:', + '', + ] + pairs + [ + '', + 'Are you sure you want to continue?', + ] + + self._model.deferred_req = defer(http_post, '/api/members', json=body) + self._model.operations = get_adduser_operations(body) + + self._model.scene_stack.append('AddUser') + raise NextScene('Confirm') diff --git a/ceo/tui/members/__init__.py b/ceo/tui/members/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ceo/tui/start.py b/ceo/tui/start.py new file mode 100644 index 0000000..d5e8031 --- /dev/null +++ b/ceo/tui/start.py @@ -0,0 +1,46 @@ +import sys + +from asciimatics.event import KeyboardEvent +from asciimatics.exceptions import ResizeScreenError, StopApplication +from asciimatics.scene import Scene +from asciimatics.screen import Screen + +from .ConfirmView import ConfirmView +from .Model import Model +from .TransactionView import TransactionView +from .WelcomeView import WelcomeView +from .members.AddUserView import AddUserView + + +def unhandled(event): + if isinstance(event, KeyboardEvent): + c = event.key_code + # Stop on 'q' or 'Esc' + if c in (113, 27): + raise StopApplication("User terminated app") + + +def screen_wrapper(screen, scene, model): + model.screen = screen + width = min(screen.width, 90) + height = min(screen.height, 24) + scenes = [ + Scene([WelcomeView(screen, width, height, model)], -1, name='Welcome'), + Scene([AddUserView(screen, width, height, model)], -1, name='AddUser'), + Scene([ConfirmView(screen, width, height, model)], -1, name='Confirm'), + Scene([TransactionView(screen, width, height, model)], -1, name='Transaction'), + ] + screen.play( + scenes, stop_on_resize=True, start_scene=scene, allow_int=True, + unhandled_input=unhandled) + + +def main(): + last_scene = None + model = Model() + while True: + try: + Screen.wrapper(screen_wrapper, arguments=[last_scene, model]) + sys.exit(0) + except ResizeScreenError as e: + last_scene = e.scene diff --git a/ceo/utils.py b/ceo/utils.py index 5673d8c..32ef574 100644 --- a/ceo/utils.py +++ b/ceo/utils.py @@ -1,9 +1,15 @@ -from typing import List, Dict +import functools +import json +import sys +from typing import List, Dict, Tuple, Callable import requests from zope import component +from .StreamResponseHandler import StreamResponseHandler from ceo_common.interfaces import IHTTPClient, IConfig +from ceo_common.model import Term +from ceod.transactions.members import AddMemberTransaction def http_request(method: str, path: str, **kwargs) -> requests.Response: @@ -11,13 +17,10 @@ def http_request(method: str, path: str, **kwargs) -> requests.Response: cfg = component.getUtility(IConfig) if path.startswith('/api/db'): host = cfg.get('ceod_db_host') - delegate = False else: host = cfg.get('ceod_admin_host') - # The forwarded TGT is only needed for endpoints which write to LDAP - delegate = method != 'GET' return client.request( - host, path, method, delegate=delegate, stream=True, **kwargs) + method, host, path, stream=True, **kwargs) def http_get(path: str, **kwargs) -> requests.Response: @@ -56,3 +59,142 @@ def get_failed_operations(data: List[Dict]) -> List[str]: operation = operation[:operation.index(':')] failed.append(operation) return failed + + +def space_colon_kv(pairs: List[Tuple[str, str]]) -> List[str]: + """ + Pretty-format the lines so that the keys and values + are aligned into columns. + Example: + key1: val1 + key2: val2 + key1000: val3 + val4 + """ + lines = [] + maxlen = max(len(key) for key, val in pairs) + for key, val in pairs: + if key != '': + prefix = key + ': ' + else: + # assume this is a continuation from the previous line + prefix = ' ' + extra_space = ' ' * (maxlen - len(key)) + line = prefix + extra_space + str(val) + lines.append(line) + return lines + + +def get_terms_for_new_user(num_terms: int) -> List[str]: + current_term = Term.current() + terms = [current_term + i for i in range(num_terms)] + return list(map(str, terms)) + + +def user_dict_kv(d: Dict) -> List[Tuple[str]]: + """Pretty-format a serialized User as (key, value) pairs.""" + pairs = [ + ('uid', d['uid']), + ('cn', d['cn']), + ('program', d.get('program', 'Unknown')), + ] + if 'uid_number' in d: + pairs.append(('UID number', d['uid_number'])) + if 'gid_number' in d: + pairs.append(('GID number', d['gid_number'])) + if 'login_shell' in d: + pairs.append(('login shell', d['login_shell'])) + if 'home_directory' in d: + pairs.append(('home directory', d['home_directory'])) + if 'is_club' in d: + pairs.append(('is a club', str(d['is_club']))) + if 'forwarding_addresses' in d: + if len(d['forwarding_addresses']) > 0: + pairs.append(('forwarding addresses', d['forwarding_addresses'][0])) + for address in d['forwarding_addresses'][1:]: + pairs.append(('', address)) + else: + pairs.append(('forwarding addresses', '')) + if 'terms' in d: + pairs.append(('member terms', ','.join(d['terms']))) + if 'non_member_terms' in d: + pairs.append(('non-member terms', ','.join(d['non_member_terms']))) + if 'password' in d: + pairs.append(('password', d['password'])) + return pairs + + +def user_dict_lines(d: Dict) -> List[str]: + """Pretty-format a serialized User.""" + return space_colon_kv(user_dict_kv(d)) + + +def get_adduser_operations(body: Dict): + operations = AddMemberTransaction.operations.copy() + if not body.get('forwarding_addresses'): + # don't bother displaying this because it won't be run + operations.remove('set_forwarding_addresses') + return operations + + +def generic_handle_stream_response( + resp: requests.Response, + operations: List[str], + handler: StreamResponseHandler, +) -> List[Dict]: + """ + Print output to the console while operations are being streamed + from the server over HTTP. + Returns the parsed JSON data streamed from the server. + """ + if resp.status_code != 200: + handler.handle_non_200(resp) + handler.begin() + idx = 0 + data = [] + for line in resp.iter_lines(decode_unicode=True, chunk_size=8): + d = json.loads(line) + data.append(d) + if d['status'] == 'aborted': + handler.handle_aborted(d['error']) + sys.exit(1) + elif d['status'] == 'completed': + while idx < len(operations): + handler.handle_skipped_operation() + idx += 1 + handler.handle_completed() + return data + + operation = d['operation'] + oper_failed = False + err_msg = None + prefix = 'failed_to_' + if operation.startswith(prefix): + operation = operation[len(prefix):] + oper_failed = True + # sometimes the operation looks like + # "failed_to_do_something: error message" + if ':' in operation: + operation, err_msg = operation.split(': ', 1) + + while idx < len(operations) and operations[idx] != operation: + handler.handle_skipped_operation() + idx += 1 + if idx == len(operations): + handler.handle_unrecognized_operation(operation) + continue + if oper_failed: + handler.handle_failed_operation(err_msg) + else: + handler.handle_successful_operation() + idx += 1 + + raise Exception('server response ended abruptly') + + +def defer(f: Callable, *args, **kwargs): + """Defer a function's execution.""" + @functools.wraps(f) + def wrapper(): + return f(*args, **kwargs) + return wrapper diff --git a/ceo_common/interfaces/IHTTPClient.py b/ceo_common/interfaces/IHTTPClient.py index b65eec6..f42564a 100644 --- a/ceo_common/interfaces/IHTTPClient.py +++ b/ceo_common/interfaces/IHTTPClient.py @@ -4,21 +4,20 @@ from zope.interface import Interface class IHTTPClient(Interface): """A helper class for HTTP requests to ceod.""" - def request(host: str, api_path: str, method: str, delegate: bool, **kwargs): + def request(host: str, path: str, method: str, **kwargs): """ Make an HTTP request. - If `delegate` is True, GSSAPI credentials will be forwarded to the - remote. + **kwargs are passed to requests.request(). """ - def get(host: str, api_path: str, delegate: bool = True, **kwargs): + def get(host: str, path: str, **kwargs): """Make a GET request.""" - def post(host: str, api_path: str, delegate: bool = True, **kwargs): + def post(host: str, path: str, **kwargs): """Make a POST request.""" - def patch(host: str, api_path: str, delegate: bool = True, **kwargs): + def patch(host: str, path: str, **kwargs): """Make a PATCH request.""" - def delete(host: str, api_path: str, delegate: bool = True, **kwargs): + def delete(host: str, path: str, **kwargs): """Make a DELETE request.""" diff --git a/ceo_common/model/HTTPClient.py b/ceo_common/model/HTTPClient.py index d63b85a..67fbc11 100644 --- a/ceo_common/model/HTTPClient.py +++ b/ceo_common/model/HTTPClient.py @@ -1,4 +1,5 @@ import flask +from flask import g import gssapi import requests from requests_gssapi import HTTPSPNEGOAuth @@ -20,40 +21,51 @@ class HTTPClient: self.ceod_port = cfg.get('ceod_port') self.base_domain = cfg.get('base_domain') - def request(self, host: str, api_path: str, method: str, delegate: bool, **kwargs): + def request(self, method: str, host: str, path: str, **kwargs): # always use the FQDN if '.' not in host: host = host + '.' + self.base_domain + if method == 'GET': + # This is the only GET endpoint which requires auth + need_auth = path.startswith('/api/members') + delegate = False + else: + need_auth = True + delegate = True + # SPNEGO - spnego_kwargs = { - 'opportunistic_auth': True, - 'target_name': gssapi.Name('ceod/' + host), - } - if flask.has_request_context() and 'client_token' in flask.g: - # This is reached when we are the server and the client has forwarded - # their credentials to us. - spnego_kwargs['creds'] = gssapi.Credentials(token=flask.g.client_token) - if delegate: - # This is reached when we are the client and we want to forward our - # credentials to the server. - spnego_kwargs['delegate'] = True - auth = HTTPSPNEGOAuth(**spnego_kwargs) + if need_auth: + spnego_kwargs = { + 'opportunistic_auth': True, + 'target_name': gssapi.Name('ceod/' + host), + } + if flask.has_request_context() and 'client_token' in g: + # This is reached when we are the server and the client has + # forwarded their credentials to us. + spnego_kwargs['creds'] = gssapi.Credentials(token=flask.g.client_token) + elif delegate: + # This is reached when we are the client and we want to + # forward our credentials to the server. + spnego_kwargs['delegate'] = True + auth = HTTPSPNEGOAuth(**spnego_kwargs) + else: + auth = None return requests.request( method, - f'{self.scheme}://{host}:{self.ceod_port}{api_path}', + f'{self.scheme}://{host}:{self.ceod_port}{path}', auth=auth, **kwargs, ) - def get(self, host: str, api_path: str, delegate: bool = True, **kwargs): - return self.request(host, api_path, 'GET', delegate, **kwargs) + def get(self, host: str, path: str, **kwargs): + return self.request('GET', host, path, **kwargs) - def post(self, host: str, api_path: str, delegate: bool = True, **kwargs): - return self.request(host, api_path, 'POST', delegate, **kwargs) + def post(self, host: str, path: str, **kwargs): + return self.request('POST', host, path, **kwargs) - def patch(self, host: str, api_path: str, delegate: bool = True, **kwargs): - return self.request(host, api_path, 'PATCH', delegate, **kwargs) + def patch(self, host: str, path: str, **kwargs): + return self.request('PATCH', host, path, **kwargs) - def delete(self, host: str, api_path: str, delegate: bool = True, **kwargs): - return self.request(host, api_path, 'DELETE', delegate, **kwargs) + def delete(self, host: str, path: str, **kwargs): + return self.request('DELETE', host, path, **kwargs) diff --git a/ceo_common/model/RemoteMailmanService.py b/ceo_common/model/RemoteMailmanService.py index f6f23de..c1de3a2 100644 --- a/ceo_common/model/RemoteMailmanService.py +++ b/ceo_common/model/RemoteMailmanService.py @@ -15,8 +15,7 @@ class RemoteMailmanService: def subscribe(self, address: str, mailing_list: str): resp = self.http_client.post( - self.mailman_host, f'/api/mailman/{mailing_list}/{address}', - delegate=False) + self.mailman_host, f'/api/mailman/{mailing_list}/{address}') if not resp.ok: if resp.status_code == 409: raise UserAlreadySubscribedError() @@ -26,8 +25,7 @@ class RemoteMailmanService: def unsubscribe(self, address: str, mailing_list: str): resp = self.http_client.delete( - self.mailman_host, f'/api/mailman/{mailing_list}/{address}', - delegate=False) + self.mailman_host, f'/api/mailman/{mailing_list}/{address}') if not resp.ok: if resp.status_code == 404: raise UserNotSubscribedError() diff --git a/gen_cred.py b/gen_cred.py deleted file mode 100755 index f4de443..0000000 --- a/gen_cred.py +++ /dev/null @@ -1,14 +0,0 @@ -#!/usr/bin/env python3 - -from base64 import b64encode -import sys - -from ceo_common.krb5.utils import get_fwd_tgt - -if len(sys.argv) != 2: - print(f'Usage: {sys.argv[0]} ', file=sys.stderr) - sys.exit(1) - -b = get_fwd_tgt('ceod/' + sys.argv[1]) -with open('cred', 'wb') as f: - f.write(b64encode(b)) diff --git a/requirements.txt b/requirements.txt index d416050..537266e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +asciimatics==1.13.0 click==8.0.1 Flask==2.0.1 gssapi==1.6.14 diff --git a/tests/ceo/cli/test_groups.py b/tests/ceo/cli/test_groups.py index c81802e..8db4f6e 100644 --- a/tests/ceo/cli/test_groups.py +++ b/tests/ceo/cli/test_groups.py @@ -45,6 +45,7 @@ def test_groups(cli_setup, ldap_user): f"Are you sure you want to add {ldap_user.uid} to test_group_1? [y/N]: y\n" "Add user to group... Done\n" "Add user to auxiliary groups... Skipped\n" + "Subscribe user to auxiliary mailing lists... Skipped\n" "Transaction successfully completed.\n" "Added to groups: test_group_1\n" ) @@ -65,6 +66,7 @@ def test_groups(cli_setup, ldap_user): f"Are you sure you want to remove {ldap_user.uid} from test_group_1? [y/N]: y\n" "Remove user from group... Done\n" "Remove user from auxiliary groups... Skipped\n" + "Unsubscribe user from auxiliary mailing lists... Skipped\n" "Transaction successfully completed.\n" "Removed from groups: test_group_1\n" ) diff --git a/tests/ceo/cli/test_members.py b/tests/ceo/cli/test_members.py index a4fce2a..2670e69 100644 --- a/tests/ceo/cli/test_members.py +++ b/tests/ceo/cli/test_members.py @@ -12,15 +12,16 @@ def test_members_get(cli_setup, ldap_user): runner = CliRunner() result = runner.invoke(cli, ['members', 'get', ldap_user.uid]) expected = ( - f"uid: {ldap_user.uid}\n" - f"cn: {ldap_user.cn}\n" - f"program: {ldap_user.program}\n" - f"UID number: {ldap_user.uid_number}\n" - f"GID number: {ldap_user.gid_number}\n" - f"login shell: {ldap_user.login_shell}\n" - f"home directory: {ldap_user.home_directory}\n" - f"is a club: {ldap_user.is_club()}\n" - f"terms: {','.join(ldap_user.terms)}\n" + f"uid: {ldap_user.uid}\n" + f"cn: {ldap_user.cn}\n" + f"program: {ldap_user.program}\n" + f"UID number: {ldap_user.uid_number}\n" + f"GID number: {ldap_user.gid_number}\n" + f"login shell: {ldap_user.login_shell}\n" + f"home directory: {ldap_user.home_directory}\n" + f"is a club: {ldap_user.is_club()}\n" + "forwarding addresses: \n" + f"member terms: {','.join(ldap_user.terms)}\n" ) assert result.exit_code == 0 assert result.output == expected @@ -34,11 +35,11 @@ def test_members_add(cli_setup): ], input='y\n') expected_pat = re.compile(( "^The following user will be created:\n" - "uid: test_1\n" - "cn: Test One\n" - "program: Math\n" - "member terms: [sfw]\\d{4}\n" - "forwarding address: test_1@uwaterloo.internal\n" + "uid: test_1\n" + "cn: Test One\n" + "program: Math\n" + "forwarding addresses: test_1@uwaterloo.internal\n" + "member terms: [sfw]\\d{4}\n" "Do you want to continue\\? \\[y/N\\]: y\n" "Add user to LDAP... Done\n" "Add group to LDAP... Done\n" @@ -58,7 +59,7 @@ def test_members_add(cli_setup): "home directory: [a-z0-9/_-]+/test_1\n" "is a club: False\n" "forwarding addresses: test_1@uwaterloo.internal\n" - "terms: [sfw]\\d{4}\n" + "member terms: [sfw]\\d{4}\n" "password: \\S+\n$" ), re.MULTILINE) assert result.exit_code == 0 diff --git a/tests/conftest.py b/tests/conftest.py index e6c51d1..d659e93 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -369,16 +369,13 @@ def app_process(cfg, app, http_client): proc.start() try: - # Currently the HTTPClient uses SPNEGO for all requests, - # even GETs - with gssapi_token_ctx('ctdalek'): - for i in range(5): - try: - http_client.get(hostname, '/ping', delegate=False) - except requests.exceptions.ConnectionError: - time.sleep(1) - continue - break + for i in range(5): + try: + http_client.get(hostname, '/ping') + except requests.exceptions.ConnectionError: + time.sleep(1) + continue + break assert i != 5, 'Timed out' yield finally: diff --git a/tests/conftest_ceod_api.py b/tests/conftest_ceod_api.py index c2941ab..f60c37b 100644 --- a/tests/conftest_ceod_api.py +++ b/tests/conftest_ceod_api.py @@ -46,12 +46,14 @@ class CeodTestClient: headers = list(req.prepare().headers.items()) return headers - def request(self, method: str, path: str, principal: str, delegate: bool, **kwargs): + def request(self, method, path, principal, need_auth, delegate, **kwargs): # make sure that we're not already in a Flask context assert not flask.has_app_context() - if principal is None: - principal = self.syscom_principal - headers = self.get_headers(principal, delegate) + if need_auth: + principal = principal or self.syscom_principal + headers = self.get_headers(principal, delegate) + else: + headers = [] resp = self.client.open(path, method=method, headers=headers, **kwargs) status = int(resp.status.split(' ', 1)[0]) if resp.headers['content-type'] == 'application/json': @@ -60,14 +62,14 @@ class CeodTestClient: data = [json.loads(line) for line in resp.data.splitlines()] return status, data - def get(self, path, principal=None, delegate=True, **kwargs): - return self.request('GET', path, principal, delegate, **kwargs) + def get(self, path, principal=None, need_auth=True, delegate=True, **kwargs): + return self.request('GET', path, principal, need_auth, delegate, **kwargs) - def post(self, path, principal=None, delegate=True, **kwargs): - return self.request('POST', path, principal, delegate, **kwargs) + def post(self, path, principal=None, need_auth=True, delegate=True, **kwargs): + return self.request('POST', path, principal, need_auth, delegate, **kwargs) - def patch(self, path, principal=None, delegate=True, **kwargs): - return self.request('PATCH', path, principal, delegate, **kwargs) + def patch(self, path, principal=None, need_auth=True, delegate=True, **kwargs): + return self.request('PATCH', path, principal, need_auth, delegate, **kwargs) - def delete(self, path, principal=None, delegate=True, **kwargs): - return self.request('DELETE', path, principal, delegate, **kwargs) + def delete(self, path, principal=None, need_auth=True, delegate=True, **kwargs): + return self.request('DELETE', path, principal, need_auth, delegate, **kwargs) From c6c01d8720f5f2fb2fd29ef818caf2f8a2d88ea9 Mon Sep 17 00:00:00 2001 From: Andrew Wang Date: Sat, 4 Sep 2021 22:25:37 -0400 Subject: [PATCH 2/2] allow mysql connections from unix socket (#14) Co-authored-by: Andrew Wang Co-authored-by: Max Erenberg Reviewed-on: https://git.csclub.uwaterloo.ca/public/pyceo/pulls/14 Co-authored-by: Andrew Wang Co-committed-by: Andrew Wang --- ceo/cli/utils.py | 3 --- ceo/tui/TransactionView.py | 2 +- ceod/db/MySQLService.py | 16 +++++++++++----- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/ceo/cli/utils.py b/ceo/cli/utils.py index 76389c9..d50f38e 100644 --- a/ceo/cli/utils.py +++ b/ceo/cli/utils.py @@ -1,12 +1,9 @@ -import json import socket -import sys from typing import List, Tuple, Dict import click import requests -from ..operation_strings import descriptions as op_desc from ..utils import space_colon_kv, generic_handle_stream_response from .CLIStreamResponseHandler import CLIStreamResponseHandler diff --git a/ceo/tui/TransactionView.py b/ceo/tui/TransactionView.py index 1773cfd..7adc26b 100644 --- a/ceo/tui/TransactionView.py +++ b/ceo/tui/TransactionView.py @@ -45,7 +45,7 @@ class TransactionView(Frame): if self._loaded: return self._loaded = True - + for _ in range(2): self._add_line() for operation in self._model.operations: diff --git a/ceod/db/MySQLService.py b/ceod/db/MySQLService.py index 043a906..e6a194d 100644 --- a/ceod/db/MySQLService.py +++ b/ceod/db/MySQLService.py @@ -46,17 +46,21 @@ class MySQLService: password = gen_password() search_for_user = f"SELECT user FROM mysql.user WHERE user='{username}'" search_for_db = f"SHOW DATABASES LIKE '{username}'" - create_user = f""" - CREATE USER '{username}'@'%' IDENTIFIED BY %(password)s; - """ + # CREATE USER can't be used in a query with multiple statements + create_user_commands = [ + f"CREATE USER '{username}'@'localhost' IDENTIFIED BY %(password)s", + f"CREATE USER '{username}'@'%' IDENTIFIED BY %(password)s", + ] create_database = f""" CREATE DATABASE {username}; + GRANT ALL PRIVILEGES ON {username}.* TO '{username}'@'localhost'; GRANT ALL PRIVILEGES ON {username}.* TO '{username}'@'%'; """ with self.mysql_connection() as con, con.cursor() as cursor: if response_is_empty(search_for_user, con): - cursor.execute(create_user, {'password': password}) + for cmd in create_user_commands: + cursor.execute(cmd, {'password': password}) if response_is_empty(search_for_db, con): cursor.execute(create_database) else: @@ -67,7 +71,8 @@ class MySQLService: password = gen_password() search_for_user = f"SELECT user FROM mysql.user WHERE user='{username}'" reset_password = f""" - ALTER USER '{username}'@'%' IDENTIFIED BY %(password)s + ALTER USER '{username}'@'localhost' IDENTIFIED BY %(password)s; + ALTER USER '{username}'@'%' IDENTIFIED BY %(password)s; """ with self.mysql_connection() as con, con.cursor() as cursor: @@ -80,6 +85,7 @@ class MySQLService: def delete_db(self, username: str): drop_db = f"DROP DATABASE IF EXISTS {username}" drop_user = f""" + DROP USER IF EXISTS '{username}'@'localhost'; DROP USER IF EXISTS '{username}'@'%'; """