Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add similarity to raw_lookup endpoint #1639

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from werkzeug.exceptions import HTTPException

from OpenMediaMatch.blueprints.hashing import hash_media
from OpenMediaMatch.blueprints.matching import lookup_signal
from OpenMediaMatch.blueprints.matching import (
lookup_signal,
lookup_signal_with_distance,
)
from OpenMediaMatch.utils.flask_utils import api_error_handler

from OpenMediaMatch.utils import dev_utils
Expand Down Expand Up @@ -50,6 +53,11 @@ def query_media():
return signal_type_to_signal_map
abort(500, "Something went wrong while hashing the provided media.")

include_distance = bool(request.args.get("include_distance", False)) == True
lookup_signal_func = (
lookup_signal_with_distance if include_distance else lookup_signal
)

# Check if signal_type is an option in the map of hashes
signal_type_name = request.args.get("signal_type")
if signal_type_name is not None:
Expand All @@ -59,14 +67,14 @@ def query_media():
f"Requested signal type '{signal_type_name}' is not supported for the provided "
"media.",
)
return lookup_signal(
return lookup_signal_func(
signal_type_to_signal_map[signal_type_name], signal_type_name
)
return {
"matches": list(
itertools.chain(
*map(
lambda x: lookup_signal(x[1], x[0])["matches"],
lambda x: lookup_signal_func(x[1], x[0])["matches"],
signal_type_to_signal_map.items(),
),
)
Expand Down
41 changes: 38 additions & 3 deletions hasher-matcher-actioner/src/OpenMediaMatch/blueprints/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dataclasses import dataclass
import datetime
import random
import sys
import typing as t
import time

Expand All @@ -16,6 +17,7 @@

from threatexchange.signal_type.signal_base import SignalType
from threatexchange.signal_type.index import SignalTypeIndex
from threatexchange.signal_type.index import IndexMatch

from OpenMediaMatch.background_tasks.development import get_apscheduler
from OpenMediaMatch.storage import interface
Expand Down Expand Up @@ -95,14 +97,19 @@ def raw_lookup():
* Signal value (the hash)
* Optional list of banks to restrict search to
Output:
* List of matching content items
* List of matching with content_id and, if included, distance values
"""
signal = require_request_param("signal")
signal_type_name = require_request_param("signal_type")
return lookup_signal(signal, signal_type_name)
include_distance = bool(request.args.get("include_distance", False)) == True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works for me as a solution, though I think this bool parsing is different than what we do for other flags.

blocking: Can you include it in the list of inputs in the docstring?

lookup_signal_func = (
lookup_signal_with_distance if include_distance else lookup_signal
)

return lookup_signal_func(signal, signal_type_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Both of these functions return the same shape, it might be slightly easier to have this return the list and do

return {
  "matches": lookup_signal_func()
}


def lookup_signal(signal: str, signal_type_name: str) -> dict[str, list[int]]:

def query_index(signal: str, signal_type_name: str) -> IndexMatch:
storage = get_storage()
signal_type = _validate_and_transform_signal_type(signal_type_name, storage)

Expand All @@ -118,9 +125,29 @@ def lookup_signal(signal: str, signal_type_name: str) -> dict[str, list[int]]:
current_app.logger.debug("[lookup_signal] querying index")
results = index.query(signal)
current_app.logger.debug("[lookup_signal] query complete")
return results


def lookup_signal(signal: str, signal_type_name: str) -> dict[str, list[int]]:
results = query_index(signal, signal_type_name)
return {"matches": [m.metadata for m in results]}


def lookup_signal_with_distance(
signal: str, signal_type_name: str
) -> dict[str, dict[str, str]]:
results = query_index(signal, signal_type_name)
return {
"matches": [
{
"content_id": m.metadata,
"distance": m.similarity_info.pretty_str(),
}
for m in results
]
}


def _validate_and_transform_signal_type(
signal_type_name: str, storage: interface.ISignalTypeConfigStore
) -> type[SignalType]:
Expand Down Expand Up @@ -300,9 +327,17 @@ def index_cache_is_stale() -> bool:

def _get_index(signal_type: t.Type[SignalType]) -> SignalTypeIndex[int] | None:
entry = _get_index_cache().get(signal_type.get_name())

if entry is not None and is_in_pytest():
entry.reload_if_needed(get_storage())
Comment on lines +331 to +332
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocking q: Hmm, I'm very suspicious of any branching in the code that specifically changes behavior in tests.

Is there another solution that you can use in the test itself that can achieve the same results? (e.g. disabling background loading in the fixture for the test? Manually calling a load_all_indices() function?)


if entry is None:
current_app.logger.debug("[lookup_signal] no cache, loading index")
return get_storage().get_signal_type_index(signal_type)
if entry.is_ready:
return entry.index
return None


def is_in_pytest():
return "pytest" in sys.modules
Comment on lines +342 to +343
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Danger! I don't think this works consistently, and someone could accidentally include a test library in the build that silently changes the behavior of HMA. This is why I am generally suspicious of this kind of branching logic!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ha agreed, glad we found the root cause!

Loading