diff --git a/src/virtmat/language/cli/__init__.py b/src/virtmat/language/cli/__init__.py index 4bd33ba61a05de6ed63fa51eaf291248c936fd0c..95413bdd6c594c131f6c0b3e55f3a65f15d4dbac 100644 --- a/src/virtmat/language/cli/__init__.py +++ b/src/virtmat/language/cli/__init__.py @@ -36,3 +36,7 @@ def texts(): parser_b.set_defaults(func=run_model.main) clargs = parser.parse_args() clargs.func(clargs) +# python -m cProfile src/virtmat/language/cli/__init__.py [texts options] 2>&1 | tee cProfile.out +# uncomment for profiling +# if __name__ == '__main__': +# texts() diff --git a/src/virtmat/language/interpreter/workflow_executor.py b/src/virtmat/language/interpreter/workflow_executor.py index b5f218eeafba47ebf50a8c4e07e1f4f57f37ca71..567d0cad5bc2294b8e6b61a828fe90b935e84190 100644 --- a/src/virtmat/language/interpreter/workflow_executor.py +++ b/src/virtmat/language/interpreter/workflow_executor.py @@ -25,6 +25,7 @@ from virtmat.language.utilities.fireworks import get_nodes_providing, get_parent from virtmat.language.utilities.fireworks import safe_update, get_nodes_info, retrieve_value from virtmat.language.utilities.fireworks import get_fw_metadata, get_launches from virtmat.language.utilities.fireworks import get_representative_launch +from virtmat.language.utilities.fireworks import append_wf from virtmat.language.utilities.serializable import DATA_SCHEMA_VERSION from virtmat.language.utilities.serializable import FWDataObject, tag_serialize from virtmat.language.utilities.textx import isinstance_m, get_identifiers @@ -475,7 +476,7 @@ def append_var_nodes(model): parents = get_parent_nodes(model.lpad, model.uuid, node) if None not in parents: get_logger(__name__).debug('appending %s, parents %s', node, parents) - model.lpad.append_wf(Workflow([nodes.pop(ind)]), fw_ids=parents) + append_wf(model.lpad, Workflow([nodes.pop(ind)]), fw_ids=parents) break assert len(nodes) < num_nodes logger.debug('appended %s new variable nodes', nodes_len) @@ -495,9 +496,9 @@ def append_output_nodes(model): if nodes: assert len(nodes) == 1 obj_to.__fw_name = next(n['name'] for n in nodes) - else: # not covered + else: parents = get_nodes_providing(model.lpad, model.uuid, obj_to.ref.name) - model.lpad.append_wf(Workflow([obj_to.firework]), fw_ids=parents) + append_wf(model.lpad, Workflow([obj_to.firework]), fw_ids=parents) logger.debug('added output node for var %s', obj_to.ref.name) diff --git a/src/virtmat/language/utilities/fireworks.py b/src/virtmat/language/utilities/fireworks.py index 1a0e671c4f646a201dde9c792810c1fbb73a3b99..895cfc4ce7658c8bdb28ab8306b95224b1eaea43 100644 --- a/src/virtmat/language/utilities/fireworks.py +++ b/src/virtmat/language/utilities/fireworks.py @@ -6,6 +6,7 @@ import pandas from fireworks import Workflow, Firework, FWorker, Launch from fireworks.utilities.fw_serializers import load_object from fireworks.core.rocket_launcher import rapidfire, launch_rocket +from fireworks.core.launchpad import WFLock from virtmat.middleware.resconfig import get_default_resconfig from virtmat.middleware.utilities import get_slurm_job_state, exec_cancel from virtmat.language.utilities.errors import FILE_READ_EXCEPTIONS @@ -404,3 +405,11 @@ def get_models_overview(lpad, uuids): df_2.rename(lambda x: x[0:3], axis='columns', inplace=True) df_3 = get_models_tags(lpad, uuids) return pandas.concat([df_1, df_2, df_3], axis='columns') + + +def append_wf(lpad, new_wf, fw_ids, detour=False, pull_spec_mods=True): + """replace LaunchPad.append_wf() adding a performance optimization""" + wf = lpad.get_wf_by_fw_id_lzyfw(fw_ids[0]) + updated_ids = wf.append_wf(new_wf, fw_ids, detour=detour, pull_spec_mods=pull_spec_mods) + with WFLock(lpad, fw_ids[0]): + getattr(lpad, '_update_wf')(wf, updated_ids) # due to protected member diff --git a/tests/test_session.py b/tests/test_session.py index 9791918e98aa190aa5a29416039cf0bed95e8b6d..6178eb2ec49bb9f05191a100ef5e7f81b195c24c 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -3,6 +3,7 @@ import os import time import logging import uuid +import yaml import pytest from textx import get_children_of_type from textx.exceptions import TextXError @@ -250,6 +251,16 @@ def test_extend_model_with_large_extension(lpad): assert session_2.model.value == '1' +def test_extend_model_with_object_to(lpad, tmp_path): + """test model extension including an export (object_to) statement""" + path = os.path.join(tmp_path, 'export.yaml') + session = Session(lpad, grammar_path=GRAMMAR_LOC, model_str='a = true') + Session(lpad, uuid=session.uuid, autorun=True, model_str=f"a to file \'{path}\'") + with open(path, 'r', encoding='utf-8') as ifile: + data = yaml.safe_load(ifile) + assert data is True + + def test_error_handler_static_errors(lpad, capsys): """test the domain specific error handler for static errors""" session = Session(lpad, grammar_path=GRAMMAR_LOC, autorun=False)