diff --git a/docs/source/structs.rst b/docs/source/structs.rst index d90bf2fa..516f5f81 100644 --- a/docs/source/structs.rst +++ b/docs/source/structs.rst @@ -334,11 +334,12 @@ A struct's tagging configuration is determined as follows. a union. - If a struct is tagged, ``tag`` defaults to the class name (e.g. ``"Get"``) if - not provided or inherited. This can be overridden by passing a tag value - explicitly (e.g. ``tag="get"``). or a callable from class name to ``str`` - (e.g. ``tag=lambda name: name.lower()`` to lowercase the class name - automatically). Note that the tag value must be unique for all struct types - in a union. + not provided or inherited. This can be overridden by passing a string (or + less commonly an integer) value explicitly (e.g. ``tag="get"``). ``tag`` can + also be passed a callable that takes the class name and returns a valid tag + value (e.g. ``tag=str.lower``). Note that tag values must be unique for all + struct types in a union, and ``str`` and ``int`` tag types cannot both be + used within the same union. If you like subclassing, both ``tag_field`` and ``tag`` are inheritable by subclasses, allowing configuration to be set once on a base class and reused @@ -353,7 +354,7 @@ for all struct types you wish to tag. >>> # Create a base class for tagged structs, where: ... # - the tag field is "op" ... # - the tag is the class name lowercased - ... class TaggedBase(msgspec.Struct, tag_field="op", tag=lambda name: name.lower()): + ... class TaggedBase(msgspec.Struct, tag_field="op", tag=str.lower): ... pass >>> # Use the base class to pass on the configuration diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 33fbef51..e3765893 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -338,7 +338,7 @@ or doesn't match any valid `enum.IntEnum` member. >>> msgspec.json.decode(b'4', type=JobState) Traceback (most recent call last): File "", line 1, in - msgspec.DecodeError: Invalid enum value `4` + msgspec.DecodeError: Invalid enum value 4 ``Enum`` ~~~~~~~~ @@ -565,7 +565,7 @@ values, or doesn't match any of their component types. >>> msgspec.json.decode(b'4', type=Literal[1, 2, 3]) Traceback (most recent call last): File "", line 1, in - msgspec.DecodeError: Invalid enum value `4` + msgspec.DecodeError: Invalid enum value 4 >>> msgspec.json.decode(b'"bad"', type=Literal[1, 2, 3]) Traceback (most recent call last): diff --git a/msgspec/__init__.pyi b/msgspec/__init__.pyi index 2acabc6b..c597dbd4 100644 --- a/msgspec/__init__.pyi +++ b/msgspec/__init__.pyi @@ -38,7 +38,7 @@ class Struct(metaclass=__StructMeta): def __init__(self, *args: Any, **kwargs: Any) -> None: ... def __init_subclass__( cls, - tag: Union[None, bool, str, Callable[[str], str]] = None, + tag: Union[None, bool, str, int, Callable[[str], Union[str, int]]] = None, tag_field: Union[None, str] = None, rename: Union[ None, Literal["lower", "upper", "camel", "pascal"], Callable[[str], str] @@ -59,7 +59,7 @@ def defstruct( bases: Tuple[Type[Struct], ...] = (), module: Optional[str] = None, namespace: Optional[Dict[str, Any]] = None, - tag: Union[None, bool, str, Callable[[str], str]] = None, + tag: Union[None, bool, str, int, Callable[[str], Union[str, int]]] = None, tag_field: Union[None, str] = None, rename: Union[ None, Literal["lower", "upper", "camel", "pascal"], Callable[[str], str] diff --git a/msgspec/_core.c b/msgspec/_core.c index bd67039e..c4ceb544 100644 --- a/msgspec/_core.c +++ b/msgspec/_core.c @@ -595,116 +595,92 @@ strbuilder_build(strbuilder *self) { * Lookup Tables for ints & strings * *************************************************************************/ +typedef struct Lookup { + PyObject_VAR_HEAD + PyObject *tag_field; /* used for struct lookup table only */ + bool array_like; +} Lookup; + static PyTypeObject IntLookup_Type; +static PyTypeObject StrLookup_Type; -typedef struct IntLookupObject { - PyObject_VAR_HEAD - int64_t offset; +typedef struct IntLookup { + Lookup common; bool compact; +} IntLookup; + +typedef struct IntLookupEntry { + int64_t key; + PyObject *value; +} IntLookupEntry; + +typedef struct IntLookupHashmap { + IntLookup base; + IntLookupEntry table[]; +} IntLookupHashmap; + +typedef struct IntLookupCompact { + IntLookup base; + int64_t offset; PyObject* table[]; -} IntLookupObject; +} IntLookupCompact; -static PyObject** -_IntLookup_lookup_int64(IntLookupObject *self, int64_t key) -{ - int64_t entry_val; - PyObject **entry, **table = self->table; - size_t hash = key; - size_t perturb = hash; - size_t mask = Py_SIZE(self) - 1; - size_t i = hash & mask; +typedef struct StrLookupEntry { + PyObject *key; + PyObject *value; +} StrLookupEntry; - while (true) { - entry = &table[i]; - if (*entry == NULL) return entry; - int overflow = 0; - entry_val = PyLong_AsLongLongAndOverflow(*entry, &overflow); - if (!overflow) { - if (entry_val == -1 && PyErr_Occurred()) return NULL; - if (entry_val == key) return entry; - } - /* Collision, perturb and try again */ - perturb >>= 5; - i = mask & (i*5 + perturb + 1); - } - /* Unreachable */ - return NULL; -} +typedef struct StrLookup { + Lookup common; + StrLookupEntry table[]; +} StrLookup; -static PyObject** -_IntLookup_lookup_uint64(IntLookupObject *self, uint64_t key) -{ - uint64_t entry_val; - PyObject **entry, **table = self->table; - size_t hash = key; - size_t perturb = hash; +#define Lookup_array_like(obj) ((Lookup *)(obj))->array_like +#define Lookup_tag_field(obj) ((Lookup *)(obj))->tag_field +#define Lookup_IsStrLookup(obj) (Py_TYPE(obj) == &StrLookup_Type) +#define Lookup_IsIntLookup(obj) (Py_TYPE(obj) == &IntLookup_Type) + +static IntLookupEntry * +_IntLookupHashmap_lookup(IntLookupHashmap *self, int64_t key) { + IntLookupEntry *table = self->table; size_t mask = Py_SIZE(self) - 1; - size_t i = hash & mask; + size_t i = key & mask; while (true) { - entry = &table[i]; - if (*entry == NULL) return entry; - entry_val = PyLong_AsUnsignedLongLong(*entry); - if (entry_val == ((uint64_t)-1) && PyErr_Occurred()) { - /* Negative value, can't match key */ - PyErr_Clear(); - } - else if (entry_val == key) { - return entry; - } - /* Collision, perturb and try again */ - perturb >>= 5; - i = mask & (i*5 + perturb + 1); + IntLookupEntry *entry = &table[i]; + if (MS_LIKELY(entry->key == key)) return entry; + if (entry->value == NULL) return entry; + i = (i + 1) & mask; } /* Unreachable */ return NULL; } -static int -IntLookup_Add(IntLookupObject *self, PyObject *item) { - int overflow = 0; - int64_t ival = PyLong_AsLongLongAndOverflow(item, &overflow); - if (!overflow) { - if (ival == -1 && PyErr_Occurred()) return -1; - PyObject **entry = _IntLookup_lookup_int64(self, ival); - if (entry == NULL) return -1; - if (*entry == NULL) { - *entry = item; - Py_INCREF(item); - } - } - else { - uint64_t uval = PyLong_AsUnsignedLongLong(item); - if (uval == ((uint64_t)-1) && PyErr_Occurred()) return -1; - PyObject **entry = _IntLookup_lookup_uint64(self, uval); - if (entry == NULL) return -1; - if (*entry == NULL) { - *entry = item; - Py_INCREF(item); - } - } - return 0; +static void +_IntLookupHashmap_Set(IntLookupHashmap *self, int64_t key, PyObject *value) { + IntLookupEntry *entry = _IntLookupHashmap_lookup(self, key); + Py_XDECREF(entry->value); + Py_INCREF(value); + entry->key = key; + entry->value = value; } -static IntLookupObject * -IntLookup_New(PyObject *arg) -{ +static PyObject * +IntLookup_New(PyObject *arg, PyObject *tag_field, bool array_like) { Py_ssize_t nitems; - PyObject *item, *items; - IntLookupObject *self = NULL; + PyObject *item, *items = NULL; + IntLookup *self = NULL; int64_t imin = LLONG_MAX, imax = LLONG_MIN; - bool uint64_present = false; - if (!PyTuple_CheckExact(arg)) { - items = PySequence_Tuple(arg); - if (items == NULL) return NULL; + if (PyDict_CheckExact(arg)) { + nitems = PyDict_GET_SIZE(arg); } else { - items = arg; + items = PySequence_Tuple(arg); + if (items == NULL) return NULL; + nitems = PyTuple_GET_SIZE(items); } - nitems = PyTuple_GET_SIZE(items); - /* Must have at least one item */ if (nitems == 0) { PyErr_Format( @@ -717,28 +693,40 @@ IntLookup_New(PyObject *arg) /* Find the min/max of items, and error if any item isn't an integer or is * out of range */ - for (Py_ssize_t i = 0; i < nitems; i++) { - int64_t ival; - int overflow = 0; - PyObject* item; - - item = PyTuple_GET_ITEM(items, i); - ival = PyLong_AsLongLongAndOverflow(item, &overflow); - if (overflow) { - uint64_present = true; - } - else if (ival == -1 && PyErr_Occurred()) { - goto cleanup; +#define handle(key) \ + do { \ + int overflow = 0; \ + int64_t ival = PyLong_AsLongLongAndOverflow(key, &overflow); \ + if (overflow) { \ + PyErr_SetString( \ + PyExc_NotImplementedError, \ + "Integer values > (2**63 - 1) are not currently supported for " \ + "IntEnum/Literal/integer tags. If you need this feature, please " \ + "open an issue on GitHub." \ + ); \ + goto cleanup; \ + } \ + if (ival == -1 && PyErr_Occurred()) goto cleanup; \ + if (ival < imin) { \ + imin = ival; \ + } \ + if (ival > imax) { \ + imax = ival; \ + } \ + } while (false) + if (PyDict_CheckExact(arg)) { + PyObject *key, *val; + Py_ssize_t pos = 0; + while (PyDict_Next(arg, &pos, &key, &val)) { + handle(key); } - else { - if (ival < imin) { - imin = ival; - } - if (ival > imax) { - imax = ival; - } + } + else { + for (Py_ssize_t i = 0; i < nitems; i++) { + handle(PyTuple_GET_ITEM(items, i)); } } +#undef handle /* Calculate range without overflow */ uint64_t range; @@ -750,79 +738,141 @@ IntLookup_New(PyObject *arg) range = imax - imin; } - if (!uint64_present && (range < 1.4 * nitems)) { - size_t size = range + 1; + if (range < 1.4 * nitems) { /* Use compact representation */ - self = PyObject_GC_NewVar(IntLookupObject, &IntLookup_Type, size); - if (self == NULL) goto cleanup; - self->offset = imin; - self->compact = true; + size_t size = range + 1; + + IntLookupCompact *out = (IntLookupCompact *) _PyObject_GC_Malloc( + sizeof(IntLookupCompact) + (size + 1) * sizeof(PyObject *) + ); + if (out == NULL) { + PyErr_NoMemory(); + goto cleanup; + } + PyObject_InitVar((PyVarObject *)out, &IntLookup_Type, size); + + out->offset = imin; for (size_t i = 0; i < size; i++) { - self->table[i] = NULL; + out->table[i] = NULL; } - for (Py_ssize_t i = 0; i < nitems; i++) { - item = PyTuple_GET_ITEM(items, i); - int64_t ival = PyLong_AsLongLong(item); - if (ival == -1 && PyErr_Occurred()) { - Py_CLEAR(self); - goto cleanup; + +#define setitem(key, val) \ + do { \ + int64_t ikey = PyLong_AsLongLong(key); \ + out->table[ikey - imin] = val; \ + Py_INCREF(val); \ + } while (false) + + if (PyDict_CheckExact(arg)) { + PyObject *key, *val; + Py_ssize_t pos = 0; + while (PyDict_Next(arg, &pos, &key, &val)) { + setitem(key, val); + } + } + else { + for (Py_ssize_t i = 0; i < nitems; i++) { + item = PyTuple_GET_ITEM(items, i); + setitem(item, item); } - self->table[ival - imin] = item; - Py_INCREF(item); } + +#undef setitem + + self = (IntLookup *)out; + self->compact = true; } else { /* Use hashtable */ size_t needed = nitems * 4 / 3; size_t size = 4; - while (size < (size_t)needed) { - size <<= 1; + while (size < (size_t)needed) { size <<= 1; } + + IntLookupHashmap *out = (IntLookupHashmap *) _PyObject_GC_Malloc( + sizeof(IntLookupHashmap) + (size + 1) * sizeof(IntLookupEntry) + ); + if (out == NULL) { + PyErr_NoMemory(); + goto cleanup; } - self = PyObject_GC_NewVar(IntLookupObject, &IntLookup_Type, size); - if (self == NULL) goto cleanup; - self->compact = false; + PyObject_InitVar((PyVarObject *)out, &IntLookup_Type, size); + for (size_t i = 0; i < size; i++) { - self->table[i] = NULL; + out->table[i].key = 0; + out->table[i].value = NULL; } - for (Py_ssize_t i = 0; i < nitems; i++) { - item = PyTuple_GET_ITEM(items, i); - if (IntLookup_Add(self, item) < 0) { - Py_CLEAR(self); - goto cleanup; + + if (PyDict_CheckExact(arg)) { + PyObject *key, *val; + Py_ssize_t pos = 0; + while (PyDict_Next(arg, &pos, &key, &val)) { + int64_t ival = PyLong_AsLongLong(key); + _IntLookupHashmap_Set(out, ival, val); + } + } + else { + for (Py_ssize_t i = 0; i < nitems; i++) { + PyObject *val = PyTuple_GET_ITEM(items, i); + int64_t ival = PyLong_AsLongLong(val); + _IntLookupHashmap_Set(out, ival, val); } } + self = (IntLookup *)out; + self->compact = false; } + /* Store the tag field & array_like status (struct lookup only) */ + Py_XINCREF(tag_field); + self->common.tag_field = tag_field; + self->common.array_like = array_like; + cleanup: - if (arg != items) { - Py_DECREF(items); - } + Py_XDECREF(items); if (self != NULL) { PyObject_GC_Track(self); } - return self; + return (PyObject *)self; } static int -IntLookup_traverse(IntLookupObject *self, visitproc visit, void *arg) +IntLookup_traverse(IntLookup *self, visitproc visit, void *arg) { - for (Py_ssize_t i = 0; i < Py_SIZE(self); i++) { - Py_VISIT(self->table[i]); + if (self->compact) { + IntLookupCompact *lk = (IntLookupCompact *)self; + for (Py_ssize_t i = 0; i < Py_SIZE(lk); i++) { + Py_VISIT(lk->table[i]); + } + } + else { + IntLookupHashmap *lk = (IntLookupHashmap *)self; + for (Py_ssize_t i = 0; i < Py_SIZE(lk); i++) { + Py_VISIT(lk->table[i].value); + } } return 0; } static int -IntLookup_clear(IntLookupObject *self) +IntLookup_clear(IntLookup *self) { - for (Py_ssize_t i = 0; i < Py_SIZE(self); i++) { - Py_CLEAR(self->table[i]); + if (self->compact) { + IntLookupCompact *lk = (IntLookupCompact *)self; + for (Py_ssize_t i = 0; i < Py_SIZE(lk); i++) { + Py_CLEAR(lk->table[i]); + } + } + else { + IntLookupHashmap *lk = (IntLookupHashmap *)self; + for (Py_ssize_t i = 0; i < Py_SIZE(lk); i++) { + Py_CLEAR(lk->table[i].value); + } } + Py_CLEAR(self->common.tag_field); return 0; } static void -IntLookup_dealloc(IntLookupObject *self) +IntLookup_dealloc(IntLookup *self) { PyObject_GC_UnTrack(self); IntLookup_clear(self); @@ -830,69 +880,37 @@ IntLookup_dealloc(IntLookupObject *self) } static PyObject * -_IntLookup_get_compact(IntLookupObject *self, int64_t key) { - Py_ssize_t index = key - self->offset; - if (index >= 0 && index < Py_SIZE(self)) { - return self->table[index]; - } - return NULL; -} - -static PyObject * -IntLookup_GetInt64(IntLookupObject *self, int64_t key) { - PyObject *out; - if (self->compact) { - out = _IntLookup_get_compact(self, key); - } - else { - PyObject **entry = _IntLookup_lookup_int64(self, key); - if (entry == NULL) return NULL; - out = *entry; +IntLookup_GetInt64(IntLookup *self, int64_t key) { + if (MS_LIKELY(self->compact)) { + IntLookupCompact *lk = (IntLookupCompact *)self; + Py_ssize_t index = key - lk->offset; + if (index >= 0 && index < Py_SIZE(lk)) { + return lk->table[index]; + } + return NULL; } - return out; + return _IntLookupHashmap_lookup((IntLookupHashmap *)self, key)->value; } static PyObject * -IntLookup_GetUInt64(IntLookupObject *self, uint64_t key) { - PyObject *out; - if (self->compact) { - out = _IntLookup_get_compact(self, key); - } - else { - PyObject **entry = _IntLookup_lookup_uint64(self, key); - if (entry == NULL) return NULL; - out = *entry; - } - return out; +IntLookup_GetUInt64(IntLookup *self, uint64_t key) { + if (key > LLONG_MAX) return NULL; + return IntLookup_GetInt64(self, key); } static PyTypeObject IntLookup_Type = { PyVarObject_HEAD_INIT(NULL, 0) .tp_name = "msgspec._core.IntLookup", - .tp_basicsize = sizeof(IntLookupObject), - .tp_itemsize = sizeof(PyObject *), + .tp_basicsize = sizeof(IntLookup), + .tp_itemsize = 0, .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, - .tp_dealloc = (destructor) IntLookup_dealloc, + .tp_dealloc = (destructor)IntLookup_dealloc, .tp_clear = (inquiry)IntLookup_clear, - .tp_traverse = (traverseproc) IntLookup_traverse, + .tp_traverse = (traverseproc)IntLookup_traverse, }; -static PyTypeObject StrLookup_Type; - -typedef struct StrLookupEntry { - PyObject *key; - PyObject *value; -} StrLookupEntry; - -typedef struct StrLookupObject { - PyObject_VAR_HEAD - PyObject *tag_field; /* used for struct lookup table only */ - bool array_like; - StrLookupEntry table[]; -} StrLookupObject; - static StrLookupEntry * -_StrLookup_lookup(StrLookupObject *self, const char *key, Py_ssize_t size) +_StrLookup_lookup(StrLookup *self, const char *key, Py_ssize_t size) { StrLookupEntry *table = self->table; size_t hash = murmur2(key, size); @@ -915,7 +933,7 @@ _StrLookup_lookup(StrLookupObject *self, const char *key, Py_ssize_t size) } static int -StrLookup_Set(StrLookupObject *self, PyObject *key, PyObject *value) { +StrLookup_Set(StrLookup *self, PyObject *key, PyObject *value) { Py_ssize_t key_size; const char *key_str = unicode_str_and_size(key, &key_size); if (key_str == NULL) return -1; @@ -929,10 +947,10 @@ StrLookup_Set(StrLookupObject *self, PyObject *key, PyObject *value) { } static PyObject * -StrLookup_NewFullArgs(PyObject *arg, PyObject *tag_field, bool array_like) { +StrLookup_New(PyObject *arg, PyObject *tag_field, bool array_like) { Py_ssize_t nitems; PyObject *item, *items = NULL; - StrLookupObject *self = NULL; + StrLookup *self = NULL; if (PyDict_CheckExact(arg)) { nitems = PyDict_GET_SIZE(arg); @@ -958,7 +976,7 @@ StrLookup_NewFullArgs(PyObject *arg, PyObject *tag_field, bool array_like) { while (size < (size_t)needed) { size <<= 1; } - self = PyObject_GC_NewVar(StrLookupObject, &StrLookup_Type, size); + self = PyObject_GC_NewVar(StrLookup, &StrLookup_Type, size); if (self == NULL) goto cleanup; /* Zero out memory */ for (size_t i = 0; i < size; i++) { @@ -966,11 +984,6 @@ StrLookup_NewFullArgs(PyObject *arg, PyObject *tag_field, bool array_like) { self->table[i].value = NULL; } - /* Store the tag field & array_like status (struct lookup only) */ - Py_XINCREF(tag_field); - self->tag_field = tag_field; - self->array_like = array_like; - if (PyDict_CheckExact(arg)) { PyObject *key, *val; Py_ssize_t pos = 0; @@ -1002,6 +1015,11 @@ StrLookup_NewFullArgs(PyObject *arg, PyObject *tag_field, bool array_like) { } } + /* Store the tag field & array_like status (struct lookup only) */ + Py_XINCREF(tag_field); + self->common.tag_field = tag_field; + self->common.array_like = array_like; + cleanup: Py_XDECREF(items); if (self != NULL) { @@ -1010,13 +1028,8 @@ StrLookup_NewFullArgs(PyObject *arg, PyObject *tag_field, bool array_like) { return (PyObject *)self; } -static PyObject * -StrLookup_New(PyObject *arg) { - return StrLookup_NewFullArgs(arg, NULL, false); -} - static int -StrLookup_traverse(StrLookupObject *self, visitproc visit, void *arg) +StrLookup_traverse(StrLookup *self, visitproc visit, void *arg) { for (Py_ssize_t i = 0; i < Py_SIZE(self); i++) { Py_VISIT(self->table[i].key); @@ -1026,18 +1039,18 @@ StrLookup_traverse(StrLookupObject *self, visitproc visit, void *arg) } static int -StrLookup_clear(StrLookupObject *self) +StrLookup_clear(StrLookup *self) { for (Py_ssize_t i = 0; i < Py_SIZE(self); i++) { Py_CLEAR(self->table[i].key); Py_CLEAR(self->table[i].value); } - Py_CLEAR(self->tag_field); + Py_CLEAR(self->common.tag_field); return 0; } static void -StrLookup_dealloc(StrLookupObject *self) +StrLookup_dealloc(StrLookup *self) { PyObject_GC_UnTrack(self); StrLookup_clear(self); @@ -1045,7 +1058,7 @@ StrLookup_dealloc(StrLookupObject *self) } static PyObject * -StrLookup_Get(StrLookupObject *self, const char *key, Py_ssize_t size) { +StrLookup_Get(StrLookup *self, const char *key, Py_ssize_t size) { StrLookupEntry *entry = _StrLookup_lookup(self, key, size); return entry->value; } @@ -1053,7 +1066,7 @@ StrLookup_Get(StrLookupObject *self, const char *key, Py_ssize_t size) { static PyTypeObject StrLookup_Type = { PyVarObject_HEAD_INIT(NULL, 0) .tp_name = "msgspec._core.StrLookup", - .tp_basicsize = sizeof(StrLookupObject), + .tp_basicsize = sizeof(StrLookup), .tp_itemsize = sizeof(StrLookupEntry), .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, .tp_dealloc = (destructor) StrLookup_dealloc, @@ -1398,7 +1411,7 @@ TypeNode_get_struct(TypeNode *type) { return ((TypeNodeExtra *)type)->extra[0]; } -static MS_INLINE StrLookupObject * +static MS_INLINE Lookup * TypeNode_get_struct_union(TypeNode *type) { /* Struct union types are always first */ return ((TypeNodeExtra *)type)->extra[0]; @@ -1410,7 +1423,7 @@ TypeNode_get_custom(TypeNode *type) { return ((TypeNodeExtra *)type)->extra[0]; } -static MS_INLINE IntLookupObject * +static MS_INLINE IntLookup * TypeNode_get_int_enum_or_literal(TypeNode *type) { Py_ssize_t i = ms_popcount( type->types & ( @@ -1421,7 +1434,7 @@ TypeNode_get_int_enum_or_literal(TypeNode *type) { return ((TypeNodeExtra *)type)->extra[i]; } -static MS_INLINE StrLookupObject * +static MS_INLINE StrLookup * TypeNode_get_str_enum_or_literal(TypeNode *type) { Py_ssize_t i = ms_popcount( type->types & ( @@ -1633,14 +1646,14 @@ typenode_from_collect_state(TypeNodeCollectState *state, bool err_not_json, bool if (lookup == NULL) { /* IntLookup isn't created yet, create and store on enum class */ PyErr_Clear(); - lookup = (PyObject *)IntLookup_New(state->intenum_obj); + lookup = IntLookup_New(state->intenum_obj, NULL, false); if (lookup == NULL) goto error; if (PyObject_SetAttr(state->intenum_obj, state->mod->str___msgspec_lookup__, lookup) < 0) { Py_DECREF(lookup); goto error; } } - else if (Py_TYPE(lookup) != &IntLookup_Type) { + else if (!Lookup_IsIntLookup(lookup)) { /* the lookup attribute has been overwritten, error */ Py_DECREF(lookup); PyErr_Format( @@ -1663,7 +1676,7 @@ typenode_from_collect_state(TypeNodeCollectState *state, bool err_not_json, bool PyErr_Clear(); PyObject *member_map = PyObject_GetAttr(state->enum_obj, state->mod->str__member_map_); if (member_map == NULL) goto error; - lookup = StrLookup_New(member_map); + lookup = StrLookup_New(member_map, NULL, false); Py_DECREF(member_map); if (lookup == NULL) goto error; if (PyObject_SetAttr(state->enum_obj, state->mod->str___msgspec_lookup__, lookup) < 0) { @@ -1984,16 +1997,16 @@ typenode_collect_convert_literals(TypeNodeCollectState *state) { PyObject *str_lookup = PyTuple_GET_ITEM(cached, 1); if ( ( - (int_lookup == Py_None || Py_TYPE(int_lookup) == &IntLookup_Type) && - (str_lookup == Py_None || Py_TYPE(str_lookup) == &StrLookup_Type) + (int_lookup == Py_None || Lookup_IsIntLookup(int_lookup)) && + (str_lookup == Py_None || Lookup_IsStrLookup(str_lookup)) ) ) { - if (Py_TYPE(int_lookup) == &IntLookup_Type) { + if (Lookup_IsIntLookup(int_lookup)) { Py_INCREF(int_lookup); state->types |= MS_TYPE_INTLITERAL; state->int_literal_lookup = int_lookup; } - if (Py_TYPE(str_lookup) == &StrLookup_Type) { + if (Lookup_IsStrLookup(str_lookup)) { Py_INCREF(str_lookup); state->types |= MS_TYPE_STRLITERAL; state->str_literal_lookup = str_lookup; @@ -2020,12 +2033,12 @@ typenode_collect_convert_literals(TypeNodeCollectState *state) { /* Convert values to lookup objects (if values exist for each type) */ if (state->int_literal_values != NULL) { state->types |= MS_TYPE_INTLITERAL; - state->int_literal_lookup = (PyObject *)IntLookup_New(state->int_literal_values); + state->int_literal_lookup = IntLookup_New(state->int_literal_values, NULL, false); if (state->int_literal_lookup == NULL) return -1; } if (state->str_literal_values != NULL) { state->types |= MS_TYPE_STRLITERAL; - state->str_literal_lookup = StrLookup_New(state->str_literal_values); + state->str_literal_lookup = StrLookup_New(state->str_literal_values, NULL, false); if (state->str_literal_lookup == NULL) return -1; } @@ -2050,12 +2063,12 @@ typenode_collect_convert_literals(TypeNodeCollectState *state) { /* Convert values to lookup objects (if values exist for each type) */ if (state->int_literal_values != NULL) { state->types |= MS_TYPE_INTLITERAL; - state->int_literal_lookup = (PyObject *)IntLookup_New(state->int_literal_values); + state->int_literal_lookup = IntLookup_New(state->int_literal_values, NULL, false); if (state->int_literal_lookup == NULL) return -1; } if (state->str_literal_values != NULL) { state->types |= MS_TYPE_STRLITERAL; - state->str_literal_lookup = StrLookup_New(state->str_literal_values); + state->str_literal_lookup = StrLookup_New(state->str_literal_values, NULL, false); if (state->str_literal_lookup == NULL) return -1; } return 0; @@ -2096,7 +2109,7 @@ typenode_collect_convert_structs( Py_INCREF(lookup); state->structs_lookup = lookup; - if (((StrLookupObject *)lookup)->array_like) { + if (Lookup_array_like(lookup)) { state->types |= MS_TYPE_STRUCT_ARRAY_UNION; } else { @@ -2120,6 +2133,7 @@ typenode_collect_convert_structs( Py_ssize_t set_pos = 0; Py_hash_t set_hash; bool array_like = false; + bool tags_are_strings = true; int status = -1; tag_mapping = PyDict_New(); @@ -2149,10 +2163,10 @@ typenode_collect_convert_structs( ); goto cleanup; } - if (tag_field == NULL) { array_like = struct_type->array_like == OPT_TRUE; tag_field = struct_type->struct_tag_field; + tags_are_strings = PyUnicode_CheckExact(item_tag_value); } else { if (array_like != item_array_like) { @@ -2164,6 +2178,15 @@ typenode_collect_convert_structs( ); goto cleanup; } + if (tags_are_strings != PyUnicode_CheckExact(item_tag_value)) { + PyErr_Format( + PyExc_TypeError, + "Type unions may not contain Struct types with both `int` " + "and `str` tags - type `%R` is not supported", + state->context + ); + goto cleanup; + } int compare = PyUnicode_Compare(item_tag_field, tag_field); if (compare == -1 && PyErr_Occurred()) goto cleanup; @@ -2191,7 +2214,12 @@ typenode_collect_convert_structs( } } /* Build a lookup from tag_value -> struct_type */ - lookup = StrLookup_NewFullArgs(tag_mapping, tag_field, array_like); + if (tags_are_strings) { + lookup = StrLookup_New(tag_mapping, tag_field, array_like); + } + else { + lookup = IntLookup_New(tag_mapping, tag_field, array_like); + } if (lookup == NULL) goto cleanup; state->structs_lookup = lookup; @@ -2641,6 +2669,18 @@ ms_invalid_cstr_value(const char *cstr, Py_ssize_t size, PathNode *path) { return NULL; } +static PyObject * +ms_invalid_cint_value(int64_t val, PathNode *path) { + ms_raise_validation_error(path, "Invalid value %lld%U", val); + return NULL; +} + +static PyObject * +ms_invalid_cuint_value(uint64_t val, PathNode *path) { + ms_raise_validation_error(path, "Invalid value %llu%U", val); + return NULL; +} + /* Same as ms_raise_validation_error, except doesn't require any format arguments. */ static PyObject * ms_error_with_path(const char *msg, PathNode *path) { @@ -3320,8 +3360,18 @@ StructMeta_new_inner( Py_INCREF(tag_temp); tag_value = tag_temp; } - if (!PyUnicode_CheckExact(tag_value)) { - PyErr_SetString(PyExc_TypeError, "`tag` must be a `str`"); + if (PyLong_CheckExact(tag_value)) { + int64_t val = PyLong_AsLongLong(tag_value); + if (val == -1 && PyErr_Occurred()) { + PyErr_SetString( + PyExc_ValueError, + "Integer `tag` values must be within [-2**63, 2**63 - 1]" + ); + goto error; + } + } + else if (!PyUnicode_CheckExact(tag_value)) { + PyErr_SetString(PyExc_TypeError, "`tag` must be a `str` or an `int`"); goto error; } } @@ -4435,7 +4485,7 @@ PyDoc_STRVAR(Struct__doc__, " Whether fields should be omitted from encoding if the corresponding value\n" " is the default for that field. Enabling this may reduce message size, and\n" " often also improve encoding & decoding performance.\n" -"tag: str, bool, callable, or None, default None\n" +"tag: str, int, bool, callable, or None, default None\n" " Used along with ``tag_field`` for configuring tagged union support. If\n" " either are non-None, then the struct is considered \"tagged\". In this case,\n" " an extra field (the ``tag_field``) and value (the ``tag``) are added to the\n" @@ -4444,10 +4494,10 @@ PyDoc_STRVAR(Struct__doc__, "\n" " Set ``tag=True`` to enable the default tagged configuration (``tag_field``\n" " is ``\"type\"``, ``tag`` is the class name). Alternatively, you can provide\n" -" a string value directly to be used as the tag (``tag=\"my-tag-value\"``).\n" -" ``tag`` can also be passed a callable that takes the class name and returns\n" -" a new string to use for the tag value (``tag=str.lower`` for example) See\n" -" the docs for more information.\n" +" a string (or less commonly int) value directly to be used as the tag\n" +" (e.g. ``tag=\"my-tag-value\"``).``tag`` can also be passed a callable that\n" +" takes the class name and returns a valid tag value (e.g. ``tag=str.lower``).\n" +" See the docs for more information.\n" "tag_field: str or None, default None\n" " The field name to use for tagged union support. If ``tag`` is non-None,\n" " then this defaults to ``\"type\"``. See the ``tag`` docs above for more\n" @@ -5018,7 +5068,7 @@ static PyMemberDef Encoder_members[] = { static MS_NOINLINE PyObject * ms_decode_str_enum_or_literal(const char *name, Py_ssize_t size, TypeNode *type, PathNode *path) { - StrLookupObject *lookup = TypeNode_get_str_enum_or_literal(type); + StrLookup *lookup = TypeNode_get_str_enum_or_literal(type); PyObject *out = StrLookup_Get(lookup, name, size); if (out == NULL) { PyObject *val = PyUnicode_DecodeUTF8(name, size, NULL); @@ -5033,10 +5083,10 @@ ms_decode_str_enum_or_literal(const char *name, Py_ssize_t size, TypeNode *type, static MS_NOINLINE PyObject * ms_decode_int_enum_or_literal_int64(int64_t val, TypeNode *type, PathNode *path) { - IntLookupObject *lookup = TypeNode_get_int_enum_or_literal(type); + IntLookup *lookup = TypeNode_get_int_enum_or_literal(type); PyObject *out = IntLookup_GetInt64(lookup, val); if (out == NULL) { - ms_raise_validation_error(path, "Invalid enum value `%lld`%U", val); + ms_raise_validation_error(path, "Invalid enum value %lld%U", val); return NULL; } Py_INCREF(out); @@ -5045,10 +5095,10 @@ ms_decode_int_enum_or_literal_int64(int64_t val, TypeNode *type, PathNode *path) static MS_NOINLINE PyObject * ms_decode_int_enum_or_literal_uint64(uint64_t val, TypeNode *type, PathNode *path) { - IntLookupObject *lookup = TypeNode_get_int_enum_or_literal(type); + IntLookup *lookup = TypeNode_get_int_enum_or_literal(type); PyObject *out = IntLookup_GetUInt64(lookup, val); if (out == NULL) { - ms_raise_validation_error(path, "Invalid enum value `%llu`%U", val); + ms_raise_validation_error(path, "Invalid enum value %llu%U", val); return NULL; } Py_INCREF(out); @@ -5550,7 +5600,7 @@ mpack_encode_struct(EncoderState *self, PyObject *obj) if (struct_type->array_like == OPT_TRUE) { if (mpack_encode_array_header(self, len, "structs") < 0) return -1; if (tagged) { - if (mpack_encode_str(self, tag_value) < 0) goto cleanup; + if (mpack_encode(self, tag_value) < 0) goto cleanup; } for (i = 0; i < nfields; i++) { val = Struct_get_index(obj, i); @@ -5565,7 +5615,7 @@ mpack_encode_struct(EncoderState *self, PyObject *obj) if (tagged) { if (mpack_encode_str(self, tag_field) < 0) goto cleanup; - if (mpack_encode_str(self, tag_value) < 0) goto cleanup; + if (mpack_encode(self, tag_value) < 0) goto cleanup; } if (struct_type->omit_defaults == OPT_TRUE) { @@ -6483,7 +6533,7 @@ json_encode_struct_default( if (tag_value != NULL) { if (json_encode_str(self, tag_field) < 0) goto cleanup; if (ms_write(self, ":", 1) < 0) goto cleanup; - if (json_encode_str(self, tag_value) < 0) goto cleanup; + if (json_encode(self, tag_value) < 0) goto cleanup; if (ms_write(self, ",", 1) < 0) goto cleanup; } for (i = 0; i < nfields; i++) { @@ -6523,7 +6573,7 @@ json_encode_struct_omit_defaults( if (tag_value != NULL) { if (json_encode_str(self, tag_field) < 0) goto cleanup; if (ms_write(self, ":", 1) < 0) goto cleanup; - if (json_encode_str(self, tag_value) < 0) goto cleanup; + if (json_encode(self, tag_value) < 0) goto cleanup; if (ms_write(self, ",", 1) < 0) goto cleanup; } @@ -6575,7 +6625,7 @@ json_encode_struct_array_like( if (ms_write(self, "[", 1) < 0) return -1; if (Py_EnterRecursiveCall(" while serializing an object")) return -1; if (tag_value != NULL) { - if (json_encode_str(self, tag_value) < 0) goto cleanup; + if (json_encode(self, tag_value) < 0) goto cleanup; if (ms_write(self, ",", 1) < 0) goto cleanup; } for (Py_ssize_t i = 0; i < nfields; i++) { @@ -7025,7 +7075,7 @@ mpack_error_expected(char op, char *expected, PathNode *path) { break; } } - ms_raise_validation_error(path, "Expected `str`, got `%s`%U", got); + ms_raise_validation_error(path, "Expected `%s`, got `%s`%U", expected, got); return NULL; } @@ -7056,6 +7106,64 @@ mpack_decode_cstr(DecoderState *self, char ** out, PathNode *path) { return size; } +/* Decode an integer. If the value fits in an int64_t, it will be stored in + * `out`, otherwise it will be stored in `uout`. A return value of -1 indicates + * an error. */ +static int +mpack_decode_cint(DecoderState *self, int64_t *out, uint64_t *uout, PathNode *path) { + char op = 0; + char *s = NULL; + + if (mpack_read1(self, &op) < 0) return -1; + + if (('\x00' <= op && op <= '\x7f') || ('\xe0' <= op && op <= '\xff')) { + *out = *((int8_t *)(&op)); + } + else if (op == MP_UINT8) { + if (MS_UNLIKELY(mpack_read(self, &s, 1) < 0)) return -1; + *out = *(uint8_t *)s; + } + else if (op == MP_UINT16) { + if (MS_UNLIKELY(mpack_read(self, &s, 2) < 0)) return -1; + *out = _msgspec_load16(uint16_t, s); + } + else if (op == MP_UINT32) { + if (MS_UNLIKELY(mpack_read(self, &s, 4) < 0)) return -1; + *out = _msgspec_load32(uint32_t, s); + } + else if (op == MP_UINT64) { + if (MS_UNLIKELY(mpack_read(self, &s, 8) < 0)) return -1; + uint64_t ux = _msgspec_load64(uint64_t, s); + if (ux > LLONG_MAX) { + *uout = ux; + } + else { + *out = ux; + } + } + else if (op == MP_INT8) { + if (MS_UNLIKELY(mpack_read(self, &s, 1) < 0)) return -1; + *out = *(int8_t *)s; + } + else if (op == MP_INT16) { + if (MS_UNLIKELY(mpack_read(self, &s, 2) < 0)) return -1; + *out = _msgspec_load16(int16_t, s); + } + else if (op == MP_INT32) { + if (MS_UNLIKELY(mpack_read(self, &s, 4) < 0)) return -1; + *out = _msgspec_load32(int32_t, s); + } + else if (op == MP_INT64) { + if (MS_UNLIKELY(mpack_read(self, &s, 8) < 0)) return -1; + *out = _msgspec_load64(int64_t, s); + } + else { + mpack_error_expected(op, "int", path); + return -1; + } + return 0; +} + static PyObject * mpack_decode_datetime( @@ -7434,6 +7542,82 @@ mpack_decode_fixtuple( return res; } +static int +mpack_ensure_tag_matches( + DecoderState *self, PathNode *path, PyObject *expected_tag +) { + if (PyUnicode_CheckExact(expected_tag)) { + char *tag = NULL; + Py_ssize_t tag_size; + tag_size = mpack_decode_cstr(self, &tag, path); + if (tag_size < 0) return -1; + + /* Check that tag matches expected tag value */ + Py_ssize_t expected_size; + const char *expected_str = unicode_str_and_size( + expected_tag, &expected_size + ); + if (tag_size != expected_size || memcmp(tag, expected_str, expected_size) != 0) { + /* Tag doesn't match the expected value, error nicely */ + ms_invalid_cstr_value(tag, tag_size, path); + return -1; + } + } + else { + int64_t tag = 0; + uint64_t utag = 0; + if (mpack_decode_cint(self, &tag, &utag, path) < 0) return -1; + int64_t expected = PyLong_AsLongLong(expected_tag); + /* Tags must be int64s, if utag != 0 then we know the tags don't match. + * We parse the full uint64 value only to validate the message and + * raise a nice error */ + if (utag != 0) { + ms_invalid_cuint_value(utag, path); + return -1; + } + if (tag != expected) { + ms_invalid_cint_value(tag, path); + return -1; + } + } + return 0; +} + +static StructMetaObject * +mpack_decode_tag_and_lookup_type( + DecoderState *self, Lookup *lookup, PathNode *path +) { + StructMetaObject *out = NULL; + if (Lookup_IsStrLookup(lookup)) { + Py_ssize_t tag_size; + char *tag = NULL; + tag_size = mpack_decode_cstr(self, &tag, path); + if (tag_size < 0) return NULL; + out = (StructMetaObject *)StrLookup_Get((StrLookup *)lookup, tag, tag_size); + if (out == NULL) { + ms_invalid_cstr_value(tag, tag_size, path); + } + } + else { + int64_t tag = 0; + uint64_t utag = 0; + if (mpack_decode_cint(self, &tag, &utag, path) < 0) return NULL; + if (utag == 0) { + out = (StructMetaObject *)IntLookup_GetInt64((IntLookup *)lookup, tag); + if (out == NULL) { + ms_invalid_cint_value(tag, path); + } + } + else { + out = (StructMetaObject *)IntLookup_GetUInt64((IntLookup *)lookup, utag); + if (out == NULL) { + ms_invalid_cuint_value(utag, path); + } + } + } + return out; +} + static PyObject * mpack_decode_struct_array_inner( DecoderState *self, Py_ssize_t size, bool tag_already_read, @@ -7461,20 +7645,8 @@ mpack_decode_struct_array_inner( if (tagged) { if (!tag_already_read) { - /* Decode tag */ - char *tag = NULL; - Py_ssize_t tag_size; - tag_size = mpack_decode_cstr(self, &tag, &item_path); - if (tag_size < 0) return NULL; - - /* Check that tag matches expected tag value */ - Py_ssize_t expected_size; - const char *expected = unicode_str_and_size( - st_type->struct_tag_value, &expected_size - ); - if (tag_size != expected_size || memcmp(tag, expected, expected_size) != 0) { - /* Tag doesn't match the expected value, error nicely */ - return ms_invalid_cstr_value(tag, tag_size, &item_path); + if (mpack_ensure_tag_matches(self, &item_path, st_type->struct_tag_value) < 0) { + return NULL; } } size--; @@ -7536,25 +7708,17 @@ static PyObject * mpack_decode_struct_array_union( DecoderState *self, Py_ssize_t size, TypeNode *type, PathNode *path, bool is_key ) { - StrLookupObject *lookup = TypeNode_get_struct_union(type); + Lookup *lookup = TypeNode_get_struct_union(type); if (size == 0) { return ms_error_with_path( "Expected `array` of at least length 1, got 0%U", path ); } - /* Decode tag */ + /* Decode and lookup tag */ PathNode tag_path = {path, 0}; - char *tag = NULL; - Py_ssize_t tag_size; - tag_size = mpack_decode_cstr(self, &tag, &tag_path); - if (MS_UNLIKELY(tag_size < 0)) return NULL; - - /* Lookup Struct type from tag */ - StructMetaObject *struct_type = (StructMetaObject *)StrLookup_Get(lookup, tag, tag_size); - if (struct_type == NULL) { - return ms_invalid_cstr_value(tag, tag_size, &tag_path); - } + StructMetaObject *struct_type = mpack_decode_tag_and_lookup_type(self, lookup, &tag_path); + if (struct_type == NULL) return NULL; /* Finish decoding the rest of the struct */ return mpack_decode_struct_array_inner(self, size, true, struct_type, path, is_key); @@ -7718,22 +7882,9 @@ mpack_decode_struct_map( field_index = StructMeta_get_field_index(st_type, key, key_size, &pos); if (field_index < 0) { if (MS_UNLIKELY(field_index == -2)) { - /* Matches the tag field */ - Py_ssize_t tag_size, expected_size; - char *tag = NULL; PathNode tag_path = {path, PATH_TAG, st_type->struct_tag_field}; - - /* Decode the tag value */ - tag_size = mpack_decode_cstr(self, &tag, &tag_path); - if (tag_size < 0) goto error; - - /* Check that the tag value matches the expected value */ - const char *expected = unicode_str_and_size( - st_type->struct_tag_value, &expected_size - ); - if (tag_size != expected_size || memcmp(tag, expected, expected_size) != 0) { - /* Tag doesn't match the expected value, error nicely */ - return ms_invalid_cstr_value(tag, tag_size, &tag_path); + if (mpack_ensure_tag_matches(self, &tag_path, st_type->struct_tag_value) < 0) { + return NULL; } } else { @@ -7765,10 +7916,12 @@ mpack_decode_struct_union( DecoderState *self, Py_ssize_t size, TypeNode *type, PathNode *path, bool is_key ) { - StrLookupObject *lookup = TypeNode_get_struct_union(type); + Lookup *lookup = TypeNode_get_struct_union(type); PathNode key_path = {path, PATH_KEY, NULL}; Py_ssize_t tag_field_size; - const char *tag_field = unicode_str_and_size(lookup->tag_field, &tag_field_size); + const char *tag_field = unicode_str_and_size( + Lookup_tag_field(lookup), &tag_field_size + ); /* Cache the current input position in case we need to reset it once the * tag is found */ @@ -7782,16 +7935,10 @@ mpack_decode_struct_union( if (key_size < 0) return NULL; if (key_size == tag_field_size && memcmp(key, tag_field, key_size) == 0) { - Py_ssize_t tag_size; - char *tag = NULL; - PathNode tag_path = {path, PATH_TAG, lookup->tag_field}; - tag_size = mpack_decode_cstr(self, &tag, &tag_path); - if (tag_size < 0) return NULL; - - StructMetaObject *struct_type = (StructMetaObject *)StrLookup_Get(lookup, tag, tag_size); - if (struct_type == NULL) { - return ms_invalid_cstr_value(tag, tag_size, &tag_path); - } + /* Decode and lookup tag */ + PathNode tag_path = {path, PATH_TAG, Lookup_tag_field(lookup)}; + StructMetaObject *struct_type = mpack_decode_tag_and_lookup_type(self, lookup, &tag_path); + if (struct_type == NULL) return NULL; if (i == 0) { /* Common case, tag is first field. No need to reset, just mark * that the first field has been read. */ @@ -7811,7 +7958,7 @@ mpack_decode_struct_union( ms_raise_validation_error( path, "Object missing required field `%U`%U", - lookup->tag_field + Lookup_tag_field(lookup) ); return NULL; } @@ -9578,6 +9725,97 @@ json_decode_struct_array_inner( return NULL; } +/* Decode an integer. If the value fits in an int64_t, it will be stored in + * `out`, otherwise it will be stored in `uout`. A return value of -1 indicates + * an error. */ +static int +json_decode_cint(JSONDecoderState *self, int64_t *out, uint64_t *uout, PathNode *path) { + uint64_t mantissa = 0; + bool is_negative = false; + unsigned char c; + unsigned char *orig_input_pos = self->input_pos; + + if (MS_UNLIKELY(!json_peek_skip_ws(self, &c))) return -1; + + /* Parse minus sign (if present) */ + if (c == '-') { + self->input_pos++; + c = json_peek_or_null(self); + is_negative = true; + } + + /* Parse integer */ + if (MS_UNLIKELY(c == '0')) { + /* Ensure at most one leading zero */ + self->input_pos++; + c = json_peek_or_null(self); + if (MS_UNLIKELY(is_digit(c))) { + json_err_invalid(self, "invalid number"); + return -1; + } + } + else { + /* Parse the integer part of the number. + * + * We can read the first 19 digits safely into a uint64 without + * checking for overflow. Removing overflow checks from the loop gives + * a measurable performance boost. */ + size_t remaining = self->input_end - self->input_pos; + size_t n_safe = Py_MIN(19, remaining); + while (n_safe) { + c = *self->input_pos; + if (!is_digit(c)) goto end_integer; + self->input_pos++; + n_safe--; + mantissa = mantissa * 10 + (uint64_t)(c - '0'); + } + if (MS_UNLIKELY(remaining > 19)) { + /* Reading a 20th digit may or may not cause overflow. Any + * additional digits definitely will. Read the 20th digit (and + * check for a 21st), taking the slow path upon overflow. */ + c = *self->input_pos; + if (MS_UNLIKELY(is_digit(c))) { + self->input_pos++; + uint64_t mantissa2 = mantissa * 10 + (uint64_t)(c - '0'); + bool overflowed = (mantissa2 < mantissa) || ((mantissa2 - (uint64_t)(c - '0')) / 10) != mantissa; + if (MS_UNLIKELY(overflowed || is_digit(json_peek_or_null(self)))) { + goto error_not_int; + } + mantissa = mantissa2; + c = json_peek_or_null(self); + } + } + +end_integer: + /* There must be at least one digit */ + if (MS_UNLIKELY(mantissa == 0)) goto error_not_int; + } + + if (c == '.' || c == 'e' || c == 'E') goto error_not_int; + + if (is_negative) { + if (mantissa > 1ull << 63) goto error_not_int; + *out = -1 * (int64_t)mantissa; + } + else { + if (mantissa > LLONG_MAX) { + *uout = mantissa; + } + else { + *out = mantissa; + } + } + return 0; + +error_not_int: + /* Use skip to catch malformed JSON */ + self->input_pos = orig_input_pos; + if (json_skip(self) < 0) return -1; + + ms_error_with_path("Expected `int`%U", path); + return -1; +} + static Py_ssize_t json_decode_cstr(JSONDecoderState *self, char **out, PathNode *path) { unsigned char c; @@ -9593,11 +9831,10 @@ json_decode_cstr(JSONDecoderState *self, char **out, PathNode *path) { return json_decode_string_view(self, out, &is_ascii); } -static Py_ssize_t -json_decode_struct_array_tag( - JSONDecoderState *self, StructMetaObject *st_type, char **tag, PathNode *path +static int +json_ensure_array_nonempty( + JSONDecoderState *self, StructMetaObject *st_type, PathNode *path ) { - PathNode tag_path = {path, 0}; unsigned char c; /* Check for an early end to the array */ if (MS_UNLIKELY(!json_peek_skip_ws(self, &c))) return -1; @@ -9621,7 +9858,81 @@ json_decode_struct_array_tag( ); return -1; } - return json_decode_cstr(self, tag, &tag_path); + return 0; +} + +static int +json_ensure_tag_matches( + JSONDecoderState *self, PathNode *path, PyObject *expected_tag +) { + if (PyUnicode_CheckExact(expected_tag)) { + char *tag = NULL; + Py_ssize_t tag_size; + tag_size = json_decode_cstr(self, &tag, path); + if (tag_size < 0) return -1; + + /* Check that tag matches expected tag value */ + Py_ssize_t expected_size; + const char *expected_str = unicode_str_and_size( + expected_tag, &expected_size + ); + if (tag_size != expected_size || memcmp(tag, expected_str, expected_size) != 0) { + /* Tag doesn't match the expected value, error nicely */ + ms_invalid_cstr_value(tag, tag_size, path); + return -1; + } + } + else { + int64_t tag = 0; + uint64_t utag = 0; + if (json_decode_cint(self, &tag, &utag, path) < 0) return -1; + int64_t expected = PyLong_AsLongLong(expected_tag); + /* Tags must be int64s, if utag != 0 then we know the tags don't match. + * We parse the full uint64 value only to validate the message and + * raise a nice error */ + if (utag != 0) { + ms_invalid_cuint_value(utag, path); + return -1; + } + if (tag != expected) { + ms_invalid_cint_value(tag, path); + return -1; + } + } + return 0; +} + +static StructMetaObject * +json_decode_tag_and_lookup_type( + JSONDecoderState *self, Lookup *lookup, PathNode *path +) { + StructMetaObject *out = NULL; + if (Lookup_IsStrLookup(lookup)) { + Py_ssize_t tag_size; + char *tag = NULL; + tag_size = json_decode_cstr(self, &tag, path); + if (tag_size < 0) return NULL; + out = (StructMetaObject *)StrLookup_Get((StrLookup *)lookup, tag, tag_size); + if (out == NULL) { + ms_invalid_cstr_value(tag, tag_size, path); + } + } + else { + int64_t tag = 0; + uint64_t utag = 0; + if (json_decode_cint(self, &tag, &utag, path) < 0) return NULL; + if (utag == 0) { + out = (StructMetaObject *)IntLookup_GetInt64((IntLookup *)lookup, tag); + if (out == NULL) { + ms_invalid_cint_value(tag, path); + } + } + else { + /* tags can't be uint64 values, we only decode to give a nice error */ + ms_invalid_cuint_value(utag, path); + } + } + return out; } static PyObject * @@ -9635,18 +9946,9 @@ json_decode_struct_array( /* If this is a tagged struct, first read and validate the tag */ if (st_type->struct_tag_value != NULL) { - Py_ssize_t tag_size, expected_size; - char *tag = NULL; - const char *expected = unicode_str_and_size( - st_type->struct_tag_value, &expected_size - ); - tag_size = json_decode_struct_array_tag(self, st_type, &tag, path); - if (tag_size < 0) return NULL; - if (tag_size != expected_size || memcmp(tag, expected, expected_size) != 0) { - /* Tag doesn't match the expected value, error nicely */ - PathNode tag_path = {path, 0}; - return ms_invalid_cstr_value(tag, tag_size, &tag_path); - } + PathNode tag_path = {path, 0}; + if (json_ensure_array_nonempty(self, st_type, path) < 0) return NULL; + if (json_ensure_tag_matches(self, &tag_path, st_type->struct_tag_value) < 0) return NULL; starting_index = 1; } @@ -9658,22 +9960,14 @@ static PyObject * json_decode_struct_array_union( JSONDecoderState *self, TypeNode *type, PathNode *path ) { - char *tag = NULL; - Py_ssize_t tag_size; - StrLookupObject *lookup = TypeNode_get_struct_union(type); + PathNode tag_path = {path, 0}; + Lookup *lookup = TypeNode_get_struct_union(type); self->input_pos++; /* Skip '[' */ - - /* Decode the tag */ - tag_size = json_decode_struct_array_tag(self, NULL, &tag, path); - if (tag_size < 0) return NULL; - - /* Lookup Struct type from tag */ - StructMetaObject *struct_type = (StructMetaObject *)StrLookup_Get(lookup, tag, tag_size); - if (struct_type == NULL) { - PathNode tag_path = {path, 0}; - return ms_invalid_cstr_value(tag, tag_size, &tag_path); - } + /* Decode & lookup struct type from tag */ + if (json_ensure_array_nonempty(self, NULL, path) < 0) return NULL; + StructMetaObject *struct_type = json_decode_tag_and_lookup_type(self, lookup, &tag_path); + if (struct_type == NULL) return NULL; /* Finish decoding the rest of the struct */ return json_decode_struct_array_inner(self, struct_type, path, 1); @@ -9864,22 +10158,10 @@ json_decode_struct_map_inner( Struct_set_index(out, field_index, val); } else if (MS_UNLIKELY(field_index == -2)) { - /* Matches the tag field */ - Py_ssize_t tag_size, expected_size; - char *tag = NULL; + /* Decode and check that the tag value matches the expected value */ PathNode tag_path = {path, PATH_TAG, st_type->struct_tag_field}; - - /* Decode the tag value */ - tag_size = json_decode_cstr(self, &tag, &tag_path); - if (tag_size < 0) goto error; - - /* Check that the tag value matches the expected value */ - const char *expected = unicode_str_and_size( - st_type->struct_tag_value, &expected_size - ); - if (tag_size != expected_size || memcmp(tag, expected, expected_size) != 0) { - /* Tag doesn't match the expected value, error nicely */ - return ms_invalid_cstr_value(tag, tag_size, &tag_path); + if (json_ensure_tag_matches(self, &tag_path, st_type->struct_tag_value) < 0) { + return NULL; } } else { @@ -9912,10 +10194,12 @@ static PyObject * json_decode_struct_union( JSONDecoderState *self, TypeNode *type, PathNode *path ) { - StrLookupObject *lookup = TypeNode_get_struct_union(type); - PathNode tag_path = {path, PATH_TAG, lookup->tag_field}; + Lookup *lookup = TypeNode_get_struct_union(type); + PathNode tag_path = {path, PATH_TAG, Lookup_tag_field(lookup)}; Py_ssize_t tag_field_size; - const char *tag_field = unicode_str_and_size(lookup->tag_field, &tag_field_size); + const char *tag_field = unicode_str_and_size( + Lookup_tag_field(lookup), &tag_field_size + ); self->input_pos++; /* Skip '{' */ @@ -9970,14 +10254,9 @@ json_decode_struct_union( /* Parse value */ if (tag_found) { - /* Parse tag string */ - char *tag = NULL; - Py_ssize_t tag_size = json_decode_cstr(self, &tag, &tag_path); - if (tag_size < 0) return NULL; - StructMetaObject *st_type = (StructMetaObject *)StrLookup_Get(lookup, tag, tag_size); - if (st_type == NULL) { - return ms_invalid_cstr_value(tag, tag_size, &tag_path); - } + /* Decode & lookup struct type from tag */ + StructMetaObject *st_type = json_decode_tag_and_lookup_type(self, lookup, &tag_path); + if (st_type == NULL) return NULL; if (i != 0) { /* tag wasn't first field, reset decoder position */ self->input_pos = orig_input_pos; @@ -9992,7 +10271,7 @@ json_decode_struct_union( ms_raise_validation_error( path, "Object missing required field `%U`%U", - lookup->tag_field + Lookup_tag_field(lookup) ); return NULL; } diff --git a/tests/basic_typing_examples.py b/tests/basic_typing_examples.py index b36cb8bd..563ba3ae 100644 --- a/tests/basic_typing_examples.py +++ b/tests/basic_typing_examples.py @@ -154,13 +154,19 @@ class Test3(msgspec.Struct, tag=False): class Test4(msgspec.Struct, tag="mytag"): pass - class Test5(msgspec.Struct, tag=lambda n: n.lower()): + class Test5(msgspec.Struct, tag=123): pass - class Test6(msgspec.Struct, tag_field=None): + class Test6(msgspec.Struct, tag=str.lower): pass - class Test7(msgspec.Struct, tag_field="type"): + class Test7(msgspec.Struct, tag=lambda n: len(n)): + pass + + class Test8(msgspec.Struct, tag_field=None): + pass + + class Test9(msgspec.Struct, tag_field="type"): pass diff --git a/tests/test_common.py b/tests/test_common.py index 1ad5a4ae..3dad0948 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -129,15 +129,14 @@ class Test(enum.IntEnum): @pytest.mark.parametrize( "values", [ - [0, 1, 2, 2**64], [0, 1, 2, -(2**63) - 1], - [0, 1, 2, 2**63 + 1, -(2**64)], + [0, 1, 2, 2**63], ], ) def test_int_lookup_values_out_of_range(self, values): myenum = enum.IntEnum("myenum", [(f"x{i}", v) for i, v in enumerate(values)]) - with pytest.raises(OverflowError): + with pytest.raises(NotImplementedError): msgspec.msgpack.Decoder(myenum) def test_msgspec_lookup_overwritten(self): @@ -183,9 +182,8 @@ def test_compact(self, values): "values", [ [-(2**63), 2**63 - 1, 0], - [-(2**63), 2**64 - 1, 0], - [2**64 - 2, 2**64 - 3, 2**64 - 1], - [2**64 - 2, 2**64 - 3, 2**64 - 1, 0, 2, 3, 4, 5, 6], + [2**63 - 2, 2**63 - 3, 2**63 - 1], + [2**63 - 2, 2**63 - 3, 2**63 - 1, 0, 2, 3, 4, 5, 6], ], ) def test_hashtable(self, values): @@ -346,15 +344,14 @@ def test_empty_errors(self): @pytest.mark.parametrize( "values", [ - [0, 1, 2, 2**64], + [0, 1, 2, 2**63], [0, 1, 2, -(2**63) - 1], - [0, 1, 2, 2**63 + 1, -(2**64)], ], ) def test_int_literal_values_out_of_range(self, values): literal = Literal[tuple(values)] - with pytest.raises(OverflowError): + with pytest.raises(NotImplementedError): msgspec.msgpack.Decoder(literal) @pytest.mark.parametrize( @@ -437,7 +434,7 @@ def test_multiple_literals(self): for val in [-1, -2, -3, "apple", "banana"]: assert dec.decode(msgspec.msgpack.encode(val)) == val - with pytest.raises(msgspec.DecodeError, match="Invalid enum value `4`"): + with pytest.raises(msgspec.DecodeError, match="Invalid enum value 4"): dec.decode(msgspec.msgpack.encode(4)) with pytest.raises(msgspec.DecodeError, match="Invalid enum value 'carrot'"): @@ -457,7 +454,7 @@ def test_nested_literals(self): for val in [-1, -2, -3, "apple", "banana"]: assert dec.decode(msgspec.msgpack.encode(val)) == val - with pytest.raises(msgspec.DecodeError, match="Invalid enum value `4`"): + with pytest.raises(msgspec.DecodeError, match="Invalid enum value 4"): dec.decode(msgspec.msgpack.encode(4)) with pytest.raises(msgspec.DecodeError, match="Invalid enum value 'carrot'"): @@ -623,9 +620,34 @@ class Test2(msgspec.Struct, tag_field="bar", array_like=array_like): assert "the same `tag_field`" in str(rec.value) assert repr(typ) in str(rec.value) + @pytest.mark.parametrize("array_like", [False, True]) + def test_err_union_struct_mix_int_str_tags(self, proto, array_like): + class Test1(msgspec.Struct, tag=1, array_like=array_like): + x: int + + class Test2(msgspec.Struct, tag="two", array_like=array_like): + x: int + + typ = Union[Test1, Test2] + + with pytest.raises(TypeError) as rec: + proto.Decoder(typ) + + assert "not supported" in str(rec.value) + assert "both `int` and `str` tags" in str(rec.value) + assert repr(typ) in str(rec.value) + @pytest.mark.parametrize("array_like", [False, True]) @pytest.mark.parametrize( - "tags", [("a", "b", "b"), ("a", "a", "b"), ("a", "b", "a")] + "tags", + [ + ("a", "b", "b"), + ("a", "a", "b"), + ("a", "b", "a"), + (1, 2, 2), + (1, 1, 2), + (1, 2, 1), + ], ) def test_err_union_struct_non_unique_tag_values(self, proto, array_like, tags): class Test1(msgspec.Struct, tag=tags[0], array_like=array_like): @@ -646,13 +668,20 @@ class Test3(msgspec.Struct, tag=tags[2], array_like=array_like): assert "unique `tag`" in str(rec.value) assert repr(typ) in str(rec.value) - def test_decode_struct_union(self, proto): - class Test1(msgspec.Struct, tag=True): + @pytest.mark.parametrize( + "tag1, tag2, unknown", + [ + ("Test1", "Test2", "Test3"), + (123, -123, 0), + ], + ) + def test_decode_struct_union(self, proto, tag1, tag2, unknown): + class Test1(msgspec.Struct, tag=tag1): a: int b: int c: int = 0 - class Test2(msgspec.Struct, tag=True): + class Test2(msgspec.Struct, tag=tag2): x: int y: int @@ -660,22 +689,22 @@ class Test2(msgspec.Struct, tag=True): enc = proto.Encoder() # Tag can be in any position - assert dec.decode(enc.encode({"type": "Test1", "a": 1, "b": 2})) == Test1(1, 2) - assert dec.decode(enc.encode({"a": 1, "type": "Test1", "b": 2})) == Test1(1, 2) - assert dec.decode(enc.encode({"x": 1, "y": 2, "type": "Test2"})) == Test2(1, 2) + assert dec.decode(enc.encode({"type": tag1, "a": 1, "b": 2})) == Test1(1, 2) + assert dec.decode(enc.encode({"a": 1, "type": tag1, "b": 2})) == Test1(1, 2) + assert dec.decode(enc.encode({"x": 1, "y": 2, "type": tag2})) == Test2(1, 2) # Optional fields still work - assert dec.decode( - enc.encode({"type": "Test1", "a": 1, "b": 2, "c": 3}) - ) == Test1(1, 2, 3) - assert dec.decode( - enc.encode({"a": 1, "b": 2, "c": 3, "type": "Test1"}) - ) == Test1(1, 2, 3) + assert dec.decode(enc.encode({"type": tag1, "a": 1, "b": 2, "c": 3})) == Test1( + 1, 2, 3 + ) + assert dec.decode(enc.encode({"a": 1, "b": 2, "c": 3, "type": tag1})) == Test1( + 1, 2, 3 + ) # Extra fields still ignored - assert dec.decode( - enc.encode({"a": 1, "b": 2, "d": 4, "type": "Test1"}) - ) == Test1(1, 2) + assert dec.decode(enc.encode({"a": 1, "b": 2, "d": 4, "type": tag1})) == Test1( + 1, 2 + ) # Tag missing with pytest.raises(msgspec.DecodeError) as rec: @@ -684,48 +713,55 @@ class Test2(msgspec.Struct, tag=True): # Tag wrong type with pytest.raises(msgspec.DecodeError) as rec: - dec.decode(enc.encode({"type": 1, "a": 1, "b": 2})) - assert "Expected `str`" in str(rec.value) + dec.decode(enc.encode({"type": 123.456, "a": 1, "b": 2})) + assert f"Expected `{type(tag1).__name__}`" in str(rec.value) assert "`$.type`" in str(rec.value) # Tag unknown with pytest.raises(msgspec.DecodeError) as rec: - dec.decode(enc.encode({"type": "bad", "a": 1, "b": 2})) - assert "Invalid value 'bad' - at `$.type`" == str(rec.value) + dec.decode(enc.encode({"type": unknown, "a": 1, "b": 2})) + assert f"Invalid value {unknown!r} - at `$.type`" == str(rec.value) - def test_decode_struct_array_union(self, proto): - class Test1(msgspec.Struct, tag=True, array_like=True): + @pytest.mark.parametrize( + "tag1, tag2, tag3, unknown", + [ + ("Test1", "Test2", "Test3", "Test4"), + (123, -123, 0, -1), + ], + ) + def test_decode_struct_array_union(self, proto, tag1, tag2, tag3, unknown): + class Test1(msgspec.Struct, tag=tag1, array_like=True): a: int b: int c: int = 0 - class Test2(msgspec.Struct, tag=True, array_like=True): + class Test2(msgspec.Struct, tag=tag2, array_like=True): x: int y: int - class Test3(msgspec.Struct, tag=True, array_like=True): + class Test3(msgspec.Struct, tag=tag3, array_like=True): pass dec = proto.Decoder(Union[Test1, Test2, Test3]) enc = proto.Encoder() # Decoding works - assert dec.decode(enc.encode(["Test1", 1, 2])) == Test1(1, 2) - assert dec.decode(enc.encode(["Test2", 3, 4])) == Test2(3, 4) - assert dec.decode(enc.encode(["Test3"])) == Test3() + assert dec.decode(enc.encode([tag1, 1, 2])) == Test1(1, 2) + assert dec.decode(enc.encode([tag2, 3, 4])) == Test2(3, 4) + assert dec.decode(enc.encode([tag3])) == Test3() # Optional & Extra fields still respected - assert dec.decode(enc.encode(["Test1", 1, 2, 3])) == Test1(1, 2, 3) - assert dec.decode(enc.encode(["Test1", 1, 2, 3, 4])) == Test1(1, 2, 3) + assert dec.decode(enc.encode([tag1, 1, 2, 3])) == Test1(1, 2, 3) + assert dec.decode(enc.encode([tag1, 1, 2, 3, 4])) == Test1(1, 2, 3) # Missing required field with pytest.raises(msgspec.DecodeError) as rec: - dec.decode(enc.encode(["Test1", 1])) + dec.decode(enc.encode([tag1, 1])) assert "Expected `array` of at least length 3, got 2" in str(rec.value) # Type error has correct field index with pytest.raises(msgspec.DecodeError) as rec: - dec.decode(enc.encode(["Test1", 1, "bad", 2])) + dec.decode(enc.encode([tag1, 1, "bad", 2])) assert "Expected `int`, got `str` - at `$[2]`" == str(rec.value) # Tag missing @@ -735,14 +771,14 @@ class Test3(msgspec.Struct, tag=True, array_like=True): # Tag wrong type with pytest.raises(msgspec.DecodeError) as rec: - dec.decode(enc.encode([1, 2, 3, 4])) - assert "Expected `str`" in str(rec.value) + dec.decode(enc.encode([123.456, 2, 3, 4])) + assert f"Expected `{type(tag1).__name__}`" in str(rec.value) assert "`$[0]`" in str(rec.value) # Tag unknown with pytest.raises(msgspec.DecodeError) as rec: - dec.decode(enc.encode(["bad", 1, 2, 3])) - assert "Invalid value 'bad' - at `$[0]`" == str(rec.value) + dec.decode(enc.encode([unknown, 1, 2, 3])) + assert f"Invalid value {unknown!r} - at `$[0]`" == str(rec.value) @pytest.mark.parametrize("array_like", [False, True]) def test_decode_struct_union_with_non_struct_types(self, array_like, proto): diff --git a/tests/test_json.py b/tests/test_json.py index c2e2de2e..2ceb160e 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1125,7 +1125,7 @@ def test_decode_intenum(self): assert x == FruitInt.APPLE def test_decode_intenum_invalid_value(self): - with pytest.raises(msgspec.DecodeError, match="Invalid enum value `3`"): + with pytest.raises(msgspec.DecodeError, match="Invalid enum value 3"): msgspec.json.decode(b"3", type=FruitInt) def test_decode_intenum_invalid_value_nested(self): @@ -1133,7 +1133,7 @@ class Test(msgspec.Struct): fruit: FruitInt with pytest.raises( - msgspec.DecodeError, match=r"Invalid enum value `3` - at `\$.fruit`" + msgspec.DecodeError, match=r"Invalid enum value 3 - at `\$.fruit`" ): msgspec.json.decode(b'{"fruit": 3}', type=Test) @@ -1162,7 +1162,7 @@ def test_literal(self, values): def test_int_literal_errors(self): dec = msgspec.json.Decoder(Literal[1, 2, 3]) - with pytest.raises(msgspec.DecodeError, match="Invalid enum value `4`"): + with pytest.raises(msgspec.DecodeError, match="Invalid enum value 4"): dec.decode(b"4") with pytest.raises(msgspec.DecodeError, match="Expected `int`, got `str`"): @@ -1763,29 +1763,31 @@ def test_decode_dict_malformed(self, s, error, type): class TestStruct: - @pytest.mark.parametrize("tag", [False, True]) + @pytest.mark.parametrize("tag", [False, "Test", 123]) def test_encode_empty_struct(self, tag): class Test(msgspec.Struct, tag=tag): pass s = msgspec.json.encode(Test()) if tag: - assert s == b'{"type":"Test"}' + expected = msgspec.json.encode({"type": tag}) + assert s == expected else: assert s == b"{}" - @pytest.mark.parametrize("tag", [False, True]) + @pytest.mark.parametrize("tag", [False, "Test", 123]) def test_encode_one_field_struct(self, tag): class Test(msgspec.Struct, tag=tag): a: int s = msgspec.json.encode(Test(a=1)) if tag: - assert s == b'{"type":"Test","a":1}' + expected = msgspec.json.encode({"type": tag, "a": 1}) + assert s == expected else: assert s == b'{"a":1}' - @pytest.mark.parametrize("tag", [False, True]) + @pytest.mark.parametrize("tag", [False, "Test", 123]) def test_encode_two_field_struct(self, tag): class Test(msgspec.Struct, tag=tag): a: int @@ -1793,7 +1795,8 @@ class Test(msgspec.Struct, tag=tag): s = msgspec.json.encode(Test(a=1, b="two")) if tag: - assert s == b'{"type":"Test","a":1,"b":"two"}' + expected = msgspec.json.encode({"type": tag, "a": 1, "b": "two"}) + assert s == expected else: assert s == b'{"a":1,"b":"two"}' @@ -1956,8 +1959,9 @@ def test_struct_recursive_definition(self): res = dec.decode(s) assert res == x - def test_decode_tagged_struct(self): - class Test(msgspec.Struct, tag=True): + @pytest.mark.parametrize("tag", ["Test", 0, 2**63 - 1, -(2**63)]) + def test_decode_tagged_struct(self, tag): + class Test(msgspec.Struct, tag=tag): a: int b: int @@ -1966,26 +1970,29 @@ class Test(msgspec.Struct, tag=True): # Test decode with and without tag for msg in [ {"a": 1, "b": 2}, - {"type": "Test", "a": 1, "b": 2}, - {"a": 1, "type": "Test", "b": 2}, + {"type": tag, "a": 1, "b": 2}, + {"a": 1, "type": tag, "b": 2}, ]: res = dec.decode(msgspec.json.encode(msg)) assert res == Test(1, 2) # Tag incorrect type - with pytest.raises(msgspec.DecodeError) as rec: - dec.decode(msgspec.json.encode({"type": 1})) - assert "Expected `str`" in str(rec.value) - assert "`$.type`" in str(rec.value) + for bad in [False, 123.456]: + with pytest.raises(msgspec.DecodeError) as rec: + dec.decode(msgspec.json.encode({"type": bad})) + assert f"Expected `{type(tag).__name__}`" in str(rec.value) + assert "`$.type`" in str(rec.value) # Tag incorrect value + bad = -3 if isinstance(tag, int) else "bad" with pytest.raises(msgspec.DecodeError) as rec: - dec.decode(msgspec.json.encode({"type": "bad"})) - assert "Invalid value 'bad'" in str(rec.value) + dec.decode(msgspec.json.encode({"type": bad})) + assert f"Invalid value {bad!r}" in str(rec.value) assert "`$.type`" in str(rec.value) - def test_decode_tagged_empty_struct(self): - class Test(msgspec.Struct, tag=True): + @pytest.mark.parametrize("tag", ["Test", 123, -123]) + def test_decode_tagged_empty_struct(self, tag): + class Test(msgspec.Struct, tag=tag): pass dec = msgspec.json.Decoder(Test) @@ -1995,7 +2002,7 @@ class Test(msgspec.Struct, tag=True): assert res == Test() # Tag present - res = dec.decode(msgspec.json.encode({"type": "Test"})) + res = dec.decode(msgspec.json.encode({"type": tag})) assert res == Test() @pytest.mark.parametrize( @@ -2016,7 +2023,7 @@ class Test(msgspec.Struct, tag=True): (b'{"a": 1 "b"}', r"expected ',' or '}'"), ], ) - def test_decode_tagged_struct_malformed(self, s, error): + def test_decode_struct_tag_malformed(self, s, error): class Test1(msgspec.Struct, tag=True): a: int b: int @@ -2024,6 +2031,65 @@ class Test1(msgspec.Struct, tag=True): with pytest.raises(msgspec.DecodeError, match=error): msgspec.json.decode(s, type=Test1) + @pytest.mark.parametrize("ndigits", range(19)) + @pytest.mark.parametrize("negative", [False, True]) + def test_decode_tagged_struct_int_tag(self, ndigits, negative): + if ndigits == 0: + s = b"0" + else: + s = "".join( + itertools.islice(itertools.cycle("123456789"), ndigits) + ).encode() + + tag = int(s) + if negative: + tag = -tag + + class Test(msgspec.Struct, tag=tag): + x: int + + t = Test(1) + msg = msgspec.json.encode(t) + assert msgspec.json.decode(msg, type=Test) == t + + def test_decode_tagged_struct_int_tag_uint64_always_invalid(self): + """Uint64 values aren't currently valid tag values, but we still want + to raise a good error message.""" + + class Test(msgspec.Struct, tag=123): + pass + + with pytest.raises(msgspec.DecodeError) as rec: + msgspec.json.decode(msgspec.json.encode({"type": 2**64 - 1}), type=Test) + assert f"Invalid value {2**64 - 1}" in str(rec.value) + assert "`$.type`" in str(rec.value) + + @pytest.mark.parametrize( + "s, error", + [ + (b'{"type": 00}', "invalid number"), + (b'{"type": -n123}', "invalid character"), + (b'{"type": 123n}', "expected ',' or '}'"), + (b'{"type": 123.}', "invalid number"), + (b'{"type": 123.n}', "invalid number"), + (b'{"type": 123e}', "invalid number"), + (b'{"type": 123en}', "invalid number"), + (b'{"type": 123, }', "trailing comma in object"), + (b'{"type": 123, "a" 1}', "expected ':'"), + (b'{"type": 123, "a": 1 "b"}', "expected ',' or '}'"), + (b'{"type": nulp}', "invalid character"), + (b'{"type": "bad}', "truncated"), + (b'{"type": bad}', "invalid character"), + ], + ) + def test_decode_struct_int_tag_malformed(self, s, error): + class Test1(msgspec.Struct, tag=123): + a: int + b: int + + with pytest.raises(msgspec.DecodeError, match=error): + msgspec.json.decode(s, type=Test1) + class TestStructUnion: """Most functionality is tested in `test_common.py:TestStructUnion`, this only @@ -2057,6 +2123,35 @@ class Test2(msgspec.Struct, tag=True): with pytest.raises(msgspec.DecodeError, match=error): msgspec.json.decode(s, type=Union[Test1, Test2]) + @pytest.mark.parametrize( + "s, error", + [ + (b'{"type": 00}', "invalid number"), + (b'{"type": -n123}', "invalid character"), + (b'{"type": 123n}', "expected ',' or '}'"), + (b'{"type": 123.}', "invalid number"), + (b'{"type": 123.n}', "invalid number"), + (b'{"type": 123e}', "invalid number"), + (b'{"type": 123en}', "invalid number"), + (b'{"type": 123, }', "trailing comma in object"), + (b'{"type": 123, "a" 1}', "expected ':'"), + (b'{"type": 123, "a": 1 "b"}', "expected ',' or '}'"), + (b'{"type": nulp}', "invalid character"), + (b'{"type": "bad}', "truncated"), + (b'{"type": bad}', "invalid character"), + ], + ) + def test_decode_struct_union_int_tag_malformed(self, s, error): + class Test1(msgspec.Struct, tag=-123): + a: int + b: int + + class Test2(msgspec.Struct, tag=123): + pass + + with pytest.raises(msgspec.DecodeError, match=error): + msgspec.json.decode(s, type=Union[Test1, Test2]) + @pytest.mark.parametrize( "s", [ @@ -2076,6 +2171,25 @@ class Test2(msgspec.Struct, tag=True): res = msgspec.json.decode(s, type=Union[Test1, Test2]) assert res == Test1(1, 2) + @pytest.mark.parametrize( + "s", + [ + b' { "type" : -123 , "a" : 1 , "b" : 2 } ', + b' { "a" : 1 , "type" : -123 , "b" : 2 } ', + b' { "a" : 1 , "b" : 2 , "type" : -123 } ', + ], + ) + def test_decode_struct_union_int_tag_ignores_whitespace(self, s): + class Test1(msgspec.Struct, tag=-123): + a: int + b: int + + class Test2(msgspec.Struct, tag=123): + pass + + res = msgspec.json.decode(s, type=Union[Test1, Test2]) + assert res == Test1(1, 2) + class TestStructArray: @pytest.mark.parametrize("tag", [False, True]) @@ -2198,8 +2312,9 @@ class Point(msgspec.Struct, array_like=True): with pytest.raises(msgspec.DecodeError, match=error): msgspec.json.decode(s, type=Point) - def test_decode_tagged_struct(self): - class Test(msgspec.Struct, tag=True, array_like=True): + @pytest.mark.parametrize("tag", ["Test", 123]) + def test_decode_tagged_struct(self, tag): + class Test(msgspec.Struct, tag=tag, array_like=True): a: int b: int c: int = 0 @@ -2207,18 +2322,18 @@ class Test(msgspec.Struct, tag=True, array_like=True): dec = msgspec.json.Decoder(Test) # Decode with tag - res = dec.decode(msgspec.json.encode(["Test", 1, 2])) + res = dec.decode(msgspec.json.encode([tag, 1, 2])) assert res == Test(1, 2) - res = dec.decode(msgspec.json.encode(["Test", 1, 2, 3])) + res = dec.decode(msgspec.json.encode([tag, 1, 2, 3])) assert res == Test(1, 2, 3) # Trailing fields ignored - res = dec.decode(msgspec.json.encode(["Test", 1, 2, 3, 4])) + res = dec.decode(msgspec.json.encode([tag, 1, 2, 3, 4])) assert res == Test(1, 2, 3) # Missing required field errors with pytest.raises(msgspec.DecodeError) as rec: - dec.decode(msgspec.json.encode(["Test", 1])) + dec.decode(msgspec.json.encode([tag, 1])) assert "Expected `array` of at least length 3, got 2" in str(rec.value) # Tag missing @@ -2228,30 +2343,32 @@ class Test(msgspec.Struct, tag=True, array_like=True): # Tag incorrect type with pytest.raises(msgspec.DecodeError) as rec: - dec.decode(msgspec.json.encode([1, 2, 3])) - assert "Expected `str`" in str(rec.value) + dec.decode(msgspec.json.encode([123.456, 2, 3])) + assert f"Expected `{type(tag).__name__}`" in str(rec.value) assert "`$[0]`" in str(rec.value) # Tag incorrect value + bad = 0 if isinstance(tag, int) else "bad" with pytest.raises(msgspec.DecodeError) as rec: - dec.decode(msgspec.json.encode(["bad", 1, 2])) - assert "Invalid value 'bad'" in str(rec.value) + dec.decode(msgspec.json.encode([bad, 1, 2])) + assert f"Invalid value {bad!r}" in str(rec.value) assert "`$[0]`" in str(rec.value) # Field incorrect type correct index with pytest.raises(msgspec.DecodeError) as rec: - dec.decode(msgspec.json.encode(["Test", "a", 2])) + dec.decode(msgspec.json.encode([tag, "a", 2])) assert "Expected `int`, got `str`" in str(rec.value) assert "`$[1]`" in str(rec.value) - def test_decode_tagged_empty_struct(self): - class Test(msgspec.Struct, tag=True, array_like=True): + @pytest.mark.parametrize("tag", ["Test", 123]) + def test_decode_tagged_empty_struct(self, tag): + class Test(msgspec.Struct, tag=tag, array_like=True): pass dec = msgspec.json.Decoder(Test) # Decode with tag - res = dec.decode(msgspec.json.encode(["Test", 1, 2])) + res = dec.decode(msgspec.json.encode([tag, 1, 2])) assert res == Test() # Tag missing @@ -2294,6 +2411,35 @@ class Test2(msgspec.Struct, tag=True, array_like=True): with pytest.raises(msgspec.DecodeError, match=error): msgspec.json.decode(s, type=Union[Test1, Test2]) + @pytest.mark.parametrize( + "s, error", + [ + (b"[,]", "invalid character"), + (b"[, 1]", "invalid character"), + (b"[nulp]", "invalid character"), + (b"[123, nulp]", "invalid character"), + (b"[", "truncated"), + (b"[123.n,", "invalid number"), + (b"[123en,", "invalid number"), + (b"[123", "truncated"), + (b"[123,", "truncated"), + (b"[123, ]", "trailing comma in array"), + (b"[123 g", r"expected ',' or ']'"), + (b"[123, 1 g", r"expected ',' or ']'"), + ], + ) + def test_decode_struct_array_like_union_int_tag_malformed(self, s, error): + class Test1(msgspec.Struct, tag=123, array_like=True): + x: int + y: int + z: int + + class Test2(msgspec.Struct, tag=-123, array_like=True): + pass + + with pytest.raises(msgspec.DecodeError, match=error): + msgspec.json.decode(s, type=Union[Test1, Test2]) + def test_decode_struct_array_union_ignores_whitespace(self): s = b' [ "Test1" , 1 , 2 ] ' @@ -2307,6 +2453,19 @@ class Test2(msgspec.Struct, tag=True, array_like=True): res = msgspec.json.decode(s, type=Union[Test1, Test2]) assert res == Test1(1, 2) + def test_decode_struct_array_union_int_tag_ignores_whitespace(self): + s = b" [ 123 , 1 , 2 ] " + + class Test1(msgspec.Struct, tag=123, array_like=True): + a: int + b: int + + class Test2(msgspec.Struct, tag=-123, array_like=True): + pass + + res = msgspec.json.decode(s, type=Union[Test1, Test2]) + assert res == Test1(1, 2) + class TestRaw: def test_encode_raw(self): diff --git a/tests/test_msgpack.py b/tests/test_msgpack.py index 37a8b8d8..ec81a523 100644 --- a/tests/test_msgpack.py +++ b/tests/test_msgpack.py @@ -98,6 +98,7 @@ class Point(NamedTuple): 2**16, 2**32 - 1, 2**32, + 2**63 - 1, 2**64 - 1, ] @@ -968,11 +969,11 @@ def test_int_enum(self): with pytest.raises(msgspec.DecodeError, match="truncated"): dec.decode(a[:-2]) - with pytest.raises(msgspec.DecodeError, match="Invalid enum value `1000`"): + with pytest.raises(msgspec.DecodeError, match="Invalid enum value 1000"): dec.decode(enc.encode(1000)) with pytest.raises( - msgspec.DecodeError, match=r"Invalid enum value `1000` - at `\$\[0\]`" + msgspec.DecodeError, match=r"Invalid enum value 1000 - at `\$\[0\]`" ): msgspec.msgpack.decode(enc.encode([1000]), type=List[FruitInt]) @@ -1001,11 +1002,11 @@ def test_int_literal(self): assert dec.decode(enc.encode(1)) == 1 - with pytest.raises(msgspec.DecodeError, match="Invalid enum value `1000`"): + with pytest.raises(msgspec.DecodeError, match="Invalid enum value 1000"): dec.decode(enc.encode(1000)) with pytest.raises( - msgspec.DecodeError, match=r"Invalid enum value `1000` - at `\$\[0\]`" + msgspec.DecodeError, match=r"Invalid enum value 1000 - at `\$\[0\]`" ): msgspec.msgpack.decode(enc.encode([1000]), type=List[literal]) @@ -1438,40 +1439,40 @@ def check(self, x): class TestStruct: - @pytest.mark.parametrize("tag", [False, True]) + @pytest.mark.parametrize("tag", [False, "Test", 123]) def test_encode_empty_struct(self, tag): class Test(msgspec.Struct, tag=tag): pass if tag: - msg = {"type": "Test"} + msg = {"type": tag} else: msg = {} s = msgspec.msgpack.encode(Test()) s2 = msgspec.msgpack.encode(msg) assert s == s2 - @pytest.mark.parametrize("tag", [False, True]) + @pytest.mark.parametrize("tag", [False, "Test", 123]) def test_encode_one_field_struct(self, tag): class Test(msgspec.Struct, tag=tag): a: int if tag: - msg = {"type": "Test", "a": 1} + msg = {"type": tag, "a": 1} else: msg = {"a": 1} s = msgspec.msgpack.encode(Test(a=1)) s2 = msgspec.msgpack.encode(msg) assert s == s2 - @pytest.mark.parametrize("tag", [False, True]) + @pytest.mark.parametrize("tag", [False, "Test", 123]) def test_encode_two_field_struct(self, tag): class Test(msgspec.Struct, tag=tag): a: int b: str if tag: - msg = {"type": "Test", "a": 1, "b": "two"} + msg = {"type": tag, "a": 1, "b": "two"} else: msg = {"a": 1, "b": "two"} s = msgspec.msgpack.encode(Test(a=1, b="two")) @@ -1608,8 +1609,9 @@ def test_struct_recursive_definition(self): res = dec.decode(s) assert res == x - def test_decode_tagged_struct(self): - class Test(msgspec.Struct, tag=True): + @pytest.mark.parametrize("tag", ["Test", 123, -123]) + def test_decode_tagged_struct(self, tag): + class Test(msgspec.Struct, tag=tag): a: int b: int @@ -1618,26 +1620,52 @@ class Test(msgspec.Struct, tag=True): # Test decode with and without tag for msg in [ {"a": 1, "b": 2}, - {"type": "Test", "a": 1, "b": 2}, - {"a": 1, "type": "Test", "b": 2}, + {"type": tag, "a": 1, "b": 2}, + {"a": 1, "type": tag, "b": 2}, ]: res = dec.decode(msgspec.msgpack.encode(msg)) assert res == Test(1, 2) # Tag incorrect type with pytest.raises(msgspec.DecodeError) as rec: - dec.decode(msgspec.msgpack.encode({"type": 1})) - assert "Expected `str`" in str(rec.value) + dec.decode(msgspec.msgpack.encode({"type": 123.456})) + assert f"Expected `{type(tag).__name__}`" in str(rec.value) assert "`$.type`" in str(rec.value) # Tag incorrect value + bad = -3 if isinstance(tag, int) else "bad" + with pytest.raises(msgspec.DecodeError) as rec: + dec.decode(msgspec.msgpack.encode({"type": bad})) + assert f"Invalid value {bad!r}" in str(rec.value) + assert "`$.type`" in str(rec.value) + + @pytest.mark.parametrize("tag", [i for i in INTS if -(2**63) <= i < 2**63]) + def test_decode_tagged_struct_int_ranges(self, tag): + class Test(msgspec.Struct, tag=tag): + a: int + b: int + + dec = msgspec.msgpack.Decoder(Test) + t = Test(1, 2) + assert dec.decode(msgspec.msgpack.encode(t)) + + def test_decode_tagged_struct_int_tag_uint64_always_invalid(self): + """Uint64 values aren't currently valid tag values, but we still want + to raise a good error message.""" + + class Test(msgspec.Struct, tag=123): + pass + with pytest.raises(msgspec.DecodeError) as rec: - dec.decode(msgspec.msgpack.encode({"type": "bad"})) - assert "Invalid value 'bad'" in str(rec.value) + msgspec.msgpack.decode( + msgspec.msgpack.encode({"type": 2**64 - 1}), type=Test + ) + assert f"Invalid value {2**64 - 1}" in str(rec.value) assert "`$.type`" in str(rec.value) - def test_decode_tagged_empty_struct(self): - class Test(msgspec.Struct, tag=True): + @pytest.mark.parametrize("tag", ["Test", 123, -123]) + def test_decode_tagged_empty_struct(self, tag): + class Test(msgspec.Struct, tag=tag): pass dec = msgspec.msgpack.Decoder(Test) @@ -1647,38 +1675,38 @@ class Test(msgspec.Struct, tag=True): assert res == Test() # Tag present - res = dec.decode(msgspec.msgpack.encode({"type": "Test"})) + res = dec.decode(msgspec.msgpack.encode({"type": tag})) assert res == Test() class TestStructArray: - @pytest.mark.parametrize("tag", [False, True]) + @pytest.mark.parametrize("tag", [False, "Test", 123]) def test_encode_empty_struct(self, tag): class Test(msgspec.Struct, array_like=True, tag=tag): pass s = msgspec.msgpack.encode(Test()) if tag: - msg = ["Test"] + msg = [tag] else: msg = [] s2 = msgspec.msgpack.encode(msg) assert s == s2 - @pytest.mark.parametrize("tag", [False, True]) + @pytest.mark.parametrize("tag", [False, "Test", 123]) def test_encode_one_field_struct(self, tag): class Test(msgspec.Struct, array_like=True, tag=tag): a: int s = msgspec.msgpack.encode(Test(a=1)) if tag: - msg = ["Test", 1] + msg = [tag, 1] else: msg = [1] s2 = msgspec.msgpack.encode(msg) assert s == s2 - @pytest.mark.parametrize("tag", [False, True]) + @pytest.mark.parametrize("tag", [False, "Test", 123]) def test_encode_two_field_struct(self, tag): class Test(msgspec.Struct, array_like=True, tag=tag): a: int @@ -1686,7 +1714,7 @@ class Test(msgspec.Struct, array_like=True, tag=tag): s = msgspec.msgpack.encode(Test(a=1, b="two")) if tag: - msg = ["Test", 1, "two"] + msg = [tag, 1, "two"] else: msg = [1, "two"] s2 = msgspec.msgpack.encode(msg) @@ -1760,8 +1788,9 @@ def test_struct_map_and_array_like_messages_cant_mix(self): with pytest.raises(msgspec.DecodeError, match="Expected `array`, got `object`"): array_dec.decode(map_msg) - def test_decode_tagged_struct(self): - class Test(msgspec.Struct, tag=True, array_like=True): + @pytest.mark.parametrize("tag", ["Test", -123, 123]) + def test_decode_tagged_struct(self, tag): + class Test(msgspec.Struct, tag=tag, array_like=True): a: int b: int c: int = 0 @@ -1769,18 +1798,18 @@ class Test(msgspec.Struct, tag=True, array_like=True): dec = msgspec.msgpack.Decoder(Test) # Decode with tag - res = dec.decode(msgspec.msgpack.encode(["Test", 1, 2])) + res = dec.decode(msgspec.msgpack.encode([tag, 1, 2])) assert res == Test(1, 2) - res = dec.decode(msgspec.msgpack.encode(["Test", 1, 2, 3])) + res = dec.decode(msgspec.msgpack.encode([tag, 1, 2, 3])) assert res == Test(1, 2, 3) # Trailing fields ignored - res = dec.decode(msgspec.msgpack.encode(["Test", 1, 2, 3, 4])) + res = dec.decode(msgspec.msgpack.encode([tag, 1, 2, 3, 4])) assert res == Test(1, 2, 3) # Missing required field errors with pytest.raises(msgspec.DecodeError) as rec: - dec.decode(msgspec.msgpack.encode(["Test", 1])) + dec.decode(msgspec.msgpack.encode([tag, 1])) assert "Expected `array` of at least length 3, got 2" in str(rec.value) # Tag missing @@ -1790,30 +1819,32 @@ class Test(msgspec.Struct, tag=True, array_like=True): # Tag incorrect type with pytest.raises(msgspec.DecodeError) as rec: - dec.decode(msgspec.msgpack.encode([1, 2, 3])) - assert "Expected `str`" in str(rec.value) + dec.decode(msgspec.msgpack.encode([123.456, 2, 3])) + assert f"Expected `{type(tag).__name__}`" in str(rec.value) assert "`$[0]`" in str(rec.value) # Tag incorrect value + bad = -3 if isinstance(tag, int) else "bad" with pytest.raises(msgspec.DecodeError) as rec: - dec.decode(msgspec.msgpack.encode(["bad", 1, 2])) - assert "Invalid value 'bad'" in str(rec.value) + dec.decode(msgspec.msgpack.encode([bad, 1, 2])) + assert f"Invalid value {bad!r}" in str(rec.value) assert "`$[0]`" in str(rec.value) # Field incorrect type correct index with pytest.raises(msgspec.DecodeError) as rec: - dec.decode(msgspec.msgpack.encode(["Test", "a", 2])) + dec.decode(msgspec.msgpack.encode([tag, "a", 2])) assert "Expected `int`, got `str`" in str(rec.value) assert "`$[1]`" in str(rec.value) - def test_decode_tagged_empty_struct(self): - class Test(msgspec.Struct, tag=True, array_like=True): + @pytest.mark.parametrize("tag", ["Test", 123, -123]) + def test_decode_tagged_empty_struct(self, tag): + class Test(msgspec.Struct, tag=tag, array_like=True): pass dec = msgspec.msgpack.Decoder(Test) # Decode with tag - res = dec.decode(msgspec.msgpack.encode(["Test", 1, 2])) + res = dec.decode(msgspec.msgpack.encode([tag, 1, 2])) assert res == Test() # Tag missing diff --git a/tests/test_struct.py b/tests/test_struct.py index e13f0f37..349d3b85 100644 --- a/tests/test_struct.py +++ b/tests/test_struct.py @@ -1040,6 +1040,9 @@ class TestTagAndTagField: # tag str ({"tag": "test"}, "type", "test"), (dict(tag="test", tag_field="kind"), "kind", "test"), + # tag int + ({"tag": 1}, "type", 1), + (dict(tag=1, tag_field="kind"), "kind", 1), # tag callable (dict(tag=lambda n: n.lower()), "type", "test"), (dict(tag=lambda n: n.lower(), tag_field="kind"), "kind", "test"), @@ -1068,6 +1071,11 @@ class Test(Struct, **opts): ({"tag": "test"}, {"tag": "test2"}, "type", "test2"), ({"tag": "test"}, {"tag": None}, "type", "test"), ({"tag": "test"}, {"tag_field": "foo"}, "foo", "test"), + # tag int + ({"tag": 1}, {}, "type", 1), + ({"tag": 1}, {"tag": "test2"}, "type", "test2"), + ({"tag": 1}, {"tag": None}, "type", 1), + ({"tag": 1}, {"tag_field": "foo"}, "foo", 1), # tag callable ({"tag": lambda n: n.lower()}, {}, "type", "s2"), ({"tag": lambda n: n.lower()}, {"tag": False}, None, None), @@ -1087,7 +1095,14 @@ class S2(S1, **opts2): @pytest.mark.parametrize("tag", [b"bad", lambda n: b"bad"]) def test_tag_wrong_type(self, tag): - with pytest.raises(TypeError, match="`tag` must be a `str`"): + with pytest.raises(TypeError, match="`tag` must be a `str` or an `int`"): + + class Test(Struct, tag=tag): + pass + + @pytest.mark.parametrize("tag", [-(2**63) - 1, 2**63]) + def test_tag_integer_out_of_range(self, tag): + with pytest.raises(ValueError, match="Integer `tag` values must be"): class Test(Struct, tag=tag): pass