Skip to content

Commit

Permalink
fix(object): Object of type WhistleBlowerDict is not JSON serializable (
Browse files Browse the repository at this point in the history
  • Loading branch information
mgmacias95 authored Jan 30, 2024
1 parent 944e58c commit 011039a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def test_get_object(httpserver):

def test_patch_object(httpserver):
obj = Object("dummy_type", "dummy_id", {"foo": 1, "bar": 2})
obj._context_attributes = {"a": "b"} # pylint: disable=protected-access
obj.foo = 2

httpserver.expect_request(
Expand All @@ -191,6 +192,7 @@ def test_patch_object(httpserver):
"attributes": {
"foo": 2,
},
"context_attributes": {"a": "b"}
}
}
)
Expand Down
3 changes: 3 additions & 0 deletions vt/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import asyncio
import base64
import functools
import io
import json
import typing
Expand All @@ -25,6 +26,7 @@
from .feed import Feed, FeedType
from .iterator import Iterator
from .object import Object
from .object import UserDictJsonEncoder
from .utils import make_sync
from .version import __version__

Expand Down Expand Up @@ -267,6 +269,7 @@ def _get_session(self) -> aiohttp.ClientSession:
headers=headers,
trust_env=self._trust_env,
timeout=aiohttp.ClientTimeout(total=self._timeout),
json_serialize=functools.partial(json.dumps, cls=UserDictJsonEncoder)
)
return self._session

Expand Down
10 changes: 10 additions & 0 deletions vt/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import collections
import datetime
import functools
import json
import re
import typing

Expand Down Expand Up @@ -51,6 +52,15 @@ def __delitem__(self, item):
super().__delitem__(item)


class UserDictJsonEncoder(json.JSONEncoder):
"""Custom json encoder for UserDict objects."""

def default(self, o):
if isinstance(o, collections.UserDict):
return o.data
return super().default(o)


class Object:
"""This class encapsulates any type of object in the VirusTotal API.
Expand Down

0 comments on commit 011039a

Please sign in to comment.