Skip to content

Commit

Permalink
Merge pull request #294 from RDFLib/david/check-header-clean
Browse files Browse the repository at this point in the history
David/check header
  • Loading branch information
lalewis1 authored Nov 1, 2024
2 parents 2a17d2c + 6d815b2 commit f43a0dc
Show file tree
Hide file tree
Showing 7 changed files with 650 additions and 612 deletions.
1,183 changes: 586 additions & 597 deletions poetry.lock

Large diffs are not rendered by default.

21 changes: 16 additions & 5 deletions prez/app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from pathlib import Path
from contextlib import asynccontextmanager
from functools import partial
from pathlib import Path
from textwrap import dedent
from typing import Optional, Dict, Union, Any

Expand All @@ -28,14 +28,16 @@
URINotFoundException,
NoProfilesException,
InvalidSPARQLQueryException,
PrefixNotFoundException, NoEndpointNodeshapeException,
PrefixNotFoundException,
NoEndpointNodeshapeException
)
from prez.middleware import create_validate_header_middleware
from prez.repositories import RemoteSparqlRepo, PyoxigraphRepo, OxrdflibRepo
from prez.routers.base_router import router as base_prez_router
from prez.routers.custom_endpoints import create_dynamic_router
from prez.routers.identifier import router as identifier_router
from prez.routers.management import router as management_router, config_router
from prez.routers.ogc_features_router import features_subapi
from prez.routers.base_router import router as base_prez_router
from prez.routers.sparql import router as sparql_router
from prez.services.app_service import (
healthcheck_sparql_endpoints,
Expand All @@ -54,7 +56,8 @@
catch_uri_not_found_exception,
catch_no_profiles_exception,
catch_invalid_sparql_query,
catch_prefix_not_found_exception, catch_no_endpoint_nodeshape_exception,
catch_prefix_not_found_exception,
catch_no_endpoint_nodeshape_exception,
)
from prez.services.generate_profiles import create_profiles_graph
from prez.services.prez_logging import setup_logger
Expand Down Expand Up @@ -186,7 +189,11 @@ def assemble_app(
app.include_router(sparql_router)
if _settings.configuration_mode:
app.include_router(config_router)
app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")
app.mount(
"/static",
StaticFiles(directory=Path(__file__).parent / "static"),
name="static",
)
if _settings.enable_ogc_features:
app.mount(
_settings.ogc_features_mount_path,
Expand All @@ -213,6 +220,10 @@ def assemble_app(
allow_headers=["*"],
expose_headers=["*"],
)
validate_header_middleware = create_validate_header_middleware(
settings.required_header
)
app.middleware("http")(validate_header_middleware)

return app

Expand Down
7 changes: 5 additions & 2 deletions prez/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import toml
from pydantic import field_validator
from pydantic import field_validator, Field
from pydantic_settings import BaseSettings
from rdflib import DCTERMS, RDFS, SDO, URIRef
from rdflib.namespace import SKOS
Expand Down Expand Up @@ -81,12 +81,15 @@ class Settings(BaseSettings):
]
enable_sparql_endpoint: bool = False
enable_ogc_features: bool = True
ogc_features_mount_path: str = "/catalogs/{catalogId}/collections/{recordsCollectionId}/features"
ogc_features_mount_path: str = (
"/catalogs/{catalogId}/collections/{recordsCollectionId}/features"
)
custom_endpoints: bool = False
configuration_mode: bool = False
temporal_predicate: Optional[URIRef] = SDO.temporal
endpoint_to_template_query_filename: Optional[Dict[str, str]] = {}
prez_ui_url: Optional[str] = None
required_header: dict[str, str] | None = None

@field_validator("prez_version")
@classmethod
Expand Down
23 changes: 23 additions & 0 deletions prez/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from fastapi import Request
from fastapi.responses import JSONResponse


def create_validate_header_middleware(required_header: dict[str, str] | None):
async def validate_header(request: Request, call_next):
if required_header:
header_name, expected_value = next(iter(required_header.items()))
if (
header_name not in request.headers
or request.headers[header_name] != expected_value
):
return JSONResponse( # attempted to use Exception and although it was caught it did not propagate
status_code=400,
content={
"error": "Header Validation Error",
"message": f"Missing or invalid header: {header_name}",
"code": "HEADER_VALIDATION_ERROR",
},
)
return await call_next(request)

return validate_header
4 changes: 2 additions & 2 deletions prez/repositories/pyoxigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def _sparql(self, query: str) -> dict | Graph | bool:
elif isinstance(results, pyoxigraph.QueryTriples): # a CONSTRUCT query result
result_graph = self._handle_query_triples_results(results)
return result_graph
elif isinstance(results, bool):
results_dict = {"head": {}, "boolean": results}
elif isinstance(results, pyoxigraph.QueryBoolean):
results_dict = {"head": {}, "boolean": bool(results)}
return results_dict
else:
raise TypeError(f"Unexpected result class {type(results)}")
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ aiocache = "^0.12.2"
sparql-grammar-pydantic = "^0.1.2"
rdf2geojson = {git = "https://github.com/ashleysommer/rdf2geojson.git", rev = "v0.2.1"}
python-multipart = "^0.0.9"
pyoxigraph = "^0.3.22"
oxrdflib = "^0.3.7"
pyoxigraph = "^0.4.2"
oxrdflib = {git = "https://github.com/oxigraph/oxrdflib.git", rev = "main"}

[tool.poetry.extras]
server = ["uvicorn"]
Expand Down
20 changes: 16 additions & 4 deletions tests/test_sparql.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import pytest


def test_select(client):
"""check that a valid select query returns a 200 response."""
r = client.get(
Expand All @@ -14,11 +17,20 @@ def test_construct(client):
assert r.status_code == 200


def test_ask(client):
"""check that a valid ask query returns a 200 response."""
r = client.get(
"/sparql?query=PREFIX%20ex%3A%20%3Chttp%3A%2F%2Fexample.com%2Fdatasets%2F%3E%0APREFIX%20dcterms%3A%20%3Chttp%3A%2F%2Fpurl.org%2Fdc%2Fterms%2F%3E%0A%0AASK%0AWHERE%20%7B%0A%20%20%3Fsubject%20dcterms%3Atitle%20%3Ftitle%20.%0A%20%20FILTER%20CONTAINS(LCASE(%3Ftitle)%2C%20%22sandgate%22)%0A%7D"
@pytest.mark.parametrize("query,expected_result", [
(
"/sparql?query=PREFIX%20ex%3A%20%3Chttp%3A%2F%2Fexample.com%2Fdatasets%2F%3E%0APREFIX%20dcterms%3A%20%3Chttp%3A%2F%2Fpurl.org%2Fdc%2Fterms%2F%3E%0A%0AASK%0AWHERE%20%7B%0A%20%20%3Fsubject%20dcterms%3Atitle%20%3Ftitle%20.%0A%20%20FILTER%20CONTAINS(LCASE(%3Ftitle)%2C%20%22sandgate%22)%0A%7D",
True
),
(
"/sparql?query=ASK%20%7B%20%3Chttps%3A%2F%2Ffake%3E%20%3Fp%20%3Fo%20%7D",
False
)
])
def test_ask(client, query, expected_result):
"""Check that valid ASK queries return a 200 response with the expected boolean result."""
r = client.get(query)

assert r.status_code == 200


Expand Down

0 comments on commit f43a0dc

Please sign in to comment.