pyceo/ceod/model/VHostManager.py

240 lines
9.3 KiB
Python

import glob
import ipaddress
import os
import re
import shutil
import subprocess
from typing import List, Dict, Tuple
import jinja2
from zope import component
from zope.interface import implementer
from .utils import rate_limit
from ceo_common.errors import InvalidDomainError, InvalidIPError
from ceo_common.logger_factory import logger_factory
from ceo_common.interfaces import IVHostManager, IConfig
PROXY_PASS_IP_RE = re.compile(
r'^\s+proxy_pass\s+http://(?P<ip_address>[\w.:-]+);$'
)
VHOST_FILENAME_RE = re.compile(r'^(?P<username>[0-9a-z-]+)_(?P<domain>[0-9a-z.-]+)$')
VALID_DOMAIN_RE = re.compile(r'^(?:[0-9a-z-]+\.)+[a-z]+$')
IP_WITH_PORT_RE = re.compile(r'^(?P<ip_address>[\d.]+)(:\d{2,5})?$')
logger = logger_factory(__name__)
@implementer(IVHostManager)
class VHostManager:
def __init__(self):
cfg = component.getUtility(IConfig)
self.vhost_dir = cfg.get('cloud vhosts_vhost_dir')
self.ssl_dir = cfg.get('cloud vhosts_ssl_dir')
if not os.path.exists(self.vhost_dir):
os.makedirs(self.vhost_dir)
if not os.path.exists(self.ssl_dir):
os.makedirs(self.ssl_dir)
self.default_ssl_cert = cfg.get('cloud vhosts_default_ssl_cert')
self.default_ssl_key = cfg.get('cloud vhosts_default_ssl_key')
self.vhost_domain = cfg.get('cloud vhosts_members_domain')
self.vhost_domain_re = re.compile(
r'^[a-z0-9-]+\.' + self.vhost_domain.replace('.', r'\.') + '$'
)
self.k8s_vhost_domain = cfg.get('cloud vhosts_k8s_members_domain')
self.k8s_vhost_domain_re = re.compile(
r'^[a-z0-9-]+\.' + self.k8s_vhost_domain.replace('.', r'\.') + '$'
)
self.k8s_ssl_cert = cfg.get('cloud vhosts_k8s_ssl_cert')
self.k8s_ssl_key = cfg.get('cloud vhosts_k8s_ssl_key')
self.max_vhosts_per_account = cfg.get('cloud vhosts_max_vhosts_per_account')
self.vhost_ip_min = ipaddress.ip_address(cfg.get('cloud vhosts_ip_range_min'))
self.vhost_ip_max = ipaddress.ip_address(cfg.get('cloud vhosts_ip_range_max'))
self.acme_challenge_dir = cfg.get('cloud vhosts_acme_challenge_dir')
self.acme_dir = '/root/.acme.sh'
self.acme_sh = os.path.join(self.acme_dir, 'acme.sh')
self.jinja_env = jinja2.Environment(
loader=jinja2.PackageLoader('ceod.model'),
keep_trailing_newline=True,
)
rate_limit_secs = cfg.get('cloud vhosts_rate_limit_seconds')
self.create_vhost = \
rate_limit('create_vhost', rate_limit_secs)(self.create_vhost)
@staticmethod
def _vhost_filename(username: str, domain: str) -> str:
"""Generate a filename for the vhost record"""
# sanity check...
assert '..' not in domain and '/' not in domain
return username + '_' + domain
def _vhost_filepath(self, username: str, domain: str) -> str:
"""Generate an absolute path for the vhost record"""
return os.path.join(self.vhost_dir, self._vhost_filename(username, domain))
def _vhost_files(self, username: str) -> List[str]:
"""Return a list of all vhost files for this user."""
return glob.glob(os.path.join(self.vhost_dir, username + '_*'))
def _run(self, args: List[str]):
subprocess.run(args, check=True)
def _reload_web_server(self):
logger.debug('Reloading NGINX')
self._run(['systemctl', 'reload', 'nginx'])
def is_valid_domain(self, username: str, domain: str) -> bool:
if VALID_DOMAIN_RE.match(domain) is None:
return False
if len(domain) > 80:
return False
if domain.endswith('.' + self.k8s_vhost_domain):
prefix = domain[:len(domain) - len(self.k8s_vhost_domain) - 1]
elif domain.endswith('.' + self.vhost_domain):
prefix = domain[:len(domain) - len(self.vhost_domain) - 1]
else:
return False
last_part = prefix.split('.')[-1]
if last_part == username:
return True
if last_part.endswith('-' + username):
return True
return False
def is_valid_ip_address(self, ip_address: str) -> bool:
if ip_address == 'k8s':
# special case - this is an NGINX upstream
return True
# strip off the port number, if there is one
match = IP_WITH_PORT_RE.match(ip_address)
if match is None:
return False
ip_address = match.group('ip_address')
# make sure the IP is in the allowed range
try:
addr = ipaddress.ip_address(ip_address)
except ValueError:
return False
return self.vhost_ip_min <= addr <= self.vhost_ip_max
def _get_cert_and_key_path(self, domain: str) -> Tuple[str, str]:
# Use the wildcard certs, if possible
if self.vhost_domain_re.match(domain) is not None:
return self.default_ssl_cert, self.default_ssl_key
elif self.k8s_vhost_domain_re.match(domain) is not None:
return self.k8s_ssl_cert, self.k8s_ssl_key
# Otherwise, obtain a new cert with acme.sh
cert_path = f'{self.ssl_dir}/{domain}.chain'
key_path = f'{self.ssl_dir}/{domain}.key'
return cert_path, key_path
def _acquire_new_cert(self, domain: str, cert_path: str, key_path: str):
logger.info(f'issuing new certificate for {domain}')
self._run([
self.acme_sh, '--issue', '-d', domain,
'-w', self.acme_challenge_dir,
])
logger.info(f'installing new certificate for {domain}')
self._run([
self.acme_sh, '--install-cert', '-d', domain,
'--key-file', key_path,
'--fullchain-file', cert_path,
'--reloadcmd', 'systemctl reload nginx',
])
def _delete_cert(self, domain: str, cert_path: str, key_path: str):
logger.info(f'removing certificate for {domain}')
self._run([self.acme_sh, '--remove', '-d', domain])
if os.path.exists(os.path.join(self.acme_dir, domain)):
shutil.rmtree(os.path.join(self.acme_dir, domain))
os.unlink(cert_path)
os.unlink(key_path)
def create_vhost(self, username: str, domain: str, ip_address: str):
if self.get_num_vhosts(username) >= self.max_vhosts_per_account:
raise Exception(f'Only {self.max_vhosts_per_account} vhosts '
'allowed per account')
if not self.is_valid_domain(username, domain):
raise InvalidDomainError()
if not self.is_valid_ip_address(ip_address):
raise InvalidIPError()
cert_path, key_path = self._get_cert_and_key_path(domain)
if not (os.path.exists(cert_path) and os.path.exists(key_path)):
self._acquire_new_cert(domain, cert_path, key_path)
template = self.jinja_env.get_template('nginx_cloud_vhost_config.j2')
body = template.render(
username=username, domain=domain, ip_address=ip_address,
ssl_cert_path=cert_path, ssl_key_path=key_path)
filepath = self._vhost_filepath(username, domain)
logger.info(f'Writing a new vhost ({domain} -> {ip_address}) to {filepath}')
with open(filepath, 'w') as fo:
fo.write(body)
self._reload_web_server()
def delete_vhost(self, username: str, domain: str):
if not self.is_valid_domain(username, domain):
raise InvalidDomainError()
cert_path, key_path = self._get_cert_and_key_path(domain)
if cert_path not in [self.default_ssl_cert, self.k8s_ssl_cert] \
and cert_path.startswith(self.ssl_dir) \
and os.path.exists(cert_path) and os.path.exists(key_path):
self._delete_cert(domain, cert_path, key_path)
filepath = self._vhost_filepath(username, domain)
logger.info(f'Deleting {filepath}')
os.unlink(filepath)
self._reload_web_server()
def get_num_vhosts(self, username: str) -> int:
return len(self._vhost_files(username))
def get_vhosts(self, username: str) -> List[Dict]:
vhosts = []
for filepath in self._vhost_files(username):
filename = os.path.basename(filepath)
match = VHOST_FILENAME_RE.match(filename)
assert match is not None, f"'{filename}' does not match expected pattern"
domain = match.group('domain')
ip_address = None
for line in open(filepath):
match = PROXY_PASS_IP_RE.match(line)
if match is None:
continue
ip_address = match.group('ip_address')
break
assert ip_address is not None, f"Could not find IP address in {filename}"
vhosts.append({'domain': domain, 'ip_address': ip_address})
return vhosts
def delete_all_vhosts_for_user(self, username: str):
filepaths = self._vhost_files(username)
if not filepaths:
return
for filepath in filepaths:
logger.info(f'Deleting {filepath}')
os.unlink(filepath)
self._reload_web_server()
def get_accounts(self) -> List[str]:
vhost_files = os.listdir(self.vhost_dir)
usernames = list({
filename.split('_', 1)[0]
for filename in vhost_files
if '_' in filename
})
return usernames