Compare commits

...

2 Commits

Author SHA1 Message Date
Sean Zhang 0985ed5d3a postgres/mysql service updates
updated to use sql safe format strings
2023-05-22 18:35:12 -04:00
Sean Zhang 0c125de1d7 passing parameters to sql queries without fstrings 2022-11-20 17:19:55 -05:00
3 changed files with 43 additions and 16 deletions

View File

@ -63,7 +63,7 @@ class MySQLService:
cursor.execute(create_local_user_cmd)
cursor.execute(create_user_cmd, {'password': password})
if response_is_empty(search_for_db, con):
cursor.execute(create_database)
cursor.execute(create_database, multi=True)
return password
def reset_db_passwd(self, username: str) -> str:
@ -81,11 +81,14 @@ class MySQLService:
def delete_db(self, username: str):
drop_db = f"DROP DATABASE IF EXISTS {username}"
drop_user = f"""
DROP USER IF EXISTS '{username}'@'localhost';
DROP USER IF EXISTS '{username}'@'%';
"""
drop_user = f"DROP USER IF EXISTS %(username)s@'localhost'"
drop_user_2 = f"DROP USER IF EXISTS %(username)s@'%'"
args_dict = {
'username': username
}
with self.mysql_connection() as con, con.cursor() as cursor:
cursor.execute(drop_db)
cursor.execute(drop_user)
cursor.execute(drop_user, args_dict)
cursor.execute(drop_user_2, args_dict)

View File

@ -9,7 +9,7 @@ from ceo_common.logger_factory import logger_factory
from ceod.utils import gen_password
from ceod.db.utils import response_is_empty
from psycopg2 import connect, OperationalError, ProgrammingError
from psycopg2 import connect, OperationalError, ProgrammingError, sql
from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT
logger = logger_factory(__name__)
@ -53,11 +53,21 @@ class PostgreSQLService:
def create_db(self, username: str) -> str:
password = gen_password()
search_for_user = f"SELECT FROM pg_roles WHERE rolname='{username}'"
search_for_db = f"SELECT FROM pg_database WHERE datname='{username}'"
create_user = f"CREATE USER {username} WITH PASSWORD %(password)s"
create_database = f"CREATE DATABASE {username} OWNER {username}"
revoke_perms = f"REVOKE ALL ON DATABASE {username} FROM PUBLIC"
search_for_user = sql.SQL(
"SELECT FROM pg_roles WHERE rolname={username}"
).format(username=sql.Literal(username))
search_for_db = sql.SQL(
"SELECT FROM pg_database WHERE datname={username}"
).format(username=sql.Literal(username))
create_user = sql.SQL(
"CREATE USER {username} WITH PASSWORD %(password)s"
).format(username=sql.Identifier(username))
create_database = sql.SQL(
"CREATE DATABASE {username} OWNER {username}"
).format(username=sql.Identifier(username))
revoke_perms = sql.SQL(
"REVOKE ALL ON DATABASE {username} FROM PUBLIC"
).format(username=sql.Identifier(username))
with self.psql_connection() as con, con.cursor() as cursor:
if not response_is_empty(search_for_user, con):
@ -70,8 +80,12 @@ class PostgreSQLService:
def reset_db_passwd(self, username: str) -> str:
password = gen_password()
search_for_user = f"SELECT FROM pg_roles WHERE rolname='{username}'"
reset_password = f"ALTER USER {username} WITH PASSWORD %(password)s"
search_for_user = sql.SQL(
"SELECT FROM pg_roles WHERE rolname={username}"
).format(username=sql.Literal(username))
reset_password = sql.SQL(
"ALTER USER {username} WITH PASSWORD %(password)s"
).format(username=sql.Identifier(username))
with self.psql_connection() as con, con.cursor() as cursor:
if response_is_empty(search_for_user, con):
@ -80,8 +94,12 @@ class PostgreSQLService:
return password
def delete_db(self, username: str):
drop_db = f"DROP DATABASE IF EXISTS {username}"
drop_user = f"DROP USER IF EXISTS {username}"
drop_db = sql.SQL(
"DROP DATABASE IF EXISTS {username}"
).format(username=sql.Identifier(username))
drop_user = sql.SQL(
"DROP USER IF EXISTS {username}"
).format(username=sql.Identifier(username))
with self.psql_connection() as con, con.cursor() as cursor:
cursor.execute(drop_db)

View File

@ -3,3 +3,9 @@ def response_is_empty(query: str, connection) -> bool:
cursor.execute(query)
response = cursor.fetchall()
return len(response) == 0
def mysql_response_is_empty(query: str, args_dict, connection) -> bool:
with connection.cursor() as cursor:
cursor.execute(query, args_dict)
response = cursor.fetchall()
return len(response) == 0