Source code for janis_core.utils.scatter

from enum import Enum
from typing import List, Union

from janis_core.operators.selectors import (
    Selector,
    InputNodeSelector,
    StepOutputSelector,
    InputSelector,
)
from janis_core.operators.operator import Operator


class ScatterMethod(Enum):
    """
    The scatter methods that Janis supports:
        - Dot: Inner product of two vectors of the same length, eg: Dot [A, B, C] ยท [1, 2, 3] => [A1, B2, C3]
        - Cross: Cartesian product of two vectors, producing every combination of the scattered inputs [A, B] x [1, 2] => [A1, A2, B1, B2]
    """

    dot = "dot"
    cross = "cross"

    def cwl(self):
        if self == ScatterMethod.dot:
            return "dotproduct"
        elif self == ScatterMethod.cross:
            return "flat_crossproduct"  # "nested_crossproduct"

        raise Exception(f"Unrecognised scatter method: '{self.value}'")


ScatterMethods = ScatterMethod


[docs]class ScatterDescription: """ Class for keeping track of scatter information """
[docs] def __init__( self, fields: List[str], method: ScatterMethod = None, labels: Union[Selector, List[str]] = None, ): """ :param fields: The fields of the the tool that should be scattered on. :param method: The method that should be used to scatter the two arrays :param labels: (JANIS ONLY) - :type method: ScatterMethod """ self.fields = fields self.method: ScatterMethod = method self.labels = None if labels is not None: if isinstance(labels, list): self.labels = map(str, labels) elif isinstance(labels, InputNodeSelector): labels = InputSelector(labels.id()) elif isinstance(labels, StepOutputSelector): raise Exception( f"Forbidden: Unable to use StepOutputSelector '{str(labels)}' for scatter label." ) elif isinstance(labels, Operator): if any(isinstance(l, StepOutputSelector) for l in labels.get_leaves()): raise Exception( f"Forbidden: There was a StepOutputSelector as a parameter to the scatter " f"label operator '{str(labels)}' which is not allowed." ) self.labels = labels if len(fields) > 1 and method is None: raise Exception( "When there is more than one field, a ScatterMethod must be selected" )