Spaces:
Running
Running
| import os | |
| from argparse import ArgumentParser, RawDescriptionHelpFormatter | |
| from collections.abc import Mapping | |
| import yaml | |
| __all__ = ['Config'] | |
| class ArgsParser(ArgumentParser): | |
| def __init__(self): | |
| super(ArgsParser, | |
| self).__init__(formatter_class=RawDescriptionHelpFormatter) | |
| self.add_argument('-o', | |
| '--opt', | |
| nargs='*', | |
| help='set configuration options') | |
| self.add_argument('--local_rank') | |
| def parse_args(self, argv=None): | |
| args = super(ArgsParser, self).parse_args(argv) | |
| assert args.config is not None, 'Please specify --config=configure_file_path.' | |
| args.opt = self._parse_opt(args.opt) | |
| return args | |
| def _parse_opt(self, opts): | |
| config = {} | |
| if not opts: | |
| return config | |
| for s in opts: | |
| s = s.strip() | |
| k, v = s.split('=', 1) | |
| if '.' not in k: | |
| config[k] = yaml.load(v, Loader=yaml.Loader) | |
| else: | |
| keys = k.split('.') | |
| if keys[0] not in config: | |
| config[keys[0]] = {} | |
| cur = config[keys[0]] | |
| for idx, key in enumerate(keys[1:]): | |
| if idx == len(keys) - 2: | |
| cur[key] = yaml.load(v, Loader=yaml.Loader) | |
| else: | |
| cur[key] = {} | |
| cur = cur[key] | |
| return config | |
| class AttrDict(dict): | |
| """Single level attribute dict, NOT recursive.""" | |
| def __init__(self, **kwargs): | |
| super(AttrDict, self).__init__() | |
| super(AttrDict, self).update(kwargs) | |
| def __getattr__(self, key): | |
| if key in self: | |
| return self[key] | |
| raise AttributeError("object has no attribute '{}'".format(key)) | |
| def _merge_dict(config, merge_dct): | |
| """Recursive dict merge. Inspired by :meth:``dict.update()``, instead of | |
| updating only top-level keys, dict_merge recurses down into dicts nested to | |
| an arbitrary depth, updating keys. The ``merge_dct`` is merged into | |
| ``dct``. | |
| Args: | |
| config: dict onto which the merge is executed | |
| merge_dct: dct merged into config | |
| Returns: dct | |
| """ | |
| for key, value in merge_dct.items(): | |
| sub_keys = key.split('.') | |
| key = sub_keys[0] | |
| if key in config and len(sub_keys) > 1: | |
| _merge_dict(config[key], {'.'.join(sub_keys[1:]): value}) | |
| elif key in config and isinstance(config[key], dict) and isinstance( | |
| value, Mapping): | |
| _merge_dict(config[key], value) | |
| else: | |
| config[key] = value | |
| return config | |
| def print_dict(cfg, print_func=print, delimiter=0): | |
| """Recursively visualize a dict and indenting acrrording by the | |
| relationship of keys.""" | |
| for k, v in sorted(cfg.items()): | |
| if isinstance(v, dict): | |
| print_func('{}{} : '.format(delimiter * ' ', str(k))) | |
| print_dict(v, print_func, delimiter + 4) | |
| elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict): | |
| print_func('{}{} : '.format(delimiter * ' ', str(k))) | |
| for value in v: | |
| print_dict(value, print_func, delimiter + 4) | |
| else: | |
| print_func('{}{} : {}'.format(delimiter * ' ', k, v)) | |
| class Config(object): | |
| def __init__(self, config_path, BASE_KEY='_BASE_'): | |
| self.BASE_KEY = BASE_KEY | |
| self.cfg = self._load_config_with_base(config_path) | |
| def _load_config_with_base(self, file_path): | |
| """Load config from file. | |
| Args: | |
| file_path (str): Path of the config file to be loaded. | |
| Returns: global config | |
| """ | |
| _, ext = os.path.splitext(file_path) | |
| assert ext in ['.yml', '.yaml'], 'only support yaml files for now' | |
| with open(file_path) as f: | |
| file_cfg = yaml.load(f, Loader=yaml.Loader) | |
| # NOTE: cfgs outside have higher priority than cfgs in _BASE_ | |
| if self.BASE_KEY in file_cfg: | |
| all_base_cfg = AttrDict() | |
| base_ymls = list(file_cfg[self.BASE_KEY]) | |
| for base_yml in base_ymls: | |
| if base_yml.startswith('~'): | |
| base_yml = os.path.expanduser(base_yml) | |
| if not base_yml.startswith('/'): | |
| base_yml = os.path.join(os.path.dirname(file_path), | |
| base_yml) | |
| with open(base_yml) as f: | |
| base_cfg = self._load_config_with_base(base_yml) | |
| all_base_cfg = _merge_dict(all_base_cfg, base_cfg) | |
| del file_cfg[self.BASE_KEY] | |
| file_cfg = _merge_dict(all_base_cfg, file_cfg) | |
| file_cfg['filename'] = os.path.splitext( | |
| os.path.split(file_path)[-1])[0] | |
| return file_cfg | |
| def merge_dict(self, args): | |
| self.cfg = _merge_dict(self.cfg, args) | |
| def print_cfg(self, print_func=print): | |
| """Recursively visualize a dict and indenting acrrording by the | |
| relationship of keys.""" | |
| print_func('----------- Config -----------') | |
| print_dict(self.cfg, print_func) | |
| print_func('---------------------------------------------') | |
| def save(self, p, cfg=None): | |
| if cfg is None: | |
| cfg = self.cfg | |
| with open(p, 'w') as f: | |
| yaml.dump(dict(cfg), f, default_flow_style=False, sort_keys=False) | |