model_structure_viewer / install_gpu_packages.py
maomao88's picture
remove flash_attn from requirements
b014f93
raw
history blame contribute delete
375 Bytes
import subprocess
import sys
import torch
if torch.cuda.is_available():
print("GPU detected. Installing GPU packages...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "flash_attn", "einops"])
else:
print("No GPU detected. Installing CPU-only packages if needed...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "einops"])