Skip to content

Commit 5890549

Browse files
committed
macos: reduce precision to float32
Handle: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead
1 parent 099fded commit 5890549

1 file changed

Lines changed: 14 additions & 3 deletions

File tree

src/damast/data_handling/accessors.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import logging
66
import random
7+
import sys
78
import time
89
from typing import Any, List, Optional, Union
910

@@ -22,6 +23,16 @@
2223
logger = logging.getLogger("damast")
2324

2425

26+
if sys.platform == "darwin":
27+
# Handle "Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead"
28+
def _mps_precision(data):
29+
if data.dtype == np.float64:
30+
return data.astype(np.float32)
31+
return data
32+
else:
33+
def _mps_precision(data):
34+
return data
35+
2536
# https://www.tensorflow.org/tutorials/structured_data/time_series
2637
class GroupSequenceAccessor:
2738
"""
@@ -289,12 +300,12 @@ def _generator(features: List[str], target: Optional[List[str]],
289300
# target it the last step in the timeline, so the last
290301
target_chunk.append(target_window.to_numpy())
291302

292-
X = np.array(chunk)
303+
X = _mps_precision(np.array(chunk))
293304
if use_target:
294305
if np.lib.NumpyVersion(np.__version__) >= '2.0.0':
295-
y = np.array(target_chunk)
306+
y = _mps_precision(np.array(target_chunk))
296307
else:
297-
y = np.array(target_chunk, copy=False)
308+
y = _mps_precision(np.array(target_chunk, copy=False))
298309
yield (X, y)
299310
else:
300311
yield (X,)

0 commit comments

Comments
 (0)