Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nested patterns #96

Merged
merged 5 commits into from
Sep 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions changes/95.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Partially covered https://github.com/zaibacu/rita-dsl/issues/70

Allow nested patterns, like:

.. code-block::

num_with_fractions = {NUM, WORD("-")?, IN_LIST(fractions)}
complex_number = {NUM|PATTERN(num_with_fractions)}

{PATTERN(complex_number)}->MARK("NUMBER")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "rita-dsl"
version = "0.6.6-1"
version = "0.6.7"
description = "DSL for building language rules"
authors = [
"Šarūnas Navickas <[email protected]>"
Expand Down
2 changes: 1 addition & 1 deletion rita/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

logger = logging.getLogger(__name__)

__version__ = (0, 6, 6, os.getenv("VERSION_PATCH"))
__version__ = (0, 6, 7, os.getenv("VERSION_PATCH"))


def get_version():
Expand Down
6 changes: 4 additions & 2 deletions rita/engine/translate_spacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def generic_parse(tag, value, config, op=None):


def punct_parse(_, config, op=None):
d = {}
d = dict()
d["IS_PUNCT"] = True
if op:
d["OP"] = op
Expand Down Expand Up @@ -87,7 +87,9 @@ def tag_parse(r, config, op=None):


def nested_parse(values, config, op=None):
results = rules_to_patterns("", values, config=config)
from rita.macros import resolve_value
results = rules_to_patterns("", [resolve_value(v, config=config)
for v in values], config=config)
return results["pattern"]


Expand Down
6 changes: 4 additions & 2 deletions rita/engine/translate_standalone.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def apply_operator(syntax, op):


def any_of_parse(lst, config, op=None):
clause = r"(^|\s)(({0})\s?)".format("|".join(sorted(lst, key=lambda x: (-len(x), x))))
clause = r"((^|\s)(({0})\s?))".format("|".join(sorted(lst, key=lambda x: (-len(x), x))))
return apply_operator(clause, op)


Expand Down Expand Up @@ -67,7 +67,9 @@ def phrase_parse(value, config, op=None):


def nested_parse(values, config, op=None):
(_, patterns) = rules_to_patterns("", values, config=config)
from rita.macros import resolve_value
(_, patterns) = rules_to_patterns("", [resolve_value(v, config=config)
for v in values], config=config)
return r"(?P<g{}>{})".format(config.new_nested_group_id(), "".join(patterns))


Expand Down
37 changes: 32 additions & 5 deletions rita/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from functools import reduce

from rita.utils import Node, deaccent
from rita.macros import resolve_value

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -231,9 +230,9 @@ def gen():
if type(p) is tuple:
(k, other, op) = p
if k == "nested":
fns = other[0][1]
yield "nested", list([resolve_value(f, config)
for f in fns]), op
fn = other[0][0]
children = other[0][1]
yield fn, children, op
else:
yield p
else:
Expand All @@ -242,14 +241,42 @@ def gen():
yield group_label, list(gen())


def flatten_2nd_level_nested(rules, config):
"""
1st level of nested: use PATTERN(...) inside of your rule
2nd level of nested: use PATTERN(...) which has PATTERN(...) and so on (recursively)

we want to resolve up to 1st level
"""

for group_label, pattern in rules:
def gen():
for p in pattern:
if type(p) is list:
for item in p:
yield item
else:
yield p

yield group_label, list(gen())


def preprocess_rules(root, config):
logger.info("Preprocessing rules")

rules = [rule_tuple(doc())
for doc in root
if doc and doc()]

pipeline = [dummy, expand_patterns, handle_deaccent, handle_rule_branching, handle_multi_word, handle_prefix]
pipeline = [
dummy,
expand_patterns,
handle_deaccent,
handle_rule_branching,
flatten_2nd_level_nested,
handle_multi_word,
handle_prefix
]

if config.implicit_hyphon:
logger.info("Adding implicit Hyphons")
Expand Down
31 changes: 9 additions & 22 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,34 +389,21 @@ def test_complex_number_match(engine):
fractions={"1 / 2", "3 / 4", "1 / 8", "3 / 8", "5 / 8", "7 / 8", "1 / 16", "3 / 16", "5 / 16", "7 / 16", "9 / 16",
"11 / 16", "13 / 16", "15 / 16", "1 / 32", "3 / 32", "5 / 32", "7 / 32", "9 / 32", "11 / 32", "13 / 32", "15 / 32",
"17 / 32", "19 / 32", "21 / 32", "23 / 32", "25 / 32", "27 / 32", "29 / 32", "31 / 32"}
complex_number = {NUM, WORD("-")?, IN_LIST(fractions)?}

{PATTERN(complex_number)}->MARK("NUMBER")
""")

results = parser('length 10 1 / 2 "')

print(results)
assert len(results) == 1
assert results[0] == ("10 1 / 2", "NUMBER")


@pytest.mark.parametrize('engine', [standalone_engine])
def test_complex_number_match_2(engine):
parser = engine("""
fractions={"1 / 2", "3 / 4", "1 / 8", "3 / 8", "5 / 8", "7 / 8", "1 / 16", "3 / 16", "5 / 16", "7 / 16", "9 / 16",
"11 / 16", "13 / 16", "15 / 16", "1 / 32", "3 / 32", "5 / 32", "7 / 32", "9 / 32", "11 / 32", "13 / 32", "15 / 32",
"17 / 32", "19 / 32", "21 / 32", "23 / 32", "25 / 32", "27 / 32", "29 / 32", "31 / 32"}
complex_number = {NUM, WORD("-")?, IN_LIST(fractions)?}
num_with_fractions = {NUM, WORD("-")?, IN_LIST(fractions)}
complex_number = {NUM|PATTERN(num_with_fractions)}

{WORD("length"), PATTERN(complex_number)}->MARK("NUMBER")
""")

results = parser('length 10 1 / 2 "')
simple_number = parser("length 50 cm")
assert len(simple_number) == 1
assert ("length 50", "NUMBER") == simple_number[0]

print(results)
assert len(results) == 1
assert results[0] == ("length 10 1 / 2", "NUMBER")
complex_number = parser('length 10 1 / 2 "')

assert len(complex_number) == 1
assert ("length 10 1 / 2", "NUMBER") == complex_number[0]


@pytest.mark.parametrize('engine', [standalone_engine])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def test_optional_list(self):
print(rules)

assert len(rules) == 1
assert rules[0] == re.compile(r"(?P<OPTIONAL_LIST>(?P<s0>((^|\s)((one|two)\s?)?)))", self.flags)
assert rules[0] == re.compile(r"(?P<OPTIONAL_LIST>(?P<s0>((^|\s)((one|two)\s?))?))", self.flags)

def test_complex_list(self):
rules = self.compiler("""
Expand Down