Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't do 2 SQL updates per editing #11232

Merged
merged 1 commit into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions geoportal/c2cgeoportal_geoportal/lib/dbreflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,19 @@ def get_table(
# create table and reflect it
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "Did not recognize type 'geometry' of column", SAWarning)
args = [tablename, metadata]
args = []
if primary_key is not None:
# Ensure we have a primary key to be able to edit views
args.append(Column(primary_key, Integer, primary_key=True))
with _get_table_lock:
table = Table(*args, schema=schema, autoload_with=engine) # type: ignore[arg-type]
print(f"Table {tablename} loaded")
print([c.name for c in table.columns])
table = Table(
tablename,
metadata,
*args,
schema=schema,
autoload_with=engine,
keep_existing=True,
)
return table


Expand Down
261 changes: 136 additions & 125 deletions geoportal/c2cgeoportal_geoportal/views/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import shapely.geometry
import sqlalchemy.ext.declarative
import sqlalchemy.orm
import sqlalchemy.orm.query
from geoalchemy2 import Geometry
from geoalchemy2.shape import from_shape, to_shape
from geojson.feature import Feature, FeatureCollection
Expand Down Expand Up @@ -68,10 +69,120 @@

if TYPE_CHECKING:
from c2cgeoportal_commons.models import main # pylint: disable=ungrouped-imports.useless-suppression


_LOG = logging.getLogger(__name__)
_CACHE_REGION = get_region("std")


class _BaseCallback:
def __init__(self, layer: "main.Layer"):
sbrunner marked this conversation as resolved.
Show resolved Hide resolved
self.layer = layer

def update(self, request: pyramid.request.Request, obj: Any) -> None:
last_update_date = Layers.get_metadata(self.layer, "lastUpdateDateColumn")
if last_update_date is not None:
setattr(obj, last_update_date, datetime.now())

last_update_user = Layers.get_metadata(self.layer, "lastUpdateUserColumn")
if last_update_user is not None:
setattr(obj, last_update_user, request.user.id)

def _get_geometry_check_base_query(
self, request: pyramid.request.Request
) -> sqlalchemy.orm.query.RowReturningQuery[tuple[int]]:
from c2cgeoportal_commons.models.main import ( # pylint: disable=import-outside-toplevel
Layer,
RestrictionArea,
Role,
)

assert models.DBSession is not None
allowed = models.DBSession.query(func.count(RestrictionArea.id)) # pylint: disable=not-callable
allowed = allowed.join(RestrictionArea.roles)
allowed = allowed.join(RestrictionArea.layers)
allowed = allowed.filter(RestrictionArea.readwrite.is_(True))
allowed = allowed.filter(Role.id.in_(get_roles_id(request)))
allowed = allowed.filter(Layer.id == self.layer.id)
return allowed


class _InsertCallback(_BaseCallback):
def __call__(self, request: pyramid.request.Request, feature: Feature, obj: Any) -> None:
from c2cgeoportal_commons.models.main import ( # pylint: disable=import-outside-toplevel
RestrictionArea,
)

assert models.DBSession is not None

geom = feature.geometry
if geom and not isinstance(geom, geojson.geometry.Default):
shape = shapely.geometry.shape(geom)
srid = Layers._get_geom_col_info(self.layer)[1]
spatial_elt = from_shape(shape, srid=srid)
allowed = self._get_geometry_check_base_query(request)
allowed = allowed.filter(
or_(RestrictionArea.area.is_(None), RestrictionArea.area.ST_Contains(spatial_elt))
)
if allowed.scalar() == 0:
raise HTTPForbidden()

# Check if geometry is valid
if Layers._get_validation_setting(self.layer, request):
Layers._validate_geometry(spatial_elt)
sbrunner marked this conversation as resolved.
Show resolved Hide resolved

self.update(request, obj)


class _UpdateCallback(_BaseCallback):
def __call__(self, request: pyramid.request.Request, feature: Feature, obj: Any) -> None:
from c2cgeoportal_commons.models.main import ( # pylint: disable=import-outside-toplevel
RestrictionArea,
)

assert models.DBSession is not None

# we need both the "original" and "new" geometry to be
# within the restriction area
geom_attr, srid = Layers._get_geom_col_info(self.layer)
geom_attr = getattr(obj, geom_attr)
geom = feature.geometry
allowed = self._get_geometry_check_base_query(request)
allowed = allowed.filter(
or_(RestrictionArea.area.is_(None), RestrictionArea.area.ST_Contains(geom_attr))
)
spatial_elt = None
if geom and not isinstance(geom, geojson.geometry.Default):
shape = shapely.geometry.shape(geom)
spatial_elt = from_shape(shape, srid=srid)
allowed = allowed.filter(
or_(RestrictionArea.area.is_(None), RestrictionArea.area.ST_Contains(spatial_elt))
)
if allowed.scalar() == 0:
raise HTTPForbidden()

# Check is geometry is valid
if Layers._get_validation_setting(self.layer, request):
Layers._validate_geometry(spatial_elt)

self.update(request, obj)


