File size: 3,529 Bytes
8b617cc fac2d98 161bcb6 37293dc 77fca25 3355706 34c0a86 3355706 d69ba2b 1648279 d69ba2b 3355706 d69ba2b 3355706 8d288a2 161bcb6 fac2d98 ef22351 fac2d98 ef22351 fac2d98 039e2a0 ef22351 039e2a0 ef22351 039e2a0 ef22351 161bcb6 8d288a2 3355706 77fca25 2bc1a5b ef22351 6c5fbe6 2bc1a5b 77fca25 3355706 77fca25 cf66547 039e2a0 cf66547 732851f 039e2a0 732851f c25ba79 039e2a0 8a49309 2bc1a5b 40a6362 05b398a 40a6362 9be92d1 5894f0e 1648279 dd449c5 77fca25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
"""setup.py for axolotl"""
import platform
import re
from importlib.metadata import PackageNotFoundError, version
from setuptools import find_packages, setup
def parse_requirements():
_install_requires = []
_dependency_links = []
with open("./requirements.txt", encoding="utf-8") as requirements_file:
lines = [r.strip() for r in requirements_file.readlines()]
for line in lines:
is_extras = (
"flash-attn" in line
or "flash-attention" in line
or "deepspeed" in line
or "mamba-ssm" in line
or "lion-pytorch" in line
)
if line.startswith("--extra-index-url"):
# Handle custom index URLs
_, url = line.split()
_dependency_links.append(url)
elif not is_extras and line and line[0] != "#":
# Handle standard packages
_install_requires.append(line)
try:
if "Darwin" in platform.system():
# don't install xformers on MacOS
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
else:
# detect the version of torch already installed
# and set it so dependencies don't clobber the torch version
torch_version = version("torch")
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
if version_match:
major, minor, patch = version_match.groups()
major, minor = int(major), int(minor)
patch = (
int(patch) if patch is not None else 0
) # Default patch to 0 if not present
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 3):
pass
elif (major, minor) >= (2, 2):
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
_install_requires.append("xformers>=0.0.25.post1")
else:
_install_requires.pop(_install_requires.index("xformers==0.0.26.post1"))
_install_requires.append("xformers>=0.0.23.post1")
except PackageNotFoundError:
pass
return _install_requires, _dependency_links
install_requires, dependency_links = parse_requirements()
setup(
name="axolotl",
version="0.4.1",
description="LLM Trainer",
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
package_dir={"": "src"},
packages=find_packages(),
install_requires=install_requires,
dependency_links=dependency_links,
extras_require={
"flash-attn": [
"flash-attn==2.5.8",
],
"fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/[email protected]#subdirectory=csrc/fused_dense_lib",
],
"deepspeed": [
"deepspeed==0.14.2",
"deepspeed-kernels",
],
"mamba-ssm": [
"mamba-ssm==1.2.0.post1",
],
"auto-gptq": [
"auto-gptq==0.5.1",
],
"mlflow": [
"mlflow",
],
"lion-pytorch": [
"lion-pytorch==0.1.2",
],
"galore": [
"galore_torch",
],
},
)
|