Skip to content

Commit 8fad4cc

Browse files
authored
Fix precision values for decimals (#91)
1 parent 510ca2f commit 8fad4cc

File tree

3 files changed

+40
-25
lines changed

3 files changed

+40
-25
lines changed

singlestoredb/http/connection.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -647,23 +647,24 @@ def json_to_str(x: Any) -> Optional[str]:
647647
type_code = types.ColumnType.get_code(data_type)
648648
prec, scale = get_precision_scale(col['dataType'])
649649
converter = http_converters.get(type_code, None)
650+
650651
if 'UNSIGNED' in data_type:
651652
flags = 32
653+
652654
if data_type.endswith('BLOB') or data_type.endswith('BINARY'):
653655
converter = functools.partial(
654656
b64decode_converter, converter, # type: ignore
655657
)
656658
charset = 63 # BINARY
659+
657660
if type_code == 0: # DECIMAL
658661
type_code = types.ColumnType.get_code('NEWDECIMAL')
659662
elif type_code == 15: # VARCHAR / VARBINARY
660663
type_code = types.ColumnType.get_code('VARSTRING')
661-
if type_code == 246 and prec is not None: # NEWDECIMAL
662-
prec += 1 # for sign
663-
if scale is not None and scale > 0:
664-
prec += 1 # for decimal
664+
665665
if converter is not None:
666666
convs.append((i, None, converter))
667+
667668
description.append(
668669
Description(
669670
str(col['name']), type_code,
@@ -673,6 +674,7 @@ def json_to_str(x: Any) -> Optional[str]:
673674
),
674675
)
675676
pymy_res.append(PyMyField(col['name'], flags, charset))
677+
676678
self._descriptions.append(description)
677679
self._schemas.append(get_schema(self._results_type, description))
678680

@@ -936,7 +938,7 @@ def next(self) -> Optional[Result]:
936938

937939
def __iter__(self) -> Iterable[Tuple[Any, ...]]:
938940
"""Return result iterator."""
939-
return iter(self._rows)
941+
return iter(self._rows[self._row_idx:])
940942

941943
def __enter__(self) -> 'Cursor':
942944
"""Enter a context."""

singlestoredb/mysql/protocol.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,13 +324,26 @@ def _parse_field_descriptor(self, encoding):
324324
raise TypeError(f'unrecognized extended data type: {ext_type_code}')
325325

326326
def description(self):
327-
"""Provides a 7-item tuple compatible with the Python PEP249 DB Spec."""
327+
"""
328+
Provides a 9-item tuple.
329+
330+
Standard descriptions only have 7 fields according to the Python
331+
PEP249 DB Spec, but we need to surface information about unsigned
332+
types and charsetnr for proper type handling.
333+
334+
"""
335+
precision = self.get_column_length()
336+
if self.type_code in (FIELD_TYPE.DECIMAL, FIELD_TYPE.NEWDECIMAL):
337+
if precision:
338+
precision -= 1 # for the sign
339+
if self.scale > 0:
340+
precision -= 1 # for the decimal point
328341
return Description(
329342
self.name,
330343
self.type_code,
331344
None, # TODO: display_length; should this be self.length?
332345
self.get_column_length(), # 'internal_size'
333-
self.get_column_length(), # 'precision' # TODO: why!?!?
346+
precision, # 'precision'
334347
self.scale,
335348
self.flags % 2 == 0,
336349
self.flags,

singlestoredb/tests/test_connection.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,7 +1446,7 @@ def test_alltypes_polars(self):
14461446
# Recent versions of polars have a problem with decimals
14471447
class FixCompare(str):
14481448
def __eq__(self, other):
1449-
return super().__eq__(other.replace('precision=None', 'precision=22'))
1449+
return super().__eq__(other.replace('precision=None', 'precision=20'))
14501450

14511451
dtypes = [
14521452
('id', 'Int32'),
@@ -1469,10 +1469,10 @@ def __eq__(self, other):
14691469
('float', 'Float32'),
14701470
('double', 'Float64'),
14711471
('real', 'Float64'),
1472-
('decimal', FixCompare('Decimal(precision=22, scale=6)')),
1473-
('dec', FixCompare('Decimal(precision=22, scale=6)')),
1474-
('fixed', FixCompare('Decimal(precision=22, scale=6)')),
1475-
('numeric', FixCompare('Decimal(precision=22, scale=6)')),
1472+
('decimal', FixCompare('Decimal(precision=20, scale=6)')),
1473+
('dec', FixCompare('Decimal(precision=20, scale=6)')),
1474+
('fixed', FixCompare('Decimal(precision=20, scale=6)')),
1475+
('numeric', FixCompare('Decimal(precision=20, scale=6)')),
14761476
('date', 'Date'),
14771477
('time', "Duration(time_unit='us')"),
14781478
('time_6', "Duration(time_unit='us')"),
@@ -1593,7 +1593,7 @@ def test_alltypes_no_nulls_polars(self):
15931593
# Recent versions of polars have a problem with decimals
15941594
class FixCompare(str):
15951595
def __eq__(self, other):
1596-
return super().__eq__(other.replace('precision=None', 'precision=22'))
1596+
return super().__eq__(other.replace('precision=None', 'precision=20'))
15971597

15981598
dtypes = [
15991599
('id', 'Int32'),
@@ -1616,10 +1616,10 @@ def __eq__(self, other):
16161616
('float', 'Float32'),
16171617
('double', 'Float64'),
16181618
('real', 'Float64'),
1619-
('decimal', FixCompare('Decimal(precision=22, scale=6)')),
1620-
('dec', FixCompare('Decimal(precision=22, scale=6)')),
1621-
('fixed', FixCompare('Decimal(precision=22, scale=6)')),
1622-
('numeric', FixCompare('Decimal(precision=22, scale=6)')),
1619+
('decimal', FixCompare('Decimal(precision=20, scale=6)')),
1620+
('dec', FixCompare('Decimal(precision=20, scale=6)')),
1621+
('fixed', FixCompare('Decimal(precision=20, scale=6)')),
1622+
('numeric', FixCompare('Decimal(precision=20, scale=6)')),
16231623
('date', 'Date'),
16241624
('time', "Duration(time_unit='us')"),
16251625
('time_6', "Duration(time_unit='us')"),
@@ -1825,10 +1825,10 @@ def test_alltypes_arrow(self):
18251825
('float', 'float'),
18261826
('double', 'double'),
18271827
('real', 'double'),
1828-
('decimal', 'decimal128(22, 6)'),
1829-
('dec', 'decimal128(22, 6)'),
1830-
('fixed', 'decimal128(22, 6)'),
1831-
('numeric', 'decimal128(22, 6)'),
1828+
('decimal', 'decimal128(20, 6)'),
1829+
('dec', 'decimal128(20, 6)'),
1830+
('fixed', 'decimal128(20, 6)'),
1831+
('numeric', 'decimal128(20, 6)'),
18321832
('date', 'date64[ms]'),
18331833
('time', 'duration[us]'),
18341834
('time_6', 'duration[us]'),
@@ -1964,10 +1964,10 @@ def test_alltypes_no_nulls_arrow(self):
19641964
('float', 'float'),
19651965
('double', 'double'),
19661966
('real', 'double'),
1967-
('decimal', 'decimal128(22, 6)'),
1968-
('dec', 'decimal128(22, 6)'),
1969-
('fixed', 'decimal128(22, 6)'),
1970-
('numeric', 'decimal128(22, 6)'),
1967+
('decimal', 'decimal128(20, 6)'),
1968+
('dec', 'decimal128(20, 6)'),
1969+
('fixed', 'decimal128(20, 6)'),
1970+
('numeric', 'decimal128(20, 6)'),
19711971
('date', 'date64[ms]'),
19721972
('time', 'duration[us]'),
19731973
('time_6', 'duration[us]'),

0 commit comments

Comments
 (0)