from base64 import b64encode import json import socket from flask import g from flask.testing import FlaskClient import gssapi import pytest from requests import Request from requests_gssapi import HTTPSPNEGOAuth from ceo_common.krb5.utils import get_fwd_tgt from .utils import krb5ccname_ctx __all__ = ['client'] @pytest.fixture(scope='session') def client(app): app_client = app.test_client() yield CeodTestClient(app_client) class CeodTestClient: def __init__(self, app_client: FlaskClient): self.client = app_client self.syscom_principal = 'ctdalek' # this is only used for the HTTPSNEGOAuth self.base_url = f'http://{socket.getfqdn()}' # for SPNEGO self.target_name = gssapi.Name('ceod/' + socket.getfqdn()) def get_auth(self, principal): """Acquire a HTTPSPNEGOAuth instance for the principal.""" name = gssapi.Name(principal) # the 'store' arg doesn't seem to work for DIR ccaches creds = gssapi.Credentials(name=name, usage='initiate') auth = HTTPSPNEGOAuth( opportunistic_auth=True, target_name=self.target_name, creds=creds, ) return auth def get_headers(self, principal: str, need_cred: bool): with krb5ccname_ctx(principal): # Get the Authorization header (SPNEGO). # The method doesn't matter here because we just need to extract # the header using req.prepare(). req = Request('GET', self.base_url, auth=self.get_auth(principal)) headers = list(req.prepare().headers.items()) if need_cred: # Get the X-KRB5-CRED header (forwarded TGT). cred = b64encode(get_fwd_tgt('ceod/' + socket.getfqdn())).decode() headers.append(('X-KRB5-CRED', cred)) return headers def request(self, method: str, path: str, principal: str, need_cred: bool, **kwargs): # Make sure that we're not already in a request context, otherwise # g will get overridden with pytest.raises(RuntimeError): '' in g if principal is None: principal = self.syscom_principal headers = self.get_headers(principal, need_cred) resp = self.client.open(path, method=method, headers=headers, **kwargs) status = int(resp.status.split(' ', 1)[0]) if resp.headers['content-type'] == 'application/json': data = json.loads(resp.data) else: data = [json.loads(line) for line in resp.data.splitlines()] return status, data def get(self, path, principal=None, need_cred=True, **kwargs): return self.request('GET', path, principal, need_cred, **kwargs) def post(self, path, principal=None, need_cred=True, **kwargs): return self.request('POST', path, principal, need_cred, **kwargs) def patch(self, path, principal=None, need_cred=True, **kwargs): return self.request('PATCH', path, principal, need_cred, **kwargs) def delete(self, path, principal=None, need_cred=True, **kwargs): return self.request('DELETE', path, principal, need_cred, **kwargs)