Coverage for src/appl/core/config.py: 65%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-22 15:39 -0800

1import os 

2from typing import Any 

3 

4import addict 

5import yaml 

6from loguru import logger 

7 

8from .io import get_ext, load_file 

9 

10DIR = os.path.dirname(os.path.abspath(__file__)) 

11DEFAULT_CONFIG_FILE = os.path.join(DIR, "..", "default_configs.yaml") 

12 

13 

14class Configs(addict.Dict): 

15 """A Dictionary class that allows for dot notation access to nested dictionaries.""" 

16 

17 def getattrs(self, key: str, default: Any = None) -> Any: 

18 """Get a value from a nested dictionary using a dot-separated key string.""" 

19 if "." in key: 

20 keys = key.split(".") 

21 else: 

22 keys = [key] 

23 prefix = "." 

24 v = self 

25 try: 

26 for k in keys: 

27 v = getattr(v, k) 

28 prefix += k + "." 

29 return v 

30 except KeyError as e: 

31 msg = f"{e} not found in prefix '{prefix}'" 

32 

33 if default is None: # check if key exists in default configs 

34 try: 

35 # fallback to default configs 

36 default = DEFAULT_CONFIGS.getattrs(key) 

37 except Exception: 

38 pass 

39 

40 if default is not None: 

41 logger.warning(f"{msg}, using default: {default}") 

42 return default 

43 logger.error(msg) 

44 raise e 

45 

46 def to_yaml(self) -> str: 

47 """Convert the Configs object to a YAML string.""" 

48 return yaml.dump(self.to_dict()) 

49 

50 def __missing__(self, key: str) -> None: 

51 raise KeyError(key) 

52 

53 

54def load_config(file: str, *args: Any, **kwargs: Any) -> Configs: 

55 """Load a config file and return the data as a dictionary.""" 

56 ext = get_ext(file) 

57 if ext not in [".json", ".yaml", ".yml", ".toml"]: 

58 raise ValueError(f"Unsupported config file type {ext}") 

59 content = load_file(file, *args, **kwargs) 

60 return Configs(content) 

61 

62 

63DEFAULT_CONFIGS = load_config(DEFAULT_CONFIG_FILE) 

64"""The static default configs loaded from the default config file.""" 

65# singleton 

66configs = DEFAULT_CONFIGS.deepcopy() 

67"""The global configs"""