diff --git a/bindings/python/CHANGELOG.md b/bindings/python/CHANGELOG.md index 8416acff..eeb65a54 100644 --- a/bindings/python/CHANGELOG.md +++ b/bindings/python/CHANGELOG.md @@ -3,6 +3,15 @@ --- +# Changes in Version 1.9.0 (2025/XX/XX) + +- Providing a schema now enforces strict type adherence for data. + If a result contains a field whose value does not match the schema's type for that field, a TypeError will be raised. + Note that ``NaN`` is a valid type for all fields. + To suppress these errors and instead silently convert such mismatches to ``NaN``, pass the ``allow_invalid=True`` argument to your ``pymongoarrow`` API call. + For example, a result with a field of type ``int`` but with a string value will now raise a TypeError, + unless ``allow_invalid=True`` is passed, in which case the result's field will have a value of ``NaN``. + # Changes in Version 1.8.0 (2025/05/12) - Add support for PyArrow 20.0. diff --git a/bindings/python/pymongoarrow/api.py b/bindings/python/pymongoarrow/api.py index 11a43b88..ecb4a794 100644 --- a/bindings/python/pymongoarrow/api.py +++ b/bindings/python/pymongoarrow/api.py @@ -72,7 +72,7 @@ _MAX_WRITE_BATCH_SIZE = max(100000, MAX_WRITE_BATCH_SIZE) -def find_arrow_all(collection, query, *, schema=None, **kwargs): +def find_arrow_all(collection, query, *, schema=None, allow_invalid=False, **kwargs): """Method that returns the results of a find query as a :class:`pyarrow.Table` instance. @@ -83,6 +83,8 @@ def find_arrow_all(collection, query, *, schema=None, **kwargs): - `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`. If the schema is not given, it will be inferred using the data in the result set. + - `allow_invalid` (optional): If set to ``True``, + results will have all fields that do not conform to the schema silently converted to NaN. Additional keyword-arguments passed to this method will be passed directly to the underlying ``find`` operation. @@ -90,7 +92,9 @@ def find_arrow_all(collection, query, *, schema=None, **kwargs): :Returns: An instance of class:`pyarrow.Table`. """ - context = PyMongoArrowContext(schema, codec_options=collection.codec_options) + context = PyMongoArrowContext( + schema, codec_options=collection.codec_options, allow_invalid=allow_invalid + ) for opt in ("cursor_type",): if kwargs.pop(opt, None): @@ -110,7 +114,7 @@ def find_arrow_all(collection, query, *, schema=None, **kwargs): return context.finish() -def aggregate_arrow_all(collection, pipeline, *, schema=None, **kwargs): +def aggregate_arrow_all(collection, pipeline, *, schema=None, allow_invalid=False, **kwargs): """Method that returns the results of an aggregation pipeline as a :class:`pyarrow.Table` instance. @@ -121,6 +125,8 @@ def aggregate_arrow_all(collection, pipeline, *, schema=None, **kwargs): - `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`. If the schema is not given, it will be inferred using the data in the result set. + - `allow_invalid` (optional): If set to ``True``, + results will have all fields that do not conform to the schema silently converted to NaN. Additional keyword-arguments passed to this method will be passed directly to the underlying ``aggregate`` operation. @@ -128,7 +134,9 @@ def aggregate_arrow_all(collection, pipeline, *, schema=None, **kwargs): :Returns: An instance of class:`pyarrow.Table`. """ - context = PyMongoArrowContext(schema, codec_options=collection.codec_options) + context = PyMongoArrowContext( + schema, codec_options=collection.codec_options, allow_invalid=allow_invalid + ) if pipeline and ("$out" in pipeline[-1] or "$merge" in pipeline[-1]): msg = ( @@ -165,7 +173,7 @@ def _arrow_to_pandas(arrow_table): return arrow_table.to_pandas(split_blocks=True, self_destruct=True) -def find_pandas_all(collection, query, *, schema=None, **kwargs): +def find_pandas_all(collection, query, *, schema=None, allow_invalid=False, **kwargs): """Method that returns the results of a find query as a :class:`pandas.DataFrame` instance. @@ -176,6 +184,8 @@ def find_pandas_all(collection, query, *, schema=None, **kwargs): - `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`. If the schema is not given, it will be inferred using the data in the result set. + - `allow_invalid` (optional): If set to ``True``, + results will have all fields that do not conform to the schema silently converted to NaN. Additional keyword-arguments passed to this method will be passed directly to the underlying ``find`` operation. @@ -183,10 +193,12 @@ def find_pandas_all(collection, query, *, schema=None, **kwargs): :Returns: An instance of class:`pandas.DataFrame`. """ - return _arrow_to_pandas(find_arrow_all(collection, query, schema=schema, **kwargs)) + return _arrow_to_pandas( + find_arrow_all(collection, query, schema=schema, allow_invalid=allow_invalid, **kwargs) + ) -def aggregate_pandas_all(collection, pipeline, *, schema=None, **kwargs): +def aggregate_pandas_all(collection, pipeline, *, schema=None, allow_invalid=False, **kwargs): """Method that returns the results of an aggregation pipeline as a :class:`pandas.DataFrame` instance. @@ -197,6 +209,8 @@ def aggregate_pandas_all(collection, pipeline, *, schema=None, **kwargs): - `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`. If the schema is not given, it will be inferred using the data in the result set. + - `allow_invalid` (optional): If set to ``True``, + results will have all fields that do not conform to the schema silently converted to NaN. Additional keyword-arguments passed to this method will be passed directly to the underlying ``aggregate`` operation. @@ -204,7 +218,11 @@ def aggregate_pandas_all(collection, pipeline, *, schema=None, **kwargs): :Returns: An instance of class:`pandas.DataFrame`. """ - return _arrow_to_pandas(aggregate_arrow_all(collection, pipeline, schema=schema, **kwargs)) + return _arrow_to_pandas( + aggregate_arrow_all( + collection, pipeline, schema=schema, allow_invalid=allow_invalid, **kwargs + ) + ) def _arrow_to_numpy(arrow_table, schema=None): @@ -227,7 +245,7 @@ def _arrow_to_numpy(arrow_table, schema=None): return container -def find_numpy_all(collection, query, *, schema=None, **kwargs): +def find_numpy_all(collection, query, *, schema=None, allow_invalid=False, **kwargs): """Method that returns the results of a find query as a :class:`dict` instance whose keys are field names and values are :class:`~numpy.ndarray` instances bearing the appropriate dtype. @@ -239,6 +257,8 @@ def find_numpy_all(collection, query, *, schema=None, **kwargs): - `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`. If the schema is not given, it will be inferred using the data in the result set. + - `allow_invalid` (optional): If set to ``True``, + results will have all fields that do not conform to the schema silently converted to NaN. Additional keyword-arguments passed to this method will be passed directly to the underlying ``find`` operation. @@ -255,10 +275,13 @@ def find_numpy_all(collection, query, *, schema=None, **kwargs): :Returns: An instance of :class:`dict`. """ - return _arrow_to_numpy(find_arrow_all(collection, query, schema=schema, **kwargs), schema) + return _arrow_to_numpy( + find_arrow_all(collection, query, schema=schema, allow_invalid=allow_invalid, **kwargs), + schema, + ) -def aggregate_numpy_all(collection, pipeline, *, schema=None, **kwargs): +def aggregate_numpy_all(collection, pipeline, *, schema=None, allow_invalid=False, **kwargs): """Method that returns the results of an aggregation pipeline as a :class:`dict` instance whose keys are field names and values are :class:`~numpy.ndarray` instances bearing the appropriate dtype. @@ -270,6 +293,8 @@ def aggregate_numpy_all(collection, pipeline, *, schema=None, **kwargs): - `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`. If the schema is not given, it will be inferred using the data in the result set. + - `allow_invalid` (optional): If set to ``True``, + results will have all fields that do not conform to the schema silently converted to NaN. Additional keyword-arguments passed to this method will be passed directly to the underlying ``aggregate`` operation. @@ -287,7 +312,10 @@ def aggregate_numpy_all(collection, pipeline, *, schema=None, **kwargs): An instance of :class:`dict`. """ return _arrow_to_numpy( - aggregate_arrow_all(collection, pipeline, schema=schema, **kwargs), schema + aggregate_arrow_all( + collection, pipeline, schema=schema, allow_invalid=allow_invalid, **kwargs + ), + schema, ) @@ -326,7 +354,7 @@ def _arrow_to_polars(arrow_table: pa.Table): return pl.from_arrow(arrow_table_without_extensions) -def find_polars_all(collection, query, *, schema=None, **kwargs): +def find_polars_all(collection, query, *, schema=None, allow_invalid=False, **kwargs): """Method that returns the results of a find query as a :class:`polars.DataFrame` instance. @@ -337,6 +365,8 @@ def find_polars_all(collection, query, *, schema=None, **kwargs): - `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`. If the schema is not given, it will be inferred using the data in the result set. + - `allow_invalid` (optional): If set to ``True``, + results will have all fields that do not conform to the schema silently converted to NaN. Additional keyword-arguments passed to this method will be passed directly to the underlying ``find`` operation. @@ -346,10 +376,12 @@ def find_polars_all(collection, query, *, schema=None, **kwargs): .. versionadded:: 1.3 """ - return _arrow_to_polars(find_arrow_all(collection, query, schema=schema, **kwargs)) + return _arrow_to_polars( + find_arrow_all(collection, query, schema=schema, allow_invalid=allow_invalid, **kwargs) + ) -def aggregate_polars_all(collection, pipeline, *, schema=None, **kwargs): +def aggregate_polars_all(collection, pipeline, *, schema=None, allow_invalid=False, **kwargs): """Method that returns the results of an aggregation pipeline as a :class:`polars.DataFrame` instance. @@ -360,6 +392,8 @@ def aggregate_polars_all(collection, pipeline, *, schema=None, **kwargs): - `schema` (optional): Instance of :class:`~pymongoarrow.schema.Schema`. If the schema is not given, it will be inferred using the data in the result set. + - `allow_invalid` (optional): If set to ``True``, + results will have all fields that do not conform to the schema silently converted to NaN. Additional keyword-arguments passed to this method will be passed directly to the underlying ``aggregate`` operation. @@ -367,7 +401,11 @@ def aggregate_polars_all(collection, pipeline, *, schema=None, **kwargs): :Returns: An instance of class:`polars.DataFrame`. """ - return _arrow_to_polars(aggregate_arrow_all(collection, pipeline, schema=schema, **kwargs)) + return _arrow_to_polars( + aggregate_arrow_all( + collection, pipeline, schema=schema, allow_invalid=allow_invalid, **kwargs + ) + ) def _transform_bwe(bwe, offset): diff --git a/bindings/python/pymongoarrow/context.py b/bindings/python/pymongoarrow/context.py index 4c7c8f79..d8e07bb6 100644 --- a/bindings/python/pymongoarrow/context.py +++ b/bindings/python/pymongoarrow/context.py @@ -19,13 +19,15 @@ class PyMongoArrowContext: """A context for converting BSON-formatted data to an Arrow Table.""" - def __init__(self, schema, codec_options=None): + def __init__(self, schema, codec_options=None, allow_invalid=False): """Initialize the context. :Parameters: - `schema`: Instance of :class:`~pymongoarrow.schema.Schema`. - `builder_map`: Mapping of utf-8-encoded field names to :class:`~pymongoarrow.builders._BuilderBase` instances. + - `allow_invalid`: If set to ``True``, + results will have all fields that do not conform to the schema silently converted to NaN. """ self.schema = schema if self.schema is None and codec_options is not None: @@ -40,7 +42,9 @@ def __init__(self, schema, codec_options=None): # Delayed import to prevent import errors for unbuilt library. from pymongoarrow.lib import BuilderManager - self.manager = BuilderManager(schema_map, self.schema is not None, self.tzinfo) + self.manager = BuilderManager( + schema_map, self.schema is not None, self.tzinfo, allow_invalid=allow_invalid + ) self.schema_map = schema_map def process_bson_stream(self, stream): diff --git a/bindings/python/pymongoarrow/lib.pyx b/bindings/python/pymongoarrow/lib.pyx index 4ff35c3e..f6c02003 100644 --- a/bindings/python/pymongoarrow/lib.pyx +++ b/bindings/python/pymongoarrow/lib.pyx @@ -56,6 +56,59 @@ cdef const bson_t* bson_reader_read_safe(bson_reader_t* stream_reader) except? N raise InvalidBSON("Could not read BSON document stream") return doc +cdef str _bson_type_name(uint8_t type_t): + if type_t == BSON_TYPE_EOD: + result = "EOD" + elif type_t == BSON_TYPE_UTF8: + result = "string" + elif type_t == BSON_TYPE_DOUBLE: + result = "double" + elif type_t == BSON_TYPE_DOCUMENT: + result = "document" + elif type_t == BSON_TYPE_ARRAY: + result = "array" + elif type_t == BSON_TYPE_BINARY: + result = "binary" + elif type_t == BSON_TYPE_UNDEFINED: + result = "undefined" + elif type_t == BSON_TYPE_OID: + result = "objectId" + elif type_t == BSON_TYPE_BOOL: + result = "boolean" + elif type_t == BSON_TYPE_DATE_TIME: + result = "datetime" + elif type_t == BSON_TYPE_NULL: + result = "null" + elif type_t == BSON_TYPE_REGEX: + result = "regex" + elif type_t == BSON_TYPE_DBPOINTER: + result = "dbpointer" + elif type_t == BSON_TYPE_CODE: + result = "code" + elif type_t == BSON_TYPE_SYMBOL: + result = "symbol" + elif type_t == BSON_TYPE_CODEWSCOPE: + result = "codewscope" + elif type_t == BSON_TYPE_INT32: + result = "int32" + elif type_t == BSON_TYPE_TIMESTAMP: + result = "timestamp" + elif type_t == BSON_TYPE_INT64: + result = "int64" + elif type_t == BSON_TYPE_DECIMAL128: + result = "decimal128" + elif type_t == BSON_TYPE_MAXKEY: + result = "maxkey" + elif type_t == BSON_TYPE_MINKEY: + result = "minkey" + elif type_t == ARROW_TYPE_DATE32: + result = "date32" + elif type_t == ARROW_TYPE_DATE64: + result = "date64" + else: + result = f"Unknown type: {str(type_t)}" + return result + cdef class BuilderManager: cdef: @@ -66,8 +119,9 @@ cdef class BuilderManager: bint has_schema object tzinfo object pool + bint allow_invalid - def __cinit__(self, dict schema_map, bint has_schema, object tzinfo): + def __cinit__(self, dict schema_map, bint has_schema, object tzinfo, bint allow_invalid): self.has_schema = has_schema self.tzinfo = tzinfo self.count = 0 @@ -75,6 +129,7 @@ cdef class BuilderManager: self.parent_names = {} self.parent_types = {} self.pool = default_memory_pool() + self.allow_invalid = allow_invalid # Unpack the schema map. for fname, (ftype, arrow_type) in schema_map.items(): name = fname.encode('utf-8') @@ -82,14 +137,14 @@ cdef class BuilderManager: if ftype == BSON_TYPE_DATE_TIME: if tzinfo is not None and arrow_type.tz is None: arrow_type = timestamp(arrow_type.unit, tz=tzinfo) # noqa: PLW2901 - self.builder_map[name] = DatetimeBuilder(dtype=arrow_type, memory_pool=self.pool) + self.builder_map[name] = DatetimeBuilder(dtype=arrow_type, memory_pool=self.pool, allow_invalid=allow_invalid) elif ftype == BSON_TYPE_BINARY: - self.builder_map[name] = BinaryBuilder(arrow_type.subtype, memory_pool=self.pool) + self.builder_map[name] = BinaryBuilder(arrow_type.subtype, memory_pool=self.pool, allow_invalid=allow_invalid) else: # We only use the doc_iter for binary arrays, which are handled already. - self.get_builder(name, ftype, nullptr) + self.get_builder(name, ftype, nullptr, allow_invalid=allow_invalid) - cdef _ArrayBuilderBase get_builder(self, cstring key, bson_type_t value_t, bson_iter_t * doc_iter): + cdef _ArrayBuilderBase get_builder(self, cstring key, bson_type_t value_t, bson_iter_t * doc_iter, bint allow_invalid): cdef _ArrayBuilderBase builder = None cdef bson_subtype_t subtype cdef const uint8_t *val_buf = NULL @@ -119,27 +174,27 @@ cdef class BuilderManager: raise ValueError('Did not pass a doc_iter!') bson_iter_binary (doc_iter, &subtype, &val_buf_len, &val_buf) - builder = BinaryBuilder(subtype, memory_pool=self.pool) + builder = BinaryBuilder(subtype, memory_pool=self.pool, allow_invalid=allow_invalid) elif value_t == ARROW_TYPE_DATE32: - builder = Date32Builder(memory_pool=self.pool) + builder = Date32Builder(memory_pool=self.pool, allow_invalid=allow_invalid) elif value_t == ARROW_TYPE_DATE64: - builder = Date64Builder(memory_pool=self.pool) + builder = Date64Builder(memory_pool=self.pool, allow_invalid=allow_invalid) elif value_t == BSON_TYPE_INT32: - builder = Int32Builder(memory_pool=self.pool) + builder = Int32Builder(memory_pool=self.pool, allow_invalid=allow_invalid) elif value_t == BSON_TYPE_INT64: - builder = Int64Builder(memory_pool=self.pool) + builder = Int64Builder(memory_pool=self.pool, allow_invalid=allow_invalid) elif value_t == BSON_TYPE_DOUBLE: - builder = DoubleBuilder(memory_pool=self.pool) + builder = DoubleBuilder(memory_pool=self.pool, allow_invalid=allow_invalid) elif value_t == BSON_TYPE_OID: - builder = ObjectIdBuilder(memory_pool=self.pool) + builder = ObjectIdBuilder(memory_pool=self.pool, allow_invalid=allow_invalid) elif value_t == BSON_TYPE_UTF8: - builder = StringBuilder(memory_pool=self.pool) + builder = StringBuilder(memory_pool=self.pool, allow_invalid=allow_invalid) elif value_t == BSON_TYPE_BOOL: - builder = BoolBuilder(memory_pool=self.pool) + builder = BoolBuilder(memory_pool=self.pool, allow_invalid=allow_invalid) elif value_t == BSON_TYPE_DECIMAL128: - builder = Decimal128Builder(memory_pool=self.pool) + builder = Decimal128Builder(memory_pool=self.pool, allow_invalid=allow_invalid) elif value_t == BSON_TYPE_CODE: - builder = CodeBuilder(memory_pool=self.pool) + builder = CodeBuilder(memory_pool=self.pool, allow_invalid=allow_invalid) self.builder_map[key] = builder return builder @@ -178,7 +233,7 @@ cdef class BuilderManager: # Get the builder. builder = <_ArrayBuilderBase>self.builder_map.get(full_key, None) if builder is None and not self.has_schema: - builder = self.get_builder(full_key, value_t, doc_iter) + builder = self.get_builder(full_key, value_t, doc_iter, True) if builder is None: continue @@ -279,6 +334,7 @@ cdef class BuilderManager: cdef class _ArrayBuilderBase: cdef: public uint8_t type_marker + public bint allow_invalid def append_values(self, values): for value in values: @@ -354,10 +410,11 @@ cdef class StringBuilder(_ArrayBuilderBase): cdef: shared_ptr[CStringBuilder] builder - def __cinit__(self, MemoryPool memory_pool=None): + def __cinit__(self, MemoryPool memory_pool=None, bint allow_invalid=False): cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) self.builder.reset(new CStringBuilder(pool)) self.type_marker = BSON_TYPE_UTF8 + self.allow_invalid = allow_invalid cdef CStatus append_raw(self, bson_iter_t * doc_iter, bson_type_t value_t): cdef const char* value @@ -365,17 +422,21 @@ cdef class StringBuilder(_ArrayBuilderBase): if value_t == BSON_TYPE_UTF8: value = bson_iter_utf8(doc_iter, &str_len) return self.builder.get().Append(value, str_len) - return self.builder.get().AppendNull() + if not self.allow_invalid and value_t != BSON_TYPE_NULL: + raise TypeError(f"Got unexpected type `{_bson_type_name(value_t)}` instead of expected type `{_bson_type_name(self.type_marker)}`") + else: + return self.builder.get().AppendNull() cdef shared_ptr[CArrayBuilder] get_builder(self): return self.builder cdef class CodeBuilder(StringBuilder): - def __cinit__(self, MemoryPool memory_pool=None): + def __cinit__(self, MemoryPool memory_pool=None, bint allow_invalid=False): cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) self.builder.reset(new CStringBuilder(pool)) self.type_marker = BSON_TYPE_CODE + self.allow_invalid = allow_invalid cdef CStatus append_raw(self, bson_iter_t * doc_iter, bson_type_t value_t): cdef const char * bson_str @@ -383,7 +444,10 @@ cdef class CodeBuilder(StringBuilder): if value_t == BSON_TYPE_CODE: bson_str = bson_iter_code(doc_iter, &str_len) return self.builder.get().Append(bson_str, str_len) - return self.builder.get().AppendNull() + if not self.allow_invalid and value_t != BSON_TYPE_NULL: + raise TypeError(f"Got unexpected type `{_bson_type_name(value_t)}` instead of expected type `{_bson_type_name(self.type_marker)}`") + else: + return self.builder.get().AppendNull() cdef shared_ptr[CArrayBuilder] get_builder(self): return self.builder @@ -395,16 +459,20 @@ cdef class CodeBuilder(StringBuilder): cdef class ObjectIdBuilder(_ArrayBuilderBase): cdef shared_ptr[CFixedSizeBinaryBuilder] builder - def __cinit__(self, MemoryPool memory_pool=None): + def __cinit__(self, MemoryPool memory_pool=None, bint allow_invalid=False): cdef shared_ptr[CDataType] dtype = fixed_size_binary(12) cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) self.builder.reset(new CFixedSizeBinaryBuilder(dtype, pool)) self.type_marker = BSON_TYPE_OID + self.allow_invalid = allow_invalid cdef CStatus append_raw(self, bson_iter_t * doc_iter, bson_type_t value_t): if value_t == BSON_TYPE_OID: return self.builder.get().Append(bson_iter_oid(doc_iter).bytes) - return self.builder.get().AppendNull() + if not self.allow_invalid and value_t != BSON_TYPE_NULL: + raise TypeError(f"Got unexpected type `{_bson_type_name(value_t)}` instead of expected type `{_bson_type_name(self.type_marker)}`") + else: + return self.builder.get().AppendNull() cdef shared_ptr[CArrayBuilder] get_builder(self): return self.builder @@ -416,10 +484,11 @@ cdef class ObjectIdBuilder(_ArrayBuilderBase): cdef class Int32Builder(_ArrayBuilderBase): cdef shared_ptr[CInt32Builder] builder - def __cinit__(self, MemoryPool memory_pool=None): + def __cinit__(self, MemoryPool memory_pool=None, bint allow_invalid=False): cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) self.builder.reset(new CInt32Builder(pool)) self.type_marker = BSON_TYPE_INT32 + self.allow_invalid = allow_invalid cdef CStatus append_raw(self, bson_iter_t * doc_iter, bson_type_t value_t) except *: cdef double dvalue @@ -435,13 +504,16 @@ cdef class Int32Builder(_ArrayBuilderBase): # Treat nan as null. dvalue = bson_iter_as_double(doc_iter) if isnan(dvalue): - return self.builder.get().AppendNull() + return self.builder.get().AppendNull() # Check for overflow errors. ivalue = bson_iter_as_int64(doc_iter) if ivalue > INT_MAX or ivalue < INT_MIN: raise OverflowError("Overflowed Int32 value") return self.builder.get().Append(ivalue) - return self.builder.get().AppendNull() + if not self.allow_invalid and value_t != BSON_TYPE_NULL: + raise TypeError(f"Got unexpected type `{_bson_type_name(value_t)}` instead of expected type `{_bson_type_name(self.type_marker)}`") + else: + return self.builder.get().AppendNull() cdef shared_ptr[CArrayBuilder] get_builder(self): return self.builder @@ -450,10 +522,11 @@ cdef class Int32Builder(_ArrayBuilderBase): cdef class Int64Builder(_ArrayBuilderBase): cdef shared_ptr[CInt64Builder] builder - def __cinit__(self, MemoryPool memory_pool=None): + def __cinit__(self, MemoryPool memory_pool=None, bint allow_invalid=False): cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) self.builder.reset(new CInt64Builder(pool)) self.type_marker = BSON_TYPE_INT64 + self.allow_invalid = allow_invalid cdef CStatus append_raw(self, bson_iter_t * doc_iter, bson_type_t value_t): cdef double dvalue @@ -468,7 +541,10 @@ cdef class Int64Builder(_ArrayBuilderBase): if isnan(dvalue): return self.builder.get().AppendNull() return self.builder.get().Append(bson_iter_as_int64(doc_iter)) - return self.builder.get().AppendNull() + if not self.allow_invalid and value_t != BSON_TYPE_NULL: + raise TypeError(f"Got unexpected type `{_bson_type_name(value_t)}` instead of expected type `{_bson_type_name(self.type_marker)}`") + else: + return self.builder.get().AppendNull() cdef shared_ptr[CArrayBuilder] get_builder(self): return self.builder @@ -477,10 +553,11 @@ cdef class Int64Builder(_ArrayBuilderBase): cdef class DoubleBuilder(_ArrayBuilderBase): cdef shared_ptr[CDoubleBuilder] builder - def __cinit__(self, MemoryPool memory_pool=None): + def __cinit__(self, MemoryPool memory_pool=None, bint allow_invalid=False): cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) self.builder.reset(new CDoubleBuilder(pool)) self.type_marker = BSON_TYPE_DOUBLE + self.allow_invalid = allow_invalid cdef CStatus append_raw(self, bson_iter_t * doc_iter, bson_type_t value_t): if (value_t == BSON_TYPE_DOUBLE or @@ -488,7 +565,10 @@ cdef class DoubleBuilder(_ArrayBuilderBase): value_t == BSON_TYPE_INT32 or value_t == BSON_TYPE_INT64): return self.builder.get().Append(bson_iter_as_double(doc_iter)) - return self.builder.get().AppendNull() + if not self.allow_invalid and value_t != BSON_TYPE_NULL: + raise TypeError(f"Got unexpected type `{_bson_type_name(value_t)}` instead of expected type `{_bson_type_name(self.type_marker)}`") + else: + return self.builder.get().AppendNull() cdef shared_ptr[CArrayBuilder] get_builder(self): return self.builder @@ -500,7 +580,7 @@ cdef class DatetimeBuilder(_ArrayBuilderBase): shared_ptr[CTimestampBuilder] builder def __cinit__(self, TimestampType dtype=timestamp('ms'), - MemoryPool memory_pool=None): + MemoryPool memory_pool=None, bint allow_invalid=False): cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) if dtype.unit != 'ms': raise TypeError("PyMongoArrow only supports millisecond " @@ -510,6 +590,7 @@ cdef class DatetimeBuilder(_ArrayBuilderBase): self.builder.reset(new CTimestampBuilder( pyarrow_unwrap_data_type(self.dtype), pool)) self.type_marker = BSON_TYPE_DATE_TIME + self.allow_invalid = allow_invalid @property def unit(self): @@ -518,7 +599,10 @@ cdef class DatetimeBuilder(_ArrayBuilderBase): cdef CStatus append_raw(self, bson_iter_t * doc_iter, bson_type_t value_t): if value_t == BSON_TYPE_DATE_TIME: return self.builder.get().Append(bson_iter_date_time(doc_iter)) - return self.builder.get().AppendNull() + if not self.allow_invalid and value_t != BSON_TYPE_NULL: + raise TypeError(f"Got unexpected type `{_bson_type_name(value_t)}` instead of expected type `{_bson_type_name(self.type_marker)}`") + else: + return self.builder.get().AppendNull() cdef shared_ptr[CArrayBuilder] get_builder(self): return self.builder @@ -528,15 +612,19 @@ cdef class Date64Builder(_ArrayBuilderBase): DataType dtype shared_ptr[CDate64Builder] builder - def __cinit__(self, MemoryPool memory_pool=None): + def __cinit__(self, MemoryPool memory_pool=None, bint allow_invalid=False): cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) self.builder.reset(new CDate64Builder(pool)) self.type_marker = ARROW_TYPE_DATE64 + self.allow_invalid = allow_invalid cdef CStatus append_raw(self, bson_iter_t * doc_iter, bson_type_t value_t): if value_t == BSON_TYPE_DATE_TIME: return self.builder.get().Append(bson_iter_date_time(doc_iter)) - return self.builder.get().AppendNull() + if not self.allow_invalid and value_t != BSON_TYPE_NULL: + raise TypeError(f"Got unexpected type `{_bson_type_name(value_t)}` instead of expected type `{_bson_type_name(self.type_marker)}`") + else: + return self.builder.get().AppendNull() @property def unit(self): @@ -551,10 +639,11 @@ cdef class Date32Builder(_ArrayBuilderBase): DataType dtype shared_ptr[CDate32Builder] builder - def __cinit__(self, MemoryPool memory_pool=None): + def __cinit__(self, MemoryPool memory_pool=None, bint allow_invalid=False): cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) self.builder.reset(new CDate32Builder(pool)) self.type_marker = ARROW_TYPE_DATE32 + self.allow_invalid = allow_invalid cdef CStatus append_raw(self, bson_iter_t * doc_iter, bson_type_t value_t): cdef int64_t value @@ -565,7 +654,10 @@ cdef class Date32Builder(_ArrayBuilderBase): # Convert from milliseconds to days (1000*60*60*24) seconds_val = value // 86400000 return self.builder.get().Append(seconds_val) - return self.builder.get().AppendNull() + if not self.allow_invalid and value_t != BSON_TYPE_NULL: + raise TypeError(f"Got unexpected type `{_bson_type_name(value_t)}` instead of expected type `{_bson_type_name(self.type_marker)}`") + else: + return self.builder.get().AppendNull() @property def unit(self): @@ -592,15 +684,19 @@ cdef class NullBuilder(_ArrayBuilderBase): cdef class BoolBuilder(_ArrayBuilderBase): cdef shared_ptr[CBooleanBuilder] builder - def __cinit__(self, MemoryPool memory_pool=None): + def __cinit__(self, MemoryPool memory_pool=None, bint allow_invalid=False): cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) self.builder.reset(new CBooleanBuilder(pool)) self.type_marker = BSON_TYPE_BOOL + self.allow_invalid = allow_invalid cdef CStatus append_raw(self, bson_iter_t * doc_iter, bson_type_t value_t): if value_t == BSON_TYPE_BOOL: return self.builder.get().Append(bson_iter_bool(doc_iter)) - return self.builder.get().AppendNull() + if not self.allow_invalid and value_t != BSON_TYPE_NULL: + raise TypeError(f"Got unexpected type `{_bson_type_name(value_t)}` instead of expected type `{_bson_type_name(self.type_marker)}`") + else: + return self.builder.get().AppendNull() cdef shared_ptr[CArrayBuilder] get_builder(self): return self.builder @@ -609,7 +705,7 @@ cdef class Decimal128Builder(_ArrayBuilderBase): cdef shared_ptr[CFixedSizeBinaryBuilder] builder cdef uint8_t supported - def __cinit__(self, MemoryPool memory_pool=None): + def __cinit__(self, MemoryPool memory_pool=None, bint allow_invalid=False): cdef shared_ptr[CDataType] dtype = fixed_size_binary(16) cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) self.builder.reset(new CFixedSizeBinaryBuilder(dtype, pool)) @@ -618,6 +714,7 @@ cdef class Decimal128Builder(_ArrayBuilderBase): self.supported = 1 else: self.supported = 0 + self.allow_invalid = allow_invalid cdef CStatus append_raw(self, bson_iter_t * doc_iter, bson_type_t value_t): cdef uint8_t dec128_buf[16] @@ -625,14 +722,17 @@ cdef class Decimal128Builder(_ArrayBuilderBase): if self.supported == 0: # We do not support big-endian systems. - return self.builder.get().AppendNull() + raise TypeError(f"Big-endian systems are not supported for `{_bson_type_name(self.type_marker)}`") if value_t == BSON_TYPE_DECIMAL128: bson_iter_decimal128(doc_iter, &dec128) memcpy(dec128_buf, &dec128.low, 8); memcpy(dec128_buf + 8, &dec128.high, 8) return self.builder.get().Append(dec128_buf) - return self.builder.get().AppendNull() + if not self.allow_invalid and value_t != BSON_TYPE_NULL: + raise TypeError(f"Got unexpected type `{_bson_type_name(value_t)}` instead of expected type `{_bson_type_name(self.type_marker)}`") + else: + return self.builder.get().AppendNull() cdef shared_ptr[CArrayBuilder] get_builder(self): return self.builder @@ -646,11 +746,12 @@ cdef class BinaryBuilder(_ArrayBuilderBase): uint8_t _subtype shared_ptr[CStringBuilder] builder - def __cinit__(self, uint8_t subtype, MemoryPool memory_pool=None): + def __cinit__(self, uint8_t subtype, MemoryPool memory_pool=None, bint allow_invalid=False): cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool) self._subtype = subtype self.builder.reset(new CStringBuilder(pool)) self.type_marker = BSON_TYPE_BINARY + self.allow_invalid = allow_invalid @property def subtype(self): @@ -664,9 +765,16 @@ cdef class BinaryBuilder(_ArrayBuilderBase): if value_t == BSON_TYPE_BINARY: bson_iter_binary(doc_iter, &subtype, &val_buf_len, &val_buf) if subtype != self._subtype: - return self.builder.get().AppendNull() + if not self.allow_invalid: + raise TypeError( + f"Got unexpected subtype `{subtype}` instead of expected subtype `{self._subtype}`") + else: + return self.builder.get().AppendNull() return self.builder.get().Append(val_buf, val_buf_len) - return self.builder.get().AppendNull() + if not self.allow_invalid and value_t != BSON_TYPE_NULL: + raise TypeError(f"Got unexpected type `{_bson_type_name(value_t)}` instead of expected type `{_bson_type_name(self.type_marker)}`") + else: + return self.builder.get().AppendNull() cdef shared_ptr[CArrayBuilder] get_builder(self): return self.builder diff --git a/bindings/python/test/test_arrow.py b/bindings/python/test/test_arrow.py index 34d9a01a..e0f16b06 100644 --- a/bindings/python/test/test_arrow.py +++ b/bindings/python/test/test_arrow.py @@ -544,13 +544,22 @@ def test_schema_missing_field(self): out = func(self.coll, {} if func == find_arrow_all else [], schema=schema).drop(["_id"]) self.assertEqual(out["list_field"].to_pylist(), expected) - def test_schema_incorrect_data_type(self): + def test_schema_incorrect_data_type_allow_invalid(self): # From https://github.com/mongodb-labs/mongo-arrow/issues/260. self.coll.delete_many({}) self.coll.insert_one({"x": {"y": 1}}) - out = find_arrow_all(self.coll, {}, schema=Schema({"x": str})) + out = find_arrow_all(self.coll, {}, schema=Schema({"x": str}), allow_invalid=True) assert out.to_pylist() == [{"x": None}] + def test_schema_incorrect_data_type(self): + # From https://github.com/mongodb-labs/mongo-arrow/issues/260. + self.coll.delete_many({}) + self.coll.insert_one({"x": {"y": 1}}) + with self.assertRaisesRegex( + TypeError, "Got unexpected type `document` instead of expected type `string`" + ): + find_arrow_all(self.coll, {}, schema=Schema({"x": str})) + def test_schema_arrays_of_documents(self): # From https://github.com/mongodb-labs/mongo-arrow/issues/258. coll = self.coll @@ -824,7 +833,7 @@ def test_empty_nested_objects(self): data = Table.from_pydict(raw_data, ArrowSchema(schema)) self.round_trip(data, Schema(schema)) - def test_malformed_embedded_documents(self): + def test_malformed_embedded_documents_allow_invalid(self): schema = Schema({"data": struct([field("a", int32()), field("b", bool_())])}) data = [ dict(data=dict(a=1, b=True)), @@ -834,7 +843,7 @@ def test_malformed_embedded_documents(self): ] self.coll.drop() self.coll.insert_many(data) - res = find_arrow_all(self.coll, {}, schema=schema)["data"].to_pylist() + res = find_arrow_all(self.coll, {}, schema=schema, allow_invalid=True)["data"].to_pylist() self.assertEqual( res, [ @@ -845,7 +854,22 @@ def test_malformed_embedded_documents(self): ], ) - def test_mixed_subtype(self): + def test_malformed_embedded_documents(self): + schema = Schema({"data": struct([field("a", int32()), field("b", bool_())])}) + data = [ + dict(data=dict(a=1, b=True)), + dict(data=dict(a=1, b=True, c="bar")), + dict(data=dict(a=1)), + dict(data=dict(a="str", b=False)), + ] + self.coll.drop() + self.coll.insert_many(data) + with self.assertRaisesRegex( + TypeError, "Got unexpected type `string` instead of expected type `int32`" + ): + find_arrow_all(self.coll, {}, schema=schema)["data"].to_pylist() + + def test_mixed_subtype_allow_invalid(self): schema = Schema({"data": BinaryType(10)}) coll = self.client.pymongoarrow_test.get_collection( "test", write_concern=WriteConcern(w="majority") @@ -853,9 +877,22 @@ def test_mixed_subtype(self): coll.drop() coll.insert_many([{"data": Binary(b"1", 10)}, {"data": Binary(b"2", 20)}]) - res = find_arrow_all(coll, {}, schema=schema) + res = find_arrow_all(coll, {}, schema=schema, allow_invalid=True) self.assertEqual(res["data"].to_pylist(), [Binary(b"1", 10), None]) + def test_mixed_subtype(self): + schema = Schema({"data": BinaryType(10)}) + coll = self.client.pymongoarrow_test.get_collection( + "test", write_concern=WriteConcern(w="majority") + ) + + coll.drop() + coll.insert_many([{"data": Binary(b"1", 10)}, {"data": Binary(b"2", 20)}]) + with self.assertRaisesRegex( + TypeError, "Got unexpected subtype `20` instead of expected subtype `10`" + ): + find_arrow_all(coll, {}, schema=schema) + def _test_mixed_types_int(self, inttype): docs = [ {"a": 1}, @@ -869,7 +906,15 @@ def _test_mixed_types_int(self, inttype): ] self.coll.delete_many({}) self.coll.insert_many(docs) - table = find_arrow_all(self.coll, {}, projection={"_id": 0}, schema=Schema({"a": inttype})) + # Test with strict schema + with self.assertRaisesRegex( + TypeError, f"Got unexpected type `string` instead of expected type `{inttype}`" + ): + find_arrow_all(self.coll, {}, projection={"_id": 0}, schema=Schema({"a": inttype})) + # Test with allow_invalid + table = find_arrow_all( + self.coll, {}, projection={"_id": 0}, schema=Schema({"a": inttype}), allow_invalid=True + ) expected = Table.from_pylist( [ {"a": 1}, @@ -891,12 +936,24 @@ def test_mixed_types_int32(self): self.coll.delete_many({}) self.coll.insert_one({"a": 2 << 34}) with self.assertRaises(OverflowError): - find_arrow_all(self.coll, {}, projection={"_id": 0}, schema=Schema({"a": int32()})) + find_arrow_all( + self.coll, + {}, + projection={"_id": 0}, + schema=Schema({"a": int32()}), + allow_invalid=True, + ) # Test double overflowing int32 self.coll.delete_many({}) self.coll.insert_one({"a": float(2 << 34)}) with self.assertRaises(OverflowError): - find_arrow_all(self.coll, {}, projection={"_id": 0}, schema=Schema({"a": int32()})) + find_arrow_all( + self.coll, + {}, + projection={"_id": 0}, + schema=Schema({"a": int32()}), + allow_invalid=True, + ) def test_mixed_types_int64(self): self._test_mixed_types_int(int64()) diff --git a/bindings/python/test/test_bson.py b/bindings/python/test/test_bson.py index a293b900..1022e63a 100644 --- a/bindings/python/test/test_bson.py +++ b/bindings/python/test/test_bson.py @@ -167,7 +167,8 @@ def setUp(self): self.schema = Schema({"data": bool}) self.context = PyMongoArrowContext(self.schema) - def test_simple(self): + def test_simple_allow_invalid(self): + self.context = PyMongoArrowContext(self.schema, allow_invalid=True) docs = [ {"data": True}, {"data": False}, @@ -179,6 +180,21 @@ def test_simple(self): as_dict = {"data": [True, False, None, None, False, True]} self._run_test(docs, as_dict) + def test_simple(self): + docs = [ + {"data": True}, + {"data": False}, + {"data": 19}, + {"data": "string"}, + {"data": False}, + {"data": True}, + ] + as_dict = {"data": [True, False, None, None, False, True]} + with self.assertRaisesRegex( + TypeError, "Got unexpected type `int32` instead of expected type `boolean`" + ): + self._run_test(docs, as_dict) + class TestStringType(TestBsonToArrowConversionBase): def setUp(self): @@ -214,6 +230,29 @@ def setUp(self): self.schema = Schema({"data": dict(x=bool)}) self.context = PyMongoArrowContext(self.schema) + def test_simple_allow_invalid(self): + self.context = PyMongoArrowContext(self.schema, allow_invalid=True) + + docs = [ + {"data": dict(x=True)}, + {"data": dict(x=False)}, + {"data": dict(x=19)}, + {"data": dict(x="string")}, + {"data": dict(x=False)}, + {"data": dict(x=True)}, + ] + as_dict = { + "data": [ + dict(x=True), + dict(x=False), + dict(x=None), + dict(x=None), + dict(x=False), + dict(x=True), + ] + } + self._run_test(docs, as_dict) + def test_simple(self): docs = [ {"data": dict(x=True)}, @@ -233,6 +272,33 @@ def test_simple(self): dict(x=True), ] } + with self.assertRaisesRegex( + TypeError, "Got unexpected type `int32` instead of expected type `boolean`" + ): + self._run_test(docs, as_dict) + + def test_nested_allow_invalid(self): + self.schema = Schema({"data": dict(x=bool, y=dict(a=int))}) + self.context = PyMongoArrowContext(self.schema, allow_invalid=True) + + docs = [ + {"data": dict(x=True, y=dict(a=1))}, + {"data": dict(x=False, y=dict(a=1))}, + {"data": dict(x=19, y=dict(a=1))}, + {"data": dict(x="string", y=dict(a=1))}, + {"data": dict(x=False, y=dict(a=1))}, + {"data": dict(x=True, y=dict(a=1))}, + ] + as_dict = { + "data": [ + dict(x=True, y=dict(a=1)), + dict(x=False, y=dict(a=1)), + dict(x=None, y=dict(a=1)), + dict(x=None, y=dict(a=1)), + dict(x=False, y=dict(a=1)), + dict(x=True, y=dict(a=1)), + ] + } self._run_test(docs, as_dict) def test_nested(self): @@ -257,4 +323,7 @@ def test_nested(self): dict(x=True, y=dict(a=1)), ] } - self._run_test(docs, as_dict) + with self.assertRaisesRegex( + TypeError, "Got unexpected type `int32` instead of expected type `boolean`" + ): + self._run_test(docs, as_dict) diff --git a/bindings/python/test/test_builders.py b/bindings/python/test/test_builders.py index 373d4c8a..0db38af2 100644 --- a/bindings/python/test/test_builders.py +++ b/bindings/python/test/test_builders.py @@ -38,8 +38,8 @@ class IntBuildersTestMixin: - def test_simple(self): - builder = self.builder_cls() + def test_simple_allow_invalid(self): + builder = self.builder_cls(allow_invalid=True) builder.append(0) builder.append_values([1, 2, 3, 4]) builder.append("a") @@ -52,6 +52,15 @@ def test_simple(self): self.assertEqual(arr.to_pylist(), [0, 1, 2, 3, 4, None, None]) self.assertEqual(arr.type, self.data_type) + def test_simple(self): + builder = self.builder_cls() + builder.append(0) + builder.append_values([1, 2, 3, 4]) + with self.assertRaisesRegex( + TypeError, f"Got unexpected type `string` instead of expected type `{self.data_type}`" + ): + builder.append("a") + class TestInt32Builder(TestCase, IntBuildersTestMixin): def setUp(self): @@ -71,10 +80,10 @@ def test_default_unit(self): builder = DatetimeBuilder() self.assertEqual(builder.unit, timestamp("ms")) - def test_simple(self): + def test_simple_allow_invalid(self): self.maxDiff = None - builder = DatetimeBuilder(dtype=timestamp("ms")) + builder = DatetimeBuilder(dtype=timestamp("ms"), allow_invalid=True) datetimes = [datetime.now(timezone.utc) + timedelta(days=k * 100) for k in range(5)] builder.append(datetimes[0]) builder.append_values(datetimes[1:]) @@ -92,6 +101,18 @@ def test_simple(self): self.assertIsNone(expected) self.assertEqual(arr.type, timestamp("ms")) + def test_simple(self): + self.maxDiff = None + + builder = DatetimeBuilder(dtype=timestamp("ms")) + datetimes = [datetime.now(timezone.utc) + timedelta(days=k * 100) for k in range(5)] + builder.append(datetimes[0]) + builder.append_values(datetimes[1:]) + with self.assertRaisesRegex( + TypeError, "Got unexpected type `int32` instead of expected type `datetime`" + ): + builder.append(1) + def test_unsupported_units(self): for unit in ("s", "us", "ns"): with self.assertRaises(TypeError): @@ -99,8 +120,8 @@ def test_unsupported_units(self): class TestDoubleBuilder(TestCase): - def test_simple(self): - builder = DoubleBuilder() + def test_simple_allow_invalid(self): + builder = DoubleBuilder(allow_invalid=True) values = [0.123, 1.234, 2.345, 3.456, 4.567, 1] builder.append(values[0]) builder.append_values(values[1:]) @@ -113,11 +134,21 @@ def test_simple(self): self.assertEqual(len(arr), 8) self.assertEqual(arr.to_pylist(), values + [None, None]) + def test_simple(self): + builder = DoubleBuilder() + values = [0.123, 1.234, 2.345, 3.456, 4.567, 1] + builder.append(values[0]) + builder.append_values(values[1:]) + with self.assertRaisesRegex( + TypeError, "Got unexpected type `string` instead of expected type `double`" + ): + builder.append("a") + class TestObjectIdBuilder(TestCase): - def test_simple(self): - ids = [ObjectId() for i in range(5)] - builder = ObjectIdBuilder() + def test_simple_allow_invalid(self): + ids = [ObjectId() for _ in range(5)] + builder = ObjectIdBuilder(allow_invalid=True) builder.append(ids[0]) builder.append_values(ids[1:]) builder.append(b"123456789123") @@ -129,14 +160,24 @@ def test_simple(self): self.assertEqual(len(arr), 7) self.assertEqual(arr.to_pylist(), ids + [None, None]) + def test_simple(self): + ids = [ObjectId() for _ in range(5)] + builder = ObjectIdBuilder() + builder.append(ids[0]) + builder.append_values(ids[1:]) + with self.assertRaisesRegex( + TypeError, "Got unexpected type `binary` instead of expected type `objectId`" + ): + builder.append(b"123456789123") + class TestStringBuilder(TestCase): - def test_simple(self): + def test_simple_allow_invalid(self): # Greetings in various languages, from # https://www.w3.org/2001/06/utf-8-test/UTF-8-demo.html values = ["Hello world", "Καλημέρα κόσμε", "コンニチハ"] values += ["hello\u0000world"] - builder = StringBuilder() + builder = StringBuilder(allow_invalid=True) builder.append(values[0]) builder.append_values(values[1:]) builder.append(b"1") @@ -148,6 +189,19 @@ def test_simple(self): self.assertEqual(len(arr), 6) self.assertEqual(arr.to_pylist(), values + [None, None]) + def test_simple(self): + # Greetings in various languages, from + # https://www.w3.org/2001/06/utf-8-test/UTF-8-demo.html + values = ["Hello world", "Καλημέρα κόσμε", "コンニチハ"] + values += ["hello\u0000world"] + builder = StringBuilder() + builder.append(values[0]) + builder.append_values(values[1:]) + with self.assertRaisesRegex( + TypeError, "Got unexpected type `binary` instead of expected type `string`" + ): + builder.append(b"1") + class TestDocumentBuilder(TestCase): def test_simple(self): @@ -182,7 +236,7 @@ def test_simple(self): class TestBuilderManager(TestCase): def test_simple(self): - manager = BuilderManager({}, False, None) + manager = BuilderManager({}, False, None, False) data = b"".join(encode(d) for d in [dict(a=1), dict(a=2), dict(a=None), dict(a=4)]) manager.process_bson_stream(data, len(data)) array_map = manager.finish() @@ -200,7 +254,7 @@ def test_nested_object(self): inner = inner_values[0].copy() inner["c"] = 1.0 values.append(dict(c=inner, e=ObjectId(), f=None, g=[])) - manager = BuilderManager({}, False, None) + manager = BuilderManager({}, False, None, False) data = b"".join(encode(v) for v in values) manager.process_bson_stream(data, len(data)) array_map = manager.finish() @@ -239,9 +293,9 @@ def test_nested_object(self): class TestBinaryBuilder(TestCase): - def test_simple(self): + def test_simple_allow_invalid(self): data = [Binary(bytes(i), 10) for i in range(5)] - builder = BinaryBuilder(10) + builder = BinaryBuilder(10, allow_invalid=True) builder.append(data[0]) builder.append_values(data[1:]) builder.append(1) @@ -253,11 +307,21 @@ def test_simple(self): self.assertEqual(len(arr), 7) self.assertEqual(arr.to_pylist(), data + [None, None]) + def test_simple(self): + data = [Binary(bytes(i), 10) for i in range(5)] + builder = BinaryBuilder(10) + builder.append(data[0]) + builder.append_values(data[1:]) + with self.assertRaisesRegex( + TypeError, "Got unexpected type `int32` instead of expected type `binary`" + ): + builder.append(1) + class TestDecimal128Builder(TestCase): - def test_simple(self): + def test_simple_allow_invalid(self): data = [Decimal128([i, i]) for i in range(5)] - builder = Decimal128Builder() + builder = Decimal128Builder(allow_invalid=True) builder.append(data[0]) builder.append_values(data[1:]) builder.append(1) @@ -269,10 +333,20 @@ def test_simple(self): self.assertEqual(len(arr), 7) self.assertEqual(arr.to_pylist(), data + [None, None]) - -class BoolBuilderTestMixin: def test_simple(self): - builder = BoolBuilder() + data = [Decimal128([i, i]) for i in range(5)] + builder = Decimal128Builder() + builder.append(data[0]) + builder.append_values(data[1:]) + with self.assertRaisesRegex( + TypeError, "Got unexpected type `int32` instead of expected type `decimal128`" + ): + builder.append(1) + + +class TestBoolBuilder(TestCase): + def test_simple_allow_invalid(self): + builder = BoolBuilder(allow_invalid=True) builder.append(False) builder.append_values([True, False, True, False, True, False]) builder.append(1) @@ -285,22 +359,25 @@ def test_simple(self): self.assertEqual( arr.to_pylist(), [False, True, False, True, False, True, False, None, None] ) - self.assertEqual(arr.type, self.data_type) - + self.assertEqual(arr.type, bool_()) -class TestBoolBuilder(TestCase, BoolBuilderTestMixin): - def setUp(self): - self.builder_cls = BoolBuilder - self.data_type = bool_() + def test_simple(self): + builder = BoolBuilder() + builder.append(False) + builder.append_values([True, False, True, False, True, False]) + with self.assertRaisesRegex( + TypeError, "Got unexpected type `int32` instead of expected type `boolean`" + ): + builder.append(1) class TestCodeBuilder(TestCase): - def test_simple(self): + def test_simple_allow_invalid(self): # Greetings in various languages, from # https://www.w3.org/2001/06/utf-8-test/UTF-8-demo.html values = ["Hello world", "Καλημέρα κόσμε", "コンニチハ"] values += ["hello\u0000world"] - builder = CodeBuilder() + builder = CodeBuilder(allow_invalid=True) builder.append(Code(values[0])) builder.append_values([Code(v) for v in values[1:]]) builder.append("foo") @@ -314,11 +391,24 @@ def test_simple(self): self.assertEqual(len(arr), 6) self.assertEqual(arr.to_pylist(), codes + [None, None]) + def test_simple(self): + # Greetings in various languages, from + # https://www.w3.org/2001/06/utf-8-test/UTF-8-demo.html + values = ["Hello world", "Καλημέρα κόσμε", "コンニチハ"] + values += ["hello\u0000world"] + builder = CodeBuilder() + builder.append(Code(values[0])) + builder.append_values([Code(v) for v in values[1:]]) + with self.assertRaisesRegex( + TypeError, "Got unexpected type `string` instead of expected type `code`" + ): + builder.append("foo") + class TestDate32Builder(TestCase): - def test_simple(self): + def test_simple_allow_invalid(self): values = [datetime(1970 + i, 1, 1) for i in range(3)] - builder = Date32Builder() + builder = Date32Builder(allow_invalid=True) builder.append(values[0]) builder.append_values(values[1:]) builder.append(1) @@ -331,11 +421,21 @@ def test_simple(self): dates = [v.date() for v in values] self.assertEqual(arr.to_pylist(), dates + [None, None]) + def test_simple(self): + values = [datetime(1970 + i, 1, 1) for i in range(3)] + builder = Date32Builder() + builder.append(values[0]) + builder.append_values(values[1:]) + with self.assertRaisesRegex( + TypeError, "Got unexpected type `int32` instead of expected type `date32`" + ): + builder.append(1) + class TestDate64Builder(TestCase): - def test_simple(self): + def test_simple_allow_invalid(self): values = [datetime(1970 + i, 1, 1) for i in range(3)] - builder = Date64Builder() + builder = Date64Builder(allow_invalid=True) builder.append(values[0]) builder.append_values(values[1:]) builder.append(1) @@ -347,3 +447,13 @@ def test_simple(self): self.assertEqual(len(arr), 5) dates = [v.date() for v in values] self.assertEqual(arr.to_pylist(), dates + [None, None]) + + def test_simple(self): + values = [datetime(1970 + i, 1, 1) for i in range(3)] + builder = Date64Builder() + builder.append(values[0]) + builder.append_values(values[1:]) + with self.assertRaisesRegex( + TypeError, "Got unexpected type `int32` instead of expected type `date64`" + ): + builder.append(1)