-
Notifications
You must be signed in to change notification settings - Fork 321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Convert some tests to pytest #1693
base: main
Are you sure you want to change the base?
Changes from 5 commits
a45f986
890ac4b
a4ccca2
485f8f7
e71df22
29c27d4
f303965
be67fa4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,28 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
|
||
import unittest | ||
import pathlib | ||
|
||
import pytest | ||
from threatexchange.signal_type.md5 import VideoMD5Signal | ||
|
||
# Define the test file path | ||
TEST_FILE = pathlib.Path(__file__).parent.parent.parent.parent.joinpath( | ||
"data", "sample-b.jpg" | ||
) | ||
|
||
@pytest.fixture | ||
def file_content(): | ||
""" | ||
Fixture to open and yield file content for testing, | ||
then close the file after the test. | ||
""" | ||
with open(TEST_FILE, "rb") as f: | ||
yield f.read() | ||
|
||
class VideoMD5SignalTestCase(unittest.TestCase): | ||
def setUp(self): | ||
self.a_file = open(TEST_FILE, "rb") | ||
|
||
def tearDown(self): | ||
self.a_file.close() | ||
|
||
def test_can_hash_simple_files(self): | ||
assert "d35c785545392755e7e4164457657269" == VideoMD5Signal.hash_from_bytes( | ||
self.a_file.read() | ||
), "MD5 hash does not match" | ||
def test_can_hash_simple_files(file_content): | ||
""" | ||
Test that the VideoMD5Signal produces the expected hash. | ||
""" | ||
expected_hash = "d35c785545392755e7e4164457657269" | ||
computed_hash = VideoMD5Signal.hash_from_bytes(file_content) | ||
assert computed_hash == expected_hash, "MD5 hash does not match" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,6 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
|
||
import unittest | ||
import pickle | ||
import typing as t | ||
import pytest | ||
import functools | ||
|
||
from threatexchange.signal_type.index import ( | ||
|
@@ -13,145 +11,79 @@ | |
test_entries = [ | ||
( | ||
"0000000000000000000000000000000000000000000000000000000000000000", | ||
dict( | ||
{ | ||
"hash_type": "pdq", | ||
"system_id": 9, | ||
} | ||
), | ||
{"hash_type": "pdq", "system_id": 9}, | ||
), | ||
( | ||
"000000000000000000000000000000000000000000000000000000000000ffff", | ||
dict( | ||
{ | ||
"hash_type": "pdq", | ||
"system_id": 8, | ||
} | ||
), | ||
{"hash_type": "pdq", "system_id": 8}, | ||
), | ||
( | ||
"0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f", | ||
dict( | ||
{ | ||
"hash_type": "pdq", | ||
"system_id": 7, | ||
} | ||
), | ||
{"hash_type": "pdq", "system_id": 7}, | ||
), | ||
( | ||
"f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0f0", | ||
dict( | ||
{ | ||
"hash_type": "pdq", | ||
"system_id": 6, | ||
} | ||
), | ||
{"hash_type": "pdq", "system_id": 6}, | ||
), | ||
( | ||
"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", | ||
dict( | ||
{ | ||
"hash_type": "pdq", | ||
"system_id": 5, | ||
} | ||
), | ||
{"hash_type": "pdq", "system_id": 5}, | ||
), | ||
] | ||
|
||
|
||
class TestPDQIndex(unittest.TestCase): | ||
def setUp(self): | ||
self.index = PDQIndex.build(test_entries) | ||
|
||
def assertEqualPDQIndexMatchResults( | ||
self, result: t.List[PDQIndexMatch], expected: t.List[PDQIndexMatch] | ||
): | ||
self.assertEqual( | ||
len(result), len(expected), "search results not of expected length" | ||
) | ||
|
||
accum_type = t.Dict[int, t.Set[int]] | ||
|
||
# Between python 3.8.6 and 3.8.11, something caused the order of results | ||
# from the index to change. This was noticed for items which had the | ||
# same distance. To allow for this, we convert result and expected | ||
# arrays from | ||
# [PDQIndexMatch, PDQIndexMatch] to { distance: <set of PDQIndexMatch.metadata hash> } | ||
# This allows you to compare [PDQIndexMatch A, PDQIndexMatch B] with | ||
# [PDQIndexMatch B, PDQIndexMatch A] as long as A.distance == B.distance. | ||
def quality_indexed_dict_reducer( | ||
acc: accum_type, item: PDQIndexMatch | ||
) -> accum_type: | ||
acc[item.similarity_info.distance] = acc.get( | ||
item.similarity_info.distance, set() | ||
) | ||
# Instead of storing the unhashable item.metadata dict, store its | ||
# hash so we can compare using self.assertSetEqual | ||
acc[item.similarity_info.distance].add(hash(frozenset(item.metadata))) | ||
return acc | ||
|
||
# Convert results to distance -> set of metadata map | ||
distance_to_result_items_map: accum_type = functools.reduce( | ||
quality_indexed_dict_reducer, result, {} | ||
) | ||
|
||
# Convert expected to distance -> set of metadata map | ||
distance_to_expected_items_map: accum_type = functools.reduce( | ||
quality_indexed_dict_reducer, expected, {} | ||
) | ||
|
||
assert len(distance_to_expected_items_map) == len( | ||
distance_to_result_items_map | ||
), "Unequal number of items in expected and results." | ||
|
||
for distance, result_items in distance_to_result_items_map.items(): | ||
assert ( | ||
distance in distance_to_expected_items_map | ||
), f"Unexpected distance {distance} found" | ||
self.assertSetEqual(result_items, distance_to_expected_items_map[distance]) | ||
|
||
def test_search_index_for_matches(self): | ||
entry_hash = test_entries[1][0] | ||
result = self.index.query(entry_hash) | ||
self.assertEqualPDQIndexMatchResults( | ||
result, | ||
[ | ||
PDQIndexMatch( | ||
SignalSimilarityInfoWithIntDistance(0), test_entries[1][1] | ||
), | ||
PDQIndexMatch( | ||
SignalSimilarityInfoWithIntDistance(16), test_entries[0][1] | ||
), | ||
], | ||
) | ||
|
||
def test_search_index_with_no_match(self): | ||
query_hash = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" | ||
result = self.index.query(query_hash) | ||
self.assertEqualPDQIndexMatchResults(result, []) | ||
|
||
def test_supports_pickling(self): | ||
pickled_data = pickle.dumps(self.index) | ||
assert pickled_data != None, "index does not support pickling to a data stream" | ||
|
||
reconstructed_index = pickle.loads(pickled_data) | ||
assert ( | ||
reconstructed_index != None | ||
), "index does not support unpickling from data stream" | ||
assert ( | ||
reconstructed_index.index.faiss_index != self.index.index.faiss_index | ||
), "unpickling should create it's own faiss index in memory" | ||
|
||
query = test_entries[0][0] | ||
result = reconstructed_index.query(query) | ||
self.assertEqualPDQIndexMatchResults( | ||
result, | ||
[ | ||
PDQIndexMatch( | ||
SignalSimilarityInfoWithIntDistance(0), test_entries[1][1] | ||
), | ||
PDQIndexMatch( | ||
SignalSimilarityInfoWithIntDistance(16), test_entries[0][1] | ||
), | ||
], | ||
) | ||
@pytest.fixture | ||
def index(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ignorable: While this might be a misuse of feature, fixtures are basically a 1:1 mapping for setUp, so I think this a faithful translation |
||
return PDQIndex.build(test_entries) | ||
|
||
def assert_equal_pdq_index_match_results(result: t.List[PDQIndexMatch], expected: t.List[PDQIndexMatch]): | ||
assert len(result) == len(expected), "search results not of expected length" | ||
|
||
def quality_indexed_dict_reducer( | ||
acc: t.Dict[int, t.Set[int]], item: PDQIndexMatch | ||
) -> t.Dict[int, t.Set[int]]: | ||
acc[item.similarity_info.distance] = acc.get(item.similarity_info.distance, set()) | ||
acc[item.similarity_info.distance].add(hash(frozenset(item.metadata))) | ||
return acc | ||
|
||
distance_to_result_items_map = functools.reduce(quality_indexed_dict_reducer, result, {}) | ||
distance_to_expected_items_map = functools.reduce(quality_indexed_dict_reducer, expected, {}) | ||
|
||
assert len(distance_to_expected_items_map) == len(distance_to_result_items_map), "Unequal number of items" | ||
|
||
for distance, result_items in distance_to_result_items_map.items(): | ||
assert distance in distance_to_expected_items_map, f"Unexpected distance {distance} found" | ||
assert result_items == distance_to_expected_items_map[distance] | ||
|
||
def test_search_index_for_matches(index): | ||
entry_hash = test_entries[1][0] | ||
result = index.query(entry_hash) | ||
assert_equal_pdq_index_match_results( | ||
result, | ||
[ | ||
PDQIndexMatch(SignalSimilarityInfoWithIntDistance(0), test_entries[1][1]), | ||
PDQIndexMatch(SignalSimilarityInfoWithIntDistance(16), test_entries[0][1]), | ||
], | ||
) | ||
|
||
def test_search_index_with_no_match(index): | ||
query_hash = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" | ||
result = index.query(query_hash) | ||
assert_equal_pdq_index_match_results(result, []) | ||
|
||
def test_supports_pickling(index): | ||
pickled_data = pickle.dumps(index) | ||
assert pickled_data is not None, "index does not support pickling to a data stream" | ||
|
||
reconstructed_index = pickle.loads(pickled_data) | ||
assert reconstructed_index is not None, "index does not support unpickling from data stream" | ||
assert reconstructed_index.index.faiss_index != index.index.faiss_index, "unpickling should create its own faiss index in memory" | ||
|
||
query = test_entries[0][0] | ||
result = reconstructed_index.query(query) | ||
assert_equal_pdq_index_match_results( | ||
result, | ||
[ | ||
PDQIndexMatch(SignalSimilarityInfoWithIntDistance(0), test_entries[1][1]), | ||
PDQIndexMatch(SignalSimilarityInfoWithIntDistance(16), test_entries[0][1]), | ||
], | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
|
||
import pytest | ||
from threatexchange.signal_type.raw_text import RawTextSignal | ||
from threatexchange.signal_type.tests.signal_type_test_helper import MatchesStrAutoTest | ||
|
||
from threatexchange.signal_type.raw_text import RawTextSignal | ||
|
@@ -8,16 +10,19 @@ | |
class TestRawTextSignal(MatchesStrAutoTest): | ||
TYPE = RawTextSignal | ||
|
||
def get_validate_hash_cases(self): | ||
@pytest.fixture | ||
def validate_hash_cases(self): | ||
Comment on lines
+11
to
+12
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. blocking: This doesn't look like it needs to be a fixture - it looks like a better fit for pytest.mark.parametrize There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any response on this unadressed blocking comment? |
||
return [ | ||
("a", "a"), | ||
("a ", "a"), | ||
] | ||
|
||
def get_compare_hash_cases(self): | ||
@pytest.fixture | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. blocking: This is definitely doesn't need a fixture! You can just use a list! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any response on this unaddressed blocking comment? |
||
def compare_hash_cases(self): | ||
return [] | ||
|
||
def get_matches_str_cases(self): | ||
@pytest.fixture | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. blocking: Ditto that this looks like a better match for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any response on this unaddressed blocking comment? |
||
def matches_str_cases(self): | ||
return [ | ||
("", ""), | ||
("a", "a"), | ||
|
@@ -30,3 +35,18 @@ def get_matches_str_cases(self): | |
("a" * 19, "a" * 18 + "b", False, 1), | ||
("a" * 20, "a" * 19 + "b", True, 1), | ||
] | ||
|
||
def test_validate_hash(self, validate_hash_cases): | ||
for case in validate_hash_cases: | ||
input_val, expected_hash = case | ||
assert self.TYPE.validate_hash(input_val) == expected_hash | ||
|
||
def test_compare_hash(self, compare_hash_cases): | ||
for case in compare_hash_cases: | ||
input_val, expected_result = case | ||
assert self.TYPE.compare_hash(input_val) == expected_result | ||
|
||
def test_matches_str(self, matches_str_cases): | ||
for case in matches_str_cases: | ||
input_str, match_str, expected_match, threshold = case | ||
assert self.TYPE.matches_str(input_str, match_str, threshold) == expected_match |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure that you need a fixture of this - the test can just open itself.
Fixtures are helpful when you are sharing setup between tests, and here's there is only one test.