import glob import ipaddress import os import re import shutil import subprocess from typing import List, Dict, Tuple import jinja2 from zope import component from zope.interface import implementer from .utils import rate_limit from ceo_common.errors import InvalidDomainError, InvalidIPError from ceo_common.logger_factory import logger_factory from ceo_common.interfaces import IVHostManager, IConfig PROXY_PASS_IP_RE = re.compile( r'^\s+proxy_pass\s+http://(?P[\w.:-]+);$' ) VHOST_FILENAME_RE = re.compile(r'^(?P[0-9a-z-]+)_(?P[0-9a-z.-]+)$') VALID_DOMAIN_RE = re.compile(r'^(?:[0-9a-z-]+\.)+[a-z]+$') IP_WITH_PORT_RE = re.compile(r'^(?P[\d.]+)(:\d{2,5})?$') logger = logger_factory(__name__) @implementer(IVHostManager) class VHostManager: def __init__(self): cfg = component.getUtility(IConfig) self.vhost_dir = cfg.get('cloud vhosts_vhost_dir') self.ssl_dir = cfg.get('cloud vhosts_ssl_dir') if not os.path.exists(self.vhost_dir): os.makedirs(self.vhost_dir) if not os.path.exists(self.ssl_dir): os.makedirs(self.ssl_dir) self.default_ssl_cert = cfg.get('cloud vhosts_default_ssl_cert') self.default_ssl_key = cfg.get('cloud vhosts_default_ssl_key') self.vhost_domain = cfg.get('cloud vhosts_members_domain') self.vhost_domain_re = re.compile( r'^[a-z0-9-]+\.' + self.vhost_domain.replace('.', r'\.') + '$' ) self.k8s_vhost_domain = cfg.get('cloud vhosts_k8s_members_domain') self.k8s_vhost_domain_re = re.compile( r'^[a-z0-9-]+\.' + self.k8s_vhost_domain.replace('.', r'\.') + '$' ) self.k8s_ssl_cert = cfg.get('cloud vhosts_k8s_ssl_cert') self.k8s_ssl_key = cfg.get('cloud vhosts_k8s_ssl_key') self.max_vhosts_per_account = cfg.get('cloud vhosts_max_vhosts_per_account') self.vhost_ip_min = ipaddress.ip_address(cfg.get('cloud vhosts_ip_range_min')) self.vhost_ip_max = ipaddress.ip_address(cfg.get('cloud vhosts_ip_range_max')) self.acme_challenge_dir = cfg.get('cloud vhosts_acme_challenge_dir') self.acme_dir = '/root/.acme.sh' self.acme_sh = os.path.join(self.acme_dir, 'acme.sh') self.jinja_env = jinja2.Environment( loader=jinja2.PackageLoader('ceod.model'), keep_trailing_newline=True, ) rate_limit_secs = cfg.get('cloud vhosts_rate_limit_seconds') self.create_vhost = \ rate_limit('create_vhost', rate_limit_secs)(self.create_vhost) @staticmethod def _vhost_filename(username: str, domain: str) -> str: """Generate a filename for the vhost record""" # sanity check... assert '..' not in domain and '/' not in domain return username + '_' + domain def _vhost_filepath(self, username: str, domain: str) -> str: """Generate an absolute path for the vhost record""" return os.path.join(self.vhost_dir, self._vhost_filename(username, domain)) def _vhost_files(self, username: str) -> List[str]: """Return a list of all vhost files for this user.""" return glob.glob(os.path.join(self.vhost_dir, username + '_*')) def _run(self, args: List[str]): subprocess.run(args, check=True) def _reload_web_server(self): logger.debug('Reloading NGINX') self._run(['systemctl', 'reload', 'nginx']) def is_valid_domain(self, username: str, domain: str) -> bool: if VALID_DOMAIN_RE.match(domain) is None: return False if len(domain) > 80: return False if domain.endswith('.' + self.k8s_vhost_domain): prefix = domain[:len(domain) - len(self.k8s_vhost_domain) - 1] elif domain.endswith('.' + self.vhost_domain): prefix = domain[:len(domain) - len(self.vhost_domain) - 1] else: return False last_part = prefix.split('.')[-1] if last_part == username: return True if last_part.endswith('-' + username): return True return False def is_valid_ip_address(self, ip_address: str) -> bool: if ip_address == 'k8s': # special case - this is an NGINX upstream return True # strip off the port number, if there is one match = IP_WITH_PORT_RE.match(ip_address) if match is None: return False ip_address = match.group('ip_address') # make sure the IP is in the allowed range try: addr = ipaddress.ip_address(ip_address) except ValueError: return False return self.vhost_ip_min <= addr <= self.vhost_ip_max def _get_cert_and_key_path(self, domain: str) -> Tuple[str, str]: # Use the wildcard certs, if possible if self.vhost_domain_re.match(domain) is not None: return self.default_ssl_cert, self.default_ssl_key elif self.k8s_vhost_domain_re.match(domain) is not None: return self.k8s_ssl_cert, self.k8s_ssl_key # Otherwise, obtain a new cert with acme.sh cert_path = f'{self.ssl_dir}/{domain}.chain' key_path = f'{self.ssl_dir}/{domain}.key' return cert_path, key_path def _acquire_new_cert(self, domain: str, cert_path: str, key_path: str): logger.info(f'issuing new certificate for {domain}') self._run([ self.acme_sh, '--issue', '-d', domain, '-w', self.acme_challenge_dir, ]) logger.info(f'installing new certificate for {domain}') self._run([ self.acme_sh, '--install-cert', '-d', domain, '--key-file', key_path, '--fullchain-file', cert_path, '--reloadcmd', 'systemctl reload nginx', ]) def _delete_cert(self, domain: str, cert_path: str, key_path: str): logger.info(f'removing certificate for {domain}') self._run([self.acme_sh, '--remove', '-d', domain]) if os.path.exists(os.path.join(self.acme_dir, domain)): shutil.rmtree(os.path.join(self.acme_dir, domain)) os.unlink(cert_path) os.unlink(key_path) def create_vhost(self, username: str, domain: str, ip_address: str): if self.get_num_vhosts(username) >= self.max_vhosts_per_account: raise Exception(f'Only {self.max_vhosts_per_account} vhosts ' 'allowed per account') if not self.is_valid_domain(username, domain): raise InvalidDomainError() if not self.is_valid_ip_address(ip_address): raise InvalidIPError() cert_path, key_path = self._get_cert_and_key_path(domain) if not (os.path.exists(cert_path) and os.path.exists(key_path)): self._acquire_new_cert(domain, cert_path, key_path) template = self.jinja_env.get_template('nginx_cloud_vhost_config.j2') body = template.render( username=username, domain=domain, ip_address=ip_address, ssl_cert_path=cert_path, ssl_key_path=key_path) filepath = self._vhost_filepath(username, domain) logger.info(f'Writing a new vhost ({domain} -> {ip_address}) to {filepath}') with open(filepath, 'w') as fo: fo.write(body) self._reload_web_server() def delete_vhost(self, username: str, domain: str): if not self.is_valid_domain(username, domain): raise InvalidDomainError() cert_path, key_path = self._get_cert_and_key_path(domain) if cert_path not in [self.default_ssl_cert, self.k8s_ssl_cert] \ and cert_path.startswith(self.ssl_dir) \ and os.path.exists(cert_path) and os.path.exists(key_path): self._delete_cert(domain, cert_path, key_path) filepath = self._vhost_filepath(username, domain) logger.info(f'Deleting {filepath}') os.unlink(filepath) self._reload_web_server() def get_num_vhosts(self, username: str) -> int: return len(self._vhost_files(username)) def get_vhosts(self, username: str) -> List[Dict]: vhosts = [] for filepath in self._vhost_files(username): filename = os.path.basename(filepath) match = VHOST_FILENAME_RE.match(filename) assert match is not None, f"'{filename}' does not match expected pattern" domain = match.group('domain') ip_address = None for line in open(filepath): match = PROXY_PASS_IP_RE.match(line) if match is None: continue ip_address = match.group('ip_address') break assert ip_address is not None, f"Could not find IP address in {filename}" vhosts.append({'domain': domain, 'ip_address': ip_address}) return vhosts def delete_all_vhosts_for_user(self, username: str): filepaths = self._vhost_files(username) if not filepaths: return for filepath in filepaths: logger.info(f'Deleting {filepath}') os.unlink(filepath) self._reload_web_server() def get_accounts(self) -> List[str]: vhost_files = os.listdir(self.vhost_dir) usernames = list({ filename.split('_', 1)[0] for filename in vhost_files if '_' in filename }) return usernames