Skip to content

Commit d1f8627

Browse files
committed
angelos's commit files
1 parent 8ca7f9e commit d1f8627

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+12350
-0
lines changed

.gitignore

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Metal libraries
10+
*.metallib
11+
12+
# Distribution / packaging
13+
python/mlx/share
14+
python/mlx/include
15+
.Python
16+
build/
17+
develop-eggs/
18+
dist/
19+
downloads/
20+
eggs/
21+
.eggs/
22+
lib/
23+
lib64/
24+
parts/
25+
sdist/
26+
var/
27+
wheels/
28+
share/python-wheels/
29+
*.egg-info/
30+
.installed.cfg
31+
*.egg
32+
MANIFEST
33+
34+
# vim
35+
*.swp
36+
37+
# Ignore build dir
38+
build/
39+
40+
# Prerequisites
41+
*.d
42+
43+
# Compiled Object files
44+
*.slo
45+
*.lo
46+
*.o
47+
*.obj
48+
49+
# Precompiled Headers
50+
*.gch
51+
*.pch
52+
53+
# Compiled Dynamic libraries
54+
*.so
55+
*.dylib
56+
*.dll
57+
58+
# Fortran module files
59+
*.mod
60+
*.smod
61+
62+
# Compiled Static libraries
63+
*.lai
64+
*.la
65+
*.a
66+
*.lib
67+
68+
# Executables
69+
*.exe
70+
*.out
71+
*.app
72+
73+
# VSCode
74+
.vscode/
75+
.DS_Store

.pre-commit-config.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
repos:
2+
- repo: https://github.com/pre-commit/mirrors-clang-format
3+
rev: v14.0.6
4+
hooks:
5+
- id: clang-format
6+
- repo: https://github.com/psf/black
7+
rev: 22.10.0
8+
hooks:
9+
- id: black

