241 lines
9.4 KiB
Python
241 lines
9.4 KiB
Python
import glob
|
|
import ipaddress
|
|
import os
|
|
import re
|
|
import shutil
|
|
import subprocess
|
|
from typing import List, Dict, Tuple, Union
|
|
|
|
import jinja2
|
|
from zope import component
|
|
from zope.interface import implementer
|
|
|
|
from .utils import rate_limit
|
|
from ceo_common.errors import InvalidDomainError, InvalidIPError
|
|
from ceo_common.logger_factory import logger_factory
|
|
from ceo_common.interfaces import IVHostManager, IConfig
|
|
|
|
PROXY_PASS_IP_RE = re.compile(
|
|
r'^\s+proxy_pass\s+http://(?P<ip_address>[\w.:-]+);$'
|
|
)
|
|
VHOST_FILENAME_RE = re.compile(r'^(?P<username>[0-9a-z-]+)_(?P<domain>[0-9a-z.-]+)$')
|
|
VALID_DOMAIN_RE = re.compile(r'^(?:[0-9a-z-]+\.)+[a-z]+$')
|
|
IP_WITH_PORT_RE = re.compile(r'^(?P<ip_address>[\d.]+)(:\d{2,5})?$')
|
|
logger = logger_factory(__name__)
|
|
|
|
|
|
@implementer(IVHostManager)
|
|
class VHostManager:
|
|
def __init__(self):
|
|
cfg = component.getUtility(IConfig)
|
|
|
|
self.vhost_dir = cfg.get('cloud vhosts_vhost_dir')
|
|
self.ssl_dir = cfg.get('cloud vhosts_ssl_dir')
|
|
if not os.path.exists(self.vhost_dir):
|
|
os.makedirs(self.vhost_dir)
|
|
if not os.path.exists(self.ssl_dir):
|
|
os.makedirs(self.ssl_dir)
|
|
|
|
self.default_ssl_cert = cfg.get('cloud vhosts_default_ssl_cert')
|
|
self.default_ssl_key = cfg.get('cloud vhosts_default_ssl_key')
|
|
self.vhost_domain = cfg.get('cloud vhosts_members_domain')
|
|
self.vhost_domain_re = re.compile(
|
|
r'^[a-z0-9-]+\.' + self.vhost_domain.replace('.', r'\.') + '$'
|
|
)
|
|
|
|
self.k8s_vhost_domain = cfg.get('cloud vhosts_k8s_members_domain')
|
|
self.k8s_vhost_domain_re = re.compile(
|
|
r'^[a-z0-9-]+\.' + self.k8s_vhost_domain.replace('.', r'\.') + '$'
|
|
)
|
|
self.k8s_ssl_cert = cfg.get('cloud vhosts_k8s_ssl_cert')
|
|
self.k8s_ssl_key = cfg.get('cloud vhosts_k8s_ssl_key')
|
|
|
|
self.max_vhosts_per_account = cfg.get('cloud vhosts_max_vhosts_per_account')
|
|
self.vhost_ip_min = ipaddress.ip_address(cfg.get('cloud vhosts_ip_range_min'))
|
|
self.vhost_ip_max = ipaddress.ip_address(cfg.get('cloud vhosts_ip_range_max'))
|
|
self.reload_web_server_cmd = cfg.get('cloud vhosts_reload_web_server_cmd')
|
|
|
|
self.acme_challenge_dir = cfg.get('cloud vhosts_acme_challenge_dir')
|
|
self.acme_dir = '/root/.acme.sh'
|
|
self.acme_sh = os.path.join(self.acme_dir, 'acme.sh')
|
|
|
|
self.jinja_env = jinja2.Environment(
|
|
loader=jinja2.PackageLoader('ceod.model'),
|
|
keep_trailing_newline=True,
|
|
)
|
|
|
|
rate_limit_secs = cfg.get('cloud vhosts_rate_limit_seconds')
|
|
self.create_vhost = \
|
|
rate_limit('create_vhost', rate_limit_secs)(self.create_vhost)
|
|
|
|
@staticmethod
|
|
def _vhost_filename(username: str, domain: str) -> str:
|
|
"""Generate a filename for the vhost record"""
|
|
# sanity check...
|
|
assert '..' not in domain and '/' not in domain
|
|
return username + '_' + domain
|
|
|
|
def _vhost_filepath(self, username: str, domain: str) -> str:
|
|
"""Generate an absolute path for the vhost record"""
|
|
return os.path.join(self.vhost_dir, self._vhost_filename(username, domain))
|
|
|
|
def _vhost_files(self, username: str) -> List[str]:
|
|
"""Return a list of all vhost files for this user."""
|
|
return glob.glob(os.path.join(self.vhost_dir, username + '_*'))
|
|
|
|
def _run(self, args: Union[List[str], str], **kwargs):
|
|
subprocess.run(args, check=True, **kwargs)
|
|
|
|
def _reload_web_server(self):
|
|
logger.debug('Reloading NGINX')
|
|
self._run(self.reload_web_server_cmd, shell=True)
|
|
|
|
def is_valid_domain(self, username: str, domain: str) -> bool:
|
|
if VALID_DOMAIN_RE.match(domain) is None:
|
|
return False
|
|
if len(domain) > 80:
|
|
return False
|
|
|
|
if domain.endswith('.' + self.k8s_vhost_domain):
|
|
prefix = domain[:len(domain) - len(self.k8s_vhost_domain) - 1]
|
|
elif domain.endswith('.' + self.vhost_domain):
|
|
prefix = domain[:len(domain) - len(self.vhost_domain) - 1]
|
|
else:
|
|
return False
|
|
last_part = prefix.split('.')[-1]
|
|
|
|
if last_part == username:
|
|
return True
|
|
if last_part.endswith('-' + username):
|
|
return True
|
|
|
|
return False
|
|
|
|
def is_valid_ip_address(self, ip_address: str) -> bool:
|
|
if ip_address == 'k8s':
|
|
# special case - this is an NGINX upstream
|
|
return True
|
|
|
|
# strip off the port number, if there is one
|
|
match = IP_WITH_PORT_RE.match(ip_address)
|
|
if match is None:
|
|
return False
|
|
ip_address = match.group('ip_address')
|
|
|
|
# make sure the IP is in the allowed range
|
|
try:
|
|
addr = ipaddress.ip_address(ip_address)
|
|
except ValueError:
|
|
return False
|
|
return self.vhost_ip_min <= addr <= self.vhost_ip_max
|
|
|
|
def _get_cert_and_key_path(self, domain: str) -> Tuple[str, str]:
|
|
# Use the wildcard certs, if possible
|
|
if self.vhost_domain_re.match(domain) is not None:
|
|
return self.default_ssl_cert, self.default_ssl_key
|
|
elif self.k8s_vhost_domain_re.match(domain) is not None:
|
|
return self.k8s_ssl_cert, self.k8s_ssl_key
|
|
# Otherwise, obtain a new cert with acme.sh
|
|
cert_path = f'{self.ssl_dir}/{domain}.chain'
|
|
key_path = f'{self.ssl_dir}/{domain}.key'
|
|
return cert_path, key_path
|
|
|
|
def _acquire_new_cert(self, domain: str, cert_path: str, key_path: str):
|
|
logger.info(f'issuing new certificate for {domain}')
|
|
self._run([
|
|
self.acme_sh, '--issue', '-d', domain,
|
|
'-w', self.acme_challenge_dir,
|
|
])
|
|
logger.info(f'installing new certificate for {domain}')
|
|
self._run([
|
|
self.acme_sh, '--install-cert', '-d', domain,
|
|
'--key-file', key_path,
|
|
'--fullchain-file', cert_path,
|
|
'--reloadcmd', self.reload_web_server_cmd,
|
|
])
|
|
|
|
def _delete_cert(self, domain: str, cert_path: str, key_path: str):
|
|
logger.info(f'removing certificate for {domain}')
|
|
self._run([self.acme_sh, '--remove', '-d', domain])
|
|
if os.path.exists(os.path.join(self.acme_dir, domain)):
|
|
shutil.rmtree(os.path.join(self.acme_dir, domain))
|
|
os.unlink(cert_path)
|
|
os.unlink(key_path)
|
|
|
|
def create_vhost(self, username: str, domain: str, ip_address: str):
|
|
if self.get_num_vhosts(username) >= self.max_vhosts_per_account:
|
|
raise Exception(f'Only {self.max_vhosts_per_account} vhosts '
|
|
'allowed per account')
|
|
if not self.is_valid_domain(username, domain):
|
|
raise InvalidDomainError()
|
|
if not self.is_valid_ip_address(ip_address):
|
|
raise InvalidIPError()
|
|
|
|
cert_path, key_path = self._get_cert_and_key_path(domain)
|
|
if not (os.path.exists(cert_path) and os.path.exists(key_path)):
|
|
self._acquire_new_cert(domain, cert_path, key_path)
|
|
|
|
template = self.jinja_env.get_template('nginx_cloud_vhost_config.j2')
|
|
body = template.render(
|
|
username=username, domain=domain, ip_address=ip_address,
|
|
ssl_cert_path=cert_path, ssl_key_path=key_path)
|
|
filepath = self._vhost_filepath(username, domain)
|
|
logger.info(f'Writing a new vhost ({domain} -> {ip_address}) to {filepath}')
|
|
with open(filepath, 'w') as fo:
|
|
fo.write(body)
|
|
self._reload_web_server()
|
|
|
|
def delete_vhost(self, username: str, domain: str):
|
|
if not self.is_valid_domain(username, domain):
|
|
raise InvalidDomainError()
|
|
|
|
cert_path, key_path = self._get_cert_and_key_path(domain)
|
|
if cert_path not in [self.default_ssl_cert, self.k8s_ssl_cert] \
|
|
and cert_path.startswith(self.ssl_dir) \
|
|
and os.path.exists(cert_path) and os.path.exists(key_path):
|
|
self._delete_cert(domain, cert_path, key_path)
|
|
|
|
filepath = self._vhost_filepath(username, domain)
|
|
logger.info(f'Deleting {filepath}')
|
|
os.unlink(filepath)
|
|
self._reload_web_server()
|
|
|
|
def get_num_vhosts(self, username: str) -> int:
|
|
return len(self._vhost_files(username))
|
|
|
|
def get_vhosts(self, username: str) -> List[Dict]:
|
|
vhosts = []
|
|
for filepath in self._vhost_files(username):
|
|
filename = os.path.basename(filepath)
|
|
match = VHOST_FILENAME_RE.match(filename)
|
|
assert match is not None, f"'{filename}' does not match expected pattern"
|
|
domain = match.group('domain')
|
|
ip_address = None
|
|
for line in open(filepath):
|
|
match = PROXY_PASS_IP_RE.match(line)
|
|
if match is None:
|
|
continue
|
|
ip_address = match.group('ip_address')
|
|
break
|
|
assert ip_address is not None, f"Could not find IP address in {filename}"
|
|
vhosts.append({'domain': domain, 'ip_address': ip_address})
|
|
return vhosts
|
|
|
|
def delete_all_vhosts_for_user(self, username: str):
|
|
filepaths = self._vhost_files(username)
|
|
if not filepaths:
|
|
return
|
|
for filepath in filepaths:
|
|
logger.info(f'Deleting {filepath}')
|
|
os.unlink(filepath)
|
|
self._reload_web_server()
|
|
|
|
def get_accounts(self) -> List[str]:
|
|
vhost_files = os.listdir(self.vhost_dir)
|
|
usernames = list({
|
|
filename.split('_', 1)[0]
|
|
for filename in vhost_files
|
|
if '_' in filename
|
|
})
|
|
return usernames
|