import sys import os import fire import torch import transformers import json import jsonlines from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig if torch.cuda.is_available(): device = "cuda" else: device = "cpu" try: if torch.backends.mps.is_available(): device = "mps" except: pass base_model = "./models/WizardCoder-15B-V1.0" load_8bit = False tokenizer = AutoTokenizer.from_pretrained(base_model) if device == "cuda": model = AutoModelForCausalLM.from_pretrained( base_model, load_in_8bit=load_8bit, torch_dtype=torch.float16, device_map="auto", ) elif device == "mps": model = AutoModelForCausalLM.from_pretrained( base_model, device_map={"": device}, torch_dtype=torch.float16, ) model.config.pad_token_id = tokenizer.pad_token_id if not load_8bit: model.half() model.eval() if torch.__version__ >= "2" and sys.platform != "win32": model = torch.compile(model)