Skip to content
Snippets Groups Projects
test_field_equality.ipynb 6.29 KiB
Newer Older
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pystencils.session import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test field equality behaviour\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Fields create with same parameters are equal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "f1 = ps.Field.create_generic('f', spatial_dimensions=2, index_dimensions=0)\n",
    "f2 = ps.Field.create_generic('f', spatial_dimensions=2, index_dimensions=0)\n",
    "\n",
    "assert f1 == f2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Field ids equal in accesses:  True\n",
      "Field accesses equal:  True\n"
     ]
    }
   ],
   "source": [
    "print(\"Field ids equal in accesses: \", id(f1.center._field) == id(f2.center._field))\n",
    "print(\"Field accesses equal: \", f1.center == f2.center)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "f1 = ps.Field.create_generic('f', spatial_dimensions=1, index_dimensions=0)\n",
    "f2 = ps.Field.create_generic('f', spatial_dimensions=2, index_dimensions=0)\n",
    "assert f1 != f2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "f1 = ps.Field.create_generic('f', spatial_dimensions=1, index_dimensions=0)\n",
    "f2 = ps.Field.create_generic('f', spatial_dimensions=1, index_dimensions=0, dtype=np.float32)\n",
    "assert f1 != f2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Properties of fields:\n",
    "- `field_type`: enum distinguishing normal from index or buffer fields\n",
    "- `_dtype`: data type of field elements\n",
    "- `_layout`: tuple indicating the memory linearization order\n",
    "- `shape`: size of field for each dimension\n",
    "- `strides`: number of elements to jump over to increase coordinate of this dimension by one\n",
    "- `latex_name`: optional display name when field is printed as latex\n",
    "\n",
    "Equality compare of fields:\n",
    "- field has `__eq__` and ``__hash__`` overridden\n",
    "- all parameter but `latex_name` are considered for equality"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test field access equality behaviour"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "f1 = ps.Field.create_generic('f', spatial_dimensions=1, index_dimensions=0)\n",
    "f2 = ps.Field.create_generic('f', spatial_dimensions=1, index_dimensions=0)\n",
    "assert f1.center == f2.center"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "f1 = ps.Field.create_generic('f', spatial_dimensions=1, index_dimensions=0)\n",
    "f2 = ps.Field.create_generic('f', spatial_dimensions=1, index_dimensions=0, dtype=np.float32)\n",
    "assert f1.center != f2.center"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_field_accesses_debug(expr):\n",
    "    from pystencils import Field\n",
    "    fas = list(expr.atoms(Field.Access))\n",
    "    fields = list({e.field for e in fas})\n",
    "    print(\"Field Accesses:\")\n",
    "    for fa in fas:\n",
    "        print(f\"   - {fa}, hash {hash(fa)}, offsets {fa._offsets}, index {fa._index}, {fa._hashable_content()}\")\n",
    "    print(\"\")\n",
    "    for i in range(len(fas)):\n",
    "        for j in range(len(fas)):\n",
    "            if not i < j: \n",
    "                continue\n",
    "            print( f\"   -> {i},{j}  {fas[i]} == {fas[j]}: {fas[i] == {fas[j]}}\")\n",
    "    \n",
    "    print(\"Fields\")\n",
    "    for f in fields:\n",
    "        print(f\"  - {f}, {id(f)}, shape {f.shape}, strides {f.strides}, {f._dtype}, {f.field_type}, layout {f.layout}\")\n",
    "    print(\"\")\n",
    "    for i in range(len(fields)):\n",
    "        for j in range(len(fields)):\n",
    "            if not i < j: \n",
    "                continue\n",
    "            print(f\"  - {fields[i]} == {fields[j]}: {fields[i] == fields[j]}, ids equal {id(fields[i])==id(fields[j])}, hash equal {hash(fields[i])==hash(fields[j])}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Field Accesses:\n",
      "   - f[0], hash -8859424145258271267, offsets (0,), index (), ((('f_C', ('commutative', True), ('complex', True), ('extended_real', True), ('finite', True), ('hermitian', True), ('imaginary', False), ('infinite', False), ('real', True)), 2305067722319023373), ((0,), (_size_f_0,), (_stride_f_0,), <FieldType.GENERIC: 0>, 'f', None, double), 0)\n",
      "   - f[0], hash -6454673863007224785, offsets (0,), index (), ((('f_C', ('commutative', True), ('complex', True), ('extended_real', True), ('finite', True), ('hermitian', True), ('imaginary', False), ('infinite', False), ('real', True)), 4093629613697528859), ((0,), (_size_f_0,), (_stride_f_0,), <FieldType.GENERIC: 0>, 'f', None, float), 0)\n",
      "\n",
      "   -> 0,1  f[0] == f[0]: False\n",
      "Fields\n",
      "  - f, 4881406800, shape (_size_f_0,), strides (_stride_f_0,), double, FieldType.GENERIC, layout (0,)\n",
      "  - f, 4881445024, shape (_size_f_0,), strides (_stride_f_0,), float, FieldType.GENERIC, layout (0,)\n",
      "\n",
      "  - f == f: False, ids equal False, hash equal False\n"
     ]
    }
   ],
   "source": [
    "print_field_accesses_debug(f1.center * f2.center)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2