Source code for rascal.utils.io

import os
import importlib
from collections.abc import Iterable
import numpy as np
import json
from copy import deepcopy
from abc import ABC, abstractmethod

BETA_VERSION = "0.1"

CURRENT_VERSION = BETA_VERSION

MAX_RECURSION_DEPTH = 20


[docs]def dump_obj(fn, instance, version=CURRENT_VERSION): """Save a python object that inherits from the BaseIO class Parameters ---------- fn : string path to save instance instance : class python object that inherits from the BaseIO class version : string, optional serialization version to use, by default CURRENT_VERSION Raises ------ RuntimeError When instance does not inherit from BaseIO """ if isinstance(instance, BaseIO): to_file(fn, instance, version) else: raise RuntimeError( "The instance does not inherit from BaseIO: {}".format( instance.__class__.__mro__ ) )
[docs]def load_obj(fn): """Load a python object from a file Parameters ---------- fn : string path to the file describing the saved object Returns ------- python class that inherits from BaseIO """ return from_file(fn)
def dump_json(fn, data): """Utility to save a python object to a file. Parameters ---------- fn : string filename to save data data : a json serializable python object """ with open(fn, "w") as f: json.dump(data, f, sort_keys=True, indent=2) def load_json(fn): """Utility to load a python object saved in the json format Parameters ---------- fn : string filename Returns ------- loaded python object from fn """ def _decode(o): # JSON does not have integer keys so they are converted to string # to load the object as it was in python this hook converts to 'int' all # dictionary keys that can be converted if isinstance(o, str): try: return int(o) except ValueError: return o elif isinstance(o, dict): return {_decode(k): v for k, v in o.items()} else: return o with open(fn, "r") as f: data = json.load(f, object_hook=_decode) return data def is_npy(data): """is data a numpy array ?""" return isinstance(data, np.ndarray) def is_large_array(data): """is data a numpy array larger than 50MB ?""" if is_npy(data): if data.nbytes > 50e6: return True else: return False else: return False def is_npy_filename(fn): """does fn string corresponds to a saved numpy array ?""" if isinstance(fn, str): filename, file_extension = os.path.splitext(fn) if file_extension == ".npy": return True else: return False else: return False def get_class(module_name, class_name): """Use module_name and class_name to make an instantiable class.""" module = importlib.import_module(module_name) class_ = getattr(module, class_name) return class_ def obj2dict_beta(cls, state): """Take a python object cls with its state and return a dictionary that can be used to create a copy of this object. Parameters ---------- cls : object state : dictionary Contains the state of cls, i.e. the parameters used to initialize cls in a 'init_params' field and the rest of the data needed to recover the current state in a 'data' field. Returns ------- dictionary fully serialized version of cls to a dictionary """ VERSION = BETA_VERSION module_name = cls.__module__ class_name = cls.__name__ frozen = dict( version=VERSION, class_name=class_name, module_name=module_name, init_params=state["init_params"], data=state["data"], ) return frozen def dict2obj_beta(data): """Take data, a dictionary created by the obj2dict function, and creates a python object as described. Parameters ---------- data : dictionary [description] Returns ------- deserialized python object described by data """ cls = get_class(data["module_name"], data["class_name"]) obj = cls(**data["init_params"]) obj._set_data(data["data"]) return obj def is_valid_object_dict_beta(data): """check compatibility of data to be used in dict2obj_beta""" valid_keys = [ "version", "class_name", "module_name", "init_params", "data", ] aa = [] if isinstance(data, dict): for k in data: if k in valid_keys: aa.append(True) if len(aa) == len(valid_keys): return True else: return False else: return False obj2dict = {BETA_VERSION: obj2dict_beta} dict2obj = {BETA_VERSION: dict2obj_beta} is_valid_object_dict = {BETA_VERSION: is_valid_object_dict_beta} def get_current_io_version(): return CURRENT_VERSION def get_supported_io_versions(): return list(dict2obj.keys())
[docs]class BaseIO(ABC): """Interface of a Python class serializable by to_dict() It corresponds to 3 methods: + _get_init_params is expected to return a dictionary containing all the parameters used by the __init__() methods. + _get_data is expected to return a dictionary containing all the data that is not set by the initialization of the class. + _set_data is expected to set the data that has been extracted by _get_data The underlying c++ objects are not pickle-able so deepcopy does not work out of the box. This class provides an override of the __deepcopy__() function so that classes that inherit from this base class can be deepcopied. """ @abstractmethod def _get_data(self): return dict() @abstractmethod def _set_data(self, data): pass @abstractmethod def _get_init_params(self): return dict() def __deepcopy__(self, memo=None): """Overrides deepcopy default behaviour with custom serialization instead of using pickle.""" return from_dict(to_dict(self)) def __setstate__(self, state): """Overrides default pickling behaviour passing through the dict representation.""" obj = from_dict(state) self.__dict__.update(obj.__dict__) def __getstate__(self): """Overrides default pickling behaviour passing through the dict representation.""" return to_dict(self)
def _get_state(obj): if isinstance(obj, BaseIO): state = dict(data=obj._get_data(), init_params=obj._get_init_params()) else: raise ValueError( 'input object: "{}" does not inherit from "BaseIO"'.format(obj) ) return state def to_dict(obj, version=CURRENT_VERSION, recursion_depth=0): """Recursively serialize to dict via the BaseIO interface. obj has to inherit from BaseIO.""" if recursion_depth >= MAX_RECURSION_DEPTH: raise ValueError( "The object to be serialized to dict contains more than {}".format( MAX_RECURSION_DEPTH ) + " levels of nested objects suggesting there is a circular reference." + " Objects containing a reference to themselves are not supported." ) else: recursion_depth += 1 state = _get_state(obj) # loop over the 2 fields of state for name, entry in state.items(): if isinstance(entry, dict): # case of potentially nested objects for k, v in entry.items(): if isinstance(v, BaseIO): state[name][k] = to_dict(v, version, recursion_depth) elif isinstance(v, list): # make sure list of objects are properly serialized ll = [] for val in v: if isinstance(val, BaseIO): ll.append(to_dict(val, version, recursion_depth)) else: ll.append(val) state[name][k] = ll data = obj2dict[version](obj.__class__, state) return data def from_dict(data): """Recursirvely deserialize from dict via the BaseIO interface.""" # temporary dictionary to hold the object being recovered data_obj = dict() version = data["version"] for name, entry in data.items(): if isinstance(entry, dict): data_obj[name] = dict() for k, v in entry.items(): if is_valid_object_dict[version](v): # in case of nested objects data_obj[name][k] = from_dict(v) elif isinstance(v, list): # in case of list make sure to handle list of serialized # objects ll = [] for val in v: if is_valid_object_dict[version](val): ll.append(from_dict(val)) else: ll.append(val) data_obj[name][k] = ll else: # just transfer the data data_obj[name][k] = v else: # just transfer the data data_obj[name] = entry obj = dict2obj[version](data_obj) return obj def to_file(fn, obj, version=CURRENT_VERSION): """Saves the object 'obj' to a file named 'fn'. It uses the to_dict() serialization procedure.""" fn = os.path.abspath(fn) filename, file_extension = os.path.splitext(fn) data = to_dict(obj, version=version) class_name = data["class_name"].lower() if file_extension == ".json": _dump_npy(fn, data, class_name) dump_json(fn, data) else: raise NotImplementedError("Unknown file extention: {}".format(file_extension)) def from_file(fn): """Loads an object that was saved using to_file() from a file""" fn = os.path.abspath(fn) path = os.path.dirname(fn) filename, file_extension = os.path.splitext(fn) if file_extension == ".json": data = load_json(fn) version = data["version"] if is_valid_object_dict[version](data): _load_npy(data, path) return from_dict(data) else: raise RuntimeError( "The file: {}; does not contain a valid dictionary".format(fn) + " representation of an object." ) else: raise NotImplementedError("Unknown file extention: {}".format(file_extension)) def _dump_npy(fn, data, class_name): """Saves numpy array to the object file. If the array is large (>50MB) main file contains a relative path to the *.npy file so that it can be loaded properly. Small numpy array are converted to lists and saved in the main file.""" filename, file_extension = os.path.splitext(fn) for k, v in data.items(): if isinstance(v, dict): if "class_name" in data: class_name = data["class_name"].lower() _dump_npy(fn, v, class_name) elif is_large_array(v): if "tag" in data: class_name += "-" + data["tag"] v_fn = filename + "-{}-{}".format(class_name, k) + ".npy" v_bfn = os.path.basename(v_fn) data[k] = v_bfn np.save(v_fn, v) elif is_npy(v): data[k] = ["npy", v.tolist()] def _load_npy(data, path): """Loads a numpy array saved using _dump_npy(). A large array stored in a different file is mmaped so it is physically loaded only when needed.""" for k, v in data.items(): if isinstance(v, dict): _load_npy(v, path) elif is_npy_filename(v): data[k] = np.load(os.path.join(path, v), mmap_mode="r") elif isinstance(v, list): if len(v) == 2: if "npy" == v[0]: data[k] = np.array(v[1])