pyceo/tests/conftest_ceod_api.py

98 lines
3.3 KiB
Python

from base64 import b64encode
import json
import socket
import subprocess
import tempfile
from flask import g
from flask.testing import FlaskClient
import gssapi
import pytest
from requests import Request
from requests_gssapi import HTTPSPNEGOAuth
from ceo_common.krb5.utils import get_fwd_tgt
__all__ = ['client']
@pytest.fixture(scope='session')
def client(app):
app_client = app.test_client()
with tempfile.TemporaryDirectory() as cache_dir:
yield CeodTestClient(app_client, cache_dir)
class CeodTestClient:
def __init__(self, app_client: FlaskClient, cache_dir: str):
self.client = app_client
self.syscom_principal = 'ctdalek'
# this is only used for the HTTPSNEGOAuth
self.base_url = f'http://{socket.getfqdn()}'
# keep a list of all of the principals for which we acquired a TGT
self.principals = []
# this is where we'll store the credentials for each principal
self.ccache = 'DIR:' + cache_dir
def get_auth(self, principal):
name = gssapi.Name(principal)
creds = gssapi.Credentials(
name=name,
usage='initiate',
store={'ccache': self.ccache},
)
auth = HTTPSPNEGOAuth(
opportunistic_auth=True,
target_name='ceod',
creds=creds,
)
return auth
def get_headers(self, principal):
if principal not in self.principals:
# Acquire the initial TGT
subprocess.run(
['kinit', '-c', self.ccache, principal],
text=True, input='krb5',
check=True, stdout=subprocess.DEVNULL)
self.principals.append(principal)
# Get the Authorization header (SPNEGO).
# The method doesn't matter here because we just need to extract
# the header using req.prepare().
req = Request('GET', self.base_url, auth=self.get_auth(principal))
headers = list(req.prepare().headers.items())
# Get the X-KRB5-CRED header (forwarded TGT).
cred = b64encode(
get_fwd_tgt('ceod/' + socket.getfqdn(), self.ccache)
).decode()
headers.append(('X-KRB5-CRED', cred))
return headers
def request(self, method, path, principal, **kwargs):
# Make sure that we're not already in a request context, otherwise
# g will get overridden
with pytest.raises(RuntimeError):
'' in g
if principal is None:
principal = self.syscom_principal
resp = self.client.open(
path, method=method, headers=self.get_headers(principal), **kwargs)
status = int(resp.status.split(' ', 1)[0])
if resp.headers['content-type'] == 'application/json':
data = json.loads(resp.data)
else:
data = [json.loads(line) for line in resp.data.splitlines()]
return status, data
def get(self, path, principal=None, **kwargs):
return self.request('GET', path, principal, **kwargs)
def post(self, path, principal=None, **kwargs):
return self.request('POST', path, principal, **kwargs)
def patch(self, path, principal=None, **kwargs):
return self.request('PATCH', path, principal, **kwargs)
def delete(self, path, principal=None, **kwargs):
return self.request('DELETE', path, principal, **kwargs)