From 8524bfb54db977c74c7da31765fa36e49753eb3e Mon Sep 17 00:00:00 2001 From: Bryan Forbes Date: Mon, 28 Mar 2022 15:24:03 -0500 Subject: [PATCH] Fix caching of UnionType instead of resolved typing.Union --- discord/utils.py | 4 +-- tests/test_utils.py | 59 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/discord/utils.py b/discord/utils.py index 3b71b4051..aaa18b780 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -1032,9 +1032,9 @@ def evaluate_annotation( if implicit_str and isinstance(tp, str): if tp in cache: return cache[tp] - evaluated = eval(tp, globals, locals) + evaluated = evaluate_annotation(eval(tp, globals, locals), globals, locals, cache) cache[tp] = evaluated - return evaluate_annotation(evaluated, globals, locals, cache) + return evaluated if hasattr(tp, '__args__'): implicit_str = True diff --git a/tests/test_utils.py b/tests/test_utils.py index b1977da98..5f8060d52 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -164,7 +164,9 @@ def test_resolve_template(url, code): assert utils.resolve_template(url) == code -@pytest.mark.parametrize('mention', ['@everyone', '@here', '<@80088516616269824>', '<@!80088516616269824>', '<@&381978264698224660>']) +@pytest.mark.parametrize( + 'mention', ['@everyone', '@here', '<@80088516616269824>', '<@!80088516616269824>', '<@&381978264698224660>'] +) def test_escape_mentions(mention): assert mention not in utils.escape_mentions(mention) assert mention not in utils.escape_mentions(f"one {mention} two") @@ -198,6 +200,37 @@ def test_resolve_annotation(annotation, resolved): assert resolved == utils.resolve_annotation(annotation, globals(), locals(), None) +@pytest.mark.parametrize( + ('annotation', 'resolved', 'check_cache'), + [ + (datetime.datetime, datetime.datetime, False), + ('datetime.datetime', datetime.datetime, True), + ( + 'typing.Union[typing.Literal["a"], typing.Literal["b"]]', + typing.Union[typing.Literal["a"], typing.Literal["b"]], + True, + ), + ('typing.Union[typing.Union[int, str], typing.Union[bool, dict]]', typing.Union[int, str, bool, dict], True), + ], +) +def test_resolve_annotation_with_cache(annotation, resolved, check_cache): + cache = {} + + assert resolved == utils.resolve_annotation(annotation, globals(), locals(), cache) + + if check_cache: + assert len(cache) == 1 + + cached_item = cache[annotation] + + latest = utils.resolve_annotation(annotation, globals(), locals(), cache) + + assert latest is cached_item + assert typing.get_origin(latest) is typing.get_origin(resolved) + else: + assert len(cache) == 0 + + def test_resolve_annotation_optional_normalisation(): value = utils.resolve_annotation('typing.Union[None, int]', globals(), locals(), None) assert value.__args__ == (int, type(None)) @@ -216,6 +249,30 @@ def test_resolve_annotation_310(annotation, resolved): assert resolved == utils.resolve_annotation(annotation, globals(), locals(), None) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="3.10 union syntax") +@pytest.mark.parametrize( + ('annotation', 'resolved'), + [ + ('int | None', typing.Optional[int]), + ('str | int', typing.Union[str, int]), + ('str | int | None', typing.Optional[typing.Union[str, int]]), + ], +) +def test_resolve_annotation_with_cache_310(annotation, resolved): + cache = {} + + assert resolved == utils.resolve_annotation(annotation, globals(), locals(), cache) + assert typing.get_origin(resolved) is typing.Union + + assert len(cache) == 1 + + cached_item = cache[annotation] + + latest = utils.resolve_annotation(annotation, globals(), locals(), cache) + assert latest is cached_item + assert typing.get_origin(latest) is typing.get_origin(resolved) + + # is_inside_class tests