# Custom Metal Kernels
MLX supports writing custom Metal kernels through the Python and C++ APIs.
## Simple Example
Let’s write a custom kernel that computes `exp` elementwise:
    
    source = """
        uint elem = thread_position_in_grid.x;
        T tmp = inp[elem];
        out[elem] = metal::exp(tmp);
    """
    
    kernel = mx.fast.metal_kernel(
        name="myexp",
        input_names=["inp"],
        output_names=["out"],
        source=source,
    )
    
    def exp_elementwise(a: mx.array):
        outputs = kernel(
            inputs=[a],
            template=[("T", mx.float32)],
            grid=(a.size, 1, 1),
            threadgroup=(256, 1, 1),
            output_shapes=[a.shape],
            output_dtypes=[a.dtype],
        )
        return outputs[0]
    
    a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
    b = exp_elementwise(a)
    assert mx.allclose(b, mx.exp(a))
    
Every time you make a kernel, a new Metal library is created and possibly JIT compiled. To reduce the overhead from that, build the kernel once with `fast.metal_kernel()` and then use it many times.
Note
Only pass the body of the Metal kernel in `source`. The function signature is generated automatically.
The full function signature will be generated using:
  * The shapes/dtypes of `inputs`
    
In the above, `a` is an `mx.array` of type `mx.float16` and we pass it with the key `inp` so we will add `const device float16_t* inp` to the signature. `inp_shape`, `inp_strides` and `inp_ndim` are also added for convenience if they are present in `source`.
  * The list of `output_dtypes`
    
In the above, `out` is an `mx.array` of type `mx.float16` so we add `device float16_t* out`.
  * Template parameters passed using `template`
    
In the above, `template=[("T", mx.float32)]` adds a template of `template <typename T>` to the function and instantiates the template with `custom_kernel_myexp_float<float>`. Template parameters can be `mx.core.Dtype`, `int` or `bool`.
  * Metal attributes used in `source` such as `[[thread_position_in_grid]]`
    
These will be added as function arguments. All the attributes defined in Table 5.8 of the Metal Shading Language Specification are supported.


Putting this all together, the generated function signature for `myexp` is as follows:
    
    template <typename T>
    [[kernel]] void custom_kernel_myexp_float(
      const device float16_t* inp [[buffer(0)]],
      device float16_t* out [[buffer(1)]],
      uint3 thread_position_in_grid [[thread_position_in_grid]]) {
    
            uint elem = thread_position_in_grid.x;
            T tmp = inp[elem];
            out[elem] = metal::exp(tmp);
    
    }
    
    template [[host_name("custom_kernel_myexp_float")]] [[kernel]] decltype(custom_kernel_myexp_float<float>) custom_kernel_myexp_float<float>;
    
Note: `grid` and `threadgroup` are parameters to the Metal dispatchThreads function. This means we will launch `mx.prod(grid)` threads, subdivided into `threadgroup` size threadgroups. For optimal performance, each thread group dimension should be less than or equal to the corresponding grid dimension.
Passing `verbose=True` to `ast.metal_kernel.__call__()` will print the generated code for debugging purposes.
## Using Shape/Strides
`fast.metal_kernel()` supports an argument `ensure_row_contiguous` which is `True` by default. This will copy the array inputs if needed before the kernel is launched to ensure that the memory layout is row contiguous. Generally this makes writing the kernel easier, since we don’t have to worry about gaps or the ordering of the dims when indexing.
If we want to avoid this copy, `fast.metal_kernel()` automatically passes `a_shape`, `a_strides` and `a_ndim` for each input array `a` if any are present in `source`. We can then use MLX’s built in indexing utils to fetch the right elements for each thread.
Let’s convert `myexp` above to support arbitrarily strided arrays without relying on a copy from `ensure_row_contiguous`:
    
    source = """
        uint elem = thread_position_in_grid.x;
        // Utils from `mlx/backend/metal/kernels/utils.h` are automatically included
        uint loc = elem_to_loc(elem, inp_shape, inp_strides, inp_ndim);
        T tmp = inp[loc];
        // Output arrays are always row contiguous
        out[elem] = metal::exp(tmp);
    """
    
    kernel = mx.fast.metal_kernel(
        name="myexp_strided",
        input_names=["inp"],
        output_names=["out"],
        source=source,
        ensure_row_contiguous=False,
    )
    
    def exp_elementwise(a: mx.array):
        outputs = kernel(
            inputs=[a],
            template=[("T", mx.float32)],
            grid=(a.size, 1, 1),
            threadgroup=(256, 1, 1),
            output_shapes=[a.shape],
            output_dtypes=[a.dtype],
        )
        return outputs[0]
    
    a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
    # make non-contiguous
    a = a[::2]
    b = exp_elementwise(a)
    assert mx.allclose(b, mx.exp(a))
    
## Complex Example
Let’s implement a more complex example: `grid_sample` in `"bilinear"` mode.
We’ll start with the following MLX implementation using standard ops:
    
    def grid_sample_ref(x, grid):
        N, H_in, W_in, _ = x.shape
        ix = ((grid[..., 0] + 1) * W_in - 1) / 2
        iy = ((grid[..., 1] + 1) * H_in - 1) / 2
    
        ix_nw = mx.floor(ix).astype(mx.int32)
        iy_nw = mx.floor(iy).astype(mx.int32)
    
        ix_ne = ix_nw + 1
        iy_ne = iy_nw
    
        ix_sw = ix_nw
        iy_sw = iy_nw + 1
    
        ix_se = ix_nw + 1
        iy_se = iy_nw + 1
    
        nw = (ix_se - ix)    * (iy_se - iy)
        ne = (ix    - ix_sw) * (iy_sw - iy)
        sw = (ix_ne - ix)    * (iy    - iy_ne)
        se = (ix    - ix_nw) * (iy    - iy_nw)
    
        I_nw = x[mx.arange(N)[:, None, None], iy_nw, ix_nw, :]
        I_ne = x[mx.arange(N)[:, None, None], iy_ne, ix_ne, :]
        I_sw = x[mx.arange(N)[:, None, None], iy_sw, ix_sw, :]
        I_se = x[mx.arange(N)[:, None, None], iy_se, ix_se, :]
    
        mask_nw = (iy_nw >= 0) & (iy_nw <= H_in - 1) & (ix_nw >= 0) & (ix_nw <= W_in - 1)
        mask_ne = (iy_ne >= 0) & (iy_ne <= H_in - 1) & (ix_ne >= 0) & (ix_ne <= W_in - 1)
        mask_sw = (iy_sw >= 0) & (iy_sw <= H_in - 1) & (ix_sw >= 0) & (ix_sw <= W_in - 1)
        mask_se = (iy_se >= 0) & (iy_se <= H_in - 1) & (ix_se >= 0) & (ix_se <= W_in - 1)
    
        I_nw *= mask_nw[..., None]
        I_ne *= mask_ne[..., None]
        I_sw *= mask_sw[..., None]
        I_se *= mask_se[..., None]
    
        output = nw[..., None] * I_nw + ne[..., None] * I_ne + sw[..., None] * I_sw + se[..., None] * I_se
    
        return output
    
Now let’s use `custom_function()` together with `fast.metal_kernel()` to write a fast GPU kernel for both the forward and backward passes.
First we’ll implement the forward pass as a fused kernel:
    
    source = """
        uint elem = thread_position_in_grid.x;
        int H = x_shape[1];
        int W = x_shape[2];
        int C = x_shape[3];
        int gH = grid_shape[1];
        int gW = grid_shape[2];
    
        int w_stride = C;
        int h_stride = W * w_stride;
        int b_stride = H * h_stride;
    
        uint grid_idx = elem / C * 2;
        float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
        float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
    
        int ix_nw = floor(ix);
        int iy_nw = floor(iy);
    
        int ix_ne = ix_nw + 1;
        int iy_ne = iy_nw;
    
        int ix_sw = ix_nw;
        int iy_sw = iy_nw + 1;
    
        int ix_se = ix_nw + 1;
        int iy_se = iy_nw + 1;
    
        T nw = (ix_se - ix)    * (iy_se - iy);
        T ne = (ix    - ix_sw) * (iy_sw - iy);
        T sw = (ix_ne - ix)    * (iy    - iy_ne);
        T se = (ix    - ix_nw) * (iy    - iy_nw);
    
        int batch_idx = elem / C / gH / gW * b_stride;
        int channel_idx = elem % C;
        int base_idx = batch_idx + channel_idx;
    
        T I_nw = x[base_idx + iy_nw * h_stride + ix_nw * w_stride];
        T I_ne = x[base_idx + iy_ne * h_stride + ix_ne * w_stride];
        T I_sw = x[base_idx + iy_sw * h_stride + ix_sw * w_stride];
        T I_se = x[base_idx + iy_se * h_stride + ix_se * w_stride];
    
        I_nw = iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1 ? I_nw : 0;
        I_ne = iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1 ? I_ne : 0;
        I_sw = iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1 ? I_sw : 0;
        I_se = iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1 ? I_se : 0;
    
        out[elem] = nw * I_nw + ne * I_ne + sw * I_sw + se * I_se;
    """
    
    kernel = mx.fast.metal_kernel(
        name="grid_sample",
        input_names=["x", "grid"],
        output_names=["out"],
        source=source,
    )
    
    @mx.custom_function
    def grid_sample(x, grid):
    
        assert x.ndim == 4, "`x` must be 4D."
        assert grid.ndim == 4, "`grid` must be 4D."
    
        B, _, _, C = x.shape
        _, gN, gM, D = grid.shape
        out_shape = (B, gN, gM, C)
    
        assert D == 2, "Last dim of `grid` must be size 2."
    
        outputs = kernel(
            inputs=[x, grid],
            template=[("T", x.dtype)],
            output_shapes=[out_shape],
            output_dtypes=[x.dtype],
            grid=(np.prod(out_shape), 1, 1),
            threadgroup=(256, 1, 1),
        )
        return outputs[0]
    
For a reasonably sized input such as:
    
    x.shape = (8, 1024, 1024, 64)
    grid.shape = (8, 256, 256, 2)
    
On an M1 Max, we see a big performance improvement:
`55.7ms -> 6.7ms => 8x speed up`
## Grid Sample VJP
Since we decorated `grid_sample` with `custom_function()`, we can now define its custom vjp transform so MLX can differentiate it.
The backwards pass requires atomically updating `x_grad`/`grid_grad` and so requires a few extra `fast.metal_kernel()` features:
  * `init_value=0`
    
Initialize all of the kernel’s outputs to this value before it runs. This allows us to update only part of the output arrays with the kernel.
  * `atomic_outputs=True`
    
Designate all of the kernel outputs as `atomic` in the function signature. This means we can use Metal’s `atomic` features to simultaneously update the `x_grad` and `grid_grad` arrays from multiple threadgroups. See section 6.15 of the Metal Shading Language Specification for more details.


We can then implement the backwards pass as follows:
    
    source = """
        uint elem = thread_position_in_grid.x;
        int H = x_shape[1];
        int W = x_shape[2];
        int C = x_shape[3];
        // Pad C to the nearest larger simdgroup size multiple
        int C_padded = ceildiv(C, threads_per_simdgroup) * threads_per_simdgroup;
    
        int gH = grid_shape[1];
        int gW = grid_shape[2];
    
        int w_stride = C;
        int h_stride = W * w_stride;
        int b_stride = H * h_stride;
    
        uint grid_idx = elem / C_padded * 2;
        float ix = ((grid[grid_idx] + 1) * W - 1) / 2;
        float iy = ((grid[grid_idx + 1] + 1) * H - 1) / 2;
    
        int ix_nw = floor(ix);
        int iy_nw = floor(iy);
    
        int ix_ne = ix_nw + 1;
        int iy_ne = iy_nw;
    
        int ix_sw = ix_nw;
        int iy_sw = iy_nw + 1;
    
        int ix_se = ix_nw + 1;
        int iy_se = iy_nw + 1;
    
        T nw = (ix_se - ix)    * (iy_se - iy);
        T ne = (ix    - ix_sw) * (iy_sw - iy);
        T sw = (ix_ne - ix)    * (iy    - iy_ne);
        T se = (ix    - ix_nw) * (iy    - iy_nw);
    
        int batch_idx = elem / C_padded / gH / gW * b_stride;
        int channel_idx = elem % C_padded;
        int base_idx = batch_idx + channel_idx;
    
        T gix = T(0);
        T giy = T(0);
        if (channel_idx < C) {
            int cot_index = elem / C_padded * C + channel_idx;
            T cot = cotangent[cot_index];
            if (iy_nw >= 0 && iy_nw <= H - 1 && ix_nw >= 0 && ix_nw <= W - 1) {
                int offset = base_idx + iy_nw * h_stride + ix_nw * w_stride;
                atomic_fetch_add_explicit(&x_grad[offset], nw * cot, memory_order_relaxed);
    
                T I_nw = x[offset];
                gix -= I_nw * (iy_se - iy) * cot;
                giy -= I_nw * (ix_se - ix) * cot;
            }
            if (iy_ne >= 0 && iy_ne <= H - 1 && ix_ne >= 0 && ix_ne <= W - 1) {
                int offset = base_idx + iy_ne * h_stride + ix_ne * w_stride;
                atomic_fetch_add_explicit(&x_grad[offset], ne * cot, memory_order_relaxed);
    
                T I_ne = x[offset];
                gix += I_ne * (iy_sw - iy) * cot;
                giy -= I_ne * (ix - ix_sw) * cot;
            }
            if (iy_sw >= 0 && iy_sw <= H - 1 && ix_sw >= 0 && ix_sw <= W - 1) {
                int offset = base_idx + iy_sw * h_stride + ix_sw * w_stride;
                atomic_fetch_add_explicit(&x_grad[offset], sw * cot, memory_order_relaxed);
    
                T I_sw = x[offset];
                gix -= I_sw * (iy - iy_ne) * cot;
                giy += I_sw * (ix_ne - ix) * cot;
            }
            if (iy_se >= 0 && iy_se <= H - 1 && ix_se >= 0 && ix_se <= W - 1) {
                int offset = base_idx + iy_se * h_stride + ix_se * w_stride;
                atomic_fetch_add_explicit(&x_grad[offset], se * cot, memory_order_relaxed);
    
                T I_se = x[offset];
                gix += I_se * (iy - iy_nw) * cot;
                giy += I_se * (ix - ix_nw) * cot;
            }
        }
    
        T gix_mult = W / 2;
        T giy_mult = H / 2;
    
        // Reduce across each simdgroup first.
        // This is much faster than relying purely on atomics.
        gix = simd_sum(gix);
        giy = simd_sum(giy);
    
        if (thread_index_in_simdgroup == 0) {
            atomic_fetch_add_explicit(&grid_grad[grid_idx], gix * gix_mult, memory_order_relaxed);
            atomic_fetch_add_explicit(&grid_grad[grid_idx + 1], giy * giy_mult, memory_order_relaxed);
        }
    """
    kernel = mx.fast.metal_kernel(
        name="grid_sample_grad",
        input_names=["x", "grid", "cotangent"],
        output_names=["x_grad", "grid_grad"],
        source=source,
        atomic_outputs=True,
    )
    
    @grid_sample.vjp
    def grid_sample_vjp(primals, cotangent, _):
        x, grid = primals
        B, _, _, C = x.shape
        _, gN, gM, D = grid.shape
    
        assert D == 2, "Last dim of `grid` must be size 2."
    
        # pad the output channels to simd group size
        # so that our `simd_sum`s don't overlap.
        simdgroup_size = 32
        C_padded = (C + simdgroup_size - 1) // simdgroup_size * simdgroup_size
        grid_size = B * gN * gM * C_padded
        outputs = kernel(
            inputs=[x, grid, cotangent],
            template=[("T", x.dtype)],
            output_shapes=[x.shape, grid.shape],
            output_dtypes=[x.dtype, x.dtype],
            grid=(grid_size, 1, 1),
            threadgroup=(256, 1, 1),
            init_value=0,
        )
        return outputs[0], outputs[1]
    
There’s an even larger speed up for the vjp:
`676.4ms -> 16.7ms => 40x speed up`
# Custom Extensions in MLX
You can extend MLX with custom operations on the CPU or GPU. This guide explains how to do that with a simple example.
## Introducing the Example
Let’s say you would like an operation that takes in two arrays, `x` and `y`, scales them both by coefficients `alpha` and `beta` respectively, and then adds them together to get the result `z = alpha * x + beta * y`. You can do that in MLX directly:
    
    import mlx.core as mx
    
    def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
        return alpha * x + beta * y
    
This function performs that operation while leaving the implementation and function transformations to MLX.
However, you may want to customize the underlying implementation, perhaps to make it faster. In this tutorial we will go through adding custom extensions. It will cover:
  * The structure of the MLX library.
  * Implementing a CPU operation.
  * Implementing a GPU operation using metal.
  * Adding the `vjp` and `jvp` function transformation.
  * Building a custom extension and binding it to python.


## Operations and Primitives
Operations in MLX build the computation graph. Primitives provide the rules for evaluating and transforming the graph. Let’s start by discussing operations in more detail.
### Operations
Operations are the front-end functions that operate on arrays. They are defined in the C++ API (Operations), and the Python API (Operations) binds them.
We would like an operation `axpby()` that takes in two arrays, `x` and `y`, and two scalars, `alpha` and `beta`. This is how to define it in C++:
    
    /**
    *  Scale and sum two vectors element-wise
    *  z = alpha * x + beta * y
    *
    *  Use NumPy-style broadcasting between x and y
    *  Inputs are upcasted to floats if needed
    **/
    array axpby(
        const array& x, // Input array x
        const array& y, // Input array y
        const float alpha, // Scaling factor for x
        const float beta, // Scaling factor for y
        StreamOrDevice s = {} // Stream on which to schedule the operation
    );
    
The simplest way to implement this is with existing operations:
    
    array axpby(
        const array& x, // Input array x
        const array& y, // Input array y
        const float alpha, // Scaling factor for x
        const float beta, // Scaling factor for y
        StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
    ) {
        // Scale x and y on the provided stream
        auto ax = multiply(array(alpha), x, s);
        auto by = multiply(array(beta), y, s);
    
        // Add and return
        return add(ax, by, s);
    }
    
The operations themselves do not contain the implementations that act on the data, nor do they contain the rules of transformations. Rather, they are an easy to use interface that use `Primitive` building blocks.
### Primitives
A `Primitive` is part of the computation graph of an `array`. It defines how to create output arrays given input arrays. Further, a `Primitive` has methods to run on the CPU or GPU and for function transformations such as `vjp` and `jvp`. Let’s go back to our example to be more concrete:
    
    class Axpby : public Primitive {
      public:
        explicit Axpby(Stream stream, float alpha, float beta)
            : Primitive(stream), alpha_(alpha), beta_(beta){};
    
        /**
        * A primitive must know how to evaluate itself on the CPU/GPU
        * for the given inputs and populate the output array.
        *
        * To avoid unnecessary allocations, the evaluation function
        * is responsible for allocating space for the array.
        */
        void eval_cpu(
            const std::vector<array>& inputs,
            std::vector<array>& outputs) override;
        void eval_gpu(
            const std::vector<array>& inputs,
            std::vector<array>& outputs) override;
    
        /** The Jacobian-vector product. */
        std::vector<array> jvp(
            const std::vector<array>& primals,
            const std::vector<array>& tangents,
            const std::vector<int>& argnums) override;
    
        /** The vector-Jacobian product. */
        std::vector<array> vjp(
            const std::vector<array>& primals,
            const std::vector<array>& cotangents,
            const std::vector<int>& argnums,
            const std::vector<array>& outputs) override;
    
        /**
        * The primitive must know how to vectorize itself across
        * the given axes. The output is a pair containing the array
        * representing the vectorized computation and the axis which
        * corresponds to the output vectorized dimension.
        */
        std::pair<std::vector<array>, std::vector<int>> vmap(
            const std::vector<array>& inputs,
            const std::vector<int>& axes) override;
    
        /** The name of primitive. */
        const char* name() const override {
          return "Axpby";
        }
    
        /** Equivalence check **/
        bool is_equivalent(const Primitive& other) const override;
    
      private:
        float alpha_;
        float beta_;
    };
    
The `Axpby` class derives from the base `Primitive` class. The `Axpby` treats `alpha` and `beta` as parameters. It then provides implementations of how the output array is produced given the inputs through `Axpby::eval_cpu()` and `Axpby::eval_gpu()`. It also provides rules of transformations in `Axpby::jvp()`, `Axpby::vjp()`, and `Axpby::vmap()`.
### Using the Primitive
Operations can use this `Primitive` to add a new `array` to the computation graph. An `array` can be constructed by providing its data type, shape, the `Primitive` that computes it, and the `array` inputs that are passed to the primitive.
Let’s reimplement our operation now in terms of our `Axpby` primitive.
    
    array axpby(
        const array& x, // Input array x
        const array& y, // Input array y
        const float alpha, // Scaling factor for x
        const float beta, // Scaling factor for y
        StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
    ) {
        // Promote dtypes between x and y as needed
        auto promoted_dtype = promote_types(x.dtype(), y.dtype());
    
        // Upcast to float32 for non-floating point inputs x and y
        auto out_dtype = issubdtype(promoted_dtype, float32)
            ? promoted_dtype
            : promote_types(promoted_dtype, float32);
    
        // Cast x and y up to the determined dtype (on the same stream s)
        auto x_casted = astype(x, out_dtype, s);
        auto y_casted = astype(y, out_dtype, s);
    
        // Broadcast the shapes of x and y (on the same stream s)
        auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
        auto out_shape = broadcasted_inputs[0].shape();
    
        // Construct the array as the output of the Axpby primitive
        // with the broadcasted and upcasted arrays as inputs
        return array(
            /* const std::vector<int>& shape = */ out_shape,
            /* Dtype dtype = */ out_dtype,
            /* std::unique_ptr<Primitive> primitive = */
            std::make_shared<Axpby>(to_stream(s), alpha, beta),
            /* const std::vector<array>& inputs = */ broadcasted_inputs);
    }
    
This operation now handles the following:
  1. Upcast inputs and resolve the output data type.
  2. Broadcast the inputs and resolve the output shape.
  3. Construct the primitive `Axpby` using the given stream, `alpha`, and `beta`.
  4. Construct the output `array` using the primitive and the inputs.


## Implementing the Primitive
No computation happens when we call the operation alone. The operation only builds the computation graph. When we evaluate the output array, MLX schedules the execution of the computation graph, and calls `Axpby::eval_cpu()` or `Axpby::eval_gpu()` depending on the stream/device specified by the user.
Warning
When `Primitive::eval_cpu()` or `Primitive::eval_gpu()` are called, no memory has been allocated for the output array. It falls on the implementation of these functions to allocate memory as needed.
### Implementing the CPU Back-end
Let’s start by implementing `Axpby::eval_cpu()`.
The method will go over each element of the output array, find the corresponding input elements of `x` and `y` and perform the operation point-wise. This is captured in the templated function `axpby_impl()`.
    
    template <typename T>
    void axpby_impl(
        const mx::array& x,
        const mx::array& y,
        mx::array& out,
        float alpha_,
        float beta_,
        mx::Stream stream) {
      out.set_data(mx::allocator::malloc(out.nbytes()));
    
      // Get the CPU command encoder and register input and output arrays
      auto& encoder = mx::cpu::get_command_encoder(stream);
      encoder.set_input_array(x);
      encoder.set_input_array(y);
      encoder.set_output_array(out);
    
      // Launch the CPU kernel
      encoder.dispatch([x_ptr = x.data<T>(),
                        y_ptr = y.data<T>(),
                        out_ptr = out.data<T>(),
                        size = out.size(),
                        shape = out.shape(),
                        x_strides = x.strides(),
                        y_strides = y.strides(),
                        alpha_,
                        beta_]() {
    
        // Cast alpha and beta to the relevant types
        T alpha = static_cast<T>(alpha_);
        T beta = static_cast<T>(beta_);
    
        // Do the element-wise operation for each output
        for (size_t out_idx = 0; out_idx < size; out_idx++) {
          // Map linear indices to offsets in x and y
          auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
          auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
    
          // We allocate the output to be contiguous and regularly strided
          // (defaults to row major) and hence it doesn't need additional mapping
          out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
        }
      });
    }
    
Our implementation should work for all incoming floating point arrays. Accordingly, we add dispatches for `float32`, `float16`, `bfloat16` and `complex64`. We throw an error if we encounter an unexpected type.
    
    void Axpby::eval_cpu(
        const std::vector<mx::array>& inputs,
        std::vector<mx::array>& outputs) {
      auto& x = inputs[0];
      auto& y = inputs[1];
      auto& out = outputs[0];
    
      // Dispatch to the correct dtype
      if (out.dtype() == mx::float32) {
        return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
      } else if (out.dtype() == mx::float16) {
        return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
      } else if (out.dtype() == mx::bfloat16) {
        return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
      } else if (out.dtype() == mx::complex64) {
        return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
      } else {
        throw std::runtime_error(
            "Axpby is only supported for floating point types.");
      }
    }
    
Just this much is enough to run the operation `axpby()` on a CPU stream! If you do not plan on running the operation on the GPU or using transforms on computation graphs that contain `Axpby`, you can stop implementing the primitive here.
### Implementing the GPU Back-end
Apple silicon devices address their GPUs using the Metal shading language, and GPU kernels in MLX are written using Metal.
Note
Here are some helpful resources if you are new to Metal:
  * A walkthrough of the metal compute pipeline: Metal Example
  * Documentation for metal shading language: Metal Specification
  * Using metal from C++: Metal-cpp


Let’s keep the GPU kernel simple. We will launch exactly as many threads as there are elements in the output. Each thread will pick the element it needs from `x` and `y`, do the point-wise operation, and update its assigned element in the output.
    
    template <typename T>
    [[kernel]] void axpby_general(
            device const T* x [[buffer(0)]],
            device const T* y [[buffer(1)]],
            device T* out [[buffer(2)]],
            constant const float& alpha [[buffer(3)]],
            constant const float& beta [[buffer(4)]],
            constant const int* shape [[buffer(5)]],
            constant const int64_t* x_strides [[buffer(6)]],
            constant const int64_t* y_strides [[buffer(7)]],
            constant const int& ndim [[buffer(8)]],
            uint index [[thread_position_in_grid]]) {
        // Convert linear indices to offsets in array
        auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
        auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
    
        // Do the operation and update the output
        out[index] =
            static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
    }
    
We then need to instantiate this template for all floating point types and give each instantiation a unique host name so we can identify it.
    
    instantiate_kernel("axpby_general_float32", axpby_general, float)
    instantiate_kernel("axpby_general_float16", axpby_general, float16_t)
    instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t)
    instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t)
    
The logic to determine the kernel, set the inputs, resolve the grid dimensions, and dispatch to the GPU are contained in `Axpby::eval_gpu()` as shown below.
    
    /** Evaluate primitive on GPU */
    void Axpby::eval_gpu(
      const std::vector<array>& inputs,
      std::vector<array>& outputs) {
        // Prepare inputs
        assert(inputs.size() == 2);
        auto& x = inputs[0];
        auto& y = inputs[1];
        auto& out = outputs[0];
    
        // Each primitive carries the stream it should execute on
        // and each stream carries its device identifiers
        auto& s = stream();
        // We get the needed metal device using the stream
        auto& d = metal::device(s.device);
    
        // Allocate output memory
        out.set_data(allocator::malloc(out.nbytes()));
    
        // Resolve name of kernel
        std::stream kname;
        kname = "axpby_general_" + type_to_name(out);
    
        // Load the metal library
        auto lib = d.get_library("mlx_ext", current_binary_dir());
    
        // Make a kernel from this metal library
        auto kernel = d.get_kernel(kname, lib);
    
        // Prepare to encode kernel
        auto& compute_encoder = d.get_command_encoder(s.index);
        compute_encoder.set_compute_pipeline_state(kernel);
    
        // Kernel parameters are registered with buffer indices corresponding to
        // those in the kernel declaration at axpby.metal
        int ndim = out.ndim();
        size_t nelem = out.size();
    
        // Encode input arrays to kernel
        compute_encoder.set_input_array(x, 0);
        compute_encoder.set_input_array(y, 1);
    
        // Encode output arrays to kernel
        compute_encoder.set_output_array(out, 2);
    
        // Encode alpha and beta
        compute_encoder.set_bytes(alpha_, 3);
        compute_encoder.set_bytes(beta_, 4);
    
        // Encode shape, strides and ndim
        compute_encoder.set_vector_bytes(x.shape(), 5);
        compute_encoder.set_vector_bytes(x.strides(), 6);
        compute_encoder.set_bytes(y.strides(), 7);
        compute_encoder.set_bytes(ndim, 8);
    
        // We launch 1 thread for each input and make sure that the number of
        // threads in any given threadgroup is not higher than the max allowed
        size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());
    
        // Fix the 3D size of each threadgroup (in terms of threads)
        MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);
    
        // Fix the 3D size of the launch grid (in terms of threads)
        MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
    
        // Launch the grid with the given number of threads divided among
        // the given threadgroups
        compute_encoder.dispatch_threads(grid_dims, group_dims);
    }
    
We can now call the `axpby()` operation on both the CPU and the GPU!
A few things to note about MLX and Metal before moving on. MLX keeps track of the active `command_buffer` and the `MTLCommandBuffer` to which it is associated. We rely on `d.get_command_encoder()` to give us the active metal compute command encoder instead of building a new one and calling `compute_encoder->end_encoding()` at the end. MLX adds kernels (compute pipelines) to the active command buffer until some specified limit is hit or the command buffer needs to be flushed for synchronization.
### Primitive Transforms
Next, let’s add implementations for transformations in a `Primitive`. These transformations can be built on top of other operations, including the one we just defined:
    
    /** The Jacobian-vector product. */
    std::vector<array> Axpby::jvp(
            const std::vector<array>& primals,
            const std::vector<array>& tangents,
            const std::vector<int>& argnums) {
        // Forward mode diff that pushes along the tangents
        // The jvp transform on the primitive can be built with ops
        // that are scheduled on the same stream as the primitive
    
        // If argnums = {0}, we only push along x in which case the
        // jvp is just the tangent scaled by alpha
        // Similarly, if argnums = {1}, the jvp is just the tangent
        // scaled by beta
        if (argnums.size() > 1) {
            auto scale = argnums[0] == 0 ? alpha_ : beta_;
            auto scale_arr = array(scale, tangents[0].dtype());
            return {multiply(scale_arr, tangents[0], stream())};
        }
        // If argnums = {0, 1}, we take contributions from both
        // which gives us jvp = tangent_x * alpha + tangent_y * beta
        else {
            return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
        }
    }
    
    
    /** The vector-Jacobian product. */
    std::vector<array> Axpby::vjp(
            const std::vector<array>& primals,
            const std::vector<array>& cotangents,
            const std::vector<int>& argnums,
            const std::vector<int>& /* unused */) {
        // Reverse mode diff
        std::vector<array> vjps;
        for (auto arg : argnums) {
            auto scale = arg == 0 ? alpha_ : beta_;
            auto scale_arr = array(scale, cotangents[0].dtype());
            vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
        }
        return vjps;
    }
    
Note, a transformation does not need to be fully defined to start using the `Primitive`.
    
    /** Vectorize primitive along given axis */
    std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
            const std::vector<array>& inputs,
            const std::vector<int>& axes) {
        throw std::runtime_error("[Axpby] vmap not implemented.");
    }
    
## Building and Binding
Let’s look at the overall directory structure first.
extensions
├── axpby
│ ├── axpby.cpp
│ ├── axpby.h
│ └── axpby.metal
├── mlx_sample_extensions
│ └── __init__.py
├── bindings.cpp
├── CMakeLists.txt
└── setup.py
  * `extensions/axpby/` defines the C++ extension library
  * `extensions/mlx_sample_extensions` sets out the structure for the associated Python package
  * `extensions/bindings.cpp` provides Python bindings for our operation
  * `extensions/CMakeLists.txt` holds CMake rules to build the library and Python bindings
  * `extensions/setup.py` holds the `setuptools` rules to build and install the Python package


### Binding to Python
We use nanobind to build a Python API for the C++ library. Since bindings for components such as `mlx.core.array`, `mlx.core.stream`, etc. are already provided, adding our `axpby()` is simple.
    
    NB_MODULE(_ext, m) {
         m.doc() = "Sample extension for MLX";
    
         m.def(
             "axpby",
             &axpby,
             "x"_a,
             "y"_a,
             "alpha"_a,
             "beta"_a,
             nb::kw_only(),
             "stream"_a = nb::none(),
             R"(
                 Scale and sum two vectors element-wise
                 ``z = alpha * x + beta * y``
    
                 Follows numpy style broadcasting between ``x`` and ``y``
                 Inputs are upcasted to floats if needed
    
                 Args:
                     x (array): Input array.
                     y (array): Input array.
                     alpha (float): Scaling factor for ``x``.
                     beta (float): Scaling factor for ``y``.
    
                 Returns:
                     array: ``alpha * x + beta * y``
             )");
     }
    
Most of the complexity in the above example comes from additional bells and whistles such as the literal names and doc-strings.
Warning
`mlx.core` must be imported before importing `mlx_sample_extensions` as defined by the nanobind module above to ensure that the casters for `mlx.core` components like `mlx.core.array` are available.
### Building with CMake
Building the C++ extension library only requires that you `find_package(MLX CONFIG)` and then link it to your library.
    
    # Add library
    add_library(mlx_ext)
    
    # Add sources
    target_sources(
        mlx_ext
        PUBLIC
        ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
    )
    
    # Add include headers
    target_include_directories(
        mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
    )
    
    # Link to mlx
    target_link_libraries(mlx_ext PUBLIC mlx)
    
We also need to build the attached Metal library. For convenience, we provide a `mlx_build_metallib()` function that builds a `.metallib` target given sources, headers, destinations, etc. (defined in `cmake/extension.cmake` and automatically imported with MLX package).
Here is what that looks like in practice:
    
    # Build metallib
    if(MLX_BUILD_METAL)
    
    mlx_build_metallib(
        TARGET mlx_ext_metallib
        TITLE mlx_ext
        SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
        INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
        OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
    )
    
    add_dependencies(
        mlx_ext
        mlx_ext_metallib
    )
    
    endif()
    
Finally, we build the nanobind bindings
    
    nanobind_add_module(
      _ext
      NB_STATIC STABLE_ABI LTO NOMINSIZE
      NB_DOMAIN mlx
      ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
    )
    target_link_libraries(_ext PRIVATE mlx_ext)
    
    if(BUILD_SHARED_LIBS)
      target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
    endif()
    
### Building with `setuptools`
Once we have set out the CMake build rules as described above, we can use the build utilities defined in `mlx.extension`:
    
    from mlx import extension
    from setuptools import setup
    
    if __name__ == "__main__":
        setup(
            name="mlx_sample_extensions",
            version="0.0.0",
            description="Sample C++ and Metal extensions for MLX primitives.",
            ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
            cmdclass={"build_ext": extension.CMakeBuild},
            packages=["mlx_sample_extensions"],
            package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
            extras_require={"dev":[]},
            zip_safe=False,
            python_requires=">=3.8",
        )
    
Note
We treat `extensions/mlx_sample_extensions` as the package directory even though it only contains a `__init__.py` to ensure the following:
  * `mlx.core` must be imported before importing `_ext`
  * The C++ extension library and the metal library are co-located with the python bindings and copied together if the package is installed


To build the package, first install the build dependencies with `pip install -r requirements.txt`. You can then build inplace for development using `python setup.py build_ext -j8 --inplace` (in `extensions/`)
This results in the directory structure:
extensions
├── mlx_sample_extensions
│ ├── __init__.py
│ ├── libmlx_ext.dylib # C++ extension library
│ ├── mlx_ext.metallib # Metal library
│ └── _ext.cpython-3x-darwin.so # Python Binding
…
When you try to install using the command `python -m pip install .` (in `extensions/`), the package will be installed with the same structure as `extensions/mlx_sample_extensions` and the C++ and Metal library will be copied along with the Python binding since they are specified as `package_data`.
## Usage
After installing the extension as described above, you should be able to simply import the Python package and play with it as you would any other MLX operation.
Let’s look at a simple script and its results:
    
    import mlx.core as mx
    from mlx_sample_extensions import axpby
    
    a = mx.ones((3, 4))
    b = mx.ones((3, 4))
    c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
    
    print(f"c shape: {c.shape}")
    print(f"c dtype: {c.dtype}")
    print(f"c is correct: {mx.all(c == 6.0).item()}")
    
Output:
    
    c shape: [3, 4]
    c dtype: float32
    c is correct: True
    
### Results
Let’s run a quick benchmark and see how our new `axpby` operation compares with the naive `simple_axpby()` we first defined.
    
    import mlx.core as mx
    from mlx_sample_extensions import axpby
    import time
    
    def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
        return alpha * x + beta * y
    
    M = 4096
    N = 4096
    
    x = mx.random.normal((M, N))
    y = mx.random.normal((M, N))
    alpha = 4.0
    beta = 2.0
    
    mx.eval(x, y)
    
    def bench(f):
        # Warm up
        for i in range(5):
            z = f(x, y, alpha, beta)
            mx.eval(z)
    
        # Timed run
        s = time.time()
        for i in range(100):
            z = f(x, y, alpha, beta)
            mx.eval(z)
        e = time.time()
        return 1000 * (e - s) / 100
    
    simple_time = bench(simple_axpby)
    custom_time = bench(axpby)
    
    print(f"Simple axpby: {simple_time:.3f} ms | Custom axpby: {custom_time:.3f} ms")
    
The results are `Simple axpby: 1.559 ms | Custom axpby: 0.774 ms`. We see modest improvements right away!
This operation is now good to be used to build other operations, in `mlx.nn.Module` calls, and also as a part of graph transformations like `grad()`.
## Scripts
Download the code
The full example code is available in mlx.
# Metal Debugger
Profiling is a key step for performance optimization. You can build MLX with the `MLX_METAL_DEBUG` option to improve the Metal debugging and optimization workflow. The `MLX_METAL_DEBUG` debug option:
  * Records source during Metal compilation, for later inspection while debugging.
  * Labels Metal objects such as command queues, improving capture readability.


To build with debugging enabled in Python prepend `CMAKE_ARGS="-DMLX_METAL_DEBUG=ON"` to the build call.
The `metal.start_capture()` function initiates a capture of all MLX GPU work.
Note
To capture a GPU trace you must run the application with `MTL_CAPTURE_ENABLED=1`.
    
    import mlx.core as mx
    
    a = mx.random.uniform(shape=(512, 512))
    b = mx.random.uniform(shape=(512, 512))
    mx.eval(a, b)
    
    trace_file = "mlx_trace.gputrace"
    
    # Make sure to run with MTL_CAPTURE_ENABLED=1 and
    # that the path trace_file does not already exist.
    mx.metal.start_capture(trace_file)
    
    for _ in range(10):
      mx.eval(mx.add(a, b))
    
    mx.metal.stop_capture()
    
You can open and replay the GPU trace in Xcode. The `Dependencies` view has a great overview of all operations. Checkout the Metal debugger documentation for more information.
## Xcode Workflow
You can skip saving to a path by running within Xcode. First, generate an Xcode project using CMake.
    
    mkdir build && cd build
    cmake .. -DMLX_METAL_DEBUG=ON -G Xcode
    open mlx.xcodeproj
    
Select the `metal_capture` example schema and run.
# Using MLX in C++
You can use MLX in a C++ project with CMake.
Note
This guide is based one the following example using MLX in C++
First install MLX:
    
    pip install -U mlx
    
You can also install the MLX Python package from source or just the C++ library. For more information see the documentation on installing MLX.
Next make an example program in `example.cpp`:
    
    #include <iostream>
    
    #include "mlx/mlx.h"
    
    namespace mx = mlx::core;
    
    int main() {
      auto x = mx::array({1, 2, 3});
      auto y = mx::array({1, 2, 3});
      std::cout << x + y << std::endl;
      return 0;
    }
    
The next step is to setup a CMake file in `CMakeLists.txt`:
    
    cmake_minimum_required(VERSION 3.27)
    
    project(example LANGUAGES CXX)
    
    set(CMAKE_CXX_STANDARD 17)
    set(CMAKE_CXX_STANDARD_REQUIRED ON)
    
Depending on how you installed MLX, you may need to tell CMake where to find it.
If you installed MLX with Python, then add the following to the CMake file:
    
    find_package(
      Python 3.9
      COMPONENTS Interpreter Development.Module
      REQUIRED)
    execute_process(
      COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
      OUTPUT_STRIP_TRAILING_WHITESPACE
      OUTPUT_VARIABLE MLX_ROOT)
    
If you installed the MLX C++ package to a system path, then CMake should be able to find it. If you installed it to a non-standard location or CMake can’t find MLX then set `MLX_ROOT` to the location where MLX is installed:
    
    set(MLX_ROOT "/path/to/mlx/")
    
Next, instruct CMake to find MLX:
    
    find_package(MLX CONFIG REQUIRED)
    
Finally, add the `example.cpp` program as an executable and link MLX.
    
    add_executable(example example.cpp)
    target_link_libraries(example PRIVATE mlx)
    
You can build the example with:
    
    cmake -B build -DCMAKE_BUILD_TYPE=Release
    cmake --build build
    
And run it with:
    
    ./build/example
    
Note `find_package(MLX CONFIG REQUIRED)` sets the following variables:
Package Variables
Variable
Description  
MLX_FOUND
`True` if MLX is found  
MLX_INCLUDE_DIRS
Include directory  
MLX_LIBRARIES
Libraries to link against  
MLX_CXX_FLAGS
Additional compiler flags  
MLX_BUILD_ACCELERATE
`True` if MLX was built with Accelerate  
MLX_BUILD_METAL
`True` if MLX was built with Metal  
# Linear Regression
Let’s implement a basic linear regression model as a starting point to learn MLX. First import the core package and setup some problem metadata:
    
    import mlx.core as mx
    
    num_features = 100
    num_examples = 1_000
    num_iters = 10_000  # iterations of SGD
    lr = 0.01  # learning rate for SGD
    
We’ll generate a synthetic dataset by:
  1. Sampling the design matrix `X`.
  2. Sampling a ground truth parameter vector `w_star`.
  3. Compute the dependent values `y` by adding Gaussian noise to `X @ w_star`.


    
    # True parameters
    w_star = mx.random.normal((num_features,))
    
    # Input examples (design matrix)
    X = mx.random.normal((num_examples, num_features))
    
    # Noisy labels
    eps = 1e-2 * mx.random.normal((num_examples,))
    y = X @ w_star + eps
    
We will use SGD to find the optimal weights. To start, define the squared loss and get the gradient function of the loss with respect to the parameters.
    
    def loss_fn(w):
        return 0.5 * mx.mean(mx.square(X @ w - y))
    
    grad_fn = mx.grad(loss_fn)
    
Start the optimization by initializing the parameters `w` randomly. Then repeatedly update the parameters for `num_iters` iterations.
    
    w = 1e-2 * mx.random.normal((num_features,))
    
    for _ in range(num_iters):
        grad = grad_fn(w)
        w = w - lr * grad
        mx.eval(w)
    
Finally, compute the loss of the learned parameters and verify that they are close to the ground truth parameters.
    
    loss = loss_fn(w)
    error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5
    
    print(
        f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, "
    )
    # Should print something close to: Loss 0.00005, |w-w*| = 0.00364
    
Complete linear regression and logistic regression examples are available in the MLX GitHub repo.
# LLM inference
MLX enables efficient inference of large-ish transformers on Apple silicon without compromising on ease of use. In this example we will create an inference script for the Llama family of transformer models in which the model is defined in less than 200 lines of python.
## Implementing the model
We will use the neural network building blocks defined in the `mlx.nn` module to concisely define the model architecture.
### Attention layer
We will start with the Llama attention layer which notably uses the RoPE positional encoding. [1] In addition, our attention layer will optionally use a key/value cache that will be concatenated with the provided keys and values to support efficient inference.
Our implementation uses `mlx.nn.Linear` for all the projections and `mlx.nn.RoPE` for the positional encoding.
    
    import mlx.core as mx
    import mlx.nn as nn
    
    class LlamaAttention(nn.Module):
        def __init__(self, dims: int, num_heads: int):
            super().__init__()
    
            self.num_heads = num_heads
    
            self.rope = nn.RoPE(dims // num_heads, traditional=True)
            self.query_proj = nn.Linear(dims, dims, bias=False)
            self.key_proj = nn.Linear(dims, dims, bias=False)
            self.value_proj = nn.Linear(dims, dims, bias=False)
            self.out_proj = nn.Linear(dims, dims, bias=False)
    
        def __call__(self, queries, keys, values, mask=None, cache=None):
            queries = self.query_proj(queries)
            keys = self.key_proj(keys)
            values = self.value_proj(values)
    
            # Extract some shapes
            num_heads = self.num_heads
            B, L, D = queries.shape
    
            # Prepare the queries, keys and values for the attention computation
            queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
            keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
            values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
    
            # Add RoPE to the queries and keys and combine them with the cache
            if cache is not None:
                key_cache, value_cache = cache
                queries = self.rope(queries, offset=key_cache.shape[2])
                keys = self.rope(keys, offset=key_cache.shape[2])
                keys = mx.concatenate([key_cache, keys], axis=2)
                values = mx.concatenate([value_cache, values], axis=2)
            else:
                queries = self.rope(queries)
                keys = self.rope(keys)
    
            # Finally perform the attention computation
            scale = math.sqrt(1 / queries.shape[-1])
            scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
            if mask is not None:
                scores = scores + mask
            scores = mx.softmax(scores, axis=-1)
            values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
    
            # Note that we return the keys and values to possibly be used as a cache
            return self.out_proj(values_hat), (keys, values)
    
### Encoder layer
The other component of the Llama model is the encoder layer which uses RMS normalization [2] and SwiGLU. [3] For RMS normalization we will use `mlx.nn.RMSNorm` that is already provided in `mlx.nn`.
    
    class LlamaEncoderLayer(nn.Module):
        def __init__(self, dims: int, mlp_dims: int, num_heads: int):
            super().__init__()
    
            self.attention = LlamaAttention(dims, num_heads)
    
            self.norm1 = nn.RMSNorm(dims)
            self.norm2 = nn.RMSNorm(dims)
    
            self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
            self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
            self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
    
        def __call__(self, x, mask=None, cache=None):
            y = self.norm1(x)
            y, cache = self.attention(y, y, y, mask, cache)
            x = x + y
    
            y = self.norm2(x)
            a = self.linear1(y)
            b = self.linear2(y)
            y = a * mx.sigmoid(a) * b
            y = self.linear3(y)
            x = x + y
    
            return x, cache
    
### Full model
To implement any Llama model we simply have to combine `LlamaEncoderLayer` instances with an `mlx.nn.Embedding` to embed the input tokens.
    
    class Llama(nn.Module):
        def __init__(
            self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int
        ):
            super().__init__()
    
            self.embedding = nn.Embedding(vocab_size, dims)
            self.layers = [
                LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers)
            ]
            self.norm = nn.RMSNorm(dims)
            self.out_proj = nn.Linear(dims, vocab_size, bias=False)
    
        def __call__(self, x):
            mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
            mask = mask.astype(self.embedding.weight.dtype)
    
            x = self.embedding(x)
            for l in self.layers:
                x, _ = l(x, mask)
            x = self.norm(x)
            return self.out_proj(x)
    
Note that in the implementation above we use a simple list to hold the encoder layers but using `model.parameters()` will still consider these layers.
### Generation
Our `Llama` module can be used for training but not inference as the `__call__` method above processes one input, completely ignores the cache and performs no sampling whatsoever. In the rest of this subsection, we will implement the inference function as a python generator that processes the prompt and then autoregressively yields tokens one at a time.
    
    class Llama(nn.Module):
        ...
    
        def generate(self, x, temp=1.0):
            cache = []
    
            # Make an additive causal mask. We will need that to process the prompt.
            mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
            mask = mask.astype(self.embedding.weight.dtype)
    
            # First we process the prompt x the same way as in __call__ but
            # save the caches in cache
            x = self.embedding(x)
            for l in self.layers:
                x, c = l(x, mask=mask)
                cache.append(c)  # <--- we store the per layer cache in a
                                 #      simple python list
            x = self.norm(x)
            y = self.out_proj(x[:, -1])  # <--- we only care about the last logits
                                         #      that generate the next token
            y = mx.random.categorical(y * (1/temp))
    
            # y now has size [1]
            # Since MLX is lazily evaluated nothing is computed yet.
            # Calling y.item() would force the computation to happen at
            # this point but we can also choose not to do that and let the
            # user choose when to start the computation.
            yield y
    
            # Now we parsed the prompt and generated the first token we
            # need to feed it back into the model and loop to generate the
            # rest.
            while True:
                # Unsqueezing the last dimension to add a sequence length
                # dimension of 1
                x = y[:, None]
    
                x = self.embedding(x)
                for i in range(len(cache)):
                    # We are overwriting the arrays in the cache list. When
                    # the computation will happen, MLX will be discarding the
                    # old cache the moment it is not needed anymore.
                    x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
                x = self.norm(x)
                y = self.out_proj(x[:, -1])
                y = mx.random.categorical(y * (1/temp))
    
                yield y
    
### Putting it all together
We now have everything we need to create a Llama model and sample tokens from it. In the following code, we randomly initialize a small Llama model, process 6 tokens of prompt and generate 10 tokens.
    
    model = Llama(num_layers=12, vocab_size=8192, dims=512, mlp_dims=1024, num_heads=8)
    
    # Since MLX is lazily evaluated nothing has actually been materialized yet.
    # We could have set the `dims` to 20_000 on a machine with 8GB of RAM and the
    # code above would still run. Let's actually materialize the model.
    mx.eval(model.parameters())
    
    prompt = mx.array([[1, 10, 8, 32, 44, 7]])  # <-- Note the double brackets because we
                                                #     have a batch dimension even
                                                #     though it is 1 in this case
    
    generated = [t for i, t in zip(range(10), model.generate(prompt, 0.8))]
    
    # Since we haven't evaluated anything, nothing is computed yet. The list
    # `generated` contains the arrays that hold the computation graph for the
    # full processing of the prompt and the generation of 10 tokens.
    #
    # We can evaluate them one at a time, or all together. Concatenate them or
    # print them. They would all result in very similar runtimes and give exactly
    # the same results.
    mx.eval(generated)
    
## Converting the weights
This section assumes that you have access to the original Llama weights and the SentencePiece model that comes with them. We will write a small script to convert the PyTorch weights to MLX compatible ones and write them in a NPZ file that can be loaded directly by MLX.
    
    import argparse
    from itertools import starmap
    
    import numpy as np
    import torch
    
    def map_torch_to_mlx(key, value):
        if "tok_embedding" in key:
            key = "embedding.weight"
    
        elif "norm" in key:
            key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
    
        elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
            key = key.replace("wq", "query_proj")
            key = key.replace("wk", "key_proj")
            key = key.replace("wv", "value_proj")
            key = key.replace("wo", "out_proj")
    
        elif "w1" in key or "w2" in key or "w3" in key:
            # The FFN is a separate submodule in PyTorch
            key = key.replace("feed_forward.w1", "linear1")
            key = key.replace("feed_forward.w3", "linear2")
            key = key.replace("feed_forward.w2", "linear3")
    
        elif "output" in key:
            key = key.replace("output", "out_proj")
    
        elif "rope" in key:
            return None, None
    
        return key, value.numpy()
    
    
    if __name__ == "__main__":
        parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
        parser.add_argument("torch_weights")
        parser.add_argument("output_file")
        args = parser.parse_args()
    
        state = torch.load(args.torch_weights)
        np.savez(
            args.output_file,
            **{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}
        )
    
## Weight loading and benchmarking
After converting the weights to be compatible to our implementation, all that is left is to load them from disk and we can finally use the LLM to generate text. We can load numpy format files using the `mlx.core.load()` operation.
To create a parameter dictionary from the key/value representation of NPZ files we will use the `mlx.utils.tree_unflatten()` helper method as follows:
    
    from mlx.utils import tree_unflatten
    
    model.update(tree_unflatten(list(mx.load(weight_file).items())))
    
`mlx.utils.tree_unflatten()` will take keys from the NPZ file that look like `layers.2.attention.query_proj.weight` and will transform them to
    
    {"layers": [..., ..., {"attention": {"query_proj": {"weight": ...}}}]}
    
which can then be used to update the model. Note that the method above incurs several unnecessary copies from disk to numpy and then from numpy to MLX. It will be replaced in the future with direct loading to MLX.
You can download the full example code in mlx-examples. Assuming, the existence of `weights.pth` and `tokenizer.model` in the current working directory we can play around with our inference script as follows (the timings are representative of an M1 Ultra and the 7B parameter Llama model):
    
    $ python convert.py weights.pth llama-7B.mlx.npz
    $ python llama.py llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely'
    [INFO] Loading model from disk: 5.247 s
    Press enter to start generation
    ------
    , having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down,
    ------
    [INFO] Prompt processing: 0.437 s
    [INFO] Full generation: 4.330 s
    
We observe that 4.3 seconds are required to generate 100 tokens and 0.4 seconds of those are spent processing the prompt. This amounts to a little over 39 ms per token.
By running with a much bigger prompt we can see that the per token generation time as well as the prompt processing time remains almost constant.
    
    $ python llama.py llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not'
    [INFO] Loading model from disk: 5.247 s
    Press enter to start generation
    ------
    take his eyes from the ground. “What is it you are waiting for?” said I. “I am not accustomed to be thus questioned,” said he. “You look like a reasonable man—tell me, then, what are you waiting for?” “You would not understand,” he replied; “and how could you help me, if I were to tell you?” “I should not only understand, but would do all that I could,” said I. He did not
    ------
    [INFO] Prompt processing: 0.579 s
    [INFO] Full generation: 4.690 s
    $ python llama.py --num-tokens 500 llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not'
    [INFO] Loading model from disk: 5.628 s
    Press enter to start generation
    ------
    take his eyes from the ground. “What is it you are waiting for?” said I. “I am not accustomed to be thus questioned,” said he. “You look like a reasonable man—tell me, then, what are you waiting for?” “You would not understand,” he replied; “and how could you help me, if I were to tell you?” “I should not only understand, but would do all that I could,” said I. He did not reply, but still went on looking at the ground, and took hold of his bundle with a nervous trembling. I waited some time, and then resumed. “It is of no use to say you would not understand, if I were to tell you,” said he. “I have not told you why I am waiting for him,” said I. “And I am sure I should not understand,” replied he. “I will tell you then,” said I, “and, perhaps, you would not be surprised.” “No matter,” said he, “I shall be surprised anyhow; so tell me why you are waiting for him.” “He is my friend,” said I. “Yes,” said he, with a slight smile, “I know.” “He has been kind to me,” said I, “and I am waiting for him. I want to see him, and could have waited as I am now, for a much longer time.” “He will not soon come,” said he. “Unless he sees you here, he will not know of your having waited, and he will be very unlikely to come.” “No matter,” said I, “I shall wait for him.” “This is a strange thing,” said he, still with the same amused smile. “How did you know,” said I, “that he was coming? How should you be waiting?” “That is my secret,” said he. “And you expect him?” “Yes,” said I. “Are you disappointed then, if he does not come?” “No,” said I, “it is his secret, not mine.” “If he comes,” said he, “do you mean to go straight away?” “Yes,” said I, “I cannot be happy if I do not go straight away after him.” “Did you know this place before?” asked he. “Yes,” said I. “Is there any shop to buy food here?” “
    ------
    [INFO] Prompt processing: 0.633 s
    [INFO] Full generation: 21.475 s
    
## Scripts
Download the code
The full example code is available in mlx-examples.
[1]
Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B. and Liu, Y., 2021. Roformer: Enhanced transformer with rotary position embedding. arXiv preprint arXiv:2104.09864.
[2]
Zhang, B. and Sennrich, R., 2019. Root mean square layer normalization. Advances in Neural Information Processing Systems, 32.
[3]
Shazeer, N., 2020. Glu variants improve transformer. arXiv preprint arXiv:2002.05202.
# Multi-Layer Perceptron
In this example we’ll learn to use `mlx.nn` by implementing a simple multi-layer perceptron to classify MNIST.
As a first step import the MLX packages we need:
    
    import mlx.core as mx
    import mlx.nn as nn
    import mlx.optimizers as optim
    
    import numpy as np
    
The model is defined as the `MLP` class which inherits from `mlx.nn.Module`. We follow the standard idiom to make a new module:
  1. Define an `__init__` where the parameters and/or submodules are setup. See the Module class docs for more information on how `mlx.nn.Module` registers parameters.
  2. Define a `__call__` where the computation is implemented.


    
    class MLP(nn.Module):
        def __init__(
            self, num_layers: int, input_dim: int, hidden_dim: int, output_dim: int
        ):
            super().__init__()
            layer_sizes = [input_dim] + [hidden_dim] * num_layers + [output_dim]
            self.layers = [
                nn.Linear(idim, odim)
                for idim, odim in zip(layer_sizes[:-1], layer_sizes[1:])
            ]
    
        def __call__(self, x):
            for l in self.layers[:-1]:
                x = mx.maximum(l(x), 0.0)
            return self.layers[-1](x)
    
We define the loss function which takes the mean of the per-example cross entropy loss. The `mlx.nn.losses` sub-package has implementations of some commonly used loss functions.
    
    def loss_fn(model, X, y):
        return mx.mean(nn.losses.cross_entropy(model(X), y))
    
We also need a function to compute the accuracy of the model on the validation set:
    
    def eval_fn(model, X, y):
        return mx.mean(mx.argmax(model(X), axis=1) == y)
    
Next, setup the problem parameters and load the data. To load the data, you need our mnist data loader, which we will import as `mnist`.
    
    num_layers = 2
    hidden_dim = 32
    num_classes = 10
    batch_size = 256
    num_epochs = 10
    learning_rate = 1e-1
    
    # Load the data
    import mnist
    train_images, train_labels, test_images, test_labels = map(
        mx.array, mnist.mnist()
    )
    
Since we’re using SGD, we need an iterator which shuffles and constructs minibatches of examples in the training set:
    
    def batch_iterate(batch_size, X, y):
        perm = mx.array(np.random.permutation(y.size))
        for s in range(0, y.size, batch_size):
            ids = perm[s : s + batch_size]
            yield X[ids], y[ids]
    
Finally, we put it all together by instantiating the model, the `mlx.optimizers.SGD` optimizer, and running the training loop:
    
    # Load the model
    model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
    mx.eval(model.parameters())
    
    # Get a function which gives the loss and gradient of the
    # loss with respect to the model's trainable parameters
    loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
    
    # Instantiate the optimizer
    optimizer = optim.SGD(learning_rate=learning_rate)
    
    for e in range(num_epochs):
        for X, y in batch_iterate(batch_size, train_images, train_labels):
            loss, grads = loss_and_grad_fn(model, X, y)
    
            # Update the optimizer state and model parameters
            # in a single call
            optimizer.update(model, grads)
    
            # Force a graph evaluation
            mx.eval(model.parameters(), optimizer.state)
    
        accuracy = eval_fn(model, test_images, test_labels)
        print(f"Epoch {e}: Test accuracy {accuracy.item():.3f}")
    
Note
The `mlx.nn.value_and_grad()` function is a convenience function to get the gradient of a loss with respect to the trainable parameters of a model. This should not be confused with `mlx.core.value_and_grad()`.
The model should train to a decent accuracy (about 95%) after just a few passes over the training set. The full example is available in the MLX GitHub repo.
# mlx.core.Device
class Device
    
A device to run operations on.
__init__(self, type: mlx.core.DeviceType, index: int = 0) → None
    
Methods
`__init__`(self, type[, index])  
Attributes
`type`
(self) -> mlx.core.DeviceType  
# mlx.core.Dtype
class Dtype
    
An object to hold the type of a `array`.
See the list of types for more details on available data types.
__init__(*args, **kwargs)
    
Methods
`__init__`(*args, **kwargs)  
Attributes
`size`
Size of the type in bytes.  
# mlx.core.DtypeCategory
class DtypeCategory(value)
    
Type to hold categories of `dtypes`.
  * `generic`
    * bool_
    * `number`
      * `integer`
        * `unsignedinteger`
          * uint8
          * uint16
          * uint32
          * uint64
        * `signedinteger`
          * int8
          * int32
          * int64
      * `inexact`
        * `floating`
          * float16
          * bfloat16
          * float32
          * float64
        * `complexfloating`
          * complex64


See also `issubdtype()`.
__init__()
    
Attributes
`complexfloating`  
`floating`  
`inexact`  
`signedinteger`  
`unsignedinteger`  
`integer`  
`number`  
`generic`  
# mlx.core.abs
abs(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise absolute value.
Parameters:
    
a (array) – Input array.
Returns:
    
The absolute value of `a`.
Return type:
    
array
# mlx.core.add
add(a: scalar | array, b: scalar | array, stream: None | Stream | Device = None) → array
    
Element-wise addition.
Add two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.


Returns:
    
The sum of `a` and `b`.
Return type:
    
array
# mlx.core.addmm
addmm(c: array, a: array, b: array, /, alpha: float = 1.0, beta: float = 1.0, *, stream: None | Stream | Device = None) → array
    
Matrix multiplication with addition and optional scaling.
Perform the (possibly batched) matrix multiplication of two arrays and add to the result with optional scaling factors.
Parameters:
    
  * c (array) – Input array or scalar.
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.
  * alpha (float, optional) – Scaling factor for the matrix product of `a` and `b` (default: `1`)
  * beta (float, optional) – Scaling factor for `c` (default: `1`)


Returns:
    
`alpha * (a @ b) + beta * c`
Return type:
    
array
# mlx.core.all
all(a: array, /, axis: None | int | Sequence[int] = None, keepdims: bool = False, *, stream: None | Stream | Device = None) → array
    
An and reduction over the given axes.
Parameters:
    
  * a (array) – Input array.
  * axis (int or list(int), optional) – Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array.
  * keepdims (bool, optional) – Keep reduced axes as singleton dimensions, defaults to False.


Returns:
    
The output array with the corresponding axes reduced.
Return type:
    
array
# mlx.core.allclose
allclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, equal_nan: bool = False, stream: None | Stream | Device = None) → array
    
Approximate comparison of two arrays.
Infinite values are considered equal if they have the same sign, NaN values are not equal unless `equal_nan` is `True`.
The arrays are considered equal if:
    
    all(abs(a - b) <= (atol + rtol * abs(b)))
    
Note unlike `array_equal()`, this function supports numpy-style broadcasting.
Parameters:
    
  * a (array) – Input array.
  * b (array) – Input array.
  * rtol (float) – Relative tolerance.
  * atol (float) – Absolute tolerance.
  * equal_nan (bool) – If `True`, NaNs are considered equal. Defaults to `False`.


Returns:
    
The boolean output scalar indicating if the arrays are close.
Return type:
    
array
# mlx.core.any
any(a: array, /, axis: None | int | Sequence[int] = None, keepdims: bool = False, *, stream: None | Stream | Device = None) → array
    
An or reduction over the given axes.
Parameters:
    
  * a (array) – Input array.
  * axis (int or list(int), optional) – Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array.
  * keepdims (bool, optional) – Keep reduced axes as singleton dimensions, defaults to False.


Returns:
    
The output array with the corresponding axes reduced.
Return type:
    
array
# mlx.core.arange
arange(start: int | float, stop: int | float, step: None | int | float, dtype: Dtype | None = None, *, stream: None | Stream | Device = None) → array
arange(stop: int | float, step: None | int | float = None, dtype: Dtype | None = None, *, stream: None | Stream | Device = None) → array
    
Overloaded function.
  1. `arange(start : Union[int, float], stop : Union[int, float], step : Union[None, int, float], dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array`
> Generates ranges of numbers.
> Generate numbers in the half-open interval `[start, stop)` in increments of `step`.
> Args:
>     
> start (float or int, optional): Starting value which defaults to `0`. stop (float or int): Stopping value. step (float or int, optional): Increment which defaults to `1`. dtype (Dtype, optional): Specifies the data type of the output. If unspecified will default to `float32` if any of `start`, `stop`, or `step` are `float`. Otherwise will default to `int32`.
> Returns:
>     
> array: The range of values.
> Note:
>     
> Following the Numpy convention the actual increment used to generate numbers is `dtype(start + step) - dtype(start)`. This can lead to unexpected results for example if start + step is a fractional value and the dtype is integral.
  2. `arange(stop : Union[int, float], step : Union[None, int, float] = None, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array`


# mlx.core.arccos
arccos(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise inverse cosine.
Parameters:
    
a (array) – Input array.
Returns:
    
The inverse cosine of `a`.
Return type:
    
array
# mlx.core.arccosh
arccosh(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise inverse hyperbolic cosine.
Parameters:
    
a (array) – Input array.
Returns:
    
The inverse hyperbolic cosine of `a`.
Return type:
    
array
# mlx.core.arcsin
arcsin(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise inverse sine.
Parameters:
    
a (array) – Input array.
Returns:
    
The inverse sine of `a`.
Return type:
    
array
# mlx.core.arcsinh
arcsinh(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise inverse hyperbolic sine.
Parameters:
    
a (array) – Input array.
Returns:
    
The inverse hyperbolic sine of `a`.
Return type:
    
array
# mlx.core.arctan
arctan(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise inverse tangent.
Parameters:
    
a (array) – Input array.
Returns:
    
The inverse tangent of `a`.
Return type:
    
array
# mlx.core.arctan2
arctan2(a: array, b: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise inverse tangent of the ratio of two arrays.
Parameters:
    
  * a (array) – Input array.
  * b (array) – Input array.


Returns:
    
The inverse tangent of the ratio of `a` and `b`.
Return type:
    
array
# mlx.core.arctanh
arctanh(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise inverse hyperbolic tangent.
Parameters:
    
a (array) – Input array.
Returns:
    
The inverse hyperbolic tangent of `a`.
Return type:
    
array
# mlx.core.argmax
argmax(a: array, /, axis: None | int = None, keepdims: bool = False, *, stream: None | Stream | Device = None) → array
    
Indices of the maximum values along the axis.
Parameters:
    
  * a (array) – Input array.
  * axis (int, optional) – Optional axis to reduce over. If unspecified this defaults to reducing over the entire array.
  * keepdims (bool, optional) – Keep reduced axes as singleton dimensions, defaults to False.


Returns:
    
The `uint32` array with the indices of the maximum values.
Return type:
    
array
# mlx.core.argmin
argmin(a: array, /, axis: None | int = None, keepdims: bool = False, *, stream: None | Stream | Device = None) → array
    
Indices of the minimum values along the axis.
Parameters:
    
  * a (array) – Input array.
  * axis (int, optional) – Optional axis to reduce over. If unspecified this defaults to reducing over the entire array.
  * keepdims (bool, optional) – Keep reduced axes as singleton dimensions, defaults to False.


Returns:
    
The `uint32` array with the indices of the minimum values.
Return type:
    
array
# mlx.core.argpartition
argpartition(a: array, /, kth: int, axis: None | int = -1, *, stream: None | Stream | Device = None) → array
    
Returns the indices that partition the array.
The ordering of the elements within a partition in given by the indices is undefined.
Parameters:
    
  * a (array) – Input array.
  * kth (int) – Element index at the `kth` position in the output will give the sorted position. All indices before the `kth` position will be of elements less or equal to the element at the `kth` index and all indices after will be of elements greater or equal to the element at the `kth` index.
  * axis (int or None, optional) – Optional axis to partition over. If `None`, this partitions over the flattened array. If unspecified, it defaults to `-1`.


Returns:
    
The `uint32` array containing indices that partition the input.
Return type:
    
array
# mlx.core.argsort
argsort(a: array, /, axis: None | int = -1, *, stream: None | Stream | Device = None) → array
    
Returns the indices that sort the array.
Parameters:
    
  * a (array) – Input array.
  * axis (int or None, optional) – Optional axis to sort over. If `None`, this sorts over the flattened array. If unspecified, it defaults to -1 (sorting over the last axis).


Returns:
    
The `uint32` array containing indices that sort the input.
Return type:
    
array
# mlx.core.array.T
property array.T
    
Equivalent to calling `self.transpose()` with no arguments.
# mlx.core.array.abs
array.abs(self, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `abs()`.
# mlx.core.array.all
array.all(self, axis: Optional[Union[int, Sequence[int]]] = None, keepdims: bool = False, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `all()`.
# mlx.core.array.any
array.any(self, axis: Optional[Union[int, Sequence[int]]] = None, keepdims: bool = False, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `any()`.
# mlx.core.array.argmax
array.argmax(self, axis: Optional[int] = None, keepdims: bool = False, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `argmax()`.
# mlx.core.array.argmin
array.argmin(self, axis: Optional[int] = None, keepdims: bool = False, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `argmin()`.
# mlx.core.array.astype
array.astype(self, dtype: Dtype, stream: Optional[Union[Stream, Device]] = None) → array
    
Cast the array to a specified type.
Parameters:
    
  * dtype (Dtype) – Type to which the array is cast.
  * stream (Stream) – Stream (or device) for the operation.


Returns:
    
The array with type `dtype`.
Return type:
    
array
# mlx.core.array.at
property array.at
    
Used to apply updates at the given indices.
Note
Regular in-place updates map to assignment. For instance `x[idx] += y` maps to `x[idx] = x[idx] + y`. As a result, assigning to the same index ignores all but one update. Using `x.at[idx].add(y)` will correctly apply all updates to all indices.
array.at syntax
In-place syntax  
`x = x.at[idx].add(y)`
`x[idx] += y`  
`x = x.at[idx].subtract(y)`
`x[idx] -= y`  
`x = x.at[idx].multiply(y)`
`x[idx] *= y`  
`x = x.at[idx].divide(y)`
`x[idx] /= y`  
`x = x.at[idx].maximum(y)`
`x[idx] = mx.maximum(x[idx], y)`  
`x = x.at[idx].minimum(y)`
`x[idx] = mx.minimum(x[idx], y)`  
Example
    
    >>> a = mx.array([0, 0])
    >>> idx = mx.array([0, 1, 0, 1])
    >>> a[idx] += 1
    >>> a
    array([1, 1], dtype=int32)
    >>>
    >>> a = mx.array([0, 0])
    >>> a.at[idx].add(1)
    array([2, 2], dtype=int32)
    
# mlx.core.array.conj
array.conj(self, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `conj()`.
# mlx.core.array.cos
array.cos(self, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `cos()`.
# mlx.core.array.cummax
array.cummax(self, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Optional[Union[Stream, Device]] = None) → array
    
See `cummax()`.
# mlx.core.array.cummin
array.cummin(self, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Optional[Union[Stream, Device]] = None) → array
    
See `cummin()`.
# mlx.core.array.cumprod
array.cumprod(self, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Optional[Union[Stream, Device]] = None) → array
    
See `cumprod()`.
# mlx.core.array.cumsum
array.cumsum(self, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Optional[Union[Stream, Device]] = None) → array
    
See `cumsum()`.
# mlx.core.array.diag
array.diag(self, k: int = 0, *, stream: Optional[Union[Stream, Device]] = None) → array
    
Extract a diagonal or construct a diagonal matrix.
# mlx.core.array.diagonal
array.diagonal(self, offset: int = 0, axis1: int = 0, axis2: int = 1, stream: Optional[Union[Stream, Device]] = None) → array
    
See `diagonal()`.
# mlx.core.array.dtype
property array.dtype
    
The array’s `Dtype`.
# mlx.core.array.exp
array.exp(self, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `exp()`.
# mlx.core.array.flatten
array.flatten(self, start_axis: int = 0, end_axis: int = -1, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `flatten()`.
# mlx.core.array
class array
    
An N-dimensional array object.
__init__(self: array, val: scalar | list | tuple | ndarray | array, dtype: Dtype | None = None)
    
Methods
`__init__`(self, val[, dtype])  
`abs`(self, *[, stream])
See `abs()`.  
`all`(self[, axis, keepdims, stream])
See `all()`.  
`any`(self[, axis, keepdims, stream])
See `any()`.  
`argmax`(self[, axis, keepdims, stream])
See `argmax()`.  
`argmin`(self[, axis, keepdims, stream])
See `argmin()`.  
`astype`(self, dtype[, stream])
Cast the array to a specified type.  
`conj`(self, *[, stream])
See `conj()`.  
`cos`(self, *[, stream])
See `cos()`.  
`cummax`(self[, axis, reverse, inclusive, stream])
See `cummax()`.  
`cummin`(self[, axis, reverse, inclusive, stream])
See `cummin()`.  
`cumprod`(self[, axis, reverse, inclusive, stream])
See `cumprod()`.  
`cumsum`(self[, axis, reverse, inclusive, stream])
See `cumsum()`.  
`diag`(self[, k, stream])
Extract a diagonal or construct a diagonal matrix.  
`diagonal`(self[, offset, axis1, axis2, stream])
See `diagonal()`.  
`exp`(self, *[, stream])
See `exp()`.  
`flatten`(self[, start_axis, end_axis, stream])
See `flatten()`.  
`item`(self)
Access the value of a scalar array.  
`log`(self, *[, stream])
See `log()`.  
`log10`(self, *[, stream])
See `log10()`.  
`log1p`(self, *[, stream])
See `log1p()`.  
`log2`(self, *[, stream])
See `log2()`.  
`logcumsumexp`(self[, axis, reverse, ...])
See `logcumsumexp()`.  
`logsumexp`(self[, axis, keepdims, stream])
See `logsumexp()`.  
`max`(self[, axis, keepdims, stream])
See `max()`.  
`mean`(self[, axis, keepdims, stream])
See `mean()`.  
`min`(self[, axis, keepdims, stream])
See `min()`.  
`moveaxis`(self, source, destination, *[, stream])
See `moveaxis()`.  
`prod`(self[, axis, keepdims, stream])
See `prod()`.  
`reciprocal`(self, *[, stream])
See `reciprocal()`.  
`reshape`(self, *shape[, stream])
Equivalent to `reshape()` but the shape can be passed either as a `tuple` or as separate arguments.  
`round`(self[, decimals, stream])
See `round()`.  
`rsqrt`(self, *[, stream])
See `rsqrt()`.  
`sin`(self, *[, stream])
See `sin()`.  
`split`(self, indices_or_sections[, axis, stream])
See `split()`.  
`sqrt`(self, *[, stream])
See `sqrt()`.  
`square`(self, *[, stream])
See `square()`.  
`squeeze`(self[, axis, stream])
See `squeeze()`.  
`std`(self[, axis, keepdims, ddof, stream])
See `std()`.  
`sum`(self[, axis, keepdims, stream])
See `sum()`.  
`swapaxes`(self, axis1, axis2, *[, stream])
See `swapaxes()`.  
`tolist`(self)
Convert the array to a Python `list`.  
`transpose`(self, *axes[, stream])
Equivalent to `transpose()` but the axes can be passed either as a tuple or as separate arguments.  
`var`(self[, axis, keepdims, ddof, stream])
See `var()`.  
`view`(self, dtype, *[, stream])
See `view()`.  
Attributes
`T`
Equivalent to calling `self.transpose()` with no arguments.  
`at`
Used to apply updates at the given indices.  
`dtype`
The array's `Dtype`.  
`imag`
The imaginary part of a complex array.  
`itemsize`
The size of the array's datatype in bytes.  
`nbytes`
The number of bytes in the array.  
`ndim`
The array's dimension.  
`real`
The real part of a complex array.  
`shape`
The shape of the array as a Python tuple.  
`size`
Number of elements in the array.  
# mlx.core.array.imag
property array.imag
    
The imaginary part of a complex array.
# mlx.core.array.item
array.item(self) → scalar
    
Access the value of a scalar array.
Returns:
    
Standard Python scalar.
# mlx.core.array.itemsize
property array.itemsize
    
The size of the array’s datatype in bytes.
# mlx.core.array.log
array.log(self, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `log()`.
# mlx.core.array.log10
array.log10(self, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `log10()`.
# mlx.core.array.log1p
array.log1p(self, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `log1p()`.
# mlx.core.array.log2
array.log2(self, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `log2()`.
# mlx.core.array.logcumsumexp
array.logcumsumexp(self, axis: Optional[int] = None, *, reverse: bool = False, inclusive: bool = True, stream: Optional[Union[Stream, Device]] = None) → array
    
See `logcumsumexp()`.
# mlx.core.array.logsumexp
array.logsumexp(self, axis: Optional[Union[int, Sequence[int]]] = None, keepdims: bool = False, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `logsumexp()`.
# mlx.core.array.max
array.max(self, axis: Optional[Union[int, Sequence[int]]] = None, keepdims: bool = False, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `max()`.
# mlx.core.array.mean
array.mean(self, axis: Optional[Union[int, Sequence[int]]] = None, keepdims: bool = False, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `mean()`.
# mlx.core.array.min
array.min(self, axis: Optional[Union[int, Sequence[int]]] = None, keepdims: bool = False, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `min()`.
# mlx.core.array.moveaxis
array.moveaxis(self, source: int, destination: int, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `moveaxis()`.
# mlx.core.array.nbytes
property array.nbytes
    
The number of bytes in the array.
# mlx.core.array.ndim
property array.ndim
    
The array’s dimension.
# mlx.core.array.prod
array.prod(self, axis: Optional[Union[int, Sequence[int]]] = None, keepdims: bool = False, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `prod()`.
# mlx.core.array.real
property array.real
    
The real part of a complex array.
# mlx.core.array.reciprocal
array.reciprocal(self, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `reciprocal()`.
# mlx.core.array.reshape
array.reshape(self, *shape, stream: Optional[Union[Stream, Device]] = None) → array
    
Equivalent to `reshape()` but the shape can be passed either as a `tuple` or as separate arguments.
See `reshape()` for full documentation.
# mlx.core.array.round
array.round(self, decimals: int = 0, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `round()`.
# mlx.core.array.rsqrt
array.rsqrt(self, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `rsqrt()`.
# mlx.core.array.shape
property array.shape
    
The shape of the array as a Python tuple.
Returns:
    
A tuple containing the sizes of each dimension.
Return type:
    
tuple(int)
# mlx.core.array.sin
array.sin(self, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `sin()`.
# mlx.core.array.size
property array.size
    
Number of elements in the array.
# mlx.core.array.split
array.split(self, indices_or_sections: Union[int, tuple[int, ...]], axis: int = 0, *, stream: Optional[Union[Stream, Device]] = None) → list[array]
    
See `split()`.
# mlx.core.array.sqrt
array.sqrt(self, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `sqrt()`.
# mlx.core.array.square
array.square(self, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `square()`.
# mlx.core.array.squeeze
array.squeeze(self, axis: Optional[Union[int, Sequence[int]]] = None, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `squeeze()`.
# mlx.core.array.std
array.std(self, axis: Optional[Union[int, Sequence[int]]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `std()`.
# mlx.core.array.sum
array.sum(self, axis: Optional[Union[int, Sequence[int]]] = None, keepdims: bool = False, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `sum()`.
# mlx.core.array.swapaxes
array.swapaxes(self, axis1: int, axis2: int, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `swapaxes()`.
# mlx.core.array.tolist
array.tolist(self) → list_or_scalar
    
Convert the array to a Python `list`.
Returns:
    
The Python list.
If the array is a scalar then a standard Python scalar is returned.
If the array has more than one dimension then the result is a nested list of lists.
The value type of the list corresponding to the last dimension is either `bool`, `int` or `float` depending on the `dtype` of the array.
Return type:
    
list
# mlx.core.array.transpose
array.transpose(self, *axes, stream: Optional[Union[Stream, Device]] = None) → array
    
Equivalent to `transpose()` but the axes can be passed either as a tuple or as separate arguments.
See `transpose()` for full documentation.
# mlx.core.array.var
array.var(self, axis: Optional[Union[int, Sequence[int]]] = None, keepdims: bool = False, ddof: int = 0, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `var()`.
# mlx.core.array.view
array.view(self, dtype: Dtype, *, stream: Optional[Union[Stream, Device]] = None) → array
    
See `view()`.
# mlx.core.array_equal
array_equal(a: scalar | array, b: scalar | array, equal_nan: bool = False, stream: None | Stream | Device = None) → array
    
Array equality check.
Compare two arrays for equality. Returns `True` if and only if the arrays have the same shape and their values are equal. The arrays need not have the same type to be considered equal.
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.
  * equal_nan (bool) – If `True`, NaNs are considered equal. Defaults to `False`.


Returns:
    
A scalar boolean array.
Return type:
    
array
# mlx.core.as_strided
as_strided(a: array, /, shape: Sequence[int] | None = None, strides: Sequence[int] | None = None, offset: int = 0, *, stream: None | Stream | Device = None) → array
    
Create a view into the array with the given shape and strides.
The resulting array will always be as if the provided array was row contiguous regardless of the provided arrays storage order and current strides.
Note
Note that this function should be used with caution as it changes the shape and strides of the array directly. This can lead to the resulting array pointing to invalid memory locations which can result into crashes.
Parameters:
    
  * a (array) – Input array
  * shape (list(int), optional) – The shape of the resulting array. If None it defaults to `a.shape()`.
  * strides (list(int), optional) – The strides of the resulting array. If None it defaults to the reverse exclusive cumulative product of `a.shape()`.
  * offset (int) – Skip that many elements from the beginning of the input array.


Returns:
    
The output array which is the strided view of the input.
Return type:
    
array
# mlx.core.async_eval
async_eval(*args)
    
Asynchronously evaluate an `array` or tree of `array`.
Note
This is an experimental API and may change in future versions.
Parameters:
    
*args (arrays or trees of arrays) – Each argument can be a single array or a tree of arrays. If a tree is given the nodes can be a Python `list`, `tuple` or `dict`. Leaves which are not arrays are ignored.
Example
    
    >>> x = mx.array(1.0)
    >>> y = mx.exp(x)
    >>> mx.async_eval(y)
    >>> print(y)
    >>>
    >>> y = mx.exp(x)
    >>> mx.async_eval(y)
    >>> z = y + 3
    >>> mx.async_eval(z)
    >>> print(z)
    
# mlx.core.atleast_1d
atleast_1d(*arys: array, stream: None | Stream | Device = None) → array | list[array]
    
Convert all arrays to have at least one dimension.
Parameters:
    
  * *arys – Input arrays.
  * stream (Union[None, Stream, Device], optional) – The stream to execute the operation on.


Returns:
    
An array or list of arrays with at least one dimension.
Return type:
    
array or list(array)
# mlx.core.atleast_2d
atleast_2d(*arys: array, stream: None | Stream | Device = None) → array | list[array]
    
Convert all arrays to have at least two dimensions.
Parameters:
    
  * *arys – Input arrays.
  * stream (Union[None, Stream, Device], optional) – The stream to execute the operation on.


Returns:
    
An array or list of arrays with at least two dimensions.
Return type:
    
array or list(array)
# mlx.core.atleast_3d
atleast_3d(*arys: array, stream: None | Stream | Device = None) → array | list[array]
    
Convert all arrays to have at least three dimensions.
Parameters:
    
  * *arys – Input arrays.
  * stream (Union[None, Stream, Device], optional) – The stream to execute the operation on.


Returns:
    
An array or list of arrays with at least three dimensions.
Return type:
    
array or list(array)
# mlx.core.bitwise_and
bitwise_and(a: scalar | array, b: scalar | array, stream: None | Stream | Device = None) → array
    
Element-wise bitwise and.
Take the bitwise and of two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.


Returns:
    
The bitwise and `a & b`.
Return type:
    
array
# mlx.core.bitwise_invert
bitwise_invert(a: scalar | array, stream: None | Stream | Device = None) → array
    
Element-wise bitwise inverse.
Take the bitwise complement of the input.
Parameters:
    
a (array) – Input array or scalar.
Returns:
    
The bitwise inverse `~a`.
Return type:
    
array
# mlx.core.bitwise_or
bitwise_or(a: scalar | array, b: scalar | array, stream: None | Stream | Device = None) → array
    
Element-wise bitwise or.
Take the bitwise or of two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.


Returns:
    
The bitwise or``a | b``.
Return type:
    
array
# mlx.core.bitwise_xor
bitwise_xor(a: scalar | array, b: scalar | array, stream: None | Stream | Device = None) → array
    
Element-wise bitwise xor.
Take the bitwise exclusive or of two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.


Returns:
    
The bitwise xor `a ^ b`.
Return type:
    
array
# mlx.core.block_masked_mm
block_masked_mm(a: array, b: array, /, block_size: int = 64, mask_out: array | None = None, mask_lhs: array | None = None, mask_rhs: array | None = None, *, stream: None | Stream | Device = None) → array
    
Matrix multiplication with block masking.
Perform the (possibly batched) matrix multiplication of two arrays and with blocks of size `block_size x block_size` optionally masked out.
Assuming `a` with shape (…, M, K) and b with shape (…, K, N)
  * `lhs_mask` must have shape (…, \\(\lceil\\) M / `block_size` \\(\rceil\\), \\(\lceil\\) K / `block_size` \\(\rceil\\))
  * `rhs_mask` must have shape (…, \\(\lceil\\) K / `block_size` \\(\rceil\\), \\(\lceil\\) N / `block_size` \\(\rceil\\))
  * `out_mask` must have shape (…, \\(\lceil\\) M / `block_size` \\(\rceil\\), \\(\lceil\\) N / `block_size` \\(\rceil\\))


Note: Only `block_size=64` and `block_size=32` are currently supported
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.
  * block_size (int) – Size of blocks to be masked. Must be `32` or `64`. Default: `64`.
  * mask_out (array, optional) – Mask for output. Default: `None`.
  * mask_lhs (array, optional) – Mask for `a`. Default: `None`.
  * mask_rhs (array, optional) – Mask for `b`. Default: `None`.


Returns:
    
The output array.
Return type:
    
array
# mlx.core.broadcast_arrays
broadcast_arrays(*arrays: array, stream: None | Stream | Device = None) → Tuple[array, ...]
    
Broadcast arrays against one another.
The broadcasting semantics are the same as Numpy.
Parameters:
    
*arrays (array) – The input arrays.
Returns:
    
The output arrays with the broadcasted shape.
Return type:
    
tuple(array)
# mlx.core.broadcast_to
broadcast_to(a: scalar | array, /, shape: Sequence[int], *, stream: None | Stream | Device = None) → array
    
Broadcast an array to the given shape.
The broadcasting semantics are the same as Numpy.
Parameters:
    
  * a (array) – Input array.
  * shape (list(int)) – The shape to broadcast to.


Returns:
    
The output array with the new shape.
Return type:
    
array
# mlx.core.ceil
ceil(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise ceil.
Parameters:
    
a (array) – Input array.
Returns:
    
The ceil of `a`.
Return type:
    
array
# mlx.core.clear_cache
clear_cache() → None
    
Clear the memory cache.
After calling this, `get_cache_memory()` should return `0`.
# mlx.core.clip
clip(a: array, /, a_min: scalar | array | None, a_max: scalar | array | None, *, stream: None | Stream | Device = None) → array
    
Clip the values of the array between the given minimum and maximum.
If either `a_min` or `a_max` are `None`, then corresponding edge is ignored. At least one of `a_min` and `a_max` cannot be `None`. The input `a` and the limits must broadcast with one another.
Parameters:
    
  * a (array) – Input array.
  * a_min (scalar or array or None) – Minimum value to clip to.
  * a_max (scalar or array or None) – Maximum value to clip to.


Returns:
    
The clipped array.
Return type:
    
array
# mlx.core.compile
compile(fun: Callable, inputs: object | None = None, outputs: object | None = None, shapeless: bool = False) → Callable
    
Returns a compiled function which produces the same output as `fun`.
Parameters:
    
  * fun (Callable) – A function which takes a variable number of `array` or trees of `array` and returns a variable number of `array` or trees of `array`.
  * inputs (list or dict, optional) – These inputs will be captured during the function compilation along with the inputs to `fun`. The `inputs` can be a `list` or a `dict` containing arbitrarily nested lists, dictionaries, or arrays. Leaf nodes that are not `array` are ignored. Default: `None`
  * outputs (list or dict, optional) – These outputs will be captured and updated in a compiled function. The `outputs` can be a `list` or a `dict` containing arbitrarily nested lists, dictionaries, or arrays. Leaf nodes that are not `array` are ignored. Default: `None`
  * shapeless (bool, optional) – A function compiled with the `shapeless` option enabled will not be recompiled when the input shape changes. Not all functions can be compiled with `shapeless` enabled. Attempting to compile such functions with shapeless enabled will throw. Note, changing the number of dimensions or type of any input will result in a recompilation even with `shapeless` set to `True`. Default: `False`


Returns:
    
A compiled function which has the same input arguments as `fun` and returns the the same output(s).
Return type:
    
Callable
# mlx.core.concatenate
concatenate(arrays: list[array], axis: int | None = 0, *, stream: None | Stream | Device = None) → array
    
Concatenate the arrays along the given axis.
Parameters:
    
  * arrays (list(array)) – Input `list` or `tuple` of arrays.
  * axis (int, optional) – Optional axis to concatenate along. If unspecified defaults to `0`.


Returns:
    
The concatenated array.
Return type:
    
array
# mlx.core.conj
conj(a: array, *, stream: None | Stream | Device = None) → array
    
Return the elementwise complex conjugate of the input. Alias for mx.conjugate.
Parameters:
    
a (array) – Input array
Returns:
    
The output array.
Return type:
    
array
# mlx.core.conjugate
conjugate(a: array, *, stream: None | Stream | Device = None) → array
    
Return the elementwise complex conjugate of the input. Alias for mx.conj.
Parameters:
    
a (array) – Input array
Returns:
    
The output array.
Return type:
    
array
# mlx.core.contiguous
contiguous(a: array, /, allow_col_major: bool = False, *, stream: None | Stream | Device = None) → array
    
Force an array to be row contiguous. Copy if necessary.
Parameters:
    
  * a (array) – The input to make contiguous
  * allow_col_major (bool) – Consider column major as contiguous and don’t copy


Returns:
    
The row or col contiguous output.
Return type:
    
array
# mlx.core.conv1d
conv1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, *, stream: None | Stream | Device = None) → array
    
1D convolution over an input with several channels
Parameters:
    
  * input (array) – Input array of shape `(N, L, C_in)`.
  * weight (array) – Weight array of shape `(C_out, K, C_in)`.
  * stride (int, optional) – Kernel stride. Default: `1`.
  * padding (int, optional) – Input padding. Default: `0`.
  * dilation (int, optional) – Kernel dilation. Default: `1`.
  * groups (int, optional) – Input feature groups. Default: `1`.


Returns:
    
The convolved array.
Return type:
    
array
# mlx.core.conv2d
conv2d(input: array, weight: array, /, stride: int | tuple[int, int] = 1, padding: int | tuple[int, int] = 0, dilation: int | tuple[int, int] = 1, groups: int = 1, *, stream: None | Stream | Device = None) → array
    
2D convolution over an input with several channels
Parameters:
    
  * input (array) – Input array of shape `(N, H, W, C_in)`.
  * weight (array) – Weight array of shape `(C_out, KH, KW, C_in)`.
  * stride (int or tuple(int), optional) – `tuple` of size 2 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: `1`.
  * padding (int or tuple(int), optional) – `tuple` of size 2 with symmetric input padding. All spatial dimensions get the same padding if only one number is specified. Default: `0`.
  * dilation (int or tuple(int), optional) – `tuple` of size 2 with kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: `1`
  * groups (int, optional) – input feature groups. Default: `1`.


Returns:
    
The convolved array.
Return type:
    
array
# mlx.core.conv3d
conv3d(input: array, weight: array, /, stride: int | tuple[int, int, int] = 1, padding: int | tuple[int, int, int] = 0, dilation: int | tuple[int, int, int] = 1, groups: int = 1, *, stream: None | Stream | Device = None) → array
    
3D convolution over an input with several channels
Note: Only the default `groups=1` is currently supported.
Parameters:
    
  * input (array) – Input array of shape `(N, D, H, W, C_in)`.
  * weight (array) – Weight array of shape `(C_out, KD, KH, KW, C_in)`.
  * stride (int or tuple(int), optional) – `tuple` of size 3 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: `1`.
  * padding (int or tuple(int), optional) – `tuple` of size 3 with symmetric input padding. All spatial dimensions get the same padding if only one number is specified. Default: `0`.
  * dilation (int or tuple(int), optional) – `tuple` of size 3 with kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: `1`
  * groups (int, optional) – input feature groups. Default: `1`.


Returns:
    
The convolved array.
Return type:
    
array
# mlx.core.conv_general
conv_general(input: array, weight: array, /, stride: int | Sequence[int] = 1, padding: int | Sequence[int] | tuple[Sequence[int], Sequence[int]] = 0, kernel_dilation: int | Sequence[int] = 1, input_dilation: int | Sequence[int] = 1, groups: int = 1, flip: bool = False, *, stream: None | Stream | Device = None) → array
    
General convolution over an input with several channels
Parameters:
    
  * input (array) – Input array of shape `(N, ..., C_in)`.
  * weight (array) – Weight array of shape `(C_out, ..., C_in)`.
  * stride (int or list(int), optional) – `list` with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: `1`.
  * padding (int, list(int), or tuple(list(int), list(int)), optional) – `list` with input padding. All spatial dimensions get the same padding if only one number is specified. Default: `0`.
  * kernel_dilation (int or list(int), optional) – `list` with kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: `1`
  * input_dilation (int or list(int), optional) – `list` with input dilation. All spatial dimensions get the same dilation if only one number is specified. Default: `1`
  * groups (int, optional) – Input feature groups. Default: `1`.
  * flip (bool, optional) – Flip the order in which the spatial dimensions of the weights are processed. Performs the cross-correlation operator when `flip` is `False` and the convolution operator otherwise. Default: `False`.


Returns:
    
The convolved array.
Return type:
    
array
# mlx.core.conv_transpose1d
conv_transpose1d(input: array, weight: array, /, stride: int = 1, padding: int = 0, dilation: int = 1, output_padding: int = 0, groups: int = 1, *, stream: None | Stream | Device = None) → array
    
1D transposed convolution over an input with several channels
Parameters:
    
  * input (array) – Input array of shape `(N, L, C_in)`.
  * weight (array) – Weight array of shape `(C_out, K, C_in)`.
  * stride (int, optional) – Kernel stride. Default: `1`.
  * padding (int, optional) – Input padding. Default: `0`.
  * dilation (int, optional) – Kernel dilation. Default: `1`.
  * output_padding (int, optional) – Output padding. Default: `0`.
  * groups (int, optional) – Input feature groups. Default: `1`.


Returns:
    
The convolved array.
Return type:
    
array
# mlx.core.conv_transpose2d
conv_transpose2d(input: array, weight: array, /, stride: int | Tuple[int, int] = 1, padding: int | Tuple[int, int] = 0, dilation: int | Tuple[int, int] = 1, output_padding: int | Tuple[int, int] = 0, groups: int = 1, *, stream: None | Stream | Device = None) → array
    
2D transposed convolution over an input with several channels
Note: Only the default `groups=1` is currently supported.
Parameters:
    
  * input (array) – Input array of shape `(N, H, W, C_in)`.
  * weight (array) – Weight array of shape `(C_out, KH, KW, C_in)`.
  * stride (int or tuple(int), optional) – `tuple` of size 2 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: `1`.
  * padding (int or tuple(int), optional) – `tuple` of size 2 with symmetric input padding. All spatial dimensions get the same padding if only one number is specified. Default: `0`.
  * dilation (int or tuple(int), optional) – `tuple` of size 2 with kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: `1`
  * output_padding (int or tuple(int), optional) – `tuple` of size 2 with output padding. All spatial dimensions get the same output padding if only one number is specified. Default: `0`.
  * groups (int, optional) – input feature groups. Default: `1`.


Returns:
    
The convolved array.
Return type:
    
array
# mlx.core.conv_transpose3d
conv_transpose3d(input: array, weight: array, /, stride: int | Tuple[int, int, int] = 1, padding: int | Tuple[int, int, int] = 0, dilation: int | Tuple[int, int, int] = 1, output_padding: int | Tuple[int, int, int] = 0, groups: int = 1, *, stream: None | Stream | Device = None) → array
    
3D transposed convolution over an input with several channels
Note: Only the default `groups=1` is currently supported.
Parameters:
    
  * input (array) – Input array of shape `(N, D, H, W, C_in)`.
  * weight (array) – Weight array of shape `(C_out, KD, KH, KW, C_in)`.
  * stride (int or tuple(int), optional) – `tuple` of size 3 with kernel strides. All spatial dimensions get the same stride if only one number is specified. Default: `1`.
  * padding (int or tuple(int), optional) – `tuple` of size 3 with symmetric input padding. All spatial dimensions get the same padding if only one number is specified. Default: `0`.
  * dilation (int or tuple(int), optional) – `tuple` of size 3 with kernel dilation. All spatial dimensions get the same dilation if only one number is specified. Default: `1`
  * output_padding (int or tuple(int), optional) – `tuple` of size 3 with output padding. All spatial dimensions get the same output padding if only one number is specified. Default: `0`.
  * groups (int, optional) – input feature groups. Default: `1`.


Returns:
    
The convolved array.
Return type:
    
array
# mlx.core.convolve
convolve(a: array, v: array, /, mode: str = 'full', *, stream: None | Stream | Device = None) → array
    
The discrete convolution of 1D arrays.
If `v` is longer than `a`, then they are swapped. The conv filter is flipped following signal processing convention.
Parameters:
    
  * a (array) – 1D Input array.
  * v (array) – 1D Input array.
  * mode (str, optional) – {‘full’, ‘valid’, ‘same’}


Returns:
    
The convolved array.
Return type:
    
array
# mlx.core.cos
cos(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise cosine.
Parameters:
    
a (array) – Input array.
Returns:
    
The cosine of `a`.
Return type:
    
array
# mlx.core.cosh
cosh(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise hyperbolic cosine.
Parameters:
    
a (array) – Input array.
Returns:
    
The hyperbolic cosine of `a`.
Return type:
    
array
# mlx.core.cuda.is_available
is_available() → bool
    
Check if the CUDA back-end is available.
# mlx.core.cummax
cummax(a: array, /, axis: int | None = None, *, reverse: bool = False, inclusive: bool = True, stream: None | Stream | Device = None) → array
    
Return the cumulative maximum of the elements along the given axis.
Parameters:
    
  * a (array) – Input array
  * axis (int, optional) – Optional axis to compute the cumulative maximum over. If unspecified the cumulative maximum of the flattened array is returned.
  * reverse (bool) – Perform the cumulative maximum in reverse.
  * inclusive (bool) – The i-th element of the output includes the i-th element of the input.


Returns:
    
The output array.
Return type:
    
array
# mlx.core.cummin
cummin(a: array, /, axis: int | None = None, *, reverse: bool = False, inclusive: bool = True, stream: None | Stream | Device = None) → array
    
Return the cumulative minimum of the elements along the given axis.
Parameters:
    
  * a (array) – Input array
  * axis (int, optional) – Optional axis to compute the cumulative minimum over. If unspecified the cumulative minimum of the flattened array is returned.
  * reverse (bool) – Perform the cumulative minimum in reverse.
  * inclusive (bool) – The i-th element of the output includes the i-th element of the input.


Returns:
    
The output array.
Return type:
    
array
# mlx.core.cumprod
cumprod(a: array, /, axis: int | None = None, *, reverse: bool = False, inclusive: bool = True, stream: None | Stream | Device = None) → array
    
Return the cumulative product of the elements along the given axis.
Parameters:
    
  * a (array) – Input array
  * axis (int, optional) – Optional axis to compute the cumulative product over. If unspecified the cumulative product of the flattened array is returned.
  * reverse (bool) – Perform the cumulative product in reverse.
  * inclusive (bool) – The i-th element of the output includes the i-th element of the input.


Returns:
    
The output array.
Return type:
    
array
# mlx.core.cumsum
cumsum(a: array, /, axis: int | None = None, *, reverse: bool = False, inclusive: bool = True, stream: None | Stream | Device = None) → array
    
Return the cumulative sum of the elements along the given axis.
Parameters:
    
  * a (array) – Input array
  * axis (int, optional) – Optional axis to compute the cumulative sum over. If unspecified the cumulative sum of the flattened array is returned.
  * reverse (bool) – Perform the cumulative sum in reverse.
  * inclusive (bool) – The i-th element of the output includes the i-th element of the input.


Returns:
    
The output array.
Return type:
    
array
# mlx.core.custom_function
class custom_function
    
Set up a function for custom gradient and vmap definitions.
This class is meant to be used as a function decorator. Instances are callables that behave identically to the wrapped function. However, when a function transformation is used (e.g. computing gradients using `value_and_grad()`) then the functions defined via `custom_function.vjp()`, `custom_function.jvp()` and `custom_function.vmap()` are used instead of the default transformation.
Note, all custom transformations are optional. Undefined transformations fall back to the default behaviour.
Example
    
    import mlx.core as mx
    
    @mx.custom_function
    def f(x, y):
        return mx.sin(x) * y
    
    @f.vjp
    def f_vjp(primals, cotangent, output):
        x, y = primals
        return cotan * mx.cos(x) * y, cotan * mx.sin(x)
    
    @f.jvp
    def f_jvp(primals, tangents):
      x, y = primals
      dx, dy = tangents
      return dx * mx.cos(x) * y + dy * mx.sin(x)
    
    @f.vmap
    def f_vmap(inputs, axes):
      x, y = inputs
      ax, ay = axes
      if ay != ax and ax is not None:
          y = y.swapaxes(ay, ax)
      return mx.sin(x) * y, (ax or ay)
    
All `custom_function` instances behave as pure functions. Namely, any variables captured will be treated as constants and no gradients will be computed with respect to the captured arrays. For instance:
> 
>     import mlx.core as mx
>     
>     def g(x, y):
>       @mx.custom_function
>       def f(x):
>         return x * y
>     
>       @f.vjp
>       def f_vjp(x, dx, fx):
>         # Note that we have only x, dx and fx and nothing with respect to y
>         raise ValueError("Abort!")
>     
>       return f(x)
>     
>     x = mx.array(2.0)
>     y = mx.array(3.0)
>     print(g(x, y))                     # prints 6.0
>     print(mx.grad(g)(x, y))            # Raises exception
>     print(mx.grad(g, argnums=1)(x, y)) # prints 0.0
>     
__init__(self, f: Callable)
    
Methods
`__init__`(self, f)  
`jvp`(self, f)
Define a custom jvp for the wrapped function.  
`vjp`(self, f)
Define a custom vjp for the wrapped function.  
`vmap`(self, f)
Define a custom vectorization transformation for the wrapped function.  
# mlx.core.default_device
default_device() → Device
    
Get the default device.
# mlx.core.default_stream
default_stream(device: Device) → Stream
    
Get the device’s default stream.
# mlx.core.degrees
degrees(a: array, /, *, stream: None | Stream | Device = None) → array
    
Convert angles from radians to degrees.
Parameters:
    
a (array) – Input array.
Returns:
    
The angles in degrees.
Return type:
    
array
# mlx.core.dequantize
dequantize(w: array, /, scales: array, biases: array | None = None, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: None | Stream | Device = None) → array
    
Dequantize the matrix `w` using quantization parameters.
Parameters:
    
  * w (array) – Matrix to be dequantized
  * scales (array) – The scales to use per `group_size` elements of `w`.
  * biases (array, optional) – The biases to use per `group_size` elements of `w`. Default: `None`.
  * group_size (int, optional) – The size of the group in `w` that shares a scale and bias. Default: `64`.
  * bits (int, optional) – The number of bits occupied by each element in `w`. Default: `4`.
  * mode (str, optional) – The quantization mode. Default: `"affine"`.


Returns:
    
The dequantized version of `w`
Return type:
    
array
Notes
The currently supported quantization modes are `"affine"` and `mxfp4`.
For `affine` quantization, given the notation in `quantize()`, we compute \\(w_i\\) from \\(\hat{w_i}\\) and corresponding \\(s\\) and \\(\beta\\) as follows
\\[w_i = s \hat{w_i} + \beta\\]
# mlx.core.diag
diag(a: array, /, k: int = 0, *, stream: None | Stream | Device = None) → array
    
Extract a diagonal or construct a diagonal matrix. If `a` is 1-D then a diagonal matrix is constructed with `a` on the \\(k\\)-th diagonal. If `a` is 2-D then the \\(k\\)-th diagonal is returned.
Parameters:
    
  * a (array) – 1-D or 2-D input array.
  * k (int, optional) – The diagonal to extract or construct. Default: `0`.


Returns:
    
The extracted diagonal or the constructed diagonal matrix.
Return type:
    
array
# mlx.core.diagonal
diagonal(a: array, offset: int = 0, axis1: int = 0, axis2: int = 1, stream: None | Stream | Device = None) → array
    
Return specified diagonals.
If `a` is 2-D, then a 1-D array containing the diagonal at the given `offset` is returned.
If `a` has more than two dimensions, then `axis1` and `axis2` determine the 2D subarrays from which diagonals are extracted. The new shape is the original shape with `axis1` and `axis2` removed and a new dimension inserted at the end corresponding to the diagonal.
Parameters:
    
  * a (array) – Input array
  * offset (int, optional) – Offset of the diagonal from the main diagonal. Can be positive or negative. Default: `0`.
  * axis1 (int, optional) – The first axis of the 2-D sub-arrays from which the diagonals should be taken. Default: `0`.
  * axis2 (int, optional) – The second axis of the 2-D sub-arrays from which the diagonals should be taken. Default: `1`.


Returns:
    
The diagonals of the array.
Return type:
    
array
# mlx.core.disable_compile
disable_compile() → None
    
Globally disable compilation. Setting the environment variable `MLX_DISABLE_COMPILE` can also be used to disable compilation.
# mlx.core.distributed.Group
class Group
    
An `mlx.core.distributed.Group` represents a group of independent mlx processes that can communicate.
__init__(*args, **kwargs)
    
Methods
`__init__`(*args, **kwargs)  
`rank`(self)
Get the rank of this process  
`size`(self)
Get the size of the group  
`split`(self, color[, key])
Split the group to subgroups based on the provided color.  
# mlx.core.distributed.all_gather
all_gather(x: array, *, group: Group | None = None, stream: None | Stream | Device = None) → array
    
Gather arrays from all processes.
Gather the `x` arrays from all processes in the group and concatenate them along the first axis. The arrays should all have the same shape.
Parameters:
    
  * x (array) – Input array.
  * group (Group) – The group of processes that will participate in the gather. If set to `None` the global group is used. Default: `None`.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
The concatenation of all `x` arrays.
Return type:
    
array
# mlx.core.distributed.all_sum
all_sum(x: array, *, group: Group | None = None, stream: None | Stream | Device = None) → array
    
All reduce sum.
Sum the `x` arrays from all processes in the group.
Parameters:
    
  * x (array) – Input array.
  * group (Group) – The group of processes that will participate in the reduction. If set to `None` the global group is used. Default: `None`.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
The sum of all `x` arrays.
Return type:
    
array
# mlx.core.distributed.init
init(strict: bool = False, backend: str = 'any') → Group
    
Initialize the communication backend and create the global communication group.
Example
    
    import mlx.core as mx
    
    group = mx.distributed.init(backend="ring")
    
Parameters:
    
  * strict (bool, optional) – If set to False it returns a singleton group in case `mx.distributed.is_available()` returns False otherwise it throws a runtime error. Default: `False`
  * backend (str, optional) – Which distributed backend to initialize. Possible values `mpi`, `ring`, `nccl`, `any`. If set to `any` all available backends are tried and the first one that succeeds becomes the global group which will be returned in subsequent calls. Default: `any`


Returns:
    
The group representing all the launched processes.
Return type:
    
Group
# mlx.core.distributed.is_available
is_available() → bool
    
Check if a communication backend is available.
# mlx.core.distributed.recv
recv(shape: Sequence[int], dtype: Dtype, src: int, *, group: Group | None = None, stream: None | Stream | Device = None) → array
    
Recv an array with shape `shape` and dtype `dtype` from process with rank `src`.
Parameters:
    
  * shape (Tuple[int]) – The shape of the array we are receiving.
  * dtype (Dtype) – The data type of the array we are receiving.
  * src (int) – Rank of the source process in the group.
  * group (Group) – The group of processes that will participate in the recv. If set to `None` the global group is used. Default: `None`.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
The array that was received from `src`.
Return type:
    
array
# mlx.core.distributed.recv_like
recv_like(x: array, src: int, *, group: Group | None = None, stream: None | Stream | Device = None) → array
    
Recv an array with shape and type like `x` from process with rank `src`.
It is equivalent to calling `mx.distributed.recv(x.shape, x.dtype, src)`.
Parameters:
    
  * x (array) – An array defining the shape and dtype of the array we are receiving.
  * src (int) – Rank of the source process in the group.
  * group (Group) – The group of processes that will participate in the recv. If set to `None` the global group is used. Default: `None`.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
The array that was received from `src`.
Return type:
    
array
# mlx.core.distributed.send
send(x: array, dst: int, *, group: Group | None = None, stream: None | Stream | Device = None) → array
    
Send an array from the current process to the process that has rank `dst` in the group.
Parameters:
    
  * x (array) – Input array.
  * dst (int) – Rank of the destination process in the group.
  * group (Group) – The group of processes that will participate in the sned. If set to `None` the global group is used. Default: `None`.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
An array identical to `x` which when evaluated the send is performed.
Return type:
    
array
# mlx.core.divide
divide(a: scalar | array, b: scalar | array, stream: None | Stream | Device = None) → array
    
Element-wise division.
Divide two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.


Returns:
    
The quotient `a / b`.
Return type:
    
array
# mlx.core.divmod
divmod(a: scalar | array, b: scalar | array, stream: None | Stream | Device = None) → array
    
Element-wise quotient and remainder.
The fuction `divmod(a, b)` is equivalent to but faster than `(a // b, a % b)`. The function uses numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.


Returns:
    
The quotient `a // b` and remainder `a % b`.
Return type:
    
tuple(array, array)
# mlx.core.einsum
einsum(subscripts: str, *operands, stream: None | Stream | Device = None) → array
    
Perform the Einstein summation convention on the operands.
Parameters:
    
  * subscripts (str) – The Einstein summation convention equation.
  * *operands (array) – The input arrays.


Returns:
    
The output array.
Return type:
    
array
# mlx.core.einsum_path
einsum_path(subscripts: str, *operands)
    
Compute the contraction order for the given Einstein summation.
Parameters:
    
  * subscripts (str) – The Einstein summation convention equation.
  * *operands (array) – The input arrays.


Returns:
    
The einsum path and a string containing information about the chosen path.
Return type:
    
tuple(list(tuple(int, int)), str)
# mlx.core.enable_compile
enable_compile() → None
    
Globally enable compilation. This will override the environment variable `MLX_DISABLE_COMPILE` if set.
# mlx.core.equal
equal(a: scalar | array, b: scalar | array, stream: None | Stream | Device = None) → array
    
Element-wise equality.
Equality comparison on two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.


Returns:
    
The element-wise comparison `a == b`.
Return type:
    
array
# mlx.core.erf
erf(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise error function.
\\[\mathrm{erf}(x) = \frac{2}{\sqrt{\pi}} \int_0^x e^{-t^2} \, dt\\]
Parameters:
    
a (array) – Input array.
Returns:
    
The error function of `a`.
Return type:
    
array
# mlx.core.erfinv
erfinv(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise inverse of `erf()`.
Parameters:
    
a (array) – Input array.
Returns:
    
The inverse error function of `a`.
Return type:
    
array
# mlx.core.eval
eval(*args) → None
    
Evaluate an `array` or tree of `array`.
Parameters:
    
*args (arrays or trees of arrays) – Each argument can be a single array or a tree of arrays. If a tree is given the nodes can be a Python `list`, `tuple` or `dict`. Leaves which are not arrays are ignored.
# mlx.core.exp
exp(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise exponential.
Parameters:
    
a (array) – Input array.
Returns:
    
The exponential of `a`.
Return type:
    
array
# mlx.core.expand_dims
expand_dims(a: array, /, axis: int | Sequence[int], *, stream: None | Stream | Device = None) → array
    
Add a size one dimension at the given axis.
Parameters:
    
  * a (array) – Input array.
  * axes (int or tuple(int)) – The index of the inserted dimensions.


Returns:
    
The array with inserted dimensions.
Return type:
    
array
# mlx.core.expm1
expm1(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise exponential minus 1.
Computes `exp(x) - 1` with greater precision for small `x`.
Parameters:
    
a (array) – Input array.
Returns:
    
The expm1 of `a`.
Return type:
    
array
# mlx.core.export_function
export_function(file: str, fun: Callable, *args, shapeless: bool = False, **kwargs) → None
    
Export a function to a file.
Example input arrays must be provided to export a function. The example inputs can be variable `*args` and `**kwargs` or a tuple of arrays and/or dictionary of string keys with array values.
Warning
This is part of an experimental API which is likely to change in future versions of MLX. Functions exported with older versions of MLX may not be compatible with future versions.
Parameters:
    
  * file (str) – File path to export the function to.
  * fun (Callable) – A function which takes as input zero or more `array` and returns one or more `array`.
  * *args (array) – Example array inputs to the function.
  * shapeless (bool, optional) – Whether or not the function allows inputs with variable shapes. Default: `False`.
  * **kwargs (array) – Additional example keyword array inputs to the function.


Example
    
    def fun(x, y):
        return x + y
    
    x = mx.array(1)
    y = mx.array([1, 2, 3])
    mx.export_function("fun.mlxfn", fun, x, y=y)
    
# mlx.core.export_to_dot
export_to_dot(file: object, *args, **kwargs) → None
    
Export a graph to DOT format for visualization.
A variable number of output arrays can be provided for exporting The graph exported will recursively include all unevaluated inputs of the provided outputs.
Parameters:
    
  * file (str) – The file path to export to.
  * *args (array) – The output arrays.
  * **kwargs (dict[str, array]) – Provide some names for arrays in the graph to make the result easier to parse.


Example
    
    >>> a = mx.array(1) + mx.array(2)
    >>> mx.export_to_dot("graph.dot", a)
    >>> x = mx.array(1)
    >>> y = mx.array(2)
    >>> mx.export_to_dot("graph.dot", x + y, x=x, y=y)
    
# mlx.core.exporter
exporter(file: str, fun: Callable, *, shapeless: bool = False) → mlx.core.FunctionExporter
    
Make a callable object to export multiple traces of a function to a file.
Warning
This is part of an experimental API which is likely to change in future versions of MLX. Functions exported with older versions of MLX may not be compatible with future versions.
Parameters:
    
  * file (str) – File path to export the function to.
  * shapeless (bool, optional) – Whether or not the function allows inputs with variable shapes. Default: `False`.


Example
    
    def fun(*args):
        return sum(args)
    
    with mx.exporter("fun.mlxfn", fun) as exporter:
        exporter(mx.array(1))
        exporter(mx.array(1), mx.array(2))
        exporter(mx.array(1), mx.array(2), mx.array(3))
    
# mlx.core.eye
eye(n: int, m: int | None = None, k: int = 0, dtype: Dtype | None = float32, *, stream: None | Stream | Device = None) → array
    
Create an identity matrix or a general diagonal matrix.
Parameters:
    
  * n (int) – The number of rows in the output.
  * m (int, optional) – The number of columns in the output. Defaults to n.
  * k (int, optional) – Index of the diagonal. Defaults to 0 (main diagonal).
  * dtype (Dtype, optional) – Data type of the output array. Defaults to float32.
  * stream (Stream, optional) – Stream or device. Defaults to None.


Returns:
    
An array where all elements are equal to zero, except for the k-th diagonal, whose values are equal to one.
Return type:
    
array
# mlx.core.fast.cuda_kernel
cuda_kernel(name: str, input_names: Sequence[str], output_names: Sequence[str], source: str, header: str = '', ensure_row_contiguous: bool = True, shared_memory: int = 0) → object
    
A jit-compiled custom CUDA kernel defined from a source string.
This is the CUDA equivalent of Custom Metal Kernels.
Parameters:
    
  * name (str) – Name for the kernel.
  * input_names (List[str]) – The parameter names of the inputs in the function signature.
  * output_names (List[str]) – The parameter names of the outputs in the function signature.
  * source (str) – Source code. This is the body of a function in CUDA, the function signature will be automatically generated.
  * header (str) – Header source code to include before the main function. Useful for helper functions or includes that should live outside of the main function body.
  * ensure_row_contiguous (bool) – Whether to ensure the inputs are row contiguous before the kernel runs. Default: `True`.
  * shared_memory (int) – The dynamic shared memory to request for the kernel. A value of 0 means no dynamic shared memory. Default: `0`.


Returns:
    
Callable `cuda_kernel`.
Example
    
    def exp_elementwise(a: mx.array):
        source = '''
            auto elem = cooperative_groups::this_grid().thread_rank();
            T tmp = inp[elem];
            out[elem] = exp(tmp);
        '''
    
        kernel = mx.fast.cuda_kernel(
            name="myexp",
            input_names=["inp"],
            output_names=["out"],
            source=source
        )
        outputs = kernel(
            inputs=[a],
            template=[("T", mx.float32)],
            grid=(a.size, 1, 1),
            threadgroup=(256, 1, 1),
            output_shapes=[a.shape],
            output_dtypes=[a.dtype],
            verbose=True,
        )
        return outputs[0]
    
    a = mx.random.normal(shape=(16, 16)).astype(mx.float16)
    b = exp_elementwise(a)
    assert mx.allclose(b, mx.exp(a))
    
# mlx.core.fast.layer_norm
layer_norm(x: array, weight: array | None, bias: array | None, eps: float, *, stream: None | Stream | Device = None) → array
    
Layer normalization.
The normalization is with respect to the last axis of the input `x`.
Parameters:
    
  * x (array) – Input array.
  * weight (array, optional) – A multiplicative weight to scale the result by. The `weight` should be one-dimensional with the same size as the last axis of `x`. If set to `None` then no scaling happens.
  * bias (array, optional) – An additive offset to be added to the result. The `bias` should be one-dimensional with the same size as the last axis of `x`. If set to `None` then no translation happens.
  * eps (float) – A small additive constant for numerical stability.


Returns:
    
The output array.
Return type:
    
array
# mlx.core.fast.metal_kernel
metal_kernel(name: str, input_names: Sequence[str], output_names: Sequence[str], source: str, header: str = '', ensure_row_contiguous: bool = True, atomic_outputs: bool = False) → object
    
A jit-compiled custom Metal kernel defined from a source string.
Full documentation: Custom Metal Kernels.
Parameters:
    
  * name (str) – Name for the kernel.
  * input_names (List[str]) – The parameter names of the inputs in the function signature.
  * output_names (List[str]) – The parameter names of the outputs in the function signature.
  * source (str) – Source code. This is the body of a function in Metal, the function signature will be automatically generated.
  * header (str) – Header source code to include before the main function. Useful for helper functions or includes that should live outside of the main function body.
  * ensure_row_contiguous (bool) – Whether to ensure the inputs are row contiguous before the kernel runs. Default: `True`.
  * atomic_outputs (bool) – Whether to use atomic outputs in the function signature e.g. `device atomic<float>`. Default: `False`.


Returns:
    
Callable `metal_kernel`.
Example
    
    def exp_elementwise(a: mx.array):
        source = '''
            uint elem = thread_position_in_grid.x;
            T tmp = inp[elem];
            out[elem] = metal::exp(tmp);
        '''
    
        kernel = mx.fast.metal_kernel(
            name="myexp",
            input_names=["inp"],
            output_names=["out"],
            source=source
        )
        outputs = kernel(
            inputs=[a],
            template=[("T", mx.float32)],
            grid=(a.size, 1, 1),
            threadgroup=(256, 1, 1),
            output_shapes=[a.shape],
            output_dtypes=[a.dtype],
            verbose=True,
        )
        return outputs[0]
    
    a = mx.random.normal(shape=(4, 16)).astype(mx.float16)
    b = exp_elementwise(a)
    assert mx.allclose(b, mx.exp(a))
    
# mlx.core.fast.rms_norm
rms_norm(x: array, weight: array | None, eps: float, *, stream: None | Stream | Device = None) → array
    
Root Mean Square normalization (RMS norm).
The normalization is with respect to the last axis of the input `x`.
Parameters:
    
  * x (array) – Input array.
  * weight (array, optional) – A multiplicative weight to scale the result by. The `weight` should be one-dimensional with the same size as the last axis of `x`. If set to `None` then no scaling happens.
  * eps (float) – A small additive constant for numerical stability.


Returns:
    
The output array.
Return type:
    
array
# mlx.core.fast.rope
rope(a: array, dims: int, *, traditional: bool, base: float | None, scale: float, offset: int | array, freqs: array | None = None, stream: None | Stream | Device = None) → array
    
Apply rotary positional encoding to the input.
The input is expected to be at least 3D with shape `(B, *, T, D)` where:
    
  * `B` is the batch size.
  * `T` is the sequence length.
  * `D` is the feature dimension.


Parameters:
    
  * a (array) – The input array.
  * dims (int) – The feature dimensions to be rotated. If the input feature is larger than dims then the rest is left unchanged.
  * traditional (bool) – If set to `True` choose the traditional implementation which rotates consecutive dimensions.
  * base (float, optional) – The base used to compute angular frequency for each dimension in the positional encodings. Exactly one of `base` and `freqs` must be `None`.
  * scale (float) – The scale used to scale the positions.
  * offset (int or array) – The position offset to start at. If an `array` is given it can be a scalar or vector of `B` offsets for each example in the batch.
  * freqs (array, optional) – Optional frequencies to use with RoPE. If set, the `base` parameter must be `None`. Default: `None`.


Returns:
    
The output array.
Return type:
    
array
# mlx.core.fast.scaled_dot_product_attention
scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: None | str | array = None, sinks: array | None = None, stream: None | Stream | Device = None) → array
    
A fast implementation of multi-head attention: `O = softmax(Q @ K.T, dim=-1) @ V`.
Supports:
  * Multi-Head Attention
  * Grouped Query Attention
  * Multi-Query Attention


Note
  * The softmax operation is performed in `float32` regardless of the input precision.
  * For Grouped Query Attention and Multi-Query Attention, the `k` and `v` inputs should not be pre-tiled to match `q`.


In the following the dimensions are given by:
  * `B`: The batch size.
  * `N_q`: The number of query heads.
  * `N_kv`: The number of key and value heads.
  * `T_q`: The number of queries per example.
  * `T_kv`: The number of keys and values per example.
  * `D`: The per-head dimension.


Parameters:
    
  * q (array) – Queries with shape `[B, N_q, T_q, D]`.
  * k (array) – Keys with shape `[B, N_kv, T_kv, D]`.
  * v (array) – Values with shape `[B, N_kv, T_kv, D]`.
  * scale (float) – Scale for queries (typically `1.0 / sqrt(q.shape(-1)`).
  * mask (str or array, optional) – The mask to apply to the query-key scores. The mask can be an array or a string indicating the mask type. The only supported string type is `"causal"`. If the mask is an array it can be a boolean or additive mask. The mask can have at most 4 dimensions and must be broadcast-compatible with the shape `[B, N, T_q, T_kv]`. If an additive mask is given its type must promote to the promoted type of `q`, `k`, and `v`.
  * sinks (array, optional) – An optional array of attention sinks. Default: `None`.


Returns:
    
The output array.
Return type:
    
array
Example
    
    B = 2
    N_q = N_kv = 32
    T_q = T_kv = 1000
    D = 128
    
    q = mx.random.normal(shape=(B, N_q, T_q, D))
    k = mx.random.normal(shape=(B, N_kv, T_kv, D))
    v = mx.random.normal(shape=(B, N_kv, T_kv, D))
    scale = D ** -0.5
    out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask="causal")
    
# mlx.core.fft.fft
fft(a: array, n: Optional[int] = None, axis: int = -1, stream: Optional[Union[Stream, Device]] = None) → array
    
One dimensional discrete Fourier Transform.
Parameters:
    
  * a (array) – The input array.
  * n (int, optional) – Size of the transformed axis. The corresponding axis in the input is truncated or padded with zeros to match `n`. The default value is `a.shape[axis]`.
  * axis (int, optional) – Axis along which to perform the FFT. The default is `-1`.


Returns:
    
The DFT of the input along the given axis.
Return type:
    
array
# mlx.core.fft.fft2
fft2(a: array, s: Optional[tuple[int, ...]] = None, axes: Optional[Sequence[int]] = [-2, -1], stream: Optional[Union[Stream, Device]] = None) → array
    
Two dimensional discrete Fourier Transform.
Parameters:
    
  * a (array) – The input array.
  * s (list(int), optional) – Sizes of the transformed axes. The corresponding axes in the input are truncated or padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes`.
  * axes (list(int), optional) – Axes along which to perform the FFT. The default is `[-2, -1]`.


Returns:
    
The DFT of the input along the given axes.
Return type:
    
array
# mlx.core.fft.fftn
fftn(a: array, s: Optional[tuple[int, ...]] = None, axes: Optional[Sequence[int]] = None, stream: Optional[Union[Stream, Device]] = None) → array
    
n-dimensional discrete Fourier Transform.
Parameters:
    
  * a (array) – The input array.
  * s (list(int), optional) – Sizes of the transformed axes. The corresponding axes in the input are truncated or padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes`.
  * axes (list(int), optional) – Axes along which to perform the FFT. The default is `None` in which case the FFT is over the last `len(s)` axes are or all axes if `s` is also `None`.


Returns:
    
The DFT of the input along the given axes.
Return type:
    
array
# mlx.core.fft.fftshift
fftshift(a: array, axes: Optional[Sequence[int]] = None, stream: Optional[Union[Stream, Device]] = None) → array
    
Shift the zero-frequency component to the center of the spectrum.
Parameters:
    
  * a (array) – The input array.
  * axes (list(int), optional) – Axes over which to perform the shift. If `None`, shift all axes.


Returns:
    
The shifted array with the same shape as the input.
Return type:
    
array
# mlx.core.fft.ifft
ifft(a: array, n: Optional[int] = None, axis: int = -1, stream: Optional[Union[Stream, Device]] = None) → array
    
One dimensional inverse discrete Fourier Transform.
Parameters:
    
  * a (array) – The input array.
  * n (int, optional) – Size of the transformed axis. The corresponding axis in the input is truncated or padded with zeros to match `n`. The default value is `a.shape[axis]`.
  * axis (int, optional) – Axis along which to perform the FFT. The default is `-1`.


Returns:
    
The inverse DFT of the input along the given axis.
Return type:
    
array
# mlx.core.fft.ifft2
ifft2(a: array, s: Optional[tuple[int, ...]] = None, axes: Optional[Sequence[int]] = [-2, -1], stream: Optional[Union[Stream, Device]] = None) → array
    
Two dimensional inverse discrete Fourier Transform.
Parameters:
    
  * a (array) – The input array.
  * s (list(int), optional) – Sizes of the transformed axes. The corresponding axes in the input are truncated or padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes`.
  * axes (list(int), optional) – Axes along which to perform the FFT. The default is `[-2, -1]`.


Returns:
    
The inverse DFT of the input along the given axes.
Return type:
    
array
# mlx.core.fft.ifftn
ifftn(a: array, s: Optional[tuple[int, ...]] = None, axes: Optional[Sequence[int]] = None, stream: Optional[Union[Stream, Device]] = None) → array
    
n-dimensional inverse discrete Fourier Transform.
Parameters:
    
  * a (array) – The input array.
  * s (list(int), optional) – Sizes of the transformed axes. The corresponding axes in the input are truncated or padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes`.
  * axes (list(int), optional) – Axes along which to perform the FFT. The default is `None` in which case the FFT is over the last `len(s)` axes or all axes if `s` is also `None`.


Returns:
    
The inverse DFT of the input along the given axes.
Return type:
    
array
# mlx.core.fft.ifftshift
ifftshift(a: array, axes: Optional[Sequence[int]] = None, stream: Optional[Union[Stream, Device]] = None) → array
    
The inverse of `fftshift()`. While identical to `fftshift()` for even-length axes, the behavior differs for odd-length axes.
Parameters:
    
  * a (array) – The input array.
  * axes (list(int), optional) – Axes over which to perform the inverse shift. If `None`, shift all axes.


Returns:
    
The inverse-shifted array with the same shape as the input.
Return type:
    
array
# mlx.core.fft.irfft
irfft(a: array, n: Optional[int] = None, axis: int = -1, stream: Optional[Union[Stream, Device]] = None) → array
    
The inverse of `rfft()`.
The output has the same shape as the input except along `axis` in which case it has size `n`.
Parameters:
    
  * a (array) – The input array.
  * n (int, optional) – Size of the transformed axis. The corresponding axis in the input is truncated or padded with zeros to match `n // 2 + 1`. The default value is `a.shape[axis] // 2 + 1`.
  * axis (int, optional) – Axis along which to perform the FFT. The default is `-1`.


Returns:
    
The real array containing the inverse of `rfft()`.
Return type:
    
array
# mlx.core.fft.irfft2
irfft2(a: array, s: Optional[tuple[int, ...]] = None, axes: Optional[Sequence[int]] = [-2, -1], stream: Optional[Union[Stream, Device]] = None) → array
    
The inverse of `rfft2()`.
Note the input is generally complex. The dimensions of the input specified in `axes` are padded or truncated to match the sizes from `s`. The last axis in `axes` is treated as the real axis and will have size `s[-1] // 2 + 1`.
Parameters:
    
  * a (array) – The input array.
  * s (list(int), optional) – Sizes of the transformed axes. The corresponding axes in the input are truncated or padded with zeros to match the sizes in `s` except for the last axis which has size `s[-1] // 2 + 1`. The default value is the sizes of `a` along `axes`.
  * axes (list(int), optional) – Axes along which to perform the FFT. The default is `[-2, -1]`.


Returns:
    
The real array containing the inverse of `rfft2()`.
Return type:
    
array
# mlx.core.fft.irfftn
irfftn(a: array, s: Optional[tuple[int, ...]] = None, axes: Optional[Sequence[int]] = None, stream: Optional[Union[Stream, Device]] = None) → array
    
The inverse of `rfftn()`.
Note the input is generally complex. The dimensions of the input specified in `axes` are padded or truncated to match the sizes from `s`. The last axis in `axes` is treated as the real axis and will have size `s[-1] // 2 + 1`.
Parameters:
    
  * a (array) – The input array.
  * s (list(int), optional) – Sizes of the transformed axes. The corresponding axes in the input are truncated or padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes`.
  * axes (list(int), optional) – Axes along which to perform the FFT. The default is `None` in which case the FFT is over the last `len(s)` axes or all axes if `s` is also `None`.


Returns:
    
The real array containing the inverse of `rfftn()`.
Return type:
    
array
# mlx.core.fft.rfft
rfft(a: array, n: Optional[int] = None, axis: int = -1, stream: Optional[Union[Stream, Device]] = None) → array
    
One dimensional discrete Fourier Transform on a real input.
The output has the same shape as the input except along `axis` in which case it has size `n // 2 + 1`.
Parameters:
    
  * a (array) – The input array. If the array is complex it will be silently cast to a real type.
  * n (int, optional) – Size of the transformed axis. The corresponding axis in the input is truncated or padded with zeros to match `n`. The default value is `a.shape[axis]`.
  * axis (int, optional) – Axis along which to perform the FFT. The default is `-1`.


Returns:
    
The DFT of the input along the given axis. The output data type will be complex.
Return type:
    
array
# mlx.core.fft.rfft2
rfft2(a: array, s: Optional[tuple[int, ...]] = None, axes: Optional[Sequence[int]] = [-2, -1], stream: Optional[Union[Stream, Device]] = None) → array
    
Two dimensional real discrete Fourier Transform.
The output has the same shape as the input except along the dimensions in `axes` in which case it has sizes from `s`. The last axis in `axes` is treated as the real axis and will have size `s[-1] // 2 + 1`.
Parameters:
    
  * a (array) – The input array. If the array is complex it will be silently cast to a real type.
  * s (list(int), optional) – Sizes of the transformed axes. The corresponding axes in the input are truncated or padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes`.
  * axes (list(int), optional) – Axes along which to perform the FFT. The default is `[-2, -1]`.


Returns:
    
The real DFT of the input along the given axes. The output data type will be complex.
Return type:
    
array
# mlx.core.fft.rfftn
rfftn(a: array, s: Optional[tuple[int, ...]] = None, axes: Optional[Sequence[int]] = None, stream: Optional[Union[Stream, Device]] = None) → array
    
n-dimensional real discrete Fourier Transform.
The output has the same shape as the input except along the dimensions in `axes` in which case it has sizes from `s`. The last axis in `axes` is treated as the real axis and will have size `s[-1] // 2 + 1`.
Parameters:
    
  * a (array) – The input array. If the array is complex it will be silently cast to a real type.
  * s (list(int), optional) – Sizes of the transformed axes. The corresponding axes in the input are truncated or padded with zeros to match the sizes in `s`. The default value is the sizes of `a` along `axes`.
  * axes (list(int), optional) – Axes along which to perform the FFT. The default is `None` in which case the FFT is over the last `len(s)` axes or all axes if `s` is also `None`.


Returns:
    
The real DFT of the input along the given axes. The output
Return type:
    
array
# mlx.core.finfo
class finfo
    
Get information on floating-point types.
__init__(self, arg: Dtype, /) → None
    
Methods
`__init__`(self, arg, /)  
Attributes
`dtype`
The `Dtype`.  
`eps`
The difference between 1.0 and the next smallest representable number larger than 1.0.  
`max`
The largest representable number.  
`min`
The smallest representable number.  
# mlx.core.flatten
flatten(a: array, /, start_axis: int = 0, end_axis: int = -1, *, stream: None | Stream | Device = None) → array
    
Flatten an array.
The axes flattened will be between `start_axis` and `end_axis`, inclusive. Negative axes are supported. After converting negative axis to positive, axes outside the valid range will be clamped to a valid value, `start_axis` to `0` and `end_axis` to `ndim - 1`.
Parameters:
    
  * a (array) – Input array.
  * start_axis (int, optional) – The first dimension to flatten. Defaults to `0`.
  * end_axis (int, optional) – The last dimension to flatten. Defaults to `-1`.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
The flattened array.
Return type:
    
array
Example
    
    >>> a = mx.array([[1, 2], [3, 4]])
    >>> mx.flatten(a)
    array([1, 2, 3, 4], dtype=int32)
    >>>
    >>> mx.flatten(a, start_axis=0, end_axis=-1)
    array([1, 2, 3, 4], dtype=int32)
    
# mlx.core.floor
floor(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise floor.
Parameters:
    
a (array) – Input array.
Returns:
    
The floor of `a`.
Return type:
    
array
# mlx.core.floor_divide
floor_divide(a: scalar | array, b: scalar | array, stream: None | Stream | Device = None) → array
    
Element-wise integer division.
If either array is a floating point type then it is equivalent to calling `floor()` after `divide()`.
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.


Returns:
    
The quotient `a // b`.
Return type:
    
array
# mlx.core.full
full(shape: int | Sequence[int], vals: scalar | array, dtype: Dtype | None = None, *, stream: None | Stream | Device = None) → array
    
Construct an array with the given value.
Constructs an array of size `shape` filled with `vals`. If `vals` is an `array` it must be broadcastable to the given `shape`.
Parameters:
    
  * shape (int or list(int)) – The shape of the output array.
  * vals (float or int or array) – Values to fill the array with.
  * dtype (Dtype, optional) – Data type of the output array. If unspecified the output type is inferred from `vals`.


Returns:
    
The output array with the specified shape and values.
Return type:
    
array
# mlx.core.gather_mm
gather_mm(a: array, b: array, /, lhs_indices: array, rhs_indices: array, *, sorted_indices: bool = False, stream: None | Stream | Device = None) → array
    
Matrix multiplication with matrix-level gather.
Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays. This operation is more efficient than explicitly applying a `take()` followed by a `matmul()`.
The indices `lhs_indices` and `rhs_indices` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of `a` and `b` respectively.
For `a` with shape `(A1, A2, ..., AS, M, K)`, `lhs_indices` contains indices from the range `[0, A1 * A2 * ... * AS)`
For `b` with shape `(B1, B2, ..., BS, M, K)`, `rhs_indices` contains indices from the range `[0, B1 * B2 * ... * BS)`
If only one index is passed and it is sorted, the `sorted_indices` flag can be passed for a possible faster implementation.
Parameters:
    
  * a (array) – Input array.
  * b (array) – Input array.
  * lhs_indices (array, optional) – Integer indices for `a`. Default: `None`
  * rhs_indices (array, optional) – Integer indices for `b`. Default: `None`
  * sorted_indices (bool, optional) – May allow a faster implementation if the passed indices are sorted. Default: `False`.


Returns:
    
The output array.
Return type:
    
array
# mlx.core.gather_qmm
gather_qmm(x: array, w: array, /, scales: array, biases: array | None = None, lhs_indices: array | None = None, rhs_indices: array | None = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, sorted_indices: bool = False, stream: None | Stream | Device = None) → array
    
Perform quantized matrix multiplication with matrix-level gather.
This operation is the quantized equivalent to `gather_mm()`. Similar to `gather_mm()`, the indices `lhs_indices` and `rhs_indices` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of `x` and `w` respectively.
Note that `scales` and `biases` must have the same batch dimensions as `w` since they represent the same quantized matrix.
Parameters:
    
  * x (array) – Input array
  * w (array) – Quantized matrix packed in unsigned integers
  * scales (array) – The scales to use per `group_size` elements of `w`
  * biases (array, optional) – The biases to use per `group_size` elements of `w`. Default: `None`.
  * lhs_indices (array, optional) – Integer indices for `x`. Default: `None`.
  * rhs_indices (array, optional) – Integer indices for `w`. Default: `None`.
  * transpose (bool, optional) – Defines whether to multiply with the transposed `w` or not, namely whether we are performing `x @ w.T` or `x @ w`. Default: `True`.
  * group_size (int, optional) – The size of the group in `w` that shares a scale and bias. Default: `64`.
  * bits (int, optional) – The number of bits occupied by each element in `w`. Default: `4`.
  * mode (str, optional) – The quantization mode. Default: `"affine"`.
  * sorted_indices (bool, optional) – May allow a faster implementation if the passed indices are sorted. Default: `False`.


Returns:
    
The result of the multiplication of `x` with `w`
    
after gathering using `lhs_indices` and `rhs_indices`.
Return type:
    
array
# mlx.core.get_active_memory
get_active_memory() → int
    
Get the actively used memory in bytes.
Note, this will not always match memory use reported by the system because it does not include cached memory buffers.
# mlx.core.get_cache_memory
get_cache_memory() → int
    
Get the cache size in bytes.
The cache includes memory not currently used that has not been returned to the system allocator.
# mlx.core.get_peak_memory
get_peak_memory() → int
    
Get the peak amount of used memory in bytes.
The maximum memory used recorded from the beginning of the program execution or since the last call to `reset_peak_memory()`.
# mlx.core.grad
grad(fun: Callable, argnums: int | Sequence[int] | None = None, argnames: str | Sequence[str] = []) → Callable
    
Returns a function which computes the gradient of `fun`.
Parameters:
    
  * fun (Callable) – A function which takes a variable number of `array` or trees of `array` and returns a scalar output `array`.
  * argnums (int or list(int), optional) – Specify the index (or indices) of the positional arguments of `fun` to compute the gradient with respect to. If neither `argnums` nor `argnames` are provided `argnums` defaults to `0` indicating `fun`’s first argument.
  * argnames (str or list(str), optional) – Specify keyword arguments of `fun` to compute gradients with respect to. It defaults to [] so no gradients for keyword arguments by default.


Returns:
    
A function which has the same input arguments as `fun` and returns the gradient(s).
Return type:
    
Callable
# mlx.core.greater
greater(a: scalar | array, b: scalar | array, stream: None | Stream | Device = None) → array
    
Element-wise greater than.
Strict greater than on two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.


Returns:
    
The element-wise comparison `a > b`.
Return type:
    
array
# mlx.core.greater_equal
greater_equal(a: scalar | array, b: scalar | array, stream: None | Stream | Device = None) → array
    
Element-wise greater or equal.
Greater than or equal on two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.


Returns:
    
The element-wise comparison `a >= b`.
Return type:
    
array
# mlx.core.hadamard_transform
hadamard_transform(a: array, scale: float | None = None, stream: None | Stream | Device = None) → array
    
Perform the Walsh-Hadamard transform along the final axis.
Equivalent to:
    
    from scipy.linalg import hadamard
    
    y = (hadamard(len(x)) @ x) * scale
    
Supports sizes `n = m*2^k` for `m` in `(1, 12, 20, 28)` and `2^k <= 8192` for float32 and `2^k <= 16384` for float16/bfloat16.
Parameters:
    
  * a (array) – Input array or scalar.
  * scale (float) – Scale the output by this factor. Defaults to `1/sqrt(a.shape[-1])` so that the Hadamard matrix is orthonormal.


Returns:
    
The transformed array.
Return type:
    
array
# mlx.core.identity
identity(n: int, dtype: Dtype | None = float32, *, stream: None | Stream | Device = None) → array
    
Create a square identity matrix.
Parameters:
    
  * n (int) – The number of rows and columns in the output.
  * dtype (Dtype, optional) – Data type of the output array. Defaults to float32.
  * stream (Stream, optional) – Stream or device. Defaults to None.


Returns:
    
An identity matrix of size n x n.
Return type:
    
array
# mlx.core.imag
imag(a: array, /, *, stream: None | Stream | Device = None) → array
    
Returns the imaginary part of a complex array.
Parameters:
    
a (array) – Input array.
Returns:
    
The imaginary part of `a`.
Return type:
    
array
# mlx.core.import_function
import_function(file: str) → Callable
    
Import a function from a file.
The imported function can be called either with `*args` and `**kwargs` or with a tuple of arrays and/or dictionary of string keys with array values. Imported functions always return a tuple of arrays.
Warning
This is part of an experimental API which is likely to change in future versions of MLX. Functions exported with older versions of MLX may not be compatible with future versions.
Parameters:
    
file (str) – The file path to import the function from.
Returns:
    
The imported function.
Return type:
    
Callable
Example
    
    >>> fn = mx.import_function("function.mlxfn")
    >>> out = fn(a, b, x=x, y=y)[0]
    >>>
    >>> out = fn((a, b), {"x": x, "y": y}[0]
    
# mlx.core.inner
inner(a: array, b: array, /, *, stream: None | Stream | Device = None) → array
    
Ordinary inner product of vectors for 1-D arrays, in higher dimensions a sum product over the last axes.
Parameters:
    
  * a (array) – Input array
  * b (array) – Input array


Returns:
    
The inner product.
Return type:
    
array
# mlx.core.isclose
isclose(a: array, b: array, /, rtol: float = 1e-05, atol: float = 1e-08, *, equal_nan: bool = False, stream: None | Stream | Device = None) → array
    
Returns a boolean array where two arrays are element-wise equal within a tolerance.
Infinite values are considered equal if they have the same sign, NaN values are not equal unless `equal_nan` is `True`.
Two values are considered equal if:
    
    abs(a - b) <= (atol + rtol * abs(b))
    
Note unlike `array_equal()`, this function supports numpy-style broadcasting.
Parameters:
    
  * a (array) – Input array.
  * b (array) – Input array.
  * rtol (float) – Relative tolerance.
  * atol (float) – Absolute tolerance.
  * equal_nan (bool) – If `True`, NaNs are considered equal. Defaults to `False`.


Returns:
    
The boolean output scalar indicating if the arrays are close.
Return type:
    
array
# mlx.core.isfinite
isfinite(a: array, stream: None | Stream | Device = None) → array
    
Return a boolean array indicating which elements are finite.
An element is finite if it is not infinite or NaN.
Parameters:
    
a (array) – Input array.
Returns:
    
The boolean array indicating which elements are finite.
Return type:
    
array
# mlx.core.isinf
isinf(a: array, stream: None | Stream | Device = None) → array
    
Return a boolean array indicating which elements are +/- inifnity.
Parameters:
    
a (array) – Input array.
Returns:
    
The boolean array indicating which elements are +/- infinity.
Return type:
    
array
# mlx.core.isnan
isnan(a: array, stream: None | Stream | Device = None) → array
    
Return a boolean array indicating which elements are NaN.
Parameters:
    
a (array) – Input array.
Returns:
    
The boolean array indicating which elements are NaN.
Return type:
    
array
# mlx.core.isneginf
isneginf(a: array, stream: None | Stream | Device = None) → array
    
Return a boolean array indicating which elements are negative infinity.
Parameters:
    
  * a (array) – Input array.
  * stream (Union[None, Stream, Device]) – Optional stream or device.


Returns:
    
The boolean array indicating which elements are negative infinity.
Return type:
    
array
# mlx.core.isposinf
isposinf(a: array, stream: None | Stream | Device = None) → array
    
Return a boolean array indicating which elements are positive infinity.
Parameters:
    
  * a (array) – Input array.
  * stream (Union[None, Stream, Device]) – Optional stream or device.


Returns:
    
The boolean array indicating which elements are positive infinity.
Return type:
    
array
# mlx.core.issubdtype
issubdtype(arg1: Dtype | DtypeCategory, arg2: Dtype | DtypeCategory) → bool
    
Check if a `Dtype` or `DtypeCategory` is a subtype of another.
Parameters:
    
  * (Union[Dtype (arg2) – First dtype or category.
  * DtypeCategory] – First dtype or category.
  * (Union[Dtype – Second dtype or category.
  * DtypeCategory] – Second dtype or category.


Returns:
    
A boolean indicating if the first input is a subtype of the second input.
Return type:
    
bool
Example
    
    >>> ints = mx.array([1, 2, 3], dtype=mx.int32)
    >>> mx.issubdtype(ints.dtype, mx.integer)
    True
    >>> mx.issubdtype(ints.dtype, mx.floating)
    False
    
    
    >>> floats = mx.array([1, 2, 3], dtype=mx.float32)
    >>> mx.issubdtype(floats.dtype, mx.integer)
    False
    >>> mx.issubdtype(floats.dtype, mx.floating)
    True
    
Similar types of different sizes are not subdtypes of each other:
    
    >>> mx.issubdtype(mx.float64, mx.float32)
    False
    >>> mx.issubdtype(mx.float32, mx.float64)
    False
    
but both are subtypes of floating:
    
    >>> mx.issubdtype(mx.float64, mx.floating)
    True
    >>> mx.issubdtype(mx.float32, mx.floating)
    True
    
For convenience, dtype-like objects are allowed too:
    
    >>> mx.issubdtype(mx.float32, mx.inexact)
    True
    >>> mx.issubdtype(mx.signedinteger, mx.floating)
    False
    
# mlx.core.jvp
jvp(fun: Callable, primals: list[array], tangents: list[array]) → tuple[list[array], list[array]]
    
Compute the Jacobian-vector product.
This computes the product of the Jacobian of a function `fun` evaluated at `primals` with the `tangents`.
Parameters:
    
  * fun (Callable) – A function which takes a variable number of `array` and returns a single `array` or list of `array`.
  * primals (list(array)) – A list of `array` at which to evaluate the Jacobian.
  * tangents (list(array)) – A list of `array` which are the “vector” in the Jacobian-vector product. The `tangents` should be the same in number, shape, and type as the inputs of `fun` (i.e. the `primals`).


Returns:
    
A list of the Jacobian-vector products which is the same in number, shape, and type of the inputs to `fun`.
Return type:
    
list(array)
# mlx.core.kron
kron(a: array, b: array, *, stream: None | Stream | Device = None) → array
    
Compute the Kronecker product of two arrays `a` and `b`.
Parameters:
    
  * a (array) – The first input array.
  * b (array) – The second input array.
  * stream (Union[None, Stream, Device], optional) – Optional stream or device for execution. Default: `None`.


Returns:
    
The Kronecker product of `a` and `b`.
Return type:
    
array
Examples
    
    >>> a = mx.array([[1, 2], [3, 4]])
    >>> b = mx.array([[0, 5], [6, 7]])
    >>> result = mx.kron(a, b)
    >>> print(result)
    array([[0, 5, 0, 10],
           [6, 7, 12, 14],
           [0, 15, 0, 20],
           [18, 21, 24, 28]], dtype=int32)
    
# mlx.core.left_shift
left_shift(a: scalar | array, b: scalar | array, stream: None | Stream | Device = None) → array
    
Element-wise left shift.
Shift the bits of the first input to the left by the second using numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.


Returns:
    
The bitwise left shift `a << b`.
Return type:
    
array
# mlx.core.less
less(a: scalar | array, b: scalar | array, stream: None | Stream | Device = None) → array
    
Element-wise less than.
Strict less than on two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.


Returns:
    
The element-wise comparison `a < b`.
Return type:
    
array
# mlx.core.less_equal
less_equal(a: scalar | array, b: scalar | array, stream: None | Stream | Device = None) → array
    
Element-wise less than or equal.
Less than or equal on two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.


Returns:
    
The element-wise comparison `a <= b`.
Return type:
    
array
# mlx.core.linalg.cholesky
cholesky(a: array, upper: bool = False, *, stream: None | Stream | Device = None) → array
    
Compute the Cholesky decomposition of a real symmetric positive semi-definite matrix.
This function supports arrays with at least 2 dimensions. When the input has more than two dimensions, the Cholesky decomposition is computed for each matrix in the last two dimensions of `a`.
If the input matrix is not symmetric positive semi-definite, behaviour is undefined.
Parameters:
    
  * a (array) – Input array.
  * upper (bool, optional) – If `True`, return the upper triangular Cholesky factor. If `False`, return the lower triangular Cholesky factor. Default: `False`.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
If `upper = False`, it returns a lower triangular `L` matrix such that `L @ L.T = a`. If `upper = True`, it returns an upper triangular `U` matrix such that `U.T @ U = a`.
Return type:
    
array
# mlx.core.linalg.cholesky_inv
cholesky_inv(L: array, upper: bool = False, *, stream: None | Stream | Device = None) → array
    
Compute the inverse of a real symmetric positive semi-definite matrix using it’s Cholesky decomposition.
Let \\(\mathbf{A}\\) be a real symmetric positive semi-definite matrix and \\(\mathbf{L}\\) its Cholesky decomposition such that:
\\[\begin{aligned} \mathbf{A} = \mathbf{L}\mathbf{L}^T \end{aligned}\\]
This function computes \\(\mathbf{A}^{-1}\\).
This function supports arrays with at least 2 dimensions. When the input has more than two dimensions, the Cholesky inverse is computed for each matrix in the last two dimensions of \\(\mathbf{L}\\).
If the input matrix is not a triangular matrix behaviour is undefined.
Parameters:
    
  * L (array) – Input array.
  * upper (bool, optional) – If `True`, return the upper triangular Cholesky factor. If `False`, return the lower triangular Cholesky factor. Default: `False`.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
\\(\mathbf{A^{-1}}\\) where \\(\mathbf{A} = \mathbf{L}\mathbf{L}^T\\).
Return type:
    
array
# mlx.core.linalg.cross
cross(a: array, b: array, axis: int = -1, *, stream: None | Stream | Device = None) → array
    
Compute the cross product of two arrays along a specified axis.
The cross product is defined for arrays with size 2 or 3 in the specified axis. If the size is 2 then the third value is assumed to be zero.
Parameters:
    
  * a (array) – Input array.
  * b (array) – Input array.
  * axis (int, optional) – Axis along which to compute the cross product. Default: `-1`.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
The cross product of `a` and `b` along the specified axis.
Return type:
    
array
# mlx.core.linalg.eig
eig(a: array, *, stream: None | Stream | Device = None) → Tuple[array, array]
    
Compute the eigenvalues and eigenvectors of a square matrix.
This function differs from `numpy.linalg.eig()` in that the return type is always complex even if the eigenvalues are all real.
This function supports arrays with at least 2 dimensions. When the input has more than two dimensions, the eigenvalues and eigenvectors are computed for each matrix in the last two dimensions.
Parameters:
    
  * a (array) – The input array.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
A tuple containing the eigenvalues and the normalized right eigenvectors. The column `v[:, i]` is the eigenvector corresponding to the i-th eigenvalue.
Return type:
    
Tuple[array, array]
Example
    
    >>> A = mx.array([[1., -2.], [-2., 1.]])
    >>> w, v = mx.linalg.eig(A, stream=mx.cpu)
    >>> w
    array([3+0j, -1+0j], dtype=complex64)
    >>> v
    array([[0.707107+0j, 0.707107+0j],
           [-0.707107+0j, 0.707107+0j]], dtype=complex64)
    
# mlx.core.linalg.eigh
eigh(a: array, UPLO: str = 'L', *, stream: None | Stream | Device = None) → Tuple[array, array]
    
Compute the eigenvalues and eigenvectors of a complex Hermitian or real symmetric matrix.
This function supports arrays with at least 2 dimensions. When the input has more than two dimensions, the eigenvalues and eigenvectors are computed for each matrix in the last two dimensions.
Parameters:
    
  * a (array) – Input array. Must be a real symmetric or complex Hermitian matrix.
  * UPLO (str, optional) – Whether to use the upper (`"U"`) or lower (`"L"`) triangle of the matrix. Default: `"L"`.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
A tuple containing the eigenvalues in ascending order and the normalized eigenvectors. The column `v[:, i]` is the eigenvector corresponding to the i-th eigenvalue.
Return type:
    
Tuple[array, array]
Note
The input matrix is assumed to be symmetric (or Hermitian). Only the selected triangle is used. No checks for symmetry are performed.
Example
    
    >>> A = mx.array([[1., -2.], [-2., 1.]])
    >>> w, v = mx.linalg.eigh(A, stream=mx.cpu)
    >>> w
    array([-1., 3.], dtype=float32)
    >>> v
    array([[ 0.707107, -0.707107],
          [ 0.707107,  0.707107]], dtype=float32)
    
# mlx.core.linalg.eigvals
eigvals(a: array, *, stream: Optional[Union[Stream, Device]] = None) → array
    
Compute the eigenvalues of a square matrix.
This function differs from `numpy.linalg.eigvals()` in that the return type is always complex even if the eigenvalues are all real.
This function supports arrays with at least 2 dimensions. When the input has more than two dimensions, the eigenvalues are computed for each matrix in the last two dimensions.
Parameters:
    
  * a (array) – The input array.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
The eigenvalues (not necessarily in order).
Return type:
    
array
Example
    
    >>> A = mx.array([[1., -2.], [-2., 1.]])
    >>> eigenvalues = mx.linalg.eigvals(A, stream=mx.cpu)
    >>> eigenvalues
    array([3+0j, -1+0j], dtype=complex64)
    
# mlx.core.linalg.eigvalsh
eigvalsh(a: array, UPLO: str = 'L', *, stream: Optional[Union[Stream, Device]] = None) → array
    
Compute the eigenvalues of a complex Hermitian or real symmetric matrix.
This function supports arrays with at least 2 dimensions. When the input has more than two dimensions, the eigenvalues are computed for each matrix in the last two dimensions.
Parameters:
    
  * a (array) – Input array. Must be a real symmetric or complex Hermitian matrix.
  * UPLO (str, optional) – Whether to use the upper (`"U"`) or lower (`"L"`) triangle of the matrix. Default: `"L"`.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
The eigenvalues in ascending order.
Return type:
    
array
Note
The input matrix is assumed to be symmetric (or Hermitian). Only the selected triangle is used. No checks for symmetry are performed.
Example
    
    >>> A = mx.array([[1., -2.], [-2., 1.]])
    >>> eigenvalues = mx.linalg.eigvalsh(A, stream=mx.cpu)
    >>> eigenvalues
    array([-1., 3.], dtype=float32)
    
# mlx.core.linalg.inv
inv(a: array, *, stream: None | Stream | Device = None) → array
    
Compute the inverse of a square matrix.
This function supports arrays with at least 2 dimensions. When the input has more than two dimensions, the inverse is computed for each matrix in the last two dimensions of `a`.
Parameters:
    
  * a (array) – Input array.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
`ainv` such that `dot(a, ainv) = dot(ainv, a) = eye(a.shape[0])`
Return type:
    
array
# mlx.core.linalg.lu
lu(a: array, *, stream: None | Stream | Device = None) → Tuple[array, array, array]
    
Compute the LU factorization of the given matrix `A`.
Note, unlike the default behavior of `scipy.linalg.lu`, the pivots are indices. To reconstruct the input use `L[P, :] @ U` for 2 dimensions or `mx.take_along_axis(L, P[..., None], axis=-2) @ U` for more than 2 dimensions.
To construct the full permuation matrix do:
    
    P = mx.put_along_axis(mx.zeros_like(L), p[..., None], mx.array(1.0), axis=-1)
    
Parameters:
    
  * a (array) – Input array.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
The `p`, `L`, and `U` arrays, such that `A = L[P, :] @ U`
Return type:
    
tuple(array, array, array)
# mlx.core.linalg.lu_factor
lu_factor(a: array, *, stream: None | Stream | Device = None) → Tuple[array, array]
    
Computes a compact representation of the LU factorization.
Parameters:
    
  * a (array) – Input array.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
The `LU` matrix and `pivots` array.
Return type:
    
tuple(array, array)
# mlx.core.linalg.norm
norm(a: array, /, ord: None | int | float | str = None, axis: None | int | list[int] = None, keepdims: bool = False, *, stream: None | Stream | Device = None) → array
    
Matrix or vector norm.
This function computes vector or matrix norms depending on the value of the `ord` and `axis` parameters.
Parameters:
    
  * a (array) – Input array. If `axis` is `None`, `a` must be 1-D or 2-D, unless `ord` is `None`. If both `axis` and `ord` are `None`, the 2-norm of `a.flatten` will be returned.
  * ord (int, float or str, optional) – Order of the norm (see table under `Notes`). If `None`, the 2-norm (or Frobenius norm for matrices) will be computed along the given `axis`. Default: `None`.
  * axis (int or list(int), optional) – If `axis` is an integer, it specifies the axis of `a` along which to compute the vector norms. If `axis` is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix norms of these matrices are computed. If axis is `None` then either a vector norm (when `a` is 1-D) or a matrix norm (when `a` is 2-D) is returned. Default: `None`.
  * keepdims (bool, optional) – If `True`, the axes which are normed over are left in the result as dimensions with size one. Default `False`.


Returns:
    
The output containing the norm(s).
Return type:
    
array
Notes
For values of `ord < 1`, the result is, strictly speaking, not a mathematical norm, but it may still be useful for various numerical purposes.
The following norms can be calculated:
ord
norm for matrices
norm for vectors  
None
Frobenius norm
2-norm  
‘fro’
Frobenius norm
–  
‘nuc’
nuclear norm
–  
inf
max(sum(abs(x), axis=1))
max(abs(x))  
-inf
min(sum(abs(x), axis=1))
min(abs(x))  
0
–
sum(x != 0)  
1
max(sum(abs(x), axis=0))
as below  
-1
min(sum(abs(x), axis=0))
as below  
2
2-norm (largest sing. value)
as below  
-2
smallest singular value
as below  
other
–
sum(abs(x)**ord)**(1./ord)  
The Frobenius norm is given by [1]:
> \\(||A||_F = [\sum_{i,j} abs(a_{i,j})^2]^{1/2}\\)
The nuclear norm is the sum of the singular values.
Both the Frobenius and nuclear norm orders are only defined for matrices and raise a `ValueError` when `a.ndim != 2`.
References
[1]
G. H. Golub and C. F. Van Loan, Matrix Computations, Baltimore, MD, Johns Hopkins University Press, 1985, pg. 15
Examples
    
    >>> import mlx.core as mx
    >>> from mlx.core import linalg as la
    >>> a = mx.arange(9) - 4
    >>> a
    array([-4, -3, -2, ..., 2, 3, 4], dtype=int32)
    >>> b = a.reshape((3,3))
    >>> b
    array([[-4, -3, -2],
           [-1,  0,  1],
           [ 2,  3,  4]], dtype=int32)
    >>> la.norm(a)
    array(7.74597, dtype=float32)
    >>> la.norm(b)
    array(7.74597, dtype=float32)
    >>> la.norm(b, 'fro')
    array(7.74597, dtype=float32)
    >>> la.norm(a, float("inf"))
    array(4, dtype=float32)
    >>> la.norm(b, float("inf"))
    array(9, dtype=float32)
    >>> la.norm(a, -float("inf"))
    array(0, dtype=float32)
    >>> la.norm(b, -float("inf"))
    array(2, dtype=float32)
    >>> la.norm(a, 1)
    array(20, dtype=float32)
    >>> la.norm(b, 1)
    array(7, dtype=float32)
    >>> la.norm(a, -1)
    array(0, dtype=float32)
    >>> la.norm(b, -1)
    array(6, dtype=float32)
    >>> la.norm(a, 2)
    array(7.74597, dtype=float32)
    >>> la.norm(a, 3)
    array(5.84804, dtype=float32)
    >>> la.norm(a, -3)
    array(0, dtype=float32)
    >>> c = mx.array([[ 1, 2, 3],
    ...               [-1, 1, 4]])
    >>> la.norm(c, axis=0)
    array([1.41421, 2.23607, 5], dtype=float32)
    >>> la.norm(c, axis=1)
    array([3.74166, 4.24264], dtype=float32)
    >>> la.norm(c, ord=1, axis=1)
    array([6, 6], dtype=float32)
    >>> m = mx.arange(8).reshape(2,2,2)
    >>> la.norm(m, axis=(1,2))
    array([3.74166, 11.225], dtype=float32)
    >>> la.norm(m[0, :, :]), LA.norm(m[1, :, :])
    (array(3.74166, dtype=float32), array(11.225, dtype=float32))
    
# mlx.core.linalg.pinv
pinv(a: array, *, stream: None | Stream | Device = None) → array
    
Compute the (Moore-Penrose) pseudo-inverse of a matrix.
This function calculates a generalized inverse of a matrix using its singular-value decomposition. This function supports arrays with at least 2 dimensions. When the input has more than two dimensions, the inverse is computed for each matrix in the last two dimensions of `a`.
Parameters:
    
  * a (array) – Input array.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
`aplus` such that `a @ aplus @ a = a`
Return type:
    
array
# mlx.core.linalg.qr
qr(a: array, *, stream: None | Stream | Device = None) → Tuple[array, array]
    
The QR factorization of the input matrix.
This function supports arrays with at least 2 dimensions. The matrices which are factorized are assumed to be in the last two dimensions of the input.
Parameters:
    
  * a (array) – Input array.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
`Q` and `R` matrices such that `Q @ R = a`.
Return type:
    
tuple(array, array)
Example
    
    >>> A = mx.array([[2., 3.], [1., 2.]])
    >>> Q, R = mx.linalg.qr(A, stream=mx.cpu)
    >>> Q
    array([[-0.894427, -0.447214],
           [-0.447214, 0.894427]], dtype=float32)
    >>> R
    array([[-2.23607, -3.57771],
           [0, 0.447214]], dtype=float32)
    
# mlx.core.linalg.solve
solve(a: array, b: array, *, stream: None | Stream | Device = None) → array
    
Compute the solution to a system of linear equations `AX = B`.
Parameters:
    
  * a (array) – Input array.
  * b (array) – Input array.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
The unique solution to the system `AX = B`.
Return type:
    
array
# mlx.core.linalg.solve_triangular
solve_triangular(a: array, b: array, *, upper: bool = False, stream: None | Stream | Device = None) → array
    
Computes the solution of a triangular system of linear equations `AX = B`.
Parameters:
    
  * a (array) – Input array.
  * b (array) – Input array.
  * upper (bool, optional) – Whether the array is upper or lower triangular. Default: `False`.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
The unique solution to the system `AX = B`.
Return type:
    
array
# mlx.core.linalg.svd
svd(a: array, compute_uv: bool = True, *, stream: None | Stream | Device = None) → Tuple[array, array, array]
    
The Singular Value Decomposition (SVD) of the input matrix.
This function supports arrays with at least 2 dimensions. When the input has more than two dimensions, the function iterates over all indices of the first a.ndim - 2 dimensions and for each combination SVD is applied to the last two indices.
Parameters:
    
  * a (array) – Input array.
  * compute_uv (bool, optional) – If `True`, return the `U`, `S`, and `Vt` components. If `False`, return only the `S` array. Default: `True`.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
If compute_uv is `True` returns the `U`, `S`, and `Vt` matrices, such that `A = U @ diag(S) @ Vt`. If compute_uv is `False` returns singular values array `S`.
Return type:
    
Union[tuple(array, …), array]
# mlx.core.linalg.tri_inv
tri_inv(a: array, upper: bool = False, *, stream: None | Stream | Device = None) → array
    
Compute the inverse of a triangular square matrix.
This function supports arrays with at least 2 dimensions. When the input has more than two dimensions, the inverse is computed for each matrix in the last two dimensions of `a`.
Parameters:
    
  * a (array) – Input array.
  * upper (bool, optional) – Whether the array is upper or lower triangular. Defaults to `False`.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
`ainv` such that `dot(a, ainv) = dot(ainv, a) = eye(a.shape[0])`
Return type:
    
array
# mlx.core.linspace
linspace(start, stop, num: int | None = 50, dtype: Dtype | None = float32, stream: None | Stream | Device = None) → array
    
Generate `num` evenly spaced numbers over interval `[start, stop]`.
Parameters:
    
  * start (scalar) – Starting value.
  * stop (scalar) – Stopping value.
  * num (int, optional) – Number of samples, defaults to `50`.
  * dtype (Dtype, optional) – Specifies the data type of the output, default to `float32`.


Returns:
    
The range of values.
Return type:
    
array
# mlx.core.load
load(file: file | str | Path, /, format: str | None = None, return_metadata: bool = False, *, stream: None | Stream | Device = None) → array | dict[str, array]
    
Load array(s) from a binary file.
The supported formats are `.npy`, `.npz`, `.safetensors`, and `.gguf`.
Parameters:
    
  * file (file, str, Path) – File in which the array is saved.
  * format (str, optional) – Format of the file. If `None`, the format is inferred from the file extension. Supported formats: `npy`, `npz`, and `safetensors`. Default: `None`.
  * return_metadata (bool, optional) – Load the metadata for formats which support matadata. The metadata will be returned as an additional dictionary. Default: `False`.


Returns:
    
A single array if loading from a `.npy` file or a dict mapping names to arrays if loading from a `.npz` or `.safetensors` file. If `return_metadata` is `True` an additional dictionary of metadata will be returned.
Return type:
    
array or dict
Warning
When loading unsupported quantization formats from GGUF, tensors will automatically cast to `mx.float16`
# mlx.core.log
log(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise natural logarithm.
Parameters:
    
a (array) – Input array.
Returns:
    
The natural logarithm of `a`.
Return type:
    
array
# mlx.core.log10
log10(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise base-10 logarithm.
Parameters:
    
a (array) – Input array.
Returns:
    
The base-10 logarithm of `a`.
Return type:
    
array
# mlx.core.log1p
log1p(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise natural log of one plus the array.
Parameters:
    
a (array) – Input array.
Returns:
    
The natural logarithm of one plus `a`.
Return type:
    
array
# mlx.core.log2
log2(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise base-2 logarithm.
Parameters:
    
a (array) – Input array.
Returns:
    
The base-2 logarithm of `a`.
Return type:
    
array
# mlx.core.logaddexp
logaddexp(a: scalar | array, b: scalar | array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise log-add-exp.
This is a numerically stable log-add-exp of two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
The computation is is a numerically stable version of `log(exp(a) + exp(b))`.
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.


Returns:
    
The log-add-exp of `a` and `b`.
Return type:
    
array
# mlx.core.logcumsumexp
logcumsumexp(a: array, /, axis: int | None = None, *, reverse: bool = False, inclusive: bool = True, stream: None | Stream | Device = None) → array
    
Return the cumulative logsumexp of the elements along the given axis.
Parameters:
    
  * a (array) – Input array
  * axis (int, optional) – Optional axis to compute the cumulative logsumexp over. If unspecified the cumulative logsumexp of the flattened array is returned.
  * reverse (bool) – Perform the cumulative logsumexp in reverse.
  * inclusive (bool) – The i-th element of the output includes the i-th element of the input.


Returns:
    
The output array.
Return type:
    
array
# mlx.core.logical_and
logical_and(a: array, b: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise logical and.
Parameters:
    
  * a (array) – First input array or scalar.
  * b (array) – Second input array or scalar.


Returns:
    
The boolean array containing the logical and of `a` and `b`.
Return type:
    
array
# mlx.core.logical_not
logical_not(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise logical not.
Parameters:
    
a (array) – Input array or scalar.
Returns:
    
The boolean array containing the logical not of `a`.
Return type:
    
array
# mlx.core.logical_or
logical_or(a: array, b: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise logical or.
Parameters:
    
  * a (array) – First input array or scalar.
  * b (array) – Second input array or scalar.


Returns:
    
The boolean array containing the logical or of `a` and `b`.
Return type:
    
array
# mlx.core.logsumexp
logsumexp(a: array, /, axis: None | int | Sequence[int] = None, keepdims: bool = False, *, stream: None | Stream | Device = None) → array
    
A log-sum-exp reduction over the given axes.
The log-sum-exp reduction is a numerically stable version of:
    
    log(sum(exp(a), axis))
    
Parameters:
    
  * a (array) – Input array.
  * axis (int or list(int), optional) – Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array.
  * keepdims (bool, optional) – Keep reduced axes as singleton dimensions, defaults to False.


Returns:
    
The output array with the corresponding axes reduced.
Return type:
    
array
# mlx.core.matmul
matmul(a: array, b: array, /, *, stream: None | Stream | Device = None) → array
    
Matrix multiplication.
Perform the (possibly batched) matrix multiplication of two arrays. This function supports broadcasting for arrays with more than two dimensions.
  * If the first array is 1-D then a 1 is prepended to its shape to make it a matrix. Similarly if the second array is 1-D then a 1 is appended to its shape to make it a matrix. In either case the singleton dimension is removed from the result.
  * A batched matrix multiplication is performed if the arrays have more than 2 dimensions. The matrix dimensions for the matrix product are the last two dimensions of each input.
  * All but the last two dimensions of each input are broadcast with one another using standard numpy-style broadcasting semantics.


Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.


Returns:
    
The matrix product of `a` and `b`.
Return type:
    
array
# mlx.core.max
max(a: array, /, axis: None | int | Sequence[int] = None, keepdims: bool = False, *, stream: None | Stream | Device = None) → array
    
A max reduction over the given axes.
Parameters:
    
  * a (array) – Input array.
  * axis (int or list(int), optional) – Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array.
  * keepdims (bool, optional) – Keep reduced axes as singleton dimensions, defaults to False.


Returns:
    
The output array with the corresponding axes reduced.
Return type:
    
array
# mlx.core.maximum
maximum(a: scalar | array, b: scalar | array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise maximum.
Take the element-wise max of two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.


Returns:
    
The max of `a` and `b`.
Return type:
    
array
# mlx.core.mean
mean(a: array, /, axis: None | int | Sequence[int] = None, keepdims: bool = False, *, stream: None | Stream | Device = None) → array
    
Compute the mean(s) over the given axes.
Parameters:
    
  * a (array) – Input array.
  * axis (int or list(int), optional) – Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array.
  * keepdims (bool, optional) – Keep reduced axes as singleton dimensions, defaults to False.


Returns:
    
The output array of means.
Return type:
    
array
# mlx.core.meshgrid
meshgrid(*arrays: array, sparse: bool | None = False, indexing: str | None = 'xy', stream: None | Stream | Device = None) → array
    
Generate multidimensional coordinate grids from 1-D coordinate arrays
Parameters:
    
  * *arrays (array) – Input arrays.
  * sparse (bool, optional) – If `True`, a sparse grid is returned in which each output array has a single non-zero element. If `False`, a dense grid is returned. Defaults to `False`.
  * indexing (str, optional) – Cartesian (‘xy’) or matrix (‘ij’) indexing of the output arrays. Defaults to `'xy'`.


Returns:
    
The output arrays.
Return type:
    
list(array)
# mlx.core.metal.device_info
device_info() → dict[str, Union[str, int]]
    
Get information about the GPU device and system settings.
Currently returns:
  * `architecture`
  * `max_buffer_size`
  * `max_recommended_working_set_size`
  * `memory_size`
  * `resource_limit`


Returns:
    
A dictionary with string keys and string or integer values.
Return type:
    
dict
# mlx.core.metal.is_available
is_available() → bool
    
Check if the Metal back-end is available.
# mlx.core.metal.start_capture
start_capture(path: str) → None
    
Start a Metal capture.
Parameters:
    
path (str) – The path to save the capture which should have the extension `.gputrace`.
# mlx.core.metal.stop_capture
stop_capture() → None
    
Stop a Metal capture.
# mlx.core.min
min(a: array, /, axis: None | int | Sequence[int] = None, keepdims: bool = False, *, stream: None | Stream | Device = None) → array
    
A min reduction over the given axes.
Parameters:
    
  * a (array) – Input array.
  * axis (int or list(int), optional) – Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array.
  * keepdims (bool, optional) – Keep reduced axes as singleton dimensions, defaults to False.


Returns:
    
The output array with the corresponding axes reduced.
Return type:
    
array
# mlx.core.minimum
minimum(a: scalar | array, b: scalar | array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise minimum.
Take the element-wise min of two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.


Returns:
    
The min of `a` and `b`.
Return type:
    
array
# mlx.core.moveaxis
moveaxis(a: array, /, source: int, destination: int, *, stream: None | Stream | Device = None) → array
    
Move an axis to a new position.
Parameters:
    
  * a (array) – Input array.
  * source (int) – Specifies the source axis.
  * destination (int) – Specifies the destination axis.


Returns:
    
The array with the axis moved.
Return type:
    
array
# mlx.core.multiply
multiply(a: scalar | array, b: scalar | array, stream: None | Stream | Device = None) → array
    
Element-wise multiplication.
Multiply two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.


Returns:
    
The multiplication `a * b`.
Return type:
    
array
# mlx.core.nan_to_num
nan_to_num(a: scalar | array, nan: float = 0, posinf: float | None = None, neginf: float | None = None, *, stream: None | Stream | Device = None) → array
    
Replace NaN and Inf values with finite numbers.
Parameters:
    
  * a (array) – Input array
  * nan (float, optional) – Value to replace NaN with. Default: `0`.
  * posinf (float, optional) – Value to replace positive infinities with. If `None`, defaults to largest finite value for the given data type. Default: `None`.
  * neginf (float, optional) – Value to replace negative infinities with. If `None`, defaults to the negative of the largest finite value for the given data type. Default: `None`.


Returns:
    
Output array with NaN and Inf replaced.
Return type:
    
array
# mlx.core.negative
negative(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise negation.
Parameters:
    
a (array) – Input array.
Returns:
    
The negative of `a`.
Return type:
    
array
# mlx.core.new_stream
new_stream(device: Device) → Stream
    
Make a new stream on the given device.
# mlx.core.not_equal
not_equal(a: scalar | array, b: scalar | array, stream: None | Stream | Device = None) → array
    
Element-wise not equal.
Not equal comparison on two arrays with numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.


Returns:
    
The element-wise comparison `a != b`.
Return type:
    
array
# mlx.core.ones
ones(shape: int | Sequence[int], dtype: Dtype | None = float32, *, stream: None | Stream | Device = None) → array
    
Construct an array of ones.
Parameters:
    
  * shape (int or list(int)) – The shape of the output array.
  * dtype (Dtype, optional) – Data type of the output array. If unspecified the output type defaults to `float32`.


Returns:
    
The array of ones with the specified shape.
Return type:
    
array
# mlx.core.ones_like
ones_like(a: array, /, *, stream: None | Stream | Device = None) → array
    
An array of ones like the input.
Parameters:
    
a (array) – The input to take the shape and type from.
Returns:
    
The output array filled with ones.
Return type:
    
array
# mlx.core.outer
outer(a: array, b: array, /, *, stream: None | Stream | Device = None) → array
    
Compute the outer product of two 1-D arrays, if the array’s passed are not 1-D a flatten op will be run beforehand.
Parameters:
    
  * a (array) – Input array
  * b (array) – Input array


Returns:
    
The outer product.
Return type:
    
array
# mlx.core.pad
pad(a: array, pad_width: int | tuple[int] | tuple[int, int] | list[tuple[int, int]], mode: Literal['constant', 'edge'] = 'constant', constant_values: scalar | array = 0, *, stream: None | Stream | Device = None) → array
    
Pad an array with a constant value
Parameters:
    
  * a (array) – Input array.
  * pad_width (int, tuple(int), tuple(int, int) or list(tuple(int, int))) – Number of padded values to add to the edges of each axis:`((before_1, after_1), (before_2, after_2), ..., (before_N, after_N))`. If a single pair of integers is passed then `(before_i, after_i)` are all the same. If a single integer or tuple with a single integer is passed then all axes are extended by the same number on each side.
  * mode – Padding mode. One of the following strings: “constant” (default): Pads with a constant value. “edge”: Pads with the edge values of array.
  * constant_value (array or scalar, optional) – Optional constant value to pad the edges of the array with.


Returns:
    
The padded array.
Return type:
    
array
# mlx.core.partition
partition(a: array, /, kth: int, axis: None | int = -1, *, stream: None | Stream | Device = None) → array
    
Returns a partitioned copy of the array such that the smaller `kth` elements are first.
The ordering of the elements in partitions is undefined.
Parameters:
    
  * a (array) – Input array.
  * kth (int) – Element at the `kth` index will be in its sorted position in the output. All elements before the kth index will be less or equal to the `kth` element and all elements after will be greater or equal to the `kth` element in the output.
  * axis (int or None, optional) – Optional axis to partition over. If `None`, this partitions over the flattened array. If unspecified, it defaults to `-1`.


Returns:
    
The partitioned array.
Return type:
    
array
# mlx.core.power
power(a: scalar | array, b: scalar | array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise power operation.
Raise the elements of a to the powers in elements of b with numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.


Returns:
    
Bases of `a` raised to powers in `b`.
Return type:
    
array
# mlx.core.prod
prod(a: array, /, axis: None | int | Sequence[int] = None, keepdims: bool = False, *, stream: None | Stream | Device = None) → array
    
An product reduction over the given axes.
Parameters:
    
  * a (array) – Input array.
  * axis (int or list(int), optional) – Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array.
  * keepdims (bool, optional) – Keep reduced axes as singleton dimensions, defaults to False.


Returns:
    
The output array with the corresponding axes reduced.
Return type:
    
array
# mlx.core.put_along_axis
put_along_axis(a: array, /, indices: array, values: array, axis: int | None = None, *, stream: None | Stream | Device = None) → array
    
Put values along an axis at the specified indices.
Parameters:
    
  * a (array) – Destination array.
  * indices (array) – Indices array. These should be broadcastable with the input array excluding the axis dimension.
  * values (array) – Values array. These should be broadcastable with the indices.
  * axis (int or None) – Axis in the destination to put the values to. If `axis == None` the destination is flattened prior to the put operation.


Returns:
    
The output array.
Return type:
    
array
# mlx.core.quantize
quantize(w: array, /, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: None | Stream | Device = None) → tuple[array, array, array]
    
Quantize the matrix `w` using `bits` bits per element.
Note, every `group_size` elements in a row of `w` are quantized together. Hence, number of columns of `w` should be divisible by `group_size`. In particular, the rows of `w` are divided into groups of size `group_size` which are quantized together.
Warning
`quantize` currently only supports 2D inputs with the second dimension divisible by `group_size`
The supported quantization modes are `"affine"` and `"mxfp4"`. They are described in more detail below.
Parameters:
    
  * w (array) – Matrix to be quantized
  * group_size (int, optional) – The size of the group in `w` that shares a scale and bias. Default: `64`.
  * bits (int, optional) – The number of bits occupied by each element of `w` in the returned quantized matrix. Default: `4`.
  * mode (str, optional) – The quantization mode. Default: `"affine"`.


Returns:
    
A tuple with either two or three elements containing:
  * w_q (array): The quantized version of `w`
  * scales (array): The quantization scales
  * biases (array): The quantization biases (returned for `mode=="affine"`).


Return type:
    
tuple
Notes
The `affine` mode quantizes groups of \\(g\\) consecutive elements in a row of `w`. For each group the quantized representation of each element \\(\hat{w_i}\\) is computed as follows:
\\[\begin{split}\begin{aligned} \alpha &= \max_i w_i \\\ \beta &= \min_i w_i \\\ s &= \frac{\alpha - \beta}{2^b - 1} \\\ \hat{w_i} &= \textrm{round}\left( \frac{w_i - \beta}{s}\right). \end{aligned}\end{split}\\]
After the above computation, \\(\hat{w_i}\\) fits in \\(b\\) bits and is packed in an unsigned 32-bit integer from the lower to upper bits. For instance, for 4-bit quantization we fit 8 elements in an unsigned 32 bit integer where the 1st element occupies the 4 least significant bits, the 2nd bits 4-7 etc.
To dequantize the elements of `w`, we also save \\(s\\) and \\(\beta\\) which are the returned `scales` and `biases` respectively.
The `mxfp4` mode similarly quantizes groups of \\(g\\) elements of `w`. For `mxfp4` the group size must be `32`. The elements are quantized to 4-bit precision floating-point values (E2M1) with a shared 8-bit scale per group. Unlike `affine` quantization, `mxfp4` does not have a bias value. More details on the format can be found in the specification.
# mlx.core.quantized_matmul
quantized_matmul(x: array, w: array, /, scales: array, biases: array | None = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine', *, stream: None | Stream | Device = None) → array
    
Perform the matrix multiplication with the quantized matrix `w`. The quantization uses one floating point scale and bias per `group_size` of elements. Each element in `w` takes `bits` bits and is packed in an unsigned 32 bit integer.
Parameters:
    
  * x (array) – Input array
  * w (array) – Quantized matrix packed in unsigned integers
  * scales (array) – The scales to use per `group_size` elements of `w`
  * biases (array, optional) – The biases to use per `group_size` elements of `w`. Default: `None`.
  * transpose (bool, optional) – Defines whether to multiply with the transposed `w` or not, namely whether we are performing `x @ w.T` or `x @ w`. Default: `True`.
  * group_size (int, optional) – The size of the group in `w` that shares a scale and bias. Default: `64`.
  * bits (int, optional) – The number of bits occupied by each element in `w`. Default: `4`.
  * mode (str, optional) – The quantization mode. Default: `"affine"`.


Returns:
    
The result of the multiplication of `x` with `w`.
Return type:
    
array
# mlx.core.radians
radians(a: array, /, *, stream: None | Stream | Device = None) → array
    
Convert angles from degrees to radians.
Parameters:
    
a (array) – Input array.
Returns:
    
The angles in radians.
Return type:
    
array
# mlx.core.random.bernoulli
bernoulli(p: scalar | array = 0.5, shape: Sequence[int] | None = None, key: array | None = None, stream: None | Stream | Device = None) → array
    
Generate Bernoulli random values.
The values are sampled from the bernoulli distribution with parameter `p`. The parameter `p` can be a `float` or `array` and must be broadcastable to `shape`.
Parameters:
    
  * p (float or array, optional) – Parameter of the Bernoulli distribution. Default: `0.5`.
  * shape (list(int), optional) – Shape of the output. Default: `p.shape`.
  * key (array, optional) – A PRNG key. Default: `None`.


Returns:
    
The array of random integers.
Return type:
    
array
# mlx.core.random.categorical
categorical(logits: array, axis: int = -1, shape: Sequence[int] | None = None, num_samples: int | None = None, key: array | None = None, stream: None | Stream | Device = None) → array
    
Sample from a categorical distribution.
The values are sampled from the categorical distribution specified by the unnormalized values in `logits`. Note, at most one of `shape` or `num_samples` can be specified. If both are `None`, the output has the same shape as `logits` with the `axis` dimension removed.
Parameters:
    
  * logits (array) – The unnormalized categorical distribution(s).
  * axis (int, optional) – The axis which specifies the distribution. Default: `-1`.
  * shape (list(int), optional) – The shape of the output. This must be broadcast compatible with `logits.shape` with the `axis` dimension removed. Default: `None`
  * num_samples (int, optional) – The number of samples to draw from each of the categorical distributions in `logits`. The output will have `num_samples` in the last dimension. Default: `None`.
  * key (array, optional) – A PRNG key. Default: `None`.


Returns:
    
The `shape`-sized output array with type `uint32`.
Return type:
    
array
# mlx.core.random.gumbel
gumbel(shape: Sequence[int] = [], dtype: Dtype | None = float32, key: None | Stream | Device = None, stream: array | None = None) → array
    
Sample from the standard Gumbel distribution.
The values are sampled from a standard Gumbel distribution which CDF `exp(-exp(-x))`.
Parameters:
    
  * shape (list(int)) – The shape of the output.
  * dtype (Dtype, optional) – The data type of the output. Default: `float32`.
  * key (array, optional) – A PRNG key. Default: `None`.


Returns:
    
The `array` with shape `shape` and distributed according to the Gumbel distribution.
Return type:
    
array
# mlx.core.random.key
key(seed: int) → array
    
Get a PRNG key from a seed.
Parameters:
    
seed (int) – Seed for the PRNG.
Returns:
    
The PRNG key array.
Return type:
    
array
# mlx.core.random.laplace
laplace(shape: Sequence[int] = [], dtype: Dtype | None = float32, loc: float = 0.0, scale: float = 1.0, key: array | None = None, stream: None | Stream | Device = None) → array
    
Sample numbers from a Laplace distribution.
Parameters:
    
  * shape (list(int), optional) – Shape of the output. Default: `()`.
  * dtype (Dtype, optional) – Type of the output. Default: `float32`.
  * loc (float, optional) – Mean of the distribution. Default: `0.0`.
  * scale (float, optional) – The scale “b” of the Laplace distribution. Default:`1.0`.
  * key (array, optional) – A PRNG key. Default: `None`.


Returns:
    
The output array of random values.
Return type:
    
array
# mlx.core.random.multivariate_normal
multivariate_normal(mean: array, cov: array, shape: Sequence[int] = [], dtype: Dtype | None = float32, key: array | None = None, stream: None | Stream | Device = None) → array
    
Generate jointly-normal random samples given a mean and covariance.
The matrix `cov` must be positive semi-definite. The behavior is undefined if it is not. The only supported `dtype` is `float32`.
Parameters:
    
  * mean (array) – array of shape `(..., n)`, the mean of the distribution.
  * cov (array) – array of shape `(..., n, n)`, the covariance matrix of the distribution. The batch shape `...` must be broadcast-compatible with that of `mean`.
  * shape (list(int), optional) – The output shape must be broadcast-compatible with `mean.shape[:-1]` and `cov.shape[:-2]`. If empty, the result shape is determined by broadcasting the batch shapes of `mean` and `cov`. Default: `[]`.
  * dtype (Dtype, optional) – The output type. Default: `float32`.
  * key (array, optional) – A PRNG key. Default: `None`.


Returns:
    
The output array of random values.
Return type:
    
array
# mlx.core.random.normal
normal(shape: Sequence[int] = [], dtype: Dtype | None = float32, loc: scalar | array | None = None, scale: scalar | array | None = None, key: array | None = None, stream: None | Stream | Device = None) → array
    
Generate normally distributed random numbers.
If `loc` and `scale` are not provided the “standard” normal distribution is used. That means $x sim mathcal{N}(0, 1)$ for real numbers and $text{Re}(x),text{Im}(x) sim mathcal{N}(0, frac{1}{2})$ for complex numbers.
Parameters:
    
  * shape (list(int), optional) – Shape of the output. Default: `()`.
  * dtype (Dtype, optional) – Type of the output. Default: `float32`.
  * loc (scalar or array, optional) – Mean of the distribution. Default: `None`.
  * scale (scalar or array, optional) – Standard deviation of the distribution. Default: `None`.
  * key (array, optional) – A PRNG key. Default: `None`.


Returns:
    
The output array of random values.
Return type:
    
array
# mlx.core.random.permutation
permutation(x: int | array, axis: int = 0, key: array | None = None, stream: None | Stream | Device = None) → array
    
Generate a random permutation or permute the entries of an array.
Parameters:
    
  * x (int or array, optional) – If an integer is provided a random permtuation of `mx.arange(x)` is returned. Otherwise the entries of `x` along the given axis are randomly permuted.
  * axis (int, optional) – The axis to permute along. Default: `0`.
  * key (array, optional) – A PRNG key. Default: `None`.


Returns:
    
The generated random permutation or randomly permuted input array.
Return type:
    
array
# mlx.core.random.randint
randint(low: scalar | array, high: scalar | array, shape: Sequence[int] = [], dtype: Dtype | None = int32, key: array | None = None, stream: None | Stream | Device = None) → array
    
Generate random integers from the given interval.
The values are sampled with equal probability from the integers in half-open interval `[low, high)`. The lower and upper bound can be scalars or arrays and must be broadcastable to `shape`.
Parameters:
    
  * low (scalar or array) – Lower bound of the interval.
  * high (scalar or array) – Upper bound of the interval.
  * shape (list(int), optional) – Shape of the output. Default: `()`.
  * dtype (Dtype, optional) – Type of the output. Default: `int32`.
  * key (array, optional) – A PRNG key. Default: `None`.


Returns:
    
The array of random integers.
Return type:
    
array
# mlx.core.random.seed
seed(seed: int) → None
    
Seed the global PRNG.
Parameters:
    
seed (int) – Seed for the global PRNG.
# mlx.core.random.split
split(key: array, num: int = 2, stream: None | Stream | Device = None) → array
    
Split a PRNG key into sub keys.
Parameters:
    
  * key (array) – Input key to split.
  * num (int, optional) – Number of sub keys. Default: `2`.


Returns:
    
The array of sub keys with `num` as its first dimension.
Return type:
    
array
# mlx.core.random.truncated_normal
truncated_normal(lower: scalar | array, upper: scalar | array, shape: Sequence[int] | None = None, dtype: Dtype | None = float32, key: array | None = None, stream: None | Stream | Device = None) → array
    
Generate values from a truncated normal distribution.
The values are sampled from the truncated normal distribution on the domain `(lower, upper)`. The bounds `lower` and `upper` can be scalars or arrays and must be broadcastable to `shape`.
Parameters:
    
  * lower (scalar or array) – Lower bound of the domain.
  * upper (scalar or array) – Upper bound of the domain.
  * shape (list(int), optional) – The shape of the output. Default:`()`.
  * dtype (Dtype, optional) – The data type of the output. Default: `float32`.
  * key (array, optional) – A PRNG key. Default: `None`.


Returns:
    
The output array of random values.
Return type:
    
array
# mlx.core.random.uniform
uniform(low: scalar | array = 0, high: scalar | array = 1, shape: Sequence[int] = [], dtype: Dtype | None = float32, key: array | None = None, stream: None | Stream | Device = None) → array
    
Generate uniformly distributed random numbers.
The values are sampled uniformly in the half-open interval `[low, high)`. The lower and upper bound can be scalars or arrays and must be broadcastable to `shape`.
Parameters:
    
  * low (scalar or array, optional) – Lower bound of the distribution. Default: `0`.
  * high (scalar or array, optional) – Upper bound of the distribution. Default: `1`.
  * shape (list(int), optional) – Shape of the output. Default:`()`.
  * dtype (Dtype, optional) – Type of the output. Default: `float32`.
  * key (array, optional) – A PRNG key. Default: `None`.


Returns:
    
The output array random values.
Return type:
    
array
# mlx.core.real
real(a: array, /, *, stream: None | Stream | Device = None) → array
    
Returns the real part of a complex array.
Parameters:
    
a (array) – Input array.
Returns:
    
The real part of `a`.
Return type:
    
array
# mlx.core.reciprocal
reciprocal(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise reciprocal.
Parameters:
    
a (array) – Input array.
Returns:
    
The reciprocal of `a`.
Return type:
    
array
# mlx.core.remainder
remainder(a: scalar | array, b: scalar | array, stream: None | Stream | Device = None) → array
    
Element-wise remainder of division.
Computes the remainder of dividing a with b with numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.


Returns:
    
The remainder of `a // b`.
Return type:
    
array
# mlx.core.repeat
repeat(array: array, repeats: int, axis: int | None = None, *, stream: None | Stream | Device = None) → array
    
Repeat an array along a specified axis.
Parameters:
    
  * array (array) – Input array.
  * repeats (int) – The number of repetitions for each element.
  * axis (int, optional) – The axis in which to repeat the array along. If unspecified it uses the flattened array of the input and repeats along axis 0.
  * stream (Stream, optional) – Stream or device. Defaults to `None`.


Returns:
    
The resulting repeated array.
Return type:
    
array
# mlx.core.reset_peak_memory
reset_peak_memory() → None
    
Reset the peak memory to zero.
# mlx.core.reshape
reshape(a: array, /, shape: Sequence[int], *, stream: None | Stream | Device = None) → array
    
Reshape an array while preserving the size.
Parameters:
    
  * a (array) – Input array.
  * shape (tuple(int)) – New shape.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
The reshaped array.
Return type:
    
array
# mlx.core.right_shift
right_shift(a: scalar | array, b: scalar | array, stream: None | Stream | Device = None) → array
    
Element-wise right shift.
Shift the bits of the first input to the right by the second using numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.


Returns:
    
The bitwise right shift `a >> b`.
Return type:
    
array
# mlx.core.roll
roll(a: array, shift: int | Tuple[int], axis: None | int | Tuple[int] = None, /, *, stream: None | Stream | Device = None) → array
    
Roll array elements along a given axis.
Elements that are rolled beyond the end of the array are introduced at the beggining and vice-versa.
If the axis is not provided the array is flattened, rolled and then the shape is restored.
Parameters:
    
  * a (array) – Input array
  * shift (int or tuple(int)) – The number of places by which elements are shifted. If positive the array is rolled to the right, if negative it is rolled to the left. If an int is provided but the axis is a tuple then the same value is used for all axes.
  * axis (int or tuple(int), optional) – The axis or axes along which to roll the elements.


# mlx.core.round
round(a: array, /, decimals: int = 0, stream: None | Stream | Device = None) → array
    
Round to the given number of decimals.
Basically performs:
    
    s = 10**decimals
    x = round(x * s) / s
    
Parameters:
    
  * a (array) – Input array
  * decimals (int) – Number of decimal places to round to. (default: 0)


Returns:
    
An array of the same type as `a` rounded to the given number of decimals.
Return type:
    
array
# mlx.core.rsqrt
rsqrt(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise reciprocal and square root.
Parameters:
    
a (array) – Input array.
Returns:
    
One over the square root of `a`.
Return type:
    
array
# mlx.core.save
save(file: file | str | Path, arr: array) → None
    
Save the array to a binary file in `.npy` format.
Parameters:
    
  * file (str, Path, file) – File to which the array is saved
  * arr (array) – Array to be saved.


# mlx.core.save_gguf
save_gguf(file: file | str | Path, arrays: dict[str, array], metadata: dict[str, array | str | list[str]])
    
Save array(s) to a binary file in `.gguf` format.
See the GGUF documentation for more information on the format.
Parameters:
    
  * file (file, str, Path) – File in which the array is saved.
  * arrays (dict(str, array)) – The dictionary of names to arrays to be saved.
  * metadata (dict(str, Union[array, str, list(str)])) – The dictionary of metadata to be saved. The values can be a scalar or 1D obj:array, a `str`, or a `list` of `str`.


# mlx.core.save_safetensors
save_safetensors(file: file | str | Path, arrays: dict[str, array], metadata: dict[str, str] | None = None)
    
Save array(s) to a binary file in `.safetensors` format.
See the Safetensors documentation for more information on the format.
Parameters:
    
  * file (file, str, Path) – File in which the array is saved.
  * arrays (dict(str, array)) – The dictionary of names to arrays to be saved.
  * metadata (dict(str, str), optional) – The dictionary of metadata to be saved.


# mlx.core.savez
savez(file: file | str | Path, *args, **kwargs)
    
Save several arrays to a binary file in uncompressed `.npz` format.
    
    import mlx.core as mx
    
    x = mx.ones((10, 10))
    mx.savez("my_path.npz", x=x)
    
    import mlx.nn as nn
    from mlx.utils import tree_flatten
    
    model = nn.TransformerEncoder(6, 128, 4)
    flat_params = tree_flatten(model.parameters())
    mx.savez("model.npz", **dict(flat_params))
    
Parameters:
    
  * file (file, str, Path) – Path to file to which the arrays are saved.
  * *args (arrays) – Arrays to be saved.
  * **kwargs (arrays) – Arrays to be saved. Each array will be saved with the associated keyword as the output file name.


# mlx.core.savez_compressed
savez_compressed(file: file | str | Path, *args, **kwargs)
    
Save several arrays to a binary file in compressed `.npz` format.
Parameters:
    
  * file (file, str, Path) – Path to file to which the arrays are saved.
  * *args (arrays) – Arrays to be saved.
  * **kwargs (arrays) – Arrays to be saved. Each array will be saved with the associated keyword as the output file name.


# mlx.core.set_cache_limit
set_cache_limit(limit: int) → int
    
Set the free cache limit.
If using more than the given limit, free memory will be reclaimed from the cache on the next allocation. To disable the cache, set the limit to `0`.
The cache limit defaults to the memory limit. See `set_memory_limit()` for more details.
Parameters:
    
limit (int) – The cache limit in bytes.
Returns:
    
The previous cache limit in bytes.
Return type:
    
int
# mlx.core.set_default_device
set_default_device(device: Device) → None
    
Set the default device.
# mlx.core.set_default_stream
set_default_stream(stream: Stream) → None
    
Set the default stream.
This will make the given stream the default for the streams device. It will not change the default device.
Parameters:
    
stream (stream) – Stream to make the default.
# mlx.core.set_memory_limit
set_memory_limit(limit: int) → int
    
Set the memory limit.
The memory limit is a guideline for the maximum amount of memory to use during graph evaluation. If the memory limit is exceeded and there is no more RAM (including swap when available) allocations will result in an exception.
When metal is available the memory limit defaults to 1.5 times the maximum recommended working set size reported by the device.
Parameters:
    
limit (int) – Memory limit in bytes.
Returns:
    
The previous memory limit in bytes.
Return type:
    
int
# mlx.core.set_wired_limit
set_wired_limit(limit: int) → int
    
Set the wired size limit.
Note
  * This function is only useful on macOS 15.0 or higher.
  * The wired limit should remain strictly less than the total memory size.


The wired limit is the total size in bytes of memory that will be kept resident. The default value is `0`.
Setting a wired limit larger than system wired limit is an error. You can increase the system wired limit with:
    
    sudo sysctl iogpu.wired_limit_mb=<size_in_megabytes>
    
Use `device_info()` to query the system wired limit (`"max_recommended_working_set_size"`) and the total memory size (`"memory_size"`).
Parameters:
    
limit (int) – The wired limit in bytes.
Returns:
    
The previous wired limit in bytes.
Return type:
    
int
# mlx.core.sigmoid
sigmoid(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise logistic sigmoid.
The logistic sigmoid function is:
\\[\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}\\]
Parameters:
    
a (array) – Input array.
Returns:
    
The logistic sigmoid of `a`.
Return type:
    
array
# mlx.core.sign
sign(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise sign.
Parameters:
    
a (array) – Input array.
Returns:
    
The sign of `a`.
Return type:
    
array
# mlx.core.sin
sin(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise sine.
Parameters:
    
a (array) – Input array.
Returns:
    
The sine of `a`.
Return type:
    
array
# mlx.core.sinh
sinh(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise hyperbolic sine.
Parameters:
    
a (array) – Input array.
Returns:
    
The hyperbolic sine of `a`.
Return type:
    
array
# mlx.core.slice
slice(a: array, start_indices: array, axes: Sequence[int], slice_size: Sequence[int], *, stream: None | Stream | Device = None) → array
    
Extract a sub-array from the input array.
Parameters:
    
  * a (array) – Input array
  * start_indices (array) – The index location to start the slice at.
  * axes (tuple(int)) – The axes corresponding to the indices in `start_indices`.
  * slice_size (tuple(int)) – The size of the slice.


Returns:
    
The sliced output array.
Return type:
    
array
Example
    
    >>> a = mx.array([[1, 2, 3], [4, 5, 6]])
    >>> mx.slice(a, start_indices=mx.array(1), axes=(0,), slice_size=(1, 2))
    array([[4, 5]], dtype=int32)
    >>>
    >>> mx.slice(a, start_indices=mx.array(1), axes=(1,), slice_size=(2, 1))
    array([[2],
           [5]], dtype=int32)
    
# mlx.core.slice_update
slice_update(a: array, update: array, start_indices: array, axes: Sequence[int], *, stream: None | Stream | Device = None) → array
    
Update a sub-array of the input array.
Parameters:
    
  * a (array) – The input array to update
  * update (array) – The update array.
  * start_indices (array) – The index location to start the slice at.
  * axes (tuple(int)) – The axes corresponding to the indices in `start_indices`.


Returns:
    
The output array with the same shape and type as the input.
Return type:
    
array
Example
    
    >>> a = mx.zeros((3, 3))
    >>> mx.slice_update(a, mx.ones((1, 2)), start_indices=mx.array(1, 1), axes=(0, 1))
    array([[0, 0, 0],
           [0, 1, 0],
           [0, 1, 0]], dtype=float32)
    
# mlx.core.softmax
softmax(a: array, /, axis: None | int | Sequence[int] = None, *, stream: None | Stream | Device = None) → array
    
Perform the softmax along the given axis.
This operation is a numerically stable version of:
    
    exp(a) / sum(exp(a), axis, keepdims=True)
    
Parameters:
    
  * a (array) – Input array.
  * axis (int or list(int), optional) – Optional axis or axes to compute the softmax over. If unspecified this performs the softmax over the full array.


Returns:
    
The output of the softmax.
Return type:
    
array
# mlx.core.sort
sort(a: array, /, axis: None | int = -1, *, stream: None | Stream | Device = None) → array
    
Returns a sorted copy of the array.
Parameters:
    
  * a (array) – Input array.
  * axis (int or None, optional) – Optional axis to sort over. If `None`, this sorts over the flattened array. If unspecified, it defaults to -1 (sorting over the last axis).


Returns:
    
The sorted array.
Return type:
    
array
# mlx.core.split
split(a: array, /, indices_or_sections: int | Sequence[int], axis: int = 0, *, stream: None | Stream | Device = None) → array
    
Split an array along a given axis.
Parameters:
    
  * a (array) – Input array.
  * indices_or_sections (int or list(int)) – If `indices_or_sections` is an integer the array is split into that many sections of equal size. An error is raised if this is not possible. If `indices_or_sections` is a list, the list contains the indices of the start of each subarray along the given axis.
  * axis (int, optional) – Axis to split along, defaults to 0.


Returns:
    
A list of split arrays.
Return type:
    
list(array)
# mlx.core.sqrt
sqrt(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise square root.
Parameters:
    
a (array) – Input array.
Returns:
    
The square root of `a`.
Return type:
    
array
# mlx.core.square
square(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise square.
Parameters:
    
a (array) – Input array.
Returns:
    
The square of `a`.
Return type:
    
array
# mlx.core.squeeze
squeeze(a: array, /, axis: None | int | Sequence[int] = None, *, stream: None | Stream | Device = None) → array
    
Remove length one axes from an array.
Parameters:
    
  * a (array) – Input array.
  * axis (int or tuple(int), optional) – Axes to remove. Defaults to `None` in which case all size one axes are removed.


Returns:
    
The output array with size one axes removed.
Return type:
    
array
# mlx.core.stack
stack(arrays: list[array], axis: int | None = 0, *, stream: None | Stream | Device = None) → array
    
Stacks the arrays along a new axis.
Parameters:
    
  * arrays (list(array)) – A list of arrays to stack.
  * axis (int, optional) – The axis in the result array along which the input arrays are stacked. Defaults to `0`.
  * stream (Stream, optional) – Stream or device. Defaults to `None`.


Returns:
    
The resulting stacked array.
Return type:
    
array
# mlx.core.std
std(a: array, /, axis: None | int | Sequence[int] = None, keepdims: bool = False, ddof: int = 0, *, stream: None | Stream | Device = None) → array
    
Compute the standard deviation(s) over the given axes.
Parameters:
    
  * a (array) – Input array.
  * axis (int or list(int), optional) – Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array.
  * keepdims (bool, optional) – Keep reduced axes as singleton dimensions, defaults to False.
  * ddof (int, optional) – The divisor to compute the variance is `N - ddof`, defaults to 0.


Returns:
    
The output array of standard deviations.
Return type:
    
array
# mlx.core.stop_gradient
stop_gradient(a: array, /, *, stream: None | Stream | Device = None) → array
    
Stop gradients from being computed.
The operation is the identity but it prevents gradients from flowing through the array.
Parameters:
    
a (array) – Input array.
Returns:
    
The unchanged input `a` but without gradient flowing through it.
Return type:
    
array
# mlx.core.stream
stream(s: Union[Stream, Device]) → mlx.core.StreamContext
    
Create a context manager to set the default device and stream.
Parameters:
    
s – The `Stream` or `Device` to set as the default.
Returns:
    
A context manager that sets the default device and stream.
Example:
# mlx.core.subtract
subtract(a: scalar | array, b: scalar | array, stream: None | Stream | Device = None) → array
    
Element-wise subtraction.
Subtract one array from another with numpy-style broadcasting semantics. Either or both input arrays can also be scalars.
Parameters:
    
  * a (array) – Input array or scalar.
  * b (array) – Input array or scalar.


Returns:
    
The difference `a - b`.
Return type:
    
array
# mlx.core.sum
sum(a: array, /, axis: None | int | Sequence[int] = None, keepdims: bool = False, *, stream: None | Stream | Device = None) → array
    
Sum reduce the array over the given axes.
Parameters:
    
  * a (array) – Input array.
  * axis (int or list(int), optional) – Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array.
  * keepdims (bool, optional) – Keep reduced axes as singleton dimensions, defaults to False.


Returns:
    
The output array with the corresponding axes reduced.
Return type:
    
array
# mlx.core.swapaxes
swapaxes(a: array, /, axis1: int, axis2: int, *, stream: None | Stream | Device = None) → array
    
Swap two axes of an array.
Parameters:
    
  * a (array) – Input array.
  * axis1 (int) – Specifies the first axis.
  * axis2 (int) – Specifies the second axis.


Returns:
    
The array with swapped axes.
Return type:
    
array
# mlx.core.synchronize
synchronize(stream: Optional[Stream] = None) → None
    
Synchronize with the given stream.
Parameters:
    
stream (Stream, optional) – The stream to synchronize with. If `None` then the default stream of the default device is used. Default: `None`.
# mlx.core.take
take(a: array, /, indices: int | array, axis: int | None = None, *, stream: None | Stream | Device = None) → array
    
Take elements along an axis.
The elements are taken from `indices` along the specified axis. If the axis is not specified the array is treated as a flattened 1-D array prior to performing the take.
As an example, if the `axis=1` this is equivalent to `a[:, indices, ...]`.
Parameters:
    
  * a (array) – Input array.
  * indices (int or array) – Integer index or input array with integral type.
  * axis (int, optional) – Axis along which to perform the take. If unspecified the array is treated as a flattened 1-D vector.


Returns:
    
The indexed values of `a`.
Return type:
    
array
# mlx.core.take_along_axis
take_along_axis(a: array, /, indices: array, axis: int | None = None, *, stream: None | Stream | Device = None) → array
    
Take values along an axis at the specified indices.
Parameters:
    
  * a (array) – Input array.
  * indices (array) – Indices array. These should be broadcastable with the input array excluding the axis dimension.
  * axis (int or None) – Axis in the input to take the values from. If `axis == None` the array is flattened to 1D prior to the indexing operation.


Returns:
    
The output array.
Return type:
    
array
# mlx.core.tan
tan(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise tangent.
Parameters:
    
a (array) – Input array.
Returns:
    
The tangent of `a`.
Return type:
    
array
# mlx.core.tanh
tanh(a: array, /, *, stream: None | Stream | Device = None) → array
    
Element-wise hyperbolic tangent.
Parameters:
    
a (array) – Input array.
Returns:
    
The hyperbolic tangent of `a`.
Return type:
    
array
# mlx.core.tensordot
tensordot(a: array, b: array, /, axes: int | list[Sequence[int]] = 2, *, stream: None | Stream | Device = None) → array
    
Compute the tensor dot product along the specified axes.
Parameters:
    
  * a (array) – Input array
  * b (array) – Input array
  * axes (int or list(list(int)), optional) – The number of dimensions to sum over. If an integer is provided, then sum over the last `axes` dimensions of `a` and the first `axes` dimensions of `b`. If a list of lists is provided, then sum over the corresponding dimensions of `a` and `b`. Default: 2.


Returns:
    
The tensor dot product.
Return type:
    
array
# mlx.core.tile
tile(a: array, reps: int | Sequence[int], /, *, stream: None | Stream | Device = None) → array
    
Construct an array by repeating `a` the number of times given by `reps`.
Parameters:
    
  * a (array) – Input array
  * reps (int or list(int)) – The number of times to repeat `a` along each axis.


Returns:
    
The tiled array.
Return type:
    
array
# mlx.core.topk
topk(a: array, /, k: int, axis: None | int = -1, *, stream: None | Stream | Device = None) → array
    
Returns the `k` largest elements from the input along a given axis.
The elements will not necessarily be in sorted order.
Parameters:
    
  * a (array) – Input array.
  * k (int) – `k` top elements to be returned
  * axis (int or None, optional) – Optional axis to select over. If `None`, this selects the top `k` elements over the flattened array. If unspecified, it defaults to `-1`.


Returns:
    
The top `k` elements from the input.
Return type:
    
array
# mlx.core.trace
trace(a: array, /, offset: int = 0, axis1: int = 0, axis2: int = 1, dtype: Dtype | None = None, *, stream: None | Stream | Device = None) → array
    
Return the sum along a specified diagonal in the given array.
Parameters:
    
  * a (array) – Input array
  * offset (int, optional) – Offset of the diagonal from the main diagonal. Can be positive or negative. Default: `0`.
  * axis1 (int, optional) – The first axis of the 2-D sub-arrays from which the diagonals should be taken. Default: `0`.
  * axis2 (int, optional) – The second axis of the 2-D sub-arrays from which the diagonals should be taken. Default: `1`.
  * dtype (Dtype, optional) – Data type of the output array. If unspecified the output type is inferred from the input array.


Returns:
    
Sum of specified diagonal.
Return type:
    
array
# mlx.core.transpose
transpose(a: array, /, axes: Sequence[int] | None = None, *, stream: None | Stream | Device = None) → array
    
Transpose the dimensions of the array.
Parameters:
    
  * a (array) – Input array.
  * axes (list(int), optional) – Specifies the source axis for each axis in the new array. The default is to reverse the axes.


Returns:
    
The transposed array.
Return type:
    
array
# mlx.core.tri
tri(n: int, m: int, k: int, dtype: Dtype | None = None, *, stream: None | Stream | Device = None) → array
    
An array with ones at and below the given diagonal and zeros elsewhere.
Parameters:
    
  * n (int) – The number of rows in the output.
  * m (int, optional) – The number of cols in the output. Defaults to `None`.
  * k (int, optional) – The diagonal of the 2-D array. Defaults to `0`.
  * dtype (Dtype, optional) – Data type of the output array. Defaults to `float32`.
  * stream (Stream, optional) – Stream or device. Defaults to `None`.


Returns:
    
Array with its lower triangle filled with ones and zeros elsewhere
Return type:
    
array
# mlx.core.tril
tril(x: array, k: int, *, stream: None | Stream | Device = None) → array
    
Zeros the array above the given diagonal.
Parameters:
    
  * x (array) – input array.
  * k (int, optional) – The diagonal of the 2-D array. Defaults to `0`.
  * stream (Stream, optional) – Stream or device. Defaults to `None`.


Returns:
    
Array zeroed above the given diagonal
Return type:
    
array
# mlx.core.triu
triu(x: array, k: int, *, stream: None | Stream | Device = None) → array
    
Zeros the array below the given diagonal.
Parameters:
    
  * x (array) – input array.
  * k (int, optional) – The diagonal of the 2-D array. Defaults to `0`.
  * stream (Stream, optional) – Stream or device. Defaults to `None`.


Returns:
    
Array zeroed below the given diagonal
Return type:
    
array
# mlx.core.unflatten
unflatten(a: array, /, axis: int, shape: Sequence[int], *, stream: None | Stream | Device = None) → array
    
Unflatten an axis of an array to a shape.
Parameters:
    
  * a (array) – Input array.
  * axis (int) – The axis to unflatten.
  * shape (tuple(int)) – The shape to unflatten to. At most one entry can be `-1` in which case the corresponding size will be inferred.
  * stream (Stream, optional) – Stream or device. Defaults to `None` in which case the default stream of the default device is used.


Returns:
    
The unflattened array.
Return type:
    
array
Example
    
    >>> a = mx.array([1, 2, 3, 4])
    >>> mx.unflatten(a, 0, (2, -1))
    array([[1, 2], [3, 4]], dtype=int32)
    
# mlx.core.value_and_grad
value_and_grad(fun: Callable, argnums: int | Sequence[int] | None = None, argnames: str | Sequence[str] = []) → Callable
    
Returns a function which computes the value and gradient of `fun`.
The function passed to `value_and_grad()` should return either a scalar loss or a tuple in which the first element is a scalar loss and the remaining elements can be anything.
    
    import mlx.core as mx
    
    def mse(params, inputs, targets):
        outputs = forward(params, inputs)
        lvalue = (outputs - targets).square().mean()
        return lvalue
    
    # Returns lvalue, dlvalue/dparams
    lvalue, grads = mx.value_and_grad(mse)(params, inputs, targets)
    
    def lasso(params, inputs, targets, a=1.0, b=1.0):
        outputs = forward(params, inputs)
        mse = (outputs - targets).square().mean()
        l1 = mx.abs(outputs - targets).mean()
    
        loss = a*mse + b*l1
    
        return loss, mse, l1
    
    (loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets)
    
Parameters:
    
  * fun (Callable) – A function which takes a variable number of `array` or trees of `array` and returns a scalar output `array` or a tuple the first element of which should be a scalar `array`.
  * argnums (int or list(int), optional) – Specify the index (or indices) of the positional arguments of `fun` to compute the gradient with respect to. If neither `argnums` nor `argnames` are provided `argnums` defaults to `0` indicating `fun`’s first argument.
  * argnames (str or list(str), optional) – Specify keyword arguments of `fun` to compute gradients with respect to. It defaults to [] so no gradients for keyword arguments by default.


Returns:
    
A function which returns a tuple where the first element is the output of fun and the second element is the gradients w.r.t. the loss.
Return type:
    
Callable
# mlx.core.var
var(a: array, /, axis: None | int | Sequence[int] = None, keepdims: bool = False, ddof: int = 0, *, stream: None | Stream | Device = None) → array
    
Compute the variance(s) over the given axes.
Parameters:
    
  * a (array) – Input array.
  * axis (int or list(int), optional) – Optional axis or axes to reduce over. If unspecified this defaults to reducing over the entire array.
  * keepdims (bool, optional) – Keep reduced axes as singleton dimensions, defaults to False.
  * ddof (int, optional) – The divisor to compute the variance is `N - ddof`, defaults to 0.


Returns:
    
The output array of variances.
Return type:
    
array
# mlx.core.view
view(a: scalar | array, dtype: Dtype, stream: None | Stream | Device = None) → array
    
View the array as a different type.
The output shape changes along the last axis if the input array’s type and the input `dtype` do not have the same size.
Note: the view op does not imply that the input and output arrays share their underlying data. The view only gaurantees that the binary representation of each element (or group of elements) is the same.
Parameters:
    
  * a (array) – Input array or scalar.
  * dtype (Dtype) – The data type to change to.


Returns:
    
The array with the new type.
Return type:
    
array
# mlx.core.vjp
vjp(fun: Callable, primals: list[array], cotangents: list[array]) → tuple[list[array], list[array]]
    
Compute the vector-Jacobian product.
Computes the product of the `cotangents` with the Jacobian of a function `fun` evaluated at `primals`.
Parameters:
    
  * fun (Callable) – A function which takes a variable number of `array` and returns a single `array` or list of `array`.
  * primals (list(array)) – A list of `array` at which to evaluate the Jacobian.
  * cotangents (list(array)) – A list of `array` which are the “vector” in the vector-Jacobian product. The `cotangents` should be the same in number, shape, and type as the outputs of `fun`.


Returns:
    
A list of the vector-Jacobian products which is the same in number, shape, and type of the outputs of `fun`.
Return type:
    
list(array)
# mlx.core.vmap
vmap(fun: Callable, in_axes: object = 0, out_axes: object = 0) → Callable
    
Returns a vectorized version of `fun`.
Parameters:
    
  * fun (Callable) – A function which takes a variable number of `array` or a tree of `array` and returns a variable number of `array` or a tree of `array`.
  * in_axes (int, optional) – An integer or a valid prefix tree of the inputs to `fun` where each node specifies the vmapped axis. If the value is `None` then the corresponding input(s) are not vmapped. Defaults to `0`.
  * out_axes (int, optional) – An integer or a valid prefix tree of the outputs of `fun` where each node specifies the vmapped axis. If the value is `None` then the corresponding outputs(s) are not vmapped. Defaults to `0`.


Returns:
    
The vectorized function.
Return type:
    
Callable
# mlx.core.where
where(condition: scalar | array, x: scalar | array, y: scalar | array, /, *, stream: None | Stream | Device = None) → array
    
Select from `x` or `y` according to `condition`.
The condition and input arrays must be the same shape or broadcastable with each another.
Parameters:
    
  * condition (array) – The condition array.
  * x (array) – The input selected from where condition is `True`.
  * y (array) – The input selected from where condition is `False`.


Returns:
    
The output containing elements selected from `x` and `y`.
Return type:
    
array
# mlx.core.zeros
zeros(shape: int | Sequence[int], dtype: Dtype | None = float32, *, stream: None | Stream | Device = None) → array
    
Construct an array of zeros.
Parameters:
    
  * shape (int or list(int)) – The shape of the output array.
  * dtype (Dtype, optional) – Data type of the output array. If unspecified the output type defaults to `float32`.


Returns:
    
The array of zeros with the specified shape.
Return type:
    
array
# mlx.core.zeros_like
zeros_like(a: array, /, *, stream: None | Stream | Device = None) → array
    
An array of zeros like the input.
Parameters:
    
a (array) – The input to take the shape and type from.
Returns:
    
The output array filled with zeros.
Return type:
    
array
# mlx.nn.average_gradients
average_gradients(gradients: Any, group: Group | None = None, all_reduce_size: int = 33554432, communication_type: Dtype | None = None, communication_stream: Stream | None = None)
    
Average the gradients across the distributed processes in the passed group.
This helper enables concatenating several gradients of small arrays to one big all reduce call for better networking performance.
Parameters:
    
  * gradients (Any) – The Python tree containing the gradients (it should have the same structure across processes)
  * group (Optional[Group]) – The group of processes to average the gradients. If set to `None` the global group is used. Default: `None`.
  * all_reduce_size (int) – Group arrays until their size in bytes exceeds this number. Perform one communication step per group of arrays. If less or equal to 0 array grouping is disabled. Default: `32MiB`.
  * communication_type (Optional[Dtype]) – If provided cast to this type before performing the communication. Typically cast to a smaller float to reduce the communication size. Default: `None`.
  * communication_stream (Optional[Stream]) – The stream to usse for the communication. If unspecified the default communication stream is used which can vary by back-end. Default: `None`.


# mlx.nn.quantize
quantize(model: Module, group_size: int = 64, bits: int = 4, *, mode: str = 'affine', class_predicate: Callable[[str, Module], bool | dict] | None = None)
    
Quantize the sub-modules of a module according to a predicate.
By default all layers that define a `to_quantized(group_size, bits)` method will be quantized. Both `Linear` and `Embedding` layers will be quantized. Note also, the module is updated in-place.
Parameters:
    
  * model (Module) – The model whose leaf modules may be quantized.
  * group_size (int) – The quantization group size (see `mlx.core.quantize()`). Default: `64`.
  * bits (int) – The number of bits per parameter (see `mlx.core.quantize()`). Default: `4`.
  * mode (str) – The quantization method to use (see `mlx.core.quantize()`). Default: `"affine"`.
  * class_predicate (Optional[Callable]) – A callable which receives the `Module` path and `Module` itself and returns `True` or a dict of params for to_quantized if it should be quantized and `False` otherwise. If `None`, then all layers that define a `to_quantized(group_size, bits)` method are quantized. Default: `None`.


# mlx.nn.value_and_grad
value_and_grad(model: Module, fn: Callable)
    
Transform the passed function `fn` to a function that computes the gradients of `fn` wrt the model’s trainable parameters and also its value.
Parameters:
    
  * model (Module) – The model whose trainable parameters to compute gradients for
  * fn (Callable) – The scalar function to compute gradients for


Returns:
    
A callable that returns the value of `fn` and the gradients wrt the trainable parameters of `model`
# mlx.optimizers.clip_grad_norm
clip_grad_norm(grads, max_norm)
    
Clips the global norm of the gradients.
This function ensures that the global norm of the gradients does not exceed `max_norm`. It scales down the gradients proportionally if their norm is greater than `max_norm`.
Example
    
    >>> grads = {"w1": mx.array([2, 3]), "w2": mx.array([1])}
    >>> clipped_grads, total_norm = clip_grad_norm(grads, max_norm=2.0)
    >>> print(clipped_grads)
    {"w1": mx.array([...]), "w2": mx.array([...])}
    
Parameters:
    
  * grads (dict) – A dictionary containing the gradient arrays.
  * max_norm (float) – The maximum allowed global norm of the gradients.


Returns:
    
The possibly rescaled gradients and the original gradient norm.
Return type:
    
(dict, float)
# mlx.utils.tree_flatten
tree_flatten(tree: Any, prefix: str = '', is_leaf: Callable | None = None, destination: List[Tuple[str, Any]] | Dict[str, Any] | None = None) → List[Tuple[str, Any]] | Dict[str, Any]
    
Flattens a Python tree to a list of key, value tuples.
The keys are using the dot notation to define trees of arbitrary depth and complexity.
    
    from mlx.utils import tree_flatten
    
    print(tree_flatten([[[0]]]))
    # [("0.0.0", 0)]
    
    print(tree_flatten([[[0]]], prefix=".hello"))
    # [("hello.0.0.0", 0)]
    
    tree_flatten({"a": {"b": 1}}, destination={})
    {"a.b": 1}
    
Note
Dictionaries should have keys that are valid Python identifiers.
Parameters:
    
  * tree (Any) – The Python tree to be flattened.
  * prefix (str) – A prefix to use for the keys. The first character is always discarded.
  * is_leaf (callable) – An optional callable that returns True if the passed object is considered a leaf or False otherwise.
  * destination (list or dict, optional) – A list or dictionary to store the flattened tree. If None an empty list will be used. Default: `None`.


Returns:
    
The flat representation of
    
the Python tree.
Return type:
    
Union[List[Tuple[str, Any]], Dict[str, Any]]
# mlx.utils.tree_map
tree_map(fn: Callable, tree: Any, *rest: Any, is_leaf: Callable | None = None) → Any
    
Applies `fn` to the leaves of the Python tree `tree` and returns a new collection with the results.
If `rest` is provided, every item is assumed to be a superset of `tree` and the corresponding leaves are provided as extra positional arguments to `fn`. In that respect, `tree_map()` is closer to `itertools.starmap()` than to `map()`.
The keyword argument `is_leaf` decides what constitutes a leaf from `tree` similar to `tree_flatten()`.
    
    import mlx.nn as nn
    from mlx.utils import tree_map
    
    model = nn.Linear(10, 10)
    print(model.parameters().keys())
    # dict_keys(['weight', 'bias'])
    
    # square the parameters
    model.update(tree_map(lambda x: x*x, model.parameters()))
    
Parameters:
    
  * fn (callable) – The function that processes the leaves of the tree.
  * tree (Any) – The main Python tree that will be iterated upon.
  * rest (tuple[Any]) – Extra trees to be iterated together with `tree`.
  * is_leaf (callable, optional) – An optional callable that returns `True` if the passed object is considered a leaf or `False` otherwise.


Returns:
    
A Python tree with the new values returned by `fn`.
# mlx.utils.tree_map_with_path
tree_map_with_path(fn: Callable, tree: Any, *rest: Any, is_leaf: Callable | None = None, path: Any | None = None) → Any
    
Applies `fn` to the path and leaves of the Python tree `tree` and returns a new collection with the results.
This function is the same `tree_map()` but the `fn` takes the path as the first argument followed by the remaining tree nodes.
Parameters:
    
  * fn (callable) – The function that processes the leaves of the tree.
  * tree (Any) – The main Python tree that will be iterated upon.
  * rest (tuple[Any]) – Extra trees to be iterated together with `tree`.
  * is_leaf (Optional[Callable]) – An optional callable that returns `True` if the passed object is considered a leaf or `False` otherwise.
  * path (Optional[Any]) – Prefix will be added to the result.


Returns:
    
A Python tree with the new values returned by `fn`.
Example
    
    >>> from mlx.utils import tree_map_with_path
    >>> tree = {"model": [{"w": 0, "b": 1}, {"w": 0, "b": 1}]}
    >>> new_tree = tree_map_with_path(lambda path, _: print(path), tree)
    model.0.w
    model.0.b
    model.1.w
    model.1.b
    
# mlx.utils.tree_reduce
tree_reduce(fn, tree, initializer=None, is_leaf=None)
    
Applies a reduction to the leaves of a Python tree.
This function reduces Python trees into an accumulated result by applying the provided function `fn` to the leaves of the tree.
Example
    
    >>> from mlx.utils import tree_reduce
    >>> tree = {"a": [1, 2, 3], "b": [4, 5]}
    >>> tree_reduce(lambda acc, x: acc + x, tree, 0)
    15
    
Parameters:
    
  * fn (callable) – The reducer function that takes two arguments (accumulator, current value) and returns the updated accumulator.
  * tree (Any) – The Python tree to reduce. It can be any nested combination of lists, tuples, or dictionaries.
  * initializer (Any, optional) – The initial value to start the reduction. If not provided, the first leaf value is used.
  * is_leaf (callable, optional) – A function to determine if an object is a leaf, returning `True` for leaf nodes and `False` otherwise.


Returns:
    
The accumulated value.
Return type:
    
Any
# mlx.utils.tree_unflatten
tree_unflatten(tree: List[Tuple[str, Any]] | Dict[str, Any]) → Any
    
Recreate a Python tree from its flat representation.
    
    from mlx.utils import tree_unflatten
    
    d = tree_unflatten([("hello.world", 42)])
    print(d)
    # {"hello": {"world": 42}}
    
    d = tree_unflatten({"hello.world": 42})
    print(d)
    # {"hello": {"world": 42}}
    
Parameters:
    
tree (list[tuple[str, Any]] or dict[str, Any]) – The flat representation of a Python tree. For instance as returned by `tree_flatten()`.
Returns:
    
A Python tree.
# mlx.core.Stream
class Stream
    
A stream for running operations on a given device.
__init__(*args, **kwargs)
    
Methods
`__init__`(*args, **kwargs)  
Attributes
`device`
(self) -> mlx.core.Device  
# Array
`array`
An N-dimensional array object.  
`array.astype`(self, dtype[, stream])
Cast the array to a specified type.  
`array.at`
Used to apply updates at the given indices.  
`array.item`(self)
Access the value of a scalar array.  
`array.tolist`(self)
Convert the array to a Python `list`.  
`array.dtype`
The array's `Dtype`.  
`array.itemsize`
The size of the array's datatype in bytes.  
`array.nbytes`
The number of bytes in the array.  
`array.ndim`
The array's dimension.  
`array.shape`
The shape of the array as a Python tuple.  
`array.size`
Number of elements in the array.  
`array.real`
The real part of a complex array.  
`array.imag`
The imaginary part of a complex array.  
`array.abs`(self, *[, stream])
See `abs()`.  
`array.all`(self[, axis, keepdims, stream])
See `all()`.  
`array.any`(self[, axis, keepdims, stream])
See `any()`.  
`array.argmax`(self[, axis, keepdims, stream])
See `argmax()`.  
`array.argmin`(self[, axis, keepdims, stream])
See `argmin()`.  
`array.conj`(self, *[, stream])
See `conj()`.  
`array.cos`(self, *[, stream])
See `cos()`.  
`array.cummax`(self[, axis, reverse, ...])
See `cummax()`.  
`array.cummin`(self[, axis, reverse, ...])
See `cummin()`.  
`array.cumprod`(self[, axis, reverse, ...])
See `cumprod()`.  
`array.cumsum`(self[, axis, reverse, ...])
See `cumsum()`.  
`array.diag`(self[, k, stream])
Extract a diagonal or construct a diagonal matrix.  
`array.diagonal`(self[, offset, axis1, axis2, ...])
See `diagonal()`.  
`array.exp`(self, *[, stream])
See `exp()`.  
`array.flatten`(self[, start_axis, end_axis, ...])
See `flatten()`.  
`array.log`(self, *[, stream])
See `log()`.  
`array.log10`(self, *[, stream])
See `log10()`.  
`array.log1p`(self, *[, stream])
See `log1p()`.  
`array.log2`(self, *[, stream])
See `log2()`.  
`array.logcumsumexp`(self[, axis, reverse, ...])
See `logcumsumexp()`.  
`array.logsumexp`(self[, axis, keepdims, stream])
See `logsumexp()`.  
`array.max`(self[, axis, keepdims, stream])
See `max()`.  
`array.mean`(self[, axis, keepdims, stream])
See `mean()`.  
`array.min`(self[, axis, keepdims, stream])
See `min()`.  
`array.moveaxis`(self, source, destination, *)
See `moveaxis()`.  
`array.prod`(self[, axis, keepdims, stream])
See `prod()`.  
`array.reciprocal`(self, *[, stream])
See `reciprocal()`.  
`array.reshape`(self, *shape[, stream])
Equivalent to `reshape()` but the shape can be passed either as a `tuple` or as separate arguments.  
`array.round`(self[, decimals, stream])
See `round()`.  
`array.rsqrt`(self, *[, stream])
See `rsqrt()`.  
`array.sin`(self, *[, stream])
See `sin()`.  
`array.split`(self, indices_or_sections[, ...])
See `split()`.  
`array.sqrt`(self, *[, stream])
See `sqrt()`.  
`array.square`(self, *[, stream])
See `square()`.  
`array.squeeze`(self[, axis, stream])
See `squeeze()`.  
`array.std`(self[, axis, keepdims, ddof, stream])
See `std()`.  
`array.sum`(self[, axis, keepdims, stream])
See `sum()`.  
`array.swapaxes`(self, axis1, axis2, *[, stream])
See `swapaxes()`.  
`array.transpose`(self, *axes[, stream])
Equivalent to `transpose()` but the axes can be passed either as a tuple or as separate arguments.  
`array.T`
Equivalent to calling `self.transpose()` with no arguments.  
`array.var`(self[, axis, keepdims, ddof, stream])
See `var()`.  
`array.view`(self, dtype, *[, stream])
See `view()`.  
# CUDA
`is_available`()
Check if the CUDA back-end is available.  
# Data Types
The default floating point type is `float32` and the default integer type is `int32`. The table below shows supported values for `Dtype`.
Supported Data Types
Type
Bytes
Description  
`bool_`
1
Boolean (`True`, `False`) data type  
`uint8`
1
8-bit unsigned integer  
`uint16`
2
16-bit unsigned integer  
`uint32`
4
32-bit unsigned integer  
`uint64`
8
64-bit unsigned integer  
`int8`
1
8-bit signed integer  
`int16`
2
16-bit signed integer  
`int32`
4
32-bit signed integer  
`int64`
8
64-bit signed integer  
`bfloat16`
2
16-bit brain float (e8, m7)  
`float16`
2
16-bit IEEE float (e5, m10)  
`float32`
4
32-bit float  
`float64`
4
64-bit double  
`complex64`
8
64-bit complex float  
Note
Arrays with type `float64` only work with CPU operations. Using `float64` arrays on the GPU will result in an exception.
Data type are aranged in a hierarchy. See the `DtypeCategory` object documentation for more information. Use `issubdtype()` to determine if one `dtype` (or category) is a subtype of another category.
`Dtype`
An object to hold the type of a `array`.  
`DtypeCategory`(value)
Type to hold categories of `dtypes`.  
`issubdtype`(arg1, arg2)
Check if a `Dtype` or `DtypeCategory` is a subtype of another.  
`finfo`
Get information on floating-point types.  
# Devices and Streams
`Device`
A device to run operations on.  
`Stream`
A stream for running operations on a given device.  
`default_device`()
Get the default device.  
`set_default_device`(device)
Set the default device.  
`default_stream`(device)
Get the device's default stream.  
`new_stream`(device)
Make a new stream on the given device.  
`set_default_stream`(stream)
Set the default stream.  
`stream`(s)
Create a context manager to set the default device and stream.  
`synchronize`([stream])
Synchronize with the given stream.  
# Distributed Communication
MLX provides a distributed communication package using MPI. The MPI library is loaded at runtime; if MPI is available then distributed communication is also made available.
`Group`
An `mlx.core.distributed.Group` represents a group of independent mlx processes that can communicate.  
`is_available`()
Check if a communication backend is available.  
`init`([strict, backend])
Initialize the communication backend and create the global communication group.  
`all_sum`(x, *[, group, stream])
All reduce sum.  
`all_gather`(x, *[, group, stream])
Gather arrays from all processes.  
`send`(x, dst, *[, group, stream])
Send an array from the current process to the process that has rank `dst` in the group.  
`recv`(shape, dtype, src, *[, group, stream])
Recv an array with shape `shape` and dtype `dtype` from process with rank `src`.  
`recv_like`(x, src, *[, group, stream])
Recv an array with shape and type like `x` from process with rank `src`.  
# Export Functions
`export_function`(file, fun, *args[, shapeless])
Export a function to a file.  
`import_function`(file)
Import a function from a file.  
`exporter`(file, fun, *[, shapeless])
Make a callable object to export multiple traces of a function to a file.  
`export_to_dot`(file, *args, **kwargs)
Export a graph to DOT format for visualization.  
# Fast
`rms_norm`(x, weight, eps, *[, stream])
Root Mean Square normalization (RMS norm).  
`layer_norm`(x, weight, bias, eps, *[, stream])
Layer normalization.  
`rope`(a, dims, *, traditional, base, scale, ...)
Apply rotary positional encoding to the input.  
`scaled_dot_product_attention`(q, k, v, *, scale)
A fast implementation of multi-head attention: `O = softmax(Q @ K.T, dim=-1) @ V`.  
`metal_kernel`(name, input_names, ...[, ...])
A jit-compiled custom Metal kernel defined from a source string.  
`cuda_kernel`(name, input_names, output_names, ...)
A jit-compiled custom CUDA kernel defined from a source string.  
# FFT
`fft`(a[, n, axis, stream])
One dimensional discrete Fourier Transform.  
`ifft`(a[, n, axis, stream])
One dimensional inverse discrete Fourier Transform.  
`fft2`(a[, s, axes, stream])
Two dimensional discrete Fourier Transform.  
`ifft2`(a[, s, axes, stream])
Two dimensional inverse discrete Fourier Transform.  
`fftn`(a[, s, axes, stream])
n-dimensional discrete Fourier Transform.  
`ifftn`(a[, s, axes, stream])
n-dimensional inverse discrete Fourier Transform.  
`rfft`(a[, n, axis, stream])
One dimensional discrete Fourier Transform on a real input.  
`irfft`(a[, n, axis, stream])
The inverse of `rfft()`.  
`rfft2`(a[, s, axes, stream])
Two dimensional real discrete Fourier Transform.  
`irfft2`(a[, s, axes, stream])
The inverse of `rfft2()`.  
`rfftn`(a[, s, axes, stream])
n-dimensional real discrete Fourier Transform.  
`irfftn`(a[, s, axes, stream])
The inverse of `rfftn()`.  
`fftshift`(a[, axes, stream])
Shift the zero-frequency component to the center of the spectrum.  
`ifftshift`(a[, axes, stream])
The inverse of `fftshift()`.  
# Linear Algebra
`inv`(a, *[, stream])
Compute the inverse of a square matrix.  
`tri_inv`(a[, upper, stream])
Compute the inverse of a triangular square matrix.  
`norm`(a, /[, ord, axis, keepdims, stream])
Matrix or vector norm.  
`cholesky`(a[, upper, stream])
Compute the Cholesky decomposition of a real symmetric positive semi-definite matrix.  
`cholesky_inv`(L[, upper, stream])
Compute the inverse of a real symmetric positive semi-definite matrix using it's Cholesky decomposition.  
`cross`(a, b[, axis, stream])
Compute the cross product of two arrays along a specified axis.  
`qr`(a, *[, stream])
The QR factorization of the input matrix.  
`svd`(a[, compute_uv, stream])
The Singular Value Decomposition (SVD) of the input matrix.  
`eigvals`(a, *[, stream])
Compute the eigenvalues of a square matrix.  
`eig`(a, *[, stream])
Compute the eigenvalues and eigenvectors of a square matrix.  
`eigvalsh`(a[, UPLO, stream])
Compute the eigenvalues of a complex Hermitian or real symmetric matrix.  
`eigh`(a[, UPLO, stream])
Compute the eigenvalues and eigenvectors of a complex Hermitian or real symmetric matrix.  
`lu`(a, *[, stream])
Compute the LU factorization of the given matrix `A`.  
`lu_factor`(a, *[, stream])
Computes a compact representation of the LU factorization.  
`pinv`(a, *[, stream])
Compute the (Moore-Penrose) pseudo-inverse of a matrix.  
`solve`(a, b, *[, stream])
Compute the solution to a system of linear equations `AX = B`.  
`solve_triangular`(a, b, *[, upper, stream])
Computes the solution of a triangular system of linear equations `AX = B`.  
# Memory Management
`get_active_memory`()
Get the actively used memory in bytes.  
`get_peak_memory`()
Get the peak amount of used memory in bytes.  
`reset_peak_memory`()
Reset the peak memory to zero.  
`get_cache_memory`()
Get the cache size in bytes.  
`set_memory_limit`(limit)
Set the memory limit.  
`set_cache_limit`(limit)
Set the free cache limit.  
`set_wired_limit`(limit)
Set the wired size limit.  
`clear_cache`()
Clear the memory cache.  
# Metal
`is_available`()
Check if the Metal back-end is available.  
`device_info`()
Get information about the GPU device and system settings.  
`start_capture`(path)
Start a Metal capture.  
`stop_capture`()
Stop a Metal capture.  
# Neural Networks
Writing arbitrarily complex neural networks in MLX can be done using only `mlx.core.array` and `mlx.core.value_and_grad()`. However, this requires the user to write again and again the same simple neural network operations as well as handle all the parameter state and initialization manually and explicitly.
The module `mlx.nn` solves this problem by providing an intuitive way of composing neural network layers, initializing their parameters, freezing them for finetuning and more.
## Quick Start with Neural Networks
    
    import mlx.core as mx
    import mlx.nn as nn
    
    class MLP(nn.Module):
        def __init__(self, in_dims: int, out_dims: int):
            super().__init__()
    
            self.layers = [
                nn.Linear(in_dims, 128),
                nn.Linear(128, 128),
                nn.Linear(128, out_dims),
            ]
    
        def __call__(self, x):
            for i, l in enumerate(self.layers):
                x = mx.maximum(x, 0) if i > 0 else x
                x = l(x)
            return x
    
    # The model is created with all its parameters but nothing is initialized
    # yet because MLX is lazily evaluated
    mlp = MLP(2, 10)
    
    # We can access its parameters by calling mlp.parameters()
    params = mlp.parameters()
    print(params["layers"][0]["weight"].shape)
    
    # Printing a parameter will cause it to be evaluated and thus initialized
    print(params["layers"][0])
    
    # We can also force evaluate all parameters to initialize the model
    mx.eval(mlp.parameters())
    
    # A simple loss function.
    # NOTE: It doesn't matter how it uses the mlp model. It currently captures
    #       it from the local scope. It could be a positional argument or a
    #       keyword argument.
    def l2_loss(x, y):
        y_hat = mlp(x)
        return (y_hat - y).square().mean()
    
    # Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the
    # gradient with respect to `mlp.trainable_parameters()`
    loss_and_grad = nn.value_and_grad(mlp, l2_loss)
    
## The Module Class
The workhorse of any neural network library is the `Module` class. In MLX the `Module` class is a container of `mlx.core.array` or `Module` instances. Its main function is to provide a way to recursively access and update its parameters and those of its submodules.
### Parameters
A parameter of a module is any public member of type `mlx.core.array` (its name should not start with `_`). It can be arbitrarily nested in other `Module` instances or lists and dictionaries.
`Module.parameters()` can be used to extract a nested dictionary with all the parameters of a module and its submodules.
A `Module` can also keep track of “frozen” parameters. See the `Module.freeze()` method for more details. `mlx.nn.value_and_grad()` the gradients returned will be with respect to these trainable parameters.
### Updating the Parameters
MLX modules allow accessing and updating individual parameters. However, most times we need to update large subsets of a module’s parameters. This action is performed by `Module.update()`.
### Inspecting Modules
The simplest way to see the model architecture is to print it. Following along with the above example, you can print the `MLP` with:
    
    print(mlp)
    
This will display:
    
    MLP(
      (layers.0): Linear(input_dims=2, output_dims=128, bias=True)
      (layers.1): Linear(input_dims=128, output_dims=128, bias=True)
      (layers.2): Linear(input_dims=128, output_dims=10, bias=True)
    )
    
To get more detailed information on the arrays in a `Module` you can use `mlx.utils.tree_map()` on the parameters. For example, to see the shapes of all the parameters in a `Module` do:
    
    from mlx.utils import tree_map
    shapes = tree_map(lambda p: p.shape, mlp.parameters())
    
As another example, you can count the number of parameters in a `Module` with:
    
    from mlx.utils import tree_flatten
    num_params = sum(v.size for _, v in tree_flatten(mlp.parameters()))
    
## Value and Grad
Using a `Module` does not preclude using MLX’s high order function transformations (`mlx.core.value_and_grad()`, `mlx.core.grad()`, etc.). However, these function transformations assume pure functions, namely the parameters should be passed as an argument to the function being transformed.
There is an easy pattern to achieve that with MLX modules
    
    model = ...
    
    def f(params, other_inputs):
        model.update(params)  # <---- Necessary to make the model use the passed parameters
        return model(other_inputs)
    
    f(model.trainable_parameters(), mx.zeros((10,)))
    
However, `mlx.nn.value_and_grad()` provides precisely this pattern and only computes the gradients with respect to the trainable parameters of the model.
In detail:
  * it wraps the passed function with a function that calls `Module.update()` to make sure the model is using the provided parameters.
  * it calls `mlx.core.value_and_grad()` to transform the function into a function that also computes the gradients with respect to the passed parameters.
  * it wraps the returned function with a function that passes the trainable parameters as the first argument to the function returned by `mlx.core.value_and_grad()`


`value_and_grad`(model, fn)
Transform the passed function `fn` to a function that computes the gradients of `fn` wrt the model's trainable parameters and also its value.  
`quantize`(model[, group_size, bits, mode, ...])
Quantize the sub-modules of a module according to a predicate.  
`average_gradients`(gradients[, group, ...])
Average the gradients across the distributed processes in the passed group.  
  * Module
    * `Module`
    * mlx.nn.Module.training
      * `Module.training`
    * mlx.nn.Module.state
      * `Module.state`
    * mlx.nn.Module.apply
      * `Module.apply()`
    * mlx.nn.Module.apply_to_modules
      * `Module.apply_to_modules()`
    * mlx.nn.Module.children
      * `Module.children()`
    * mlx.nn.Module.eval
      * `Module.eval()`
    * mlx.nn.Module.filter_and_map
      * `Module.filter_and_map()`
    * mlx.nn.Module.freeze
      * `Module.freeze()`
    * mlx.nn.Module.leaf_modules
      * `Module.leaf_modules()`
    * mlx.nn.Module.load_weights
      * `Module.load_weights()`
    * mlx.nn.Module.modules
      * `Module.modules()`
    * mlx.nn.Module.named_modules
      * `Module.named_modules()`
    * mlx.nn.Module.parameters
      * `Module.parameters()`
    * mlx.nn.Module.save_weights
      * `Module.save_weights()`
    * mlx.nn.Module.set_dtype
      * `Module.set_dtype()`
    * mlx.nn.Module.train
      * `Module.train()`
    * mlx.nn.Module.trainable_parameters
      * `Module.trainable_parameters()`
    * mlx.nn.Module.unfreeze
      * `Module.unfreeze()`
    * mlx.nn.Module.update
      * `Module.update()`
    * mlx.nn.Module.update_modules
      * `Module.update_modules()`
  * Layers
    * mlx.nn.ALiBi
      * `ALiBi`
    * mlx.nn.AvgPool1d
      * `AvgPool1d`
    * mlx.nn.AvgPool2d
      * `AvgPool2d`
    * mlx.nn.AvgPool3d
      * `AvgPool3d`
    * mlx.nn.BatchNorm
      * `BatchNorm`
    * mlx.nn.CELU
      * `CELU`
    * mlx.nn.Conv1d
      * `Conv1d`
    * mlx.nn.Conv2d
      * `Conv2d`
    * mlx.nn.Conv3d
      * `Conv3d`
    * mlx.nn.ConvTranspose1d
      * `ConvTranspose1d`
    * mlx.nn.ConvTranspose2d
      * `ConvTranspose2d`
    * mlx.nn.ConvTranspose3d
      * `ConvTranspose3d`
    * mlx.nn.Dropout
      * `Dropout`
    * mlx.nn.Dropout2d
      * `Dropout2d`
    * mlx.nn.Dropout3d
      * `Dropout3d`
    * mlx.nn.Embedding
      * `Embedding`
    * mlx.nn.ELU
      * `ELU`
    * mlx.nn.GELU
      * `GELU`
    * mlx.nn.GLU
      * `GLU`
    * mlx.nn.GroupNorm
      * `GroupNorm`
    * mlx.nn.GRU
      * `GRU`
    * mlx.nn.HardShrink
      * `HardShrink`
    * mlx.nn.HardTanh
      * `HardTanh`
    * mlx.nn.Hardswish
      * `Hardswish`
    * mlx.nn.InstanceNorm
      * `InstanceNorm`
    * mlx.nn.LayerNorm
      * `LayerNorm`
    * mlx.nn.LeakyReLU
      * `LeakyReLU`
    * mlx.nn.Linear
      * `Linear`
    * mlx.nn.LogSigmoid
      * `LogSigmoid`
    * mlx.nn.LogSoftmax
      * `LogSoftmax`
    * mlx.nn.LSTM
      * `LSTM`
    * mlx.nn.MaxPool1d
      * `MaxPool1d`
    * mlx.nn.MaxPool2d
      * `MaxPool2d`
    * mlx.nn.MaxPool3d
      * `MaxPool3d`
    * mlx.nn.Mish
      * `Mish`
    * mlx.nn.MultiHeadAttention
      * `MultiHeadAttention`
    * mlx.nn.PReLU
      * `PReLU`
    * mlx.nn.QuantizedEmbedding
      * `QuantizedEmbedding`
    * mlx.nn.QuantizedLinear
      * `QuantizedLinear`
    * mlx.nn.RMSNorm
      * `RMSNorm`
    * mlx.nn.ReLU
      * `ReLU`
    * mlx.nn.ReLU2
      * `ReLU2`
    * mlx.nn.ReLU6
      * `ReLU6`
    * mlx.nn.RNN
      * `RNN`
    * mlx.nn.RoPE
      * `RoPE`
    * mlx.nn.SELU
      * `SELU`
    * mlx.nn.Sequential
      * `Sequential`
    * mlx.nn.Sigmoid
      * `Sigmoid`
    * mlx.nn.SiLU
      * `SiLU`
    * mlx.nn.SinusoidalPositionalEncoding
      * `SinusoidalPositionalEncoding`
    * mlx.nn.Softmin
      * `Softmin`
    * mlx.nn.Softshrink
      * `Softshrink`
    * mlx.nn.Softsign
      * `Softsign`
    * mlx.nn.Softmax
      * `Softmax`
    * mlx.nn.Softplus
      * `Softplus`
    * mlx.nn.Step
      * `Step`
    * mlx.nn.Tanh
      * `Tanh`
    * mlx.nn.Transformer
      * `Transformer`
    * mlx.nn.Upsample
      * `Upsample`
  * Functions
    * mlx.nn.elu
      * `elu`
    * mlx.nn.celu
      * `celu`
    * mlx.nn.gelu
      * `gelu`
    * mlx.nn.gelu_approx
      * `gelu_approx`
    * mlx.nn.gelu_fast_approx
      * `gelu_fast_approx`
    * mlx.nn.glu
      * `glu`
    * mlx.nn.hard_shrink
      * `hard_shrink`
    * mlx.nn.hard_tanh
      * `hard_tanh`
    * mlx.nn.hardswish
      * `hardswish`
    * mlx.nn.leaky_relu
      * `leaky_relu`
    * mlx.nn.log_sigmoid
      * `log_sigmoid`
    * mlx.nn.log_softmax
      * `log_softmax`
    * mlx.nn.mish
      * `mish`
    * mlx.nn.prelu
      * `prelu`
    * mlx.nn.relu
      * `relu`
    * mlx.nn.relu2
      * `relu2`
    * mlx.nn.relu6
      * `relu6`
    * mlx.nn.selu
      * `selu`
    * mlx.nn.sigmoid
      * `sigmoid`
    * mlx.nn.silu
      * `silu`
    * mlx.nn.softmax
      * `softmax`
    * mlx.nn.softmin
      * `softmin`
    * mlx.nn.softplus
      * `softplus`
    * mlx.nn.softshrink
      * `softshrink`
    * mlx.nn.step
      * `step`
    * mlx.nn.tanh
      * `tanh`
  * Loss Functions
    * mlx.nn.losses.binary_cross_entropy
      * `binary_cross_entropy`
    * mlx.nn.losses.cosine_similarity_loss
      * `cosine_similarity_loss`
    * mlx.nn.losses.cross_entropy
      * `cross_entropy`
    * mlx.nn.losses.gaussian_nll_loss
      * `gaussian_nll_loss`
    * mlx.nn.losses.hinge_loss
      * `hinge_loss`
    * mlx.nn.losses.huber_loss
      * `huber_loss`
    * mlx.nn.losses.kl_div_loss
      * `kl_div_loss`
    * mlx.nn.losses.l1_loss
      * `l1_loss`
    * mlx.nn.losses.log_cosh_loss
      * `log_cosh_loss`
    * mlx.nn.losses.margin_ranking_loss
      * `margin_ranking_loss`
    * mlx.nn.losses.mse_loss
      * `mse_loss`
    * mlx.nn.losses.nll_loss
      * `nll_loss`
    * mlx.nn.losses.smooth_l1_loss
      * `smooth_l1_loss`
    * mlx.nn.losses.triplet_loss
      * `triplet_loss`
  * Initializers
    * mlx.nn.init.constant
      * `constant()`
    * mlx.nn.init.normal
      * `normal()`
    * mlx.nn.init.uniform
      * `uniform()`
    * mlx.nn.init.identity
      * `identity()`
    * mlx.nn.init.glorot_normal
      * `glorot_normal()`
    * mlx.nn.init.glorot_uniform
      * `glorot_uniform()`
    * mlx.nn.init.he_normal
      * `he_normal()`
    * mlx.nn.init.he_uniform
      * `he_uniform()`


# mlx.nn.ALiBi
class ALiBi
    
Methods
`create_alibi_matrix`(q_sequence_length, ...)  
`create_alibi_slope`(num_heads)  
# mlx.nn.AvgPool1d
class AvgPool1d(kernel_size: int | Tuple[int], stride: int | Tuple[int] | None = None, padding: int | Tuple[int] = 0)
    
Applies 1-dimensional average pooling.
Spatially downsamples the input by taking the average of a sliding window of size `kernel_size` and sliding stride `stride`.
Parameters:
    
  * kernel_size (int or tuple(int)) – The size of the pooling window kernel.
  * stride (int or tuple(int), optional) – The stride of the pooling window. Default: `kernel_size`.
  * padding (int or tuple(int), optional) – How much zero padding to apply to the input. The padding amount is applied to both sides of the spatial axis. Default: `0`.


Examples
    
    >>> import mlx.core as mx
    >>> import mlx.nn.layers as nn
    >>> x = mx.random.normal(shape=(4, 16, 5))
    >>> pool = nn.AvgPool1d(kernel_size=2, stride=2)
    >>> pool(x)
    
Methods
# mlx.nn.AvgPool2d
class AvgPool2d(kernel_size: int | Tuple[int, int], stride: int | Tuple[int, int] | None = None, padding: int | Tuple[int, int] | None = 0)
    
Applies 2-dimensional average pooling.
Spatially downsamples the input by taking the average of a sliding window of size `kernel_size` and sliding stride `stride`.
The parameters `kernel_size`, `stride`, and `padding` can either be:
  * a single `int` – in which case the same value is used for both the height and width axis.
  * a `tuple` of two `int` s – in which case, the first `int` is used for the height axis, the second `int` for the width axis.


Parameters:
    
  * kernel_size (int or tuple(int, int)) – The size of the pooling window.
  * stride (int or tuple(int, int), optional) – The stride of the pooling window. Default: `kernel_size`.
  * padding (int or tuple(int, int), optional) – How much zero padding to apply to the input. The padding is applied on both sides of the height and width axis. Default: `0`.


Examples
    
    >>> import mlx.core as mx
    >>> import mlx.nn.layers as nn
    >>> x = mx.random.normal(shape=(8, 32, 32, 4))
    >>> pool = nn.AvgPool2d(kernel_size=2, stride=2)
    >>> pool(x)
    
Methods
# mlx.nn.AvgPool3d
class AvgPool3d(kernel_size: int | Tuple[int, int, int], stride: int | Tuple[int, int, int] | None = None, padding: int | Tuple[int, int, int] | None = 0)
    
Applies 3-dimensional average pooling.
Spatially downsamples the input by taking the average of a sliding window of size `kernel_size` and sliding stride `stride`.
The parameters `kernel_size`, `stride`, and `padding` can either be:
  * a single `int` – in which case the same value is used for the depth, height, and width axis.
  * a `tuple` of three `int` s – in which case, the first `int` is used for the depth axis, the second `int` for the height axis, and the third `int` for the width axis.


Parameters:
    
  * kernel_size (int or tuple(int, int, int)) – The size of the pooling window.
  * stride (int or tuple(int, int, int), optional) – The stride of the pooling window. Default: `kernel_size`.
  * padding (int or tuple(int, int, int), optional) – How much zero padding to apply to the input. The padding is applied on both sides of the depth, height and width axis. Default: `0`.


Examples
    
    >>> import mlx.core as mx
    >>> import mlx.nn.layers as nn
    >>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
    >>> pool = nn.AvgPool3d(kernel_size=2, stride=2)
    >>> pool(x)
    
Methods
# mlx.nn.BatchNorm
class BatchNorm(num_features: int, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True)
    
Applies Batch Normalization over a 2D or 3D input.
Computes
\\[y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,\\]
where \\(\gamma\\) and \\(\beta\\) are learned per feature dimension parameters initialized at 1 and 0 respectively.
The input shape is specified as `NC` or `NLC`, where `N` is the batch, `C` is the number of features or channels, and `L` is the sequence length. The output has the same shape as the input. For four-dimensional arrays, the shape is `NHWC`, where `H` and `W` are the height and width respectively.
For more information on Batch Normalization, see the original paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift.
Parameters:
    
  * num_features (int) – The feature dimension to normalize over.
  * eps (float, optional) – A small additive constant for numerical stability. Default: `1e-5`.
  * momentum (float, optional) – The momentum for updating the running mean and variance. Default: `0.1`.
  * affine (bool, optional) – If `True`, apply a learned affine transformation after the normalization. Default: `True`.
  * track_running_stats (bool, optional) – If `True`, track the running mean and variance. Default: `True`.


Examples
    
    >>> import mlx.core as mx
    >>> import mlx.nn as nn
    >>> x = mx.random.normal((5, 4))
    >>> bn = nn.BatchNorm(num_features=4, affine=True)
    >>> output = bn(x)
    
Methods
`unfreeze`(*args, **kwargs)
Wrap unfreeze to make sure that running_mean and var are always frozen parameters.  
# mlx.nn.CELU
class CELU(alpha=1.0)
    
Applies the Continuously Differentiable Exponential Linear Unit.
    
Applies \\(\max(0, x) + \min(0, \alpha * (\exp(x / \alpha) - 1))\\) element wise.
See `celu()` for the functional equivalent.
Parameters:
    
alpha – the \\(\alpha\\) value for the CELU formulation. Default: `1.0`
Methods
# mlx.nn.Conv1d
class Conv1d(in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, dilation: int = 1, groups: int = 1, bias: bool = True)
    
Applies a 1-dimensional convolution over the multi-channel input sequence.
The channels are expected to be last i.e. the input shape should be `NLC` where:
  * `N` is the batch dimension
  * `L` is the sequence length
  * `C` is the number of input channels


Parameters:
    
  * in_channels (int) – The number of input channels
  * out_channels (int) – The number of output channels
  * kernel_size (int) – The size of the convolution filters
  * stride (int, optional) – The stride when applying the filter. Default: `1`.
  * padding (int, optional) – How many positions to 0-pad the input with. Default: `0`.
  * dilation (int, optional) – The dilation of the convolution.
  * groups (int, optional) – The number of groups for the convolution. Default: `1`.
  * bias (bool, optional) – If `True` add a learnable bias to the output. Default: `True`


Methods
# mlx.nn.Conv2d
class Conv2d(in_channels: int, out_channels: int, kernel_size: int | tuple, stride: int | tuple = 1, padding: int | tuple = 0, dilation: int | tuple = 1, groups: int = 1, bias: bool = True)
    
Applies a 2-dimensional convolution over the multi-channel input image.
The channels are expected to be last i.e. the input shape should be `NHWC` where:
  * `N` is the batch dimension
  * `H` is the input image height
  * `W` is the input image width
  * `C` is the number of input channels


Parameters:
    
  * in_channels (int) – The number of input channels.
  * out_channels (int) – The number of output channels.
  * kernel_size (int or tuple) – The size of the convolution filters.
  * stride (int or tuple, optional) – The size of the stride when applying the filter. Default: `1`.
  * padding (int or tuple, optional) – How many positions to 0-pad the input with. Default: `0`.
  * dilation (int or tuple, optional) – The dilation of the convolution.
  * groups (int, optional) – The number of groups for the convolution. Default: `1`.
  * bias (bool, optional) – If `True` add a learnable bias to the output. Default: `True`


Methods
# mlx.nn.Conv3d
class Conv3d(in_channels: int, out_channels: int, kernel_size: int | tuple, stride: int | tuple = 1, padding: int | tuple = 0, dilation: int | tuple = 1, bias: bool = True)
    
Applies a 3-dimensional convolution over the multi-channel input image.
The channels are expected to be last i.e. the input shape should be `NDHWC` where:
  * `N` is the batch dimension
  * `D` is the input image depth
  * `H` is the input image height
  * `W` is the input image width
  * `C` is the number of input channels


Parameters:
    
  * in_channels (int) – The number of input channels.
  * out_channels (int) – The number of output channels.
  * kernel_size (int or tuple) – The size of the convolution filters.
  * stride (int or tuple, optional) – The size of the stride when applying the filter. Default: `1`.
  * dilation (int or tuple, optional) – The dilation of the convolution.
  * padding (int or tuple, optional) – How many positions to 0-pad the input with. Default: `0`.
  * bias (bool, optional) – If `True` add a learnable bias to the output. Default: `True`


Methods
# mlx.nn.ConvTranspose1d
class ConvTranspose1d(in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, padding: int = 0, dilation: int = 1, output_padding: int = 0, bias: bool = True)
    
Applies a 1-dimensional transposed convolution over the multi-channel input sequence.
The channels are expected to be last i.e. the input shape should be `NLC` where:
  * `N` is the batch dimension
  * `L` is the sequence length
  * `C` is the number of input channels


Parameters:
    
  * in_channels (int) – The number of input channels
  * out_channels (int) – The number of output channels
  * kernel_size (int) – The size of the convolution filters
  * stride (int, optional) – The stride when applying the filter. Default: `1`.
  * padding (int, optional) – How many positions to 0-pad the input with. Default: `0`.
  * dilation (int, optional) – The dilation of the convolution.
  * output_padding (int, optional) – Additional size added to one side of the output shape. Default: `0`.
  * bias (bool, optional) – If `True` add a learnable bias to the output. Default: `True`


Methods
# mlx.nn.ConvTranspose2d
class ConvTranspose2d(in_channels: int, out_channels: int, kernel_size: int | tuple, stride: int | tuple = 1, padding: int | tuple = 0, dilation: int | tuple = 1, output_padding: int | tuple = 0, bias: bool = True)
    
Applies a 2-dimensional transposed convolution over the multi-channel input image.
The channels are expected to be last i.e. the input shape should be `NHWC` where:
  * `N` is the batch dimension
  * `H` is the input image height
  * `W` is the input image width
  * `C` is the number of input channels


Parameters:
    
  * in_channels (int) – The number of input channels.
  * out_channels (int) – The number of output channels.
  * kernel_size (int or tuple) – The size of the convolution filters.
  * stride (int or tuple, optional) – The size of the stride when applying the filter. Default: `1`.
  * padding (int or tuple, optional) – How many positions to 0-pad the input with. Default: `0`.
  * dilation (int or tuple, optional) – The dilation of the convolution.
  * output_padding (int or tuple, optional) – Additional size added to one side of the output shape. Default: `0`.
  * bias (bool, optional) – If `True` add a learnable bias to the output. Default: `True`


Methods
# mlx.nn.ConvTranspose3d
class ConvTranspose3d(in_channels: int, out_channels: int, kernel_size: int | tuple, stride: int | tuple = 1, padding: int | tuple = 0, dilation: int | tuple = 1, output_padding: int | tuple = 0, bias: bool = True)
    
Applies a 3-dimensional transposed convolution over the multi-channel input image.
The channels are expected to be last i.e. the input shape should be `NDHWC` where:
  * `N` is the batch dimension
  * `D` is the input image depth
  * `H` is the input image height
  * `W` is the input image width
  * `C` is the number of input channels


Parameters:
    
  * in_channels (int) – The number of input channels.
  * out_channels (int) – The number of output channels.
  * kernel_size (int or tuple) – The size of the convolution filters.
  * stride (int or tuple, optional) – The size of the stride when applying the filter. Default: `1`.
  * padding (int or tuple, optional) – How many positions to 0-pad the input with. Default: `0`.
  * dilation (int or tuple, optional) – The dilation of the convolution.
  * output_padding (int or tuple, optional) – Additional size added to one side of the output shape. Default: `0`.
  * bias (bool, optional) – If `True` add a learnable bias to the output. Default: `True`


Methods
# mlx.nn.Dropout
class Dropout(p: float = 0.5)
    
Randomly zero a portion of the elements during training.
The remaining elements are multiplied with \\(\frac{1}{1-p}\\) where \\(p\\) is the probability of zeroing an element. This is done so the expected value of a given element will remain the same.
Parameters:
    
p (float) – The probability to zero an element
Methods
# mlx.nn.Dropout2d
class Dropout2d(p: float = 0.5)
    
Apply 2D channel-wise dropout during training.
Randomly zero out entire channels independently with probability \\(p\\). This layer expects the channels to be last, i.e. the input shape should be `NWHC` or `WHC` where:`N` is the batch dimension,``H`` is the input image height,``W`` is the input image width, and``C`` is the number of input channels
The remaining channels are scaled by \\(\frac{1}{1-p}\\) to maintain the expected value of each element. Unlike traditional dropout, which zeros individual entries, this layer zeros entire channels. This is beneficial for early convolution layers where adjacent pixels are correlated. In such case, traditional dropout may not effectively regularize activations. For more details, see [1].
[1]: Thompson, J., Goroshin, R., Jain, A., LeCun, Y. and Bregler C., 2015. Efficient Object Localization Using Convolutional Networks. CVPR 2015.
Parameters:
    
p (float) – Probability of zeroing a channel during training.
Methods
# mlx.nn.Dropout3d
class Dropout3d(p: float = 0.5)
    
Apply 3D channel-wise dropout during training.
Randomly zero out entire channels independently with probability \\(p\\). This layer expects the channels to be last, i.e., the input shape should be NDHWC or DHWC where: N is the batch dimension, D is the depth, H is the input image height, W is the input image width, and C is the number of input channels.
The remaining channels are scaled by \\(\frac{1}{1-p}\\) to maintain the expected value of each element. Unlike traditional dropout, which zeros individual entries, this layer zeros entire channels. This is often beneficial for convolutional layers processing 3D data, like in medical imaging or video processing.
Parameters:
    
p (float) – Probability of zeroing a channel during training.
Methods
# mlx.nn.ELU
class ELU(alpha=1.0)
    
Applies the Exponential Linear Unit.
    
Simply `mx.where(x > 0, x, alpha * (mx.exp(x) - 1))`.
See `elu()` for the functional equivalent.
Parameters:
    
alpha – the \\(\alpha\\) value for the ELU formulation. Default: `1.0`
Methods
# mlx.nn.Embedding
class Embedding(num_embeddings: int, dims: int)
    
Implements a simple lookup table that maps each input integer to a high-dimensional vector.
Typically used to embed discrete tokens for processing by neural networks.
Parameters:
    
  * num_embeddings (int) – How many possible discrete tokens can we embed. Usually called the vocabulary size.
  * dims (int) – The dimensionality of the embeddings.


Methods
`as_linear`(x)
Call the embedding layer as a linear layer.  
`to_quantized`([group_size, bits, mode])
Return a `QuantizedEmbedding` layer that approximates this embedding layer.  
# mlx.nn.GELU
class GELU(approx='none')
    
Applies the Gaussian Error Linear Units.
\\[\textrm{GELU}(x) = x * \Phi(x)\\]
where \\(\Phi(x)\\) is the Gaussian CDF.
However, if `approx` is set to ‘precise’ or ‘fast’ it applies
\\[\begin{split}\textrm{GELUApprox}(x) &= 0.5 * x * \left(1 + \text{Tanh}\left((\sqrt{2 / \pi} * \left(x + 0.044715 * x^3\right)\right)\right) \\\ \textrm{GELUFast}(x) &= x * \sigma\left(1.702 * x\right)\end{split}\\]
respectively.
Note
For compatibility with the PyTorch API, ‘tanh’ can be used as an alias for ‘precise’.
See `gelu()`, `gelu_approx()` and `gelu_fast_approx()` for the functional equivalents and information regarding error bounds.
Parameters:
    
approx ('none' | 'precise' | 'fast') – Which approximation to gelu to use if any.
Methods
# mlx.nn.GLU
class GLU(axis: int = -1)
    
Applies the gated linear unit function.
This function splits the `axis` dimension of the input into two halves (\\(a\\) and \\(b\\)) and applies \\(a * \sigma(b)\\).
\\[\textrm{GLU}(x) = a * \sigma(b)\\]
Parameters:
    
axis (int) – The dimension to split along. Default: `-1`
Methods
# mlx.nn.GRU
class GRU(input_size: int, hidden_size: int, bias: bool = True)
    
A gated recurrent unit (GRU) RNN layer.
The input has shape `NLD` or `LD` where:
  * `N` is the optional batch dimension
  * `L` is the sequence length
  * `D` is the input’s feature dimension


Concretely, for each element of the sequence, this layer computes:
\\[\begin{split}\begin{aligned} r_t &= \sigma (W_{xr}x_t + W_{hr}h_t + b_{r}) \\\ z_t &= \sigma (W_{xz}x_t + W_{hz}h_t + b_{z}) \\\ n_t &= \text{tanh}(W_{xn}x_t + b_{n} + r_t \odot (W_{hn}h_t + b_{hn})) \\\ h_{t + 1} &= (1 - z_t) \odot n_t + z_t \odot h_t \end{aligned}\end{split}\\]
The hidden state \\(h\\) has shape `NH` or `H` depending on whether the input is batched or not. Returns the hidden state at each time step of shape `NLH` or `LH`.
Parameters:
    
  * input_size (int) – Dimension of the input, `D`.
  * hidden_size (int) – Dimension of the hidden state, `H`.
  * bias (bool) – Whether to use biases or not. Default: `True`.


Methods
# mlx.nn.GroupNorm
class GroupNorm(num_groups: int, dims: int, eps: float = 1e-05, affine: bool = True, pytorch_compatible: bool = False)
    
Applies Group Normalization [1] to the inputs.
Computes the same normalization as layer norm, namely
\\[y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,\\]
where \\(\gamma\\) and \\(\beta\\) are learned per feature dimension parameters initialized at 1 and 0 respectively. However, the mean and variance are computed over the spatial dimensions and each group of features. In particular, the input is split into num_groups across the feature dimension.
The feature dimension is assumed to be the last dimension and the dimensions that precede it (except the first) are considered the spatial dimensions.
[1]: https://arxiv.org/abs/1803.08494
Parameters:
    
  * num_groups (int) – Number of groups to separate the features into
  * dims (int) – The feature dimensions of the input to normalize over
  * eps (float) – A small additive constant for numerical stability
  * affine (bool) – If True learn an affine transform to apply after the normalization.
  * pytorch_compatible (bool) – If True perform the group normalization in the same order/grouping as PyTorch.


Methods
# mlx.nn.HardShrink
class HardShrink
    
Applies the HardShrink function.
See `hard_shrink()` for the functional equivalent.
Parameters:
    
lambd – the \\(\lambda\\) value for Hardshrink. Default: `0.5`
Methods
# mlx.nn.HardTanh
class HardTanh
    
Applies the HardTanh function.
See `hard_tanh()` for the functional equivalent.
Methods
# mlx.nn.Hardswish
class Hardswish
    
Applies the hardswish function, element-wise.
See `hardswish()` for the functional equivalent.
Methods
# mlx.nn.InstanceNorm
class InstanceNorm(dims: int, eps: float = 1e-05, affine: bool = False)
    
Applies instance normalization [1] on the inputs.
Computes
\\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta,\\]
where \\(\gamma\\) and \\(\beta\\) are learned per feature dimension parameters initialized at 1 and 0 respectively. Both are of size `dims`, if `affine` is `True`.
Parameters:
    
  * dims (int) – The number of features of the input.
  * eps (float) – A value added to the denominator for numerical stability. Default: `1e-5`.
  * affine (bool) – Default: `False`.


Shape:
    
  * Input: \\((..., C)\\) where \\(C\\) is equal to `dims`.
  * Output: Same shape as the input.


Examples
    
    >>> import mlx.core as mx
    >>> import mlx.nn as nn
    >>> x = mx.random.normal((8, 4, 4, 16))
    >>> inorm = nn.InstanceNorm(dims=16)
    >>> output = inorm(x)
    
References
[1]: https://arxiv.org/abs/1607.08022
Methods
# mlx.nn.LSTM
class LSTM(input_size: int, hidden_size: int, bias: bool = True)
    
An LSTM recurrent layer.
The input has shape `NLD` or `LD` where:
  * `N` is the optional batch dimension
  * `L` is the sequence length
  * `D` is the input’s feature dimension


Concretely, for each element of the sequence, this layer computes:
\\[\begin{split}\begin{aligned} i_t &= \sigma (W_{xi}x_t + W_{hi}h_t + b_{i}) \\\ f_t &= \sigma (W_{xf}x_t + W_{hf}h_t + b_{f}) \\\ g_t &= \text{tanh} (W_{xg}x_t + W_{hg}h_t + b_{g}) \\\ o_t &= \sigma (W_{xo}x_t + W_{ho}h_t + b_{o}) \\\ c_{t + 1} &= f_t \odot c_t + i_t \odot g_t \\\ h_{t + 1} &= o_t \text{tanh}(c_{t + 1}) \end{aligned}\end{split}\\]
The hidden state \\(h\\) and cell state \\(c\\) have shape `NH` or `H`, depending on whether the input is batched or not.
The layer returns two arrays, the hidden state and the cell state at each time step, both of shape `NLH` or `LH`.
Parameters:
    
  * input_size (int) – Dimension of the input, `D`.
  * hidden_size (int) – Dimension of the hidden state, `H`.
  * bias (bool) – Whether to use biases or not. Default: `True`.


Methods
# mlx.nn.LayerNorm
class LayerNorm(dims: int, eps: float = 1e-05, affine: bool = True, bias: bool = True)
    
Applies layer normalization [1] on the inputs.
Computes
\\[y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,\\]
where \\(\gamma\\) and \\(\beta\\) are learned per feature dimension parameters initialized at 1 and 0 respectively.
[1]: https://arxiv.org/abs/1607.06450
Parameters:
    
  * dims (int) – The feature dimension of the input to normalize over
  * eps (float) – A small additive constant for numerical stability
  * affine (bool) – If True learn an affine transform to apply after the normalization
  * bias (bool) – If True include a translation to the affine transformation. If set to False the transformation is not really affine just scaling.


Methods
# mlx.nn.LeakyReLU
class LeakyReLU(negative_slope=0.01)
    
Applies the Leaky Rectified Linear Unit.
Simply `mx.maximum(negative_slope * x, x)`.
Parameters:
    
negative_slope – Controls the angle of the negative slope. Default: `1e-2`
Methods
# mlx.nn.Linear
class Linear(input_dims: int, output_dims: int, bias: bool = True)
    
Applies an affine transformation to the input.
Concretely:
\\[y = x W^\top + b\\]
where: where \\(W\\) has shape `[output_dims, input_dims]` and \\(b\\) has shape `[output_dims]`.
The values are initialized from the uniform distribution \\(\mathcal{U}(-{k}, {k})\\), where \\(k = \frac{1}{\sqrt{D_i}}\\) and \\(D_i\\) is equal to `input_dims`.
Parameters:
    
  * input_dims (int) – The dimensionality of the input features
  * output_dims (int) – The dimensionality of the output features
  * bias (bool, optional) – If set to `False` then the layer will not use a bias. Default is `True`.


Methods
`to_quantized`([group_size, bits, mode])
Return a `QuantizedLinear` layer that approximates this layer.  
# mlx.nn.LogSigmoid
class LogSigmoid
    
Applies the Log Sigmoid function.
See `log_sigmoid()` for the functional equivalent.
Methods
# mlx.nn.LogSoftmax
class LogSoftmax
    
Applies the Log Softmax function.
See `log_softmax()` for the functional equivalent.
Methods
# mlx.nn.MaxPool1d
class MaxPool1d(kernel_size: int | Tuple[int], stride: int | Tuple[int] | None = None, padding: int | Tuple[int] = 0)
    
Applies 1-dimensional max pooling.
Spatially downsamples the input by taking the maximum of a sliding window of size `kernel_size` and sliding stride `stride`.
Parameters:
    
  * kernel_size (int or tuple(int)) – The size of the pooling window kernel.
  * stride (int or tuple(int), optional) – The stride of the pooling window. Default: `kernel_size`.
  * padding (int or tuple(int), optional) – How much negative infinity padding to apply to the input. The padding amount is applied to both sides of the spatial axis. Default: `0`.


Examples
    
    >>> import mlx.core as mx
    >>> import mlx.nn.layers as nn
    >>> x = mx.random.normal(shape=(4, 16, 5))
    >>> pool = nn.MaxPool1d(kernel_size=2, stride=2)
    >>> pool(x)
    
Methods
# mlx.nn.MaxPool2d
class MaxPool2d(kernel_size: int | Tuple[int, int], stride: int | Tuple[int, int] | None = None, padding: int | Tuple[int, int] | None = 0)
    
Applies 2-dimensional max pooling.
Spatially downsamples the input by taking the maximum of a sliding window of size `kernel_size` and sliding stride `stride`.
The parameters `kernel_size`, `stride`, and `padding` can either be:
  * a single `int` – in which case the same value is used for both the height and width axis.
  * a `tuple` of two `int` s – in which case, the first `int` is used for the height axis, the second `int` for the width axis.


Parameters:
    
  * kernel_size (int or tuple(int, int)) – The size of the pooling window.
  * stride (int or tuple(int, int), optional) – The stride of the pooling window. Default: `kernel_size`.
  * padding (int or tuple(int, int), optional) – How much negative infinity padding to apply to the input. The padding is applied on both sides of the height and width axis. Default: `0`.


Examples
    
    >>> import mlx.core as mx
    >>> import mlx.nn.layers as nn
    >>> x = mx.random.normal(shape=(8, 32, 32, 4))
    >>> pool = nn.MaxPool2d(kernel_size=2, stride=2)
    >>> pool(x)
    
Methods
# mlx.nn.MaxPool3d
class MaxPool3d(kernel_size: int | Tuple[int, int, int], stride: int | Tuple[int, int, int] | None = None, padding: int | Tuple[int, int, int] | None = 0)
    
Applies 3-dimensional max pooling.
Spatially downsamples the input by taking the maximum of a sliding window of size `kernel_size` and sliding stride `stride`.
The parameters `kernel_size`, `stride`, and `padding` can either be:
  * a single `int` – in which case the same value is used for the depth, height, and width axis.
  * a `tuple` of three `int` s – in which case, the first `int` is used for the depth axis, the second `int` for the height axis, and the third `int` for the width axis.


Parameters:
    
  * kernel_size (int or tuple(int, int, int)) – The size of the pooling window.
  * stride (int or tuple(int, int, int), optional) – The stride of the pooling window. Default: `kernel_size`.
  * padding (int or tuple(int, int, int), optional) – How much negative infinity padding to apply to the input. The padding is applied on both sides of the depth, height and width axis. Default: `0`.


Examples
    
    >>> import mlx.core as mx
    >>> import mlx.nn.layers as nn
    >>> x = mx.random.normal(shape=(8, 16, 32, 32, 4))
    >>> pool = nn.MaxPool3d(kernel_size=2, stride=2)
    >>> pool(x)
    
Methods
# mlx.nn.Mish
class Mish
    
Applies the Mish function, element-wise.
Reference: https://arxiv.org/abs/1908.08681
\\[\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))\\]
Methods
# mlx.nn.Module.apply
Module.apply(map_fn: Callable[[array], array], filter_fn: Callable[[Module, str, Any], bool] | None = None) → Module
    
Map all the parameters using the provided `map_fn` and immediately update the module with the mapped parameters.
For instance running `model.apply(lambda x: x.astype(mx.float16))` casts all parameters to 16 bit floats.
Parameters:
    
  * map_fn (Callable) – Maps an array to another array
  * filter_fn (Callable, optional) – Filter to select which arrays to map (default: `Module.valid_parameter_filter()`).


Returns:
    
The module instance after updating the parameters.
# mlx.nn.Module.apply_to_modules
Module.apply_to_modules(apply_fn: Callable[[str, Module], Any]) → Module
    
Apply a function to all the modules in this instance (including this instance).
Parameters:
    
apply_fn (Callable) – The function to apply to the modules.
Returns:
    
The module instance after updating submodules.
# mlx.nn.Module.children
Module.children()
    
Return the direct descendants of this Module instance.
# mlx.nn.Module.eval
Module.eval() → Module
    
Set the model to evaluation mode.
See `train()`.
# mlx.nn.Module.filter_and_map
Module.filter_and_map(filter_fn: Callable[[Module, str, Any], bool], map_fn: Callable | None = None, is_leaf_fn: Callable[[Module, str, Any], bool] | None = None)
    
Recursively filter the contents of the module using `filter_fn`, namely only select keys and values where `filter_fn` returns true.
This is used to implement `parameters()` and `trainable_parameters()` but it can also be used to extract any subset of the module’s parameters.
Parameters:
    
  * filter_fn (Callable) – Given a value, the key in which it is found and the containing module, decide whether to keep the value or drop it.
  * map_fn (Callable, optional) – Optionally transform the value before returning it.
  * is_leaf_fn (Callable, optional) – Given a value, the key in which it is found and the containing module decide if it is a leaf.


Returns:
    
A dictionary containing the contents of the module recursively filtered
# mlx.nn.Module.freeze
Module.freeze(*, recurse: bool = True, keys: str | List[str] | None = None, strict: bool = False) → Module
    
Freeze the Module’s parameters or some of them. Freezing a parameter means not computing gradients for it.
This function is idempotent i.e. freezing a frozen model is a no-op.
Example
For instance to only train the attention parameters from a Transformer:
    
    model = nn.Transformer()
    model.freeze()
    model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
    
Parameters:
    
  * recurse (bool, optional) – If True then freeze the parameters of the submodules as well. Default: `True`.
  * keys (str or list[str], optional) – If provided then only these parameters will be frozen otherwise all the parameters of a module. For instance freeze all biases by calling `module.freeze(keys="bias")`.
  * strict (bool, optional) – If set to `True` validate that the passed keys exist. Default: `False`.


Returns:
    
The module instance after freezing the parameters.
# mlx.nn.Module.leaf_modules
Module.leaf_modules()
    
Return the submodules that do not contain other modules.
# mlx.nn.Module.load_weights
Module.load_weights(file_or_weights: str | List[Tuple[str, array]], strict: bool = True) → Module
    
Update the model’s weights from a `.npz`, a `.safetensors` file, or a list.
Parameters:
    
  * file_or_weights (str or list(tuple(str, mx.array))) – The path to the weights `.npz` file (`.npz` or `.safetensors`) or a list of pairs of parameter names and arrays.
  * strict (bool, optional) – If `True` then checks that the provided weights exactly match the parameters of the model. Otherwise, only the weights actually contained in the model are loaded and shapes are not checked. Default: `True`.


Returns:
    
The module instance after updating the weights.
Example
    
    import mlx.core as mx
    import mlx.nn as nn
    model = nn.Linear(10, 10)
    
    # Load from file
    model.load_weights("weights.npz")
    
    # Load from .safetensors file
    model.load_weights("weights.safetensors")
    
    # Load from list
    weights = [
        ("weight", mx.random.uniform(shape=(10, 10))),
        ("bias",  mx.zeros((10,))),
    ]
    model.load_weights(weights)
    
    # Missing weight
    weights = [
        ("weight", mx.random.uniform(shape=(10, 10))),
    ]
    
    # Raises a ValueError exception
    model.load_weights(weights)
    
    # Ok, only updates the weight but not the bias
    model.load_weights(weights, strict=False)
    
# mlx.nn.Module.modules
Module.modules()
    
Return a list with all the modules in this instance.
Returns:
    
A list of `mlx.nn.Module` instances.
# mlx.nn.Module.named_modules
Module.named_modules()
    
Return a list with all the modules in this instance and their name with dot notation.
Returns:
    
A list of tuples (str, `mlx.nn.Module`).
# mlx.nn.Module.parameters
Module.parameters()
    
Recursively return all the `mlx.core.array` members of this Module as a dict of dicts and lists.
# mlx.nn.Module.save_weights
Module.save_weights(file: str)
    
Save the model’s weights to a file. The saving method is determined by the file extension: \- `.npz` will use `mx.savez()` \- `.safetensors` will use `mx.save_safetensors()`
# mlx.nn.Module.set_dtype
Module.set_dtype(dtype: ~mlx.core.Dtype, predicate: ~typing.Callable[[~mlx.core.Dtype], bool] | None = <function Module.<lambda>>)
    
Set the dtype of the module’s parameters.
Parameters:
    
  * dtype (Dtype) – The new dtype.
  * predicate (Callable, optional) – A predicate to select parameters to cast. By default, only parameters of type `floating` will be updated to avoid casting integer parameters to the new dtype.


# mlx.nn.Module.state
property Module.state
    
The module’s state dictionary
The module’s state dictionary contains any attribute set on the module including parameters in `Module.parameters()`
Unlike `Module.parameters()`, the `Module.state` property is a reference to the module’s state. Updates to it will be reflected in the original module.
# mlx.nn.Module.train
Module.train(mode: bool = True) → Module
    
Set the model in or out of training mode.
Training mode only applies to certain layers. For example `Dropout` applies a random mask in training mode, but is the identity in evaluation mode.
Parameters:
    
mode (bool) – Indicate if the model should be in training or evaluation mode. Default: `True`.
Returns:
    
The module instance after updating the training mode.
# mlx.nn.Module.trainable_parameters
Module.trainable_parameters()
    
Recursively return all the non frozen `mlx.core.array` members of this Module as a dict of dicts and lists.
# mlx.nn.Module.training
property Module.training
    
Boolean indicating if the model is in training mode.
# mlx.nn.Module.unfreeze
Module.unfreeze(*, recurse: bool = True, keys: str | List[str] | None = None, strict: bool = False) → Module
    
Unfreeze the Module’s parameters or some of them.
This function is idempotent ie unfreezing a model that is not frozen is a noop.
Example
For instance to only train the biases of a Transformer one can do:
    
    model = nn.Transformer()
    model.freeze()
    model.unfreeze(keys="bias")
    
Parameters:
    
  * recurse (bool, optional) – If True then unfreeze the parameters of the submodules as well. Default: `True`.
  * keys (str or list[str], optional) – If provided then only these parameters will be unfrozen otherwise all the parameters of a module. For instance unfreeze all biases by calling `module.unfreeze(keys="bias")`.
  * strict (bool, optional) – If set to `True` validate that the passed keys exist. Default: `False`.


Returns:
    
The module instance after unfreezing the parameters.
# mlx.nn.Module.update
Module.update(parameters: dict, strict: bool = True) → Module
    
Replace the parameters of this Module with the provided ones in the dict of dicts and lists.
Commonly used by the optimizer to change the model to the updated (optimized) parameters. Also used by the `mlx.nn.value_and_grad()` to set the tracers in the model in order to compute gradients.
The passed in parameters dictionary need not be a full dictionary similar to `parameters()`. Only the provided locations will be updated.
Parameters:
    
  * parameters (dict) – A complete or partial dictionary of the modules parameters.
  * strict (bool) – If `True` checks that `parameters` is a subset of the module’s parameters. Default: `True`.


Returns:
    
The module instance after updating the parameters.
# mlx.nn.Module.update_modules
Module.update_modules(modules: dict, strict: bool = True) → Module
    
Replace the child modules of this `Module` instance with the provided ones in the dict of dicts and lists.
It is the equivalent of `Module.update()` but for modules instead of parameters and allows us to flexibly edit complex architectures by programmatically swapping layers.
The passed in parameters dictionary need not be a full dictionary similar to `modules()`. Only the provided locations will be updated.
Parameters:
    
  * modules (dict) – A complete or partial dictionary of the module’s submodules.
  * strict (bool) – If `True` checks that `modules` is a subset of the child modules of this instance. Default: `True`.


Returns:
    
The module instance after updating the submodules.
# mlx.nn.MultiHeadAttention
class MultiHeadAttention(dims: int, num_heads: int, query_input_dims: int | None = None, key_input_dims: int | None = None, value_input_dims: int | None = None, value_dims: int | None = None, value_output_dims: int | None = None, bias: bool = False)
    
Implements the scaled dot product attention with multiple heads.
Given inputs for queries, keys and values the `MultiHeadAttention` produces new values by aggregating information from the input values according to the similarities of the input queries and keys.
All inputs as well as the output are linearly projected without biases by default.
`MultiHeadAttention` also takes an optional additive attention mask that should be broadcastable with `(batch, num_heads, # queries, # keys)`. The mask should have `-inf` or very large negative numbers at the positions that should not be attended to.
Parameters:
    
  * dims (int) – The model dimensions. This is also the default value for the queries, keys, values, and the output.
  * num_heads (int) – The number of attention heads to use.
  * query_input_dims (int, optional) – The input dimensions of the queries. Default: `dims`.
  * key_input_dims (int, optional) – The input dimensions of the keys. Default: `dims`.
  * value_input_dims (int, optional) – The input dimensions of the values. Default: `key_input_dims`.
  * value_dims (int, optional) – The dimensions of the values after the projection. Default: `dims`.
  * value_output_dims (int, optional) – The dimensions the new values will be projected to. Default: `dims`.
  * bias (bool, optional) – Whether or not to use a bias in the projections. Default: `False`.


Methods
`create_additive_causal_mask`(N[, dtype])  
# mlx.nn.PReLU
class PReLU(num_parameters=1, init=0.25)
    
Applies the element-wise parametric ReLU.
    
Applies \\(\max(0, x) + a * \min(0, x)\\) element wise, where \\(a\\) is an array.
See `prelu()` for the functional equivalent.
Parameters:
    
  * num_parameters – number of \\(a\\) to learn. Default: `1`
  * init – the initial value of \\(a\\). Default: `0.25`


Methods
# mlx.nn.QuantizedEmbedding
class QuantizedEmbedding(num_embeddings: int, dims: int, group_size: int = 64, bits: int = 4, mode: str = 'affine')
    
The same as `Embedding` but with a quantized weight matrix.
`QuantizedEmbedding` also provides a `from_embedding()` classmethod to convert embedding layers to `QuantizedEmbedding` layers.
Parameters:
    
  * num_embeddings (int) – How many possible discrete tokens can we embed. Usually called the vocabulary size.
  * dims (int) – The dimensionality of the embeddings.
  * group_size (int, optional) – The group size to use for the quantized weight. See `quantize()`. Default: `64`.
  * bits (int, optional) – The bit width to use for the quantized weight. See `quantize()`. Default: `4`.
  * mode (str) – The quantization method to use (see `mlx.core.quantize()`). Default: `"affine"`.


Methods
`as_linear`(x)
Call the quantized embedding layer as a quantized linear layer.  
`from_embedding`(embedding_layer[, ...])
Create a `QuantizedEmbedding` layer from an `Embedding` layer.  
# mlx.nn.QuantizedLinear
class QuantizedLinear(input_dims: int, output_dims: int, bias: bool = True, group_size: int = 64, bits: int = 4, mode: str = 'affine')
    
Applies an affine transformation to the input using a quantized weight matrix.
It is the quantized equivalent of `mlx.nn.Linear`. For now its parameters are frozen and will not be included in any gradient computation but this will probably change in the future.
`QuantizedLinear` also provides a classmethod `from_linear()` to convert linear layers to `QuantizedLinear` layers.
Parameters:
    
  * input_dims (int) – The dimensionality of the input features.
  * output_dims (int) – The dimensionality of the output features.
  * bias (bool, optional) – If set to `False` then the layer will not use a bias. Default: `True`.
  * group_size (int, optional) – The group size to use for the quantized weight. See `quantize()`. Default: `64`.
  * bits (int, optional) – The bit width to use for the quantized weight. See `quantize()`. Default: `4`.
  * mode (str) – The quantization method to use (see `mlx.core.quantize()`). Default: `"affine"`.


Methods
`from_linear`(linear_layer[, group_size, ...])
Create a `QuantizedLinear` layer from a `Linear` layer.  
# mlx.nn.RMSNorm
class RMSNorm(dims: int, eps: float = 1e-05)
    
Applies Root Mean Square normalization [1] to the inputs.
Computes
\\[y = \frac{x}{\sqrt{E[x^2] + \epsilon}} \gamma\\]
where \\(\gamma\\) is a learned per feature dimension parameter initialized at 1.
Note the accumulation for the mean is done in 32-bit precision.
[1]: https://arxiv.org/abs/1910.07467
Parameters:
    
  * dims (int) – The feature dimension of the input to normalize over
  * eps (float) – A small additive constant for numerical stability


Methods
# mlx.nn.RNN
class RNN(input_size: int, hidden_size: int, bias: bool = True, nonlinearity: Callable | None = None)
    
An Elman recurrent layer.
The input is a sequence of shape `NLD` or `LD` where:
  * `N` is the optional batch dimension
  * `L` is the sequence length
  * `D` is the input’s feature dimension


Concretely, for each element along the sequence length axis, this layer applies the function:
\\[h_{t + 1} = \text{tanh} (W_{ih}x_t + W_{hh}h_t + b)\\]
The hidden state \\(h\\) has shape `NH` or `H`, depending on whether the input is batched or not. Returns the hidden state at each time step, of shape `NLH` or `LH`.
Parameters:
    
  * input_size (int) – Dimension of the input, `D`.
  * hidden_size (int) – Dimension of the hidden state, `H`.
  * bias (bool, optional) – Whether to use a bias. Default: `True`.
  * nonlinearity (callable, optional) – Non-linearity to use. If `None`, then func:tanh is used. Default: `None`.


Methods
# mlx.nn.ReLU
class ReLU
    
Applies the Rectified Linear Unit.
    
Simply `mx.maximum(x, 0)`.
See `relu()` for the functional equivalent.
Methods
# mlx.nn.ReLU2
class ReLU2
    
Applies the ReLU² activation function.
See `relu2()` for the functional equivalent.
Methods
# mlx.nn.ReLU6
class ReLU6
    
Applies the Rectified Linear Unit 6.
See `relu6()` for the functional equivalent.
Methods
# mlx.nn.RoPE
class RoPE(dims: int, traditional: bool = False, base: float = 10000, scale: float = 1.0)
    
Implements the rotary positional encoding.
The traditional implementation rotates consecutive pairs of elements in the feature dimension while the default implementation rotates pairs with stride half the feature dimensions for efficiency.
For more details see RoFormer: Enhanced Transformer with Rotary Position Embedding.
Parameters:
    
  * dims (int) – The feature dimensions to be rotated. If the input feature is larger than dims then the rest is left unchanged.
  * traditional (bool, optional) – If set to `True` choose the traditional implementation which is slightly less efficient. Default: `False`.
  * base (float, optional) – The base used to compute angular frequency for each dimension in the positional encodings. Default: `10000`.
  * scale (float, optional) – The scale used to scale the positions. Default: `1.0`.


Methods
# mlx.nn.SELU
class SELU
    
Applies the Scaled Exponential Linear Unit.
See `selu()` for the functional equivalent.
Methods
# mlx.nn.Sequential
class Sequential(*modules)
    
A layer that calls the passed callables in order.
We can pass either modules or plain callables to the Sequential module. If our functions have learnable parameters they should be implemented as `nn.Module` instances.
Parameters:
    
modules (tuple of Callables) – The modules to call in order
Methods
# mlx.nn.SiLU
class SiLU
    
Applies the Sigmoid Linear Unit. Also known as Swish.
See `silu()` for the functional equivalent.
Methods
# mlx.nn.Sigmoid
class Sigmoid
    
Applies the sigmoid function, element-wise.
\\[\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}\\]
Methods
# mlx.nn.SinusoidalPositionalEncoding
class SinusoidalPositionalEncoding(dims: int, min_freq: float = 0.0001, max_freq: float = 1, scale: float | None = None, cos_first: bool = False, full_turns: bool = False)
    
Implements sinusoidal positional encoding.
For more details see the paper Attention Is All You Need.
Parameters:
    
  * dims (int) – The dimensionality of the resulting positional embeddings.
  * min_freq (float, optional) – The minimum frequency expected. Default: `0.0001`.
  * max_freq (float, optional) – The maximum frequency expected. Default: `1`.
  * scale (float, optional) – A multiplicative scale for the embeddings. Default: `sqrt(2/dims)`.
  * cos_first (bool, optional) – If `True` embed using `[cos(x); sin(x)]` instead of the reverse. Default: `False`.
  * full_turns (bool, optional) – If `True` multiply the frequencies with \\(2\pi\\). Default: `False`.


Methods
# mlx.nn.Softmax
class Softmax
    
Applies the Softmax function.
See `softmax()` for the functional equivalent.
Methods
# mlx.nn.Softmin
class Softmin
    
Applies the Softmin function.
See `softmin()` for the functional equivalent.
Methods
# mlx.nn.Softplus
class Softplus
    
Applies the Softplus function.
See `softplus()` for the functional equivalent.
Methods
# mlx.nn.Softshrink
class Softshrink(lambd=0.5)
    
Applies the Softshrink function.
See `softshrink()` for the functional equivalent.
Parameters:
    
lambd – the \\(\lambda\\) value for Softshrink. Default: `0.5`
Methods
# mlx.nn.Softsign
class Softsign
    
Applies the Softsign function.
See `softsign()` for the functional equivalent.
Methods
# mlx.nn.Step
class Step(threshold: float = 0.0)
    
Applies the Step Activation Function.
This function implements a binary step activation, where the output is set to 1 if the input is greater than a specified threshold, and 0 otherwise.
\\[\begin{split}\text{step}(x) = \begin{cases} 0 & \text{if } x < \text{threshold} \\\ 1 & \text{if } x \geq \text{threshold} \end{cases}\end{split}\\]
Parameters:
    
threshold – The value to threshold at.
Methods
# mlx.nn.Tanh
class Tanh
    
Applies the hyperbolic tangent function.
See `tanh()` for the functional equivalent.
Methods
# mlx.nn.Transformer
class Transformer(dims: int = 512, num_heads: int = 8, num_encoder_layers: int = 6, num_decoder_layers: int = 6, mlp_dims: int | None = None, dropout: float = 0.0, activation: ~typing.Callable[[~typing.Any], ~typing.Any] = <mlx.gc_func object>, custom_encoder: ~typing.Any | None = None, custom_decoder: ~typing.Any | None = None, norm_first: bool = True, checkpoint: bool = False)
    
Implements a standard Transformer model.
The implementation is based on Attention Is All You Need.
The Transformer model contains an encoder and a decoder. The encoder processes the input sequence and the decoder generates the output sequence. The interaction between encoder and decoder happens through the attention mechanism.
Parameters:
    
  * dims (int, optional) – The number of expected features in the encoder/decoder inputs. Default: `512`.
  * num_heads (int, optional) – The number of attention heads. Default: `8`.
  * num_encoder_layers (int, optional) – The number of encoder layers in the Transformer encoder. Default: `6`.
  * num_decoder_layers (int, optional) – The number of decoder layers in the Transformer decoder. Default: `6`.
  * mlp_dims (int, optional) – The hidden dimension of the MLP block in each Transformer layer. Defaults to `4*dims` if not provided. Default: `None`.
  * dropout (float, optional) – The dropout value for the Transformer encoder and decoder. Dropout is used after each attention layer and the activation in the MLP layer. Default: `0.0`.
  * activation (function, optional) – the activation function for the MLP hidden layer. Default: `mlx.nn.relu()`.
  * custom_encoder (Module, optional) – A custom encoder to replace the standard Transformer encoder. Default: `None`.
  * custom_decoder (Module, optional) – A custom decoder to replace the standard Transformer decoder. Default: `None`.
  * norm_first (bool, optional) – if `True`, encoder and decoder layers will perform layer normalization before attention and MLP operations, otherwise after. Default: `True`.
  * checkpoint (bool, optional) – if `True` perform gradient checkpointing to reduce the memory usage at the expense of more computation. Default: `False`.


Methods
# mlx.nn.Upsample
class Upsample(scale_factor: float | Tuple, mode: Literal['nearest', 'linear', 'cubic'] = 'nearest', align_corners: bool = False)
    
Upsample the input signal spatially.
The spatial dimensions are by convention dimensions `1` to `x.ndim - 2`. The first is the batch dimension and the last is the feature dimension.
For example, an audio signal would be 3D with 1 spatial dimension, an image 4D with 2 and so on and so forth.
There are three upsampling algorithms implemented nearest neighbor upsampling, linear interpolation, and cubic interpolation. All can be applied to any number of spatial dimensions. The linear interpolation will be bilinear, trilinear etc when applied to more than one spatial dimension. And cubic interpolation will be bicubic when there are 2 spatial dimensions.
Note
When using one of the linear or cubic interpolation modes the `align_corners` argument changes how the corners are treated in the input image. If `align_corners=True` then the top and left edge of the input and output will be matching as will the bottom right edge.
Parameters:
    
  * scale_factor (float or tuple) – The multiplier for the spatial size. If a `float` is provided, it is the multiplier for all spatial dimensions. Otherwise, the number of scale factors provided must match the number of spatial dimensions.
  * mode (str, optional) – The upsampling algorithm, either `"nearest"`, `"linear"` or `"cubic"`. Default: `"nearest"`.
  * align_corners (bool, optional) – Changes the way the corners are treated during `"linear"` and `"cubic"` upsampling. See the note above and the examples below for more details. Default: `False`.


Examples
    
    >>> import mlx.core as mx
    >>> import mlx.nn as nn
    >>> x = mx.arange(1, 5).reshape((1, 2, 2, 1))
    >>> x
    array([[[[1],
             [2]],
            [[3],
             [4]]]], dtype=int32)
    >>> n = nn.Upsample(scale_factor=2, mode='nearest')
    >>> n(x).squeeze()
    array([[1, 1, 2, 2],
           [1, 1, 2, 2],
           [3, 3, 4, 4],
           [3, 3, 4, 4]], dtype=int32)
    >>> b = nn.Upsample(scale_factor=2, mode='linear')
    >>> b(x).squeeze()
    array([[1, 1.25, 1.75, 2],
           [1.5, 1.75, 2.25, 2.5],
           [2.5, 2.75, 3.25, 3.5],
           [3, 3.25, 3.75, 4]], dtype=float32)
    >>> b = nn.Upsample(scale_factor=2, mode='linear', align_corners=True)
    >>> b(x).squeeze()
    array([[1, 1.33333, 1.66667, 2],
           [1.66667, 2, 2.33333, 2.66667],
           [2.33333, 2.66667, 3, 3.33333],
           [3, 3.33333, 3.66667, 4]], dtype=float32)
    
Methods
# mlx.nn.init.constant
constant(value: float, dtype: Dtype = mlx.core.float32) → Callable[[array], array]
    
An initializer that returns an array filled with `value`.
Parameters:
    
  * value (float) – The value to fill the array with.
  * dtype (Dtype, optional) – The data type of the array. Default: `float32`.


Returns:
    
An initializer that returns an array with the same shape as the input, filled with `value`.
Return type:
    
Callable[[array], array]
Example
    
    >>> init_fn = nn.init.constant(0.5)
    >>> init_fn(mx.zeros((2, 2)))
    array([[0.5, 0.5],
           [0.5, 0.5]], dtype=float32)
    
# mlx.nn.init.glorot_normal
glorot_normal(dtype: Dtype = mlx.core.float32) → Callable[[array, float], array]
    
A Glorot normal initializer.
This initializer samples from a normal distribution with a standard deviation computed from the number of input (`fan_in`) and output (`fan_out`) units according to:
\\[\sigma = \gamma \sqrt{\frac{2.0}{\text{fan\\_in} + \text{fan\\_out}}}\\]
For more details see the original reference: Understanding the difficulty of training deep feedforward neural networks
Parameters:
    
dtype (Dtype, optional) – The data type of the array. Default: `float32`.
Returns:
    
An initializer that returns an array with the same shape as the input, filled with samples from the Glorot normal distribution.
Return type:
    
Callable[[array, float], array]
Example
    
    >>> init_fn = nn.init.glorot_normal()
    >>> init_fn(mx.zeros((2, 2)))
    array([[0.191107, 1.61278],
           [-0.150594, -0.363207]], dtype=float32)
    >>> init_fn(mx.zeros((2, 2)), gain=4.0)
    array([[1.89613, -4.53947],
           [4.48095, 0.995016]], dtype=float32)
    
# mlx.nn.init.glorot_uniform
glorot_uniform(dtype: Dtype = mlx.core.float32) → Callable[[array, float], array]
    
A Glorot uniform initializer.
This initializer samples from a uniform distribution with a range computed from the number of input (`fan_in`) and output (`fan_out`) units according to:
\\[\sigma = \gamma \sqrt{\frac{6.0}{\text{fan\\_in} + \text{fan\\_out}}}\\]
For more details see the original reference: Understanding the difficulty of training deep feedforward neural networks
Parameters:
    
dtype (Dtype, optional) – The data type of the array. Default: `float32`.
Returns:
    
An initializer that returns an array with the same shape as the input, filled with samples from the Glorot uniform distribution.
Return type:
    
Callable[[array, float], array]
Example
    
    >>> init_fn = nn.init.glorot_uniform()
    >>> init_fn(mx.zeros((2, 2)))
    array([[0.223404, -0.890597],
           [-0.379159, -0.776856]], dtype=float32)
    >>> init_fn(mx.zeros((2, 2)), gain=4.0)
    array([[-1.90041, 3.02264],
           [-0.912766, 4.12451]], dtype=float32)
    
# mlx.nn.init.he_normal
he_normal(dtype: Dtype = mlx.core.float32) → Callable[[array, Literal['fan_in', 'fan_out'], float], array]
    
Build a He normal initializer.
This initializer samples from a normal distribution with a standard deviation computed from the number of input (`fan_in`) or output (`fan_out`) units according to:
\\[\sigma = \gamma \frac{1}{\sqrt{\text{fan}}}\\]
where \\(\text{fan}\\) is either the number of input units when the `mode` is `"fan_in"` or output units when the `mode` is `"fan_out"`.
For more details see the original reference: Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification
Parameters:
    
dtype (Dtype, optional) – The data type of the array. Defaults to mx.float32.
Returns:
    
An initializer that returns an array with the same shape as the input, filled with samples from the He normal distribution.
Return type:
    
Callable[[array, str, float], array]
Example
    
    >>> init_fn = nn.init.he_normal()
    >>> init_fn(mx.zeros((2, 2)))  # uses fan_in
    array([[-1.25211, 0.458835],
           [-0.177208, -0.0137595]], dtype=float32)
    >>> init_fn(mx.zeros((2, 2)), mode="fan_out", gain=5)
    array([[5.6967, 4.02765],
           [-4.15268, -2.75787]], dtype=float32)
    
# mlx.nn.init.he_uniform
he_uniform(dtype: Dtype = mlx.core.float32) → Callable[[array, Literal['fan_in', 'fan_out'], float], array]
    
A He uniform (Kaiming uniform) initializer.
This initializer samples from a uniform distribution with a range computed from the number of input (`fan_in`) or output (`fan_out`) units according to:
\\[\sigma = \gamma \sqrt{\frac{3.0}{\text{fan}}}\\]
where \\(\text{fan}\\) is either the number of input units when the `mode` is `"fan_in"` or output units when the `mode` is `"fan_out"`.
For more details see the original reference: Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification
Parameters:
    
dtype (Dtype, optional) – The data type of the array. Default: `float32`.
Returns:
    
An initializer that returns an array with the same shape as the input, filled with samples from the He uniform distribution.
Return type:
    
Callable[[array, str, float], array]
Example
    
    >>> init_fn = nn.init.he_uniform()
    >>> init_fn(mx.zeros((2, 2)))  # uses fan_in
    array([[0.0300242, -0.0184009],
           [0.793615, 0.666329]], dtype=float32)
    >>> init_fn(mx.zeros((2, 2)), mode="fan_out", gain=5)
    array([[-1.64331, -2.16506],
           [1.08619, 5.79854]], dtype=float32)
    
# mlx.nn.init.identity
identity(dtype: Dtype = mlx.core.float32) → Callable[[array], array]
    
An initializer that returns an identity matrix.
Parameters:
    
dtype (Dtype, optional) – The data type of the array. Defaults: `float32`.
Returns:
    
An initializer that returns an identity matrix with the same shape as the input.
Return type:
    
Callable[[array], array]
Example
    
    >>> init_fn = nn.init.identity()
    >>> init_fn(mx.zeros((2, 2)))
    array([[1, 0],
           [0, 1]], dtype=float32)
    
# mlx.nn.init.normal
normal(mean: float = 0.0, std: float = 1.0, dtype: Dtype = mlx.core.float32) → Callable[[array], array]
    
An initializer that returns samples from a normal distribution.
Parameters:
    
  * mean (float, optional) – Mean of the normal distribution. Default: `0.0`.
  * std (float, optional) – Standard deviation of the normal distribution. Default: `1.0`.
  * dtype (Dtype, optional) – The data type of the array. Default: `float32`.


Returns:
    
An initializer that returns an array with the same shape as the input, filled with samples from a normal distribution.
Return type:
    
Callable[[array], array]
Example
    
    >>> init_fn = nn.init.normal()
    >>> init_fn(mx.zeros((2, 2)))
    array([[-0.982273, -0.534422],
           [0.380709, 0.0645099]], dtype=float32)
    
# mlx.nn.init.uniform
uniform(low: float = 0.0, high: float = 1.0, dtype: Dtype = mlx.core.float32) → Callable[[array], array]
    
An initializer that returns samples from a uniform distribution.
Parameters:
    
  * low (float, optional) – The lower bound of the uniform distribution. Default: `0.0`.
  * high (float, optional) – The upper bound of the uniform distribution. Default: `1.0`
  * dtype (Dtype, optional) – The data type of the array. Default: `float32`.


Returns:
    
An initializer that returns an array with the same shape as the input, filled with samples from a uniform distribution
Return type:
    
Callable[[array], array]
Example
    
    >>> init_fn = nn.init.uniform(low=0, high=1)
    >>> init_fn(mx.zeros((2, 2)))
    array([[0.883935, 0.863726],
           [0.617261, 0.417497]], dtype=float32)
    
# mlx.nn.celu
class celu(x, alpha=1.0)
    
Applies the Continuously Differentiable Exponential Linear Unit.
Applies \\(\max(0, x) + \min(0, \alpha * (\exp(x / \alpha) - 1))\\) element wise.
# mlx.nn.elu
class elu(x, alpha=1.0)
    
Applies the Exponential Linear Unit.
Simply `mx.where(x > 0, x, alpha * (mx.exp(x) - 1))`.
# mlx.nn.gelu
class gelu(x)
    
Applies the Gaussian Error Linear Units function.
\\[\textrm{GELU}(x) = x * \Phi(x)\\]
where \\(\Phi(x)\\) is the Gaussian CDF.
See also `gelu_approx()` and `gelu_fast_approx()` for faster approximations.
# mlx.nn.gelu_approx
class gelu_approx(x)
    
An approximation to Gaussian Error Linear Unit.
See `gelu()` for the exact computation.
This function approximates `gelu` with a maximum absolute error \\(< 0.0005\\) in the range \\([-6, 6]\\) using the following
\\[x = 0.5 * x * \left(1 + \text{Tanh}\left((\sqrt{2 / \pi} * \left(x + 0.044715 * x^3\right)\right)\right)\\]
# mlx.nn.gelu_fast_approx
class gelu_fast_approx(x)
    
A fast approximation to Gaussian Error Linear Unit.
See `gelu()` for the exact computation.
This function approximates `gelu` with a maximum absolute error \\(< 0.015\\) in the range \\([-6, 6]\\) using the following
\\[x = x \sigma\left(1.702 x\right)\\]
where \\(\sigma(\cdot)\\) is the logistic sigmoid.
References: \- hendrycks/GELUs \- https://arxiv.org/abs/1606.08415
# mlx.nn.glu
class glu(x: array, axis: int = -1)
    
Applies the gated linear unit function.
This function splits the `axis` dimension of the input into two halves (\\(a\\) and \\(b\\)) and applies \\(a * \sigma(b)\\).
\\[\textrm{GLU}(x) = a * \sigma(b)\\]
Parameters:
    
axis (int) – The dimension to split along. Default: `-1`
# mlx.nn.hard_shrink
class hard_shrink(x, lambd=0.5)
    
Applies the HardShrink activation function.
\\[\begin{split}\text{hardshrink}(x) = \begin{cases} x & \text{if } x > \lambda \\\ x & \text{if } x < -\lambda \\\ 0 & \text{otherwise} \end{cases}\end{split}\\]
# mlx.nn.hard_tanh
class hard_tanh(x, min_val=-1.0, max_val=1.0)
    
Applies the HardTanh function.
Applies \\(\max(\min(x, \text{max\\_val}), \text{min\\_val})\\) element-wise.
# mlx.nn.hardswish
class hardswish(x)
    
Applies the hardswish function, element-wise.
\\[\text{Hardswish}(x) = x * \min(\max(x + 3, 0), 6) / 6\\]
# mlx.nn.leaky_relu
class leaky_relu(x, negative_slope=0.01)
    
Applies the Leaky Rectified Linear Unit.
Simply `mx.maximum(negative_slope * x, x)`.
# mlx.nn.log_sigmoid
class log_sigmoid(x)
    
Applies the Log Sigmoid function.
Applies \\(\log(\sigma(x)) = -\log(1 + e^{-x})\\) element wise.
# mlx.nn.log_softmax
class log_softmax(x, axis=-1)
    
Applies the Log Softmax function.
Applies \\(x + \log \sum_i e^{x_i}\\) element wise.
# mlx.nn.losses.binary_cross_entropy
class binary_cross_entropy(inputs: array, targets: array, weights: array | None = None, with_logits: bool = True, reduction: Literal['none', 'mean', 'sum'] = 'mean')
    
Computes the binary cross entropy loss.
By default, this function takes the pre-sigmoid logits, which results in a faster and more precise loss. For improved numerical stability when `with_logits=False`, the loss calculation clips the input probabilities (in log-space) to a minimum value of `-100`.
Parameters:
    
  * inputs (array) – The predicted values. If `with_logits` is `True`, then `inputs` are unnormalized logits. Otherwise, `inputs` are probabilities.
  * targets (array) – The binary target values in {0, 1}.
  * with_logits (bool, optional) – Whether `inputs` are logits. Default: `True`.
  * weights (array, optional) – Optional weights for each target. Default: `None`.
  * reduction (str, optional) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'mean'`.


Returns:
    
The computed binary cross entropy loss.
Return type:
    
array
Examples
    
    >>> import mlx.core as mx
    >>> import mlx.nn as nn
    
    
    >>> logits = mx.array([0.105361, 0.223144, 1.20397, 0.916291])
    >>> targets = mx.array([0, 0, 1, 1])
    >>> loss = nn.losses.binary_cross_entropy(logits, targets, reduction="mean")
    >>> loss
    array(0.539245, dtype=float32)
    
    
    >>> probs = mx.array([0.1, 0.1, 0.4, 0.4])
    >>> targets = mx.array([0, 0, 1, 1])
    >>> loss = nn.losses.binary_cross_entropy(probs, targets, with_logits=False, reduction="mean")
    >>> loss
    array(0.510826, dtype=float32)
    
# mlx.nn.losses.cosine_similarity_loss
class cosine_similarity_loss(x1: array, x2: array, axis: int = 1, eps: float = 1e-08, reduction: Literal['none', 'mean', 'sum'] = 'none')
    
Computes the cosine similarity between the two inputs.
The cosine similarity loss is given by
\\[\frac{x_1 \cdot x_2}{\max(\|x_1\| \cdot \|x_2\|, \epsilon)}\\]
Parameters:
    
  * x1 (mx.array) – The first set of inputs.
  * x2 (mx.array) – The second set of inputs.
  * axis (int, optional) – The embedding axis. Default: `1`.
  * eps (float, optional) – The minimum value of the denominator used for numerical stability. Default: `1e-8`.
  * reduction (str, optional) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'none'`.


Returns:
    
The computed cosine similarity loss.
Return type:
    
mx.array
# mlx.nn.losses.cross_entropy
class cross_entropy(logits: array, targets: array, weights: array | None = None, axis: int = -1, label_smoothing: float = 0.0, reduction: Literal['none', 'mean', 'sum'] = 'none')
    
Computes the cross entropy loss.
Parameters:
    
  * logits (array) – The unnormalized logits.
  * targets (array) – The ground truth values. These can be class indices or probabilities for each class. If the `targets` are class indices, then `targets` shape should match the `logits` shape with the `axis` dimension removed. If the `targets` are probabilities (or one-hot encoded), then the `targets` shape should be the same as the `logits` shape.
  * weights (array, optional) – Optional weights for each target. Default: `None`.
  * axis (int, optional) – The axis over which to compute softmax. Default: `-1`.
  * label_smoothing (float, optional) – Label smoothing factor. Default: `0`.
  * reduction (str, optional) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'none'`.


Returns:
    
The computed cross entropy loss.
Return type:
    
array
Examples
    
    >>> import mlx.core as mx
    >>> import mlx.nn as nn
    >>>
    >>> # Class indices as targets
    >>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])
    >>> targets = mx.array([0, 1])
    >>> nn.losses.cross_entropy(logits, targets)
    array([0.0485873, 0.0485873], dtype=float32)
    >>>
    >>> # Probabilities (or one-hot vectors) as targets
    >>> logits = mx.array([[2.0, -1.0], [-1.0, 2.0]])
    >>> targets = mx.array([[0.9, 0.1], [0.1, 0.9]])
    >>> nn.losses.cross_entropy(logits, targets)
    array([0.348587, 0.348587], dtype=float32)
    
# mlx.nn.losses.gaussian_nll_loss
class gaussian_nll_loss(inputs: array, targets: array, vars: array, full: bool = False, eps: float = 1e-06, reduction: Literal['none', 'mean', 'sum'] = 'mean')
    
Computes the negative log likelihood loss for a Gaussian distribution.
The loss is given by:
\\[\frac{1}{2}\left(\log\left(\max\left(\text{vars}, \ \epsilon\right)\right) + \frac{\left(\text{inputs} - \text{targets} \right)^2} {\max\left(\text{vars}, \ \epsilon \right)}\right) + \text{const.}\\]
where `inputs` are the predicted means and `vars` are the the predicted variances.
Parameters:
    
  * inputs (array) – The predicted expectation of the Gaussian distribution.
  * targets (array) – The target values (samples from the Gaussian distribution).
  * vars (array) – The predicted variance of the Gaussian distribution.
  * full (bool, optional) – Whether to include the constant term in the loss calculation. Default: `False`.
  * eps (float, optional) – Small positive constant for numerical stability. Default: `1e-6`.
  * reduction (str, optional) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'none'`.


Returns:
    
The Gaussian NLL loss.
Return type:
    
array
# mlx.nn.losses.hinge_loss
class hinge_loss(inputs: array, targets: array, reduction: Literal['none', 'mean', 'sum'] = 'none')
    
Computes the hinge loss between inputs and targets.
\\[\text{hinge}(y, y_{\text{pred}}) = \max(0, 1 - y \cdot y_{\text{pred}})\\]
Parameters:
    
  * inputs (array) – The predicted values.
  * targets (array) – The target values. They should be -1 or 1.
  * reduction (str, optional) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'none'`.


Returns:
    
The computed hinge loss.
Return type:
    
array
# mlx.nn.losses.huber_loss
class huber_loss(inputs: array, targets: array, delta: float = 1.0, reduction: Literal['none', 'mean', 'sum'] = 'none')
    
Computes the Huber loss between inputs and targets.
\\[\begin{split}l_{\delta}(a) = \left\\{ \begin{array}{ll} \frac{1}{2} a^2 & \text{for } |a| \leq \delta, \\\ \delta \left( |a| - \frac{1}{2} \delta \right) & \text{otherwise.} \end{array} \right.\end{split}\\]
Parameters:
    
  * inputs (array) – The predicted values.
  * targets (array) – The target values.
  * delta (float, optional) – The threshold at which to change between L1 and L2 loss. Default: `1.0`.
  * reduction (str, optional) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'none'`.


Returns:
    
The computed Huber loss.
Return type:
    
array
# mlx.nn.losses.kl_div_loss
class kl_div_loss(inputs: array, targets: array, axis: int = -1, reduction: Literal['none', 'mean', 'sum'] = 'none')
    
Computes the Kullback-Leibler divergence loss.
Computes the following when `reduction == 'none'`:
    
    mx.exp(targets) * (targets - inputs).sum(axis)
    
Parameters:
    
  * inputs (array) – Log probabilities for the predicted distribution.
  * targets (array) – Log probabilities for the target distribution.
  * axis (int, optional) – The distribution axis. Default: `-1`.
  * reduction (str, optional) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'none'`.


Returns:
    
The computed Kullback-Leibler divergence loss.
Return type:
    
array
# mlx.nn.losses.l1_loss
class l1_loss(predictions: array, targets: array, reduction: Literal['none', 'mean', 'sum'] = 'mean')
    
Computes the L1 loss.
Parameters:
    
  * predictions (array) – The predicted values.
  * targets (array) – The target values.
  * reduction (str, optional) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'mean'`.


Returns:
    
The computed L1 loss.
Return type:
    
array
# mlx.nn.losses.log_cosh_loss
class log_cosh_loss(inputs: array, targets: array, reduction: Literal['none', 'mean', 'sum'] = 'none')
    
Computes the log cosh loss between inputs and targets.
Logcosh acts like L2 loss for small errors, ensuring stable gradients, and like the L1 loss for large errors, reducing sensitivity to outliers. This dual behavior offers a balanced, robust approach for regression tasks.
\\[\text{logcosh}(y_{\text{true}}, y_{\text{pred}}) = \frac{1}{n} \sum_{i=1}^{n} \log(\cosh(y_{\text{pred}}^{(i)} - y_{\text{true}}^{(i)}))\\]
Parameters:
    
  * inputs (array) – The predicted values.
  * targets (array) – The target values.
  * reduction (str, optional) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'none'`.


Returns:
    
The computed log cosh loss.
Return type:
    
array
# mlx.nn.losses.margin_ranking_loss
class margin_ranking_loss(inputs1: array, inputs2: array, targets: array, margin: float = 0.0, reduction: Literal['none', 'mean', 'sum'] = 'none')
    
Calculate the margin ranking loss that loss given inputs \\(x_1\\), \\(x_2\\) and a label \\(y\\) (containing 1 or -1).
The loss is given by:
\\[\text{loss} = \max (0, -y * (x_1 - x_2) + \text{margin})\\]
Where \\(y\\) represents `targets`, \\(x_1\\) represents `inputs1` and \\(x_2\\) represents `inputs2`.
Parameters:
    
  * inputs1 (array) – Scores for the first input.
  * inputs2 (array) – Scores for the second input.
  * targets (array) – Labels indicating whether samples in `inputs1` should be ranked higher than samples in `inputs2`. Values should be 1 or -1.
  * margin (float, optional) – The margin by which the scores should be separated. Default: `0.0`.
  * reduction (str, optional) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'none'`.


Returns:
    
The computed margin ranking loss.
Return type:
    
array
Examples
    
    >>> import mlx.core as mx
    >>> import mlx.nn as nn
    >>> targets = mx.array([1, 1, -1])
    >>> inputs1 = mx.array([-0.573409, -0.765166, -0.0638])
    >>> inputs2 = mx.array([0.75596, 0.225763, 0.256995])
    >>> loss = nn.losses.margin_ranking_loss(inputs1, inputs2, targets)
    >>> loss
    array(0.773433, dtype=float32)
    
# mlx.nn.losses.mse_loss
class mse_loss(predictions: array, targets: array, reduction: Literal['none', 'mean', 'sum'] = 'mean')
    
Computes the mean squared error loss.
Parameters:
    
  * predictions (array) – The predicted values.
  * targets (array) – The target values.
  * reduction (str, optional) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'mean'`.


Returns:
    
The computed mean squared error loss.
Return type:
    
array
# mlx.nn.losses.nll_loss
class nll_loss(inputs: array, targets: array, axis: int = -1, reduction: Literal['none', 'mean', 'sum'] = 'none')
    
Computes the negative log likelihood loss.
Parameters:
    
  * inputs (array) – The predicted distribution in log space.
  * targets (array) – The target values.
  * axis (int, optional) – The distribution axis. Default: `-1`.
  * reduction (str, optional) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'none'`.


Returns:
    
The computed NLL loss.
Return type:
    
array
# mlx.nn.losses.smooth_l1_loss
class smooth_l1_loss(predictions: array, targets: array, beta: float = 1.0, reduction: Literal['none', 'mean', 'sum'] = 'mean')
    
Computes the smooth L1 loss.
The smooth L1 loss is a variant of the L1 loss which replaces the absolute difference with a squared difference when the absolute difference is less than `beta`.
The formula for the smooth L1 Loss is:
\\[\begin{split}l = \begin{cases} 0.5 (x - y)^2 / \beta, & \text{if } |x - y| < \beta \\\ |x - y| - 0.5 \beta, & \text{otherwise} \end{cases}\end{split}\\]
Parameters:
    
  * predictions (array) – Predicted values.
  * targets (array) – Ground truth values.
  * beta (float, optional) – The threshold after which the loss changes from the squared to the absolute difference. Default: `1.0`.
  * reduction (str, optional) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'mean'`.


Returns:
    
The computed smooth L1 loss.
Return type:
    
array
# mlx.nn.losses.triplet_loss
class triplet_loss(anchors: array, positives: array, negatives: array, axis: int = -1, p: int = 2, margin: float = 1.0, eps: float = 1e-06, reduction: Literal['none', 'mean', 'sum'] = 'none')
    
Computes the triplet loss for a set of anchor, positive, and negative samples. Margin is represented with alpha in the math section.
\\[\max\left(\|A - P\|_p - \|A - N\|_p + \alpha, 0\right)\\]
Parameters:
    
  * anchors (array) – The anchor samples.
  * positives (array) – The positive samples.
  * negatives (array) – The negative samples.
  * axis (int, optional) – The distribution axis. Default: `-1`.
  * p (int, optional) – The norm degree for pairwise distance. Default: `2`.
  * margin (float, optional) – Margin for the triplet loss. Defaults to `1.0`.
  * eps (float, optional) – Small positive constant to prevent numerical instability. Defaults to `1e-6`.
  * reduction (str, optional) – Specifies the reduction to apply to the output: `'none'` | `'mean'` | `'sum'`. Default: `'none'`.


Returns:
    
Computed triplet loss. If reduction is “none”, returns a tensor of the same shape as input;
    
if reduction is “mean” or “sum”, returns a scalar tensor.
Return type:
    
array
# mlx.nn.mish
class mish(x: array)
    
Applies the Mish function, element-wise.
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
Reference: https://arxiv.org/abs/1908.08681
\\[\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))\\]
# mlx.nn.prelu
class prelu(x: array, alpha: array)
    
Applies the element-wise parametric ReLU.
\\[\text{PReLU}(x) = \max(0,x) + a * \min(0,x)\\]
where \\(a\\) is an array.
# mlx.nn.relu
class relu(x)
    
Applies the Rectified Linear Unit.
Simply `mx.maximum(x, 0)`.
# mlx.nn.relu2
class relu2(x)
    
Applies the ReLU² activation function.
Applies \\(\max(0, x)^2\\) element wise.
# mlx.nn.relu6
class relu6(x)
    
Applies the Rectified Linear Unit 6.
Applies \\(\min(\max(x, 0), 6)\\) element wise.
# mlx.nn.selu
class selu(x)
    
Applies the Scaled Exponential Linear Unit.
\\[\begin{split}\text{selu}(x) = \begin{cases} \lambda x & \text{if } x > 0 \\\ \lambda \alpha (\exp(x) - 1) & \text{if } x \leq 0 \end{cases}\end{split}\\]
where \\(\lambda = 1.0507\\) and \\(\alpha = 1.67326\\).
See also `elu()`.
# mlx.nn.sigmoid
class sigmoid(x)
    
Applies the sigmoid function.
\\[\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}\\]
# mlx.nn.silu
class silu(x)
    
Applies the Sigmoid Linear Unit. Also known as Swish.
Applies \\(x \sigma(x)\\) element wise, where \\(\sigma(\cdot)\\) is the logistic sigmoid.
# mlx.nn.softmax
class softmax(x, axis=-1)
    
Applies the Softmax function.
Applies \\(\frac{e^{x_i}}{\sum_j e^{x_j}}\\) element wise.
# mlx.nn.softmin
class softmin(x, axis=-1)
    
Applies the Softmin function.
Applies \\(\frac{e^{-x_i}}{\sum_j e^{-x_j}}\\) element-wise.
# mlx.nn.softplus
class softplus(x)
    
Applies the Softplus function.
Applies \\(\log(1 + \exp(x))\\) element wise.
# mlx.nn.softshrink
class softshrink(x, lambd: float = 0.5)
    
Applies the Softshrink activation function.
\\[\begin{split}\text{softshrink}(x) = \begin{cases} x - \lambda & \text{if } x > \lambda \\\ x + \lambda & \text{if } x < -\lambda \\\ 0 & \text{otherwise} \end{cases}\end{split}\\]
# mlx.nn.step
class step(x: array, threshold: float = 0.0)
    
Applies the Step Activation Function.
This function implements a binary step activation, where the output is set to 1 if the input is greater than a specified threshold, and 0 otherwise.
\\[\begin{split}\text{step}(x) = \begin{cases} 0 & \text{if } x < \text{threshold} \\\ 1 & \text{if } x \geq \text{threshold} \end{cases}\end{split}\\]
Parameters:
    
threshold – The value to threshold at.
# mlx.nn.tanh
class tanh(x)
    
Applies the hyperbolic tangent function.
Simply `mx.tanh(x)`.
# Functions
Layers without parameters (e.g. activation functions) are also provided as simple functions.
`elu`(x[, alpha])
Applies the Exponential Linear Unit.  
`celu`(x[, alpha])
Applies the Continuously Differentiable Exponential Linear Unit.  
`gelu`(x)
Applies the Gaussian Error Linear Units function.  
`gelu_approx`(x)
An approximation to Gaussian Error Linear Unit.  
`gelu_fast_approx`(x)
A fast approximation to Gaussian Error Linear Unit.  
`glu`(x[, axis])
Applies the gated linear unit function.  
`hard_shrink`(x[, lambd])
Applies the HardShrink activation function.  
`hard_tanh`(x[, min_val, max_val])
Applies the HardTanh function.  
`hardswish`(x)
Applies the hardswish function, element-wise.  
`leaky_relu`(x[, negative_slope])
Applies the Leaky Rectified Linear Unit.  
`log_sigmoid`(x)
Applies the Log Sigmoid function.  
`log_softmax`(x[, axis])
Applies the Log Softmax function.  
`mish`(x)
Applies the Mish function, element-wise.  
`prelu`(x, alpha)
Applies the element-wise parametric ReLU.  
`relu`(x)
Applies the Rectified Linear Unit.  
`relu2`(x)
Applies the ReLU² activation function.  
`relu6`(x)
Applies the Rectified Linear Unit 6.  
`selu`(x)
Applies the Scaled Exponential Linear Unit.  
`sigmoid`(x)
Applies the sigmoid function.  
`silu`(x)
Applies the Sigmoid Linear Unit.  
`softmax`(x[, axis])
Applies the Softmax function.  
`softmin`(x[, axis])
Applies the Softmin function.  
`softplus`(x)
Applies the Softplus function.  
`softshrink`(x[, lambd])
Applies the Softshrink activation function.  
`step`(x[, threshold])
Applies the Step Activation Function.  
`tanh`(x)
Applies the hyperbolic tangent function.  
# Initializers
The `mlx.nn.init` package contains commonly used initializers for neural network parameters. Initializers return a function which can be applied to any input `mlx.core.array` to produce an initialized output.
For example:
    
    import mlx.core as mx
    import mlx.nn as nn
    
    init_fn = nn.init.uniform()
    
    # Produces a [2, 2] uniform matrix
    param = init_fn(mx.zeros((2, 2)))
    
To re-initialize all the parameter in an `mlx.nn.Module` from say a uniform distribution, you can do:
    
    import mlx.nn as nn
    model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5))
    init_fn = nn.init.uniform(low=-0.1, high=0.1)
    model.apply(init_fn)
    
`constant`(value[, dtype])
An initializer that returns an array filled with `value`.  
`normal`([mean, std, dtype])
An initializer that returns samples from a normal distribution.  
`uniform`([low, high, dtype])
An initializer that returns samples from a uniform distribution.  
`identity`([dtype])
An initializer that returns an identity matrix.  
`glorot_normal`([dtype])
A Glorot normal initializer.  
`glorot_uniform`([dtype])
A Glorot uniform initializer.  
`he_normal`([dtype])
Build a He normal initializer.  
`he_uniform`([dtype])
A He uniform (Kaiming uniform) initializer.  
# Layers
`ALiBi`()  
`AvgPool1d`(kernel_size[, stride, padding])
Applies 1-dimensional average pooling.  
`AvgPool2d`(kernel_size[, stride, padding])
Applies 2-dimensional average pooling.  
`AvgPool3d`(kernel_size[, stride, padding])
Applies 3-dimensional average pooling.  
`BatchNorm`(num_features[, eps, momentum, ...])
Applies Batch Normalization over a 2D or 3D input.  
`CELU`([alpha])
Applies the Continuously Differentiable Exponential Linear Unit.  
`Conv1d`(in_channels, out_channels, kernel_size)
Applies a 1-dimensional convolution over the multi-channel input sequence.  
`Conv2d`(in_channels, out_channels, kernel_size)
Applies a 2-dimensional convolution over the multi-channel input image.  
`Conv3d`(in_channels, out_channels, kernel_size)
Applies a 3-dimensional convolution over the multi-channel input image.  
`ConvTranspose1d`(in_channels, out_channels, ...)
Applies a 1-dimensional transposed convolution over the multi-channel input sequence.  
`ConvTranspose2d`(in_channels, out_channels, ...)
Applies a 2-dimensional transposed convolution over the multi-channel input image.  
`ConvTranspose3d`(in_channels, out_channels, ...)
Applies a 3-dimensional transposed convolution over the multi-channel input image.  
`Dropout`([p])
Randomly zero a portion of the elements during training.  
`Dropout2d`([p])
Apply 2D channel-wise dropout during training.  
`Dropout3d`([p])
Apply 3D channel-wise dropout during training.  
`Embedding`(num_embeddings, dims)
Implements a simple lookup table that maps each input integer to a high-dimensional vector.  
`ELU`([alpha])
Applies the Exponential Linear Unit.  
`GELU`([approx])
Applies the Gaussian Error Linear Units.  
`GLU`([axis])
Applies the gated linear unit function.  
`GroupNorm`(num_groups, dims[, eps, affine, ...])
Applies Group Normalization [1] to the inputs.  
`GRU`(input_size, hidden_size[, bias])
A gated recurrent unit (GRU) RNN layer.  
`HardShrink`()
Applies the HardShrink function.  
`HardTanh`()
Applies the HardTanh function.  
`Hardswish`()
Applies the hardswish function, element-wise.  
`InstanceNorm`(dims[, eps, affine])
Applies instance normalization [1] on the inputs.  
`LayerNorm`(dims[, eps, affine, bias])
Applies layer normalization [1] on the inputs.  
`LeakyReLU`([negative_slope])
Applies the Leaky Rectified Linear Unit.  
`Linear`(input_dims, output_dims[, bias])
Applies an affine transformation to the input.  
`LogSigmoid`()
Applies the Log Sigmoid function.  
`LogSoftmax`()
Applies the Log Softmax function.  
`LSTM`(input_size, hidden_size[, bias])
An LSTM recurrent layer.  
`MaxPool1d`(kernel_size[, stride, padding])
Applies 1-dimensional max pooling.  
`MaxPool2d`(kernel_size[, stride, padding])
Applies 2-dimensional max pooling.  
`MaxPool3d`(kernel_size[, stride, padding])
Applies 3-dimensional max pooling.  
`Mish`()
Applies the Mish function, element-wise.  
`MultiHeadAttention`(dims, num_heads[, ...])
Implements the scaled dot product attention with multiple heads.  
`PReLU`([num_parameters, init])
Applies the element-wise parametric ReLU.  
`QuantizedEmbedding`(num_embeddings, dims[, ...])
The same as `Embedding` but with a quantized weight matrix.  
`QuantizedLinear`(input_dims, output_dims[, ...])
Applies an affine transformation to the input using a quantized weight matrix.  
`RMSNorm`(dims[, eps])
Applies Root Mean Square normalization [1] to the inputs.  
`ReLU`()
Applies the Rectified Linear Unit.  
`ReLU2`()
Applies the ReLU² activation function.  
`ReLU6`()
Applies the Rectified Linear Unit 6.  
`RNN`(input_size, hidden_size[, bias, ...])
An Elman recurrent layer.  
`RoPE`(dims[, traditional, base, scale])
Implements the rotary positional encoding.  
`SELU`()
Applies the Scaled Exponential Linear Unit.  
`Sequential`(*modules)
A layer that calls the passed callables in order.  
`Sigmoid`()
Applies the sigmoid function, element-wise.  
`SiLU`()
Applies the Sigmoid Linear Unit.  
`SinusoidalPositionalEncoding`(dims[, ...])
Implements sinusoidal positional encoding.  
`Softmin`()
Applies the Softmin function.  
`Softshrink`([lambd])
Applies the Softshrink function.  
`Softsign`()
Applies the Softsign function.  
`Softmax`()
Applies the Softmax function.  
`Softplus`()
Applies the Softplus function.  
`Step`([threshold])
Applies the Step Activation Function.  
`Tanh`()
Applies the hyperbolic tangent function.  
`Transformer`(dims, num_heads, ...)
Implements a standard Transformer model.  
`Upsample`(scale_factor[, mode, align_corners])
Upsample the input signal spatially.  
# Loss Functions
`binary_cross_entropy`(inputs, targets[, ...])
Computes the binary cross entropy loss.  
`cosine_similarity_loss`(x1, x2[, axis, eps, ...])
Computes the cosine similarity between the two inputs.  
`cross_entropy`(logits, targets[, weights, ...])
Computes the cross entropy loss.  
`gaussian_nll_loss`(inputs, targets, vars[, ...])
Computes the negative log likelihood loss for a Gaussian distribution.  
`hinge_loss`(inputs, targets[, reduction])
Computes the hinge loss between inputs and targets.  
`huber_loss`(inputs, targets[, delta, reduction])
Computes the Huber loss between inputs and targets.  
`kl_div_loss`(inputs, targets[, axis, reduction])
Computes the Kullback-Leibler divergence loss.  
`l1_loss`(predictions, targets[, reduction])
Computes the L1 loss.  
`log_cosh_loss`(inputs, targets[, reduction])
Computes the log cosh loss between inputs and targets.  
`margin_ranking_loss`(inputs1, inputs2, targets)
Calculate the margin ranking loss that loss given inputs \\(x_1\\), \\(x_2\\) and a label \\(y\\) (containing 1 or -1).  
`mse_loss`(predictions, targets[, reduction])
Computes the mean squared error loss.  
`nll_loss`(inputs, targets[, axis, reduction])
Computes the negative log likelihood loss.  
`smooth_l1_loss`(predictions, targets[, beta, ...])
Computes the smooth L1 loss.  
`triplet_loss`(anchors, positives, negatives)
Computes the triplet loss for a set of anchor, positive, and negative samples.  
# Module
class Module
    
Base class for building neural networks with MLX.
All the layers provided in `mlx.nn.layers` subclass this class and your models should do the same.
A `Module` can contain other `Module` instances or `mlx.core.array` instances in arbitrary nesting of python lists or dicts. The `Module` then allows recursively extracting all the `mlx.core.array` instances using `mlx.nn.Module.parameters()`.
In addition, the `Module` has the concept of trainable and non trainable parameters (called “frozen”). When using `mlx.nn.value_and_grad()` the gradients are returned only with respect to the trainable parameters. All arrays in a module are trainable unless they are added in the “frozen” set by calling `freeze()`.
    
    import mlx.core as mx
    import mlx.nn as nn
    
    class MyMLP(nn.Module):
        def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16):
            super().__init__()
    
            self.in_proj = nn.Linear(in_dims, hidden_dims)
            self.out_proj = nn.Linear(hidden_dims, out_dims)
    
        def __call__(self, x):
            x = self.in_proj(x)
            x = mx.maximum(x, 0)
            return self.out_proj(x)
    
    model = MyMLP(2, 1)
    
    # All the model parameters are created but since MLX is lazy by
    # default, they are not evaluated yet. Calling `mx.eval` actually
    # allocates memory and initializes the parameters.
    mx.eval(model.parameters())
    
    # Setting a parameter to a new value is as simply as accessing that
    # parameter and assigning a new array to it.
    model.in_proj.weight = model.in_proj.weight * 2
    mx.eval(model.parameters())
    
Attributes
`Module.training`
Boolean indicating if the model is in training mode.  
`Module.state`
The module's state dictionary  
Methods
`Module.apply`(map_fn[, filter_fn])
Map all the parameters using the provided `map_fn` and immediately update the module with the mapped parameters.  
`Module.apply_to_modules`(apply_fn)
Apply a function to all the modules in this instance (including this instance).  
`Module.children`()
Return the direct descendants of this Module instance.  
`Module.eval`()
Set the model to evaluation mode.  
`Module.filter_and_map`(filter_fn[, map_fn, ...])
Recursively filter the contents of the module using `filter_fn`, namely only select keys and values where `filter_fn` returns true.  
`Module.freeze`(*[, recurse, keys, strict])
Freeze the Module's parameters or some of them.  
`Module.leaf_modules`()
Return the submodules that do not contain other modules.  
`Module.load_weights`(file_or_weights[, strict])
Update the model's weights from a `.npz`, a `.safetensors` file, or a list.  
`Module.modules`()
Return a list with all the modules in this instance.  
`Module.named_modules`()
Return a list with all the modules in this instance and their name with dot notation.  
`Module.parameters`()
Recursively return all the `mlx.core.array` members of this Module as a dict of dicts and lists.  
`Module.save_weights`(file)
Save the model's weights to a file.  
`Module.set_dtype`(dtype[, predicate])
Set the dtype of the module's parameters.  
`Module.train`([mode])
Set the model in or out of training mode.  
`Module.trainable_parameters`()
Recursively return all the non frozen `mlx.core.array` members of this Module as a dict of dicts and lists.  
`Module.unfreeze`(*[, recurse, keys, strict])
Unfreeze the Module's parameters or some of them.  
`Module.update`(parameters[, strict])
Replace the parameters of this Module with the provided ones in the dict of dicts and lists.  
`Module.update_modules`(modules[, strict])
Replace the child modules of this `Module` instance with the provided ones in the dict of dicts and lists.  
# Operations
`abs`(a, /, *[, stream])
Element-wise absolute value.  
`add`(a, b[, stream])
Element-wise addition.  
`addmm`(c, a, b, /[, alpha, beta, stream])
Matrix multiplication with addition and optional scaling.  
`all`(a, /[, axis, keepdims, stream])
An and reduction over the given axes.  
`allclose`(a, b, /[, rtol, atol, equal_nan, ...])
Approximate comparison of two arrays.  
`any`(a, /[, axis, keepdims, stream])
An or reduction over the given axes.  
`arange`(-> array)
Overloaded function.  
`arccos`(a, /, *[, stream])
Element-wise inverse cosine.  
`arccosh`(a, /, *[, stream])
Element-wise inverse hyperbolic cosine.  
`arcsin`(a, /, *[, stream])
Element-wise inverse sine.  
`arcsinh`(a, /, *[, stream])
Element-wise inverse hyperbolic sine.  
`arctan`(a, /, *[, stream])
Element-wise inverse tangent.  
`arctan2`(a, b, /, *[, stream])
Element-wise inverse tangent of the ratio of two arrays.  
`arctanh`(a, /, *[, stream])
Element-wise inverse hyperbolic tangent.  
`argmax`(a, /[, axis, keepdims, stream])
Indices of the maximum values along the axis.  
`argmin`(a, /[, axis, keepdims, stream])
Indices of the minimum values along the axis.  
`argpartition`(a, /, kth[, axis, stream])
Returns the indices that partition the array.  
`argsort`(a, /[, axis, stream])
Returns the indices that sort the array.  
`array_equal`(a, b[, equal_nan, stream])
Array equality check.  
`as_strided`(a, /[, shape, strides, offset, ...])
Create a view into the array with the given shape and strides.  
`atleast_1d`(*arys[, stream])
Convert all arrays to have at least one dimension.  
`atleast_2d`(*arys[, stream])
Convert all arrays to have at least two dimensions.  
`atleast_3d`(*arys[, stream])
Convert all arrays to have at least three dimensions.  
`bitwise_and`(a, b[, stream])
Element-wise bitwise and.  
`bitwise_invert`(a[, stream])
Element-wise bitwise inverse.  
`bitwise_or`(a, b[, stream])
Element-wise bitwise or.  
`bitwise_xor`(a, b[, stream])
Element-wise bitwise xor.  
`block_masked_mm`(a, b, /[, block_size, ...])
Matrix multiplication with block masking.  
`broadcast_arrays`(*arrays[, stream])
Broadcast arrays against one another.  
`broadcast_to`(a, /, shape, *[, stream])
Broadcast an array to the given shape.  
`ceil`(a, /, *[, stream])
Element-wise ceil.  
`clip`(a, /, a_min, a_max, *[, stream])
Clip the values of the array between the given minimum and maximum.  
`concatenate`(arrays[, axis, stream])
Concatenate the arrays along the given axis.  
`contiguous`(a, /[, allow_col_major, stream])
Force an array to be row contiguous.  
`conj`(a, *[, stream])
Return the elementwise complex conjugate of the input.  
`conjugate`(a, *[, stream])
Return the elementwise complex conjugate of the input.  
`convolve`(a, v, /[, mode, stream])
The discrete convolution of 1D arrays.  
`conv1d`(input, weight, /[, stride, padding, ...])
1D convolution over an input with several channels  
`conv2d`(input, weight, /[, stride, padding, ...])
2D convolution over an input with several channels  
`conv3d`(input, weight, /[, stride, padding, ...])
3D convolution over an input with several channels  
`conv_transpose1d`(input, weight, /[, stride, ...])
1D transposed convolution over an input with several channels  
`conv_transpose2d`(input, weight, /[, stride, ...])
2D transposed convolution over an input with several channels  
`conv_transpose3d`(input, weight, /[, stride, ...])
3D transposed convolution over an input with several channels  
`conv_general`(input, weight, /[, stride, ...])
General convolution over an input with several channels  
`cos`(a, /, *[, stream])
Element-wise cosine.  
`cosh`(a, /, *[, stream])
Element-wise hyperbolic cosine.  
`cummax`(a, /[, axis, reverse, inclusive, stream])
Return the cumulative maximum of the elements along the given axis.  
`cummin`(a, /[, axis, reverse, inclusive, stream])
Return the cumulative minimum of the elements along the given axis.  
`cumprod`(a, /[, axis, reverse, inclusive, stream])
Return the cumulative product of the elements along the given axis.  
`cumsum`(a, /[, axis, reverse, inclusive, stream])
Return the cumulative sum of the elements along the given axis.  
`degrees`(a, /, *[, stream])
Convert angles from radians to degrees.  
`dequantize`(w, /, scales[, biases, ...])
Dequantize the matrix `w` using quantization parameters.  
`diag`(a, /[, k, stream])
Extract a diagonal or construct a diagonal matrix.  
`diagonal`(a[, offset, axis1, axis2, stream])
Return specified diagonals.  
`divide`(a, b[, stream])
Element-wise division.  
`divmod`(a, b[, stream])
Element-wise quotient and remainder.  
`einsum`(subscripts, *operands[, stream])
Perform the Einstein summation convention on the operands.  
`einsum_path`(subscripts, *operands)
Compute the contraction order for the given Einstein summation.  
`equal`(a, b[, stream])
Element-wise equality.  
`erf`(a, /, *[, stream])
Element-wise error function.  
`erfinv`(a, /, *[, stream])
Element-wise inverse of `erf()`.  
`exp`(a, /, *[, stream])
Element-wise exponential.  
`expm1`(a, /, *[, stream])
Element-wise exponential minus 1.  
`expand_dims`(a, /, axis, *[, stream])
Add a size one dimension at the given axis.  
`eye`(n[, m, k, dtype, stream])
Create an identity matrix or a general diagonal matrix.  
`flatten`(a, /[, start_axis, end_axis, stream])
Flatten an array.  
`floor`(a, /, *[, stream])
Element-wise floor.  
`floor_divide`(a, b[, stream])
Element-wise integer division.  
`full`(shape, vals[, dtype, stream])
Construct an array with the given value.  
`gather_mm`(a, b, /, lhs_indices, rhs_indices, *)
Matrix multiplication with matrix-level gather.  
`gather_qmm`(x, w, /, scales[, biases, ...])
Perform quantized matrix multiplication with matrix-level gather.  
`greater`(a, b[, stream])
Element-wise greater than.  
`greater_equal`(a, b[, stream])
Element-wise greater or equal.  
`hadamard_transform`(a[, scale, stream])
Perform the Walsh-Hadamard transform along the final axis.  
`identity`(n[, dtype, stream])
Create a square identity matrix.  
`imag`(a, /, *[, stream])
Returns the imaginary part of a complex array.  
`inner`(a, b, /, *[, stream])
Ordinary inner product of vectors for 1-D arrays, in higher dimensions a sum product over the last axes.  
`isfinite`(a[, stream])
Return a boolean array indicating which elements are finite.  
`isclose`(a, b, /[, rtol, atol, equal_nan, stream])
Returns a boolean array where two arrays are element-wise equal within a tolerance.  
`isinf`(a[, stream])
Return a boolean array indicating which elements are +/- inifnity.  
`isnan`(a[, stream])
Return a boolean array indicating which elements are NaN.  
`isneginf`(a[, stream])
Return a boolean array indicating which elements are negative infinity.  
`isposinf`(a[, stream])
Return a boolean array indicating which elements are positive infinity.  
`issubdtype`(arg1, arg2)
Check if a `Dtype` or `DtypeCategory` is a subtype of another.  
`kron`(a, b, *[, stream])
Compute the Kronecker product of two arrays `a` and `b`.  
`left_shift`(a, b[, stream])
Element-wise left shift.  
`less`(a, b[, stream])
Element-wise less than.  
`less_equal`(a, b[, stream])
Element-wise less than or equal.  
`linspace`(start, stop[, num, dtype, stream])
Generate `num` evenly spaced numbers over interval `[start, stop]`.  
`load`(file, /[, format, return_metadata, stream])
Load array(s) from a binary file.  
`log`(a, /, *[, stream])
Element-wise natural logarithm.  
`log2`(a, /, *[, stream])
Element-wise base-2 logarithm.  
`log10`(a, /, *[, stream])
Element-wise base-10 logarithm.  
`log1p`(a, /, *[, stream])
Element-wise natural log of one plus the array.  
`logaddexp`(a, b, /, *[, stream])
Element-wise log-add-exp.  
`logcumsumexp`(a, /[, axis, reverse, ...])
Return the cumulative logsumexp of the elements along the given axis.  
`logical_not`(a, /, *[, stream])
Element-wise logical not.  
`logical_and`(a, b, /, *[, stream])
Element-wise logical and.  
`logical_or`(a, b, /, *[, stream])
Element-wise logical or.  
`logsumexp`(a, /[, axis, keepdims, stream])
A log-sum-exp reduction over the given axes.  
`matmul`(a, b, /, *[, stream])
Matrix multiplication.  
`max`(a, /[, axis, keepdims, stream])
A max reduction over the given axes.  
`maximum`(a, b, /, *[, stream])
Element-wise maximum.  
`mean`(a, /[, axis, keepdims, stream])
Compute the mean(s) over the given axes.  
`meshgrid`(*arrays[, sparse, indexing, stream])
Generate multidimensional coordinate grids from 1-D coordinate arrays  
`min`(a, /[, axis, keepdims, stream])
A min reduction over the given axes.  
`minimum`(a, b, /, *[, stream])
Element-wise minimum.  
`moveaxis`(a, /, source, destination, *[, stream])
Move an axis to a new position.  
`multiply`(a, b[, stream])
Element-wise multiplication.  
`nan_to_num`(a[, nan, posinf, neginf, stream])
Replace NaN and Inf values with finite numbers.  
`negative`(a, /, *[, stream])
Element-wise negation.  
`not_equal`(a, b[, stream])
Element-wise not equal.  
`ones`(shape[, dtype, stream])
Construct an array of ones.  
`ones_like`(a, /, *[, stream])
An array of ones like the input.  
`outer`(a, b, /, *[, stream])
Compute the outer product of two 1-D arrays, if the array's passed are not 1-D a flatten op will be run beforehand.  
`partition`(a, /, kth[, axis, stream])
Returns a partitioned copy of the array such that the smaller `kth` elements are first.  
`pad`(a, pad_width[, mode, constant_values, ...])
Pad an array with a constant value  
`power`(a, b, /, *[, stream])
Element-wise power operation.  
`prod`(a, /[, axis, keepdims, stream])
An product reduction over the given axes.  
`put_along_axis`(a, /, indices, values[, ...])
Put values along an axis at the specified indices.  
`quantize`(w, /[, group_size, bits, mode, stream])
Quantize the matrix `w` using `bits` bits per element.  
`quantized_matmul`(x, w, /, scales[, biases, ...])
Perform the matrix multiplication with the quantized matrix `w`.  
`radians`(a, /, *[, stream])
Convert angles from degrees to radians.  
`real`(a, /, *[, stream])
Returns the real part of a complex array.  
`reciprocal`(a, /, *[, stream])
Element-wise reciprocal.  
`remainder`(a, b[, stream])
Element-wise remainder of division.  
`repeat`(array, repeats[, axis, stream])
Repeat an array along a specified axis.  
`reshape`(a, /, shape, *[, stream])
Reshape an array while preserving the size.  
`right_shift`(a, b[, stream])
Element-wise right shift.  
`roll`(a, shift[, axis, stream])
Roll array elements along a given axis.  
`round`(a, /[, decimals, stream])
Round to the given number of decimals.  
`rsqrt`(a, /, *[, stream])
Element-wise reciprocal and square root.  
`save`(file, arr)
Save the array to a binary file in `.npy` format.  
`savez`(file, *args, **kwargs)
Save several arrays to a binary file in uncompressed `.npz` format.  
`savez_compressed`(file, *args, **kwargs)
Save several arrays to a binary file in compressed `.npz` format.  
`save_gguf`(file, arrays, metadata)
Save array(s) to a binary file in `.gguf` format.  
`save_safetensors`(file, arrays[, metadata])
Save array(s) to a binary file in `.safetensors` format.  
`sigmoid`(a, /, *[, stream])
Element-wise logistic sigmoid.  
`sign`(a, /, *[, stream])
Element-wise sign.  
`sin`(a, /, *[, stream])
Element-wise sine.  
`sinh`(a, /, *[, stream])
Element-wise hyperbolic sine.  
`slice`(a, start_indices, axes, slice_size, *)
Extract a sub-array from the input array.  
`slice_update`(a, update, start_indices, axes, *)
Update a sub-array of the input array.  
`softmax`(a, /[, axis, stream])
Perform the softmax along the given axis.  
`sort`(a, /[, axis, stream])
Returns a sorted copy of the array.  
`split`(a, /, indices_or_sections[, axis, stream])
Split an array along a given axis.  
`sqrt`(a, /, *[, stream])
Element-wise square root.  
`square`(a, /, *[, stream])
Element-wise square.  
`squeeze`(a, /[, axis, stream])
Remove length one axes from an array.  
`stack`(arrays[, axis, stream])
Stacks the arrays along a new axis.  
`std`(a, /[, axis, keepdims, ddof, stream])
Compute the standard deviation(s) over the given axes.  
`stop_gradient`(a, /, *[, stream])
Stop gradients from being computed.  
`subtract`(a, b[, stream])
Element-wise subtraction.  
`sum`(a, /[, axis, keepdims, stream])
Sum reduce the array over the given axes.  
`swapaxes`(a, /, axis1, axis2, *[, stream])
Swap two axes of an array.  
`take`(a, /, indices[, axis, stream])
Take elements along an axis.  
`take_along_axis`(a, /, indices[, axis, stream])
Take values along an axis at the specified indices.  
`tan`(a, /, *[, stream])
Element-wise tangent.  
`tanh`(a, /, *[, stream])
Element-wise hyperbolic tangent.  
`tensordot`(a, b, /[, axes, stream])
Compute the tensor dot product along the specified axes.  
`tile`(a, reps, /, *[, stream])
Construct an array by repeating `a` the number of times given by `reps`.  
`topk`(a, /, k[, axis, stream])
Returns the `k` largest elements from the input along a given axis.  
`trace`(a, /[, offset, axis1, axis2, dtype, ...])
Return the sum along a specified diagonal in the given array.  
`transpose`(a, /[, axes, stream])
Transpose the dimensions of the array.  
`tri`(n, m, k[, dtype, stream])
An array with ones at and below the given diagonal and zeros elsewhere.  
`tril`(x, k, *[, stream])
Zeros the array above the given diagonal.  
`triu`(x, k, *[, stream])
Zeros the array below the given diagonal.  
`unflatten`(a, /, axis, shape, *[, stream])
Unflatten an axis of an array to a shape.  
`var`(a, /[, axis, keepdims, ddof, stream])
Compute the variance(s) over the given axes.  
`view`(a, dtype[, stream])
View the array as a different type.  
`where`(condition, x, y, /, *[, stream])
Select from `x` or `y` according to `condition`.  
`zeros`(shape[, dtype, stream])
Construct an array of zeros.  
`zeros_like`(a, /, *[, stream])
An array of zeros like the input.  
# Optimizers
The optimizers in MLX can be used both with `mlx.nn` but also with pure `mlx.core` functions. A typical example involves calling `Optimizer.update()` to update a model’s parameters based on the loss gradients and subsequently calling `mlx.core.eval()` to evaluate both the model’s parameters and the optimizer state.
    
    # Create a model
    model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes)
    mx.eval(model.parameters())
    
    # Create the gradient function and the optimizer
    loss_and_grad_fn = nn.value_and_grad(model, loss_fn)
    optimizer = optim.SGD(learning_rate=learning_rate)
    
    for e in range(num_epochs):
        for X, y in batch_iterate(batch_size, train_images, train_labels):
            loss, grads = loss_and_grad_fn(model, X, y)
    
            # Update the model with the gradients. So far no computation has happened.
            optimizer.update(model, grads)
    
            # Compute the new parameters but also the optimizer state.
            mx.eval(model.parameters(), optimizer.state)
    
## Saving and Loading
To serialize an optimizer, save its state. To load an optimizer, load and set the saved state. Here’s a simple example:
    
    import mlx.core as mx
    from mlx.utils import tree_flatten, tree_unflatten
    import mlx.optimizers as optim
    
    optimizer = optim.Adam(learning_rate=1e-2)
    
    # Perform some updates with the optimizer
    model = {"w" : mx.zeros((5, 5))}
    grads = {"w" : mx.ones((5, 5))}
    optimizer.update(model, grads)
    
    # Save the state
    state = tree_flatten(optimizer.state, destination={})
    mx.save_safetensors("optimizer.safetensors", state)
    
    # Later on, for example when loading from a checkpoint,
    # recreate the optimizer and load the state
    optimizer = optim.Adam(learning_rate=1e-2)
    
    state = tree_unflatten(mx.load("optimizer.safetensors"))
    optimizer.state = state
    
Note, not every optimizer configuation parameter is saved in the state. For example, for Adam the learning rate is saved but the `betas` and `eps` parameters are not. A good rule of thumb is if the parameter can be scheduled then it will be included in the optimizer state.
  * Optimizer
    * `Optimizer`
    * mlx.optimizers.Optimizer.state
      * `Optimizer.state`
    * mlx.optimizers.Optimizer.apply_gradients
      * `Optimizer.apply_gradients()`
    * mlx.optimizers.Optimizer.init
      * `Optimizer.init()`
    * mlx.optimizers.Optimizer.update
      * `Optimizer.update()`
  * Common Optimizers
    * mlx.optimizers.SGD
      * `SGD`
    * mlx.optimizers.RMSprop
      * `RMSprop`
    * mlx.optimizers.Adagrad
      * `Adagrad`
    * mlx.optimizers.Adafactor
      * `Adafactor`
    * mlx.optimizers.AdaDelta
      * `AdaDelta`
    * mlx.optimizers.Adam
      * `Adam`
    * mlx.optimizers.AdamW
      * `AdamW`
    * mlx.optimizers.Adamax
      * `Adamax`
    * mlx.optimizers.Lion
      * `Lion`
    * mlx.optimizers.MultiOptimizer
      * `MultiOptimizer`
    * mlx.optimizers.Muon
      * `Muon`
  * Schedulers
    * mlx.optimizers.cosine_decay
      * `cosine_decay()`
    * mlx.optimizers.exponential_decay
      * `exponential_decay()`
    * mlx.optimizers.join_schedules
      * `join_schedules()`
    * mlx.optimizers.linear_schedule
      * `linear_schedule()`
    * mlx.optimizers.step_decay
      * `step_decay()`


`clip_grad_norm`(grads, max_norm)
Clips the global norm of the gradients.  
# mlx.optimizers.AdaDelta
class AdaDelta(learning_rate: float | Callable[[array], array], rho: float = 0.9, eps: float = 1e-06)
    
The AdaDelta optimizer with a learning rate [1].
Our AdaDelta implementation follows the original paper. In detail,
[1]: Zeiler, M.D., 2012. ADADELTA: an adaptive learning rate method. arXiv preprint arXiv:1212.5701.
\\[\begin{split}v_{t+1} &= \rho v_t + (1 - \rho) g_t^2 \\\ \Delta w_{t+1} &= \frac{\sqrt{u_t + \epsilon}}{\sqrt{v_{t+1} + \epsilon}} g_t \\\ u_{t+1} &= \rho u_t + (1 - \rho) \Delta w_{t+1}^2 \\\ w_{t+1} &= w_t - \lambda \Delta w_{t+1}\end{split}\\]
Parameters:
    
  * learning_rate (float or callable) – The learning rate \\(\lambda\\).
  * rho (float, optional) – The coefficient \\(\rho\\) used for computing a running average of squared gradients. Default: `0.9`
  * eps (float, optional) – The term \\(\epsilon\\) added to the denominator to improve numerical stability. Default: 1e-8


Methods
`__init__`(learning_rate[, rho, eps])  
`apply_single`(gradient, parameter, state)
Performs the AdaDelta parameter update and stores \\(v\\) and \\(u\\) in the optimizer state.  
`init_single`(parameter, state)
Initialize optimizer state  
# mlx.optimizers.Adafactor
class Adafactor(learning_rate: float | Callable[[array], array] | None = None, eps: Tuple[float, float] = (1e-30, 0.001), clip_threshold: float = 1.0, decay_rate: float = -0.8, beta_1: float | None = None, weight_decay: float = 0.0, scale_parameter: bool = True, relative_step: bool = True, warmup_init: bool = False)
    
The Adafactor optimizer.
Our Adafactor implementation follows the original paper: Adafactor: Adaptive Learning Rates with Sublinear Memory Cost
Parameters:
    
  * learning_rate (float or callable, optional) – The learning rate. Default: `None`.
  * eps (tuple(float, float), optional) – The first term \\(\epsilon_1\\) added to the square of the gradients to improve numerical stability and the second term \\(\epsilon_2\\) is used for parameter scaling if `parameter_scale` is set to `True`. Default: `(1e-30, 1e-3)`.
  * clip_threshold (float, optional) – Clips the unscaled update at `clip_threshold`. Default: `1.0`.
  * decay_rate (float, optional) – Coefficient for the running average of the squared gradient. Default: `-0.8`.
  * beta_1 (float, optional) – If set to a value bigger than zero then first moment will be used. Default: `None`.
  * weight_decay (float, optional) – The weight decay \\(\lambda\\). Default: `0.0`.
  * scale_parameter (bool, optional) – If set to `True` the learning rate will be scaled by \\(\max(\epsilon_1, \text{RMS}(w_{t-1}))\\). Default: `True`.
  * relative_step (bool, optional) – If set to `True` the `learning_rate` will be ignored and relative step size will be computed. Default: `True`.
  * warmup_init (bool, optional) – If set to `True` then the relative step size will be calculated by the current step. Default: `False`.


Methods
`__init__`([learning_rate, eps, ...])  
`apply_single`(gradient, parameter, state)
Performs the Adafactor parameter and state update.  
`init_single`(parameter, state)
Initialize optimizer state  
# mlx.optimizers.Adagrad
class Adagrad(learning_rate: float | Callable[[array], array], eps: float = 1e-08)
    
The Adagrad optimizer [1].
Our Adagrad implementation follows the original paper. In detail,
[1]: Duchi, J., Hazan, E. and Singer, Y., 2011. Adaptive subgradient methods for online learning and stochastic optimization. JMLR 2011.
\\[\begin{split}v_{t+1} &= v_t + g_t^2 \\\ w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}\end{split}\\]
Parameters:
    
  * learning_rate (float or callable) – The learning rate \\(\lambda\\).
  * eps (float, optional) – The term \\(\epsilon\\) added to the denominator to improve numerical stability. Default: `1e-8`


Methods
`__init__`(learning_rate[, eps])  
`apply_single`(gradient, parameter, state)
Performs the Adagrad parameter update and stores \\(v\\) in the optimizer state.  
`init_single`(parameter, state)
Initialize optimizer state  
# mlx.optimizers.Adam
class Adam(learning_rate: float | Callable[[array], array], betas: List[float] = [0.9, 0.999], eps: float = 1e-08, bias_correction: bool = False)
    
The Adam optimizer [1]. In detail,
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic optimization. ICLR 2015.
\\[\begin{split}m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\\ v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\\ w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon}\end{split}\\]
Parameters:
    
  * learning_rate (float or callable) – The learning rate \\(\lambda\\).
  * betas (Tuple[float, float], optional) – The coefficients \\((\beta_1, \beta_2)\\) used for computing running averages of the gradient and its square. Default: `(0.9, 0.999)`
  * eps (float, optional) – The term \\(\epsilon\\) added to the denominator to improve numerical stability. Default: `1e-8`
  * bias_correction (bool, optional) – If set to `True`, bias correction is applied. Default: `False`


Methods
`__init__`(learning_rate[, betas, eps, ...])  
`apply_single`(gradient, parameter, state)
Performs the Adam parameter update and stores \\(v\\) and \\(m\\) in the optimizer state.  
`init_single`(parameter, state)
Initialize optimizer state  
# mlx.optimizers.AdamW
class AdamW(learning_rate: float | Callable[[array], array], betas: List[float] = [0.9, 0.999], eps: float = 1e-08, weight_decay: float = 0.01, bias_correction: bool = False)
    
The AdamW optimizer [1]. We update the weights with a weight_decay (\\(\lambda\\)) value:
[1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay regularization. ICLR 2019.
\\[\begin{split}m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\\ v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\\ w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1}} + \epsilon} + \lambda w_t)\end{split}\\]
Parameters:
    
  * learning_rate (float or callable) – The learning rate \\(\alpha\\).
  * betas (Tuple[float, float], optional) – The coefficients \\((\beta_1, \beta_2)\\) used for computing running averages of the gradient and its square. Default: `(0.9, 0.999)`
  * eps (float, optional) – The term \\(\epsilon\\) added to the denominator to improve numerical stability. Default: `1e-8`
  * weight_decay (float, optional) – The weight decay \\(\lambda\\). Default: `0.01`.
  * bias_correction (bool, optional) – If set to `True`, bias correction is applied. Default: `False`


Methods
`__init__`(learning_rate[, betas, eps, ...])  
`apply_single`(gradient, parameter, state)
Performs the AdamW parameter update by modifying the parameters passed into Adam.  
# mlx.optimizers.Adamax
class Adamax(learning_rate: float | Callable[[array], array], betas: List[float] = [0.9, 0.999], eps: float = 1e-08)
    
The Adamax optimizer, a variant of Adam based on the infinity norm [1].
Our Adam implementation follows the original paper and omits the bias correction in the first and second moment estimates. In detail,
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic optimization. ICLR 2015.
\\[\begin{split}m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\\ v_{t+1} &= \max(\beta_2 v_t, |g_t|) \\\ w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{v_{t+1} + \epsilon}\end{split}\\]
Parameters:
    
  * learning_rate (float or callable) – The learning rate \\(\lambda\\).
  * betas (Tuple[float, float], optional) – The coefficients \\((\beta_1, \beta_2)\\) used for computing running averages of the gradient and its square. Default: `(0.9, 0.999)`
  * eps (float, optional) – The term \\(\epsilon\\) added to the denominator to improve numerical stability. Default: `1e-8`


Methods
`__init__`(learning_rate[, betas, eps])  
`apply_single`(gradient, parameter, state)
Performs the Adamax parameter update and stores \\(v\\) and \\(m\\) in the optimizer state.  
`init_single`(parameter, state)
Initialize optimizer state  
# mlx.optimizers.Lion
class Lion(learning_rate: float | Callable[[array], array], betas: List[float] = [0.9, 0.99], weight_decay: float = 0.0)
    
The Lion optimizer [1].
Since updates are computed through the sign operation, they tend to have larger norm than for other optimizers such as SGD and Adam. We recommend a learning rate that is 3-10x smaller than AdamW and a weight decay 3-10x larger than AdamW to maintain the strength (lr * wd). Our Lion implementation follows the original paper. In detail,
[1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv preprint arXiv:2302.06675.
\\[\begin{split}c_{t + 1} &= \beta_1 m_t + (1 - \beta_1) g_t \\\ m_{t + 1} &= \beta_2 m_t + (1 - \beta_2) g_t \\\ w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t)\end{split}\\]
Parameters:
    
  * learning_rate (float or callable) – The learning rate \\(\eta\\).
  * betas (Tuple[float, float], optional) – The coefficients \\((\beta_1, \beta_2)\\) used for computing the gradient momentum and update direction. Default: `(0.9, 0.99)`
  * weight_decay (float, optional) – The weight decay \\(\lambda\\). Default: `0.0`


Methods
`__init__`(learning_rate[, betas, weight_decay])  
`apply_single`(gradient, parameter, state)
Performs the Lion parameter update and stores \\(m\\) in the optimizer state.  
`init_single`(parameter, state)
Initialize optimizer state  
# mlx.optimizers.MultiOptimizer
class MultiOptimizer(optimizers, filters: list = [])
    
Wraps a list of optimizers with corresponding weight predicates/filters to make it easy to use different optimizers for different weights.
The predicates take the full “path” of the weight and the weight itself and return True if it should be considered for this optimizer. The last optimizer in the list is a fallback optimizer and no predicate should be given for it.
Parameters:
    
  * optimizers (list[Optimizer]) – A list of optimizers to delegate to
  * filters (list[Callable[[str, array], bool]) – A list of predicates that should be one less than the provided optimizers.


Methods
`__init__`(optimizers[, filters])  
`apply_gradients`(gradients, parameters)
Apply the gradients to the parameters and return the updated parameters.  
`init`(parameters)
Initialize the optimizer's state  
# mlx.optimizers.Muon
class Muon(learning_rate: float | Callable[[array], array], momentum: float = 0.95, weight_decay: float = 0.01, nesterov: bool = True, ns_steps: int = 5)
    
The Muon optimizer.
Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the original implementation: Muon: An optimizer for hidden layers in neural networks
Note
  * Muon may be sub-optimal for the embedding layer, the final fully connected layer, or any 0D/1D parameters. Those should be optimized by a different method (e.g., `AdamW`).
  * For 4D convolutional filters, it works by flattening their last dimensions.


Parameters:
    
  * learning_rate (float or callable) – The learning rate.
  * momentum (float, optional) – The momentum strength. Default: `0.95`
  * weight_decay (float, optional) – The weight decay (L2 penalty). Default: `0.01`
  * nesterov (bool, optional) – Enables Nesterov momentum. Recommended for better performance. Default: `True`
  * ns_steps (int, optional) – Number of Newton-Schulz iteration steps for orthogonalization. Default: `5`


Methods
`__init__`(learning_rate[, momentum, ...])  
`apply_single`(gradient, parameter, state)
Performs the Muon parameter update  
`init_single`(parameter, state)
Initialize optimizer state  
# mlx.optimizers.Optimizer.apply_gradients
Optimizer.apply_gradients(gradients: dict, parameters: dict)
    
Apply the gradients to the parameters and return the updated parameters.
Can be used to update a model via `model.update(opt.apply_gradients(grads, model))` which is precisely how `Optimizer.update()` is implemented.
Parameters:
    
  * gradients (dict) – A Python tree of gradients.
  * parameters (dict) – A Python tree of parameters. It can be a superset of the gradients. In that case the returned python tree will be of the same structure as the gradients.


# mlx.optimizers.Optimizer.init
Optimizer.init(parameters: dict)
    
Initialize the optimizer’s state
This function can be used to initialize optimizers which have state (like momentum in `SGD`). Using this method is optional as the optimizer will initialize itself if the state is not yet set. However, there are some cases where explicit initialization is useful in order to have access to the `Optimizer.state` before the first call to `Optimizer.update()`.
Parameters:
    
model (dict) – A Python tree of parameters.
Example
    
    >>> optimizer = optim.SGD(learning_rate=1e-1, momentum=0.9)
    >>> model = nn.Linear(2, 2)
    >>> optimizer.init(model.trainable_parameters())
    >>> optimizer.state.keys()
    dict_keys(['step', 'learning_rate', 'weight', 'bias'])
    
# mlx.optimizers.Optimizer.state
property Optimizer.state
    
The optimizer’s state dictionary.
# mlx.optimizers.Optimizer.update
Optimizer.update(model: Module, gradients: dict)
    
Apply the gradients to the parameters of the model and update the model with the new parameters.
Parameters:
    
  * model (Module) – An mlx module to be updated.
  * gradients (dict) – A Python tree of gradients, most likely computed via `mlx.nn.value_and_grad()`.


# mlx.optimizers.RMSprop
class RMSprop(learning_rate: float | Callable[[array], array], alpha: float = 0.99, eps: float = 1e-08)
    
The RMSprop optimizer [1].
[1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning
\\[\begin{split}v_{t+1} &= \alpha v_t + (1 - \alpha) g_t^2 \\\ w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}\end{split}\\]
Parameters:
    
  * learning_rate (float or callable) – The learning rate \\(\lambda\\).
  * alpha (float, optional) – The smoothing constant \\(\alpha\\). Default: `0.99`
  * eps (float, optional) – The term \\(\epsilon\\) added to the denominator to improve numerical stability. Default: `1e-8`


Methods
`__init__`(learning_rate[, alpha, eps])  
`apply_single`(gradient, parameter, state)
Performs the RMSprop parameter update and stores \\(v\\) in the optimizer state.  
`init_single`(parameter, state)
Initialize optimizer state  
# mlx.optimizers.SGD
class SGD(learning_rate: float | Callable[[array], array], momentum: float = 0.0, weight_decay: float = 0.0, dampening: float = 0.0, nesterov: bool = False)
    
The stochastic gradient descent optimizer.
Updates a parameter \\(w\\) with a gradient \\(g\\) as follows
\\[\begin{split}v_{t+1} &= \mu v_t + (1 - \tau) g_t \\\ w_{t+1} &= w_t - \lambda v_{t+1}\end{split}\\]
Parameters:
    
  * learning_rate (float or callable) – The learning rate \\(\lambda\\).
  * momentum (float, optional) – The momentum strength \\(\mu\\). Default: `0`
  * weight_decay (float, optional) – The weight decay (L2 penalty). Default: `0`
  * dampening (float, optional) – Dampening for momentum \\(\tau\\). Default: `0`
  * nesterov (bool, optional) – Enables Nesterov momentum. Default: `False`


Methods
`__init__`(learning_rate[, momentum, ...])  
`apply_single`(gradient, parameter, state)
Performs the SGD parameter update and stores \\(v\\) in the optimizer state.  
`init_single`(parameter, state)
Initialize optimizer state  
# mlx.optimizers.cosine_decay
cosine_decay(init: float, decay_steps: int, end: float = 0.0) → Callable
    
Make a cosine decay scheduler.
Parameters:
    
  * init (float) – Initial value.
  * decay_steps (int) – Number of steps to decay over. The decayed value is constant for steps beyond `decay_steps`.
  * end (float, optional) – Final value to decay to. Default: `0`.


Example
    
    >>> lr_schedule = optim.cosine_decay(1e-1, 1000)
    >>> optimizer = optim.SGD(learning_rate=lr_schedule)
    >>> optimizer.learning_rate
    array(0.1, dtype=float32)
    >>>
    >>> for _ in range(5): optimizer.update({}, {})
    ...
    >>> optimizer.learning_rate
    array(0.0999961, dtype=float32)
    
# mlx.optimizers.exponential_decay
exponential_decay(init: float, decay_rate: float) → Callable
    
Make an exponential decay scheduler.
Parameters:
    
  * init (float) – Initial value.
  * decay_rate (float) – Multiplicative factor to decay by.


Example
    
    >>> lr_schedule = optim.exponential_decay(1e-1, 0.9)
    >>> optimizer = optim.SGD(learning_rate=lr_schedule)
    >>> optimizer.learning_rate
    array(0.1, dtype=float32)
    >>>
    >>> for _ in range(5): optimizer.update({}, {})
    ...
    >>> optimizer.learning_rate
    array(0.06561, dtype=float32)
    
# mlx.optimizers.join_schedules
join_schedules(schedules: List[Callable], boundaries: List[int]) → Callable
    
Join multiple schedules to create a new schedule.
Parameters:
    
  * schedules (list(Callable)) – A list of schedules. Schedule \\(i+1\\) receives a step count indicating the number of steps since the \\(i\\)-th boundary.
  * boundaries (list(int)) – A list of integers of length `len(schedules) - 1` that indicates when to transition between schedules.


Example
    
    >>> linear = optim.linear_schedule(0, 1e-1, steps=10)
    >>> cosine = optim.cosine_decay(1e-1, 200)
    >>> lr_schedule = optim.join_schedules([linear, cosine], [10])
    >>> optimizer = optim.Adam(learning_rate=lr_schedule)
    >>> optimizer.learning_rate
    array(0.0, dtype=float32)
    >>> for _ in range(12): optimizer.update({}, {})
    ...
    >>> optimizer.learning_rate
    array(0.0999938, dtype=float32)
    
# mlx.optimizers.linear_schedule
linear_schedule(init: float, end: float, steps: int) → Callable
    
Make a linear scheduler.
Parameters:
    
  * init (float) – Initial value.
  * end (float) – Final value.
  * steps (int) – Number of steps to apply the schedule over. The value is `end` for any steps beyond `steps`.


Example
    
    >>> lr_schedule = optim.linear_schedule(0, 1e-1, 100)
    >>> optimizer = optim.Adam(learning_rate=lr_schedule)
    >>> optimizer.learning_rate
    array(0.0, dtype=float32)
    >>> for _ in range(101): optimizer.update({}, {})
    ...
    >>> optimizer.learning_rate
    array(0.1, dtype=float32)
    
# mlx.optimizers.step_decay
step_decay(init: float, decay_rate: float, step_size: int) → Callable
    
Make a step decay scheduler.
Parameters:
    
  * init (float) – Initial value.
  * decay_rate (float) – Multiplicative factor to decay by.
  * step_size (int) – Decay every `step_size` steps.


Example
    
    >>> lr_schedule = optim.step_decay(1e-1, 0.9, 10)
    >>> optimizer = optim.SGD(learning_rate=lr_schedule)
    >>> optimizer.learning_rate
    array(0.1, dtype=float32)
    >>>
    >>> for _ in range(21): optimizer.update({}, {})
    ...
    >>> optimizer.learning_rate
    array(0.081, dtype=float32)
    
# Common Optimizers
`SGD`(learning_rate[, momentum, weight_decay, ...])
The stochastic gradient descent optimizer.  
`RMSprop`(learning_rate[, alpha, eps])
The RMSprop optimizer [1].  
`Adagrad`(learning_rate[, eps])
The Adagrad optimizer [1].  
`Adafactor`([learning_rate, eps, ...])
The Adafactor optimizer.  
`AdaDelta`(learning_rate[, rho, eps])
The AdaDelta optimizer with a learning rate [1].  
`Adam`(learning_rate[, betas, eps, ...])
The Adam optimizer [1].  
`AdamW`(learning_rate[, betas, eps, ...])
The AdamW optimizer [1].  
`Adamax`(learning_rate[, betas, eps])
The Adamax optimizer, a variant of Adam based on the infinity norm [1].  
`Lion`(learning_rate[, betas, weight_decay])
The Lion optimizer [1].  
`MultiOptimizer`(optimizers[, filters])
Wraps a list of optimizers with corresponding weight predicates/filters to make it easy to use different optimizers for different weights.  
`Muon`(learning_rate[, momentum, ...])
The Muon optimizer.  
# Optimizer
class Optimizer(schedulers=None)
    
The base class for all optimizers. It allows us to implement an optimizer on a per-parameter basis and apply it to a parameter tree.
Attributes
`Optimizer.state`
The optimizer's state dictionary.  
Methods
`Optimizer.apply_gradients`(gradients, parameters)
Apply the gradients to the parameters and return the updated parameters.  
`Optimizer.init`(parameters)
Initialize the optimizer's state  
`Optimizer.update`(model, gradients)
Apply the gradients to the parameters of the model and update the model with the new parameters.  
# Schedulers
`cosine_decay`(init, decay_steps[, end])
Make a cosine decay scheduler.  
`exponential_decay`(init, decay_rate)
Make an exponential decay scheduler.  
`join_schedules`(schedules, boundaries)
Join multiple schedules to create a new schedule.  
`linear_schedule`(init, end, steps)
Make a linear scheduler.  
`step_decay`(init, decay_rate, step_size)
Make a step decay scheduler.  
# Random
Random sampling functions in MLX use an implicit global PRNG state by default. However, all function take an optional `key` keyword argument for when more fine-grained control or explicit state management is needed.
For example, you can generate random numbers with:
    
    for _ in range(3):
      print(mx.random.uniform())
    
which will print a sequence of unique pseudo random numbers. Alternatively you can explicitly set the key:
    
    key = mx.random.key(0)
    for _ in range(3):
      print(mx.random.uniform(key=key))
    
which will yield the same pseudo random number at each iteration.
Following JAX’s PRNG design we use a splittable version of Threefry, which is a counter-based PRNG.
`bernoulli`([p, shape, key, stream])
Generate Bernoulli random values.  
`categorical`(logits[, axis, shape, ...])
Sample from a categorical distribution.  
`gumbel`([shape, dtype, key, stream])
Sample from the standard Gumbel distribution.  
`key`(seed)
Get a PRNG key from a seed.  
`normal`([shape, dtype, loc, scale, key, stream])
Generate normally distributed random numbers.  
`multivariate_normal`(mean, cov[, shape, ...])
Generate jointly-normal random samples given a mean and covariance.  
`randint`(low, high[, shape, dtype, key, stream])
Generate random integers from the given interval.  
`seed`(seed)
Seed the global PRNG.  
`split`(key[, num, stream])
Split a PRNG key into sub keys.  
`truncated_normal`(lower, upper[, shape, ...])
Generate values from a truncated normal distribution.  
`uniform`([low, high, shape, dtype, key, stream])
Generate uniformly distributed random numbers.  
`laplace`([shape, dtype, loc, scale, key, stream])
Sample numbers from a Laplace distribution.  
`permutation`(x[, axis, key, stream])
Generate a random permutation or permute the entries of an array.  
# Transforms
`eval`(*args)
Evaluate an `array` or tree of `array`.  
`async_eval`(*args)
Asynchronously evaluate an `array` or tree of `array`.  
`compile`(fun[, inputs, outputs, shapeless])
Returns a compiled function which produces the same output as `fun`.  
`custom_function`
Set up a function for custom gradient and vmap definitions.  
`disable_compile`()
Globally disable compilation.  
`enable_compile`()
Globally enable compilation.  
`grad`(fun[, argnums, argnames])
Returns a function which computes the gradient of `fun`.  
`value_and_grad`(fun[, argnums, argnames])
Returns a function which computes the value and gradient of `fun`.  
`jvp`(fun, primals, tangents)
Compute the Jacobian-vector product.  
`vjp`(fun, primals, cotangents)
Compute the vector-Jacobian product.  
`vmap`(fun[, in_axes, out_axes])
Returns a vectorized version of `fun`.  
# Tree Utils
In MLX we consider a python tree to be an arbitrarily nested collection of dictionaries, lists and tuples without cycles. Functions in this module that return python trees will be using the default python `dict`, `list` and `tuple` but they can usually process objects that inherit from any of these.
Note
Dictionaries should have keys that are valid python identifiers.
`tree_flatten`(tree[, prefix, is_leaf, ...])
Flattens a Python tree to a list of key, value tuples.  
`tree_unflatten`(tree)
Recreate a Python tree from its flat representation.  
`tree_map`(fn, tree, *rest[, is_leaf])
Applies `fn` to the leaves of the Python tree `tree` and returns a new collection with the results.  
`tree_map_with_path`(fn, tree, *rest[, ...])
Applies `fn` to the path and leaves of the Python tree `tree` and returns a new collection with the results.  
`tree_reduce`(fn, tree[, initializer, is_leaf])
Applies a reduction to the leaves of a Python tree.
