Skip to content

Commit 0372fc0

Browse files
committed
build MLX, Cmlx via xcodeproj as a framework
tests passing with ml-explore/mlx#2702 locally applied
1 parent 8f9f747 commit 0372fc0

40 files changed

+5907
-5
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ playground.xcworkspace
4141
Packages/
4242
Package.pins
4343
Package.resolved
44-
*.xcodeproj
4544
#
4645
# Xcode automatically generates this directory with a .xcworkspacedata file and xcuserdata
4746
# hence it is not needed unless you have added a package configuration file to your project
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#include <Cmlx/mlx-mlx.h>
2+
#include <Cmlx/mlx-transforms_impl.h>
3+
#include <Cmlx/mlx-linalg.h>
4+
#include <Cmlx/mlx-fast.h>
Lines changed: 379 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,379 @@
1+
/* Copyright © 2023-2024 Apple Inc. */
2+
3+
#ifndef MLX_ARRAY_H
4+
#define MLX_ARRAY_H
5+
6+
#include <Cmlx/mlx-string.h>
7+
8+
#include <float.h>
9+
#include <stdbool.h>
10+
#include <stdint.h>
11+
#include <stdlib.h>
12+
13+
#include <Cmlx/mlx-half.h>
14+
15+
#ifdef __cplusplus
16+
extern "C" {
17+
#endif
18+
19+
/**
20+
* \defgroup mlx_array Array
21+
* MLX N-dimensional array object.
22+
*/
23+
/**@{*/
24+
25+
/**
26+
* A N-dimensional array object.
27+
*/
28+
typedef struct mlx_array_ {
29+
void* ctx;
30+
} mlx_array;
31+
32+
static mlx_array mlx_array_empty;
33+
34+
/**
35+
* Array element type.
36+
*/
37+
typedef enum mlx_dtype_ {
38+
MLX_BOOL,
39+
MLX_UINT8,
40+
MLX_UINT16,
41+
MLX_UINT32,
42+
MLX_UINT64,
43+
MLX_INT8,
44+
MLX_INT16,
45+
MLX_INT32,
46+
MLX_INT64,
47+
MLX_FLOAT16,
48+
MLX_FLOAT32,
49+
MLX_FLOAT64,
50+
MLX_BFLOAT16,
51+
MLX_COMPLEX64,
52+
} mlx_dtype;
53+
54+
/**
55+
* Size of given mlx_dtype datatype in bytes.
56+
*/
57+
size_t mlx_dtype_size(mlx_dtype dtype);
58+
59+
/**
60+
* Get array description.
61+
*/
62+
int mlx_array_tostring(mlx_string* str, const mlx_array arr);
63+
64+
/**
65+
* New empty array.
66+
*/
67+
mlx_array mlx_array_new(void);
68+
69+
/**
70+
* Free an array.
71+
*/
72+
int mlx_array_free(mlx_array arr);
73+
74+
/**
75+
* New array from a bool scalar.
76+
*/
77+
mlx_array mlx_array_new_bool(bool val);
78+
/**
79+
* New array from a int scalar.
80+
*/
81+
mlx_array mlx_array_new_int(int val);
82+
/**
83+
* New array from a float32 scalar.
84+
*/
85+
mlx_array mlx_array_new_float32(float val);
86+
/**
87+
* New array from a float scalar.
88+
* Same as float32.
89+
*/
90+
mlx_array mlx_array_new_float(float val);
91+
/**
92+
* New array from a float64 scalar.
93+
*/
94+
mlx_array mlx_array_new_float64(double val);
95+
/**
96+
* New array from a double scalar.
97+
* Same as float64.
98+
*/
99+
mlx_array mlx_array_new_double(double val);
100+
/**
101+
* New array from a complex scalar.
102+
*/
103+
mlx_array mlx_array_new_complex(float real_val, float imag_val);
104+
/**
105+
* New array from existing buffer.
106+
* @param data A buffer which will be copied.
107+
* @param shape Shape of the array.
108+
* @param dim Number of dimensions (size of `shape`).
109+
* @param dtype Type of array elements.
110+
*/
111+
mlx_array mlx_array_new_data(
112+
const void* data,
113+
const int* shape,
114+
int dim,
115+
mlx_dtype dtype);
116+
/**
117+
* Set array to provided src array.
118+
*/
119+
int mlx_array_set(mlx_array* arr, const mlx_array src);
120+
/**
121+
* Set array to a bool scalar.
122+
*/
123+
int mlx_array_set_bool(mlx_array* arr, bool val);
124+
/**
125+
* Set array to a int scalar.
126+
*/
127+
int mlx_array_set_int(mlx_array* arr, int val);
128+
/**
129+
* Set array to a float32 scalar.
130+
*/
131+
int mlx_array_set_float32(mlx_array* arr, float val);
132+
/**
133+
* Set array to a float scalar.
134+
*/
135+
int mlx_array_set_float(mlx_array* arr, float val);
136+
/**
137+
* Set array to a float64 scalar.
138+
*/
139+
int mlx_array_set_float64(mlx_array* arr, double val);
140+
/**
141+
* Set array to a double scalar.
142+
*/
143+
int mlx_array_set_double(mlx_array* arr, double val);
144+
/**
145+
* Set array to a complex scalar.
146+
*/
147+
int mlx_array_set_complex(mlx_array* arr, float real_val, float imag_val);
148+
/**
149+
* Set array to specified data and shape.
150+
* @param arr Destination array.
151+
* @param data A buffer which will be copied.
152+
* @param shape Shape of the array.
153+
* @param dim Number of dimensions (size of `shape`).
154+
* @param dtype Type of array elements.
155+
*/
156+
int mlx_array_set_data(
157+
mlx_array* arr,
158+
const void* data,
159+
const int* shape,
160+
int dim,
161+
mlx_dtype dtype);
162+
163+
/**
164+
* The size of the array's datatype in bytes.
165+
*/
166+
size_t mlx_array_itemsize(const mlx_array arr);
167+
/**
168+
* Number of elements in the array.
169+
*/
170+
size_t mlx_array_size(const mlx_array arr);
171+
/**
172+
* The number of bytes in the array.
173+
*/
174+
size_t mlx_array_nbytes(const mlx_array arr);
175+
/**
176+
* The array's dimension.
177+
*/
178+
size_t mlx_array_ndim(const mlx_array arr);
179+
/**
180+
* The shape of the array.
181+
* Returns: a pointer to the sizes of each dimension.
182+
*/
183+
const int* mlx_array_shape(const mlx_array arr);
184+
/**
185+
* The strides of the array.
186+
* Returns: a pointer to the sizes of each dimension.
187+
*/
188+
const size_t* mlx_array_strides(const mlx_array arr);
189+
/**
190+
* The shape of the array in a particular dimension.
191+
*/
192+
int mlx_array_dim(const mlx_array arr, int dim);
193+
/**
194+
* The array element type.
195+
*/
196+
mlx_dtype mlx_array_dtype(const mlx_array arr);
197+
198+
/**
199+
* Evaluate the array.
200+
*/
201+
int mlx_array_eval(mlx_array arr);
202+
203+
/**
204+
* Access the value of a scalar array.
205+
*/
206+
int mlx_array_item_bool(bool* res, const mlx_array arr);
207+
/**
208+
* Access the value of a scalar array.
209+
*/
210+
int mlx_array_item_uint8(uint8_t* res, const mlx_array arr);
211+
/**
212+
* Access the value of a scalar array.
213+
*/
214+
int mlx_array_item_uint16(uint16_t* res, const mlx_array arr);
215+
/**
216+
* Access the value of a scalar array.
217+
*/
218+
int mlx_array_item_uint32(uint32_t* res, const mlx_array arr);
219+
/**
220+
* Access the value of a scalar array.
221+
*/
222+
int mlx_array_item_uint64(uint64_t* res, const mlx_array arr);
223+
/**
224+
* Access the value of a scalar array.
225+
*/
226+
int mlx_array_item_int8(int8_t* res, const mlx_array arr);
227+
/**
228+
* Access the value of a scalar array.
229+
*/
230+
int mlx_array_item_int16(int16_t* res, const mlx_array arr);
231+
/**
232+
* Access the value of a scalar array.
233+
*/
234+
int mlx_array_item_int32(int32_t* res, const mlx_array arr);
235+
/**
236+
* Access the value of a scalar array.
237+
*/
238+
int mlx_array_item_int64(int64_t* res, const mlx_array arr);
239+
/**
240+
* Access the value of a scalar array.
241+
*/
242+
int mlx_array_item_float32(float* res, const mlx_array arr);
243+
/**
244+
* Access the value of a scalar array.
245+
*/
246+
int mlx_array_item_float64(double* res, const mlx_array arr);
247+
/**
248+
* Access the value of a scalar array.
249+
*/
250+
int mlx_array_item_complex64(float _Complex* res, const mlx_array arr);
251+
252+
#ifdef HAS_FLOAT16
253+
/**
254+
* Access the value of a scalar array.
255+
*/
256+
int mlx_array_item_float16(float16_t* res, const mlx_array arr);
257+
#endif
258+
259+
#ifdef HAS_BFLOAT16
260+
/**
261+
* Access the value of a scalar array.
262+
*/
263+
int mlx_array_item_bfloat16(bfloat16_t* res, const mlx_array arr);
264+
#endif
265+
266+
/**
267+
* Returns a pointer to the array data, cast to `bool*`.
268+
* Array must be evaluated, otherwise returns NULL.
269+
*/
270+
const bool* mlx_array_data_bool(const mlx_array arr);
271+
/**
272+
* Returns a pointer to the array data, cast to `uint8_t*`.
273+
* Array must be evaluated, otherwise returns NULL.
274+
*/
275+
const uint8_t* mlx_array_data_uint8(const mlx_array arr);
276+
/**
277+
* Returns a pointer to the array data, cast to `uint16_t*`.
278+
* Array must be evaluated, otherwise returns NULL.
279+
*/
280+
const uint16_t* mlx_array_data_uint16(const mlx_array arr);
281+
/**
282+
* Returns a pointer to the array data, cast to `uint32_t*`.
283+
* Array must be evaluated, otherwise returns NULL.
284+
*/
285+
const uint32_t* mlx_array_data_uint32(const mlx_array arr);
286+
/**
287+
* Returns a pointer to the array data, cast to `uint64_t*`.
288+
* Array must be evaluated, otherwise returns NULL.
289+
*/
290+
const uint64_t* mlx_array_data_uint64(const mlx_array arr);
291+
/**
292+
* Returns a pointer to the array data, cast to `int8_t*`.
293+
* Array must be evaluated, otherwise returns NULL.
294+
*/
295+
const int8_t* mlx_array_data_int8(const mlx_array arr);
296+
/**
297+
* Returns a pointer to the array data, cast to `int16_t*`.
298+
* Array must be evaluated, otherwise returns NULL.
299+
*/
300+
const int16_t* mlx_array_data_int16(const mlx_array arr);
301+
/**
302+
* Returns a pointer to the array data, cast to `int32_t*`.
303+
* Array must be evaluated, otherwise returns NULL.
304+
*/
305+
const int32_t* mlx_array_data_int32(const mlx_array arr);
306+
/**
307+
* Returns a pointer to the array data, cast to `int64_t*`.
308+
* Array must be evaluated, otherwise returns NULL.
309+
*/
310+
const int64_t* mlx_array_data_int64(const mlx_array arr);
311+
/**
312+
* Returns a pointer to the array data, cast to `float32*`.
313+
* Array must be evaluated, otherwise returns NULL.
314+
*/
315+
const float* mlx_array_data_float32(const mlx_array arr);
316+
/**
317+
* Returns a pointer to the array data, cast to `float64*`.
318+
* Array must be evaluated, otherwise returns NULL.
319+
*/
320+
const double* mlx_array_data_float64(const mlx_array arr);
321+
/**
322+
* Returns a pointer to the array data, cast to `_Complex*`.
323+
* Array must be evaluated, otherwise returns NULL.
324+
*/
325+
const float _Complex* mlx_array_data_complex64(const mlx_array arr);
326+
327+
#ifdef HAS_FLOAT16
328+
/**
329+
* Returns a pointer to the array data, cast to `float16_t*`.
330+
* Array must be evaluated, otherwise returns NULL.
331+
*/
332+
const float16_t* mlx_array_data_float16(const mlx_array arr);
333+
#endif
334+
335+
#ifdef HAS_BFLOAT16
336+
/**
337+
* Returns a pointer to the array data, cast to `bfloat16_t*`.
338+
* Array must be evaluated, otherwise returns NULL.
339+
*/
340+
const bfloat16_t* mlx_array_data_bfloat16(const mlx_array arr);
341+
#endif
342+
343+
/**
344+
* Check if the array is available.
345+
* Internal function: use at your own risk.
346+
*/
347+
int _mlx_array_is_available(bool* res, const mlx_array arr);
348+
349+
/**
350+
* Wait on the array to be available. After this `_mlx_array_is_available`
351+
* returns `true`. Internal function: use at your own risk.
352+
*/
353+
int _mlx_array_wait(const mlx_array arr);
354+
355+
/**
356+
* Whether the array is contiguous in memory.
357+
* Internal function: use at your own risk.
358+
*/
359+
int _mlx_array_is_contiguous(bool* res, const mlx_array arr);
360+
361+
/**
362+
* Whether the array's rows are contiguous in memory.
363+
* Internal function: use at your own risk.
364+
*/
365+
int _mlx_array_is_row_contiguous(bool* res, const mlx_array arr);
366+
367+
/**
368+
* Whether the array's columns are contiguous in memory.
369+
* Internal function: use at your own risk.
370+
*/
371+
int _mlx_array_is_col_contiguous(bool* res, const mlx_array arr);
372+
373+
/**@}*/
374+
375+
#ifdef __cplusplus
376+
}
377+
#endif
378+
379+
#endif

0 commit comments

Comments
 (0)