CMakeLists.txt

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
cmake_minimum_required(VERSION 3.24)
2+
3+
project(mlx LANGUAGES CXX)
4+
5+
# ----------------------------- Setup -----------------------------
6+
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
7+
set(CMAKE_CXX_STANDARD 17)
8+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
9+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
10+
set(CMAKE_INSTALL_MESSAGE NEVER)
11+
12+
# ----------------------------- Configuration -----------------------------
13+
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
14+
option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
15+
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
16+
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
17+
option(MLX_BUILD_METAL "Build metal backend" ON)
18+
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
19+
20+
if(NOT MLX_VERSION)
21+
set(MLX_VERSION 0.0.1)
22+
endif()
23+
24+
# ----------------------------- Lib -----------------------------
25+
26+
include(FetchContent)
27+
# Avoid warning about DOWNLOAD_EXTRACT_TIMESTAMP in CMake 3.24:
28+
cmake_policy(SET CMP0135 NEW)
29+
30+
add_library(mlx)
31+
32+
if (MLX_BUILD_METAL)
33+
find_library(METAL_LIB Metal)
34+
find_library(FOUNDATION_LIB Foundation)
35+
find_library(QUARTZ_LIB QuartzCore)
36+
endif()
37+
38+
if (MLX_BUILD_METAL AND NOT METAL_LIB)
39+
message(STATUS "Metal not found. Unable to build GPU")
40+
elseif (MLX_BUILD_METAL)
41+
message(STATUS "Building METAL sources")
42+
add_compile_definitions(_METAL_)
43+
44+
execute_process(COMMAND zsh "-c" "/usr/bin/sw_vers | cut -f2- -d: | sed -n 2p | grep -Eo '[0-9]+.[0-9]+'"
45+
OUTPUT_VARIABLE MACOS_VERSION)
46+
47+
message(STATUS "Detected macOS version ${MACOS_VERSION}")
48+
if (${MACOS_VERSION} GREATER_EQUAL 14.0)
49+
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS14_iOS17-beta.zip)
50+
elseif (${MACOS_VERSION} GREATER_EQUAL 13.3)
51+
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13.3_iOS16.4.zip)
52+
else()
53+
set(METAL_CPP_URL https://developer.apple.com/metal/cpp/files/metal-cpp_macOS13_iOS16.zip)
54+
endif()
55+
56+
FetchContent_Declare(
57+
metal_cpp
58+
URL ${METAL_CPP_URL}
59+
)
60+
61+
FetchContent_MakeAvailable(metal_cpp)
62+
target_include_directories(
63+
mlx PUBLIC
64+
$<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
65+
$<INSTALL_INTERFACE:include/metal_cpp>
66+
)
67+
target_link_libraries(
68+
mlx
69+
${METAL_LIB}
70+
${FOUNDATION_LIB}
71+
${QUARTZ_LIB})
72+
endif()
73+
74+
find_library(ACCELERATE_LIBRARY Accelerate)
75+
if (ACCELERATE_LIBRARY)
76+
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
77+
set(MLX_BUILD_ACCELERATE ON)
78+
target_link_libraries(mlx ${ACCELERATE_LIBRARY})
79+
add_compile_definitions(ACCELERATE_NEW_LAPACK)
80+
else()
81+
message(STATUS "Accelerate not found, using default backend.")
82+
set(MLX_BUILD_ACCELERATE OFF)
83+
#set(BLA_VENDOR Generic)
84+
find_package(BLAS REQUIRED)
85+
if (NOT BLAS_FOUND)
86+
message(FATAL_ERROR "Must have BLAS installed")
87+
endif()
88+
# TODO find a cleaner way to do this
89+
find_path(BLAS_INCLUDE_DIRS cblas.h
90+
/usr/include
91+
/usr/local/include
92+
$ENV{BLAS_HOME}/include)
93+
message(STATUS ${BLAS_LIBRARIES})
94+
message(STATUS ${BLAS_INCLUDE_DIRS})
95+
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
96+
target_link_libraries(mlx ${BLAS_LIBRARIES})
97+
endif()
98+
99+
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
100+
101+
target_include_directories(
102+
mlx
103+
PUBLIC
104+
$<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
105+
$<INSTALL_INTERFACE:include>
106+
)
107+
108+
if (MLX_BUILD_PYTHON_BINDINGS)
109+
message(STATUS "Building Python bindings.")
110+
find_package(Python COMPONENTS Interpreter Development)
111+
find_package(pybind11 CONFIG REQUIRED)
112+
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
113+
endif()
114+
115+
if (MLX_BUILD_TESTS)
116+
include(CTest)
117+
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
118+
endif()
119+
120+
if (MLX_BUILD_EXAMPLES)
121+
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
122+
endif()
123+
124+
if (MLX_BUILD_BENCHMARKS)
125+
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
126+
endif()
127+
128+
# ----------------------------- Installation -----------------------------
129+
include(GNUInstallDirs)
130+
131+
# Install library
132+
install(
133+
TARGETS mlx
134+
EXPORT MLXTargets
135+
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
136+
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
137+
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
138+
INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
139+
)
140+
141+
142+
# Install headers
143+
install(
144+
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
145+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
146+
COMPONENT headers
147+
FILES_MATCHING PATTERN "*.h"
148+
)
149+
150+
# Install metal dependencies
151+
if (MLX_BUILD_METAL)
152+
153+
# Install metal cpp
154+
install(
155+
DIRECTORY ${metal_cpp_SOURCE_DIR}/
156+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
157+
COMPONENT metal_cpp_source
158+
)
159+
160+
endif()
161+
162+
# Install cmake config
163+
set(MLX_CMAKE_BUILD_CONFIG ${CMAKE_BINARY_DIR}/MLXConfig.cmake)
164+
set(MLX_CMAKE_BUILD_VERSION_CONFIG ${CMAKE_BINARY_DIR}/MLXConfigVersion.cmake)
165+
set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
166+
167+
install(
168+
EXPORT MLXTargets
169+
FILE MLXTargets.cmake
170+
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
171+
)
172+
173+
include(CMakePackageConfigHelpers)
174+
175+
write_basic_package_version_file(
176+
${MLX_CMAKE_BUILD_VERSION_CONFIG}
177+
COMPATIBILITY SameMajorVersion
178+
VERSION ${MLX_VERSION}
179+
)
180+
181+
configure_package_config_file(
182+
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in
183+
${MLX_CMAKE_BUILD_CONFIG}
184+
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
185+
NO_CHECK_REQUIRED_COMPONENTS_MACRO
186+
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR MLX_CMAKE_INSTALL_MODULE_DIR
187+
)
188+
189+
install(
190+
FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
191+
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
192+
)
193+
194+
install(
195+
DIRECTORY ${CMAKE_MODULE_PATH}/
196+
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
197+
)

benchmarks/cpp/time_utils.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#pragma once
2+
3+
#include <chrono>
4+
#include <iomanip>
5+
#include <iostream>
6+
7+
#include "mlx/mlx.h"
8+
9+
#define milliseconds(x) \
10+
(std::chrono::duration_cast<std::chrono::nanoseconds>(x).count() / 1e6)
11+
#define time_now() std::chrono::high_resolution_clock::now()
12+
13+
#define TIME(FUNC, ...) \
14+
std::cout << "Timing " << #FUNC << " ... " << std::flush \
15+
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
16+
<< std::endl;
17+
18+
#define TIMEM(MSG, FUNC, ...) \
19+
std::cout << "Timing " \
20+
<< "(" << MSG << ") " << #FUNC << " ... " << std::flush \
21+
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
22+
<< std::endl;
23+
24+
template <typename F, typename... Args>
25+
double time_fn(F fn, Args... args) {
26+
// warmup
27+
for (int i = 0; i < 5; ++i) {
28+
eval(fn(std::forward<Args>(args)...));
29+
}
30+
31+
int num_iters = 100;
32+
auto start = time_now();
33+
for (int i = 0; i < num_iters; i++) {
34+
eval(fn(std::forward<Args>(args)...));
35+
}
36+
auto end = time_now();
37+
return milliseconds(end - start) / static_cast<double>(num_iters);
38+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import argparse
2+
import mlx.core as mx
3+
4+
from time_utils import time_fn
5+
6+
B = 8
7+
T = 1024
8+
D = 512
9+
10+
11+
def time_batch_matmul():
12+
mx.random.seed(3)
13+
a = mx.random.uniform(shape=(B, T, D))
14+
b = mx.random.uniform(shape=(D, D))
15+
c = mx.random.uniform(shape=(B, T, D))
16+
mx.eval(a, b, c)
17+
18+
time_fn(mx.matmul, a, b)
19+
20+
def batch_vjp_first():
21+
return mx.vjp(mx.matmul, [a, b], [c])[1][0]
22+
23+
time_fn(batch_vjp_first)
24+
25+
def batch_vjp_second():
26+
return mx.vjp(mx.matmul, [a, b], [c])[1][1]
27+
28+
time_fn(batch_vjp_second)
29+
30+
31+
def time_unbatch_matmul(key):
32+
mx.random.seed(3)
33+
a = mx.random.uniform(shape=(B * T, D))
34+
b = mx.random.uniform(shape=(D, D))
35+
c = mx.random.uniform(shape=(B * T, D))
36+
mx.eval(a, b, c)
37+
time_fn(mx.matmul, a, b)
38+
39+
def unbatch_vjp_first():
40+
return mx.matmul(c, mx.transpose(b))
41+
42+
time_fn(unbatch_vjp_first)
43+
44+
def unbatch_vjp_second():
45+
return mx.matmul(mx.transpose(a), c)
46+
47+
time_fn(unbatch_vjp_second)
48+
49+
50+
if __name__ == "__main__":
51+
parser = argparse.ArgumentParser("MLX benchmarks.")
52+
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
53+
args = parser.parse_args()
54+
if args.gpu:
55+
mx.set_default_device(mx.gpu)
56+
else:
57+
mx.set_default_device(mx.cpu)
58+
59+
time_batch_matmul()
60+
time_unbatch_matmul()

0 commit comments

Comments
 (0)