from base64 import b64encode import contextlib import os import json import socket import subprocess import tempfile from flask import g from flask.testing import FlaskClient import gssapi import pytest from requests import Request from requests_gssapi import HTTPSPNEGOAuth from ceo_common.krb5.utils import get_fwd_tgt __all__ = ['client'] @pytest.fixture(scope='session') def client(app): app_client = app.test_client() with tempfile.TemporaryDirectory() as cache_dir: yield CeodTestClient(app_client, cache_dir) class CeodTestClient: def __init__(self, app_client: FlaskClient, cache_dir: str): self.client = app_client self.syscom_principal = 'ctdalek' # this is only used for the HTTPSNEGOAuth self.base_url = f'http://{socket.getfqdn()}' # for each principal for which we acquired a TGT, map their # username to a file (ccache) storing their TGT self.principal_ccaches = {} # this is where we'll store the credentials for each principal self.cache_dir = cache_dir # for SPNEGO self.target_name = gssapi.Name('ceod/' + socket.getfqdn()) @contextlib.contextmanager def krb5ccname_env(self, principal): """Temporarily change KRB5CCNAME to the ccache of the principal.""" old_krb5ccname = os.environ['KRB5CCNAME'] os.environ['KRB5CCNAME'] = self.principal_ccaches[principal] try: yield finally: os.environ['KRB5CCNAME'] = old_krb5ccname def get_auth(self, principal): """Acquire a HTTPSPNEGOAuth instance for the principal.""" name = gssapi.Name(principal) # the 'store' arg doesn't seem to work for DIR ccaches with self.krb5ccname_env(principal): creds = gssapi.Credentials(name=name, usage='initiate') auth = HTTPSPNEGOAuth( opportunistic_auth=True, target_name=self.target_name, creds=creds, ) return auth def kinit(self, principal): """Acquire an initial TGT for the principal.""" # For some reason, kinit with the '-c' option deletes the other # credentials in the cache collection, so we need to override the # env variable subprocess.run( ['kinit', principal], text=True, input='krb5', check=True, stdout=subprocess.DEVNULL, env={'KRB5CCNAME': self.principal_ccaches[principal]}) def get_headers(self, principal: str, need_cred: bool): if principal not in self.principal_ccaches: _, filename = tempfile.mkstemp(dir=self.cache_dir) self.principal_ccaches[principal] = filename self.kinit(principal) # Get the Authorization header (SPNEGO). # The method doesn't matter here because we just need to extract # the header using req.prepare(). req = Request('GET', self.base_url, auth=self.get_auth(principal)) headers = list(req.prepare().headers.items()) if need_cred: # Get the X-KRB5-CRED header (forwarded TGT). cred = b64encode(get_fwd_tgt( 'ceod/' + socket.getfqdn(), self.principal_ccaches[principal] )).decode() headers.append(('X-KRB5-CRED', cred)) return headers def request(self, method: str, path: str, principal: str, need_cred: bool, **kwargs): # Make sure that we're not already in a request context, otherwise # g will get overridden with pytest.raises(RuntimeError): '' in g if principal is None: principal = self.syscom_principal headers = self.get_headers(principal, need_cred) resp = self.client.open(path, method=method, headers=headers, **kwargs) status = int(resp.status.split(' ', 1)[0]) if resp.headers['content-type'] == 'application/json': data = json.loads(resp.data) else: data = [json.loads(line) for line in resp.data.splitlines()] return status, data def get(self, path, principal=None, need_cred=True, **kwargs): return self.request('GET', path, principal, need_cred, **kwargs) def post(self, path, principal=None, need_cred=True, **kwargs): return self.request('POST', path, principal, need_cred, **kwargs) def patch(self, path, principal=None, need_cred=True, **kwargs): return self.request('PATCH', path, principal, need_cred, **kwargs) def delete(self, path, principal=None, need_cred=True, **kwargs): return self.request('DELETE', path, principal, need_cred, **kwargs)