From c377c5605bfd00f4102730a4d392b6cbfe3acf3c Mon Sep 17 00:00:00 2001
From: Tpt <thomaspt@hotmail.fr>
Date: Wed, 12 Oct 2022 17:17:23 +0200
Subject: [PATCH] Python type subs: validate optionals

---
 python/generate_stubs.py | 14 ++++++++++++--
 python/src/model.rs      |  4 ++--
 python/src/store.rs      |  2 +-
 3 files changed, 15 insertions(+), 5 deletions(-)

diff --git a/python/generate_stubs.py b/python/generate_stubs.py
index 5e370804..4a28875f 100644
--- a/python/generate_stubs.py
+++ b/python/generate_stubs.py
@@ -140,12 +140,17 @@ def arguments_stub(callable, doc: str, types_to_import: Set[str], is_init: bool)
         }
 
     parsed_param_types = {}
+    optional_params = set()
     for match in re.findall(r"\n *:type *([a-z_]+): ([^\n]*) *\n", doc):
         if match[0] not in real_parameters:
             raise ValueError(
                 f"The parameter {match[0]} is defined in the documentation but not in the function signature"
             )
-        parsed_param_types[match[0]] = convert_type_from_doc(match[1], types_to_import)
+        type = match[1]
+        if type.endswith(", optional"):
+            optional_params.add(match[0])
+            type = type[:-10]
+        parsed_param_types[match[0]] = convert_type_from_doc(type, types_to_import)
 
     # we parse the parameters
     posonlyargs = []
@@ -163,9 +168,14 @@ def arguments_stub(callable, doc: str, types_to_import: Set[str], is_init: bool)
         param_ast = ast.arg(
             arg=param.name, annotation=parsed_param_types.get(param.name)
         )
+
         default_ast = None
         if param.default != param.empty:
             default_ast = ast.Constant(param.default)
+            if param.name not in optional_params:
+                raise ValueError(f"Parameter {param.name} is optional according to the type but not flagged as such in the doc")
+        elif param.name in optional_params:
+            raise ValueError(f"Parameter {param.name} is optional according to the documentation but has no default value")
 
         if param.kind == param.POSITIONAL_ONLY:
             posonlyargs.append(param_ast)
@@ -203,7 +213,7 @@ def returns_stub(doc: str, types_to_import: Set[str]):
 
 
 def convert_type_from_doc(type_str: str, types_to_import: Set[str]):
-    type_str = type_str.strip().removesuffix(", optional")
+    type_str = type_str.strip()
     return parse_type_to_ast(type_str, types_to_import)
 
 
diff --git a/python/src/model.rs b/python/src/model.rs
index 111a4b3c..95c61d53 100644
--- a/python/src/model.rs
+++ b/python/src/model.rs
@@ -114,7 +114,7 @@ impl PyNamedNode {
 /// An RDF `blank node <https://www.w3.org/TR/rdf11-concepts/#dfn-blank-node>`_.
 ///
 /// :param value: the `blank node ID <https://www.w3.org/TR/rdf11-concepts/#dfn-blank-node-identifier>`_ (if not present, a random blank node ID is automatically generated).
-/// :type value: str, optional
+/// :type value: str or None, optional
 /// :raises ValueError: if the blank node ID is invalid according to NTriples, Turtle, and SPARQL grammars.
 ///
 /// The :py:func:`str` function provides a serialization compatible with NTriples, Turtle, and SPARQL:
@@ -122,7 +122,7 @@ impl PyNamedNode {
 /// >>> str(BlankNode('ex'))
 /// '_:ex'
 #[pyclass(name = "BlankNode")]
-#[pyo3(text_signature = "(value)")]
+#[pyo3(text_signature = "(value = None)")]
 #[derive(Eq, PartialEq, Debug, Clone, Hash)]
 pub struct PyBlankNode {
     inner: BlankNode,
diff --git a/python/src/store.rs b/python/src/store.rs
index 9a6c7bd7..062fd1e4 100644
--- a/python/src/store.rs
+++ b/python/src/store.rs
@@ -102,7 +102,7 @@ impl PyStore {
     /// :param object: the quad object or :py:const:`None` to match everything.
     /// :type object: NamedNode or BlankNode or Literal or Triple or None
     /// :param graph_name: the quad graph name. To match only the default graph, use :py:class:`DefaultGraph`. To match everything use :py:const:`None`.
-    /// :type graph_name: NamedNode or BlankNode or DefaultGraph or None
+    /// :type graph_name: NamedNode or BlankNode or DefaultGraph or None, optional
     /// :return: an iterator of the quads matching the pattern.
     /// :rtype: iter(Quad)
     /// :raises IOError: if an I/O error happens during the quads lookup.