Source code for pyrosetta_documentarian.xml

import pyrosetta, os
from collections import defaultdict
from typing import List, Optional, Union
from warnings import warn
from bs4 import BeautifulSoup  # beautifulsoup4 lxml
import json
import pkg_resources

from .base import BaseDocumentarian

class XMLDocumentarian(BaseDocumentarian):
    test_folder = 'main/tests/integration/tests'
    # cached properties:
    _xmlfilenames = [] #: all the xml files in the test folder
    _xmlprotocols = {} #: all the protocols of the the filesnames (fails!)
    _mover_directory = defaultdict(list)  #: all the filenmaes that have a given mover.

    def get_relevant_scripts(self) -> List[str]:
        """
        Get the paths of the XML test scripts that use the given mover.

        :return:
        """
        mover_name = self.target.get_name()
        if mover_name in self.mover_directory:
            return self.mover_directory[mover_name]
        else:
            return []

    def get_mover_from_script(self,
                              xmlfilename: str,
                              target: Optional[Union[pyrosetta.rosetta.protocols.moves.Mover, str]]=None
                              ) -> pyrosetta.rosetta.protocols.moves.Mover:
        """
        Get the mover from the script ``xmlfilename`` of the same type as ``target`` (mover or mover name)
        If omitted ``xml_documentarian.target`` is used.

        :param xmlfilename:
        :param target:
        :return:
        """
        if target is None:
            target = self.target
        protocol = self.load_xmlfilename(xmlfilename)
        movers = self.get_movers_in_protocol(protocol)
        for mover in movers:
            if isinstance(target, str) and mover.get_name() == target:
                return mover
            elif mover.get_name() == target.get_name():
                return mover
        else:
            raise ValueError(f'Could not find mover {target.get_name()} in {xmlfilename}')

    @property  # cached
    def mover_directory(self):
        """
        This is a dictionary associates a mover name with a list of files that use it.
        This can be used with
        """
        if self._mover_directory:
            return self._mover_directory
        for filename in self.xmlfilenames:
            for movername in self.get_movernames_from_xmlfilename(filename):
                self._mover_directory[movername].append(filename)
        return self._mover_directory

    def _get_root_of_xml(self, soup: BeautifulSoup): # -> bs4.element.Tag:
        if soup.ROSETTASCRIPTS is not None:
            return soup.ROSETTASCRIPTS
        elif soup.JobDefinitionFile is not None:
            return soup.JobDefinitionFile.Common
        else:
            warn(f'This is not standard: {soup}')
            return BeautifulSoup('<main><MOVERS></MOVERS><FILTERS></FILTERS></main>', 'xml')

    def get_movernames_from_xmlfilename(self, xmlfilename: str) -> List[str]:
        soup = BeautifulSoup(open(xmlfilename), 'xml')
        scripts = self._get_root_of_xml(soup)
        if scripts.MOVERS is None:
            return []
        return [tag.name for tag in scripts.MOVERS.findChildren()]

    def get_filternames_from_xmlfilename(self, xmlfilename: str) -> List[str]:
        soup = BeautifulSoup(open(xmlfilename), 'xml')
        scripts = self._get_root_of_xml(soup)
        if scripts.FILTERS is None:
            return []
        return [tag.name for tag in scripts.FILTERS.findChildren()]

    @property  # cached
    def xmlfilenames(self):
        if self._xmlfilenames != []:
            return self._xmlfilenames

        # walk
        def get_files(path, files: list):
            if os.path.isdir(path):
                for file in os.listdir(path):
                    get_files(os.path.join(path, file), files)
            elif '.xml' in path:
                files.append(path)
            else:
                pass
            return files

        self._xmlfilenames = get_files(os.path.join(self.rosetta_folder, self.test_folder), [])
        return self._xmlfilenames

    @property  # cached
    def xmlprotocols(self):
        warn('This will not work!')
        if self._xmlprotocols != {}:
            return self._xmlprotocols
        for xmlfilename in self.xmlfilenames:
            try:
                self._xmlprotocols[xmlfilename] = self.load_xmlfilename(xmlfilename)
            except Exception as error:
                warn(f'{xmlfilename} could not be read due to {error.__class__.__name__}: {error}')
        return self._xmlprotocols

    def load_xmlfilename(self, xmlfilename: str):
        pose = pyrosetta.Pose()
        xml_obj = pyrosetta.rosetta.protocols.rosetta_scripts.RosettaScriptsParser()
        protocol = xml_obj.generate_mover_and_apply_to_pose(pose, xmlfilename)
        return protocol

    def get_movers_in_protocol(self, protocol: pyrosetta.rosetta.protocols.rosetta_scripts.ParsedProtocol):
        # I could not figure out how to find out how many movers are in a protocol.
        # Namely how many Add tags in PROTOCOLS in the XML
        i = 1
        movers = []
        while True:
            try:
                movers.append(protocol.get_mover(i))
            except RuntimeError:
                break
        # movers may have movers themselves. All the movers in MOVERS
        for mover in movers:
            if hasattr(mover, 'mover'):
                movers.append(mover.mover())
        return mover

    @classmethod
    def fill(cls):
        """
        Fill the class attribute ``_mover_directory`` with the json info.
        This may not be up to date. For an up to date version, see ``.get_relevant_scripts() ``
        """
        raw_data = json.load(pkg_resources.resource_stream(__name__, 'mover2files.json'))
        fix_path = lambda path: path.replace('${ROSETTA_PATH}', f'{cls.rosetta_folder}/main')
        cls._mover_directory = {k: [fix_path(vv) for vv in set(v)] for k, v in raw_data.items()}