import gc
import time
import unittest
import weakref
from ast import Or
from functools import partial
from threading import Thread
from unittest import TestCase
try:
import _testcapi
except ImportError:
_testcapi = None
from test.support import threading_helper
@threading_helper.requires_working_threading()
class TestDict(TestCase):
def test_racing_creation_shared_keys(self):
"""Verify that creating dictionaries is thread safe when we
have a type with shared keys"""
class C(int):
pass
self.racing_creation(C)
def test_racing_creation_no_shared_keys(self):
"""Verify that creating dictionaries is thread safe when we
have a type with an ordinary dict"""
self.racing_creation(Or)
def test_racing_creation_inline_values_invalid(self):
"""Verify that re-creating a dict after we have invalid inline values
is thread safe"""
class C:
pass
def make_obj():
a = C()
# Make object, make inline values invalid, and then delete dict
a.__dict__ = {}
del a.__dict__
return a
self.racing_creation(make_obj)
def test_racing_creation_nonmanaged_dict(self):
"""Verify that explicit creation of an unmanaged dict is thread safe
outside of the normal attribute setting code path"""
def make_obj():
def f(): pass
return f
def set(func, name, val):
# Force creation of the dict via PyObject_GenericGetDict
func.__dict__[name] = val
self.racing_creation(make_obj, set)
def racing_creation(self, cls, set=setattr):
objects = []
processed = []
OBJECT_COUNT = 100
THREAD_COUNT = 10
CUR = 0
for i in range(OBJECT_COUNT):
objects.append(cls())
def writer_func(name):
last = -1
while True:
if CUR == last:
continue
elif CUR == OBJECT_COUNT:
break
obj = objects[CUR]
set(obj, name, name)
last = CUR
processed.append(name)
writers = []
for x in range(THREAD_COUNT):
writer = Thread(target=partial(writer_func, f"a{x:02}"))
writers.append(writer)
writer.start()
for i in range(OBJECT_COUNT):
CUR = i
while len(processed) != THREAD_COUNT:
time.sleep(0.001)
processed.clear()
CUR = OBJECT_COUNT
for writer in writers:
writer.join()
for obj_idx, obj in enumerate(objects):
assert (
len(obj.__dict__) == THREAD_COUNT
), f"{len(obj.__dict__)} {obj.__dict__!r} {obj_idx}"
for i in range(THREAD_COUNT):
assert f"a{i:02}" in obj.__dict__, f"a{i:02} missing at {obj_idx}"
def test_racing_set_dict(self):
"""Races assigning to __dict__ should be thread safe"""
def f(): pass
l = []
THREAD_COUNT = 10
class MyDict(dict): pass
def writer_func(l):
for i in range(1000):
d = MyDict()
l.append(weakref.ref(d))
f.__dict__ = d
lists = []
writers = []
for x in range(THREAD_COUNT):
thread_list = []
lists.append(thread_list)
writer = Thread(target=partial(writer_func, thread_list))
writers.append(writer)
for writer in writers:
writer.start()
for writer in writers:
writer.join()
f.__dict__ = {}
gc.collect()
for thread_list in lists:
for ref in thread_list:
self.assertIsNone(ref())
@unittest.skipIf(_testcapi is None, 'need _testcapi module')
def test_dict_version(self):
dict_version = _testcapi.dict_version
THREAD_COUNT = 10
DICT_COUNT = 10000
lists = []
writers = []
def writer_func(thread_list):
for i in range(DICT_COUNT):
thread_list.append(dict_version({}))
for x in range(THREAD_COUNT):
thread_list = []
lists.append(thread_list)
writer = Thread(target=partial(writer_func, thread_list))
writers.append(writer)
for writer in writers:
writer.start()
for writer in writers:
writer.join()
total_len = 0
values = set()
for thread_list in lists:
for v in thread_list:
if v in values:
print('dup', v, (v/4096)%256)
values.add(v)
total_len += len(thread_list)
versions = set(dict_version for thread_list in lists for dict_version in thread_list)
self.assertEqual(len(versions), THREAD_COUNT*DICT_COUNT)
if __name__ == "__main__":
unittest.main()