from . import types
import abc
import logging
import functools
logger = logging.getLogger(__name__)
[docs]def cache(function=None, *, types=None):
'''
This decorator will store the results of a visitor method
in self.cache and retrieve it if the id of a new object
matches one in the cache.
This simultaeneously helps avoid redundant computation and
ensures that shared pointers in the original tree results
in shared pointers in the new tree. If this is not desired
(if the result of a subtree is dependent upon the nodes
above it for example), please implement a custom visit_*
method WITHOUT the @cache decorator.
It is to be noted that the decorator can be used in two ways:
1) @cache
Cache all results (Used 99.999% of the time)
2) @cache(types=[...])
Cache all incoming nodes that are instances of types
(Mostly used by generic_visit)
'''
def decorator(f):
@functools.wraps(f)
def cached_method(visitor, node):
if types and not isinstance(node, types):
return f(visitor, node)
try:
return visitor.cache[id(node)]
except:
pass
newnode = f(visitor, node)
visitor.cache[id(node)] = newnode
return newnode
return cached_method
if function:
return decorator(function)
else:
return decorator
[docs]class Visitor(object):
"""
The Visitor base class walks the ALIGN specification tree and calls a
visitor function for every node found. This is very similar to the
`NodeVisitor` class implemented by the python internal `ast` module (except
that it operates on types.BaseModel derivates).
This class is meant to be subclassed, with the subclass adding visitor
methods. The visitor functions for the nodes are ``'visit_'`` + the
class name of the node. So a `SubCircuit` node visit function would
be `visit_SubCircuit`. If no visitor function exists for a node the
`generic_visit` visitor is used instead.
Don't use the `Visitor` if you want to apply changes to nodes during
traversing. For this a special visitor exists (`NodeTransformer`) that
allows modifications.
Usually you use the Visitor like this::
result = YourVisitor().visit(node)
Where the type of result is determined by the return type of the
root node visitor. Note that the generic_visitor attempts to return
either a list or None for most visitors.
"""
def __init__(self):
self.cache = {}
[docs] def visit(self, node):
if isinstance(node, (types.BaseModel, types.List, types.Dict, list, dict, str, int, type(None))):
method = 'visit_' + node.__class__.__name__
return getattr(self, method, self.generic_visit)(node)
else:
raise NotImplementedError(f'{self.__class__.__name__}.visit() does not support node of type {node.__class__.__name__}:\n{node}')
[docs] @staticmethod
def iter_fields(node):
for field in node.__fields__.keys():
try:
yield field, getattr(node, field)
except:
pass
[docs] @staticmethod
def flatten(l):
ret = []
for item in l:
if isinstance(item, list):
ret.extend(item)
elif item is not None:
ret.append(item)
return ret
[docs] @cache(types=(types.BaseModel, types.List, types.Dict))
def generic_visit(self, node):
if isinstance(node, types.BaseModel):
return self.flatten(self.visit(v) for _, v in self.iter_fields(node))
elif isinstance(node, types.List) or isinstance(node, list):
return self.flatten(self.visit(v) for v in node)
elif isinstance(node, types.Dict) or isinstance(node, dict):
return self.flatten(self.visit(v) for _, v in node.items())
elif isinstance(node, (str, int, type(None))):
return None
else:
raise NotImplementedError(
f'{self.__class__.__name__}.generic_visit() does not support node of type {node.__class__.__name__}:\n{node}')