diff --git a/ceo/cli/database.py b/ceo/cli/database.py index 6a7571b..dc9bede 100644 --- a/ceo/cli/database.py +++ b/ceo/cli/database.py @@ -4,54 +4,22 @@ from typing import Dict import click from zope import component -from ..utils import http_post, http_get, http_delete -from .utils import handle_sync_response, check_file_path, check_if_in_development +from ..utils import http_post, http_get, http_delete, write_db_creds +from .utils import handle_sync_response, check_if_in_development from ceo_common.interfaces import IConfig def db_cli_response(filename: str, user_dict: Dict, password: str, db_type: str, op: str): cfg_srv = component.getUtility(IConfig) db_host = cfg_srv.get(f'{db_type}_host') - username = user_dict['uid'] if db_type == 'mysql': db_type_name = 'MySQL' - db_cli_local_cmd = f'mysql {username}' - db_cli_cmd = f'mysql {username} -h {db_host} -u {username} -p' else: db_type_name = 'PostgreSQL' - db_cli_local_cmd = f'psql {username}' - db_cli_cmd = f'psql -d {username} -h {db_host} -U {username} -W' - username = user_dict['uid'] - info = f"""{db_type_name} Database Information for {username} - -Your new {db_type_name} database was created. To connect, use the following options: - -Database: {username} -Username: {username} -Password: {password} -Host: {db_host} - -On {db_host} to connect using the {db_type_name} command-line client use - - {db_cli_local_cmd} - -From other CSC machines you can connect using - - {db_cli_cmd} -""" - wrote_to_file = False - try: - # TODO: use phosphoric-acid to write to file (phosphoric-acid makes - # internal API call to caffeine) - with click.open_file(filename, "w") as f: - f.write(info) - os.chown(filename, user_dict['uid_number'], user_dict['gid_number']) - os.chmod(filename, 0o640) - wrote_to_file = True - except PermissionError: - pass + wrote_to_file = write_db_creds(filename, user_dict, password, db_type, db_host) if op == 'create': click.echo(f'{db_type_name} database created.') + username = user_dict['uid'] click.echo(f'''Connection Information: Database: {username} @@ -71,7 +39,6 @@ def create(username: str, db_type: str): click.confirm(f'Are you sure you want to create a {db_type_name} database for {username}?', abort=True) info_file_path = os.path.join(user_dict['home_directory'], f"ceo-{db_type}-info") - check_file_path(info_file_path) resp = http_post(f'/api/db/{db_type}/{username}') result = handle_sync_response(resp) @@ -87,7 +54,6 @@ def pwreset(username: str, db_type: str): click.confirm(f'Are you sure you want reset the {db_type_name} password for {username}?', abort=True) info_file_path = os.path.join(user_dict['home_directory'], f"ceo-{db_type}-info") - check_file_path(info_file_path) resp = http_post(f'/api/db/{db_type}/{username}/pwreset') result = handle_sync_response(resp) diff --git a/ceo/cli/utils.py b/ceo/cli/utils.py index 9423a2c..294400d 100644 --- a/ceo/cli/utils.py +++ b/ceo/cli/utils.py @@ -52,15 +52,6 @@ def handle_sync_response(resp: requests.Response): return resp.json() -def check_file_path(file): - if os.path.isfile(file): - click.echo(f"{file} will be overwritten") - click.confirm('Do you want to continue?', abort=True) - elif os.path.isdir(file): - click.echo(f"Error: there exists a directory at {file}") - raise Abort() - - def check_if_in_development() -> bool: """Aborts if we are not currently in the dev environment.""" if not socket.getfqdn().endswith('.csclub.internal'): diff --git a/ceo/tui/Model.py b/ceo/tui/Model.py index e662273..1173dde 100644 --- a/ceo/tui/Model.py +++ b/ceo/tui/Model.py @@ -61,6 +61,12 @@ class Model: 'uid': '', 'unsubscribe': True, }, + 'CreateDatabase': { + 'uid': '', + }, + 'ResetDatabasePassword': { + 'uid': '', + }, } self.viewdata = deepcopy(self._initial_viewdata) # data which is shared between multiple views @@ -69,6 +75,8 @@ class Model: self.operations = None self.deferred_req = None self.resp = None + self.db_type = None + self.user_dict = None def reset(self): self.viewdata = deepcopy(self._initial_viewdata) @@ -77,6 +85,8 @@ class Model: self.operations = None self.deferred_req = None self.resp = None + self.db_type = None + self.user_dict = None self.title = None self.error_message = None self.scene_stack.clear() diff --git a/ceo/tui/ResultView.py b/ceo/tui/ResultView.py index 0a414bb..f2b64a0 100644 --- a/ceo/tui/ResultView.py +++ b/ceo/tui/ResultView.py @@ -44,7 +44,7 @@ class ResultView(CeoFrame): def _resultview_on_load(self): self._add_text() resp = self._model.resp - if resp.status_code != 200: + if not resp.ok: self._add_text('An error occurred:') if resp.headers.get('content-type') == 'application/json': err_msg = resp.json()['error'] diff --git a/ceo/tui/WelcomeView.py b/ceo/tui/WelcomeView.py index 5e4fce1..8c30d63 100644 --- a/ceo/tui/WelcomeView.py +++ b/ceo/tui/WelcomeView.py @@ -28,10 +28,10 @@ class WelcomeView(CeoFrame): ('Remove member from group', 'RemoveMemberFromGroup'), ] db_menu_items = [ - ('Create MySQL database', 'CreateMySQL'), - ('Reset MySQL password', 'ResetMySQLPassword'), - ('Create PostgreSQL database', 'CreatePostgreSQL'), - ('Reset PostgreSQL password', 'ResetPostgreSQLPassword'), + ('Create MySQL database', 'CreateDatabase'), + ('Reset MySQL password', 'ResetDatabasePassword'), + ('Create PostgreSQL database', 'CreateDatabase'), + ('Reset PostgreSQL password', 'ResetDatabasePassword'), ] positions_menu_items = [ ('Get positions', 'GetPositions'), @@ -90,6 +90,11 @@ class WelcomeView(CeoFrame): if name == 'members': if desc.endswith('club rep'): self._model.is_club_rep = True + elif name == 'databases': + if 'MySQL' in desc: + self._model.db_type = 'mysql' + else: + self._model.db_type = 'postgresql' self._welcomeview_go_to_next_scene(desc, view) def _welcomeview_go_to_next_scene(self, desc, view): diff --git a/ceo/tui/databases/CreateDatabaseResultView.py b/ceo/tui/databases/CreateDatabaseResultView.py new file mode 100644 index 0000000..1b8ffda --- /dev/null +++ b/ceo/tui/databases/CreateDatabaseResultView.py @@ -0,0 +1,34 @@ +import os + +import requests +from zope import component + +from ...utils import write_db_creds +from ..ResultView import ResultView +from ceo_common.interfaces import IConfig + + +class CreateDatabaseResultView(ResultView): + def show_result(self, resp: requests.Response): + password = resp.json()['password'] + db_type = self._model.db_type + db_type_name = 'MySQL' if db_type == 'mysql' else 'PostgreSQL' + db_host = component.getUtility(IConfig).get(f'{db_type}_host') + user_dict = self._model.user_dict + username = user_dict['uid'] + filename = os.path.join(user_dict['home_directory'], f"ceo-{db_type}-info") + wrote_to_file = write_db_creds( + filename, user_dict, password, db_type, db_host) + self._add_text(f'{db_type_name} database created.', center=True) + self._add_text() + self._add_text((f'''Connection Information: + +Database: {username} +Username: {username} +Password: {password} +Host: {db_host}''')) + self._add_text() + if wrote_to_file: + self._add_text(f"These settings have been written to {filename}.") + else: + self._add_text(f"We were unable to write these settings to {filename}.") diff --git a/ceo/tui/databases/CreateDatabaseView.py b/ceo/tui/databases/CreateDatabaseView.py new file mode 100644 index 0000000..b8ff17f --- /dev/null +++ b/ceo/tui/databases/CreateDatabaseView.py @@ -0,0 +1,44 @@ +from asciimatics.widgets import Layout, Text + +from ...utils import http_post, http_get, defer +from ..CeoFrame import CeoFrame + + +class CreateDatabaseView(CeoFrame): + def __init__(self, screen, width, height, model): + super().__init__( + screen, height, width, model, 'CreateDatabase', + save_data=True, + ) + layout = Layout([100], fill_frame=True) + self.add_layout(layout) + self._username = Text("Username:", "uid") + layout.add_widget(self._username) + self.add_buttons( + back_btn=True, next_scene='Confirm', + on_next=self._next) + self.fix() + + def _target(self): + username = self._username.value + db_type = self._model.db_type + resp = http_get(f'/api/members/{username}') + if not resp.ok: + return resp + user_dict = resp.json() + self._model.user_dict = user_dict + return http_post(f'/api/db/{db_type}/{username}') + + def _next(self): + username = self._username.value + if not username: + return + if self._model.db_type == 'mysql': + db_type_name = 'MySQL' + else: + db_type_name = 'PostgreSQL' + self._model.confirm_lines = [ + f'Are you sure you want to create a {db_type_name} database for {username}?', + ] + self._model.deferred_req = defer(self._target) + self._model.result_view_name = 'CreateDatabaseResult' diff --git a/ceo/tui/databases/ResetDatabasePasswordResultView.py b/ceo/tui/databases/ResetDatabasePasswordResultView.py new file mode 100644 index 0000000..12b45fb --- /dev/null +++ b/ceo/tui/databases/ResetDatabasePasswordResultView.py @@ -0,0 +1,29 @@ +import os + +import requests +from zope import component + +from ...utils import write_db_creds +from ..ResultView import ResultView +from ceo_common.interfaces import IConfig + + +class ResetDatabasePasswordResultView(ResultView): + def show_result(self, resp: requests.Response): + password = resp.json()['password'] + db_type = self._model.db_type + db_type_name = 'MySQL' if db_type == 'mysql' else 'PostgreSQL' + db_host = component.getUtility(IConfig).get(f'{db_type}_host') + user_dict = self._model.user_dict + username = user_dict['uid'] + filename = os.path.join(user_dict['home_directory'], f"ceo-{db_type}-info") + wrote_to_file = write_db_creds( + filename, user_dict, password, db_type, db_host) + self._add_text(f'The new {db_type_name} password for {username} is:') + self._add_text() + self._add_text(password) + self._add_text() + if wrote_to_file: + self._add_text(f"The settings in {filename} have been updated.") + else: + self._add_text(f"We were unable to update the settings in {filename}.") diff --git a/ceo/tui/databases/ResetDatabasePasswordView.py b/ceo/tui/databases/ResetDatabasePasswordView.py new file mode 100644 index 0000000..8e5075c --- /dev/null +++ b/ceo/tui/databases/ResetDatabasePasswordView.py @@ -0,0 +1,44 @@ +from asciimatics.widgets import Layout, Text + +from ...utils import http_post, http_get, defer +from ..CeoFrame import CeoFrame + + +class ResetDatabasePasswordView(CeoFrame): + def __init__(self, screen, width, height, model): + super().__init__( + screen, height, width, model, 'ResetDatabasePassword', + save_data=True, + ) + layout = Layout([100], fill_frame=True) + self.add_layout(layout) + self._username = Text("Username:", "uid") + layout.add_widget(self._username) + self.add_buttons( + back_btn=True, next_scene='Confirm', + on_next=self._next) + self.fix() + + def _target(self): + username = self._username.value + db_type = self._model.db_type + resp = http_get(f'/api/members/{username}') + if not resp.ok: + return resp + user_dict = resp.json() + self._model.user_dict = user_dict + return http_post(f'/api/db/{db_type}/{username}/pwreset') + + def _next(self): + username = self._username.value + if not username: + return + if self._model.db_type == 'mysql': + db_type_name = 'MySQL' + else: + db_type_name = 'PostgreSQL' + self._model.confirm_lines = [ + f'Are you sure you want to reset the {db_type_name} password for {username}?', + ] + self._model.deferred_req = defer(self._target) + self._model.result_view_name = 'ResetDatabasePasswordResult' diff --git a/ceo/tui/databases/__init__.py b/ceo/tui/databases/__init__.py new file mode 100644 index 0000000..8d1c8b6 --- /dev/null +++ b/ceo/tui/databases/__init__.py @@ -0,0 +1 @@ + diff --git a/ceo/tui/start.py b/ceo/tui/start.py index 2bf1825..df88af2 100644 --- a/ceo/tui/start.py +++ b/ceo/tui/start.py @@ -10,6 +10,10 @@ from .Model import Model from .ResultView import ResultView from .TransactionView import TransactionView from .WelcomeView import WelcomeView +from .databases.CreateDatabaseView import CreateDatabaseView +from .databases.CreateDatabaseResultView import CreateDatabaseResultView +from .databases.ResetDatabasePasswordView import ResetDatabasePasswordView +from .databases.ResetDatabasePasswordResultView import ResetDatabasePasswordResultView from .groups.AddGroupView import AddGroupView from .groups.AddMemberToGroupView import AddMemberToGroupView from .groups.GetGroupView import GetGroupView @@ -56,6 +60,10 @@ def screen_wrapper(screen, last_scene, model): ('GetGroupResult', GetGroupResultView(screen, width, height, model)), ('AddMemberToGroup', AddMemberToGroupView(screen, width, height, model)), ('RemoveMemberFromGroup', RemoveMemberFromGroupView(screen, width, height, model)), + ('CreateDatabase', CreateDatabaseView(screen, width, height, model)), + ('CreateDatabaseResult', CreateDatabaseResultView(screen, width, height, model)), + ('ResetDatabasePassword', ResetDatabasePasswordView(screen, width, height, model)), + ('ResetDatabasePasswordResult', ResetDatabasePasswordResultView(screen, width, height, model)), ] scenes = [ Scene([view], -1, name=name) for name, view in views diff --git a/ceo/utils.py b/ceo/utils.py index 28e4a13..a23ecd5 100644 --- a/ceo/utils.py +++ b/ceo/utils.py @@ -1,5 +1,6 @@ import functools import json +import os from typing import List, Dict, Tuple, Callable import requests @@ -193,3 +194,50 @@ def defer(f: Callable, *args, **kwargs): def wrapper(): return f(*args, **kwargs) return wrapper + + +def write_db_creds( + filename: str, + user_dict: Dict, + password: str, + db_type: str, + db_host: str, +) -> bool: + username = user_dict['uid'] + if db_type == 'mysql': + db_type_name = 'MySQL' + db_cli_local_cmd = f'mysql {username}' + db_cli_cmd = f'mysql {username} -h {db_host} -u {username} -p' + else: + db_type_name = 'PostgreSQL' + db_cli_local_cmd = f'psql {username}' + db_cli_cmd = f'psql -d {username} -h {db_host} -U {username} -W' + info = f"""{db_type_name} Database Information for {username} + +Your new {db_type_name} database was created. To connect, use the following options: + +Database: {username} +Username: {username} +Password: {password} +Host: {db_host} + +On {db_host} to connect using the {db_type_name} command-line client use + + {db_cli_local_cmd} + +From other CSC machines you can connect using + + {db_cli_cmd} +""" + try: + # TODO: use phosphoric-acid to write to file (phosphoric-acid makes + # internal API call to caffeine) + if os.path.isfile(filename): + os.rename(filename, filename + '.bak') + with open(filename, "w") as f: + f.write(info) + os.chown(filename, user_dict['uid_number'], user_dict['gid_number']) + os.chmod(filename, 0o640) + return True + except PermissionError: + return False diff --git a/tests/ceo/cli/test_db_mysql.py b/tests/ceo/cli/test_db_mysql.py index 6418b0c..569c3c2 100644 --- a/tests/ceo/cli/test_db_mysql.py +++ b/tests/ceo/cli/test_db_mysql.py @@ -1,4 +1,5 @@ import os +import shutil from click.testing import CliRunner from mysql.connector import connect @@ -33,6 +34,7 @@ def test_mysql(cli_setup, cfg, ldap_user): # create database for user result = runner.invoke(cli, ['mysql', 'create', username], input='y\n') + print(result.output) assert result.exit_code == 0 assert os.path.isfile(info_file_path) @@ -57,8 +59,7 @@ These settings have been written to {info_file_path}. mysql_attempt_connection(host, username, passwd) # perform password reset for user - # confirm once to reset password, another to overwrite the file - result = runner.invoke(cli, ['mysql', 'pwreset', username], input="y\ny\n") + result = runner.invoke(cli, ['mysql', 'pwreset', username], input="y\n") assert result.exit_code == 0 response_arr = result.output.split() @@ -78,5 +79,4 @@ These settings have been written to {info_file_path}. with pytest.raises(ProgrammingError): mysql_attempt_connection(host, username, passwd) - os.remove(info_file_path) - os.rmdir(ldap_user.home_directory) + shutil.rmtree(ldap_user.home_directory) diff --git a/tests/ceo/cli/test_db_postgresql.py b/tests/ceo/cli/test_db_postgresql.py index ddd2353..1421bdc 100644 --- a/tests/ceo/cli/test_db_postgresql.py +++ b/tests/ceo/cli/test_db_postgresql.py @@ -1,5 +1,6 @@ import pytest import os +import shutil from click.testing import CliRunner from ceo.cli import cli @@ -59,8 +60,7 @@ These settings have been written to {info_file_path}. psql_attempt_connection(host, username, passwd) # perform password reset for user - # confirm once to reset password, another to overwrite the file - result = runner.invoke(cli, ['postgresql', 'pwreset', username], input="y\ny\n") + result = runner.invoke(cli, ['postgresql', 'pwreset', username], input="y\n") assert result.exit_code == 0 response_arr = result.output.split() @@ -80,5 +80,4 @@ These settings have been written to {info_file_path}. with pytest.raises(OperationalError): psql_attempt_connection(host, username, passwd) - os.remove(info_file_path) - os.rmdir(ldap_user.home_directory) + shutil.rmtree(ldap_user.home_directory)