__all__ = [
'BaseModel', 'List', 'Dict',
'validator', 'root_validator', 'validate_arguments',
'Optional', 'Union',
'NamedTuple', 'Literal',
'ClassVar', 'PrivateAttr'
]
# Pass through directly from typing
from typing import \
Optional, \
Union, \
NamedTuple, \
ClassVar
try:
# Python 3.8+
from typing import Literal
except:
# Python 3.7 Backport
from typing_extensions import Literal
# Pass through directly from pydantic
from pydantic import \
validator, \
root_validator, \
validate_arguments, \
PrivateAttr
# Custom ALIGN types (BaseModel, List, Dict) defined below
import pydantic.generics
import typing
import collections
import random
import string
import contextvars
import contextlib
_ctx = contextvars.ContextVar('current_constructor', default=None)
@contextlib.contextmanager
def set_context(obj):
token = _ctx.set(obj)
try:
yield
finally:
_ctx.reset(token)
def cast_to_solver(item, solver):
if hasattr(item, 'translate'):
generator = item.translate(solver)
if generator is None:
raise NotImplementedError(f'{item}.translate() did not return a valid generator')
assert solver is not None
formulae = list(generator)
if len(formulae) == 0:
raise NotImplementedError(f'{item}.translate() yielded an empty list of expressions')
yield from solver.annotate(formulae, solver.label(item))
[docs]class BaseModel(pydantic.BaseModel):
@property
def parent(self):
return self._parent
[docs] class Config:
validate_assignment = True
extra = 'forbid'
allow_mutation = False
copy_on_model_validation = False
def __init__(self, *args, **kwargs):
self._parent = _ctx.get()
with set_context(self):
super().__init__(*args, **kwargs)
[docs] def copy(self, include=None, exclude=None, update={}):
def to_dict(val):
if isinstance(val, list):
return [to_dict(v) for v in val]
elif isinstance(val, dict):
return {to_dict(k): to_dict(v) for k, v in val.items()}
elif isinstance(val, (BaseModel, Dict)):
return val.dict(
exclude_unset=True,
exclude_defaults=True)
elif isinstance(val, List):
return val.dict(
exclude_unset=True,
exclude_defaults=True)['__root__']
else:
return val
ctx = self.parent if _ctx.get() is None else _ctx.get()
v = {
**self.dict(
exclude_unset=True,
exclude_defaults=True,
include=include,
exclude=exclude),
**{
k: to_dict(v)
for k, v
in update.items()
}
}
with set_context(ctx):
return self.__class__(**v)
@classmethod
def _validator_ctx(cls):
self = _ctx.get()
assert self is not None, 'Could not retrieve ctx'
return self
_parent = pydantic.PrivateAttr()
KeyT = typing.TypeVar('KeyT')
DataT = typing.TypeVar('DataT')
[docs]class List(pydantic.generics.GenericModel, typing.Generic[DataT]):
__root__: typing.Sequence[DataT]
_commits = pydantic.PrivateAttr()
_parent = pydantic.PrivateAttr()
_cache = pydantic.PrivateAttr()
@property
def parent(self):
return self._parent
[docs] class Config:
validate_assignment = True
extra = 'forbid'
copy_on_model_validation = False
allow_mutation = False
[docs] def append(self, item: DataT):
self.__root__.append(item)
[docs] def extend(self, items: "List[DataT]"):
for item in items:
self.append(item)
[docs] def remove(self, item: DataT):
return self.__root__.remove(item)
[docs] def pop(self, index=-1):
return self.__root__.pop(index)
[docs] def clear(self):
self.__root__.clear()
def __len__(self):
return len(self.__root__)
def __iter__(self):
return iter(self.__root__)
def __getitem__(self, item):
return self.__root__[item]
def __setitem__(self, item, value):
self.__root__[item] = value
def __delitem__(self, sliceobj):
del self.__root__[sliceobj]
def __eq__(self, other):
return self.__root__ == other
def __init__(self, *args, **kwargs):
if '__root__' not in kwargs:
if len(args) == 1:
kwargs['__root__'] = args[0]
args = tuple()
elif len(args) == 0:
kwargs['__root__'] = []
self._parent = _ctx.get()
with set_context(self):
super().__init__(*args, **kwargs)
self._commits = collections.OrderedDict()
self._cache = set()
def _gen_commit_id(self, nchar=8):
id_ = ''.join(random.choices(
string.ascii_uppercase + string.digits, k=nchar))
return self._gen_commit_id(nchar) if id_ in self._commits else id_
[docs] def checkpoint(self):
self._commits[self._gen_commit_id()] = len(self)
return next(reversed(self._commits))
def _revert(self):
_, length = self._commits.popitem()
del self[length:]
[docs] def revert(self, name=None):
assert len(self._commits) > 0, 'Top of scope. Nothing to revert'
if name is None or name == next(reversed(self._commits)):
self._revert()
else:
assert name in self._commits
self._revert()
self.revert(name)
[docs] def translate(self, solver):
for item in self:
yield from cast_to_solver(item, solver)
[docs]class Dict(pydantic.generics.GenericModel, typing.Generic[KeyT, DataT]):
__root__: typing.Dict[KeyT, DataT]
_parent = pydantic.PrivateAttr()
@property
def parent(self):
return self._parent
[docs] class Config:
validate_assignment = True
extra = 'forbid'
copy_on_model_validation = False
allow_mutation = False
def __init__(self, *args, **kwargs):
if '__root__' not in kwargs:
if len(args) == 1:
kwargs['__root__'] = args[0]
args = tuple()
elif len(args) == 0:
kwargs['__root__'] = {}
self._parent = _ctx.get()
with set_context(self):
super().__init__(*args, **kwargs)
[docs] def items(self):
return self.__root__.items()
[docs] def keys(self):
return self.__root__.keys()
[docs] def values(self):
return self.__root__.values()
def __len__(self):
return len(self.__root__)
def __getitem__(self, item):
return self.__root__[item]
def __setitem__(self, item, value):
self.__root__[item] = value
def __eq__(self, other):
return self.__root__ == other
def __contains__(self, v):
return self.__root__.__contains__(v)