Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[f2dace/dev, fortran] Necessary changes for solve_nh graph. #1969

Draft
wants to merge 4 commits into
base: f2dace/dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions dace/frontend/fortran/ast_transforms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved.

import copy
import re
import warnings
from collections import namedtuple
from copy import copy
from typing import Dict, List, Optional, Tuple, Set, Union, Type

import sympy as sp
Expand All @@ -12,9 +12,9 @@
from dace.frontend.fortran import ast_internal_classes, ast_utils
from dace.frontend.fortran.ast_desugaring import ConstTypeInjection
from dace.frontend.fortran.ast_internal_classes import Var_Decl_Node, Name_Node, Int_Literal_Node, Data_Ref_Node, \
Execution_Part_Node, Array_Subscript_Node, Bool_Literal_Node
Execution_Part_Node, Array_Subscript_Node, Bool_Literal_Node, Literal
from dace.frontend.fortran.ast_utils import mywalk, iter_fields, iter_attributes, TempName, singular, atmost_one, \
match_callsite_args_to_function_args
match_callsite_args_to_function_args, duplicate_ast_element


class NeedsTypeInferenceException(BaseException):
Expand Down Expand Up @@ -69,6 +69,8 @@ def find_definition(self, scope_vars, node: ast_internal_classes.Data_Ref_Node,
while isinstance(top_ref.parent_ref, ast_internal_classes.Data_Ref_Node):
top_ref = top_ref.parent_ref

if not node.parent:
breakpoint()
struct_type = scope_vars.get_var(node.parent, ast_utils.get_name(top_ref.parent_ref)).type
struct_def = self.structures[struct_type]

Expand Down Expand Up @@ -1180,7 +1182,7 @@ def get_var(self, scope: Optional[Union[ast_internal_classes.FNode, str]],
return self.module_declarations[variable_name]
else:
raise RuntimeError(
f"Couldn't find the declaration of variable {variable_name} in function {self._scope_name(scope)}!")
f"Couldn't find the declaration of variable {variable_name} in function {self._scope_name(scope) if scope else scope}!")

def contains_var(self, scope: ast_internal_classes.FNode, variable_name: str) -> bool:
return (self._scope_name(scope), variable_name) in self.scope_vars
Expand All @@ -1199,7 +1201,7 @@ def visit_Symbol_Decl_Node(self, node: ast_internal_classes.Symbol_Decl_Node):

def _scope_name(self, scope: ast_internal_classes.FNode) -> str:
if isinstance(scope, ast_internal_classes.Main_Program_Node):
return scope.name.name.name
return scope.name.name
elif isinstance(scope, str):
return scope
else:
Expand Down Expand Up @@ -1497,26 +1499,26 @@ def visit_BinOp_Node(self, node: ast_internal_classes.BinOp_Node):
abs_name = self.ast.intrinsic_handler.replace_function_name(ast_internal_classes.Name_Node(name="ABS"))

body_if = ast_internal_classes.Execution_Part_Node(execution=[
ast_internal_classes.BinOp_Node(lval=copy.deepcopy(lval),
ast_internal_classes.BinOp_Node(lval=duplicate_ast_element(lval),
op="=",
rval=ast_internal_classes.Call_Expr_Node(
name=abs_name,
type="DOUBLE",
args=[copy.deepcopy(args[0])],
args=[duplicate_ast_element(args[0])],
line_number=node.line_number, parent=node.parent,
subroutine=False),

line_number=node.line_number, parent=node.parent)
])
body_else = ast_internal_classes.Execution_Part_Node(execution=[
ast_internal_classes.BinOp_Node(lval=copy.deepcopy(lval),
ast_internal_classes.BinOp_Node(lval=duplicate_ast_element(lval),
op="=",
rval=ast_internal_classes.UnOp_Node(
op="-",
type="VOID",
lval=ast_internal_classes.Call_Expr_Node(
name=abs_name,
args=[copy.deepcopy(args[0])],
args=[duplicate_ast_element(args[0])],
type="DOUBLE",
subroutine=False,
line_number=node.line_number, parent=node.parent),
Expand Down Expand Up @@ -2118,7 +2120,7 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
newbody.append(ast_internal_classes.BinOp_Node(
lval=ast_internal_classes.Name_Node(name="_while_cond_" + str(self.count)),
op="=",
rval=copy.deepcopy(old_cond),
rval=duplicate_ast_element(old_cond),
line_number=child.line_number,
parent=child.parent))
newcond = ast_internal_classes.BinOp_Node(
Expand All @@ -2130,7 +2132,7 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
newwhilebody.execution.append(ast_internal_classes.BinOp_Node(
lval=ast_internal_classes.Name_Node(name="_while_cond_" + str(self.count)),
op="=",
rval=copy.deepcopy(old_cond),
rval=duplicate_ast_element(old_cond),
line_number=child.line_number,
parent=child.parent))

Expand Down Expand Up @@ -3207,7 +3209,7 @@ def visit_Data_Ref_Node(self, node: ast_internal_classes.Data_Ref_Node):
# if not isinstance(last_var.part_ref, ast_internal_classes.Array_Subscript_Node):
# return node

self.data_ref_stack.append(copy.deepcopy(node))
self.data_ref_stack.append(duplicate_ast_element(node))
node.part_ref = self.visit(node.part_ref)
self.data_ref_stack.pop()

Expand Down Expand Up @@ -3842,7 +3844,7 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
])
)

dest_indices = copy.deepcopy(tmp_array.indices)
dest_indices = duplicate_ast_element(tmp_array.indices)
for idx, _, _ in tmp_array.noncontig_dims:
iter_var = ast_internal_classes.Name_Node(name=f"tmp_parfor_{tmp_array.counter}_{idx}")
dest_indices[idx] = iter_var
Expand All @@ -3854,7 +3856,7 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No
line_numbe=child.line_number
)

source_indices = copy.deepcopy(tmp_array.indices)
source_indices = duplicate_ast_element(tmp_array.indices)
for idx, main_var, var in tmp_array.noncontig_dims:
iter_var = ast_internal_classes.Name_Node(name=f"tmp_parfor_{tmp_array.counter}_{idx}")

Expand Down Expand Up @@ -3962,7 +3964,7 @@ def visit_Data_Ref_Node(self, node: ast_internal_classes.Data_Ref_Node):
if not isinstance(last_var.part_ref, ast_internal_classes.Array_Subscript_Node):
return node

self.data_ref_stack.append(copy.deepcopy(node))
self.data_ref_stack.append(duplicate_ast_element(node))
node.part_ref = self.visit(node.part_ref)
self.data_ref_stack.pop()

Expand Down
22 changes: 20 additions & 2 deletions dace/frontend/fortran/ast_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved.
from collections import Counter
from itertools import chain
from typing import List, Set, Iterator, Type, TypeVar, Dict, Tuple, Iterable, Union, Optional
from typing import List, Set, Iterator, Type, TypeVar, Dict, Tuple, Iterable, Union, Optional, Any

import networkx as nx
from fparser.two.Fortran2003 import Module_Stmt, Name, Interface_Block, Subroutine_Stmt, Specification_Part, Module, \
Expand All @@ -21,6 +21,7 @@
from dace import subsets
from dace import symbolic as sym
from dace.frontend.fortran import ast_internal_classes
from dace.frontend.fortran.ast_internal_classes import FNode
from dace.sdfg import SDFG, SDFGState, InterstateEdge
from dace.sdfg.nodes import Tasklet

Expand Down Expand Up @@ -98,7 +99,7 @@ def finish_add_state_to_sdfg(state: SDFGState, top_sdfg: SDFG, substate: SDFGSta
state.last_sdfg_states[top_sdfg] = substate


def get_name(node: ast_internal_classes.FNode):
def get_name(node: ast_internal_classes.FNode) -> str:
if isinstance(node, ast_internal_classes.Actual_Arg_Spec_Node):
actual_node = node.arg
else:
Expand Down Expand Up @@ -989,3 +990,20 @@ def get_name(tag: str = 'tmp'):
def is_literal(node: ast_internal_classes.FNode) -> bool:
return isinstance(node, (ast_internal_classes.Int_Literal_Node, ast_internal_classes.Double_Literal_Node, ast_internal_classes.Real_Literal_Node, ast_internal_classes.Bool_Literal_Node))


def duplicate_ast_element(x: Any, **kwargs) -> Any:
return x
if isinstance(x, tuple):
return (duplicate_ast_element(v) for v in x)
elif isinstance(x, list):
return [duplicate_ast_element(v) for v in x]
elif isinstance(x, FNode):
if 'parent' not in kwargs and hasattr(x, 'parent'):
kwargs['parent'] = x.parent
kvs = {k: duplicate_ast_element(getattr(x, k), **kwargs)
for k in chain(x._fields, x._attributes)
if hasattr(x, k)}
xT = type(x)
return xT(**kvs)
else:
return x
4 changes: 4 additions & 0 deletions dace/frontend/fortran/fortran_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1849,6 +1849,9 @@ def process_variable_call(self, variable_in_calling_context: ast_internal_classe

# Get name of variable in SDFG of calling context or globalSDFG if that fails

# if 'prog/t_nh_prog' in str(variable_in_calling_context) and 'tmp_index_' in str(variable_in_calling_context):
# breakpoint()

sdfg_name = self.name_mapping.get(sdfg).get(ast_utils.get_name(variable_in_calling_context))
if sdfg_name is None:
globalsdfg_name = self.name_mapping.get(self.globalsdfg).get(
Expand Down Expand Up @@ -2586,6 +2589,7 @@ def fix_shapes_before_adding_nested(self, sdfg: SDFG, new_sdfg, sizes: List, str
self.temporary_sym_dict[new_sdfg.name]["sym_" + s.name] = s.name
i = i.subs(s, sym.symbol("sym_" + s.name))
else:
breakpoint()
print(f"Symbol {s.name} not found in arrays")
raise ValueError(f"Symbol {s.name} not found in arrays")

Expand Down
22 changes: 11 additions & 11 deletions dace/frontend/fortran/intrinsics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dace.frontend.fortran import ast_internal_classes
from dace.frontend.fortran.ast_transforms import NodeVisitor, NodeTransformer, ParentScopeAssigner, \
ScopeVarsDeclarations, TypeInference, par_Decl_Range_Finder, NeedsTypeInferenceException
from dace.frontend.fortran.ast_utils import fortrantypes2dacetypes, mywalk, is_literal
from dace.frontend.fortran.ast_utils import fortrantypes2dacetypes, mywalk, is_literal, duplicate_ast_element
from dace.libraries.blas.nodes.dot import dot_libnode
from dace.libraries.blas.nodes.gemm import gemm_libnode
from dace.libraries.standard.nodes import Transpose
Expand Down Expand Up @@ -599,7 +599,7 @@ def _parse_binary_op(self, node: ast_internal_classes.Call_Expr_Node, arg: ast_i

# replace the array subscript node in the binary operation
# ignore this when the operand is a scalar
cond = copy.deepcopy(arg)
cond = duplicate_ast_element(arg)
if first_array is not None:
cond.lval = dominant_array
if second_array is not None:
Expand All @@ -615,7 +615,7 @@ def _parse_binary_op(self, node: ast_internal_classes.Call_Expr_Node, arg: ast_i
raise TypeError("Can't parse Fortran binary op with different array ranks!")

# Now, we need to convert the array to a proper subscript node
cond = copy.deepcopy(arg)
cond = duplicate_ast_element(arg)
cond.lval = first_array
cond.rval = second_array

Expand Down Expand Up @@ -1007,7 +1007,7 @@ def _result_init_value(self):

def _result_loop_update(self, node: ast_internal_classes.FNode):
return ast_internal_classes.BinOp_Node(
lval=copy.deepcopy(node.lval),
lval=duplicate_ast_element(node.lval),
op="=",
rval=ast_internal_classes.Int_Literal_Node(value="1"),
line_number=node.line_number
Expand Down Expand Up @@ -1046,7 +1046,7 @@ def _result_init_value(self):

def _result_loop_update(self, node: ast_internal_classes.FNode):
return ast_internal_classes.BinOp_Node(
lval=copy.deepcopy(node.lval),
lval=duplicate_ast_element(node.lval),
op="=",
rval=ast_internal_classes.Int_Literal_Node(value="0"),
line_number=node.line_number
Expand Down Expand Up @@ -1090,13 +1090,13 @@ def _result_init_value(self):

def _result_loop_update(self, node: ast_internal_classes.FNode):
update = ast_internal_classes.BinOp_Node(
lval=copy.deepcopy(node.lval),
lval=duplicate_ast_element(node.lval),
op="+",
rval=ast_internal_classes.Int_Literal_Node(value="1"),
line_number=node.line_number
)
return ast_internal_classes.BinOp_Node(
lval=copy.deepcopy(node.lval),
lval=duplicate_ast_element(node.lval),
op="=",
rval=update,
line_number=node.line_number
Expand Down Expand Up @@ -1184,7 +1184,7 @@ def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_
body_if = ast_internal_classes.BinOp_Node(
lval=node.lval,
op="=",
rval=copy.deepcopy(self.argument_variable),
rval=duplicate_ast_element(self.argument_variable),
line_number=node.line_number
)
return ast_internal_classes.If_Stmt_Node(
Expand Down Expand Up @@ -1430,7 +1430,7 @@ def _summarize_args(self, exec_node: ast_internal_classes.Execution_Part_Node, n
if array_decl.sizes is None or len(array_decl.sizes) == 0:

first_input = self.get_var_declaration(node.parent, node.rval.args[0])
array_decl.sizes = copy.deepcopy(first_input.sizes)
array_decl.sizes = duplicate_ast_element(first_input.sizes)
array_decl.offsets = [1] * len(array_decl.sizes)
array_decl.type = first_input.type

Expand Down Expand Up @@ -1472,14 +1472,14 @@ def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_
"""

copy_first = ast_internal_classes.BinOp_Node(
lval=copy.deepcopy(self.destination_array),
lval=duplicate_ast_element(self.destination_array),
op="=",
rval=self.first_array,
line_number=node.line_number
)

copy_second = ast_internal_classes.BinOp_Node(
lval=copy.deepcopy(self.destination_array),
lval=duplicate_ast_element(self.destination_array),
op="=",
rval=self.second_array,
line_number=node.line_number
Expand Down
19 changes: 19 additions & 0 deletions dace/frontend/fortran/tools/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -405,4 +405,23 @@ Explanation:
-k mo_velocity_advection.velocity_tendencies \
-o ~/velocity_tendencies.sdfg \
-c ~/dace/dace/frontend/fortran/conf_files
```

## 6. What are the source dependencies when processing from scratch?

- `solve_nh` (and `velocity_tendencies`) dependencies:
```shell
-i .../icon-dace/src \
-i .../icon-dace/externals/fortran-support/src \
-i .../icon-dace/externals/cdi/src \
-i .../icon-dace/externals/mtime/src \
-i .../icon-dace/support \
-i .../icon-dace/externals/math-support/src \
-i .../icon-dace/externals/math-interpolation/src \
-i .../icon-dace/externals/ecrad/utilities \ # <- For the NetCDF stub.
```

- `radiation` dependencies:
```shell
-i .../icon-dace/externals/ecrad \
```
2 changes: 1 addition & 1 deletion dace/runtime/include/dace/reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ namespace dace {

static DACE_HDFI T reduce_atomic(T *ptr, const T& value)
{
return wcr_custom<T>::template reduce_atomic(
return wcr_custom<T>::reduce_atomic(
_wcr_fixed<REDTYPE, T>(), ptr, value);
}
};
Expand Down
2 changes: 1 addition & 1 deletion dace/sdfg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,7 @@ def get_view_edge(state: SDFGState, view: nd.AccessNode) -> gr.MultiConnectorEdg
return out_edge

# If both access nodes reside in the same scope, the input data is viewed.
warnings.warn(f"Ambiguous view: in_edge {in_edge} -> view {view.data} -> out_edge {out_edge}")
# warnings.warn(f"Ambiguous view: in_edge {in_edge} -> view {view.data} -> out_edge {out_edge}")
return in_edge


Expand Down
3 changes: 2 additions & 1 deletion dace/sdfg/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def validate_control_flow_region(sdfg: 'SDFG',
undef_syms = set(edge.data.free_symbols) - set(symbols.keys())
if len(undef_syms) > 0:
eid = region.edge_id(edge)
breakpoint()
raise InvalidSDFGInterstateEdgeError(
f'Undefined symbols in edge: {undef_syms}. Add those with '
'`sdfg.add_symbol()` or define outside with `dace.symbol()`', sdfg, eid)
Expand Down Expand Up @@ -334,7 +335,7 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context
for desc in sdfg.arrays.values():
for sym in desc.free_symbols:
symbols[str(sym)] = sym.dtype
validate_control_flow_region(sdfg, sdfg, initialized_transients, symbols, references, **context)
# validate_control_flow_region(sdfg, sdfg, initialized_transients, symbols, references, **context)

except InvalidSDFGError as ex:
# If the SDFG is invalid, save it
Expand Down
3 changes: 2 additions & 1 deletion dace/transformation/passes/lift_struct_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ def __init__(self, sdfg: SDFG, element: Union[Edge[InterstateEdge], Tuple[Contro

def _handle_simple_name_access(self, node: ast.Attribute) -> Any:
struct: dt.Structure = self.data
if not node.attr in struct.members:
if node.attr not in struct.members:
breakpoint()
raise RuntimeError(
f'Structure attribute {node.attr} is not a member of the structure {struct.name} type definition'
)
Expand Down
Loading