from base64 import b64encode import json import socket import subprocess import tempfile from flask import g from flask.testing import FlaskClient import gssapi import pytest from requests import Request from requests_gssapi import HTTPSPNEGOAuth from ceo_common.krb5.utils import get_fwd_tgt __all__ = ['client'] @pytest.fixture(scope='session') def client(app): app_client = app.test_client() with tempfile.TemporaryDirectory() as cache_dir: yield CeodTestClient(app_client, cache_dir) class CeodTestClient: def __init__(self, app_client: FlaskClient, cache_dir: str): self.client = app_client self.syscom_principal = 'ctdalek' # this is only used for the HTTPSNEGOAuth self.base_url = f'http://{socket.getfqdn()}' # keep a list of all of the principals for which we acquired a TGT self.principals = [] # this is where we'll store the credentials for each principal self.ccache = 'DIR:' + cache_dir def get_auth(self, principal): name = gssapi.Name(principal) creds = gssapi.Credentials( name=name, usage='initiate', store={'ccache': self.ccache}, ) auth = HTTPSPNEGOAuth( opportunistic_auth=True, target_name='ceod', creds=creds, ) return auth def get_headers(self, principal): if principal not in self.principals: # Acquire the initial TGT subprocess.run( ['kinit', '-c', self.ccache, principal], text=True, input='krb5', check=True, stdout=subprocess.DEVNULL) self.principals.append(principal) # Get the Authorization header (SPNEGO). # The method doesn't matter here because we just need to extract # the header using req.prepare(). req = Request('GET', self.base_url, auth=self.get_auth(principal)) headers = list(req.prepare().headers.items()) # Get the X-KRB5-CRED header (forwarded TGT). cred = b64encode( get_fwd_tgt('ceod/' + socket.getfqdn(), self.ccache) ).decode() headers.append(('X-KRB5-CRED', cred)) return headers def request(self, method, path, principal, **kwargs): # Make sure that we're not already in a request context, otherwise # g will get overridden with pytest.raises(RuntimeError): '' in g if principal is None: principal = self.syscom_principal resp = self.client.open( path, method=method, headers=self.get_headers(principal), **kwargs) status = int(resp.status.split(' ', 1)[0]) if resp.headers['content-type'] == 'application/json': data = json.loads(resp.data) else: data = [json.loads(line) for line in resp.data.splitlines()] return status, data def get(self, path, principal=None, **kwargs): return self.request('GET', path, principal, **kwargs) def post(self, path, principal=None, **kwargs): return self.request('POST', path, principal, **kwargs) def patch(self, path, principal=None, **kwargs): return self.request('PATCH', path, principal, **kwargs) def delete(self, path, principal=None, **kwargs): return self.request('DELETE', path, principal, **kwargs)