-
Notifications
You must be signed in to change notification settings - Fork 321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add similarity to raw_lookup endpoint #1639
Changes from all commits
43a705b
f43280e
982ae6a
e9e823d
f27623d
89f5c7d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
from dataclasses import dataclass | ||
import datetime | ||
import random | ||
import sys | ||
import typing as t | ||
import time | ||
|
||
|
@@ -16,6 +17,7 @@ | |
|
||
from threatexchange.signal_type.signal_base import SignalType | ||
from threatexchange.signal_type.index import SignalTypeIndex | ||
from threatexchange.signal_type.index import IndexMatch | ||
|
||
from OpenMediaMatch.background_tasks.development import get_apscheduler | ||
from OpenMediaMatch.storage import interface | ||
|
@@ -95,14 +97,19 @@ def raw_lookup(): | |
* Signal value (the hash) | ||
* Optional list of banks to restrict search to | ||
Output: | ||
* List of matching content items | ||
* List of matching with content_id and, if included, distance values | ||
""" | ||
signal = require_request_param("signal") | ||
signal_type_name = require_request_param("signal_type") | ||
return lookup_signal(signal, signal_type_name) | ||
include_distance = bool(request.args.get("include_distance", False)) == True | ||
lookup_signal_func = ( | ||
lookup_signal_with_distance if include_distance else lookup_signal | ||
) | ||
|
||
return lookup_signal_func(signal, signal_type_name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Both of these functions return the same shape, it might be slightly easier to have this return the list and do
|
||
|
||
def lookup_signal(signal: str, signal_type_name: str) -> dict[str, list[int]]: | ||
|
||
def query_index(signal: str, signal_type_name: str) -> IndexMatch: | ||
storage = get_storage() | ||
signal_type = _validate_and_transform_signal_type(signal_type_name, storage) | ||
|
||
|
@@ -118,9 +125,29 @@ def lookup_signal(signal: str, signal_type_name: str) -> dict[str, list[int]]: | |
current_app.logger.debug("[lookup_signal] querying index") | ||
results = index.query(signal) | ||
current_app.logger.debug("[lookup_signal] query complete") | ||
return results | ||
|
||
|
||
def lookup_signal(signal: str, signal_type_name: str) -> dict[str, list[int]]: | ||
results = query_index(signal, signal_type_name) | ||
return {"matches": [m.metadata for m in results]} | ||
|
||
|
||
def lookup_signal_with_distance( | ||
signal: str, signal_type_name: str | ||
) -> dict[str, dict[str, str]]: | ||
results = query_index(signal, signal_type_name) | ||
return { | ||
"matches": [ | ||
{ | ||
"content_id": m.metadata, | ||
"distance": m.similarity_info.pretty_str(), | ||
} | ||
for m in results | ||
] | ||
} | ||
|
||
|
||
def _validate_and_transform_signal_type( | ||
signal_type_name: str, storage: interface.ISignalTypeConfigStore | ||
) -> type[SignalType]: | ||
|
@@ -300,9 +327,17 @@ def index_cache_is_stale() -> bool: | |
|
||
def _get_index(signal_type: t.Type[SignalType]) -> SignalTypeIndex[int] | None: | ||
entry = _get_index_cache().get(signal_type.get_name()) | ||
|
||
if entry is not None and is_in_pytest(): | ||
entry.reload_if_needed(get_storage()) | ||
Comment on lines
+331
to
+332
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. blocking q: Hmm, I'm very suspicious of any branching in the code that specifically changes behavior in tests. Is there another solution that you can use in the test itself that can achieve the same results? (e.g. disabling background loading in the fixture for the test? Manually calling a load_all_indices() function?) |
||
|
||
if entry is None: | ||
current_app.logger.debug("[lookup_signal] no cache, loading index") | ||
return get_storage().get_signal_type_index(signal_type) | ||
if entry.is_ready: | ||
return entry.index | ||
return None | ||
|
||
|
||
def is_in_pytest(): | ||
return "pytest" in sys.modules | ||
Comment on lines
+342
to
+343
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Danger! I don't think this works consistently, and someone could accidentally include a test library in the build that silently changes the behavior of HMA. This is why I am generally suspicious of this kind of branching logic! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ha agreed, glad we found the root cause! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works for me as a solution, though I think this bool parsing is different than what we do for other flags.
blocking: Can you include it in the list of inputs in the docstring?