reach-vb's picture
reach-vb HF staff
18256559666a70d7f16bb789b379220578acb9d0f204be811e183cc0a99d467b
42a2b88
raw
history blame
11.8 kB
# Copyright © 2023 Apple Inc.
import math
import mlx.core as mx
from mlx.nn.layers.base import Module
def _make_activation_module(f):
def decorator(klass):
klass.__doc__ = f.__doc__
klass.__call__ = lambda self, x: f(x)
return klass
return decorator
def sigmoid(x):
r"""Applies the element-wise function:
.. math::
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
"""
return mx.sigmoid(x)
def relu(x):
r"""Applies the Rectified Linear Unit.
Simply ``mx.maximum(x, 0)``.
"""
return mx.maximum(x, 0)
def leaky_relu(x, negative_slope=0.01):
r"""Applies the Leaky Rectified Linear Unit.
Simply ``mx.maximum(negative_slope * x, x)``.
"""
return mx.maximum(negative_slope * x, x)
def log_softmax(x, axis=-1):
r"""Applies the Log Softmax function.
Applies :math:`x + \log \sum_i e^{x_i}` element wise.
"""
return x - mx.logsumexp(x, axis=axis, keepdims=True)
def elu(x, alpha=1.0):
r"""Applies the Exponential Linear Unit.
Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.
"""
return mx.where(x > 0, x, alpha * (mx.exp(x) - 1))
def relu6(x):
r"""Applies the Rectified Linear Unit 6.
Applies :math:`\min(\max(x, 0), 6)` element wise.
"""
return mx.minimum(mx.maximum(x, 0), 6.0)
def softmax(x, axis=-1):
r"""Applies the Softmax function.
Applies :math:`\frac{e^{x_i}}{\sum_j e^{x_j}}` element wise.
"""
return mx.softmax(x, axis=axis)
def softplus(x):
r"""Applies the Softplus function.
Applies :math:`\log(1 + \exp(x))` element wise.
"""
return mx.logaddexp(x, 0)
def softsign(x):
r"""Applies the Softsign function.
Applies :math:`\frac{x}{1 + |x|}` element wise.
"""
return mx.divide(x, 1 + mx.abs(x))
def celu(x, alpha=1.0):
r"""Applies the Continuously Differentiable Exponential Linear Unit.
Applies :math:`\max(0, x) + \min(0, \alpha * (\exp(x / \alpha) - 1))`
element wise.
"""
return mx.maximum(x, 0.0) + alpha * (mx.exp(mx.minimum(x, 0.0) / alpha) - 1)
def silu(x):
r"""Applies the Sigmoid Linear Unit. Also known as Swish.
Applies :math:`x \sigma(x)` element wise, where :math:`\sigma(\cdot)` is
the logistic sigmoid.
"""
return x * mx.sigmoid(x)
def log_sigmoid(x):
r"""Applies the Log Sigmoid function.
Applies :math:`\log(\sigma(x)) = -\log(1 + e^{-x})` element wise.
"""
return -softplus(-x)
def gelu(x):
r"""Applies the Gaussian Error Linear Units function.
.. math::
\\textrm{GELU}(x) = x * \Phi(x)
where :math:`\Phi(x)` is the Gaussian CDF.
See also :func:`gelu_approx` and :func:`gelu_fast_approx` for faster
approximations.
"""
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
def gelu_approx(x):
r"""An approximation to Gaussian Error Linear Unit.
See :func:`gelu` for the exact computation.
This function approximates ``gelu`` with a maximum absolute error :math:`<
0.0003` in the range :math:`[-6, 6]` using the following
.. math::
x = x \sigma\left(1.60033 x \left(1 + 0.0433603 x^2\right)\right)
where :math:`\sigma(\cdot)` is the logistic sigmoid.
"""
return x * mx.sigmoid(1.60033 * x * (1 + 0.0433603 * x.square()))
def gelu_fast_approx(x):
r"""A fast approximation to Gaussian Error Linear Unit.
See :func:`gelu` for the exact computation.
This function approximates ``gelu`` with a maximum absolute error :math:`<
0.015` in the range :math:`[-6, 6]` using the following
.. math::
x = x \sigma\left(1.773 x\right)
where :math:`\sigma(\cdot)` is the logistic sigmoid.
"""
return x * mx.sigmoid(1.773 * x)
@_make_activation_module
class Sigmoid(Module):
r"""Applies the sigmoid function, element-wise.
.. math::
\text{Sigmoid}(x) = \sigma(x) = \frac{1}{1 + \exp(-x)}
"""
pass
def step(x: mx.array, threshold: float = 0.0):
r"""Applies the Step Activation Function.
This function implements a binary step activation, where the output is set
to 1 if the input is greater than a specified threshold, and 0 otherwise.
.. math::
\text{step}(x) = \begin{cases}
0 & \text{if } x < \text{threshold} \\
1 & \text{if } x \geq \text{threshold}
\end{cases}
Args:
threshold: The value to threshold at.
"""
return mx.where(x > threshold, 1, 0)
def selu(x):
r"""Applies the Scaled Exponential Linear Unit.
.. math::
\text{selu}(x) = \begin{cases}
\lambda x & \text{if } x > 0 \\
\lambda \alpha (\exp(x) - 1) & \text{if } x \leq 0
\end{cases}
where :math:`\lambda = 1.0507` and :math:`\alpha = 1.67326`.
See also :func:`elu`.
"""
return elu(x, 1.67326) * 1.0507
def prelu(x: mx.array, alpha: mx.array) -> mx.array:
r"""Applies the element-wise parametric ReLU.
.. math::
\text{PReLU}(x) = \max(0,x) + a * \min(0,x)
where :math:`a` is an array.
"""
return mx.maximum(0, x) + alpha * mx.minimum(0, x)
def mish(x: mx.array) -> mx.array:
r"""Applies the Mish function, element-wise.
Mish: A Self Regularized Non-Monotonic Neural Activation Function.
Reference: https://arxiv.org/abs/1908.08681
.. math::
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
"""
return x * mx.tanh(softplus(x))
def hardswish(x):
r"""Applies the hardswish function, element-wise.
.. math::
\text{Hardswish}(x) = x * \min(\max(x + 3, 0), 6) / 6
"""
max_x_3 = mx.maximum(x + 3, 0)
return x * mx.minimum(max_x_3, 6) / 6
@_make_activation_module(mish)
class Mish(Module):
r"""Applies the Mish function, element-wise.
Reference: https://arxiv.org/abs/1908.08681
.. math::
\text{Mish}(x) = x * \text{Tanh}(\text{Softplus}(x))
"""
pass
@_make_activation_module(relu)
class ReLU(Module):
r"""Applies the Rectified Linear Unit.
Simply ``mx.maximum(x, 0)``.
See :func:`relu`, for the functional equivalent.
"""
pass
class LeakyReLU(Module):
r"""Applies the Leaky Rectified Linear Unit.
Simply ``mx.maximum(negative_slope * x, x)``.
Args:
negative_slope: Controls the angle of the negative slope. Default: 1e-2.
"""
def __init__(self, negative_slope=1e-2):
super().__init__()
self._negative_slope = negative_slope
def __call__(self, x):
return leaky_relu(x, self._negative_slope)
class ELU(Module):
r"""Applies the Exponential Linear Unit.
Simply ``mx.where(x > 0, x, alpha * (mx.exp(x) - 1))``.
See :func:`elu`, for the functional equivalent.
Args:
alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
"""
def __init__(self, alpha=1.0):
super().__init__()
self._alpha = alpha
def __call__(self, x):
return elu(x, self._alpha)
@_make_activation_module(relu6)
class ReLU6(Module):
r"""Applies the Rectified Linear Unit 6.
See :func:`relu6`, for the functional equivalent.
"""
pass
@_make_activation_module(softmax)
class Softmax(Module):
r"""Applies the Softmax function.
See :func:`softmax`, for the functional equivalent.
"""
pass
@_make_activation_module(softplus)
class Softplus(Module):
r"""Applies the Softplus function.
See :func:`softplus`, for the functional equivalent.
"""
pass
@_make_activation_module(softsign)
class Softsign(Module):
r"""Applies the Softsign function.
See :func:`softsign`, for the functional equivalent.
"""
pass
class CELU(Module):
r"""Applies the Continuously Differentiable Exponential Linear Unit.
Applies :math:`\max(0, x) + \min(0, \alpha * (\exp(x / \alpha) - 1))`
element wise.
See :func:`celu`, for the functional equivalent.
Args:
alpha: the :math:`\alpha` value for the CELU formulation. Default: 1.0
"""
def __init__(self, alpha=1.0):
super().__init__()
self._alpha = alpha
def __call__(self, x):
return celu(x, self._alpha)
@_make_activation_module(silu)
class SiLU(Module):
r"""Applies the Sigmoid Linear Unit. Also known as Swish.
See :func:`silu`, for the functional equivalent.
"""
pass
@_make_activation_module(log_softmax)
class LogSoftmax(Module):
r"""Applies the Log Softmax function.
See :func:`log_softmax`, for the functional equivalent.
"""
pass
@_make_activation_module(log_sigmoid)
class LogSigmoid(Module):
r"""Applies the Log Sigmoid function.
See :func:`log_sigmoid`, for the functional equivalent.
"""
pass
class PReLU(Module):
r"""Applies the element-wise parametric ReLU.
Applies :math:`\max(0, x) + a * \min(0, x)` element wise, where :math:`a`
is an array.
See :func:`prelu`, for the functional equivalent.
Args:
num_parameters: number of :math:`a` to learn. Default: 1
init: the initial value of :math:`a`. Default: 0.25
"""
def __init__(self, num_parameters=1, init=0.25):
super().__init__()
self.weight = mx.full([num_parameters], init)
def __call__(self, x: mx.array):
return prelu(x, self.weight)
class GELU(Module):
r"""Applies the Gaussian Error Linear Units.
.. math::
\textrm{GELU}(x) = x * \Phi(x)
where :math:`\Phi(x)` is the Gaussian CDF.
However, if ``approx`` is set to 'precise' or 'fast' it applies
.. math::
\textrm{GELUApprox}(x) &= x * \sigma\left(1.60033 * x \left(1 + 0.0433603 * x^2\right)\right) \\
\textrm{GELUFast}(x) &= x * \sigma\left(1.773 * x\right)
respectively.
See :func:`gelu`, :func:`gelu_approx` and :func:`gelu_fast_approx` for the
functional equivalents and information regarding error bounds.
Args:
approx ('none' | 'precise' | 'fast'): Which approximation to gelu to use if any.
"""
def __init__(self, approx="none"):
super().__init__()
if approx == "none":
self._act = gelu
elif approx == "precise":
self._act = gelu_approx
elif approx == "fast":
self._act = gelu_fast_approx
else:
raise ValueError(
f"The approximation should be in ['none', 'precise', 'fast'] but '{approx}' was given"
)
def __call__(self, x):
return self._act(x)
def tanh(x):
"""Applies the hyperbolic tangent function.
Simply ``mx.tanh(x)``.
"""
return mx.tanh(x)
@_make_activation_module(tanh)
class Tanh(Module):
r"""Applies the hyperbolic tangent function.
See :func:`tanh`, for the functional equivalent.
"""
pass
@_make_activation_module(hardswish)
class Hardswish(Module):
r"""Applies the hardswish function, element-wise.
See :func:`hardswish`, for the functional equivalent.
"""
pass
class Step(Module):
r"""Applies the Step Activation Function.
This function implements a binary step activation, where the output is set
to 1 if the input is greater than a specified threshold, and 0 otherwise.
.. math::
\text{step}(x) = \begin{cases}
0 & \text{if } x < \text{threshold} \\
1 & \text{if } x \geq \text{threshold}
\end{cases}
Args:
threshold: The value to threshold at.
"""
def __init__(self, threshold: float = 0.0):
super().__init__()
self.threshold = threshold
def __call__(self, x: mx.array):
return step(x, self.threshold)
@_make_activation_module(selu)
class SELU(Module):
r"""Applies the Scaled Exponential Linear Unit.
See :func:`selu`, for the functional equivalent.
"""
pass