diff --git a/tests/test_client.py b/tests/test_client.py index 4f513ac..a2ce5a8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -176,6 +176,7 @@ def test_get_object(httpserver): def test_patch_object(httpserver): obj = Object("dummy_type", "dummy_id", {"foo": 1, "bar": 2}) + obj._context_attributes = {"a": "b"} # pylint: disable=protected-access obj.foo = 2 httpserver.expect_request( @@ -191,6 +192,7 @@ def test_patch_object(httpserver): "attributes": { "foo": 2, }, + "context_attributes": {"a": "b"} } } ) diff --git a/vt/client.py b/vt/client.py index fee750e..4fc7c5f 100644 --- a/vt/client.py +++ b/vt/client.py @@ -15,6 +15,7 @@ import asyncio import base64 +import functools import io import json import typing @@ -25,6 +26,7 @@ from .feed import Feed, FeedType from .iterator import Iterator from .object import Object +from .object import UserDictJsonEncoder from .utils import make_sync from .version import __version__ @@ -267,6 +269,7 @@ def _get_session(self) -> aiohttp.ClientSession: headers=headers, trust_env=self._trust_env, timeout=aiohttp.ClientTimeout(total=self._timeout), + json_serialize=functools.partial(json.dumps, cls=UserDictJsonEncoder) ) return self._session diff --git a/vt/object.py b/vt/object.py index fd3a668..58153e1 100644 --- a/vt/object.py +++ b/vt/object.py @@ -16,6 +16,7 @@ import collections import datetime import functools +import json import re import typing @@ -51,6 +52,15 @@ def __delitem__(self, item): super().__delitem__(item) +class UserDictJsonEncoder(json.JSONEncoder): + """Custom json encoder for UserDict objects.""" + + def default(self, o): + if isinstance(o, collections.UserDict): + return o.data + return super().default(o) + + class Object: """This class encapsulates any type of object in the VirusTotal API.