| import argparse | |
| import torch | |
| from .make_model import make_model | |
| hparams_dict = { | |
| 'HF_MODEL_PATH': 'facebook/wav2vec2-large-xlsr-53', | |
| 'DATASET': 'recanvo', | |
| 'MAX_DURATION': 4, | |
| 'SAMPLING_RATE': 16_000, | |
| 'OUTPUT_HIDDEN_STATES': True, | |
| 'CLASSIFIER_NAME': 'multilevel', | |
| 'CLASSIFIER_PROJ_SIZE': 256, | |
| 'NUM_LABELS': 3, | |
| 'LABEL_WEIGHTS': [1.0], | |
| 'LOSS': 'cross-entropy', | |
| 'GPU_ID': 0, | |
| 'RETURN_RAW_ARRAY': False, | |
| } | |
| hparams = argparse.Namespace(**hparams_dict) | |
| def get_behaviour_model(classifier_weights_path, device): | |
| state_dict = torch.load(classifier_weights_path, map_location=device) | |
| model = make_model(hparams) | |
| model.classifier.load_state_dict(state_dict) | |
| model.eval() | |
| return model |