class _DeleteCallback(_BaseCallback):
def __call__(self, request: pyramid.request.Request, obj: Any) -> None:
from c2cgeoportal_commons.models.main import ( # pylint: disable=import-outside-toplevel
RestrictionArea,
)

geom_attr = getattr(obj, Layers._get_geom_col_info(self.layer)[0])
allowed = self._get_geometry_check_base_query(request)
allowed = allowed.filter(
or_(RestrictionArea.area.is_(None), RestrictionArea.area.ST_Contains(geom_attr))
)
if allowed.scalar() == 0:
raise HTTPForbidden()


class Layers:
"""
All the layers view (editing).
Expand All @@ -81,8 +192,12 @@ class Layers:

def __init__(self, request: pyramid.request.Request):
self.request = request
self.settings = request.registry.settings.get("layers", {})
self.layers_enum_config = self.settings.get("enum")
self.settings = self._get_settings(request)
self.layers_enum_config = self.settings.get("enum", {})

@staticmethod
def _get_settings(request: pyramid.request.Request) -> dict[str, Any]:
return cast(dict[str, Any], request.registry.settings.get("layers", {}))

@staticmethod
def _get_geom_col_info(layer: "main.Layer") -> tuple[str, int]:
Expand Down Expand Up @@ -145,16 +260,24 @@ def _get_layer_for_request(self) -> "main.Layer":
"""Return a ``Layer`` object for the first layer id found in the ``layer_id`` matchdict."""
return next(self._get_layers_for_request())

def _get_protocol_for_layer(self, layer: "main.Layer", **kwargs: Any) -> Protocol:
def _get_protocol_for_layer(self, layer: "main.Layer") -> Protocol:
"""Return a papyrus ``Protocol`` for the ``Layer`` object."""
cls = get_layer_class(layer)
geom_attr = self._get_geom_col_info(layer)[0]
return Protocol(models.DBSession, cls, geom_attr, **kwargs)

def _get_protocol_for_request(self, **kwargs: Any) -> Protocol:
return Protocol(
models.DBSession,
cls,
geom_attr,
before_insert=_InsertCallback(layer),
before_update=_UpdateCallback(layer),
before_delete=_DeleteCallback(layer),
)

def _get_protocol_for_request(self) -> Protocol:
"""Return a papyrus ``Protocol`` for the first layer id found in the ``layer_id`` matchdict."""
layer = self._get_layer_for_request()
return self._get_protocol_for_layer(layer, **kwargs)
return self._get_protocol_for_layer(layer)

def _proto_read(self, layer: "main.Layer") -> FeatureCollection:
"""Read features for the layer based on the self.request."""
Expand Down Expand Up @@ -265,56 +388,18 @@ def count(self) -> int:

@view_config(route_name="layers_create", renderer="geojson") # type: ignore
def create(self) -> FeatureCollection | None:
from c2cgeoportal_commons.models.main import ( # pylint: disable=import-outside-toplevel
Layer,
RestrictionArea,
Role,
)

set_common_headers(self.request, "layers", Cache.PRIVATE_NO)

if self.request.user is None:
raise HTTPForbidden()

self.request.response.cache_control.no_cache = True

layer = self._get_layer_for_request()

def check_geometry(_: Any, feature: Feature, obj: Any) -> None:
del obj # unused
assert models.DBSession is not None

geom = feature.geometry
if geom and not isinstance(geom, geojson.geometry.Default):
shape = shapely.geometry.shape(geom)
srid = self._get_geom_col_info(layer)[1]
spatial_elt = from_shape(shape, srid=srid)
allowed = models.DBSession.query(
func.count(RestrictionArea.id) # pylint: disable=not-callable
)
allowed = allowed.join(RestrictionArea.roles)
allowed = allowed.join(RestrictionArea.layers)
allowed = allowed.filter(RestrictionArea.readwrite.is_(True))
allowed = allowed.filter(Role.id.in_(get_roles_id(self.request)))
allowed = allowed.filter(Layer.id == layer.id)
allowed = allowed.filter(
or_(RestrictionArea.area.is_(None), RestrictionArea.area.ST_Contains(spatial_elt))
)
if allowed.scalar() == 0:
raise HTTPForbidden()

# Check if geometry is valid
if self._get_validation_setting(layer):
self._validate_geometry(spatial_elt)

protocol = self._get_protocol_for_layer(layer, before_create=check_geometry)
protocol = self._get_protocol_for_request()
try:
features = protocol.create(self.request)
if isinstance(features, HTTPException):
raise features
if features is not None:
for feature in features.features: # pylint: disable=no-member
self._log_last_update(layer, feature)
return features
except TopologicalError as e:
self.request.response.status_int = 400
Expand All @@ -327,12 +412,6 @@ def check_geometry(_: Any, feature: Feature, obj: Any) -> None:

@view_config(route_name="layers_update", renderer="geojson") # type: ignore
def update(self) -> Feature:
from c2cgeoportal_commons.models.main import ( # pylint: disable=import-outside-toplevel
Layer,
RestrictionArea,
Role,
)

set_common_headers(self.request, "layers", Cache.PRIVATE_NO)

if self.request.user is None:
Expand All @@ -341,45 +420,11 @@ def update(self) -> Feature:
self.request.response.cache_control.no_cache = True

feature_id = self.request.matchdict.get("feature_id")
layer = self._get_layer_for_request()

def check_geometry(_: Any, feature: Feature, obj: Any) -> None:
assert models.DBSession is not None

# we need both the "original" and "new" geometry to be
# within the restriction area
geom_attr, srid = self._get_geom_col_info(layer)
geom_attr = getattr(obj, geom_attr)
geom = feature.geometry
allowed = models.DBSession.query(func.count(RestrictionArea.id)) # pylint: disable=not-callable
allowed = allowed.join(RestrictionArea.roles)
allowed = allowed.join(RestrictionArea.layers)
allowed = allowed.filter(RestrictionArea.readwrite.is_(True))
allowed = allowed.filter(Role.id.in_(get_roles_id(self.request)))
allowed = allowed.filter(Layer.id == layer.id)
allowed = allowed.filter(
or_(RestrictionArea.area.is_(None), RestrictionArea.area.ST_Contains(geom_attr))
)
spatial_elt = None
if geom and not isinstance(geom, geojson.geometry.Default):
shape = shapely.geometry.shape(geom)
spatial_elt = from_shape(shape, srid=srid)
allowed = allowed.filter(
or_(RestrictionArea.area.is_(None), RestrictionArea.area.ST_Contains(spatial_elt))
)
if allowed.scalar() == 0:
raise HTTPForbidden()

# Check is geometry is valid
if self._get_validation_setting(layer):
self._validate_geometry(spatial_elt)

protocol = self._get_protocol_for_layer(layer, before_update=check_geometry)
protocol = self._get_protocol_for_request()
try:
feature = protocol.update(self.request, feature_id)
if isinstance(feature, HTTPException):
raise feature
self._log_last_update(layer, feature)
return cast(Feature, feature)
except TopologicalError as e:
self.request.response.status_int = 400
Expand All @@ -403,15 +448,6 @@ def _validate_geometry(geom: geoalchemy2.elements.WKBElement | None) -> None:
reason = models.DBSession.query(func.ST_IsValidReason(func.ST_GeomFromEWKB(geom))).scalar()
raise TopologicalError(reason)

def _log_last_update(self, layer: "main.Layer", feature: Feature) -> None:
last_update_date = self.get_metadata(layer, "lastUpdateDateColumn")
if last_update_date is not None:
setattr(feature, last_update_date, datetime.now())

last_update_user = self.get_metadata(layer, "lastUpdateUserColumn")
if last_update_user is not None:
setattr(feature, last_update_user, self.request.user.id)

@staticmethod
def get_metadata(layer: "main.Layer", key: str, default: str | None = None) -> str | None:
metadata = layer.get_metadata(key)
Expand All @@ -420,46 +456,21 @@ def get_metadata(layer: "main.Layer", key: str, default: str | None = None) -> s
return metadata.value
return default

def _get_validation_setting(self, layer: "main.Layer") -> bool:
@classmethod
def _get_validation_setting(cls, layer: "main.Layer", request: pyramid.request.Request) -> bool:
# The validation UIMetadata is stored as a string, not a boolean
should_validate = self.get_metadata(layer, "geometryValidation", None)
should_validate = cls.get_metadata(layer, "geometryValidation", None)
if should_validate:
return should_validate.lower() != "false"
return cast(bool, self.settings.get("geometry_validation", False))
return cast(bool, cls._get_settings(request).get("geometry_validation", False))

@view_config(route_name="layers_delete") # type: ignore
def delete(self) -> pyramid.response.Response:
assert models.DBSession is not None

from c2cgeoportal_commons.models.main import ( # pylint: disable=import-outside-toplevel
Layer,
RestrictionArea,
Role,
)

if self.request.user is None:
raise HTTPForbidden()

feature_id = self.request.matchdict.get("feature_id")
layer = self._get_layer_for_request()

def security_cb(_: Any, obj: Any) -> None:
assert models.DBSession is not None

geom_attr = getattr(obj, self._get_geom_col_info(layer)[0])
allowed = models.DBSession.query(func.count(RestrictionArea.id)) # pylint: disable=not-callable
allowed = allowed.join(RestrictionArea.roles)
allowed = allowed.join(RestrictionArea.layers)
allowed = allowed.filter(RestrictionArea.readwrite.is_(True))
allowed = allowed.filter(Role.id.in_(get_roles_id(self.request)))
allowed = allowed.filter(Layer.id == layer.id)
allowed = allowed.filter(
or_(RestrictionArea.area.is_(None), RestrictionArea.area.ST_Contains(geom_attr))
)
if allowed.scalar() == 0:
raise HTTPForbidden()

protocol = self._get_protocol_for_layer(layer, before_delete=security_cb)
protocol = self._get_protocol_for_request()
response = protocol.delete(self.request, feature_id)
if isinstance(response, HTTPException):
raise response
Expand Down
Loading