diff --git a/src/virtmat/language/constraints/typechecks.py b/src/virtmat/language/constraints/typechecks.py index 184f163054b10b68c894a9f76b672f9d6227a48c..2bef5cc86274edf67a8b0de5c46672715c1e5921 100644 --- a/src/virtmat/language/constraints/typechecks.py +++ b/src/virtmat/language/constraints/typechecks.py @@ -332,12 +332,15 @@ def array_type(self): def get_array_type(datatype, typespec): """construct and return the proper array type depending on datatype""" + if datatype is None: + return None try: - mtype = next(m for m, d in dtypemap.items() if datatype and issubclass(datatype, d)) + mtype = next(m for m, d in dtypemap.items() if issubclass(datatype, d)) except StopIteration as err: - if hasattr(datatype, 'datatype'): + if is_array_type(datatype) and hasattr(datatype, 'datatype'): return get_array_type(datatype.datatype, typespec) - raise err + msg = 'array datatype must be numeric, boolean, string or array' + raise StaticTypeError(msg) from err typespec['arraytype'] = True return DType(mtype, (typemap[mtype],), typespec) diff --git a/src/virtmat/language/interpreter/deferred_executor.py b/src/virtmat/language/interpreter/deferred_executor.py index 29f2d880ddf238bf22ddf238b6483300733b019e..cff95012ded0b82cbabdb0cd74ca16366aef26e5 100644 --- a/src/virtmat/language/interpreter/deferred_executor.py +++ b/src/virtmat/language/interpreter/deferred_executor.py @@ -27,9 +27,8 @@ from virtmat.language.utilities.errors import RuntimeTypeError, TEXTX_WRAPPED_EX from virtmat.language.utilities.typemap import typemap, checktype, checktype_ from virtmat.language.utilities.typemap import is_table_like_type, is_table_like from virtmat.language.utilities.types import is_array, is_scalar, settype -from virtmat.language.utilities.types import ScalarNumerical, is_array_type +from virtmat.language.utilities.types import ScalarNumerical, get_datatype_name from virtmat.language.utilities.types import is_scalar_type, is_numeric_type -from virtmat.language.utilities.types import get_datatype_name from virtmat.language.utilities.lists import get_array_aslist from virtmat.language.utilities.units import get_units, get_dimensionality from virtmat.language.utilities.units import convert_series_units @@ -333,22 +332,28 @@ def iterable_property_func(self): start_ = self.start stop_ = self.stop step_ = self.step - - def get_sliced_value(value): - return value[start_:stop_:step_] if slice_ else value - - if self.array: - assert self.obj.type_.datatype is not None - if issubclass(self.obj.type_.datatype, str): - return (lambda *x: get_sliced_value(func(*x)).values.astype(str), pars) - if issubclass(self.obj.type_.datatype, bool): - return (lambda *x: get_sliced_value(func(*x)).values, pars) - if issubclass(self.obj.type_.datatype, (int, float, complex)): - return (lambda *x: get_sliced_value(func(*x)).values.quantity, pars) - if is_array_type(self.obj.type_.datatype): - return (lambda *x: get_nested_array(get_sliced_value(func(*x)).values), pars) - return (lambda *x: get_sliced_value(func(*x)).values, pars) - return (settype(lambda *x: get_sliced_value(func(*x))), pars) + array = self.array + + @settype + def get_sliced_value(*args): + value = func(*args) + if slice_: + value = value[start_:stop_:step_] + if array: + arr_val = value.values + if isinstance(arr_val, pint_pandas.PintArray): + return arr_val.quantity + assert isinstance(arr_val, numpy.ndarray) + if issubclass(arr_val.dtype.type, (numpy.str_, numpy.bool_)): + return arr_val + if isinstance(arr_val[0], str): + return arr_val.astype(str) + if is_array(arr_val[0]): + return get_nested_array(arr_val) + msg = 'array datatype must be numeric, boolean, string or array' + raise RuntimeTypeError(msg) + return value + return get_sliced_value, pars def iterable_query_func(self): diff --git a/src/virtmat/language/interpreter/instant_executor.py b/src/virtmat/language/interpreter/instant_executor.py index e289781e42378c9e23635cdf366c2661c6c6067e..549d85c4765ad4d887720b7ac81b1767654c788c 100644 --- a/src/virtmat/language/interpreter/instant_executor.py +++ b/src/virtmat/language/interpreter/instant_executor.py @@ -21,7 +21,7 @@ from virtmat.language.utilities.errors import InvalidUnitError, RuntimeTypeError from virtmat.language.utilities.errors import RuntimeValueError from virtmat.language.utilities.typemap import typemap, DType, checktype, checktype_ from virtmat.language.utilities.typemap import is_table_like, table_like_type -from virtmat.language.utilities.types import ScalarNumerical, is_array, is_array_type +from virtmat.language.utilities.types import ScalarNumerical, is_array from virtmat.language.utilities.types import is_numeric_type, is_numeric_scalar_type from virtmat.language.utilities.types import is_scalar_type, is_scalar, settype from virtmat.language.utilities.types import get_datatype_name @@ -294,7 +294,7 @@ def numeric_subarray_value(self): @settype def get_sliced_value(obj): - """return a value slice of an iterable/sequence data structure object""" + """return a slice and/or array of an iterable data structure object""" value = obj.obj.value if obj.slice: value = value[obj.start:obj.stop:obj.step] @@ -303,13 +303,14 @@ def get_sliced_value(obj): if isinstance(array, pint_pandas.PintArray): return array.quantity assert isinstance(array, numpy.ndarray) - assert obj.obj.type_.datatype is not None - if is_array_type(obj.obj.type_.datatype): - return get_nested_array(array) - if issubclass(obj.obj.type_.datatype, str): + if issubclass(array.dtype.type, (numpy.str_, numpy.bool_)): + return array + if isinstance(array[0], str): return array.astype(str) - assert issubclass(obj.obj.type_.datatype, bool) - return array + if is_array(array[0]): + return get_nested_array(array) + msg = 'array datatype must be numeric, boolean, string or array' + raise RuntimeTypeError(msg) return value diff --git a/tests/conftest.py b/tests/conftest.py index a8ec940b2edd72a71e965dd7a2155c723216723d..07dc1d27bcb1b8bf253530ea5af7361e253ba1f2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -176,3 +176,11 @@ def res_config_fixture(tmp_path, monkeypatch): def lpad_fixture(): """launchpad object as fixture for all tests""" return LaunchPad.from_file(LAUNCHPAD_LOC) if LAUNCHPAD_LOC else LaunchPad() + + +@pytest.fixture(name='tmp_yaml') +def fixture_tmp_yaml(tmp_path): + """create a temporary path for yaml i/o and cleanup after use""" + path = os.path.join(tmp_path, 'tmp.yaml') + yield path + os.unlink(path) diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 7e7a0d129db49d8ec164add0f5c76d31fd9a5956..69319b2bf23386cbd976e19e146419577a4f3502 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -1,12 +1,14 @@ """ tests for data structures """ +import yaml import pytest import numpy from textx import get_children_of_type from textx.exceptions import TextXError from virtmat.language.utilities.typemap import typemap from virtmat.language.utilities.errors import RuntimeValueError, RuntimeTypeError +from virtmat.language.utilities.errors import StaticTypeError def test_function_call_returning_tuple(meta_model, model_kwargs): @@ -1219,7 +1221,7 @@ def test_array_in_series_from_issue(meta_model, model_kwargs): def test_array_from_series_from_issue(meta_model, model_kwargs): - """array from series (test case from isue #265)""" + """array from series (test case from issue #265)""" inp = ('time = map((x: 0.5*x), range(0 [day], 10 [day], 1 [day]));' 'print(time:array)') output = '[0.0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5] [day]' @@ -1227,6 +1229,45 @@ def test_array_from_series_from_issue(meta_model, model_kwargs): assert prog.value == output +def test_array_from_series_with_unknown_datatype(meta_model, model_kwargs, tmp_yaml): + """array from series with unknown datatype (test case from issue #443)""" + ser_dct = {'_fw_name': '{{virtmat.language.utilities.serializable.FWSeries}}', + '_version': 7, 'data': [1, 2], 'datatype': 'int', 'name': 'a', + 'units': 'dimensionless'} + with open(tmp_yaml, 'w', encoding='utf-8') as fh: + yaml.safe_dump(ser_dct, fh) + inp = f"s = Series from file \'{tmp_yaml}\'; print(s:array)" + assert meta_model.model_from_str(inp, **model_kwargs).value == '[1, 2]' + + +def test_array_from_series_with_wrong_dtype(meta_model, model_kwargs): + """array from series with wrong datatype""" + inp = 's = (a: (b: 1)); print(s:array)' + msg = 'array datatype must be numeric, boolean, string or array' + with pytest.raises(TextXError, match=msg) as err: + meta_model.model_from_str(inp, **model_kwargs) + assert isinstance(err.value.__cause__, StaticTypeError) + + +def test_array_from_series_with_wrong_dtype_rt(meta_model, model_kwargs, tmp_yaml): + """array from series with wrong datatype at runtime""" + ser_dct = {'_fw_name': '{{virtmat.language.utilities.serializable.FWSeries}}', + '_version': 7, + 'data': [{'_fw_name': '{{virtmat.language.utilities.serializable.FWSeries}}', + '_version': 7, 'data': [1], 'datatype': 'int', 'name': 'b', + 'units': 'dimensionless'}], 'datatype': 'object', 'name': 'a'} + with open(tmp_yaml, 'w', encoding='utf-8') as fh: + yaml.safe_dump(ser_dct, fh) + inp = f"s = Series from file \'{tmp_yaml}\'; ar = s:array" + prog = meta_model.model_from_str(inp, **model_kwargs) + var_list = get_children_of_type('Variable', prog) + var_s = next(v for v in var_list if v.name == 'ar') + msg = 'array datatype must be numeric, boolean, string or array' + with pytest.raises(TextXError, match=msg) as err: + _ = var_s.value + assert isinstance(err.value.__cause__, RuntimeTypeError) + + def test_series_of_int_arrays(meta_model, model_kwargs): """series of arrays of int type""" inp = ('series_var = (cell: [[12, 0, 0], [0, 12, 0], [0, 0, 12]] [angstrom])\n'