Spaces:
Runtime error
Runtime error
| import evaluate | |
| import datasets | |
| from typing import Union, Dict | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| from tqdm import tqdm | |
| _DESCRIPTION = """ | |
| Perplexity metric implemented by d-Matrix. | |
| Perplexity (PPL) is one of the most common metrics for evaluating language models. | |
| It is defined as the exponentiated average negative log-likelihood of a sequence, calculated with exponent base `e`. | |
| For more information, see https://huggingface.co/docs/transformers/perplexity | |
| """ | |
| _KWARGS_DESCRIPTION = """ | |
| Args: | |
| model (Union[str,AutoModelForCausalLM]): model used for calculating Perplexity | |
| NOTE: Perplexity can only be calculated for causal language models. | |
| This includes models such as gpt2, causal variations of bert, | |
| causal versions of t5, and more (the full list can be found | |
| in the AutoModelForCausalLM documentation here: | |
| https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForCausalLM ) | |
| references (list of str): input text, each separate text snippet is one list entry. | |
| device (str): device to run on, defaults to 'cuda' when available. | |
| max_length (int): maximum sequence length, defaults to 2048. | |
| Returns: | |
| perplexity: dictionary containing the perplexity score and loss. | |
| Examples: | |
| Example: | |
| >>> from datasets import load_dataset | |
| >>> perplexity = evaluate.load("dmx_perplexity", module_type="metric") | |
| >>> input_texts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")["text"][:10] # doctest: +SKIP | |
| >>> results = perplexity.compute(model='distilgpt2', | |
| ... references=input_texts) | |
| >>> print(list(results.keys())) | |
| ['loss', 'perplexity'] | |
| >>> print(results['loss']) # doctest: +SKIP | |
| 3.8299286365509033 | |
| >>> print(results['perplexity']) # doctest: +SKIP | |
| 46.05925369262695 | |
| """ | |
| class DmxPerplexity(evaluate.Metric): | |
| def _info(self): | |
| return evaluate.MetricInfo( | |
| module_type="metric", | |
| description=_DESCRIPTION, | |
| citation="", | |
| inputs_description=_KWARGS_DESCRIPTION, | |
| features=datasets.Features( | |
| { | |
| "references": datasets.Value("string"), | |
| } | |
| ), | |
| reference_urls=["https://huggingface.co/docs/transformers/perplexity"], | |
| ) | |
| def _compute( | |
| self, | |
| references, | |
| model: Union[str, AutoModelForCausalLM], | |
| device=None, | |
| max_length=None, | |
| **kwargs, | |
| ): | |
| if device is not None: | |
| assert device in [ | |
| "gpu", | |
| "cpu", | |
| "cuda", | |
| ], "device should be either gpu or cpu." | |
| if device == "gpu": | |
| device = "cuda" | |
| else: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if isinstance(model, str): | |
| tokenizer = AutoTokenizer.from_pretrained(model) | |
| model = AutoModelForCausalLM.from_pretrained(model) | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path,**kwargs) | |
| if max_length: | |
| max_seq_len = max_length | |
| elif hasattr(model.config, "max_position_embeddings"): | |
| max_seq_len = model.config.max_position_embeddings | |
| elif hasattr(model.config, "n_positions"): | |
| max_seq_len = model.config.n_positions | |
| else: | |
| max_seq_len = 2048 | |
| if not hasattr(model, "hf_device_map") and ( | |
| not hasattr(model, "model_parallel") or not model.model_parallel | |
| ): | |
| model = model.to(device) | |
| model.eval() | |
| encodings = tokenizer("\n\n".join(references), return_tensors="pt") | |
| stride = max_seq_len | |
| seq_len = encodings.input_ids.size(1) | |
| seq_len = (seq_len // stride) * stride | |
| nlls = [] | |
| prev_end_loc = 0 | |
| for begin_loc in tqdm(range(0, seq_len, stride)): | |
| end_loc = min(begin_loc + max_seq_len, seq_len) | |
| trg_len = end_loc - prev_end_loc | |
| input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) | |
| target_ids = input_ids.clone() | |
| target_ids[:, :-trg_len] = -100 | |
| with torch.no_grad(): | |
| outputs = model(input_ids, labels=target_ids) | |
| if isinstance(outputs, Dict): | |
| neg_log_likelihood = outputs["loss"] * trg_len | |
| else: | |
| neg_log_likelihood = outputs.loss * trg_len | |
| nlls.append(neg_log_likelihood.to(device)) | |
| prev_end_loc = end_loc | |
| if end_loc == seq_len: | |
| break | |
| loss = torch.stack(nlls).float().sum() / end_loc | |
| ppl = torch.exp(loss) | |
| return dict( | |
| loss=loss.item(), | |
| perplexity=ppl.item(), | |
| ) | |