Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert some tests to pytest #1693

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import unittest

import pytest
from threatexchange.signal_type import (
md5,
raw_text,
Expand All @@ -13,38 +12,39 @@
from threatexchange.signal_type.signal_base import TextHasher


class SignalTypeHashTest(unittest.TestCase):
# List of signal types to test
SIGNAL_TYPES_TO_TEST = [
md5.VideoMD5Signal,
signal.PdqSignal,
raw_text.RawTextSignal,
trend_query.TrendQuerySignal,
url_md5.UrlMD5Signal,
url.URLSignal,
]


@pytest.mark.parametrize("signal_type", SIGNAL_TYPES_TO_TEST)
def test_signal_names_unique(signal_type):
"""
Verify that each signal type has a unique name.
"""
name = signal_type.get_name()
seen = set() # Using a set to automatically manage unique entries
assert name not in seen, f"Two signal types share the same name: {signal_type!r} and {seen}"
seen.add(name)


@pytest.mark.parametrize("signal_type", SIGNAL_TYPES_TO_TEST)
def test_signal_types_have_content(signal_type):
"""
Sanity check for signal type hashing methods.
Ensure that each signal type has associated content types.
"""
assert signal_type.get_content_types(), f"{signal_type!r} has no content types"


# TODO - maybe make a metaclass for this to automatically detect?
SIGNAL_TYPES_TO_TEST = [
md5.VideoMD5Signal,
signal.PdqSignal,
raw_text.RawTextSignal,
trend_query.TrendQuerySignal,
url_md5.UrlMD5Signal,
url.URLSignal,
]

def test_signal_names_unique(self):
seen = {}
for s in self.SIGNAL_TYPES_TO_TEST:
name = s.get_name()
assert (
name not in seen
), f"Two signal types share the same name: {s!r} and {seen[name]!r}"

def test_signal_types_have_content(self):
for s in self.SIGNAL_TYPES_TO_TEST:
assert s.get_content_types(), "{s!r} has no content types"

def test_str_hashers_have_impl(self):
text_hashers = [
s for s in self.SIGNAL_TYPES_TO_TEST if isinstance(s, TextHasher)
]
for s in text_hashers:
assert s.hash_from_str(
"test string"
), "{s!r} produced no output from hasher"
@pytest.mark.parametrize("signal_type", [s for s in SIGNAL_TYPES_TO_TEST if isinstance(s, TextHasher)])
def test_str_hashers_have_impl(signal_type):
"""
Check that each TextHasher has an implementation that produces output.
"""
assert signal_type.hash_from_str("test string"), f"{signal_type!r} produced no output from hasher"
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import unittest
import pathlib

import pytest
from threatexchange.signal_type.md5 import VideoMD5Signal

# Define the test file path
TEST_FILE = pathlib.Path(__file__).parent.parent.parent.parent.joinpath(
"data", "sample-b.jpg"
)

@pytest.fixture
def file_content():
"""
Fixture to open and yield file content for testing,
then close the file after the test.
"""
with open(TEST_FILE, "rb") as f:
yield f.read()
Comment on lines +13 to +20
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure that you need a fixture of this - the test can just open itself.

Fixtures are helpful when you are sharing setup between tests, and here's there is only one test.


class VideoMD5SignalTestCase(unittest.TestCase):
def setUp(self):
self.a_file = open(TEST_FILE, "rb")

def tearDown(self):
self.a_file.close()

def test_can_hash_simple_files(self):
assert "d35c785545392755e7e4164457657269" == VideoMD5Signal.hash_from_bytes(
self.a_file.read()
), "MD5 hash does not match"
def test_can_hash_simple_files(file_content):
"""
Test that the VideoMD5Signal produces the expected hash.
"""
expected_hash = "d35c785545392755e7e4164457657269"
computed_hash = VideoMD5Signal.hash_from_bytes(file_content)
assert computed_hash == expected_hash, "MD5 hash does not match"
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import unittest
import pickle
import typing as t
import pytest
import functools

from threatexchange.signal_type.index import (
Expand All @@ -13,145 +11,79 @@
test_entries = [
(
"0000000000000000000000000000000000000000000000000000000000000000",
dict(
{
"hash_type": "pdq",
"system_id": 9,
}
),
{"hash_type": "pdq", "system_id": 9},
),
(
"000000000000000000000000000000000000000000000000000000000000ffff",
dict(
{
"hash_type": "pdq",
"system_id": 8,
}
),
{"hash_type": "pdq", "system_id": 8},
),
(
"0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f",
dict(
{
"hash_type": "pdq",
"system_id": 7,
}
),
{"hash_type": "pdq", "system_id": 7},
),
(
"f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0",
dict(
{
"hash_type": "pdq",
"system_id": 6,
}
),
{"hash_type": "pdq", "system_id": 6},
),
(
"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
dict(
{
"hash_type": "pdq",
"system_id": 5,
}
),
{"hash_type": "pdq", "system_id": 5},
),
]


class TestPDQIndex(unittest.TestCase):
def setUp(self):
self.index = PDQIndex.build(test_entries)

def assertEqualPDQIndexMatchResults(
self, result: t.List[PDQIndexMatch], expected: t.List[PDQIndexMatch]
):
self.assertEqual(
len(result), len(expected), "search results not of expected length"
)

accum_type = t.Dict[int, t.Set[int]]

# Between python 3.8.6 and 3.8.11, something caused the order of results
# from the index to change. This was noticed for items which had the
# same distance. To allow for this, we convert result and expected
# arrays from
# [PDQIndexMatch, PDQIndexMatch] to { distance: <set of PDQIndexMatch.metadata hash> }
# This allows you to compare [PDQIndexMatch A, PDQIndexMatch B] with
# [PDQIndexMatch B, PDQIndexMatch A] as long as A.distance == B.distance.
def quality_indexed_dict_reducer(
acc: accum_type, item: PDQIndexMatch
) -> accum_type:
acc[item.similarity_info.distance] = acc.get(
item.similarity_info.distance, set()
)
# Instead of storing the unhashable item.metadata dict, store its
# hash so we can compare using self.assertSetEqual
acc[item.similarity_info.distance].add(hash(frozenset(item.metadata)))
return acc

# Convert results to distance -> set of metadata map
distance_to_result_items_map: accum_type = functools.reduce(
quality_indexed_dict_reducer, result, {}
)

# Convert expected to distance -> set of metadata map
distance_to_expected_items_map: accum_type = functools.reduce(
quality_indexed_dict_reducer, expected, {}
)

assert len(distance_to_expected_items_map) == len(
distance_to_result_items_map
), "Unequal number of items in expected and results."

for distance, result_items in distance_to_result_items_map.items():
assert (
distance in distance_to_expected_items_map
), f"Unexpected distance {distance} found"
self.assertSetEqual(result_items, distance_to_expected_items_map[distance])

def test_search_index_for_matches(self):
entry_hash = test_entries[1][0]
result = self.index.query(entry_hash)
self.assertEqualPDQIndexMatchResults(
result,
[
PDQIndexMatch(
SignalSimilarityInfoWithIntDistance(0), test_entries[1][1]
),
PDQIndexMatch(
SignalSimilarityInfoWithIntDistance(16), test_entries[0][1]
),
],
)

def test_search_index_with_no_match(self):
query_hash = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
result = self.index.query(query_hash)
self.assertEqualPDQIndexMatchResults(result, [])

def test_supports_pickling(self):
pickled_data = pickle.dumps(self.index)
assert pickled_data != None, "index does not support pickling to a data stream"

reconstructed_index = pickle.loads(pickled_data)
assert (
reconstructed_index != None
), "index does not support unpickling from data stream"
assert (
reconstructed_index.index.faiss_index != self.index.index.faiss_index
), "unpickling should create it's own faiss index in memory"

query = test_entries[0][0]
result = reconstructed_index.query(query)
self.assertEqualPDQIndexMatchResults(
result,
[
PDQIndexMatch(
SignalSimilarityInfoWithIntDistance(0), test_entries[1][1]
),
PDQIndexMatch(
SignalSimilarityInfoWithIntDistance(16), test_entries[0][1]
),
],
)
@pytest.fixture
def index():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ignorable: While this might be a misuse of feature, fixtures are basically a 1:1 mapping for setUp, so I think this a faithful translation

return PDQIndex.build(test_entries)

def assert_equal_pdq_index_match_results(result: t.List[PDQIndexMatch], expected: t.List[PDQIndexMatch]):
assert len(result) == len(expected), "search results not of expected length"

def quality_indexed_dict_reducer(
acc: t.Dict[int, t.Set[int]], item: PDQIndexMatch
) -> t.Dict[int, t.Set[int]]:
acc[item.similarity_info.distance] = acc.get(item.similarity_info.distance, set())
acc[item.similarity_info.distance].add(hash(frozenset(item.metadata)))
return acc

distance_to_result_items_map = functools.reduce(quality_indexed_dict_reducer, result, {})
distance_to_expected_items_map = functools.reduce(quality_indexed_dict_reducer, expected, {})

assert len(distance_to_expected_items_map) == len(distance_to_result_items_map), "Unequal number of items"

for distance, result_items in distance_to_result_items_map.items():
assert distance in distance_to_expected_items_map, f"Unexpected distance {distance} found"
assert result_items == distance_to_expected_items_map[distance]

def test_search_index_for_matches(index):
entry_hash = test_entries[1][0]
result = index.query(entry_hash)
assert_equal_pdq_index_match_results(
result,
[
PDQIndexMatch(SignalSimilarityInfoWithIntDistance(0), test_entries[1][1]),
PDQIndexMatch(SignalSimilarityInfoWithIntDistance(16), test_entries[0][1]),
],
)

def test_search_index_with_no_match(index):
query_hash = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
result = index.query(query_hash)
assert_equal_pdq_index_match_results(result, [])

def test_supports_pickling(index):
pickled_data = pickle.dumps(index)
assert pickled_data is not None, "index does not support pickling to a data stream"

reconstructed_index = pickle.loads(pickled_data)
assert reconstructed_index is not None, "index does not support unpickling from data stream"
assert reconstructed_index.index.faiss_index != index.index.faiss_index, "unpickling should create its own faiss index in memory"

query = test_entries[0][0]
result = reconstructed_index.query(query)
assert_equal_pdq_index_match_results(
result,
[
PDQIndexMatch(SignalSimilarityInfoWithIntDistance(0), test_entries[1][1]),
PDQIndexMatch(SignalSimilarityInfoWithIntDistance(16), test_entries[0][1]),
],
)
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

import pytest
from threatexchange.signal_type.raw_text import RawTextSignal
from threatexchange.signal_type.tests.signal_type_test_helper import MatchesStrAutoTest

from threatexchange.signal_type.raw_text import RawTextSignal
Expand All @@ -8,16 +10,19 @@
class TestRawTextSignal(MatchesStrAutoTest):
TYPE = RawTextSignal

def get_validate_hash_cases(self):
@pytest.fixture
def validate_hash_cases(self):
Comment on lines +11 to +12
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocking: This doesn't look like it needs to be a fixture - it looks like a better fit for pytest.mark.parametrize

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any response on this unadressed blocking comment?

return [
("a", "a"),
("a ", "a"),
]

def get_compare_hash_cases(self):
@pytest.fixture
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocking: This is definitely doesn't need a fixture! You can just use a list!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any response on this unaddressed blocking comment?

def compare_hash_cases(self):
return []

def get_matches_str_cases(self):
@pytest.fixture
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blocking: Ditto that this looks like a better match for pytest.mark.parametrize

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any response on this unaddressed blocking comment?

def matches_str_cases(self):
return [
("", ""),
("a", "a"),
Expand All @@ -30,3 +35,18 @@ def get_matches_str_cases(self):
("a" * 19, "a" * 18 + "b", False, 1),
("a" * 20, "a" * 19 + "b", True, 1),
]

def test_validate_hash(self, validate_hash_cases):
for case in validate_hash_cases:
input_val, expected_hash = case
assert self.TYPE.validate_hash(input_val) == expected_hash

def test_compare_hash(self, compare_hash_cases):
for case in compare_hash_cases:
input_val, expected_result = case
assert self.TYPE.compare_hash(input_val) == expected_result

def test_matches_str(self, matches_str_cases):
for case in matches_str_cases:
input_str, match_str, expected_match, threshold = case
assert self.TYPE.matches_str(input_str, match_str, threshold) == expected_match
Loading
Loading